《PP-OCRv2》论文精读:蒸馏让PP-OCRv2获得了7%的OCR性能提升_论文精读

  1. PP-OCR: A Practical Ultra Lightweight OCR System论文地址
  2. PP-OCRv2: Bag of Tricks for Ultra Lightweight OCR System论文地址
  3. PP-OCRv3: More Attempts for the Improvement of Ultra Lightweight OCR System论文地址
  4. PaddleOCR Github OCR工具库 43.5K个star
  • PP-OCRv1由百度发表于2020年9月,截止2024年10月份,引用数是197。是PP-OCR系列的第一篇论文。
  • PP-OCRv2由百度发表于2021年9月,截止2024年10月份,引用数是59。是PP-OCR系列的第二篇论文。
  • PP-OCRv3由百度发表于2022年6月,截止2024年10月份,引用数是83。是PP-OCR系列的第三篇论文。
  • PP-OCRv4由百度发表于2023年9月,截止2024年10月份,代码已发布,论文还没发布。


文章目录

  • 1. 论文摘要
  • 2. 实验结果
  • 2.1. 实验设置
  • 2.1.1. 数据集
  • 2.1.2. 实验细节
  • 2.2. 文本检测,Hmean=0.795,提升3.6%
  • 2.3. 文本识别,Acc=0.748,提升8.1%
  • 2.4. PP-OCRv2整体性能,Hmean 50.3% --> 57.6%
  • 2.4.1. 召回率、准确率、Hmean和F-score复习
  • 3. 引言介绍
  • 4. 增强策略
  • 4.1. 文本检测
  • 4.1.1. Collaborative Mutual Learning (CML)蒸馏方法
  • 4.1.2. CopyPaste数据增强
  • 4.2. Text Recognition文本识别
  • 4.2.1. Lightweight CPU Network (PP-LCNet)轻量化CPU网络
  • 4.2.2. Unified-Deep Mutual Learning (U-DML)
  • 4.2.3. Enhanced CTCLoss 增强CTC损失


1. 论文摘要

论文摘要描述了对先前提出的光学字符识别(OCR)系统的一个改进版本。

  1. 研究背景
  • OCR系统在多种应用场合中被广泛使用。
  • 设计一个准确且高效的OCR系统仍然是一个挑战。
  • 在先前的工作中,研究团队提出了一个实用的超轻量级OCR系统(称为PP-OCR),旨在平衡准确性和效率。
  1. 新版本介绍
  • 为了提高PP-OCR系统的准确性并保持其高效性,研究团队在此文中提出了一种更强大的OCR系统,命名为PP-OCRv2。
  1. 技术改进
  • PP-OCRv2引入了一系列技巧来训练更好的文本检测器和识别器。这些技巧包括:
  • 协同互学习(Collaborative Mutual Learning, CML)
  • CopyPaste:一种数据增强技术
  • 轻量级CPU网络(PP-LCNet):专门设计用于CPU环境下的轻量级神经网络架构。
  • 统一深度互学习(Unified-Deep Mutual Learning, U-DML)
  • 增强版CTC损失(Enhanced CTCLoss)
  1. 实验结果
  • 在真实数据上的实验表明,PP-OCRv2的精度比PP-OCRv1高出7%,并且在相同的推理成本下,其表现与使用ResNet系列作为骨干网络的服务器模型相当。

2. 实验结果

2.1. 实验设置

2.1.1. 数据集

我们在与我们之前的工作PP-OCRv1(Du et al. 2020)中使用的相同数据集上进行实验,如表1所示。

《PP-OCRv2》论文精读:蒸馏让PP-OCRv2获得了7%的OCR性能提升_ocr_02

对于文本检测,有9.7万张训练图像和500张验证图像。训练图像由6.8万张真实场景图像和2.9万张合成图像组成。从百度图像搜索和公共数据集包括LSVT(发布于2019)、RCTW-17(发布于2017)、MTWI 2018(发布于2018)、CASIA-10K(发布于2018)、SROIE(发布于2019)、MLT 2019(发布于2019)、BDI(发布于2011)、MSRATD500(发布于2012)和CCPD 2019(发布于2018)。合成图像主要关注长文本、多方向文本和表中文本的场景。验证图像均来自真实场景。

