RNN循环神经网络与序列建模

🎙️ 语音朗读 当前: 晓晓 (温柔女声)

RNN循环神经网络与序列建模

循环神经网络(Recurrent Neural Network,RNN)是一类专门处理序列数据的神经网络,在自然语言处理、时间序列预测等领域有着广泛应用。

RNN基本原理

RNN的核心思想是在处理序列数据时,利用隐藏状态(hidden state)保存历史信息:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import numpy as np

class SimpleRNN:
def __init__(self, input_size, hidden_size, output_size):
self.hidden_size = hidden_size
# 初始化参数
self.Wxh = np.random.randn(hidden_size, input_size) * 0.01
self.Whh = np.random.randn(hidden_size, hidden_size) * 0.01
self.Why = np.random.randn(output_size, hidden_size) * 0.01
self.bh = np.zeros((hidden_size, 1))
self.by = np.zeros((output_size, 1))

def forward(self, inputs):
h = np.zeros((self.hidden_size, 1))
self.last_inputs = inputs
self.last_hs = {0: h}

for i, x in enumerate(inputs):
h = np.tanh(self.Wxh @ x + self.Whh @ h + self.bh)
self.last_hs[i + 1] = h

y = self.Why @ h + self.by
return y, h

def backward(self, d_y, learn_rate=0.001):
n = len(self.last_inputs)
d_Why = d_y @ self.last_hs[n].T
d_by = d_y
d_Whh = np.zeros_like(self.Whh)
d_Wxh = np.zeros_like(self.Wxh)
d_bh = np.zeros_like(self.bh)
d_h = self.Why.T @ d_y

for t in reversed(range(n)):
temp = (1 - self.last_hs[t + 1] ** 2) * d_h
d_Wxh += temp @ self.last_inputs[t].T
d_Whh += temp @ self.last_hs[t].T
d_bh += temp
d_h = self.Whh.T @ temp

# 梯度裁剪
for d in [d_Wxh, d_Whh, d_Why, d_bh, d_by]:
np.clip(d, -5, 5, out=d)

self.Wxh -= learn_rate * d_Wxh
self.Whh -= learn_rate * d_Whh
self.Why -= learn_rate * d_Why
self.bh -= learn_rate * d_bh
self.by -= learn_rate * d_by

梯度消失与梯度爆炸

标准RNN在处理长序列时面临严重的梯度消失或梯度爆炸问题。这是因为在反向传播过程中,梯度会不断连乘,导致:

  • 梯度消失:梯度指数级衰减,远距离依赖无法学习
  • 梯度爆炸:梯度指数级增长,训练不稳定

LSTM:长短期记忆网络

LSTM通过引入门控机制有效解决了梯度消失问题:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
import torch.nn as nn

class LSTMCell(nn.Module):
def __init__(self, input_size, hidden_size):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size

# 遗忘门
self.forget_gate = nn.Linear(input_size + hidden_size, hidden_size)
# 输入门
self.input_gate = nn.Linear(input_size + hidden_size, hidden_size)
# 候选记忆
self.candidate_gate = nn.Linear(input_size + hidden_size, hidden_size)
# 输出门
self.output_gate = nn.Linear(input_size + hidden_size, hidden_size)

def forward(self, x, state):
h, c = state
combined = torch.cat([x, h], dim=1)

forget = torch.sigmoid(self.forget_gate(combined))
input_gate = torch.sigmoid(self.input_gate(combined))
candidate = torch.tanh(self.candidate_gate(combined))
output_gate = torch.sigmoid(self.output_gate(combined))

c_new = forget * c + input_gate * candidate
h_new = output_gate * torch.tanh(c_new)

return h_new, c_new

LSTM的三个关键门:

  1. 遗忘门:决定丢弃哪些历史信息
  2. 输入门:决定更新哪些新信息
  3. 输出门:决定输出哪些信息

GRU:门控循环单元

GRU是LSTM的简化版本,合并了遗忘门和输入门:

1
2
3
4
5
6
7
8
9
10
11
12
13
class GRUModel(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, num_classes):
super().__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, num_classes)

def forward(self, x):
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
out, _ = self.gru(x, h0)
out = self.fc(out[:, -1, :])
return out

双向RNN

双向RNN同时考虑前向和后向信息,在NLP任务中表现更优:

1
2
3
4
5
6
7
8
9
10
11
12
13
class BiLSTM(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_size, num_classes):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.lstm = nn.LSTM(embedding_dim, hidden_size,
batch_first=True, bidirectional=True)
self.fc = nn.Linear(hidden_size * 2, num_classes)

def forward(self, x):
embedded = self.embedding(x)
lstm_out, _ = self.lstm(embedded)
out = self.fc(lstm_out[:, -1, :])
return out

序列建模的应用

RNN及其变体在以下领域有广泛应用:

  • 机器翻译:Seq2Seq模型
  • 文本生成:语言模型
  • 情感分析:文本分类
  • 时间序列预测:股票预测、天气预测
  • 语音识别:声学模型

总结

RNN是处理序列数据的基础架构,但标准RNN存在梯度消失问题。LSTM和GRU通过门控机制有效缓解了这一问题,成为序列建模的主流选择。双向RNN进一步提升了模型性能,在NLP领域尤为突出。

© 2019-2026 ovo$^{mc^2}$ All Rights Reserved. | 站点总访问 28969 次 | 访客 19045
Theme by hiero