书接上回关于标签信息的导入:
jpg (bn, 3, 512, 512) 待输入进网络的内容
png (bn, 512, 512) (0-20为类别,21为白边,训练时忽略)
seg_labels (bn, 512, 512, 22)
交叉熵损失函数
用png标签信息:
inputs即网络输出的预测结果
target即从标签中得到的png
cls_weights即一个长为num_classes的np向量,是对每个类别计算交叉熵损失时的权重,可以通过提高权重更侧重某个类别的优化,比如背景类像素点太多,就把背景类的权重调低一点作为制衡。
def CE_Loss(inputs, target, cls_weights, num_classes=21):
n, c, h, w = inputs.size()
nt, ht, wt = target.size()
if h != ht and w != wt:
# 如果网络的输出imputs与target就采样到相同尺寸
inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True)
# (n,c,h,w) => (n*h*w,c)
temp_inputs = inputs.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
# (n,h,w) => (n*h*w)
temp_target = target.view(-1)
# 第num_classes类忽略,因为这玩意是标注时的白边不属于第0到第num_classes-1类
CE_loss = nn.CrossEntropyLoss(weight=cls_weights, ignore_index=num_classes)(temp_inputs, temp_target)
return CE_loss
Focal损失函数
关于logpt:
首先假设一个三分类问题,网络输出分数经过softmax后(0.7,0.2,0.1),标签为0,那个交叉熵损失之后为关于这个样本的损失值为log0.7,这也就是logpt中存储的值。
关于pt:
pt = torch.exp(logpt),那0.7就是pt中存的值。
所以对于一个样本(像素点)foacl的公式就是:
loss = - ( ( 1 - pt * alpha ) ** gama ) * logpt
即原本是logpt,现在在前面加上一个由pt决定的系数。如果该样本点的pt越接近1,则说明对其的预测结果越好,则其在损失函数中的权重越低,突出一个更关注分类结果差的样本点,增大分类效果差的样本在损失函数中的权重。
def Focal_Loss(inputs, target, cls_weights, num_classes=21, alpha=0.5, gamma=2):
n, c, h, w = inputs.size()
nt, ht, wt = target.size()
if h != ht and w != wt:
inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True)
temp_inputs = inputs.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
temp_target = target.view(-1)
logpt = -nn.CrossEntropyLoss(weight=cls_weights, ignore_index=num_classes, reduction='none')(temp_inputs, temp_target)
# (n*h*w,) 记录的是预测概率向量中,正确类的概率的log值
pt = torch.exp(logpt)
# e log x = x,即记录的是预测概率向量中,正确类的概率值
if alpha is not None:
logpt *= alpha
# pt越大,说明这个像素的概率预测向量中,正确类的概率越接近1,说明对这个像素点的预测结果很好,那么在梯度回传时,这个样本点的梯度信息占比就小一点吧
# 即更关注预测的效果更差的像素点
loss = -((1 - pt) ** gamma) * logpt
loss = loss.mean()
return loss
Dice损失函数
以一个例子来说明这个损失函数的思想
举个可视化的例子,三个样本的三分类问题
temp_inputs = [[0.2, 0.2, 0.6] # 输出结果
[0.8, 0.1, 0.1]
[0.9, 0.1, 0.0]]
temp_target = [[0.0, 0.0, 1.0] # 标签
[1.0, 0.0, 0.0]
[1.0, 0.0, 0.0]]
则tp = [1.7, 0.0, 0.6] # 长度为类别数
则fn = [2.0, 0.0, 1.0] - [1.7, 0.0, 0.6] = [0.3, 0.0, 0.4]
则fp = [1.9, 0.4, 0.7] - [1.7, 0.0, 0.6] = [0.2, 0.4, 0.1]
tp[0]记录了所有真实类别为第0类的样本输入进神经网络后,网络输出的该样本为第0类的概率值,即temp_inputs中0.8 0.9即网络输出的第二第三个样本属于第0类的概率,这类似于一种tp的概念(原本是第n类,网络认为它是第n类的概率值)
fn记录的是原本是第n类,网络认为它不是第n类的概率值(假阴性)
fp记录的是原本不是第n类,网络认为它是第n类的概率值(假阳性)
所以损失函数值为1 - tp/(tp+fn+fp)
fn减小意味着对于原本是第n类的样本,网络输出的其他类别的概率减小
fp减小意味着对于原本不是第n类的样本,网络输出的第n类的概率减小
tp增大意味着对于原本是第n类的样本,网络输出的第n类的概率增大
def Dice_loss(inputs, target, beta=1, smooth = 1e-5):
n, c, h, w = inputs.size()
nt, ht, wt, ct = target.size()
if h != ht and w != wt:
inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True)
# (n, h*w, c)
# (n, h*w, ct) 实际上ct=num_classes+1=c+1
temp_inputs = torch.softmax(inputs.transpose(1, 2).transpose(2, 3).contiguous().view(n, -1, c),-1)
temp_target = target.view(n, -1, ct)
#--------------------------------------------#
# 计算dice loss
# temp_target[...,:-1] 实际上是(n, h*w, c)
# temp_target[...,:-1] * temp_inputs
# torch.sum(temp_target[...,:-1], axis=[0,1]) 形状为(c,)即n*h*w个像素点中各类像素点有多少个
# tp值是上面一行的值的预测版本,tp值的理想值就是上面一行
#--------------------------------------------#
tp = torch.sum(temp_target[...,:-1] * temp_inputs, axis=[0,1]) # 真阳性
fp = torch.sum(temp_inputs , axis=[0,1]) - tp #
fn = torch.sum(temp_target[...,:-1] , axis=[0,1]) - tp
# tp / (tp+fn+fp) # 类似一种交并比的概念
# 想让这个东西最大化,
# fp变小,可以让预测结果中正确的类的概率值更接近1
# fn变小,可以让预测结果中错误的类的概率值更接近0
# 举个可视化的例子,三个样本的三分类问题
# temp_inputs = [[0.2, 0.2, 0.6]
# [0.8, 0.1, 0.1]
# [0.9, 0.1, 0.0]]
# temp_target = [[0.0, 0.0, 1.0]
# [1.0, 0.0, 0.0]
# [1.0, 0.0, 0.0]]
# 则tp = [1.7, 0.0, 0.6]
# 则fn = [2.0, 0.0, 1.0] - [1.7, 0.0, 0.6] = [0.3, 0.0, 0.4]
# 则fp = [1.9, 0.4, 0.7] - [1.7, 0.0, 0.6] = [0.2, 0.4, 0.1]
score = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth)
dice_loss = 1 - torch.mean(score)
return dice_loss