对于文本识别,有1.79千万张训练图像和1.87万张验证图像。在训练图像中,有190万张图像是真实的场景图像,它们来自于一些公共数据集和百度图像搜索。所使用的公共数据集包括LSVT、RCTW-17、MTWI 2018和CCPD 2019。剩下的1.6千万张合成图像主要关注于不同背景的场景、旋转、透视转换、噪声、垂直文本等。合成图像的语料库来自于真实的场景图像。所有的验证图像也都来自于真实的场景。

此外,我们还收集了300张不同真实应用场景的图像,以评估整个OCR系统,包括合同样本、车牌、铭牌、火车票、测试表、表格、证书、街景图像、名片、数字仪表等。图7和图8显示了测试集的一些图像。

用于文本检测和文本识别的数据合成工具从text render中进行了修改得到(Sanster 2018)。

2.1.2. 实验细节

我们采用了PP-OCRv1(Du et al. 2020)中使用的大多数策略,如图2所示。我们使用Adam优化器来训练所有的模型,将初始学习率设置为0.001。不同之处在于,我们采用余弦学习率衰减作为检测模型训练的学习率计划,但只采用piece-wise decay分段衰减用于识别模型训练。检测和识别模型训练都用上了开始时几个Epoch的热身训练

《PP-OCRv2》论文精读:蒸馏让PP-OCRv2获得了7%的OCR性能提升_论文精读

对于文本检测,该模型共训练了700个Epoch,并进行了2个Epoch的热身训练。批量大小设置为每张卡8个。

对于文本识别,模型预热5个Epoch,初始学习率0.001训练700个Epoch,然后再训练100个Epoch,这100个Epoch期间学习率衰减到0.0001。每张卡的批量大小为128个。

在推理期间,Hmean用于评估文本检测器和端到端OCR系统的性能。句子的准确性Accuracy用于评价文本识别器的性能。GPU推理时间在一个T4 GPU上进行了测试。CPU推断时间在Intel ® Xeon ® Gold 6148上进行了测试。

2.2. 文本检测,Hmean=0.795,提升3.6%

表2显示了DML、CML和CopyPaste对文本检测的消融研究。基线模型为PP-OCR轻量级检测模型。在测试期间,输入图像的长边的大小被调整为960。数据显示:

  • DML可以将Hmean度量提高近2%
  • CML可以提高3%。
  • 数据增强方法coppaste,Hmean进一步提高0.6%。

因此,在相同的速度下,由于模型结构保持不变,PP-OCRv2检测模型比PP-OCR提高了3.6%。推理时间是包括预处理和后处理在内所消耗的总时间。

《PP-OCRv2》论文精读:蒸馏让PP-OCRv2获得了7%的OCR性能提升_数据_04

2.3. 文本识别,Acc=0.748,提升8.1%

表3显示了PP-LCNet、U-DML和Enhanced CTC loss 增强CTC损失的消融研究。将PP-LCNet与MV3相比,准确率可提高2.6%。虽然PP-LCNet的模型尺寸大了3M,但由于网络结构的合理设计,推理时间从7.7 ms减少到了6.2 ms。U-DML方法可以提高4.6%,这是一个显著的提高。此外,通过提高CTC损失,精度可以提高0.9%。因此,所有这些策略的准确率提高了8.1%,模型尺寸大了3M,但平均推断时间快了1.5 ms。

《PP-OCRv2》论文精读:蒸馏让PP-OCRv2获得了7%的OCR性能提升_数据_05

为了测试PP-LCNet的泛化能力,我们在整个模型设计过程中使用了ImageNet-1k等具有挑战性的数据集。表4显示了PP-LCNet和我们在ImageNet上选择的其他不同轻量级模型之间的精度-速度比较。很明显,PP-LCNet在速度和准确性方面都取得了更好的性能,即使与像MobileNetV3等非常有竞争力的网络相比。

《PP-OCRv2》论文精读:蒸馏让PP-OCRv2获得了7%的OCR性能提升_召回率_06

2.4. PP-OCRv2整体性能,Hmean 50.3% --> 57.6%

