图像分割常见Loss

最近在研究图像分割,由于自己之前没学习过,只能好好鼓捣,好久没写了。忙完分割项目总结一下。

1、基于分类损失

①:binary cross entropy

二分类的交叉熵损失函数

inception v3图像分类过程 图像分类loss_损失函数


当类别数M等于2的时候,这个损失就是二元交叉熵Loss。

交叉熵Loss可以用在大多数语义分割场景中,但它有个缺点,那就是对于只用分割前景和背景的时候,当前景像素的数量远远小于背景像素的数量时,即y=0的数量远大于y=1的数量,损失函数中y=0的成分就会占据主导,使得模型严重偏向背景,导致效果不好。

#二值交叉熵,这里输入要经过sigmoid处理

import torch
import torch.nn as nn
import torch.nn.functional as F
nn.BCELoss(F.sigmoid(input), target)

#多分类交叉熵, 用这个 loss 前面不需要加 Softmax 层
nn.CrossEntropyLoss(input, target)

②:Weighted cross Entropy

WCE是对BCE的改进,用于缓解类别不平衡的问题

inception v3图像分类过程 图像分类loss_计算机视觉_02

③:Focal loss
用来解决难以样本数量不平衡。

class FocalLoss(nn.Module):
    def __init__(self, gamma=0, alpha=None, size_average=True):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        if isinstance(alpha,(float,int,long)): self.alpha = torch.Tensor([alpha,1-alpha])
        if isinstance(alpha,list): self.alpha = torch.Tensor(alpha)
        self.size_average = size_average

    def forward(self, input, target):
        if input.dim()>2:
            input = input.view(input.size(0),input.size(1),-1)  # N,C,H,W => N,C,H*W
            input = input.transpose(1,2)    # N,C,H*W => N,H*W,C
            input = input.contiguous().view(-1,input.size(2))   # N,H*W,C => N*H*W,C
        target = target.view(-1,1)

        logpt = F.log_softmax(input)
        logpt = logpt.gather(1,target)
        logpt = logpt.view(-1)
        pt = Variable(logpt.data.exp())

        if self.alpha is not None:
            if self.alpha.type()!=input.data.type():
                self.alpha = self.alpha.type_as(input.data)
            at = self.alpha.gather(0,target.data.view(-1))
            logpt = logpt * Variable(at)

        loss = -1 * (1-pt)**self.gamma * logpt
        if self.size_average: return loss.mean()
        else: return loss.sum()

2、基于评价函数

上面的损失函数都可以看做是对像素进行分类的损失函数,而在图像分割中经常使用IOU对分割效果进行评价,只要损失函数是可导的,可以进行反向传播,那这个损失函数就可以对模型进行优化。下面几个损失函数都是以IOU来设计的Loss。

①:Dice Loss

DIce Loss:公式定义为:

inception v3图像分类过程 图像分类loss_计算机视觉_03

Dice Loss使用与样本极度不均衡的情况,如果一般情况下使用Dice Loss会会反向传播有不理影响,使得训练不稳定。

import torch.nn as nn
import torch.nn.functional as F
class SoftDiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(SoftDiceLoss, self).__init__()
 
    def forward(self, logits, targets):
        num = targets.size(0)
        // 为了防止除0的发生
        smooth = 1
        
        probs = F.sigmoid(logits)
        m1 = probs.view(num, -1)
        m2 = targets.view(num, -1)
        intersection = (m1 * m2)
 
        score = 2. * (intersection.sum(1) + smooth) / (m1.sum(1) + m2.sum(1) + smooth)
        score = 1 - score.sum() / num
        return score

②:Generalized Dice Loss

Dice Loss对小目标的预测是十分不利的,因为一旦小目标有部分像素预测错误,就可能会引起Dice系数大幅度波动,导致梯度变化大训练不稳定。Dice Loss针对的是某一个特定类别的分割的损失。当类似于病灶分割有多个场景的时候一般都会使用多个Dice loss,故Generalized Dice loss就是将多个类别的Dice Loss进行整个,使用一个指标作为分割结果的量化标准。GDL Loss在类别数为2时公式如下:

inception v3图像分类过程 图像分类loss_深度学习_04


其中rln表示类别l在第n个位置的真实像素类别,而pln表示相应的预测概率值,wl表示每个类别的权重,wl的公式为:

inception v3图像分类过程 图像分类loss_深度学习_05

③:BCE + Dice Loss
即将BCE Loss和Dice Loss 进行组合,在数据较为均衡的情况下有所改善,但是在数据极度不均衡的情况下交叉熵Loss会在迭代几个Epoch远远小于Dice Loss,这个组合Loss 会退化为Dice Loss

④:Focal Loss + Dice Loss

⑤:Lovasz-Softmax Loss

且行且珍惜…