引言
在学习Pytorch中,为了更好理解网络结构,需要结合mdoel的图片结构和维度信息才能更好理解。keras中model.summary和plot_model工具就十分好用。在pytorch中,经过多方搜索,下列三种方式有助于自己理解,在此mark一下。其中summary要能知道模型的输入shape,可根据源代码和报错中提示进行尝试。
import torch
from torchviz import make_dot
from torch.autograd import Variable
from torchsummary import summary
model.netG # 直接打印
summary(model.netG, (3,256,256)) # 每层输出shape
xtmp = Variable(torch.randn(1,3,256,256))
ytmp = model.netG(xtmp)
make_dot(ytmp, params=dict(model.netG.named_parameters())).render('tmp', view=True) # render用于保存为图片
model.netD
summary(model.netD, (6, 256, 256))
xtmp2 = Variable(torch.randn(1,6,256,256))
ytmp2 = model.netD(xtmp2)
make_dot(ytmp2, params=dict(model.netD.named_parameters())).render('tmp2', view=True)
类似的make_dot,似乎更简洁些, From言有三
import torch
from torch.autograd import Variable
from visualize import make_dot
x = Variable(torch.randn(1,3,48,48))
model = simpleconv3()
y = model(x)
g = make_dot(y)
g.view()
20200421更新:
HiddenLayer,torchwatch
20201203更新:
https://discuss.pytorch.org/t/how-do-i-check-the-number-of-parameters-of-a-model/4325/22https://pypi.org/project/pytorch-model-summary/
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
from prettytable import PrettyTable
def count_parameters(model):
table = PrettyTable(["Modules", "Parameters"])
total_params = 0
for name, parameter in model.named_parameters():
if not parameter.requires_grad:
continue
param = parameter.numel()
table.add_row([name, param])
total_params+=param
print(table)
print(f"Total Trainable Params: {total_params}")
return total_params
summary(model, *inputs, batch_size=-1, show_input=False, show_hierarchical=False,
print_summary=False, max_depth=1, show_parent_layers=False):
# summary的例子
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch_model_summary import summary
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x, dim=1)
# # show input shape
# print(summary(Net(), torch.zeros((1, 1, 28, 28)), show_input=True))
# # show output shape
print(summary(Net(), torch.zeros((1, 1, 28, 28)), show_input=False))
# show output shape and hierarchical view of net
print(summary(Net(), torch.zeros((1, 1, 28, 28)), show_input=True, show_hierarchical=True))