在表5中,我们比较了所提出的PP-OCRv2与之前的超轻量级和大规模的PP-OCR系统之间的性能。大规模的PP-OCR系统使用ResNet18 vd作为文本检测器主干,ResNet34 vd作为文本识别器主干,可以比超轻量级版本实现更高的Hmean但更慢的推理速度。可以发现,在相同的推理成本下,PPOCRv2的Hmean比PP-OCR mobile模型高7.3%,且可与PP-OCR server模型相媲美。

《PP-OCRv2》论文精读:蒸馏让PP-OCRv2获得了7%的OCR性能提升_论文精读_07

2.4.1. 召回率、准确率、Hmean和F-score复习

Hmean和F-score的计算方式一样

《PP-OCRv2》论文精读:蒸馏让PP-OCRv2获得了7%的OCR性能提升_论文精读_08

召回率Recall

让我们用小学生都能容易理解的方式来解释召回率(Recall)。

假设你在玩一个游戏,这个游戏叫做“找宝藏”。你的任务是在一堆沙子里找出所有的金币。这里,“金币”就像是我们要找出来的东西(也就是真正的正面例子)。

现在,我们来看看“召回率”是怎么回事:

  1. 找到的金币:你在游戏中找到了一些金币。这些找到的金币就相当于你正确识别出来的正面例子。
  2. 总共的金币:实际上在沙子里隐藏的所有金币的数量。这相当于所有实际存在的正面例子。
  3. 召回率:就是你找到的金币数量占总金币数量的比例。换句话说,召回率告诉你,你找到了多少比例的真正宝藏。

举个例子:

  • 如果沙子里总共有10枚金币,而你找到了7枚,那么你的召回率就是7/10=0.7 ,或者说是70%。

所以,召回率越高,说明你找到的真正宝藏越多,也就意味着你做得越好!

简单来说,召回率就是一个衡量你有多成功地找到所有你应该找到的东西的指标。 在OCR任务中,“金币”就是需要正确识别的文字或字符,“沙子”则是整个图片或文档。

准确率Precision

让我们继续用“找宝藏”的游戏来解释准确率(Precision)。

想象一下,你正在参加一个寻宝比赛。这次比赛的目标是从沙子里找到所有的金币。但是,你不仅找到了金币,还捡了一些石头。这里的“金币”是我们希望找到的东西(正确的识别结果),而“石头”是我们不小心捡到的错误物品(错误的识别结果)。

现在,让我们看看“准确率”是什么意思:

  1. 找到的金币:你找到了一些金币。这些金币是你正确识别出来的正面例子。
  2. 找到的石头:你也捡到了一些石头。这些石头是你错误地认为是金币的东西,实际上是负面例子却被错误地标记为正面例子。
  3. 准确率:就是你找到的金币数量占你找到的所有东西(金币+石头)的比例。换句话说,准确率告诉你,你找到的所有东西中有多少是真的金币。

举个例子:

  • 如果你一共找到了10件东西,其中有8枚是金币,2块是石头,那么你的准确率就是 8/10=0.8,或者说是80%。

因此,准确率越高,说明你找到的东西中真正的“金币”越多,错误的“石头”就越少,也就是说你找到的东西质量更高!

总结一下,准确率是一个衡量你找到的东西中有多少是真正正确的东西的指标。在OCR任务中,“金币”代表正确识别的字符或单词,“石头”则代表错误地被认为是字符或单词的非字符部分。

F-score是什么

可以更好地衡量不平衡测试数据集OCR系统的性能。

《PP-OCRv2》论文精读:蒸馏让PP-OCRv2获得了7%的OCR性能提升_数据_09

3. 引言介绍

光学字符识别OCR,在过去20年得到了广泛的研究,有各种应用场景,如文件电子化、身份认证、数字金融系统和车牌识别。

在实际构建OCR系统时,不仅考虑了精度,还考虑了计算效率。在之前,我们提出了一个实用的超轻量级OCR系统(PP-OCRv1)(Du et al. 2020),以平衡精度和效率。它由文本检测、检测框修正和文本识别三部分组成。可微二值化(DB)(Liao et al. 2020a)用于文本检测,CRNN(Shi,Bai,andaYao2016)用于文本识别。该系统采用了19种有效的策略来优化和缩小模型的尺寸。为了提高PP-OCR的精度和保持效率,本文提出了一种更鲁棒的OCR系统,即PP-OCRv2。它引入了一系列的技巧来训练一个更好的文本检测器和一个更好的文本识别器。

