参考UNET-2015

网络结构

最新基于VIT的语义分割模型cvpr_2d


如上图,Unet 网络结构是对称的,形似英文字母 U 所以被称为 Unet。整张图都是由蓝/白色框与各种颜色的箭头组成,其中,蓝/白色框表示 feature map;蓝色箭头表示 3x3 卷积,用于特征提取;灰色箭头表示 skip-connection,用于特征融合;红色箭头表示池化 pooling,用于降低维度;绿色箭头表示上采样 upsample,用于恢复维度;青色箭头表示 1x1 卷积,用于输出结果。

Encoder
Encoder 由卷积操作和下采样操作组成,文中所用的卷积结构统一为 3x3 的卷积核,padding 为 0 ,striding 为 1。没有 padding 所以每次卷积之后 feature map 的 H 和 W 变小了,在 skip-connection 时要注意 feature map 的维度(其实也可以将 padding 设置为 1 避免维度不对应问题),pytorch 代码:

nn.Sequential(nn.Conv2d(in_channels, out_channels, 3),
              nn.BatchNorm2d(out_channels),
              nn.ReLU(inplace=True))

上述的两次卷积之后是一个 stride 为 2 的 max pooling,输出大小变为 1/2 *(H, W):
pytorch 代码:

nn.MaxPool2d(kernel_size=2, stride=2)

上面的步骤重复 5 次,最后一次没有 max-pooling,直接将得到的 feature map 送入 Decoder。

Decoder
feature map 经过 Decoder 恢复原始分辨率,该过程除了卷积比较关键的步骤就是 upsampling 与 skip-connection。
Upsampling 上采样常用的方式有两种:1. FCN中介绍的反卷积;2. 插值。这里介绍文中使用的插值方式。在插值实现方式中,bilinear 双线性插值的综合表现较好也较为常见 。

双线性插值
双线性插值的计算过程没有需要学习的参数,实际就是套公式,举个例子方便大家理解(例子介绍的是参数 align_corners 为 Fasle 的情况)。

import torch
import numpy as np
import torch.nn as nn

src = torch.Tensor(np.asarray([[[[10, 20], [30, 40]]]]))
print('src: ', src)
up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
print('dst: ', up(src))
# 结果
src:  tensor([[[[10., 20.],
          [30., 40.]]]])
dst:  tensor([[[[10.0000, 12.5000, 17.5000, 20.0000],
          [15.0000, 17.5000, 22.5000, 25.0000],
          [25.0000, 27.5000, 32.5000, 35.0000],
          [30.0000, 32.5000, 37.5000, 40.0000]]]])

例子中是将一个 2x2 的矩阵通过插值的方式得到 4x4 的矩阵,那么将 2x2 的矩阵称为源矩阵,4x4 的矩阵称为目标矩阵。双线性插值中,目标点的值是由离他最近的 4 个点的值计算得到的,我们首先介绍如何找到目标点周围的 4 个点,以 P2 为例。

最新基于VIT的语义分割模型cvpr_双线性插值_02


第一个公式,目标矩阵到源矩阵的坐标映射:

最新基于VIT的语义分割模型cvpr_卷积_03


最新基于VIT的语义分割模型cvpr_深度学习_04


为了找到那 4 个点,首先要找到目标点在源矩阵中的相对位置,上面的公式就是用来算这个的。P2 在目标矩阵中的坐标是 (0, 1),对应到源矩阵中的坐标就是 (-0.25, 0.25)。坐标里面居然有小数跟负数,不急我们一个一个来处理。我们知道双线性插值是从坐标周围的 4 个点来计算该坐标的值,(-0.25, 0.25) 这个点周围的 4 个点是(-1, 0), (-1, 1), (0, 0), (0, 1)。为了找到负数坐标点,我们将源矩阵扩展为下面的形式,中间红色的部分为源矩阵。

最新基于VIT的语义分割模型cvpr_2d_05


我们规定 f(i, j) 表示 (i, j)坐标点处的像素值,对于计算出来的对应的坐标,我们统一写成 (i+u, j+v) 的形式。那么这时 i=-1, u=0.75, j=0, v=0.25。把这 4 个点单独画出来,可以看到目标点 P2 对应到源矩阵中的相对位置。

最新基于VIT的语义分割模型cvpr_双线性插值_06


第二个公式,也是最后一个。

最新基于VIT的语义分割模型cvpr_2d_07

目标点的像素值就是周围 4 个点像素值的加权和,明显可以看出离得近的权值比较大例如 (0, 0) 点的权值就是 0.75x0.75,离得远的如 (-1, 1) 权值就比较小,为 0.25x0.25,这也比较符合常理吧。把值带入计算就可以得到 P2 点的值了,结果是 12.5 与代码吻合上了.

CNN 网络要想获得好效果,skip-connection 基本必不可少。Unet 中这一关键步骤融合了底层信息的位置信息与深层特征的语义信息,pytorch 代码:

torch.cat([low_layer_features, deep_layer_features], dim=1)

这里需要注意的是,FCN 中深层信息与浅层信息融合是通过对应像素相加的方式,而 Unet 是通过拼接的方式

那么这两者有什么区别呢,其实 在 ResNet 与 DenseNet 中也有一样的区别Resnet 使用了对应值相加DenseNet 使用了拼接。个人理解在相加的方式下,feature map 的维度没有变化,但每个维度都包含了更多特征,对于普通的分类任务这种不需要从 feature map 复原到原始分辨率的任务来说,这是一个高效的选择;而拼接则保留了更多的维度/位置 信息,这使得后面的 layer 可以在浅层特征与深层特征自由选择,这对语义分割任务来说更有优势

参考代码
基础模块,DoubleConv, Down, Up, OutConv

import torch
import torch.nn as nn
import torch.nn.functional as F


class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)


    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])

        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        # self.conv.bias.data.fill_(1)

    def forward(self, x):
        return self.conv(x)

组合模块如下:

class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)
        self.sigmoid = nn.Softmax(dim=1)
        #self.sigmoid = nn.Sigmoid()

    def forward(self, x):

        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        sigmoid = self.sigmoid(logits)
        return logits, sigmoid