医学图像分割主要有两种框架,一个是基于​​CNN​​​的,另一个就是基于​​FCN​​的。

基于CNN 的框架

这个想法也很简单,就是对图像的每一个像素点进行分类,在每一个像素点上取一个​​patch​​,当做一幅图像,输入神经网络进行训练,举个例子:

U-net网络_2d

这是一篇发表在​​NIPS​​上的论文Ciresan D, Giusti A, Gambardella L M, et al. Deep neural networks segment neuronal membranes in electron microscopy images[C]//Advances in neural information processing systems. 2012: 2843-2851.

这是一个二分类问题,把图像中所有​​label​​​为0的点作为负样本,所有​​label​​为1的点作为正样本。

这种网络显然有两个缺点:


  • 冗余太大,由于每个像素点都需要取一个​​patch​​​,那么相邻的两个像素点的​​patch​​相似度是非常高的,这就导致了非常多的冗余,导致网络训练很慢。
  • 感受野和定位精度不可兼得,当感受野选取比较大的时候,后面对应的​​pooling​​层的降维倍数就会增大,这样就会导致定位精度降低,但是如果感受野比较小,那么分类精度就会降低。

基于FCN框架

大名鼎鼎的FCN就不多做介绍了,这里有一篇很好的博文。

不过还是建议把论文读一下,这样才能加深理解。

U-net网络

在医学图像处理领域,有一个应用很广泛的网络结构----​​U-net​​ ,网络结构如下:

U-net网络_ide_02

网络结构如图所示, 蓝色代表卷积和激活函数, 灰色代表复制和裁剪, 红色代表下采样, 绿色代表上采样然后在卷积, ​​conv 1X1​​​代表核为​​1X1​​的卷积操作, 可以看出这个网络没有全连接,只有卷积和下采样. 这也是一个端到端的图像, 即输入是一幅图像, 输出也是一副图像。

可以看出来,就是一个全卷积神经网络,输入和输出都是图像,没有全连接层。较浅的高分辨率层用来解决像素定位的问题,较深的层用来解决像素分类的问题。

在卷积过程中,可以通过设置不同的卷积核以及​​padding​​,​​stride​​的大小,控制同一层图像的尺寸不变,同样下采样的尺寸也是可以控制的。

具体的公式如下:

n e w s i z e = f − k + 2 p s + 1 newsize = \frac{f-k+2p}{s}+1 newsize=sf−k+2p​+1

其中​​k​​是卷积核尺寸,​​p​​是​​padding​​的值,​​s​​是​​stride​​的值

根据上述公式,令​​k=3​​, ​​p=1​​, ​​s=1​​,有

n e w s i z e = f − 3 + 2 1 + 1 = f newsize = \frac{f-3+2}{1}+1=f newsize=1f−3+2​+1=f

可以使得图像在卷积过程中的尺寸不发生变化

这是一个经常使用的​​trick​

下采样,常用的​​maxpooling​​,通常是使得图像尺寸减小​​1/2​​。

U-net网络_2d_03

这会儿再看网络结构,原论文中​​k=3​​,​​p=0​​,​​s=1​​,所以原始的尺寸是572

根据公式有 572 − 3 + 2 ∗ 0 1 + 1 = 570 \frac{572-3+2*0}{1}+1=570 1572−3+2∗0​+1=570 。所以从​​channel1-channel64​​,图像尺寸从572减小到了570。

以此类推,蓝色小箭头表示卷积操作,每次图像宽和高减小2,红色小箭头表示​​maxpooling​​,图像尺寸减半,

如图中第一个红箭头处,从568减小到了284。

U-net网络_卷积_04

绿色小箭头代表上采样,与红色相对应,这里使图像尺寸增加2倍,如下图中1024到512。

U-net网络_卷积_05

网络的创新之处就是​​concatenate​​这个操作了,也就是图中的灰色箭头这个位置。

U-net网络_2d_06

可以看到,从1024个​​channel​​上采样过来的512个​​channel​​,和上一个对应有512个​​channel​​的同层​​channel​​被叠加起来了,成为了一个新的1024个​​channel​​,如图红圈部分。同理类似256+256构成了新的512。

与​​FCN​​不同的是,​​FCN​​在这一步中是直接与之前的进行对应像素相加操作,都是很巧妙的操作。

到此,整个​​Unet​​的结构细节全部解析完毕

跟着推导一遍尺寸,会受益匪浅,入门其他的网络也会相对容易。给出​​pytorch​​实现的​​Unet​​模型,用到了我所说的同层尺寸不变的​​trick​​。

import torch.nn as nn


class conv_block(nn.Module):
def __init__(self, ch_in, ch_out):
super(conv_block, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True),
nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True)
)

def forward(self, x):
x = self.conv(x)
return x

class up_conv(nn.Module):
def __init__(self, ch_in, ch_out):
super(up_conv, self).__init__()
self.up = nn.Sequential(
nn.Upsample(scale_factor=2),
nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True)
)

def forward(self, x):
x = self.up(x)
return x


class U_Net(nn.Module):
def __init__(self, img_ch=3, output_ch=1):
super(U_Net, self).__init__()

self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

self.Conv1 = conv_block(ch_in=img_ch, ch_out=64)
self.Conv2 = conv_block(ch_in=64, ch_out=128)
self.Conv3 = conv_block(ch_in=128, ch_out=256)
self.Conv4 = conv_block(ch_in=256, ch_out=512)
self.Conv5 = conv_block(ch_in=512, ch_out=1024)

self.Up5 = up_conv(ch_in=1024, ch_out=512)
self.Up_conv5 = conv_block(ch_in=1024, ch_out=512)

self.Up4 = up_conv(ch_in=512, ch_out=256)
self.Up_conv4 = conv_block(ch_in=512, ch_out=256)

self.Up3 = up_conv(ch_in=256, ch_out=128)
self.Up_conv3 = conv_block(ch_in=256, ch_out=128)

self.Up2 = up_conv(ch_in=128, ch_out=64)
self.Up_conv2 = conv_block(ch_in=128, ch_out=64)

self.Conv_1x1 = nn.Conv2d(64, output_ch, kernel_size=1, stride=1, padding=0)

def forward(self, x):
# encoding path
x1 = self.Conv1(x)

x2 = self.Maxpool(x1)
x2 = self.Conv2(x2)

x3 = self.Maxpool(x2)
x3 = self.Conv3(x3)

x4 = self.Maxpool(x3)
x4 = self.Conv4(x4)

x5 = self.Maxpool(x4)
x5 = self.Conv5(x5)

# decoding + concat path
d5 = self.Up5(x5)
d5 = torch.cat((x4, d5), dim=1)

d5 = self.Up_conv5(d5)

d4 = self.Up4(d5)
d4 = torch.cat((x3, d4), dim=1)
d4 = self.Up_conv4(d4)

d3 = self.Up3(d4)
d3 = torch.cat((x2, d3), dim=1)
d3 = self.Up_conv3(d3)

d2 = self.Up2(d3)
d2 = torch.cat((x1, d2), dim=1)
d2 = self.Up_conv2(d2)

d1 = self.Conv_1x1(d2)

return d1

​​