图2展示了PP-OCRv2的框架。大多数策略都遵循PP-OCR,如绿框所示。

《PP-OCRv2》论文精读:蒸馏让PP-OCRv2获得了7%的OCR性能提升_ocr_10

图2中,橙色方框中的策略是PP-OCRv2中的附加策略。在文本检测中,加入了协同互学习Collaborative(CML)和CopyPaste。

  • CML利用两个学生网络和一个教师网络来学习一个更稳健的文本检测器。
  • CopyPaste是一种新的数据增强技巧,已被证明可以有效地提高目标检测和实例分割任务的性能(Ghiasi et al. 2021)。我们表明,它也可以很好地用于文本检测任务。

在文本识别中,介绍了轻量级CPU网络(PP-LCNet)(Cui et al. 2021)、统一深度互学习Unified-Depp Mutual Learning(U-DML)和CenterLoss。

  • PP-LCNet是一种新设计的基于Intelcpu的轻量级骨干网,源自MobileNetV1(Howard et al. 2017)。
  • U-DML利用两个学生网络来学习一个更准确的文本识别器。
  • 中心损失CenterLoss的角色是放松相似角色的错误。我们进行了一系列的消融实验来验证上述策略的有效性。

此外,图2灰框中的策略在PP-OCR中是有效的。但这些在本文中没有得到验证。在未来,我们将采用它们来加快PP-OCRv2-tinty的推理过程。

4. 增强策略

4.1. 文本检测

4.1.1. Collaborative Mutual Learning (CML)蒸馏方法

我们提出了CML方法(Zhang et al. 2017)来解决文本检测蒸馏的问题。蒸馏有两个问题:

  • 1.如果教师模型的精度接近学生模型,一般蒸馏方法带来的改进将受到限制。
  • 2.如果教师模型的结构和学生模型的结构有很大的不同,那么一般的蒸馏方法所带来的改进也非常有限。

该框架是由多个模型组成的超级网络,分别命名为学生模型和教师模型,如图3所示。而CML方法可以实现在文本检测中,学生蒸馏后的准确性超过教师模型的准确性。

《PP-OCRv2》论文精读:蒸馏让PP-OCRv2获得了7%的OCR性能提升_ocr_11

CML旨在优化sub-student模型。冻结了教师模型的参数,只对子学生模型进行了设计损失训练。一般来说,子学生sub-student模型的监督信息包括ground truth label、另一个学生模型的后验熵 posterior entropy和教师模型的输出三个部分。相应地,有三个损失函数,包括地面真实损失ground truth loss Lgt,来自学生模型的同辈损失peer loss Ls和来自教师模型的蒸馏损失distill loss Lt

(1)地面真实损失Ground Truth Loss,称为GTLoss,是确保训练是由真实标签监督。我们使用DB算法(Liao et al. 2020b)来训练子学生sub-student模型。因此,地面真实损失Lgt是一个组合损失,它包括概率映射probability map lp的损失、二进制映射binary map lb的损失以及DB的阈值映射threshold map lt的损失。GTLoss的公式如下,其中lp、lb和lt分别为二元交叉熵binary cross-entropy损失、Dice损失和L1损失。α、β分别为默认值为5和10的超参数

《PP-OCRv2》论文精读:蒸馏让PP-OCRv2获得了7%的OCR性能提升_数据_12

(1.1)二元交叉熵binary cross-entropy损失
二元交叉熵损失(Binary Cross Entropy Loss,有时也被称作Log Loss)是用于二分类问题的一种损失函数。它衡量的是预测概率分布与真实标签之间的差异。对于单个样本,损失函数定义如下:

《PP-OCRv2》论文精读:蒸馏让PP-OCRv2获得了7%的OCR性能提升_ocr_13

如果要计算多个样本的总损失,则需要对每个样本的损失求平均或求和。下面是Python代码的一个简单实现,假设输入是两个一维数组,分别表示真实的标签和预测的概率:

