news 2026/4/15 8:39:21

深入理解LSTM:从结构到PyTorch实践

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
深入理解LSTM:从结构到PyTorch实践

1. 引言:为什么需要LSTM?

循环神经网络(RNN)因其天然的时序结构,被广泛应用于自然语言处理、时间序列预测等任务。然而,传统RNN在处理长序列时容易遭遇梯度消失梯度爆炸问题,导致模型难以捕捉远距离的语义依赖。例如,在“我出生在法国……我会说法语”中,“法语”依赖于远在前面的“法国”,传统RNN往往难以建立这种长距离关联。

为了解决这一问题,Sepp HochreiterJürgen Schmidhuber于1997年提出了长短期记忆网络(Long Short-Term Memory, LSTM)。LSTM通过精巧的门控机制细胞状态,选择性地记忆或遗忘信息,从而有效缓解了长序列训练中的梯度消失问题。后来(2000年左右),Gers等人又引入了遗忘门,进一步完善了LSTM结构。


2. LSTM的核心思想

LSTM与传统RNN最大的区别在于:它引入了一条细胞状态(Cell State)的“传送带”,信息可以在时间步上几乎无损地流动。同时,LSTM使用三个门控单元(遗忘门、输入门、输出门)来控制信息的遗忘写入读出

  • 细胞状态Ct:负责长期记忆,贯穿整个序列。

  • 隐状态ht:负责短期记忆,也是每个时间步的输出。

  • :使用sigmoid函数输出0~1之间的值,表示信息“通过”的比例(0表示完全阻断,1表示完全通过)。


3. LSTM内部结构详解(含公式)

下图示意了单个LSTM单元的内部结构(图中省略了偏置项,但在实际实现中存在)。

3.1 遗忘门(Forget Gate)

遗忘门决定上一时刻的细胞状态 Ct−1 中有多少信息需要被丢弃。它读取当前输入 xt和上一时刻隐状态 ht−1,输出一个0~1的向量 ft​。

  • σ 为sigmoid函数。

  • [ht−1​,xt​] 表示将两个向量拼接。

  • Wf​ 和 bf​ 为可学习参数。

直观理解:如果 ft中的某个分量接近0,则对应的历史信息将被遗忘;接近1则保留。

3.2 输入门(Input Gate)

输入门决定将多少新信息写入细胞状态。它由两部分组成:

  • 门控部分iti:决定哪些位置要更新。

  • 候选细胞状态C~t:利用tanh层生成新的候选值向量。

  • tanh 将输出值压缩到-1到1之间,起到调节作用。

3.3 细胞状态更新

旧细胞状态 Ct−1经过遗忘门进行选择性遗忘,再与输入门筛选后的候选状态相加,得到新的细胞状态 Ct。

  • +表示逐元素相乘(Hadamard积)。

意义:这一步完美融合了“忘记过去不重要的”和“记住当前新的重要信息”。

3.4 输出门(Output Gate)

输出门决定当前时刻的隐状态 htht​(同时也是该时刻的输出)。它基于更新后的细胞状态 CtCt​,并经过一个门控筛选。

  • 先用tanh将 Ct 的值缩放至-1~1,再通过输出门 ot​ 决定哪些信息最终输出。

总结:LSTM通过上述四个步骤,实现了对长序列信息的选择性存储和读取。其中遗忘门输入门配合完成细胞状态的更新,输出门控制隐状态的表达。


4. PyTorch中的LSTM实现

PyTorch提供了便捷的torch.nn.LSTM模块,我们可以直接调用。

4.1 参数说明

nn.LSTM(input_size, hidden_size, num_layers=1, bias=True, batch_first=False, dropout=0, bidirectional=False)
  • input_size:输入特征维度(例如词向量的长度)。

  • hidden_size:隐状态 htht​ 的维度。

  • num_layers:LSTM堆叠的层数(大于1时为多层LSTM)。

  • batch_first:若为True,输入形状为(batch, seq_len, input_size),否则为(seq_len, batch, input_size)

  • 注意bidirectional参数在这里应保持False(本文不涉及Bi-LSTM)。

4.2 输入与输出形状

  • 输入

    • input:形状(seq_len, batch, input_size)

    • h0(可选):初始隐状态,形状(num_layers, batch, hidden_size)

    • c0(可选):初始细胞状态,形状(num_layers, batch, hidden_size)

  • 输出

    • output:所有时间步的隐状态,形状(seq_len, batch, hidden_size)

    • (hn, cn):最后一个时间步的隐状态和细胞状态,形状均为(num_layers, batch, hidden_size)

4.3 完整示例

