书接上回关于标签信息的导入:

jpg (bn, 3, 512, 512) 待输入进网络的内容
png (bn, 512, 512) (0-20为类别,21为白边,训练时忽略)
seg_labels (bn, 512, 512, 22)

交叉熵损失函数

用png标签信息:
inputs即网络输出的预测结果
target即从标签中得到的png
cls_weights即一个长为num_classes的np向量,是对每个类别计算交叉熵损失时的权重,可以通过提高权重更侧重某个类别的优化,比如背景类像素点太多,就把背景类的权重调低一点作为制衡。

def CE_Loss(inputs, target, cls_weights, num_classes=21):
    n, c, h, w = inputs.size()
    nt, ht, wt = target.size()
    if h != ht and w != wt:  
    # 如果网络的输出imputs与target就采样到相同尺寸
        inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True)

    # (n,c,h,w) => (n*h*w,c)
    temp_inputs = inputs.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
    # (n,h,w) => (n*h*w)
    temp_target = target.view(-1)
    # 第num_classes类忽略,因为这玩意是标注时的白边不属于第0到第num_classes-1类
    CE_loss  = nn.CrossEntropyLoss(weight=cls_weights, ignore_index=num_classes)(temp_inputs, temp_target)
    return CE_loss

Focal损失函数

关于logpt:
首先假设一个三分类问题,网络输出分数经过softmax后(0.7,0.2,0.1),标签为0,那个交叉熵损失之后为关于这个样本的损失值为log0.7,这也就是logpt中存储的值。
关于pt:
pt = torch.exp(logpt),那0.7就是pt中存的值。
所以对于一个样本(像素点)foacl的公式就是:
loss = - ( ( 1 - pt * alpha ) ** gama ) * logpt
即原本是logpt,现在在前面加上一个由pt决定的系数。如果该样本点的pt越接近1,则说明对其的预测结果越好,则其在损失函数中的权重越低,突出一个更关注分类结果差的样本点,增大分类效果差的样本在损失函数中的权重。

def Focal_Loss(inputs, target, cls_weights, num_classes=21, alpha=0.5, gamma=2):
    n, c, h, w = inputs.size()
    nt, ht, wt = target.size()
    if h != ht and w != wt:
        inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True)

    temp_inputs = inputs.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
    temp_target = target.view(-1)

    logpt  = -nn.CrossEntropyLoss(weight=cls_weights, ignore_index=num_classes, reduction='none')(temp_inputs, temp_target)
    # (n*h*w,) 记录的是预测概率向量中,正确类的概率的log值
    pt = torch.exp(logpt)
    # e log x = x,即记录的是预测概率向量中,正确类的概率值
    if alpha is not None:
        logpt *= alpha
    # pt越大,说明这个像素的概率预测向量中,正确类的概率越接近1,说明对这个像素点的预测结果很好,那么在梯度回传时,这个样本点的梯度信息占比就小一点吧
    # 即更关注预测的效果更差的像素点
    loss = -((1 - pt) ** gamma) * logpt
    loss = loss.mean()
    return loss

Dice损失函数

以一个例子来说明这个损失函数的思想
举个可视化的例子,三个样本的三分类问题

temp_inputs = [[0.2, 0.2, 0.6] # 输出结果
 [0.8, 0.1, 0.1]
 [0.9, 0.1, 0.0]]
 temp_target = [[0.0, 0.0, 1.0] # 标签
 [1.0, 0.0, 0.0]
 [1.0, 0.0, 0.0]]
 则tp = [1.7, 0.0, 0.6] # 长度为类别数
 则fn = [2.0, 0.0, 1.0] - [1.7, 0.0, 0.6] = [0.3, 0.0, 0.4]
 则fp = [1.9, 0.4, 0.7] - [1.7, 0.0, 0.6] = [0.2, 0.4, 0.1]


tp[0]记录了所有真实类别为第0类的样本输入进神经网络后,网络输出的该样本为第0类的概率值,即temp_inputs中0.8 0.9即网络输出的第二第三个样本属于第0类的概率,这类似于一种tp的概念(原本是第n类,网络认为它是第n类的概率值)
fn记录的是原本是第n类,网络认为它不是第n类的概率值(假阴性)
fp记录的是原本不是第n类,网络认为它是第n类的概率值(假阳性)
所以损失函数值为1 - tp/(tp+fn+fp)
fn减小意味着对于原本是第n类的样本,网络输出的其他类别的概率减小
fp减小意味着对于原本不是第n类的样本,网络输出的第n类的概率减小
tp增大意味着对于原本是第n类的样本,网络输出的第n类的概率增大

def Dice_loss(inputs, target, beta=1, smooth = 1e-5):
    n, c, h, w = inputs.size()
    nt, ht, wt, ct = target.size()
    if h != ht and w != wt:
        inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True)
    # (n, h*w, c)
    # (n, h*w, ct) 实际上ct=num_classes+1=c+1
    temp_inputs = torch.softmax(inputs.transpose(1, 2).transpose(2, 3).contiguous().view(n, -1, c),-1)
    temp_target = target.view(n, -1, ct)

    #--------------------------------------------#
    #   计算dice loss
    #   temp_target[...,:-1] 实际上是(n, h*w, c)
    #   temp_target[...,:-1] * temp_inputs
    #   torch.sum(temp_target[...,:-1], axis=[0,1])  形状为(c,)即n*h*w个像素点中各类像素点有多少个
    #   tp值是上面一行的值的预测版本,tp值的理想值就是上面一行
    #--------------------------------------------#
    tp = torch.sum(temp_target[...,:-1] * temp_inputs, axis=[0,1])  # 真阳性
    fp = torch.sum(temp_inputs                       , axis=[0,1]) - tp  # 
    fn = torch.sum(temp_target[...,:-1]              , axis=[0,1]) - tp

    # tp / (tp+fn+fp)  # 类似一种交并比的概念
    # 想让这个东西最大化,
    # fp变小,可以让预测结果中正确的类的概率值更接近1
    # fn变小,可以让预测结果中错误的类的概率值更接近0
    # 举个可视化的例子,三个样本的三分类问题
    # temp_inputs = [[0.2, 0.2, 0.6]
    #                [0.8, 0.1, 0.1]
    #                [0.9, 0.1, 0.0]]
    # temp_target = [[0.0, 0.0, 1.0]
    #                [1.0, 0.0, 0.0]
    #                [1.0, 0.0, 0.0]]
    # 则tp = [1.7, 0.0, 0.6]  
    # 则fn = [2.0, 0.0, 1.0] - [1.7, 0.0, 0.6] = [0.3, 0.0, 0.4]
    # 则fp = [1.9, 0.4, 0.7] - [1.7, 0.0, 0.6] = [0.2, 0.4, 0.1]
    score = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth)
    dice_loss = 1 - torch.mean(score)
    return dice_loss