import numpy as np

def binary_cross_entropy(y_true, y_pred, epsilon=1e-15):
    # 将预测概率限制在epsilon到1-epsilon之间以避免取对数时出现无穷大的值
    y_pred = np.clip(y_pred, epsilon, 1. - epsilon)
    # 计算损失
    loss = -np.mean(y_true * np.log(y_pred) + (1 - y_true) * np.log(1 - y_pred))
    return loss

# 示例数据
y_true = np.array([0, 1, 1, 0, 1])
y_pred = np.array([0.1, 0.9, 0.8, 0.2, 0.3])

# 计算损失
loss = binary_cross_entropy(y_true, y_pred)
print("Binary cross entropy loss:", loss)

在这个例子中,y_true 是真实标签的向量,而 y_pred 是模型预测的概率向量。我们使用了numpy的clip函数来确保预测值不会超出合法范围,从而避免了对数函数在边界值处的不稳定性。此外,我们使用了平均值来计算总损失,而不是简单的求和,这样可以使得损失值不受批量大小的影响。

(1.2)Dice 损失

《PP-OCRv2》论文精读:蒸馏让PP-OCRv2获得了7%的OCR性能提升_ocr_14

在文本分割任务中,标签通常是与原始图像相同尺寸的灰度图或索引图。每个像素点的值代表了该位置属于哪一个类别。例如,在文本检测二分类问题中,标签图的每个像素可能是0(背景)或1(前景);而在多分类问题中,每个像素可能是一个整数值,对应于类别索引。

扩展而言,对于多类别分割问题,标签图中的每个像素值对应一个类别ID。例如,如果有一个包含建筑物、道路、树木等类别的分割任务,那么标签图中不同的像素值就代表了不同的类别。在进行模型训练时,这些标签会被用来计算损失函数,比如Dice Loss。

(1.3)L1 损失
L1损失(也称为绝对误差损失或L1范数损失)是一种损失函数,用于衡量预测值和真实值之间的差异。它计算的是预测值与真实值之间绝对差值的平均或总和。L1损失的一个优点是它对异常值(outliers)不太敏感,因为它直接采用绝对值而不是平方值来衡量误差。

L1损失的数学表达式为:

《PP-OCRv2》论文精读:蒸馏让PP-OCRv2获得了7%的OCR性能提升_召回率_15

以下是一个用Python实现L1损失的简单示例:

import numpy as np

def l1_loss(y_true, y_pred):
    # 计算每个样本的绝对误差
    absolute_error = np.abs(y_true - y_pred)
    # 计算均值
    mean_absolute_error = np.mean(absolute_error)
    return mean_absolute_error

# 示例数据
y_true = np.array([1, 2, 3, 4, 5])  # 真实值
y_pred = np.array([1.5, 2.5, 3.5, 4.5, 5.5])  # 预测值

# 计算L1损失
l1_loss_value = l1_loss(y_true, y_pred)
print("L1 Loss:", l1_loss_value)

在这个例子中,我们定义了一个l1_loss函数,它接收两个参数:真实的标签值y_true和模型的预测值y_pred。函数内部首先计算了两个数组元素之间的绝对差值,然后计算了这些绝对差值的平均值。

如果你正在使用PyTorch框架,也可以非常方便地使用内置的L1损失函数:

import torch
import torch.nn as nn

# 使用PyTorch实现L1损失
def l1_loss_torch(y_true, y_pred):
    criterion = nn.L1Loss()
    loss = criterion(y_pred, y_true)
    return loss

# 示例数据
y_true = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32)  # 真实值
y_pred = torch.tensor([1.5, 2.5, 3.5, 4.5, 5.5], dtype=torch.float32)  # 预测值

# 计算L1损失
l1_loss_value = l1_loss_torch(y_true, y_pred)
print("L1 Loss:", l1_loss_value.item())

这里,我们使用了torch.nn中的L1Loss类来创建损失计算对象,并通过调用该对象来计算给定真值和预测值之间的L1损失。注意,我们需要确保传递给损失函数的对象是torch.Tensor类型,并且可以通过.item()方法来提取出标量值。

