VGG网络结构的代码搭建以及代码的详细解析(基于PyTorch)
import torch.nn as nn
import torch
from torchvision import transforms
import cv2
import math
class VGG(nn.Module):
def __init__(self, features, num_classes=1000, init_weights=True):
super(VGG, self).__init__()#python3中等效于super().__init__()
self.features = features#前面的卷积结构各层函数,在下面的函数forward()被调用
self.classifier = nn.Sequential(#从全连接到分类的结构层函数的一个顺序容器
nn.Linear(512 * 7 * 7, 4096),#第一个全连接层,输入为[batch_size,512*7*7],输出大小为[batch_size,4096],bias默认为True
nn.ReLU(True),#ReLU激活函数
nn.Dropout(),#Dropout函数
nn.Linear(4096, 4096),
nn.ReLU(True),
nn.Dropout(),
#最后一层不需要添加激活函数
nn.Linear(4096, num_classes),#最后一层全连接,输出为num_size,类别的个数
)#通过Squential将网络层和激活函数和Dropout函数结合起来,输出激活后的网络节点。
if init_weights:
self._initialize_weights()
#备注:nn.Linear(in_features,out_feartures,bias=True)用于设置网络中的全连接层的,需要注意的是全连接层的输入与输出都是二维张量,一般形状为[batch_size, size]
#参数:in_features,out_feartures,bias
#in_features指的是输入的二维张量的大小,即输入的[batch_size, size]中的size。
#out_features指的是输出的二维张量的大小,即输出的二维张量的形状为[batch_size,output_size],当然,它也代表了该全连接层的神经元个数。
def forward(self, x):#x为输入的图片张量
x = self.features(x)#卷积层顺序容器,输入x,输出经过所有卷积层后的特征层
print(x,'-->',type(x),'-->',x.shape,'-->',x.dtype)#x=[batch_size,C,H,W]
x = x.view(x.size(0), -1)#把x维度进行调整,保持batch数一致,[batch_size,一个展开后的特征层]
x = self.classifier(x)#全连接层的顺序容器,最后输出num_classes,用于最后的类别判断
return x
def _initialize_weights(self):#权重初始化
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.01)
m.bias.data.zero_()
def make_layers(cfg, batch_norm=False):#卷积层的实现函数
layers = []
in_channels = 3 #初始通道数为3,RGB
for v in cfg:
if v == 'M':#判断是否为池化层
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]#经过最大池化,W,H变为原来的1/2,
else:#卷积层+/BN+ReLU
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)#卷积层(输入通道数,输出通道数,卷积核的大小,边缘补数)
if batch_norm:#批处理规范化
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
else:#不进行批处理
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = v#更改通道数
return nn.Sequential(*layers)#返回卷积部分的顺序容器
cfg = {
'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}#网络的卷积层池化层结构类型
def vgg11(**kwargs):
model = VGG(make_layers(cfg['A']), **kwargs)#实例化类VGG,此处执行类的构造函数__init__
#这里把VGG结构的第一个结构vgg11的卷积层部分,make_layers(cfg['A'])作为feature传入VGG类中
return model #返回卷积层部分的顺序容器和全连接层部分的顺序容器
def vgg11_bn(**kwargs):#是在vgg11上添加batch_noorm
model = VGG(make_layers(cfg['A'], batch_norm=True), **kwargs)
return model
def vgg13(**kwargs):
model = VGG(make_layers(cfg['B']), **kwargs)
return model
def vgg13_bn(**kwargs):
model = VGG(make_layers(cfg['B'], batch_norm=True), **kwargs)
return model
def vgg16(**kwargs):
model = VGG(make_layers(cfg['D']), **kwargs)
return model
def vgg16_bn(**kwargs):
model = VGG(make_layers(cfg['D'], batch_norm=True), **kwargs)
return model
def vgg19(**kwargs):
model = VGG(make_layers(cfg['E']), **kwargs)
return model
def vgg19_bn(**kwargs):
model = VGG(make_layers(cfg['E'], batch_norm=True), **kwargs)
return model
img = cv2.imread("img/1.jpg") # 读取输入网络的图片
print(img,'-->',type(img),'-->',img.shape,'-->',img.dtype)
trans = transforms.Compose(
[
transforms.ToTensor(),#Convert a PIL Image or numpy.ndarray to tensor.
#transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
img = trans(img)#把读入的图片数据转化为torch.tensor类型的数据,[C,H,W]
print(img,'-->',type(img),'-->',img.shape)
img = img.unsqueeze(0) # 图片扩展多一维,因为输入到保存的模型中是4维的[batch_size,通道,长,宽],而普通图片只有三维,[通道,长,宽]
print(img,'-->',type(img),'-->',img.shape,'-->',img.dtype)
if __name__ == '__main__':#执行本文件时,if条件满足,执行下面的语句
net19 = vgg19()#调用vgg19,vgg19()实例化类VGG,执行类的构造方法,返回卷积层和全连接层的顺序容器
print(net19)
x=net19.forward(img)#调用类的函数forward(),把img送入函数,执行真个网络,得到1000类的输出
print(x,'-->',type(x),'-->',x.shape,'-->',x.dtype)