1. 引言:为什么需要LSTM?
循环神经网络(RNN)因其天然的时序结构,被广泛应用于自然语言处理、时间序列预测等任务。然而,传统RNN在处理长序列时容易遭遇梯度消失或梯度爆炸问题,导致模型难以捕捉远距离的语义依赖。例如,在“我出生在法国……我会说法语”中,“法语”依赖于远在前面的“法国”,传统RNN往往难以建立这种长距离关联。
为了解决这一问题,Sepp Hochreiter和Jürgen Schmidhuber于1997年提出了长短期记忆网络(Long Short-Term Memory, LSTM)。LSTM通过精巧的门控机制和细胞状态,选择性地记忆或遗忘信息,从而有效缓解了长序列训练中的梯度消失问题。后来(2000年左右),Gers等人又引入了遗忘门,进一步完善了LSTM结构。
2. LSTM的核心思想
LSTM与传统RNN最大的区别在于:它引入了一条细胞状态(Cell State)的“传送带”,信息可以在时间步上几乎无损地流动。同时,LSTM使用三个门控单元(遗忘门、输入门、输出门)来控制信息的遗忘、写入和读出。
细胞状态Ct:负责长期记忆,贯穿整个序列。
隐状态ht:负责短期记忆,也是每个时间步的输出。
门:使用sigmoid函数输出0~1之间的值,表示信息“通过”的比例(0表示完全阻断,1表示完全通过)。
3. LSTM内部结构详解(含公式)
下图示意了单个LSTM单元的内部结构(图中省略了偏置项,但在实际实现中存在)。
3.1 遗忘门(Forget Gate)
遗忘门决定上一时刻的细胞状态 Ct−1 中有多少信息需要被丢弃。它读取当前输入 xt和上一时刻隐状态 ht−1,输出一个0~1的向量 ft。
σ 为sigmoid函数。
[ht−1,xt] 表示将两个向量拼接。
Wf 和 bf 为可学习参数。
直观理解:如果 ft中的某个分量接近0,则对应的历史信息将被遗忘;接近1则保留。
3.2 输入门(Input Gate)
输入门决定将多少新信息写入细胞状态。它由两部分组成:
门控部分iti:决定哪些位置要更新。
候选细胞状态C~t:利用tanh层生成新的候选值向量。
tanh 将输出值压缩到-1到1之间,起到调节作用。
3.3 细胞状态更新
旧细胞状态 Ct−1经过遗忘门进行选择性遗忘,再与输入门筛选后的候选状态相加,得到新的细胞状态 Ct。
+表示逐元素相乘(Hadamard积)。
意义:这一步完美融合了“忘记过去不重要的”和“记住当前新的重要信息”。
3.4 输出门(Output Gate)
输出门决定当前时刻的隐状态 htht(同时也是该时刻的输出)。它基于更新后的细胞状态 CtCt,并经过一个门控筛选。
先用tanh将 Ct 的值缩放至-1~1,再通过输出门 ot 决定哪些信息最终输出。
总结:LSTM通过上述四个步骤,实现了对长序列信息的选择性存储和读取。其中遗忘门和输入门配合完成细胞状态的更新,输出门控制隐状态的表达。
4. PyTorch中的LSTM实现
PyTorch提供了便捷的torch.nn.LSTM模块,我们可以直接调用。
4.1 参数说明
nn.LSTM(input_size, hidden_size, num_layers=1, bias=True, batch_first=False, dropout=0, bidirectional=False)input_size:输入特征维度(例如词向量的长度)。hidden_size:隐状态 htht 的维度。num_layers:LSTM堆叠的层数(大于1时为多层LSTM)。batch_first:若为True,输入形状为(batch, seq_len, input_size),否则为(seq_len, batch, input_size)。注意:
bidirectional参数在这里应保持False(本文不涉及Bi-LSTM)。
4.2 输入与输出形状
输入:
input:形状(seq_len, batch, input_size)h0(可选):初始隐状态,形状(num_layers, batch, hidden_size)c0(可选):初始细胞状态,形状(num_layers, batch, hidden_size)
输出:
output:所有时间步的隐状态,形状(seq_len, batch, hidden_size)(hn, cn):最后一个时间步的隐状态和细胞状态,形状均为(num_layers, batch, hidden_size)
4.3 完整示例
# 定义LSTM的参数含义: (input_size, hidden_size, num_layers) # 定义输入张量的参数含义: (sequence_length, batch_size, input_size) # 定义隐藏层初始张量和细胞初始状态张量的参数含义: # (num_layers * num_directions, batch_size, hidden_size) >>> import torch.nn as nn >>> import torch >>> rnn = nn.LSTM(5, 6, 2) >>> input = torch.randn(1, 3, 5) >>> h0 = torch.randn(2, 3, 6) >>> c0 = torch.randn(2, 3, 6) >>> output, (hn, cn) = rnn(input, (h0, c0)) >>> output tensor([[[ 0.0447, -0.0335, 0.1454, 0.0438, 0.0865, 0.0416], [ 0.0105, 0.1923, 0.5507, -0.1742, 0.1569, -0.0548], [-0.1186, 0.1835, -0.0022, -0.1388, -0.0877, -0.4007]]], grad_fn=<StackBackward>) >>> hn tensor([[[ 0.4647, -0.2364, 0.0645, -0.3996, -0.0500, -0.0152], [ 0.3852, 0.0704, 0.2103, -0.2524, 0.0243, 0.0477], [ 0.2571, 0.0608, 0.2322, 0.1815, -0.0513, -0.0291]], [[ 0.0447, -0.0335, 0.1454, 0.0438, 0.0865, 0.0416], [ 0.0105, 0.1923, 0.5507, -0.1742, 0.1569, -0.0548], [-0.1186, 0.1835, -0.0022, -0.1388, -0.0877, -0.4007]]], grad_fn=<StackBackward>) >>> cn tensor([[[ 0.8083, -0.5500, 0.1009, -0.5806, -0.0668, -0.1161], [ 0.7438, 0.0957, 0.5509, -0.7725, 0.0824, 0.0626], [ 0.3131, 0.0920, 0.8359, 0.9187, -0.4826, -0.0717]], [[ 0.1240, -0.0526, 0.3035, 0.1099, 0.5915, 0.0828], [ 0.0203, 0.8367, 0.9832, -0.4454, 0.3917, -0.1983], [-0.2976, 0.7764, -0.0074, -0.1965, -0.1343, -0.6683]]], grad_fn=<StackBackward>)在实际任务中(如情感分析),我们通常取output[:, -1, :]作为最后一个时间步的隐状态,再接入全连接层进行分类。
5. LSTM的优缺点
✅ 优势
长距离依赖建模能力强:相比传统RNN,LSTM通过门控机制有效缓解了梯度消失/爆炸,可以处理长达数百步的序列。
灵活性高:可以堆叠多层,也可以与其他网络(如CNN、Attention)结合。
工程成熟:各种深度学习框架均有高效实现,且有很多预训练变体。
❌ 缺点
计算复杂度高:每个时间步需要计算4个全连接层(遗忘门、输入门、输出门、候选状态),参数量约为传统RNN的4倍,训练和推理较慢。
难以并行:LSTM本质是递归结构,后一个时间步依赖前一步的输出,无法像Transformer那样进行大规模并行计算。
并非万能:在超长序列(数千步)上仍有信息衰减,且对随机打乱的序列不敏感。
6. 总结
LSTM是RNN家族中最经典、最成功的变体之一。它通过遗忘门、输入门、输出门和细胞状态实现了对长期记忆的精细控制,解决了原始RNN的梯度问题。虽然近年来Transformer等模型在多数NLP任务上取得了更好效果,但LSTM在时间序列预测、语音识别、小规模序列建模等场景中依然具有重要价值。掌握LSTM的内部原理和PyTorch实现,是深入理解序列模型的关键一步。
参考文献
Hochreiter, S., & Schmidhuber, J. (1997). Long short-term memory.Neural computation, 9(8), 1735-1780.
Gers, F. A., Schmidhuber, J., & Cummins, F. (2000). Learning to forget: Continual prediction with LSTM.Neural computation, 12(10), 2451-2471.
希望本文能帮助你彻底搞懂LSTM!如果有任何疑问,欢迎在评论区留言讨论。