神经网络可视化——基于torchviz绘制模型的计算图

  • 第一步、安装 graphviz 和 torchviz 库
  • 第二步、编写代码生成计算图
  • 第三步、安装graphviz软件



  在深入理解深度学习模型时,可视化网络结构是一个非常有用的手段。今天介绍如何使用 torchviz 和 graphviz 来生成网络计算图。这个方法特别适合那些希望深入探究网络内部细节的读者。需要注意的是,生成的网络结构图是基于反向传播过程生成的,因此展示的是一个倒序的网络结构。

第一步、安装 graphviz 和 torchviz 库

  首先,我们需要安装两个Python库:torchviz 和 graphviz。这两个库互相配合,能够帮助我们生成直观的网络结构图。安装这些库的方法非常简单,只需运行以下代码即可。

pip install graphviz torchviz

第二步、编写代码生成计算图

  接下来,我们需要编写一段代码来生成计算图。这段代码会利用前面安装的库来创建和展示网络结构。

import torch
import torch.nn as nn
from torchviz import make_dot

class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(5, 2)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

model = SimpleModel()
input_tensor = torch.randn(1, 10)  # 假设模型输入是一个 10 维的向量


output = model(input_tensor)
graph = make_dot(output, params=dict(model.named_parameters()))
graph.render('model_graph', format='png')  # 保存为 PNG 图像

  运行上述代码后,可能会遇到一个报错:

神经网络可视化——基于torchviz绘制模型的计算图_torchviz

  报错提示是关于 graphviz 中的一些可执行命令未被添加到系统路径中。意思是虽然我们安装了相关的Python库,但并没有安装 graphviz 软件本身。解决这个问题的办法是下载并安装 graphviz 软件,并确保它的可执行文件被添加到系统环境变量中。可以从以下链接下载 graphviz 软件:https://www2.graphviz.org/Packages/stable/windows/10/cmake/Release/x64/

第三步、安装graphviz软件

  下载并运行安装程序,安装过程中,请确保选择添加 graphviz 到系统环境变量的选项。这一步是非常重要的,因为它允许你的代码正确地调用 graphviz 的功能。

神经网络可视化——基于torchviz绘制模型的计算图_深度学习_02


  安装的时候,记得添加系统环境变量:

神经网络可视化——基于torchviz绘制模型的计算图_计算图_03

  安装完成后,如果代码仍然无法运行,尝试重启电脑。这样做通常可以解决环境变量更新后的相关问题。安装并配置好所有必需的组件后,你应该能够成功生成并查看计算图。最终生成的计算图示例如下:

神经网络可视化——基于torchviz绘制模型的计算图_深度学习_04


  计算图是一个有向图,主要描述网络中的操作和变量之间的关系。

  在图的顶部,可以看到网络层的权重 fc1.weight 和偏置 fc1.bias,它们代表网络中第一个全连接层(通常用fc表示)的参数。权重矩阵的维度是 (5, 10),而偏置向量的维度是 (5)。这意味着该层将接受一个10维的输入并输出一个5维的向量。

  中间的节点表示反向传播过程中的不同操作,例如 TBackwardAddmmBackward,这些都是PyTorch中的自动梯度计算操作,用于在训练期间更新网络权重。

  fc2.weightfc2.bias 表示第二个全连接层的参数,其权重矩阵维度为 (2, 5),偏置向量维度为 (2)。这表明第二个层将5维输入转换为2维输出。

  最底部的 (1, 2) 表示的是网络的最终输出,它是一个2维的向量。