简介
整个项目采用attention ocr的思路进行中文场景文字识别,整个用resnet34采集特征,用transformer的方式做解码器。网络在训练的时候可以并行进行相较于现在常用的GRU做解码器训练速度会快很多。推理的时候还是采用和GRU做解码器同样的方式,将上一步的输出用于下一步的输出,速度并未提升。

1. 网络结构

transformerOcr_缩放
1.1 编码部分
特征编码部分主要采用resnet34作为主要结构,网络的输入为1285121灰度图片,经过卷积和池化之后输出的特征图为:8321024,其中,残差网络部分输出的特征维度为512维,经过1*1Conv升维到1024维。
说明:

  1. 通过实际的实验发现,图片的尺寸越大,识别的精度越好,所以将图片高度定为128,宽度512是取的所有样本集图片等比例缩放后的平均数。

将832的二维特征,展开成一维1256的一维向量之后送入解码器进行解码。

1.2 解码器
解码器采用并行的多头注意力机制的transformer结构。首先是目标输出的词嵌入+位置编码,然后是目标序列的自注意机制,然后多头注意力机制。
在多头注意力机制中,多头数量设置为16,目标序列词嵌入的维度为1024,前向传播中的隐含层的维度为2048.
每个目标序列的开始插入‘2’(表示开始),结束的位置加上‘3’(表示结束),然后所有的序列用‘0’填充到115(最长的目标序列长度)的长度。

1.3 损失函数
采用交叉熵作为损失函数

2. 实验细节

2.1 数据
中文识别的数据集较少,而且本次百度提供的训练集有20万张,测试集有7.6万,已经很大了,最好是能找一个量级更大或相当的数据集。
本次只是将ICDAR2017的数据集添加进来了,在最后的结果上并没有明显的提升。最后训练用图片为:222543张
采用有字典的推理方式,字符集共:4470个

2.2 应用细节
采用两块V100 GPU显卡,每块16GB的显存,batchsize为128,图片统一缩放为128*512的灰度图,初始学习速率为0.001,到第9代和13代的时候,分别减小为原来的0.5,在训练和测试的时候未作任何数据增强。

2.3 实验结果
2.3.1 图片尺寸对比

CNN Bone Input size Accuracy
Resnet34 32*128 70.808
Resnet34 64*128 74.454
Resnet 128*512 79.286

2.3.2 transformer 解码器部分参数比较

特征维度 多头数量 Accuracy
512 8 77.719
1024 16 79.2