from torch import nn
import torch
class UNet(nn.Module):
def __init__(self, in_channels=1, num_classes=2): # num_classes,此处为 二分类值为2
super(UNet, self).__init__()
# == Encoder ==
# 1. extract feayures, conv1
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels, 64, 3),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, 3),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True)
)
self.subpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
# 2. extract feayures, conv2
self.conv2 = nn.Sequential(
nn.Conv2d(64, 128, 3),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, 3),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True)
)
self.subpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
# 3. extract feayures, conv3
self.conv3 = nn.Sequential(
nn.Conv2d(128, 256, 3),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, 3),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True)
)
self.subpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
# 4. extract feayures, conv4
self.conv4 = nn.Sequential(
nn.Conv2d(256, 512, 3),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.Conv2d(512, 512, 3),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True)
)
self.subpool4 = nn.MaxPool2d(kernel_size=2, stride=2)
# 5. extract feayures, conv5
self.conv5 = nn.Sequential(
nn.Conv2d(512, 1024, 3),
nn.BatchNorm2d(1024),
nn.ReLU(inplace=True),
nn.Conv2d(1024, 512, 3),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True)
)
# == Decoder ==
self.uppool1 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv6 = nn.Sequential(
nn.Conv2d(1024, 512, 3),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.Conv2d(512, 256, 3),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True)
)
self.uppool2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv7 = nn.Sequential(
nn.Conv2d(512, 256, 3),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(256, 128, 3),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True)
)
self.uppool3 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv8 = nn.Sequential(
nn.Conv2d(256, 128, 3),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 64, 3),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True)
)
self.uppool4 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv9 = nn.Sequential(
nn.Conv2d(128, 64, 3),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, 3),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, num_classes, 1),
nn.BatchNorm2d(num_classes),
nn.ReLU(inplace=True)
)
def forward(self, x):
# === encoder
conv1 = self.conv1(x)
conv1_sub = self.subpool1(conv1)
conv2 = self.conv2(conv1_sub)
conv2_sub = self.subpool2(conv2)
conv3 = self.conv3(conv2_sub)
conv3_sub = self.subpool3(conv3)
conv4 = self.conv4(conv3_sub)
conv4_sub = self.subpool4(conv4)
conv5 = self.conv5(conv4_sub) # U型的最低端,它既是是encoder输出,也是decoder的输入。
# === deoder
conv1_up = self.uppool1(conv5)
conv6 = self.conv6(torch.cat([conv4, conv1_up], dim=1))
conv2_up = self.uppool2(conv6)
conv7 = self.conv7(torch.cat([conv3, conv2_up], dim=1))
conv3_up = self.uppool3(conv7)
conv8 = self.conv8(torch.cat([conv2, conv3_up], dim=1))
conv4_up = self.uppool4(conv8)
conv9 = self.conv9(torch.cat([conv1, conv4_up], dim=1))
return conv9
if __name__ == '__main__':
# model = VGGTest()
x = torch.rand(64, 1, 572, 572)
print(x.shape)
model = UNet(in_channels=x.shape[1])
# print(model)
y = model(x)
print(y.shape)