# 定义LSTM的参数含义: (input_size, hidden_size, num_layers) # 定义输入张量的参数含义: (sequence_length, batch_size, input_size) # 定义隐藏层初始张量和细胞初始状态张量的参数含义: # (num_layers * num_directions, batch_size, hidden_size) >>> import torch.nn as nn >>> import torch >>> rnn = nn.LSTM(5, 6, 2) >>> input = torch.randn(1, 3, 5) >>> h0 = torch.randn(2, 3, 6) >>> c0 = torch.randn(2, 3, 6) >>> output, (hn, cn) = rnn(input, (h0, c0)) >>> output tensor([[[ 0.0447, -0.0335, 0.1454, 0.0438, 0.0865, 0.0416], [ 0.0105, 0.1923, 0.5507, -0.1742, 0.1569, -0.0548], [-0.1186, 0.1835, -0.0022, -0.1388, -0.0877, -0.4007]]], grad_fn=<StackBackward>) >>> hn tensor([[[ 0.4647, -0.2364, 0.0645, -0.3996, -0.0500, -0.0152], [ 0.3852, 0.0704, 0.2103, -0.2524, 0.0243, 0.0477], [ 0.2571, 0.0608, 0.2322, 0.1815, -0.0513, -0.0291]], [[ 0.0447, -0.0335, 0.1454, 0.0438, 0.0865, 0.0416], [ 0.0105, 0.1923, 0.5507, -0.1742, 0.1569, -0.0548], [-0.1186, 0.1835, -0.0022, -0.1388, -0.0877, -0.4007]]], grad_fn=<StackBackward>) >>> cn tensor([[[ 0.8083, -0.5500, 0.1009, -0.5806, -0.0668, -0.1161], [ 0.7438, 0.0957, 0.5509, -0.7725, 0.0824, 0.0626], [ 0.3131, 0.0920, 0.8359, 0.9187, -0.4826, -0.0717]], [[ 0.1240, -0.0526, 0.3035, 0.1099, 0.5915, 0.0828], [ 0.0203, 0.8367, 0.9832, -0.4454, 0.3917, -0.1983], [-0.2976, 0.7764, -0.0074, -0.1965, -0.1343, -0.6683]]], grad_fn=<StackBackward>)

在实际任务中(如情感分析),我们通常取output[:, -1, :]作为最后一个时间步的隐状态,再接入全连接层进行分类。


5. LSTM的优缺点

✅ 优势

  1. 长距离依赖建模能力强:相比传统RNN,LSTM通过门控机制有效缓解了梯度消失/爆炸,可以处理长达数百步的序列。

  2. 灵活性高:可以堆叠多层,也可以与其他网络(如CNN、Attention)结合。

  3. 工程成熟:各种深度学习框架均有高效实现,且有很多预训练变体。

❌ 缺点

  1. 计算复杂度高:每个时间步需要计算4个全连接层(遗忘门、输入门、输出门、候选状态),参数量约为传统RNN的4倍,训练和推理较慢。

  2. 难以并行:LSTM本质是递归结构,后一个时间步依赖前一步的输出,无法像Transformer那样进行大规模并行计算。

  3. 并非万能:在超长序列(数千步)上仍有信息衰减,且对随机打乱的序列不敏感。


6. 总结

LSTM是RNN家族中最经典、最成功的变体之一。它通过遗忘门、输入门、输出门细胞状态实现了对长期记忆的精细控制,解决了原始RNN的梯度问题。虽然近年来Transformer等模型在多数NLP任务上取得了更好效果,但LSTM在时间序列预测、语音识别、小规模序列建模等场景中依然具有重要价值。掌握LSTM的内部原理和PyTorch实现,是深入理解序列模型的关键一步。

参考文献

  • Hochreiter, S., & Schmidhuber, J. (1997). Long short-term memory.Neural computation, 9(8), 1735-1780.

  • Gers, F. A., Schmidhuber, J., & Cummins, F. (2000). Learning to forget: Continual prediction with LSTM.Neural computation, 12(10), 2451-2471.

希望本文能帮助你彻底搞懂LSTM!如果有任何疑问,欢迎在评论区留言讨论。

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/15 8:38:17

从接线到调试:一份超详细的汇川PLC与MCGS触摸屏485通讯避坑指南

从零搭建工业控制系统&#xff1a;汇川PLC与MCGS触摸屏485通讯全流程实战 在工业自动化项目中&#xff0c;稳定可靠的通讯系统是确保设备高效运行的基础。本文将带您完整走通汇川Easy 301 PLC与MCGS触摸屏通过485总线建立Modbus RTU通讯的全过程&#xff0c;特别针对实际工程中…

作者头像 李华
网站建设 2026/4/15 8:38:15

AD5933阻抗测量芯片的驱动代码优化与分段PGA校准实践

1. AD5933阻抗测量芯片的核心原理 AD5933是ADI公司推出的一款高集成度阻抗测量芯片&#xff0c;内部集成了DDS频率发生器、12位ADC和DFT数字信号处理单元。它的核心工作原理可以概括为&#xff1a;通过内部DDS生成精确的正弦波激励信号&#xff0c;经过外部阻抗网络后&#xff…

作者头像 李华
网站建设 2026/4/15 8:36:12

Sunshine游戏串流终极指南:三步实现高画质低延迟游戏体验

Sunshine游戏串流终极指南&#xff1a;三步实现高画质低延迟游戏体验 【免费下载链接】Sunshine Self-hosted game stream host for Moonlight. 项目地址: https://gitcode.com/GitHub_Trending/su/Sunshine Sunshine是一款免费开源的自主托管游戏串流服务器&#xff0c…

作者头像 李华
网站建设 2026/4/15 8:31:23

工业肌肉:03 变频器到底改变了什么?为什么它能让电机“听话”

03 变频器到底改变了什么?为什么它能让电机“听话” 变频器不是控制电机,而是控制电机背后的“电磁节奏”。 上次把伺服舞王拆得七零八落,今天终于轮到咱们车间里最亲民的“大管家”——变频器了。工厂里风机、水泵、传送带、搅拌机……哪台大电机旁边没挂个铁箱子?别看它其…

作者头像 李华