目录

一、论文

二、模型介绍

三、模型预训练


一、论文

先来看看VGG这篇论文《Very Deep Convolutional Networks for Large-Scale Image Recognition》论文下载地址

论文中几个模型主要以几下几种方案A、B、C、D、E。目前主要还是采用VGG16和VGG19也就是下图中的分别红框和绿框部分。 

 

pytorch模型训练pt pytorchvgg16模型训练_Image

二、模型介绍

其实通过上面的表格就已经大致知道模型的框架组成部分了。其实VGG16与VGG19的区别就是前者在三、四、五部分少了一层卷积。这里先附基于pytorch的一些预训练模型预训练模型下载地址

pytorch模型训练pt pytorchvgg16模型训练_加载_02

上图可以看出VGG分有无BatchNormalization。这里先介绍一下VGG16_bn的一些内部层结构。

VGG16_bn

序号

层结构

层数

权重

0

conv1-1

1

64x3x3

1

batchnorm

 

2

relu1-1

 

3

conv1-2

2

64x3x3

4

batchnorm

 

5

relu1-2

 

6

pool1

 

7

conv2-1

3

128x3x3

8

batchnorm

 

9

relu2-1

 

10

conv2-2

4

128x3x3

11

batchnorm

 

12

relu2-2

 

13

pool2

 

14

conv3-1

5

256x3x3

15

batchnorm

 

16

relu3-1

 

17

conv3-2

6

256x3x3

18

batchnorm

 

19

relu3-2

 

20

conv3-3

7

256x3x3

21

batchnorm

 

22

relu3-3

 

23

pool3

 

24

conv4-1

8

512x3x3

25

batchnorm

 

26

relu4-1

 

27

conv4-2

9

512x3x3

28

batchnorm

 

29

relu4-2

 

30

conv4-3

10

512x3x3

31

batchnorm

 

32

relu4-3

 

33

pool4

 

34

conv5-1

11

512x3x3

35

batchnorm

 

36

relu5-1

 

37

conv5-2

12

512x3x3

38

batchnorm

 

39

relu5-2

 

40

conv5-3

13

512x3x3

41

batchnorm

 

42

relu5-3

512x3x3

43

pool5

 

44

fc6(4096)

14

 

45

relu6

 

 

46

fc7(4096)

15

 

47

relu7

 

 

48

fc8(1000)

16

 

49

prob(softmax)

 

 

层是指卷积层和全连接层)VGG16则仅仅去掉红色部分的Batch_normalization部分。这里可以看到VGG16_bn的modules共有44个(这里不算全连接层),如果是VGG16则有31个(不算全连接层)。

下图是通过导入VGG16_bn模型在调试过程中的结果,可见与上面是一致。

pytorch模型训练pt pytorchvgg16模型训练_全连接_03

pytorch模型训练pt pytorchvgg16模型训练_加载_04

三、模型预训练

3.1加载整个模型

基于pytorch模型预训练,首先都要导入加载模型。有两种方式,下面一一介绍。

1.采用在线下载,这种一般受网络原因比较慢,不建议。

2.自己先下好预训练模型,从本地加载,这里介绍一下加载预训练模型后后自己提供一张图片进行分类识别。

import torch
import numpy
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
import torchvision.models as models

vgg = models.vgg16_bn()
pre=torch.load('./vgg16_bn-6c64b313.pth')
vgg.load_state_dict(pre)

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],#这是imagenet數據集的均值
                                 std=[0.229, 0.224, 0.225])

tran=transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
])


im='./1.jpg'
im=Image.open(im)
im=tran(im)
im.unsqueeze_(dim=0)
print(im.shape)
# input()
out=vgg(im)
outnp=out.data[0]
ind=int(numpy.argmax(outnp))
print(ind)

from cls import d
print(d[ind])


print(out.shape)


# im.show()

3 主要有几个注意的地方。由于是加载VGG模型的,并提供自己一张图像进行预测,输入就必须符合VGG的格式。

  • VGG模型的图像读入方式采用PIL库所以就得使用PIL库进行读入图片
  • 输入图像的尺寸得必须和VGG保持一致224x224的三通道。(因为全连接层用的是VGG的)
  • 上面采用的normalization归一化方式 几个固定的参数是因为VGG数据的分布,其均值和方差
  • VGG的最终分类的类别是1000类,最终out=vgg(img)是一个1000元素的张量。

 4 查看加載的參數

pre = torch.load('./pretrain/vgg16_bn-6c64b313.pth')
    for key, v in pre.items():
        print(key, v.size())

加載得到的是VGG網絡參數,可以將其輸出查看,這裏只顯示其size

