F u s i o n N e t − M o d e l ( p y t o r c h 版 本 ) FusionNet-Model(pytorch版本) FusionNet−Model(pytorch版本)
训练、验证代码逻辑
All.ipynb
import torch.nn as nn
import torch
def conv_block(in_dim,out_dim,act_fn,stride=1):
model = nn.Sequential(
nn.Conv2d(in_dim, out_dim, kernel_size=3, stride=stride, padding=1),
nn.BatchNorm2d(out_dim),
act_fn,
)
return model
def conv_trans_block(in_dim,out_dim,act_fn):
model = nn.Sequential(
nn.ConvTranspose2d(in_dim, out_dim, kernel_size=3, stride=2, padding=1,output_padding=1),
nn.BatchNorm2d(out_dim),
act_fn,
)
return model
def conv_block_3(in_dim, out_dim, act_fn):
model = nn.Sequential(
conv_block(in_dim, out_dim, act_fn),
conv_block(out_dim, out_dim, act_fn),
nn.Conv2d(out_dim, out_dim, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_dim),
)
return model
class Conv_residual_conv(nn.Module):
def __init__(self, in_dim, out_dim, act_fn):
super(Conv_residual_conv, self).__init__()
self.in_dim = in_dim
self.out_dim = out_dim
act_fn = act_fn
self.conv_1 = conv_block(self.in_dim, self.out_dim, act_fn)
self.conv_2 = conv_block_3(self.out_dim, self.out_dim, act_fn)
self.conv_3 = conv_block(self.out_dim, self.out_dim, act_fn)
def forward(self, input):
conv_1 = self.conv_1(input)
conv_2 = self.conv_2(conv_1)
res = conv_1 + conv_2
conv_3 = self.conv_3(res)
return conv_3
class Fusionnet(nn.Module):
def __init__(self, input_nc, output_nc, ngf, out_clamp=None):
super(Fusionnet, self).__init__()
self.out_clamp = out_clamp
self.in_dim = input_nc
self.out_dim = ngf
self.final_out_dim = output_nc
act_fn = nn.ReLU()
act_fn_2 = nn.ELU(inplace=True)
# encoder
self.down_1 = Conv_residual_conv(self.in_dim, self.out_dim, act_fn)
self.pool_1 = conv_block(self.out_dim, self.out_dim, act_fn, 2)
self.down_2 = Conv_residual_conv(self.out_dim, self.out_dim * 2, act_fn)
self.pool_2 = conv_block(self.out_dim * 2, self.out_dim * 2, act_fn, 2)
self.down_3 = Conv_residual_conv(self.out_dim * 2, self.out_dim * 4, act_fn)
self.pool_3 = conv_block(self.out_dim * 4, self.out_dim * 4, act_fn, 2)
self.down_4 = Conv_residual_conv(self.out_dim * 4, self.out_dim * 8, act_fn)
self.pool_4 = conv_block(self.out_dim * 8, self.out_dim * 8, act_fn, 2)
# bridge
self.bridge = Conv_residual_conv(self.out_dim * 8, self.out_dim * 16, act_fn)
# decoder
self.deconv_1 = conv_trans_block(self.out_dim * 16, self.out_dim * 8, act_fn_2)
self.up_1 = Conv_residual_conv(self.out_dim * 8, self.out_dim * 8, act_fn_2)
self.deconv_2 = conv_trans_block(self.out_dim * 8, self.out_dim * 4, act_fn_2)
self.up_2 = Conv_residual_conv(self.out_dim * 4, self.out_dim * 4, act_fn_2)
self.deconv_3 = conv_trans_block(self.out_dim * 4, self.out_dim * 2, act_fn_2)
self.up_3 = Conv_residual_conv(self.out_dim * 2, self.out_dim * 2, act_fn_2)
self.deconv_4 = conv_trans_block(self.out_dim * 2, self.out_dim, act_fn_2)
self.up_4 = Conv_residual_conv(self.out_dim, self.out_dim, act_fn_2)
# output
self.out = nn.Conv2d(self.out_dim, self.final_out_dim, kernel_size=3, stride=1, padding=1)
def forward(self, input):
print('input:', input.size())
down_1 = self.down_1(input);print('down_1:', down_1.size())
pool_1 = self.pool_1(down_1);print('pool_1:', pool_1.size())
down_2 = self.down_2(pool_1);print('down_2:', down_2.size())
pool_2 = self.pool_2(down_2);print('pool_2:', pool_2.size())
down_3 = self.down_3(pool_2);print('down_3:', down_3.size())
pool_3 = self.pool_3(down_3);print('pool_3:', pool_3.size())
down_4 = self.down_4(pool_3);print('down_4:', down_4.size())
pool_4 = self.pool_4(down_4);print('pool_4:', pool_4.size())
bridge = self.bridge(pool_4);print('bridge:', bridge.size())
deconv_1 = self.deconv_1(bridge);print('deconv_1:', deconv_1.size())
skip_1 = (deconv_1 + down_4) / 2;print('skip_1:', skip_1.size())
up_1 = self.up_1(skip_1);print('up_1:', up_1.size())
deconv_2 = self.deconv_2(up_1);print('deconv_2:', deconv_2.size())
skip_2 = (deconv_2 + down_3) / 2;print('skip_2:', skip_2.size())
up_2 = self.up_2(skip_2);print('up_2:', up_2.size())
deconv_3 = self.deconv_3(up_2);print('deconv_3:', deconv_3.size())
skip_3 = (deconv_3 + down_2) / 2;print('skip_3:', skip_3.size())
up_3 = self.up_3(skip_3);print('up_3:', up_3.size())
deconv_4 = self.deconv_4(up_3);print('deconv_4:', deconv_4.size())
skip_4 = (deconv_4 + down_1) / 2;print('skip_4:', skip_4.size())
up_4 = self.up_4(skip_4);print('up_4:', up_4.size())
out = self.out(up_4);print('out:', out.size())
return out
# 随机生成输入数据
rgb = torch.randn(1, 3, 352, 480)
# 定义网络
net = Fusionnet(3, 12, 64)
# 前向传播
out = net(rgb)
# 打印输出大小
print('-----'*5)
print(out.shape)
print('-----'*5)