(2)同辈损失peer loss。子学生sub-student模型参照DML方法相互学习(Zhang et al. 2017)。但与DML的不同之处在于,子学生sub-student模型在每次迭代中都会同时进行训练,以加快训练过程。KL散度用于计算学生模型之间的距离。学生模式之间的同辈损失peer loss如下:

《PP-OCRv2》论文精读:蒸馏让PP-OCRv2获得了7%的OCR性能提升_论文精读_16

(2.1)KL

KL散度(Kullback-Leibler divergence),又称为相对熵,是用来衡量两个概率分布之间的差异的一种统计量。KL散度经常用于机器学习和信息理论领域,尤其是在处理概率分布估计的问题时。

KL散度的定义如下:

《PP-OCRv2》论文精读:蒸馏让PP-OCRv2获得了7%的OCR性能提升_数据集_17

下面是一个使用Python和NumPy来实现KL散度损失的简单示例:

import numpy as np

def kl_divergence(P, Q):
    # 确保P和Q是非负的,并且至少有一个非零元素
    assert np.all(P >= 0), "P should be non-negative."
    assert np.all(Q >= 0), "Q should be non-negative."
    assert np.any(P > 0), "At least one element of P should be greater than 0."
    assert np.any(Q > 0), "At least one element of Q should be greater than 0."
    
    # 对于Q中为0的位置,设置一个非常小的正数来避免log(0)的问题
    Q = np.where(Q == 0, np.finfo(float).eps, Q)
    
    # 计算KL散度
    kl_div = np.sum(np.where(P != 0, P * np.log(P / Q), 0))
    return kl_div

# 示例数据
P = np.array([0.1, 0.2, 0.7])  # 真实分布
Q = np.array([0.2, 0.3, 0.5])  # 预测分布

# 计算KL散度
kl_div = kl_divergence(P, Q)
print("KL Divergence:", kl_div)

在这个例子中,我们定义了一个kl_divergence函数,该函数接收两个概率分布向量PQ,并返回它们之间的KL散度。注意,我们在计算过程中加入了检查以确保输入是非负的,并且为了避免log(0)导致的问题,我们将Q中为0的地方替换成了一个非常小的正数。

如果你正在使用深度学习框架如PyTorch,你可以利用其提供的torch.nn.functional.kl_div函数来实现KL散度:

import torch
import torch.nn.functional as F

def kl_divergence_torch(P, Q):
    # 将概率分布转换为logits形式
    P_logit = torch.log(P)
    Q_logit = torch.log(Q)
    
    # 使用F.kl_div函数计算KL散度
    kl_div = F.kl_div(P_logit, Q_logit, reduction='batchmean')
    return kl_div

# 示例数据
P = torch.tensor([0.1, 0.2, 0.7], dtype=torch.float32)  # 真实分布
Q = torch.tensor([0.2, 0.3, 0.5], dtype=torch.float32)  # 预测分布

# 计算KL散度
kl_div = kl_divergence_torch(P, Q)
print("KL Divergence:", kl_div.item())

请注意,在使用PyTorch的F.kl_div函数时,默认情况下它期望输入是对数概率(log-probabilities),而不是普通的概率分布。因此,在使用这个函数之前,通常需要将概率分布转换为对数概率的形式。

蒸馏损失distill loss反映了教师模型对子学生sub-student模型的监督。教师模式可以为学生模式提供丰富的知识,这对提高成绩很重要。为了获得更好的知识,我们扩展了教师模型的响应概率图response probability maps,以增加目标面积the object area。该操作可以略微提高教师模型的准确性。蒸馏损失distill loss如下,其中lp、lb分别为二进制交叉熵损失binary cross-entropy loss和Dice loss损失。而γ是超参数的默认值为5。fdila是膨胀函数,其核是矩阵[[1,1],[1,1]]。

《PP-OCRv2》论文精读:蒸馏让PP-OCRv2获得了7%的OCR性能提升_论文精读_18

最后,在训练PP-OCR检测模型的CML方法中所使用的损失函数如下。

《PP-OCRv2》论文精读:蒸馏让PP-OCRv2获得了7%的OCR性能提升_数据集_19

4.1.2. CopyPaste数据增强