features.0.weight torch.Size([64, 3, 3, 3])
features.0.bias torch.Size([64])
features.1.weight torch.Size([64])
features.1.bias torch.Size([64])
features.1.running_mean torch.Size([64])
features.1.running_var torch.Size([64])
features.3.weight torch.Size([64, 64, 3, 3])
features.3.bias torch.Size([64])
features.4.weight torch.Size([64])
features.4.bias torch.Size([64])
features.4.running_mean torch.Size([64])
features.4.running_var torch.Size([64])
features.7.weight torch.Size([128, 64, 3, 3])
features.7.bias torch.Size([128])
features.8.weight torch.Size([128])
features.8.bias torch.Size([128])
features.8.running_mean torch.Size([128])
features.8.running_var torch.Size([128])
features.10.weight torch.Size([128, 128, 3, 3])
features.10.bias torch.Size([128])
features.11.weight torch.Size([128])
features.11.bias torch.Size([128])
features.11.running_mean torch.Size([128])
features.11.running_var torch.Size([128])
features.14.weight torch.Size([256, 128, 3, 3])
features.14.bias torch.Size([256])
features.15.weight torch.Size([256])
features.15.bias torch.Size([256])
features.15.running_mean torch.Size([256])
features.15.running_var torch.Size([256])
features.17.weight torch.Size([256, 256, 3, 3])
features.17.bias torch.Size([256])
features.18.weight torch.Size([256])
features.18.bias torch.Size([256])
features.18.running_mean torch.Size([256])
features.18.running_var torch.Size([256])
features.20.weight torch.Size([256, 256, 3, 3])
features.20.bias torch.Size([256])
features.21.weight torch.Size([256])
features.21.bias torch.Size([256])
features.21.running_mean torch.Size([256])
features.21.running_var torch.Size([256])
features.24.weight torch.Size([512, 256, 3, 3])
features.24.bias torch.Size([512])
features.25.weight torch.Size([512])
features.25.bias torch.Size([512])
features.25.running_mean torch.Size([512])
features.25.running_var torch.Size([512])
features.27.weight torch.Size([512, 512, 3, 3])
features.27.bias torch.Size([512])
features.28.weight torch.Size([512])
features.28.bias torch.Size([512])
features.28.running_mean torch.Size([512])
features.28.running_var torch.Size([512])
features.30.weight torch.Size([512, 512, 3, 3])
features.30.bias torch.Size([512])
features.31.weight torch.Size([512])
features.31.bias torch.Size([512])
features.31.running_mean torch.Size([512])
features.31.running_var torch.Size([512])
features.34.weight torch.Size([512, 512, 3, 3])
features.34.bias torch.Size([512])
features.35.weight torch.Size([512])
features.35.bias torch.Size([512])
features.35.running_mean torch.Size([512])
features.35.running_var torch.Size([512])
features.37.weight torch.Size([512, 512, 3, 3])
features.37.bias torch.Size([512])
features.38.weight torch.Size([512])
features.38.bias torch.Size([512])
features.38.running_mean torch.Size([512])
features.38.running_var torch.Size([512])
features.40.weight torch.Size([512, 512, 3, 3])
features.40.bias torch.Size([512])
features.41.weight torch.Size([512])
features.41.bias torch.Size([512])
features.41.running_mean torch.Size([512])
features.41.running_var torch.Size([512])
classifier.0.weight torch.Size([4096, 25088])
classifier.0.bias torch.Size([4096])
classifier.3.weight torch.Size([4096, 4096])
classifier.3.bias torch.Size([4096])
classifier.6.weight torch.Size([1000, 4096])
classifier.6.bias torch.Size([1000])

上面的feature最多是42個,不是44個,因爲relu和pool沒有顯示出來,其分別是features.42  feature.43.因爲加載的參數pre裏面包含的內容是參數,而relu操作和池化操作是不需要參數的,也就是模型保存時並沒有保存下來。

3.2加载部分模型

class VGG(nn.Module):
    def __init__(self, weights=False):
        super(VGG, self).__init__()
        if weights is False:
            model = models.vgg19_bn(pretrained=True)

        model = models.vgg19_bn(pretrained=False)
        pre = torch.load(weights)
        model.load_state_dict(pre)
        self.vgg19 = model.features
        
        for param in self.vgg19.parameters():
            param.requires_grad = False

初始化有个参数权重,当为false时,默认网上下载VGG模型,通常网上下载的比较慢不建议,所以直接本地下载好之后再load即可。这里选择了vgg的features部分,全连接部分没有选择,当然也可以索引或者切片选择任何层的 features。