一、网络搭建源码分析

core.models.CornerNet.py

import torch
import torch.nn as nn

from .py_utils import TopPool, BottomPool, LeftPool, RightPool #作者定义的C++4个扩展POOL操作

from .py_utils.utils import convolution, residual, corner_pool  #三大.py文件 utils losses modules
from .py_utils.losses import CornerNet_Loss
from .py_utils.modules import hg_module, hg, hg_net

def make_pool_layer(dim):  #这一层啥也没有。。。
    return nn.Sequential()

def make_hg_layer(inp_dim, out_dim, modules):  #class residual(nn.Module)     #级联的残差网络
    layers  = [residual(inp_dim, out_dim, stride=2)]
    layers += [residual(out_dim, out_dim) for _ in range(1, modules)]
    return nn.Sequential(*layers)

class model(hg_net):  #注意继承了,hg_net很关键
    def _pred_mod(self, dim):  #灰块-conv,conv-bn-relu-conv
        return nn.Sequential(
            convolution(3, 256, 256, with_bn=False),
            nn.Conv2d(256, dim, (1, 1))  #nn.Conv2d(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True))
        )

    def _merge_mod(self): #conv-bn
        return nn.Sequential(
            nn.Conv2d(256, 256, (1, 1), bias=False),  #二维卷积
            nn.BatchNorm2d(256)
        )
#nn.Sequential,不同于 nn.ModuleList,它已经实现了内部的 forward 函数,而且里面的模块必须是按照顺序进行排列的,所以我们必须确保前一个模块的输出大小和下一个模块的输入大小是一致的
    def __init__(self):
        stacks  = 2  #主干中两层沙漏网络
        pre     = nn.Sequential(    #在沙漏网络之前,我们通过2步长 128通道的7*7卷积模块,其后接2步长256通道的残差模块,将图像分辨率降低了4倍。
            convolution(7, 3, 128, stride=2),  #灰块
            residual(128, 256, stride=2)
        )
        hg_mods = nn.ModuleList([
            hg_module(
                5, [256, 256, 384, 384, 384, 512], [2, 2, 2, 2, 2, 4],  #按照(256,384,384,384,512)五次提升维度,第一个256是初始分辨率,不算。
                make_pool_layer=make_pool_layer,   #[2, 2, 2, 2, 2, 4]是模块数量 #在沙漏网络的中部有4个残差模块。
                make_hg_layer=make_hg_layer  #这俩不使用预设
            ) for _ in range(stacks)
        ]) #两个沙漏放在了容器中
        cnvs    = nn.ModuleList([convolution(3, 256, 256) for _ in range(stacks)]) #灰块 conv-bn-relu
        inters  = nn.ModuleList([residual(256, 256) for _ in range(stacks - 1)]) #残差模块,沙漏之间的相互连接interlink
        cnvs_   = nn.ModuleList([self._merge_mod() for _ in range(stacks - 1)]) #conv-bn
        inters_ = nn.ModuleList([self._merge_mod() for _ in range(stacks - 1)]) #conv-bn 放几个是为了可调成不一样

        hgs = hg(pre, hg_mods, cnvs, inters, cnvs_, inters_)  #实例化hg类为hgs

        tl_modules = nn.ModuleList([corner_pool(256, TopPool, LeftPool) for _ in range(stacks)])   #角点池化 维度256 两大特征图
        br_modules = nn.ModuleList([corner_pool(256, BottomPool, RightPool) for _ in range(stacks)])

        tl_heats = nn.ModuleList([self._pred_mod(80) for _ in range(stacks)])  #生成热图 #灰块-conv,conv-bn-relu-conv,输出维度80,因为有80个种类
        br_heats = nn.ModuleList([self._pred_mod(80) for _ in range(stacks)])
        for tl_heat, br_heat in zip(tl_heats, br_heats):               #矩阵初始化 #zip() 函数用于将可迭代的对象作为参数,将对象中对应的元素打包成一个个元组
            torch.nn.init.constant_(tl_heat[-1].bias, -2.19)
            torch.nn.init.constant_(br_heat[-1].bias, -2.19)

        tl_tags = nn.ModuleList([self._pred_mod(1) for _ in range(stacks)])  #生成嵌入 #灰块-conv,conv-bn-relu-conv 输出维度1,每个点一个数值
        br_tags = nn.ModuleList([self._pred_mod(1) for _ in range(stacks)])

        tl_offs = nn.ModuleList([self._pred_mod(2) for _ in range(stacks)])  #生成偏移 #灰块-conv,conv-bn-relu-conv,输出维度2,每个点两个数值(即一个向量)
        br_offs = nn.ModuleList([self._pred_mod(2) for _ in range(stacks)])
