本文章是该视频的一部分,该部分的案例代码使用RNN做一个简单的实验,其余部分见作者的其他文章。
一、什么是循环神经网络
循环神经网络的来源是为了刻画一个序列当前的输出与之前信息的关系。从网络结构上,循环神经网络会记忆之前的信息,并利用之前的信息影响后面结点的输出。即:循环神经网络的隐藏层之间的结点是有连接的,隐藏层的输入不仅包括输入层的输出,还包括上一时刻隐藏层的输出。
循环神经网络对于每一个时刻的输入结合当前模型的状态给出一个输出。循环神经网络可以看做同一神经网络被无限复制的结果,出于优化考虑,现实生活中无法做到真正的无限循环。
循环神经网络中的参数在不同时刻是共享的。
为了将当前时刻的状态转换成最终的输出,循环神经网络需要另一个全连接神经网络完成此过程。不同时刻用于输出的全连接神经网络中的参数也是一致的。
循环神经网络的总损失:所有时刻(或部分时刻)上损失函数的总和。
循环神经网络可以更好地利用传统神经网络结构所不能建模的信息,但同时,带来更大的技术挑战——长期依赖(long-term dependencies)问题。
二、循环神经网络能干什么
RNNs已经被在实践中证明对NLP是非常成功的。如词向量表达、语句合法性检查、词性标注等。在RNNs中,目前使用最广泛最成功的模型便是LSTMs(Long Short-Term Memory,长短时记忆模型)模型,该模型通常比vanilla RNNs能够更好地对长短时依赖进行表达,该模型相对于一般的RNNs,只是在隐藏层做了手脚。
三、RNN使用案例,代码如下:
import torch
input_size = 4
hidden_size = 4
batch_size = 1
idx2char = ['e', 'h', 'l', 'o']
x_data = [1, 0, 2, 2, 3] # hello
y_data = [3, 1, 2, 3, 2] # ohlol
one_hot_lookup = [
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1],
]
x_one_hot = [one_hot_lookup[x] for x in x_data]
inputs = torch.Tensor(x_one_hot).view(-1, batch_size, input_size)
labels = torch.LongTensor(y_data).view(-1, 1)
class Model(torch.nn.Module):
def __init__(self, input_size, hidden_size, batch_size):
super(Model, self).__init__()
self.batch_size = batch_size
self.input_size = input_size
self.hidden_size = hidden_size
self.rnncell = torch.nn.RNNCell(input_size=self.input_size,
hidden_size=self.hidden_size)
def forward(self, input, hidden):
hidden = self.rnncell(input, hidden) # ht = cell(xt,ht-1)
return hidden
def init_hidden(self): # 初始的隐层h0全为0
return torch.zeros(self.batch_size, self.hidden_size)
net = Model(input_size, hidden_size, batch_size)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.1)
for epoch in range(15):
loss = 0
optimizer.zero_grad() # 优化器归零
hidden = net.init_hidden()
print('Predicted string: ', end='')
for input, label in zip(inputs, labels): # inputs(序列长度,batchsize,inputsize)
hidden = net(input, hidden) # input(batchsize,inputsize)
loss += criterion(hidden, label) # labels(序列长度seqsize,1)
_, idx = hidden.max(dim=1) # label(1)
print(idx2char[idx.item()], end='')
loss.backward() # 反向传播
optimizer.step() # 优化器更新
print(', Epoch [%d/15] loss=%.4f ' % (epoch+1, loss.item()))
运行结果如下:
视频截图如下: