安装graphviz和torchviz

pip install graphviz
pip install git+https://github.com/szagoruyko/pytorchviz

Windows用户请注意,需要到graphviz官网下载windows版本并安装,然后添加系统环境变量(右键开始菜单-> 系统 -> 右侧高级系统设置 -> 环境变量 -> 系统变量 -> 双击Path -> 新建):

pytorch可视化使用graphviz+torchviz查看计算图_环境变量

然后重启pycharm!在Terminal(或cmd)中输入dot -version显示版本信息证明安装成功。

示例

import torch
from torchvision.models import AlexNet
from torchviz import make_dot

# 以AlexNet为例,前向传播
x = torch.rand(8, 3, 256, 512)
model = AlexNet()
y = model(x)

# 构造图对象,3种方式
g = make_dot(y)
# g = make_dot(y, params=dict(model.named_parameters()))
# g = make_dot(y, params=dict(list(model.named_parameters()) + [('x', x)]))

# 保存图像
# g.view()  # 生成 Digraph.gv.pdf,并自动打开
g.render(filename='graph', view=False)  # 保存为 graph.pdf,参数view表示是否打开pdf

PDF效果:

pytorch可视化使用graphviz+torchviz查看计算图_深度学习_02