#nn.ModuleList 只是个容器,并没有定义一个网络,它只是将不同的模块储存在一起,这些模块之间并没有什么先后顺序可言,网络的执行顺序是根据 forward 函数来决定的。
        super(model, self).__init__(     #把以上参数传递给父类,传给hg_net
            hgs, tl_modules, br_modules, tl_heats, br_heats, 
            tl_tags, br_tags, tl_offs, br_offs
        )

        self.loss = CornerNet_Loss(pull_weight=1e-1, push_weight=1e-1)  #传入工厂函数里的损失函数

core.models.py_utils.modules.py

import torch
import torch.nn as nn

from .utils import residual, upsample, merge, _decode

def _make_layer(inp_dim, out_dim, modules):     #级联的残差模块 #make_low_layer,make_low_layer,make_up_layer=_make_layer
    layers  = [residual(inp_dim, out_dim)]
    layers += [residual(out_dim, out_dim) for _ in range(1, modules)]
    return nn.Sequential(*layers)

def _make_layer_revr(inp_dim, out_dim, modules):   #级联的残差模块
    layers  = [residual(inp_dim, inp_dim) for _ in range(modules - 1)]
    layers += [residual(inp_dim, out_dim)]
    return nn.Sequential(*layers)

def _make_pool_layer(dim):  #最大池化
    return nn.MaxPool2d(kernel_size=2, stride=2) #卷积核与卷积长

def _make_unpool_layer(dim):  #上采样 #这个函数是用来上采样或下采样,可以给定size或者scale_factor来进行上下采样。插值算法可选,最近邻、线性、双线性等等。
    return upsample(scale_factor=2)

def _make_merge_layer(dim):  # x+y
    return merge()

class hg_module(nn.Module):
    def __init__(
        self, n, dims, modules, make_up_layer=_make_layer,
        make_pool_layer=_make_pool_layer, make_hg_layer=_make_layer,
        make_low_layer=_make_layer, make_hg_layer_revr=_make_layer_revr,
        make_unpool_layer=_make_unpool_layer, make_merge_layer=_make_merge_layer  #默认的参数
    ):
        """
        hg_module(
            5, [256, 256, 384, 384, 384, 512], [2, 2, 2, 2, 2, 4],  # 按照(256,384,384,384,512)五次提升维度,第一个256是初试分辨率,不算。
            make_pool_layer=make_pool_layer,  # [2, 2, 2, 2, 2, 4]是模块数量 #在沙漏网络的中部有4个残差模块。
            make_hg_layer=make_hg_layer  # 这俩不使用预设
        )
        """
        super(hg_module, self).__init__()

        curr_mod = modules[0] #模块数量
        next_mod = modules[1]

        curr_dim = dims[0]
        next_dim = dims[1]

        self.n    = n
        self.up1  = make_up_layer(curr_dim, curr_dim, curr_mod) #级联残差
        self.max1 = make_pool_layer(curr_dim) #最大池化 #没有使用预设
        self.low1 = make_hg_layer(curr_dim, next_dim, curr_mod) #级联残差,升维 #没有使用预设
        self.low2 = hg_module(                       #我用我自己,递归!
            n - 1, dims[1:], modules[1:],  #数组舍去第一个元素,[0:]即是无变化
            make_up_layer=make_up_layer,
            make_pool_layer=make_pool_layer,
            make_hg_layer=make_hg_layer,
            make_low_layer=make_low_layer,
            make_hg_layer_revr=make_hg_layer_revr,
            make_unpool_layer=make_unpool_layer,
            make_merge_layer=make_merge_layer #照抄了一遍。。。
        ) if n > 1 else make_low_layer(next_dim, next_dim, next_mod) #沙漏中间是级联的残差模块
        self.low3 = make_hg_layer_revr(next_dim, curr_dim, curr_mod)#级联残差,降维
        self.up2  = make_unpool_layer(curr_dim) #上采样
        self.merg = make_merge_layer(curr_dim) #相加 #总体而言,沙漏维度不变

    def forward(self, x):
        up1  = self.up1(x)
        max1 = self.max1(x)
        low1 = self.low1(max1)
        low2 = self.low2(low1)
        low3 = self.low3(low2)
        up2  = self.up2(low3)
        merg = self.merg(up1, up2)
        return merg

