【深度学习基础】PyTorch实现ResNeXt亲身实践

  • 1 论文关键信息
  • 1.1 ResNeXt基本block 结构
  • 1.1.1 原始结构
  • 1.1.2 优化实现方式
  • 1.2 ResNeXt网络结构
  • 2 ResNeXt的pytorch实现
  • 2.0 BN_CONV_RELU
  • 2.1 Block结构
  • 2.2 ResNeXt网络整体结构
  • 2.3 搭建网络并测试


1 论文关键信息

这篇论文是Kaiming He团队收Inception结构的启发,基于ResNet结构提出的一种网络。
论文地址Aggregated Residual Transformations for Deep Neural Networks

主要思想有以下几点:
1、总结了Inception系列的基本结构,并将其形象地称为split-transform-merge三个阶段:通过1x1的卷积将输入降维;通过一定制化的3x3或5x5的卷积进行变换;然后将结果进行拼接。

2、论文肯定了inception结构是有效的,但是深度定制对于不同数据集的训练来说,修改网络结构太麻烦了。(看过我上一篇博客的同学们,应该也注意到,Inception的结构都是经过精心设计的,当需要调整输入大小的时候,不好修改)。所以,ResNeXt采用split-transform-merge的思想,但是沿用VGG/ResNet的构造重复的卷积层的策略,使网络在具有性能的基础上更加优雅简洁。

3、论文基于split-transform-merge的结构,提出了“cardinality”的概念,ResNeXt通过它控制transform部分的网络宽度+深度,下一节的内容会详细解释cardinality的含义和作用。

4、论文通过实验发现,增加cardinality对精度的提升作用比增加网络的宽度和深度来得更加有效,尤其是当网络的深度和宽度即将达到梯度消失的界限时。

1.1 ResNeXt基本block 结构

1.1.1 原始结构

下图是ResNeXt的一个基本block,左图其基本结构,来自于ResNet的BottleNeck(有兴趣的同学可以看我之前的博客实现)。受Inception启发论文将Residual部分分成若干个支路,这个支路的数量就是cardinality的含义。 右图是ResNeXt的一个32x4d的基本结构,32指的是cardinality是32,即利用1x1卷积降维,并分成32条支路; 4d指的是每个支路中transform的3x3卷积的滤波器数量为4。

resin rest resin restoration_resin rest

1.1.2 优化实现方式

下图给出了ResNeXt基本block的实现的优化推演过程,论文中提出这三种结构是(近似)等价的。(a)是最原始的block结构。(b)与(a)的区别是,其先将3x3的transform结果进行merge,然后再用1x1的卷积调整维度,这个过程也减少了一些参数。(c)与(b)是原理等价的,只不过其采用分组卷积(group convolution)的形式,在实现中更利于GPU运算。所以,我在搭建ResNeXt基本结构的形式是,采用(c)这种形式去复现。

resin rest resin restoration_ide_02

1.2 ResNeXt网络结构

下图是ResNet-50和ResNeXt-50(32x4d)的对比,可以发现二者网络整体结构一致,ResNeXt替换了基本的block。文章中提到的几种结构有ResNeXt29_8X64d, ResNeXt29_16x64d, ResNeXt50_32x3d, ResNeXt101_32x4d, ResNeXt101_64x4d。我们需要知道,比如8x64d指的是——cardinality为8,3x3的transform滤波器数量为64。

resin rest resin restoration_resin rest_03

2 ResNeXt的pytorch实现

2.0 BN_CONV_RELU

再次给出BN_Conv2d的代码:

class BN_Conv2d(nn.Module):
    """
    BN_CONV_RELU
    """

    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation=1, groups=1, bias=False):
        super(BN_Conv2d, self).__init__()
        self.seq = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
                      padding=padding, dilation=dilation, groups=groups, bias=bias),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, x):
        return F.relu(self.seq(x))

2.1 Block结构

在实现block的时候,需要传入的控制参数有输入维度,cardinality,transform滤波器的数量

class ResNeXt_Block(nn.Module):
    """
    ResNeXt block with group convolutions
    """

    def __init__(self, in_chnls, cardinality, group_depth, stride):
        super(ResNeXt_Block, self).__init__()
        self.group_chnls = cardinality * group_depth
        self.conv1 = BN_Conv2d(in_chnls, self.group_chnls, 1, stride=1, padding=0)
        self.conv2 = BN_Conv2d(self.group_chnls, self.group_chnls, 3, stride=stride, padding=1, groups=cardinality)
        self.conv3 = nn.Conv2d(self.group_chnls, self.group_chnls*2, 1, stride=1, padding=0)
        self.bn = nn.BatchNorm2d(self.group_chnls*2)
        self.short_cut = nn.Sequential(
            nn.Conv2d(in_chnls, self.group_chnls*2, 1, stride, 0, bias=False),
            nn.BatchNorm2d(self.group_chnls*2)
        )

    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        out = self.bn(self.conv3(out))
        out += self.short_cut(x)
        return F.relu(out)

2.2 ResNeXt网络整体结构

class ResNeXt(nn.Module):
    """
    ResNeXt builder
    """

    def __init__(self, layers: object, cardinality, group_depth, num_classes) -> object:
        super(ResNeXt, self).__init__()
        self.cardinality = cardinality
        self.channels = 64
        self.conv1 = BN_Conv2d(3, self.channels, 7, stride=2, padding=3)
        d1 = group_depth
        self.conv2 = self.___make_layers(d1, layers[0], stride=1)
        d2 = d1 * 2
        self.conv3 = self.___make_layers(d2, layers[1], stride=2)
        d3 = d2 * 2
        self.conv4 = self.___make_layers(d3, layers[2], stride=2)
        d4 = d3 * 2
        self.conv5 = self.___make_layers(d4, layers[3], stride=2)
        self.fc = nn.Linear(self.channels, num_classes)   # 224x224 input size

    def ___make_layers(self, d, blocks, stride):
        strides = [stride] + [1] * (blocks-1)
        layers = []
        for stride in strides:
            layers.append(ResNeXt_Block(self.channels, self.cardinality, d, stride))
            self.channels = self.cardinality*d*2
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x)
        out = F.max_pool2d(out, 3, 2, 1)
        out = self.conv2(out)
        out = self.conv3(out)
        out = self.conv4(out)
        out = self.conv5(out)
        out = F.avg_pool2d(out, 7)
        out = out.view(out.size(0), -1)
        out = F.softmax(self.fc(out))
        return out

2.3 搭建网络并测试

def resNeXt50_32x4d(num_classes=1000):
    return ResNeXt([3, 4, 6, 3], 32, 4, num_classes)


def resNeXt101_32x4d(num_classes=1000):
    return ResNeXt([3, 4, 23, 3], 32, 4, num_classes)


def resNeXt101_64x4d(num_classes=1000):
    return ResNeXt([3, 4, 23, 3], 64, 4, num_classes)


def test():
    # net = resNeXt50_32x4d()
    net = resNeXt101_32x4d()
    # net = resNeXt101_64x4d()
    summary(net, (3, 224, 224))


test()

ResNeXt101_64x4d的测试输出结果:

resin rest resin restoration_resin rest_04