pytorch使用TensorBoard可视化网络模型结构
原创
©著作权归作者所有:来自51CTO博客作者Lineage_的原创作品,请联系作者获取转载授权,否则将追究法律责任
在训练神经网络时,我们希望能够直观地训练情况,例如损失函数的曲线、输入的图像、模型精度等信息,这些信息可以帮助我们更好地监督网络的训练过程,并为参数优化提供方向和依据。
本文提供一个更为专业的操作,它是一个常用的可视化工具:TensorBoard,下面将利用TensorBoard来实现可视化网络模型结构操作。
PyTorch已经内置了TensorBoard的相关接口,用户在安装后便可调用相关接口进行数据可视化。
![在这里插入图片描述 pytorch使用TensorBoard可视化网络模型结构_可视化工具](https://s2.51cto.com/images/blog/202301/16143002_63c4eeea3aa9872028.png?x-oss-process=image/watermark,size_16,text_QDUxQ1RP5Y2a5a6i,color_FFFFFF,t_30,g_se,x_10,y_10,shadow_20,type_ZmFuZ3poZW5naGVpdGk=/resize,m_fixed,w_1184)
data_transform = T.Compose([
T.ToTensor(),
T.RandomResizedCrop(32),
T.Normalize((0.5,), (0.5,))
])
train_dataset = torchvision.datasets.MNIST('./', train=True, transform=data_transform)
val_dataset = torchvision.datasets.MNIST('./', train=False, transform=data_transform)
train_loader = torch.utils.data.DataLoader(train_dataset, 16)
val_loader = torch.utils.data.DataLoader(val_dataset, 16)
images, labels = next(iter(train_loader))
logger = SummaryWriter(log_dir='./log')
logger.add_graph(model, images)
我们只需要传入模型,以及模型的输入信息即可,其实上述代码不一定要传入images,其实可以随意创建一个张量即可,它的意义就是要告诉模型的输入形状,我们可以使用下面代码代替:
logger = SummaryWriter(log_dir='./log')
logger.add_graph(model, torch.randn(32, 1, 32, 32))