class hg(nn.Module):
    def __init__(self, pre, hg_modules, cnvs, inters, cnvs_, inters_):
        super(hg, self).__init__()

        self.pre  = pre
        self.hgs  = hg_modules
        self.cnvs = cnvs  ##灰块 conv-bn-relu

        self.inters  = inters  #残差模块
        self.inters_ = inters_  #conv-bn
        self.cnvs_   = cnvs_  #conv-bn

    def forward(self, x):
        inter = self.pre(x)

        cnvs  = [] #输出为数组类型?
        for ind, (hg_, cnv_) in enumerate(zip(self.hgs, self.cnvs)): #枚举,ind是序号,默认下标从0开始
            hg  = hg_(inter)  #pre模块直接第一个沙漏
            cnv = cnv_(hg)   #第一个沙漏接 灰块 conv-bn-relu
            cnvs.append(cnv)  #在列表末尾添加新的对象

            if ind < len(self.hgs) - 1:  #len返回列表元素个数,沙漏之间有东西链接,在本项目已确定两个沙漏的情况下,直接if(ind==0)就好
                inter = self.inters_[ind](inter) + self.cnvs_[ind](cnv)
                inter = nn.functional.relu_(inter)
                inter = self.inters[ind](inter)
        return cnvs  #包含两个沙漏网络的输出,之所以包含中间的是为了中间监督(然而推理只用最后一层的输出)

class hg_net(nn.Module): #被CornerNet的model继承了
    def __init__(
        self, hg, tl_modules, br_modules, tl_heats, br_heats,   #tl_modules是角点池化
        tl_tags, br_tags, tl_offs, br_offs
    ):
        super(hg_net, self).__init__()

        self._decode = _decode

        self.hg = hg

        self.tl_modules = tl_modules
        self.br_modules = br_modules

        self.tl_heats = tl_heats
        self.br_heats = br_heats

        self.tl_tags = tl_tags
        self.br_tags = br_tags
        
        self.tl_offs = tl_offs
        self.br_offs = br_offs

    def _train(self, *xs):
        image = xs[0]
        cnvs  = self.hg(image) #图片先经过沙漏
         #实际上是单独根据个第二沙漏网络的输出来预测  #在训练中加入了中间监督。然而,并未将中间预测加回到网络中因为这有损于网络性能。
        tl_modules = [tl_mod_(cnv) for tl_mod_, cnv in zip(self.tl_modules, cnvs)]   #穷举沙漏网络(2个),进行角点池化
        br_modules = [br_mod_(cnv) for br_mod_, cnv in zip(self.br_modules, cnvs)]
        tl_heats   = [tl_heat_(tl_mod) for tl_heat_, tl_mod in zip(self.tl_heats, tl_modules)]  #生成热图
        br_heats   = [br_heat_(br_mod) for br_heat_, br_mod in zip(self.br_heats, br_modules)]
        tl_tags    = [tl_tag_(tl_mod)  for tl_tag_,  tl_mod in zip(self.tl_tags,  tl_modules)]  #嵌入
        br_tags    = [br_tag_(br_mod)  for br_tag_,  br_mod in zip(self.br_tags,  br_modules)]
        tl_offs    = [tl_off_(tl_mod)  for tl_off_,  tl_mod in zip(self.tl_offs,  tl_modules)]  #偏移
        br_offs    = [br_off_(br_mod)  for br_off_,  br_mod in zip(self.br_offs,  br_modules)]
        return [tl_heats, br_heats, tl_tags, br_tags, tl_offs, br_offs]

    def _test(self, *xs, **kwargs):
        image = xs[0]
        cnvs  = self.hg(image)
        # 仅用沙漏网络backbone最后一层的输出
        tl_mod = self.tl_modules[-1](cnvs[-1])
        br_mod = self.br_modules[-1](cnvs[-1])

        tl_heat, br_heat = self.tl_heats[-1](tl_mod), self.br_heats[-1](br_mod)
        tl_tag,  br_tag  = self.tl_tags[-1](tl_mod),  self.br_tags[-1](br_mod)
        tl_off,  br_off  = self.tl_offs[-1](tl_mod),  self.br_offs[-1](br_mod)

        outs = [tl_heat, br_heat, tl_tag, br_tag, tl_off, br_off]
        return self._decode(*outs, **kwargs), tl_heat, br_heat, tl_tag, br_tag

    def forward(self, *xs, test=False, **kwargs):
        if not test:
            return self._train(*xs, **kwargs)
        return self._test(*xs, **kwargs)

