【神经网络】传统​​​​​​​RNN到LSTM,带你理解手搓LSTM

目录

一、基础:RNN

1.1循环神经网络结构

1.2RNN模型如何训练的

二、LSTM

2.1在RNN是如何修改的呢

2.2计算过程

2.2.1遗忘门

2.2.2更新门

2.2.3输入门

三、matlab手搓LSTM

3.1待训练参数

3.2流程图

3.3matlab代码


你好!以下皆为我自己的理解,仔细阅读,希望能助于你理解神经网络。

提到LSTM,离不开他的起源模型传统RNN神经网络,作为循环神经网络,RNN理论上可学习时间序列的依赖关系,不管是短期还是更早之前的都可以,但实际训练过程中人们发现,RNN在学习当前时刻更早的依赖关系时,会遇到梯度爆炸的问题,这里如果笔者想知道为什么可自行搜索一下。提一下本文很多地方省略了核函数,不影响理解。

一、基础:RNN

在学习LSTM之前我们先要知道,传统RNN循环神经网络是怎么样

1.1循环神经网络结构

这里我绘制了两张图,上面这张是传统神经网络,第一张是前反馈神经网络的,第二张是循环神经网络的。

前反馈神经网络
RNN递归循环神经网络

从图中看,对于前反馈传统神经网络的,他们的区别为:隐藏层不只与输出层有关,还与隐藏层本身也有关。换句话说,就是每个时间点的隐藏层与其他时间点的隐藏层之间也有连接,在计算隐藏层某一节点时,我们不只考虑了输出层,还考虑了其他时间点的隐藏层。此外可能有人注意到了,图中所有输入层权重,输出层权重,以及隐藏层之间的权重,他们为什么是同样的字母呀?这里就必须提到循环神经网络的一个权重机制了:权重共享机制。为了避免训练的参数过多,训练量过大,为何可以这样、要这样呢?

1.2RNN模型如何训练的

可以看着图理解一下,在实际过程中x_{1}就为t_{1}时刻的特征数据,h_{0}就为为上一时刻的隐藏层。这里深入解释一下,在实际训练过程中,我以故障诊断为例,数据集会是这样的:不同时间节点的特征数据,及故障标签。我们会假设一个时间步,假如数据集的数据是每秒的数据,我们设置时间步为5,那就是每五秒的数据为一个样本进行训练,这个样本就为图中的x_{1}x_{1}。其中有五个特征的数据为x_{11}x_{12}x_{13}x_{14}x_{15},训练时,一个时间步中的一个时间节点是特征输入进去,到隐藏层,再到输出层,最后输出Y1为标签的预测值,下图为一个时间步中的节点中的计算过程。

那么一个时间节点具体怎么影响下一个节点呢,是通过影响下一个节点的隐藏层,从而影响输出的,具体如下图,在计算h_{2}时,不仅考虑了X_{2}也考虑了,值得注意的的是,每个h之间的权重与偏置也是共享的。

二、LSTM

2.1在RNN是如何修改的呢

在传统RNN的神经网络的中,我们考虑上一时间节点信息,只是通过下文这一个简单的公式进行计算,这不仅会遇到梯度爆炸的问题,并且在实践过程中,发现往往无法做到长期的记忆,无法记住更早的东西,这个你可以这样理解:一个信息乘了很多次相同的权重,这个权重又位于0-1之间,乘了很多次的那个信息肯定被无限削弱了,以至于无法保留甚至没有了,对比一个信息乘一次权重,信息量肯定无法与之比较了。

改变就是改变在隐藏层,在隐藏层中加入门控,如何根据输入h1与X2计算的h_{2},下面我引用一张图来解释,图片来源:了解 LSTM 网络 -- colah 的博客,如有侵权立删,接下来我会解释LSTM的h_{2}怎么计算的。

关于h2如何计算更新的,用三门控制,遗忘们,更新门(很多地方叫输入门),输出门(这里输出是输出h,不是输出预测值Y,读者请注意),接下来讲解具体的计算公式,图中一些基本的图注见下图。