CopyPaste是一种新的数据增强技巧,已被证明可以有效地提高目标检测和实例分割任务的性能(Ghiasi et al. 2021)。它可以合成文本实例,平衡训练集中正负样本的比例,这是传统的图像旋转、随机翻转和随机裁剪无法实现的。由于前景中的所有文本都是独立的,CopyPaste粘贴文本在随机选择的背景图像上没有重叠。图4是CopyPaste的一个例子。

《PP-OCRv2》论文精读:蒸馏让PP-OCRv2获得了7%的OCR性能提升_数据_20

4.2. Text Recognition文本识别

4.2.1. Lightweight CPU Network (PP-LCNet)轻量化CPU网络

为了在英特尔CPU上获得更好的精度-速度的权衡,我们设计了一个基于英特尔CPU的轻量级骨干网络,它提供了一个更快、更准确的OCR识别算法。整个网络的结构如图5所示。与MobileNetV3相比,由于MobileNetV1的结构使得在英特尔CPU上启用MKLDNN时更容易优化推理速度,因此该网络基于MobileNetV1(Howard等人,2017)。为了使MobileNetV1具有更强的特征提取能力,我们对它的网络结构进行了一些修改。改进策略将从以下四个方面进行说明。

《PP-OCRv2》论文精读:蒸馏让PP-OCRv2获得了7%的OCR性能提升_ocr_21

1. Better activation function
为了提高MobileNetV1的拟合能力,我们将原本网络中的激活函数ReLU替换成H-Swish,这可以显著提高精度,只略微增加推理时间。

2. SE modules at appropriate positions
SE(Hu,Shen,和Sun 2018)模块自被提出以来已被大量网络使用。这是一种很好的方式来加权网络通道,以获得更好的功能,并被用于许多轻量级网络,如MobileNetV3(Howard et al. 2019)。但是,在Intelcpu上,SE模块增加了推理时间,因此我们不能在整个网络中使用它。事实上,通过大量的实验,我们发现,如果它越接近网络的尾部,SE模块就越有效。所以我们只将SE模块添加到网络尾部的块中。这导致了一个更好的精度-速度平衡。SE模块中两层的激活函数分别为ReLU和H-Sigmoid。

3. Larger convolution kernels
卷积核的大小往往会影响网络的最终性能。在mixnet(Tan和Le 2019)中,作者分析了不同大小的卷积核对网络性能的影响,最后在网络的同一层中混合了不同大小的核。然而,这样的混合降低了模型的推理速度,所以我们试图在尽可能少地增加推理时间的情况下增加卷积核的大小。最后,我们将网络尾部的卷积核的大小设为5×5。

4. Larger dimensional 1x1 conv layer after GAP
在PPLCNet中,GAP后的网络输出维数较小,直接连接最终的分类层会失去特征的组合。为了使网络具有更强的拟合能力,我们将一个1280维大小的1x1 conv连接到最终的GAP层,这样可以在不增加推理时间的情况下增加模型的大小。通过这四个变化,我们的模型在ImageNet上表现良好,表4列出了在Intelcpu上的其他轻量级模型的指标。

《PP-OCRv2》论文精读:蒸馏让PP-OCRv2获得了7%的OCR性能提升_数据_22

4.2.2. Unified-Deep Mutual Learning (U-DML)

Unified-Deep Mutual Learning (U-DML)深度相互学习(Zhang et al. 2017)是一种两个学生网络相互学习的方法,不需要预先训练权重的更大的教师网络进行知识蒸馏。在DML中,对于图像分类任务,损失函数包含两部分: (1)学生网络和Ground Truth之间的损失函数。(2)学生网络输出软标签之间的Kullback–Leibler divergence(KL-Div)损失。

Heo提出了OverHaul(Heo et al. 2019),其中学生网络和教师网络之间的特征图距离用于蒸馏过程。对学生网络特征图进行变换,以保持特征图的对齐性。

为了避免训练过程中教师模型过于耗时的问题,本文在DML的基础上提出了U-DML,即在蒸馏过程中也要监督特征映射。图6显示了U-DML的框架。

《PP-OCRv2》论文精读:蒸馏让PP-OCRv2获得了7%的OCR性能提升_数据集_23

