G C N − M o d e l ( p y t o r c h 版 本 ) GCN-Model(pytorch版本) GCNModel(pytorch)


训练、验证代码逻辑




All.ipynb


import torch.nn as nn
from torchvision import models
import torch
resnet152_pretrained = models.resnet152(pretrained=False)
class GCM(nn.Module):
    def __init__(self, in_channels, num_class, k=15):
        super(GCM, self).__init__()

        pad = (k-1) // 2

        self.conv1 = nn.Sequential(nn.Conv2d(in_channels, num_class, kernel_size=(1, k), padding=(0, pad), bias=False),
                                   nn.Conv2d(num_class, num_class, kernel_size=(k, 1), padding=(pad, 0), bias=False))

        self.conv2 = nn.Sequential(nn.Conv2d(in_channels, num_class, kernel_size=(k, 1), padding=(pad, 0), bias=False),
                                   nn.Conv2d(num_class, num_class, kernel_size=(1, k), padding=(0, pad), bias=False))

    def forward(self, x):

        x1 = self.conv1(x)
        x2 = self.conv2(x)

        assert x1.shape == x2.shape

        return x1 + x2


class BR(nn.Module):
    def __init__(self, num_class):
        super(BR, self).__init__()

        self.shortcut = nn.Sequential(nn.Conv2d(num_class, num_class, 3, padding=1, bias=False),
                                      nn.ReLU(),
                                      nn.Conv2d(num_class, num_class, 3, padding=1, bias=False))

    def forward(self, x):
        return x + self.shortcut(x)


class GCN_BR_BR_Deconv(nn.Module):
    def __init__(self, in_channels, num_class, k=15):
        super(GCN_BR_BR_Deconv, self).__init__()

        self.gcn = GCM(in_channels, num_class, k)
        self.br = BR(num_class)

        self.deconv = nn.ConvTranspose2d(num_class, num_class, 4, 2, 1, bias=False)

    def forward(self, x1, x2=None):

        x1 = self.gcn(x1)
        x1 = self.br(x1)

        if x2 is None:
            x = self.deconv(x1)
        else:
            x = x1 + x2
            x = self.br(x)
            x = self.deconv(x)

        return x
class GCN(nn.Module):
    def __init__(self, num_classes, k=15):
        super(GCN, self).__init__()
        self.num_class = num_classes
        self.k = k

        self.layer0 = nn.Sequential(resnet152_pretrained.conv1, resnet152_pretrained.bn1, resnet152_pretrained.relu)
        self.layer1 = nn.Sequential(resnet152_pretrained.maxpool, resnet152_pretrained.layer1)
        self.layer2 = resnet152_pretrained.layer2
        self.layer3 = resnet152_pretrained.layer3
        self.layer4 = resnet152_pretrained.layer4

        self.br = BR(self.num_class)
        self.deconv = nn.ConvTranspose2d(self.num_class, self.num_class, 4, 2, 1, bias=False)

    def forward(self, input):
        print('input:', input.size())
        x0 = self.layer0(input);print('x0:', x0.size())
        x1 = self.layer1(x0);print('x1:', x1.size())
        x2 = self.layer2(x1);print('x2:', x2.size())
        x3 = self.layer3(x2);print('x3:', x3.size())
        x4 = self.layer4(x3);print('x4:', x4.size())

        branch4 = GCN_BR_BR_Deconv(x4.shape[1], self.num_class, self.k)
        branch3 = GCN_BR_BR_Deconv(x3.shape[1], self.num_class, self.k)
        branch2 = GCN_BR_BR_Deconv(x2.shape[1], self.num_class, self.k)
        branch1 = GCN_BR_BR_Deconv(x1.shape[1], self.num_class, self.k)

        branch4 = branch4(x4);print('branch4:', branch4.size())
        branch3 = branch3(x3, branch4);print('branch3:', branch3.size())
        branch2 = branch2(x2, branch3);print('branch2:', branch2.size())
        branch1 = branch1(x1, branch2);print('branch1:', branch1.size())

        x = self.br(branch1);print('x:', x.size())
        x = self.deconv(x);print('x:', x.size())
        x = self.br(x);print('x:', x.size())

        return x
# 随机生成输入数据
rgb = torch.randn(1, 3, 512, 512)
# 定义网络
net = GCN(8)
# 前向传播
out = net(rgb)
# 打印输出大小
print('-----'*5)
print(out.shape)
print('-----'*5)

GCN-Model(pytorch版本)_2d