2.2计算过程

就以C2的更新为例吧

2.2.1遗忘门

{C_{2}}' = C_{1}*\sigma (U_{f}h_{1}+W_{f}X_{2}+b_{f})

\sigma为sigmoid函数,输出为0-1之间的数,越接近1则记住的越多,越接近0则遗忘的越多,遗忘门一句话就是C_{1}乘上一个概率,遗忘多少。

2.2.2更新门

C_{2} = {C_{2}}'+tanh (U_{i2}h_{1}+W_{i2}X_{2}+b_{2i})* \sigma (U_{i1}h_{1}+W_{i1}X_{2}+b_{i1})

经过遗忘的C_{1},加上-----由h_{1}X_{2}产生的备选项乘概率的细胞状态。

2.2.3输入门

h_{2} =\sigma (U_{o}h_{1}+W_{o}X_{2}+b_{o}) *tanh(C_{2})

输出门也是一个概率乘上经过处理的最终细胞状态C_{2},就得到了最终的h_{2}

这里得到的输出还不是真正的输出,h2还要经过输出层才能计算为y2。

三、matlab手搓LSTM

其实本文讲解的是最传统的LSTN,还有一些变体,可以看看本文插入的那个链接

3.1待训练参数

其中要训练的参数有:

遗忘门:W_{fh} |W_{fh}|b_{f}

更新门:U_{i2}\mid W_{i2}\mid b_{2i}\mid U_{i1}\mid W_{i1}\mid b_{i1}

输出门:U_{o}\mid W_{o}\mid b_{o}

此外还有输出层:w_{y}\mid b_{y}

一般优化算法为梯度下降算法,这个我前面的文章有提到,读者可以去找找。

3.2流程图

下面有个流程图,整个训练过程:

3.3matlab代码

LSTM神经网络 - MATLAB实现
功能: 完整的LSTM前向传播、反向传播和训练

