G
C
N
−
M
o
d
e
l
(
p
y
t
o
r
c
h
版
本
)
GCN-Model(pytorch版本)
GCN−Model(pytorch版本)
训练、验证代码逻辑
import torch.nn as nn
from torchvision import models
import torch
resnet152_pretrained = models.resnet152(pretrained=False)
class GCM(nn.Module):
def __init__(self, in_channels, num_class, k=15):
super(GCM, self).__init__()
pad = (k-1) // 2
self.conv1 = nn.Sequential(nn.Conv2d(in_channels, num_class, kernel_size=(1, k), padding=(0, pad), bias=False),
nn.Conv2d(num_class, num_class, kernel_size=(k, 1), padding=(pad, 0), bias=False))
self.conv2 = nn.Sequential(nn.Conv2d(in_channels, num_class, kernel_size=(k, 1), padding=(pad, 0), bias=False),
nn.Conv2d(num_class, num_class, kernel_size=(1, k), padding=(0, pad), bias=False))
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2(x)
assert x1.shape == x2.shape
return x1 + x2
class BR(nn.Module):
def __init__(self, num_class):
super(BR, self).__init__()
self.shortcut = nn.Sequential(nn.Conv2d(num_class, num_class, 3, padding=1, bias=False),
nn.ReLU(),
nn.Conv2d(num_class, num_class, 3, padding=1, bias=False))
def forward(self, x):
return x + self.shortcut(x)
class GCN_BR_BR_Deconv(nn.Module):
def __init__(self, in_channels, num_class, k=15):
super(GCN_BR_BR_Deconv, self).__init__()
self.gcn = GCM(in_channels, num_class, k)
self.br = BR(num_class)
self.deconv = nn.ConvTranspose2d(num_class, num_class, 4, 2, 1, bias=False)
def forward(self, x1, x2=None):
x1 = self.gcn(x1)
x1 = self.br(x1)
if x2 is None:
x = self.deconv(x1)
else:
x = x1 + x2
x = self.br(x)
x = self.deconv(x)
return x
class GCN(nn.Module):
def __init__(self, num_classes, k=15):
super(GCN, self).__init__()
self.num_class = num_classes
self.k = k
self.layer0 = nn.Sequential(resnet152_pretrained.conv1, resnet152_pretrained.bn1, resnet152_pretrained.relu)
self.layer1 = nn.Sequential(resnet152_pretrained.maxpool, resnet152_pretrained.layer1)
self.layer2 = resnet152_pretrained.layer2
self.layer3 = resnet152_pretrained.layer3
self.layer4 = resnet152_pretrained.layer4
self.br = BR(self.num_class)
self.deconv = nn.ConvTranspose2d(self.num_class, self.num_class, 4, 2, 1, bias=False)
def forward(self, input):
print('input:', input.size())
x0 = self.layer0(input);print('x0:', x0.size())
x1 = self.layer1(x0);print('x1:', x1.size())
x2 = self.layer2(x1);print('x2:', x2.size())
x3 = self.layer3(x2);print('x3:', x3.size())
x4 = self.layer4(x3);print('x4:', x4.size())
branch4 = GCN_BR_BR_Deconv(x4.shape[1], self.num_class, self.k)
branch3 = GCN_BR_BR_Deconv(x3.shape[1], self.num_class, self.k)
branch2 = GCN_BR_BR_Deconv(x2.shape[1], self.num_class, self.k)
branch1 = GCN_BR_BR_Deconv(x1.shape[1], self.num_class, self.k)
branch4 = branch4(x4);print('branch4:', branch4.size())
branch3 = branch3(x3, branch4);print('branch3:', branch3.size())
branch2 = branch2(x2, branch3);print('branch2:', branch2.size())
branch1 = branch1(x1, branch2);print('branch1:', branch1.size())
x = self.br(branch1);print('x:', x.size())
x = self.deconv(x);print('x:', x.size())
x = self.br(x);print('x:', x.size())
return x
# 随机生成输入数据
rgb = torch.randn(1, 3, 512, 512)
# 定义网络
net = GCN(8)
# 前向传播
out = net(rgb)
# 打印输出大小
print('-----'*5)
print(out.shape)
print('-----'*5)