PyTorch:一个简单的RNN模型训练(含对RNN的讲解以及代码)

一、简介

循环神经网络(Recurrent Neural Networks, RNN)是一类神经网络架构,专门用于处理序列数据,RNN具有"记忆"能力,能够捕捉时间序列或有序数据的动态信息,能够处理序列数据,如文本、时间序列或音频。

RNN 在自然语言处理(NLP)、语音识别、时间序列预测等任务中有着广泛的应用。

RNN 的关键特性是其能够保持隐状态(hidden state),使得网络能够记住先前时间步的信息,这对于处理序列数据至关重要。

二、RNN的基本结构

本图是取自B站耿直哥
通俗来讲,CNN就比前馈神经网络多了时间上的维度

在这里插入图片描述

参考:B站耿直哥

1. 核心组件

RNN由以下几个核心组件构成:

  • 输入层 :接收当前时间步的输入数据
  • 隐藏层 :包含循环连接,保存历史信息
  • 输出层 :产生预测结果
  • 隐藏状态 :作为"记忆",将信息从一个时间步传递到下一个时间步

2. 数学表示

RNN的基本计算可以表示为:

h_t = tanh(W_xh * x_t + W_hh * h_{t-1} + b_h)
y_t = W_hy * h_t + b_y

其中:

  • h_t :当前时间步的隐藏状态
  • x_t :当前时间步的输入
  • h_{t-1} :前一时间步的隐藏状态
  • y_t :当前时间步的输出
  • W_xh :输入到隐藏层的权重矩阵
  • W_hh :隐藏层到隐藏层的权重矩阵(循环连接)
  • W_hy :隐藏层到输出层的权重矩阵
  • b_h 和 b_y :偏置项

3. 循环结构

RNN的关键特点是其循环结构:

  • 每个时间步使用相同的权重矩阵
  • 隐藏状态在时间步之间传递
  • 可以处理任意长度的序列

三、RNN的工作流程

  1. 初始化 :设置初始隐藏状态(通常为零向量)
  2. 前向传播 :
    • 接收当前时间步的输入
    • 结合前一时间步的隐藏状态计算当前隐藏状态
    • 基于当前隐藏状态计算输出
  3. 重复 :对序列中的每个元素重复上述过程
  4. 反向传播 :使用称为"通过时间的反向传播"(BPTT)的算法进行训练

四、RNN的类型

1. 基本RNN(Vanilla RNN)

就是我们上面描述的最简单形式,存在梯度消失/爆炸问题,难以捕获长期依赖关系。

2. LSTM(长短期记忆网络)

通过引入门控机制(输入门、遗忘门、输出门)和单元状态来解决基本RNN的问题:

  • 遗忘门:决定丢弃哪些信息
  • 输入门:决定存储哪些新信息
  • 输出门:决定输出哪些信息

3. GRU(门控循环单元)

LSTM的简化版本,只有两个门(更新门和重置门),参数更少但效果通常相当。

五、RNN的优缺点

优点

  • 能够处理变长序列数据
  • 参数共享(在所有时间步使用相同的权重)
  • 能够捕获序列中的时间依赖关系

缺点

  • 基本RNN存在梯度消失/爆炸问题
  • 难以捕获长期依赖关系(可能只能记住十步以内)
  • 计算效率较低(难以并行化)

六、PyTorch代码(详细注释)

import torch  # PyTorch深度学习框架
import torch.nn as nn  # 神经网络模块
import torch.optim as optim  # 优化器
import numpy as np  # 数值计算库

# 1. 数据准备和预处理
# 数据集:字符序列预测(Hello -> Elloh)
# 定义字符集,这里只包含"hello"中的字符,注意有重复字符只保留一次
char_set = list("hello")  # 结果为['h', 'e', 'l', 'o'],'l'只保留一个

# 创建字符到索引的映射字典,enumerate返回索引和值的元组
char_to_idx = {
   
   c: i for i, c in enumerate
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值