蒸馏过程有两个网络:

学生网络和教师网络。它们具有完全相同的网络结构,但具有不同的初始化权值。目标是对于相同的输入图像,两个网络可以得到相同的输出形状,不仅是预测结果,而且是特征图。

总损失函数由三部分组成:(1) CTC损失。由于这两个网络都是从头开始训练的,所以CTC损失可以用于网络的收敛;(2) DML损失。预计两个网络的最终输出分布是相同的,因此需要DML损失来保证两个网络之间分布的一致性;(3)特征损失。这两个网络的架构是相同的,因此它们的特征图应该是相同的,特征损失可以用来约束两个网络的中间特征图的距离。

1. CTC loss
CRNN是本文中文本识别的基础体系结构,它集成了feature extraction特征提取和sequence modeling序列建模。它采用了Connectionist Temporal Classification (CTC) loss(Graves et al. 2006),以避免预测和Ground Truth之间的不一致。由于这两个子网都是从头开始训练的,所以这两个子网都采用了CTC损耗。损失函数如下。

《PP-OCRv2》论文精读:蒸馏让PP-OCRv2获得了7%的OCR性能提升_召回率_24

其中,Shout表示学生网络的头输出,Thout表示教师网络的头输出。gt提供了输入图像的组真值Ground Truth标签。

2. DML loss
在DML中,每个子网络的参数分别进行更新。在这里,为了简化训练过程,我们计算了两个子网之间的KL散度损失,并同时更新了所有参数。DML损失如下。

《PP-OCRv2》论文精读:蒸馏让PP-OCRv2获得了7%的OCR性能提升_召回率_25

其中,KL(p||q)表示p和q的KL散度。Spout和Tpout可以计算如下。

《PP-OCRv2》论文精读:蒸馏让PP-OCRv2获得了7%的OCR性能提升_数据集_26

3. Feature loss
在训练过程中,我们希望学生网络的主干输出与教师网络相同。因此,与Overhaul类似,在蒸馏过程中采用了特征损失Feature loss。损失可以计算如下。

《PP-OCRv2》论文精读:蒸馏让PP-OCRv2获得了7%的OCR性能提升_召回率_27

其中Sbout是指学生网络的骨干输出,Tbout是指教师网络的骨干输出。这里使用了均方误差损失。需要注意的是,对于特征损失,不需要进行特征图转换,因为用于计算损失的两个特征图完全相同。

最后,U-DML训练过程的总损失如下所示。

《PP-OCRv2》论文精读:蒸馏让PP-OCRv2获得了7%的OCR性能提升_论文精读_28

在训练过程中,我们发现分段学习速率策略是一个更好的蒸馏选择。当使用特征损失Featrue loss时,模型需要较长的时间才能达到最佳精度,因此在文本识别精馏过程中采用了800个Epoch和分段策略。

此外,对于标准的CRNN体系结构,CTC-Head只使用了一个FC层,这对信息解码过程有点弱。因此,我们修改了CTC-Head部分,使用两个FC层,这导致了约1.5%的精度提高,而没有任何额外的推理时间成本。

4.2.3. Enhanced CTCLoss 增强CTC损失

在中文识别任务中存在着许多相似的特征。它们在外表上的差异很小,经常被错误地认识到。在PP-OCRv2中,我们设计了一个增强的CTCLoss,它结合了原始的CTCLoss和度量学习中的中心损失 CenterLoss的思想(Wen et al. 2016)。经过了一些改进,使其适合于序列识别任务。增强型CTCLoss的定义如下:

《PP-OCRv2》论文精读:蒸馏让PP-OCRv2获得了7%的OCR性能提升_论文精读_29

其中,xt是时间戳t的特征,cyt是类yt的中心。由于CRNN(Shi,Bai,和Yao 2016)算法中的特征和标签之间的错位misalignment,我们没有针对xt的显式标签yt。我们采用贪婪解码策略得到yt greedy decoding strategy:

《PP-OCRv2》论文精读:蒸馏让PP-OCRv2获得了7%的OCR性能提升_召回率_30

W为CTC头的参数。实验表明: λ = 0.05是一个很好的选择。