安装graphviz和torchviz
pip install graphviz
pip install git+https://github.com/szagoruyko/pytorchviz
Windows用户请注意,需要到graphviz官网下载windows版本并安装,然后添加系统环境变量(右键开始菜单-> 系统 -> 右侧高级系统设置 -> 环境变量 -> 系统变量 -> 双击Path -> 新建):
然后重启pycharm!在Terminal(或cmd)中输入dot -version
显示版本信息证明安装成功。
示例
import torch
from torchvision.models import AlexNet
from torchviz import make_dot
# 以AlexNet为例,前向传播
x = torch.rand(8, 3, 256, 512)
model = AlexNet()
y = model(x)
# 构造图对象,3种方式
g = make_dot(y)
# g = make_dot(y, params=dict(model.named_parameters()))
# g = make_dot(y, params=dict(list(model.named_parameters()) + [('x', x)]))
# 保存图像
# g.view() # 生成 Digraph.gv.pdf,并自动打开
g.render(filename='graph', view=False) # 保存为 graph.pdf,参数view表示是否打开pdf
PDF效果: