在不平衡数据集上做图像分类 数据不平衡分类focaloss_实例分割


FocalLoss 的出现,主要是为了解决 anchor-based (one-stage) 目标检测网络的分类问题。后面实例分割也常使用。

注意
这里是 目标检测网络的分类问题,而不是单纯的分类问题,这两者是不一样的。
区别在于,对于分配问题,一个图片一定是属于某一确定的类的;而检测任务中的分类,是有大量的anchor无目标的(可以称为负样本)。




分类任务

正常的 K类分类任务 的标签,是用一个K长度的向量作为标签,用one-hot(或者+smooth,这里先不考虑)来进行编码,最终的标签是一个形如[1,…, 0, …, 0]这样的。那么如果想要将背景分离出,自然可以想到增加一个1维,如果目标检测任务有K类,这里只要用K+1维来表示分类,其中1维代表无目标即可。对于分类任务而言,最后一般使用 softmax 来归一,使得所有类别的输出加和为1。

在不平衡数据集上做图像分类 数据不平衡分类focaloss_数据不平衡_02

但是在检测任务中,对于无目标的anchor,我们并不希望最终结果加和为1,而是所有的概率输出都是0。 那么可以这样,我们将一个多分类任务看做多个二分类任务(sigmoid),针对每一个类别,我输出一个概率,如果接近0则代表非该类别,如果接近1,则代表这个anchor是该类别。

所以网络输出不需要用softmax来归一,而是对K长度向量的每一个分量进行sigmoid激活,让其输出值代表二分类的概率。对于无目标的anchor,gt中所有的分量都是0,代表属于每一类的概率是0,即标注为背景。

至此,FocalLoss解决的问题不是多分类问题,而是 多个二分类问题




公式解析

首先看公式:只有 标签在不平衡数据集上做图像分类 数据不平衡分类focaloss_实例分割_03时,公式/交叉熵才有意义,在不平衡数据集上做图像分类 数据不平衡分类focaloss_深度学习_04即为标签为1时对应的预测值/模型分类正确的概率在不平衡数据集上做图像分类 数据不平衡分类focaloss_目标检测_05

在不平衡数据集上做图像分类 数据不平衡分类focaloss_目标检测_06

  1. 参数p[公式3]:当 p->0时(概率很低/很难区分是那个类别),调制因子 (1-p)接近1,损失不被影响,当 p->1时,(1-p)接近0,从而减小易分样本对总 loss的贡献
  2. 参数在不平衡数据集上做图像分类 数据不平衡分类focaloss_深度学习_07:当在不平衡数据集上做图像分类 数据不平衡分类focaloss_数据不平衡_08 时,Focal loss就是传统的交叉熵,
    在不平衡数据集上做图像分类 数据不平衡分类focaloss_深度学习_07 增加时, 调节系数在不平衡数据集上做图像分类 数据不平衡分类focaloss_数据不平衡_10 也会增加。
    在不平衡数据集上做图像分类 数据不平衡分类focaloss_深度学习_07 为定值时,比如 在不平衡数据集上做图像分类 数据不平衡分类focaloss_实例分割_12 1⃣️对于easy example(p>0.5) p=0.9 的loss要比标准的交叉熵小 100倍,当 p=0.968时,要小1000+倍;2⃣️对于 hard example(p<0.5) loss要小4倍
    这样的话, hard example 的权重相对提升了很多,从而增加了哪些误分类的重要性。
    实验表明,在不平衡数据集上做图像分类 数据不平衡分类focaloss_目标检测_13时效果最好
  3. 在不平衡数据集上做图像分类 数据不平衡分类focaloss_实例分割_14 调节正负样本不平衡系数,在不平衡数据集上做图像分类 数据不平衡分类focaloss_深度学习_07



代码复现

在官方给的代码中,并没有 target = F.one_hot(target, num_clas) 这行代码,这是因为

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

import torch
from torch.nn import functional as F


def sigmoid_focal_loss( inputs: torch.Tensor, targets: torch.Tensor, alpha: float = -1, 
    					gamma: float = 2, reduction: str = "none") -> torch.Tensor:

    inputs  = inputs.float()
    targets = targets.float()
    p 	    = torch.sigmoid(inputs)
    target  = F.one_hot(target, num_clas+1)
    # target = target[:,1:]
    ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
    p_t 	= p * targets + (1 - p) * (1 - targets)
    loss 	= ce_loss * ((1 - p_t) ** gamma)

    if alpha >= 0:
        alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
        loss = alpha_t * loss

    if reduction == "mean":
        loss = loss.mean()
    elif reduction == "sum":
        loss = loss.sum()

    return loss


sigmoid_focal_loss_jit: "torch.jit.ScriptModule" = torch.jit.script(sigmoid_focal_loss)

此外,torchvision 中也支持 focal loss