Focal Loss for Dense Object Detection

论文链接:https://arxiv.org/abs/1708.02002 代码链接:集成到Detectron中,github上有大量三方实现
这周补了一下经典的focal loss,也就是RetinaNet,很多人应该也比较熟悉这篇文章了。Focal Loss是何恺明团队在2017年推出的作品,属于single stage的算法。在当时其精度甚至可以超过two stages的算法,且仍然有很高的检测速度,还不了解的同学可以一起来看看是怎么做的。

1. Background

我之前在refinenet的那篇方法里提到过,two stages的方法比one stage更好的一个重要原因是不需要面临单阶段检测框架中严重的类别不平衡问题。
具体来说,检测问题为了尽可能多地涵盖可能出现在不同位置的不同形状的物体,必须使用大量不同尺度、长宽比的anchor,在这之中,只有极少数的和gt重叠较大,会被归为正样本,剩下大量的都是负样本。多阶段的检测算法可以通过在RPN阶段之后筛选Roi来解决这个问题,但是单阶段算法就没这么幸运了。也因此,在Focal Loss中,作者很明确地把这个原因当做单阶段算法无法达到state-of-the-art准确性的根本原因。
有人会觉得,我像SSD一样做个采样,不用那么多负样本不就可以了吗?但是这样会带来两个问题:

  1. 负样本中,easy negative占大多数。所谓easy negative是指很容易被归为负样本的候选框,而那些困难的,比如包含了不容易区分的背景类物体、和gt有小部分重叠的候选框往往是少数。大量训练easy negative本身并不高效,且会主导网络的训练过程,但实际上训练应该更关注困难样本。
  2. 类似OHEM那种也是采样,虽然一定程度克服了1的问题,但是采样本身抛弃了很多可以利用的信息和数据,训练会更慢。

所以很自然地,作者提出了一个想法,就是把所有训练数据都利用起来,同时,给每个候选框一个动态的尺度因子。easy negative的尺度因子往往非常小,尺度因子的引入可以在利用所有候选框的情况下平衡正负样本对loss的贡献,同时划分难易样本、使训练更关注困难样本。

2. Focal loss & RetinaNet

损失函数定义

一般的交叉熵函数定义如下:

python 绘制L1 Loss损失函数 focal loss损失函数_函数定义

为了方便表示,定义:

python 绘制L1 Loss损失函数 focal loss损失函数_初始化_02

这样,交叉熵损失函数可以简写为-log(pt)。虽然easy examples的损失很小,但是巨大的数量会使它们占据主导,为了平衡数量,有人提出了拓展的交叉熵函数:

python 绘制L1 Loss损失函数 focal loss损失函数_ios_03

而作者更进一步提出了focal loss,表示如下:

python 绘制L1 Loss损失函数 focal loss损失函数_初始化_04

FL中新增的因子是可变的,当example被错误分类的时候,pt会很小,因子接近1;而当example被正确分类接近1的时候,因子则会接近0;γ参数也是一个变量,不同γ的效果如下图所示:

python 绘制L1 Loss损失函数 focal loss损失函数_ios_05

这张图也可以很直观地看到,well-classified exaples的loss几乎为0,作者最后使用的是γ为2. 另外,之前提到的拓展也可以用到FL上,最后的表达形式如下:

python 绘制L1 Loss损失函数 focal loss损失函数_函数定义_06

这种α平衡可以少许提升性能。

网络结构

python 绘制L1 Loss损失函数 focal loss损失函数_初始化_07

可以看到,特征抽取网络是ResNet+FPN,整体框架也并没有什么特别突出的地方。主要还是在损失函数的设计上下功夫。
分类和回归的subnets都是小型的FCN网络,且回归是class-agnostic的。关于ratios或者pyramids levels的使用这些都是一些比较常见的参数设置。
比较特别的一点是,在计算了所有anchor的和,并用anchors的数量做normalized的时候,并没有使用所有的anchor而是只用了匹配为正样本的anchor的数量,原因很有可能是大量负样本anchor的损失已经几乎为0.

初始化问题

模型随机初始化后,Binary Classification区分-1和1的两类概率基本相同,这会导致早期的训练中,损失主要来源于frequent class,比如检测问题中的负类。
为了处理这个问题,作者采用了一种特殊的初始化方法,主要是改变bias的值,让初始化输出rare class的概率约为非常小,从而和检测问题当中anchor被分为正类的概率小、负类概率大的客观事实所吻合,这样就不会导致大量分类错误的anchor引起的损失过大问题。具体的公式可以到文章中去看。

3. Experiment

这个部分主要内容是对比实验,包括:

  1. CE loss、带α平衡的CE loss与focal loss及带α平衡的focal loss的对比
  2. anchors不同设置(例如ratios)的对比
  3. focal loss不同参数(例如γ)的对比
  4. focal loss与OHEM的对比,作者通过这一点证明了focal loss这种使用了所有anchor的方法确实更具备优越性
    最后的结果如下所示:

    当然,FL最大的优势之一仍然是它的速度: