简介
整个项目采用attention ocr的思路进行中文场景文字识别,整个用resnet34采集特征,用transformer的方式做解码器。网络在训练的时候可以并行进行相较于现在常用的GRU做解码器训练速度会快很多。推理的时候还是采用和GRU做解码器同样的方式,将上一步的输出用于下一步的输出,速度并未提升。
1. 网络结构
1.1 编码部分
特征编码部分主要采用resnet34作为主要结构,网络的输入为1285121灰度图片,经过卷积和池化之后输出的特征图为:8321024,其中,残差网络部分输出的特征维度为512维,经过1*1Conv升维到1024维。
说明:
- 通过实际的实验发现,图片的尺寸越大,识别的精度越好,所以将图片高度定为128,宽度512是取的所有样本集图片等比例缩放后的平均数。
将832的二维特征,展开成一维1256的一维向量之后送入解码器进行解码。
1.2 解码器
解码器采用并行的多头注意力机制的transformer结构。首先是目标输出的词嵌入+位置编码,然后是目标序列的自注意机制,然后多头注意力机制。
在多头注意力机制中,多头数量设置为16,目标序列词嵌入的维度为1024,前向传播中的隐含层的维度为2048.
每个目标序列的开始插入‘2’(表示开始),结束的位置加上‘3’(表示结束),然后所有的序列用‘0’填充到115(最长的目标序列长度)的长度。
1.3 损失函数
采用交叉熵作为损失函数
2. 实验细节
2.1 数据
中文识别的数据集较少,而且本次百度提供的训练集有20万张,测试集有7.6万,已经很大了,最好是能找一个量级更大或相当的数据集。
本次只是将ICDAR2017的数据集添加进来了,在最后的结果上并没有明显的提升。最后训练用图片为:222543张
采用有字典的推理方式,字符集共:4470个
2.2 应用细节
采用两块V100 GPU显卡,每块16GB的显存,batchsize为128,图片统一缩放为128*512的灰度图,初始学习速率为0.001,到第9代和13代的时候,分别减小为原来的0.5,在训练和测试的时候未作任何数据增强。
2.3 实验结果
2.3.1 图片尺寸对比
CNN Bone | Input size | Accuracy |
---|---|---|
Resnet34 | 32*128 | 70.808 |
Resnet34 | 64*128 | 74.454 |
Resnet | 128*512 | 79.286 |
2.3.2 transformer 解码器部分参数比较
特征维度 | 多头数量 | Accuracy |
---|---|---|
512 | 8 | 77.719 |
1024 | 16 | 79.2 |