classdef LSTM < handle
    properties
        input_size
        hidden_size
        sequence_length
        
        % 权重矩阵
        Wf, Wi, Wc, Wo  % 输入权重
        Uf, Ui, Uc, Uo  % 隐藏状态权重
        bf, bi, bc, bo  % 偏置
        
        % 训练参数
        learning_rate
        
        % 缓存变量(用于反向传播)
        cache
    end
    
    methods
        function obj = LSTM(input_size, hidden_size, learning_rate)
            % 构造函数
            obj.input_size = input_size;
            obj.hidden_size = hidden_size;
            obj.learning_rate = learning_rate;
            
            % 初始化权重和偏置
            obj.initialize_parameters();
        end
        
        function initialize_parameters(obj)
            % Xavier初始化权重
            scale = sqrt(2.0 / (obj.input_size + obj.hidden_size));
            
            % 输入权重 (hidden_size x input_size)
            obj.Wf = randn(obj.hidden_size, obj.input_size) * scale;
            obj.Wi = randn(obj.hidden_size, obj.input_size) * scale;
            obj.Wc = randn(obj.hidden_size, obj.input_size) * scale;
            obj.Wo = randn(obj.hidden_size, obj.input_size) * scale;
            
            % 隐藏状态权重 (hidden_size x hidden_size)
            obj.Uf = randn(obj.hidden_size, obj.hidden_size) * scale;
            obj.Ui = randn(obj.hidden_size, obj.hidden_size) * scale;
            obj.Uc = randn(obj.hidden_size, obj.hidden_size) * scale;
            obj.Uo = randn(obj.hidden_size, obj.hidden_size) * scale;
            
            % 偏置初始化为零
            obj.bf = zeros(obj.hidden_size, 1);
            obj.bi = zeros(obj.hidden_size, 1);
            obj.bc = zeros(obj.hidden_size, 1);
            obj.bo = zeros(obj.hidden_size, 1);
            
            % 遗忘门偏置设为1(有助于训练)
            obj.bf = ones(obj.hidden_size, 1);
        end
        
        function [h, c, cache] = lstm_cell_forward(obj, x, h_prev, c_prev)
            % LSTM单元前向传播
            % 输入:
            %   x: 当前输入 (input_size x 1)
            %   h_prev: 前一时刻隐藏状态 (hidden_size x 1)
            %   c_prev: 前一时刻细胞状态 (hidden_size x 1)
            
            % 1. 遗忘门
            f = obj.sigmoid(obj.Wf * x + obj.Uf * h_prev + obj.bf);
            
            % 2. 输入门
            i = obj.sigmoid(obj.Wi * x + obj.Ui * h_prev + obj.bi);
            
            % 3. 候选值
            c_tilde = tanh(obj.Wc * x + obj.Uc * h_prev + obj.bc);
            
            % 4. 更新细胞状态
            c = f .* c_prev + i .* c_tilde;
            
            % 5. 输出门
            o = obj.sigmoid(obj.Wo * x + obj.Uo * h_prev + obj.bo);
            
            % 6. 更新隐藏状态
            h = o .* tanh(c);
            
            % 保存中间变量用于反向传播
            cache.x = x;
            cache.h_prev = h_prev;
            cache.c_prev = c_prev;
            cache.f = f;
            cache.i = i;
            cache.c_tilde = c_tilde;
            cache.c = c;
            cache.o = o;
            cache.h = h;
        end
        
        function [outputs, final_h, final_c] = forward(obj, X)
            % LSTM网络前向传播
            % 输入: X (input_size x sequence_length)
            
            [~, sequence_length] = size(X);
            obj.sequence_length = sequence_length;
            
            % 初始化状态
            h = zeros(obj.hidden_size, 1);
            c = zeros(obj.hidden_size, 1);
            
            % 存储输出和缓存
            outputs = zeros(obj.hidden_size, sequence_length);
            obj.cache = cell(sequence_length, 1);
            
            % 前向传播每个时间步
            for t = 1:sequence_length
                [h, c, cache] = obj.lstm_cell_forward(X(:, t), h, c);
                outputs(:, t) = h;
                obj.cache{t} = cache;
            end
            
            final_h = h;
            final_c = c;
        end
        
        function [dWf, dWi, dWc, dWo, dUf, dUi, dUc, dUo, dbf, dbi, dbc, dbo] = ...
                lstm_cell_backward(obj, dh, dc, cache)
            % LSTM单元反向传播
            
            % 从缓存中提取变量
            x = cache.x;
            h_prev = cache.h_prev;
            c_prev = cache.c_prev;
            f = cache.f;
            i = cache.i;
            c_tilde = cache.c_tilde;
            c = cache.c;
            o = cache.o;
            
            % 输出门梯度
            do = dh .* tanh(c);
            do_raw = do .* o .* (1 - o);  % sigmoid导数
            
            % 细胞状态梯度
            dc = dc + dh .* o .* (1 - tanh(c).^2);
            
            % 候选值梯度
            dc_tilde = dc .* i;
            dc_tilde_raw = dc_tilde .* (1 - c_tilde.^2);  % tanh导数
            
            % 输入门梯度
            di = dc .* c_tilde;
            di_raw = di .* i .* (1 - i);  % sigmoid导数
            
            % 遗忘门梯度
            df = dc .* c_prev;
            df_raw = df .* f .* (1 - f);  % sigmoid导数
            
            % 权重梯度
            dWf = df_raw * x';
            dWi = di_raw * x';
            dWc = dc_tilde_raw * x';
            dWo = do_raw * x';
            
            dUf = df_raw * h_prev';
            dUi = di_raw * h_prev';
            dUc = dc_tilde_raw * h_prev';
            dUo = do_raw * h_prev';
            
            % 偏置梯度
            dbf = df_raw;
            dbi = di_raw;
            dbc = dc_tilde_raw;
            dbo = do_raw;
            
            % 传递给前一时刻的梯度
            dh_prev = obj.Uf' * df_raw + obj.Ui' * di_raw + ...
                     obj.Uc' * dc_tilde_raw + obj.Uo' * do_raw;
            dc_prev = dc .* f;
        end
        
        function backward(obj, dL_doutput)
            % LSTM网络反向传播
            % 输入: dL_doutput (hidden_size x sequence_length)
            
            % 初始化梯度累积器
            dWf_total = zeros(size(obj.Wf));
            dWi_total = zeros(size(obj.Wi));
            dWc_total = zeros(size(obj.Wc));
            dWo_total = zeros(size(obj.Wo));
            
            dUf_total = zeros(size(obj.Uf));
            dUi_total = zeros(size(obj.Ui));
            dUc_total = zeros(size(obj.Uc));
            dUo_total = zeros(size(obj.Uo));
            
            dbf_total = zeros(size(obj.bf));
            dbi_total = zeros(size(obj.bi));
            dbc_total = zeros(size(obj.bc));
            dbo_total = zeros(size(obj.bo));
            
            % 初始化梯度
            dh_next = zeros(obj.hidden_size, 1);
            dc_next = zeros(obj.hidden_size, 1);
            
            % 从最后一个时间步开始反向传播
            for t = obj.sequence_length:-1:1
                % 当前时刻的输出梯度
                dh = dL_doutput(:, t) + dh_next;
                dc = dc_next;
                
                % 反向传播通过LSTM单元
                [dWf, dWi, dWc, dWo, dUf, dUi, dUc, dUo, dbf, dbi, dbc, dbo] = ...
                    obj.lstm_cell_backward(dh, dc, obj.cache{t});
                
                % 累积梯度
                dWf_total = dWf_total + dWf;
                dWi_total = dWi_total + dWi;
                dWc_total = dWc_total + dWc;
                dWo_total = dWo_total + dWo;
                
                dUf_total = dUf_total + dUf;
                dUi_total = dUi_total + dUi;
                dUc_total = dUc_total + dUc;
                dUo_total = dUo_total + dUo;
                
                dbf_total = dbf_total + dbf;
                dbi_total = dbi_total + dbi;
                dbc_total = dbc_total + dbc;
                dbo_total = dbo_total + dbo;
                
                % 更新下一时刻的梯度
                if t > 1
                    cache_prev = obj.cache{t-1};
                    h_prev = cache_prev.h_prev;
                    c_prev = cache_prev.c_prev;
                    f = cache_prev.f;
                    
                    dh_next = obj.Uf' * (dh .* obj.cache{t}.f .* (1 - obj.cache{t}.f)) + ...
                             obj.Ui' * (dh .* obj.cache{t}.i .* (1 - obj.cache{t}.i)) + ...
                             obj.Uc' * (dh .* obj.cache{t}.c_tilde .* (1 - obj.cache{t}.c_tilde.^2)) + ...
                             obj.Uo' * (dh .* obj.cache{t}.o .* (1 - obj.cache{t}.o));
                    dc_next = dc .* f;
                end
            end
            
            % 更新参数
            obj.update_parameters(dWf_total, dWi_total, dWc_total, dWo_total, ...
                                 dUf_total, dUi_total, dUc_total, dUo_total, ...
                                 dbf_total, dbi_total, dbc_total, dbo_total);
        end
        
        function update_parameters(obj, dWf, dWi, dWc, dWo, dUf, dUi, dUc, dUo, ...
                                  dbf, dbi, dbc, dbo)
            % 使用梯度下降更新参数
            obj.Wf = obj.Wf - obj.learning_rate * dWf;
            obj.Wi = obj.Wi - obj.learning_rate * dWi;
            obj.Wc = obj.Wc - obj.learning_rate * dWc;
            obj.Wo = obj.Wo - obj.learning_rate * dWo;
            
            obj.Uf = obj.Uf - obj.learning_rate * dUf;
            obj.Ui = obj.Ui - obj.learning_rate * dUi;
            obj.Uc = obj.Uc - obj.learning_rate * dUc;
            obj.Uo = obj.Uo - obj.learning_rate * dUo;
            
            obj.bf = obj.bf - obj.learning_rate * dbf;
            obj.bi = obj.bi - obj.learning_rate * dbi;
            obj.bc = obj.bc - obj.learning_rate * dbc;
            obj.bo = obj.bo - obj.learning_rate * dbo;
        end
        
        function loss = train_step(obj, X, Y)
            % 单步训练
            % 输入:
            %   X: 输入序列 (input_size x sequence_length)
            %   Y: 目标输出 (hidden_size x sequence_length)
            
            % 前向传播
            [outputs, ~, ~] = obj.forward(X);
            
            % 计算损失 (均方误差)
            loss = 0.5 * sum(sum((outputs - Y).^2)) / size(Y, 2);
            
            % 计算输出梯度
            dL_doutput = (outputs - Y) / size(Y, 2);
            
            % 反向传播
            obj.backward(dL_doutput);
        end
        
        function train(obj, X_train, Y_train, epochs)
            % 训练LSTM网络
            % 输入:
            %   X_train: 训练输入 (input_size x sequence_length x num_samples)
            %   Y_train: 训练目标 (hidden_size x sequence_length x num_samples)
            %   epochs: 训练轮数
            
            [~, ~, num_samples] = size(X_train);
            losses = zeros(epochs, 1);
            
            fprintf('开始训练LSTM网络...\n');
            
            for epoch = 1:epochs
                total_loss = 0;
                
                % 遍历所有训练样本
                for sample = 1:num_samples
                    X = X_train(:, :, sample);
                    Y = Y_train(:, :, sample);
                    
                    loss = obj.train_step(X, Y);
                    total_loss = total_loss + loss;
                end
                
                avg_loss = total_loss / num_samples;
                losses(epoch) = avg_loss;
                
                % 每100轮打印一次损失
                if mod(epoch, 100) == 0 || epoch == 1
                    fprintf('Epoch %d/%d, 平均损失: %.6f\n', epoch, epochs, avg_loss);
                end
            end
            
            % 绘制损失曲线
            figure;
            plot(1:epochs, losses, 'b-', 'LineWidth', 2);
            xlabel('训练轮数');
            ylabel('损失');
            title('LSTM训练损失曲线');
            grid on;
        end
        
        function predictions = predict(obj, X)
            % 预测
            % 输入: X (input_size x sequence_length)
            % 输出: predictions (hidden_size x sequence_length)
            
            [predictions, ~, ~] = obj.forward(X);
        end
    end
    
    methods (Static)
        function y = sigmoid(x)
            % Sigmoid激活函数
            y = 1 ./ (1 + exp(-x));
        end
    end
