循环神经网络
在深度学习领域,循环神经网络具有记忆能力,它可以根据以前的记忆来处理新的任务。记忆力在很有任务上是很有用的,比如在一场电影中推断下一个时间点的场景,这个时候仅依赖于现在的情景并不够,还需要依赖于前面发生的情节,对于这样一些不仅依赖于当前情况,还依赖于过去情况的问题,传统的神经网络结构不能很好地处理,而基于记忆的网络模型却能够完成这个任务。
LSTM
LSTM是循环神经网络的变式,它能够很好的解决长时间依赖的问题,这种循环网络结构非常流行,使用很广泛。LSTM使用三个门来控制,这三个门分别是输入门、遗忘门和输出门。输入门控制着网络的输入,遗忘门控制着网络的记忆单元,输出门控制着网络的输出。这三个门中最重要的门就是遗忘门,正是因为遗忘门的存在,使得LSTM具体了长时记忆的功能,它能够自己学习到哪些记忆将被保留,而哪些记忆将会被去掉。
图片分类问题
一般来说卷积神经网络是处理图片的能手,但这并不意味着循环神经网络不具备处理图片的能力,我们依然可以使用循环神经网络完成图片的分类任务。
手写字体识别
手写字体识别是一个非常经典的图片分类问题,手写字体识别的数据集的图片都是单通道的,图片大小是28*28。我们知道循环神经网络是处理序列数据的,所以我们可以将图片看成是序列数据,将每张图片看作是长为 28 的序列,序列中的每个元素的特征维度是28,这样就将图片变成了一个序列。同时考虑到循环神经网络的记忆性,当图片从左往右按时间步输入网络的时候,网络可以记忆住前面观察到的东西,也就是说一张图片虽然被切割成了28 份,但是网络能够通过记住前面的部分,同时和后面的部分结合从而得到最终的分类结果。
代码实现
import torch
from torch import nn
import numpy as np
from torchvision.datasets import mnist
from torch.utils.data import DataLoader
from torchvision import transforms
from datetime import datetime
def get_acc(output, label):
total = output.shape[0]
_, pred_label = output.max(1)#求每行的最大就是最有可能的类别
num_correct = (pred_label == label).sum().float()
return num_correct / total
data_tf=transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize([0.5],[0.5])
]
)
train_set =mnist.MNIST('./adata',train=True,transform=data_tf,download=True)
test_set =mnist.MNIST('./adata',train=False,transform=data_tf,download=True)
train_data=DataLoader(train_set,batch_size=64,shuffle=True)
test_data =DataLoader(test_set,batch_size=128,shuffle=True)
class RNN(nn.Module):
def __init__(self,in_dim,hidden_dim,n_layer,n_class):
super(RNN,self).__init__()
self.n_layer=n_layer
self.hidden_dim=hidden_dim
self.lstm=nn.LSTM(in_dim,hidden_dim,n_layer,batch_first=True)
self.classifier=nn.Linear(hidden_dim,n_class)
def forward(self,x):
out,_ =self.lstm(x)
out =out[:,-1,:]#取最后一个时间步
out =self.classifier(out)
return out
net=RNN(28,50,2,10)
criterion =nn.CrossEntropyLoss()#定义损失函数
optimizer =torch.optim.SGD(net.parameters(),1e-1)
prev_time=datetime.now()
for epoch in range(30):
train_loss=0
train_acc =0
net =net.train()
for im ,label in train_data:#im,label为一批数据,也就是64个样本
im = im.squeeze(1)
output =net(im)
loss =criterion(output ,label)
#反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss +=loss.data.float()
train_acc +=get_acc(output,label)
cur_time =datetime.now()
h,remainder =divmod((cur_time-prev_time).seconds,3600)
m,s=divmod(remainder,60)
time_str ="Time %02d:%02d:%02d"%(h,m,s)
valid_loss=0
valid_acc=0
net =net.eval()
for im,label in test_data:
im = im.squeeze(1)
output =net(im)
loss= criterion(output,label)
valid_loss +=loss.data.float()
valid_acc +=get_acc(output,label)
epoch_str=(
"Epoch %d. Train Loss %f,Train Acc:%f,Valid Loss: %f,Valid Acc: %f ,"
%(epoch,train_loss/len(train_data),
train_acc /len(train_data),
valid_loss/len(test_data),
valid_acc /len(test_data)))
prev_time=cur_time
print(epoch_str+time_str)
程序运行结果
可以看到,我们将手写字体识别的数据放入到LSTM模型中之后,训练准确率和测试准确率高达99.7%以上。
以上就是基于深度学习框架pytorch搭建循环神经网络LSTM完成手写字体识别的全部原理和实现