前面的博文我们讲了LSTM的原理与分析,这一篇我们用pytorch类LSTM做测试

完整测试代码如下,用于进行MNIST数据集测试,主要学习LSTM类的输入输出维度。

这里定义的LSTM模型是用了三层深度模型,双向的,输出层增加了线性转换。

完整代码如下:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms

# step 1:===========================================定义LSTM结构
# 双向网络,三层网络,带上最后输出的先线性层
class Rnn(nn.Module):
def __init__(self, input_dim, hidden_dim, n_layer, n_classes):
super(Rnn, self).__init__()

# 这里把 batch_size 放在第一维度
# 使用双向循环LSTM
self.lstm = nn.LSTM(input_dim, hidden_dim, n_layer, batch_first=True, bidirectional=True)

# 这个是网络最后的线性层
self.classifier = nn.Linear(hidden_dim, n_classes)

# 默认输入数据格式:
# input(seq_len, batch_size, input_size)
# h0(num_layers * num_directions, batch_size, hidden_size)
# c0(num_layers * num_directions, batch_size, hidden_size)
# 默认输出数据格式:
# output(seq_len, batch_size, hidden_size * num_directions)
# hn(num_layers * num_directions, batch_size, hidden_size)
# cn(num_layers * num_directions, batch_size, hidden_size)

# batch_first=True 在此条件下,batch_size是处在第一个维度的。
def forward(self, input): # input [128, 28, 28]
out, (h_n, c_n) = self.lstm(input)

# x = out[:, -1, :] # 此时可以从out中获得最终输出的状态h
x = h_n[-1, :, :]
x = self.classifier(x)
return x

# 实例化网络对象, 输入数据的维度是28维度,隐藏层维度是10,3层网络,10个线性分类输出
lstmNet = Rnn(28, 10, 3, 10)



# step 2:===========================================加载MNIST数据,并形成批量数据
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
])

trainset = torchvision.datasets.MNIST(root='./data', train=True, download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)

testset = torchvision.datasets.MNIST(root='./data', train=False, download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False)


# step 3:===========================================定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(lstmNet.parameters(), lr=0.1, momentum=0.9)



# step 3:===========================================定义损训练过程和测试过程
# Training
def train(epoch):
print('\nEpoch: %d' % epoch)
lstmNet.train()
train_loss = 0
correct = 0
total = 0

# inputs = [128, 1, 28, 28], targets = [128]
for batch_idx, (inputs, targets) in enumerate(trainloader):

optimizer.zero_grad()
outputs = lstmNet(torch.squeeze(inputs, 1))
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()

train_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()

print(batch_idx, 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
% (train_loss/(batch_idx+1), 100.*correct/total, correct, total))

def test(epoch):
global best_acc
lstmNet.eval()
test_loss = 0
correct = 0
total = 0
with torch.no_grad():

# inputs = [128, 1, 28, 28], targets = [128]
for batch_idx, (inputs, targets) in enumerate(testloader):
outputs = lstmNet(torch.squeeze(inputs, 1))
loss = criterion(outputs, targets)

test_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()

print(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
% (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

for epoch in range(100):
train(epoch)
test(epoch)

测试结果还是比较好的,甚至高达99.5%(节选自输出打印):
89 Loss: 0.015 | Acc: 99.540% (11467/11520)
90 Loss: 0.015 | Acc: 99.536% (11594/11648)
91 Loss: 0.015 | Acc: 99.533% (11721/11776)
92 Loss: 0.015 | Acc: 99.538% (11849/11904)
93 Loss: 0.015 | Acc: 99.535% (11976/12032)