文章目录
- 前言
- 1、网络结构
- 2、代码解读
- resnet50
- 总结
前言
整理下特征提取网络resnet的网络结构
1、网络结构
有5个输出层C1,C2,C3,C4,C5,其中常用的是C2,C3,C4,C5层。没有单独的层进行下采样,直接在残差的时候进行下采样。
2、代码解读
resnet50
整个resnet50的forward代码如下(示例):
def forward(self, x):
"""Forward function."""
if self.deep_stem: # teem层
x = self.stem(x)
else:
x = self.conv1(x)
x = self.norm1(x)
x = self.relu(x)
x = self.maxpool(x)
outs = []
for i, layer_name in enumerate(self.res_layers):
res_layer = getattr(self, layer_name) #获取相应名字的layer层:layer0,layer1...
x = res_layer(x) # 进行操作
if i in self.out_indices: # 输出索引,指定输出的层数,用于后续的FPN操作。
outs.append(x)
return tuple(outs)
- stem层:用三个3X3卷积(步长为2,padding=1)代替一个7X7卷积(步长为2,padding=3),保持输出特征分辨率不变。
- conv1层:和stem层一样,图片3维输入,经过一个conv1(7X7卷积后)256维输出,再经过BN和Relu(一般不用stem而是选择用一个7x7卷积),用7x7卷积在大感受野的情况下,保持输出特征分辨率不变。
- maxpool层:一个3X3的最大池化,步长为2。
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
- 残差块:主干由三个卷积组成。一个1X1卷积(步长stride为1或2)+BN+Relu+一个3X3卷积+BN+Relu+一个1X1卷积+BN+(残差连接:一个1X1卷积(步长stride为1或2)+Relu。首先通过1X1卷积压缩通道(4倍),然后3X3卷积,1x1卷积还原到原来的维度,最后还有个1X1卷积从输入直接连接到输出,实现残差相加。如果卷积过程中进行下采样(第一个步长!=1)或者通道数要发生变化(输入!=输出)时,残差的1x1卷积的步长变成能够与输出分辨率匹配的大小(如stride=2)。
- Bottleneck:resnet50有4个block,[3,4,6,3]。由网络结构图可知一个block,是由几个残差块堆叠而成的。并且进行下采样时,pytorch格式的都是在第一个残差块的第一个3X3卷积进行(stride=2,padding=1),有些是在第一个残差块的1X1卷积上进行下采样(stride=2,padding=1)。
对应结构图:
- conv2层:3个残差块组成,输入通道数由64变成64*4=256,图像分辨率在池化下采样后没变。
- conv3层:4个残差块组成,输入通道数由256变成128,最后变成12*84=512,图像分辨率减小一半,说明conv3第一个残差块的3x3卷积步长为2,残差连接的步长也应为2。
- conv4层:6个残差块组成,输入通道数由512变成256,最后变成256*4=1024,图像分辨率减小一半,说明conv4第一个残差块的3x3卷积步长为2,残差连接的步长也应为2。
- conv5层:3个残差块组成,输入通道数由1024变成512,最后变成512*4=2048,图像分辨率减小一半,说明conv4第一个残差块的3x3卷积步长为2,残差连接的步长也应为2。
conv2,3,4,5层的前向操作
def forward(self, x):
"""Forward function."""
def _inner_forward(x):
identity = x
out = self.conv1(x)
out = self.norm1(out)
out = self.relu(out)
if self.with_plugins:
out = self.forward_plugin(out, self.after_conv1_plugin_names)
out = self.conv2(out)
out = self.norm2(out)
out = self.relu(out)
if self.with_plugins:
out = self.forward_plugin(out, self.after_conv2_plugin_names)
out = self.conv3(out)
out = self.norm3(out)
if self.with_plugins:
out = self.forward_plugin(out, self.after_conv3_plugin_names)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
return out
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
out = self.relu(out)
return out
总结
Resnet的残差模块,使得神经网络能够有效的减轻梯度因为网络层数的逐渐加深而导致的梯度消失的问题。是一个十分经典的特征提取网络模块,后面还有基于resnet的res2net,resnest和resnext的改进。