end

%% 使用示例
function demo_lstm()
    % LSTM使用示例
    fprintf('=== LSTM神经网络演示 ===\n');
    
    % 设置参数
    input_size = 10;
    hidden_size = 20;
    sequence_length = 15;
    num_samples = 100;
    learning_rate = 0.01;
    epochs = 500;
    
    % 创建LSTM实例
    lstm = LSTM(input_size, hidden_size, learning_rate);
    
    % 生成模拟数据
    fprintf('生成训练数据...\n');
    X_train = randn(input_size, sequence_length, num_samples);
    
    % 简单的目标函数:输出是输入的简单变换
    Y_train = zeros(hidden_size, sequence_length, num_samples);
    for i = 1:num_samples
        for t = 1:sequence_length
            % 目标输出是输入的线性组合加噪声
            W_target = randn(hidden_size, input_size) * 0.1;
            Y_train(:, t, i) = W_target * X_train(:, t, i) + 0.01 * randn(hidden_size, 1);
        end
    end
    
    % 训练网络
    lstm.train(X_train, Y_train, epochs);
    
    % 测试预测
    fprintf('\n测试预测...\n');
    X_test = randn(input_size, sequence_length);
    predictions = lstm.predict(X_test);
    
    fprintf('预测完成! 输出维度: %dx%d\n', size(predictions, 1), size(predictions, 2));
    
    % 显示部分结果
    fprintf('\n前5个时间步的预测结果:\n');
    disp(predictions(:, 1:5));
end

% 运行演示
% demo_lstm();

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值