二、损失函数

core.models.py_utils.losses.py

import torch
import torch.nn as nn

from .utils import _tranpose_and_gather_feat

def _sigmoid(x):  #sigmoid函数,S型生长曲线,压扩限幅,将变量映射到0,1之间
    return torch.clamp(x.sigmoid_(), min=1e-4, max=1-1e-4)

def _ae_loss(tag0, tag1, mask):  #嵌入损失  # mask是[batch_size, 128]
    num  = mask.sum(dim=1, keepdim=True).float()  #一个通道中的目标数,N
    tag0 = tag0.squeeze() #降维
    tag1 = tag1.squeeze()

    tag_mean = (tag0 + tag1) / 2 #ek

    tag0 = torch.pow(tag0 - tag_mean, 2) / (num + 1e-4)  #和论文公式一样,后面的应该是为了防止除0错误
    tag0 = tag0[mask].sum()  #取前面N个相加,因为每个通道的目标数不一样,不可能裁剪
    tag1 = torch.pow(tag1 - tag_mean, 2) / (num + 1e-4) #和论文一样
    tag1 = tag1[mask].sum()
    pull = tag0 + tag1

    mask = mask.unsqueeze(1) + mask.unsqueeze(2) #二维掩膜,使得嵌入两两相加,画成表格就能理解了
    mask = mask.eq(2)                            #但这样无法去除同一个点的情况吗(即j=k)
    num  = num.unsqueeze(2)
    num2 = (num - 1) * num
    dist = tag_mean.unsqueeze(1) - tag_mean.unsqueeze(2)  #骚操作,画成表格就能理解了
    dist = 1 - torch.abs(dist)  #默认Δ=1
    dist = nn.functional.relu(dist, inplace=True) #通过这玩意实现与0比较的取最值
    dist = dist - 1 / (num + 1e-4)   #因为在前面没有滤除相同角点的情况,所以最终累加和会多出N,即push会多出1/(N-1),因此在这里加上偏置,使得最终得数符合论文公式
    dist = dist / (num2 + 1e-4) 
    dist = dist[mask]
    push = dist.sum() #累加
    return pull, push

def _off_loss(off, gt_off, mask):  #偏移损失
    num  = mask.float().sum()  #一个通道中的目标数,N
    mask = mask.unsqueeze(2).expand_as(gt_off) #升维,对齐

    off    = off[mask]
    gt_off = gt_off[mask]
    
    off_loss = nn.functional.smooth_l1_loss(off, gt_off, reduction="sum") #套公式
    off_loss = off_loss / (num + 1e-4)
    return off_loss

