图像分割常见Loss
最近在研究图像分割,由于自己之前没学习过,只能好好鼓捣,好久没写了。忙完分割项目总结一下。
1、基于分类损失
①:binary cross entropy
二分类的交叉熵损失函数
当类别数M等于2的时候,这个损失就是二元交叉熵Loss。
交叉熵Loss可以用在大多数语义分割场景中,但它有个缺点,那就是对于只用分割前景和背景的时候,当前景像素的数量远远小于背景像素的数量时,即y=0的数量远大于y=1的数量,损失函数中y=0的成分就会占据主导,使得模型严重偏向背景,导致效果不好。
#二值交叉熵,这里输入要经过sigmoid处理
import torch
import torch.nn as nn
import torch.nn.functional as F
nn.BCELoss(F.sigmoid(input), target)
#多分类交叉熵, 用这个 loss 前面不需要加 Softmax 层
nn.CrossEntropyLoss(input, target)
②:Weighted cross Entropy
WCE是对BCE的改进,用于缓解类别不平衡的问题
③:Focal loss
用来解决难以样本数量不平衡。
class FocalLoss(nn.Module):
def __init__(self, gamma=0, alpha=None, size_average=True):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.alpha = alpha
if isinstance(alpha,(float,int,long)): self.alpha = torch.Tensor([alpha,1-alpha])
if isinstance(alpha,list): self.alpha = torch.Tensor(alpha)
self.size_average = size_average
def forward(self, input, target):
if input.dim()>2:
input = input.view(input.size(0),input.size(1),-1) # N,C,H,W => N,C,H*W
input = input.transpose(1,2) # N,C,H*W => N,H*W,C
input = input.contiguous().view(-1,input.size(2)) # N,H*W,C => N*H*W,C
target = target.view(-1,1)
logpt = F.log_softmax(input)
logpt = logpt.gather(1,target)
logpt = logpt.view(-1)
pt = Variable(logpt.data.exp())
if self.alpha is not None:
if self.alpha.type()!=input.data.type():
self.alpha = self.alpha.type_as(input.data)
at = self.alpha.gather(0,target.data.view(-1))
logpt = logpt * Variable(at)
loss = -1 * (1-pt)**self.gamma * logpt
if self.size_average: return loss.mean()
else: return loss.sum()
2、基于评价函数
上面的损失函数都可以看做是对像素进行分类的损失函数,而在图像分割中经常使用IOU对分割效果进行评价,只要损失函数是可导的,可以进行反向传播,那这个损失函数就可以对模型进行优化。下面几个损失函数都是以IOU来设计的Loss。
①:Dice Loss
DIce Loss:公式定义为:
Dice Loss使用与样本极度不均衡的情况,如果一般情况下使用Dice Loss会会反向传播有不理影响,使得训练不稳定。
import torch.nn as nn
import torch.nn.functional as F
class SoftDiceLoss(nn.Module):
def __init__(self, weight=None, size_average=True):
super(SoftDiceLoss, self).__init__()
def forward(self, logits, targets):
num = targets.size(0)
// 为了防止除0的发生
smooth = 1
probs = F.sigmoid(logits)
m1 = probs.view(num, -1)
m2 = targets.view(num, -1)
intersection = (m1 * m2)
score = 2. * (intersection.sum(1) + smooth) / (m1.sum(1) + m2.sum(1) + smooth)
score = 1 - score.sum() / num
return score
②:Generalized Dice Loss
Dice Loss对小目标的预测是十分不利的,因为一旦小目标有部分像素预测错误,就可能会引起Dice系数大幅度波动,导致梯度变化大训练不稳定。Dice Loss针对的是某一个特定类别的分割的损失。当类似于病灶分割有多个场景的时候一般都会使用多个Dice loss,故Generalized Dice loss就是将多个类别的Dice Loss进行整个,使用一个指标作为分割结果的量化标准。GDL Loss在类别数为2时公式如下:
其中rln表示类别l在第n个位置的真实像素类别,而pln表示相应的预测概率值,wl表示每个类别的权重,wl的公式为:
③:BCE + Dice Loss
即将BCE Loss和Dice Loss 进行组合,在数据较为均衡的情况下有所改善,但是在数据极度不均衡的情况下交叉熵Loss会在迭代几个Epoch远远小于Dice Loss,这个组合Loss 会退化为Dice Loss
④:Focal Loss + Dice Loss
⑤:Lovasz-Softmax Loss
且行且珍惜…