def _focal_loss(preds, gt):  #gt已经过2d高斯化
    pos_inds = gt.eq(1)  #得到等于1的为正位置 equal
    neg_inds = gt.lt(1)  #小于1的为负位置 less than
    #超参数α,β设为2,4
    neg_weights = torch.pow(1 - gt[neg_inds], 4)  #和论文损失函数表达式一致。

    loss = 0
    for pred in preds:  #因为有中间监督,有多张热图 #多个维度,类
        pos_pred = pred[pos_inds]  #正位置的得分
        neg_pred = pred[neg_inds]  #负位置的得分

        pos_loss = torch.log(pos_pred) * torch.pow(1 - pos_pred, 2)  #论文公式
        neg_loss = torch.log(1 - neg_pred) * torch.pow(neg_pred, 2) * neg_weights  #论文公式

        num_pos  = pos_inds.float().sum()  #图像中目标的数量
        pos_loss = pos_loss.sum()  #累加
        neg_loss = neg_loss.sum()  #累加

        if pos_pred.nelement() == 0:  #nelement() 可以统计 tensor (张量) 的元素的个数,若没有正位置
            loss = loss - neg_loss
        else:
            loss = loss - (pos_loss + neg_loss) / num_pos
    return loss

class CornerNet_Loss(nn.Module):  #self.loss = CornerNet_Loss(pull_weight=1e-1, push_weight=1e-1)  #传入工厂函数里的损失函数
    def __init__(self, pull_weight=1, push_weight=1, off_weight=1, focal_loss=_focal_loss): #超参数和论文不一致
        super(CornerNet_Loss, self).__init__()

        self.pull_weight = pull_weight
        self.push_weight = push_weight
        self.off_weight  = off_weight
        self.focal_loss  = focal_loss
        self.ae_loss     = _ae_loss
        self.off_loss    = _off_loss

    def forward(self, outs, targets):
        tl_heats = outs[0]  #热图
        br_heats = outs[1]
        tl_tags  = outs[2]  #嵌入
        br_tags  = outs[3]
        tl_offs  = outs[4]  #偏移
        br_offs  = outs[5]

        gt_tl_heat  = targets[0]  #真实值热图
        gt_br_heat  = targets[1]
        gt_mask     = targets[2]  #二进制掩码,每个通道一行,前面为1,个数为目标数
        gt_tl_off   = targets[3]
        gt_br_off   = targets[4]
        gt_tl_ind   = targets[5]  #真实点位置的索引
        gt_br_ind   = targets[6]

        # focal loss
        focal_loss = 0

        tl_heats = [_sigmoid(t) for t in tl_heats]  #热图得分压扩 将得分映射到0,1之间
        br_heats = [_sigmoid(b) for b in br_heats]

        focal_loss += self.focal_loss(tl_heats, gt_tl_heat)  #焦点损失
        focal_loss += self.focal_loss(br_heats, gt_br_heat)

        # tag loss
        pull_loss = 0
        push_loss = 0
        tl_tags   = [_tranpose_and_gather_feat(tl_tag, gt_tl_ind) for tl_tag in tl_tags] #只取真实点的位置,来算嵌入损失
        br_tags   = [_tranpose_and_gather_feat(br_tag, gt_br_ind) for br_tag in br_tags]
        for tl_tag, br_tag in zip(tl_tags, br_tags):
            pull, push = self.ae_loss(tl_tag, br_tag, gt_mask) #注意这里三个张量都已经一维排列成一条直线了
            pull_loss += pull
            push_loss += push
        pull_loss = self.pull_weight * pull_loss
        push_loss = self.push_weight * push_loss  #加权

        off_loss = 0
        tl_offs  = [_tranpose_and_gather_feat(tl_off, gt_tl_ind) for tl_off in tl_offs]  #只取真实点的位置
        br_offs  = [_tranpose_and_gather_feat(br_off, gt_br_ind) for br_off in br_offs]
        for tl_off, br_off in zip(tl_offs, br_offs):
            off_loss += self.off_loss(tl_off, gt_tl_off, gt_mask) #前两个张量是“一条直线”,直线上每个元素是二维的。
            off_loss += self.off_loss(br_off, gt_br_off, gt_mask)
        off_loss = self.off_weight * off_loss

        loss = (focal_loss + pull_loss + push_loss + off_loss) / max(len(tl_heats), 1) #完全损失
        return loss.unsqueeze(0)  #升维,后续图像处理可以更好地进行批操作