【PyTorch基础】将pytorch模型转换为script模型
原创
©著作权归作者所有:来自51CTO博客作者mb62c788fd198da的原创作品,请联系作者获取转载授权,否则将追究法律责任
操作步骤:
1. 将PyTorch模型转换为Torch脚本;
1)通过torch.jit.trace
转换为torch脚本;
2)通过torch.jit.script转换为torch脚本;
2. 将脚本模型序列化为文件;
要想序列化模型文件,只需在模块上调用save函数即可;
3. 在c++中加载脚本模块;
安装使用LibTorch;
使用torch::jit::load()
函数对该模块进行反序列化,得到一个torch::jit::script::Module
对象。
4. 在c++中执行脚本模块;
注意,生成序列化和调用反序列化模型的输入必须要保持一致;
code
# -*- coding: utf-8 -*-
# @Time : 2021.07.27 16:00
# @Author: xxx
# @Email :
# @File : torch2script.py
"""
Transform torch model to Script module.
"""
import torch
from unet import UNet
from config import UNetConfig
cfg = UNetConfig()
model_path = './checkpoints/epoch_500.pth'
# model
model = UNet(cfg)
model.load_state_dict(torch.load(model_path), strict=True)
model.eval()
# an example input.
example = torch.rand(5, 3, 625, 620) # NCHW.
# Trace to Torch script.
# Use torch.jit.trace to generate a troch.jit.scriptmodule via tracing.
# 将 PyTorch 模型通过跟踪转换为 Torch 脚本,必须将模型的实例以及示例输入传递给torch.jit.trace函数。
# 这将产生一个torch.jit.ScriptModule对象,并将模型评估的轨迹嵌入到模块的forward方法中.
traced_script_module = torch.jit.trace(model, example)
output =output1= model(example)'./unet_trace_module.pt')
# print('output: ', output)
# print('output1: ', output1)
print('traced_script_module graph: \n', traced_script_module.graph)
print('traced_script_module code : \n', traced_script_module.code )
# ERROR!!!!!
# # Script module
# model_script = UNet(cfg)
# sm = torch.jit.script(model_script)
# output2 = sm(example)
#
# # Serialize model.
# sm.save('./unet_script_module.pt')
注意,执行脚本模型文件进行测试的输入大小必须和生成脚本模型的输入大小一致,否则执行的时候会出错;
error
/home/xxx/lib/python3.8/site-packages/torch/nn/modules/module.py(704): _slow_forward
/home/xxx/lib/python3.8/site-packages/torch/nn/modules/module.py(720): _call_impl
/home/xxx/lib/python3.8/site-packages/torch/jit/__init__.py(1109): trace_module
/home/xxx/lib/python3.8/site-packages/torch/jit/__init__.py(953): trace
torch2script.py(25): <module>
RuntimeError: Sizes of tensors must match except in dimension 1. Got 78 and 79 in dimension 3 (The offending index is 1)
Aborted (core dumped)
5. CUDA相关函数
"torch::cuda::is_available():" << torch::cuda::is_available() << std::endl;
std::cout <<"torch::cuda::cudnn_is_available():" << torch::cuda::cudnn_is_available() << std::endl;
std::cout <<"torch::cuda::device_count():"
6. GPU/CPU模式
torch::DeviceType device_type = at::kCPU; // 定义设备类型
if (torch::cuda::is_available())
device_type = at::kCUDA;
model.to(device_type);
std::vector<torch::jit::IValue> inputs;
inputs.push_back(torch::ones({ 1, 3, 224, 224
device
torch::DeviceType device_type;
device_type = torch::kCUDA;
torch::Device device(device_type);
torch::jit::script::Module module = torch::jit::load(model_path, device);
7. 注意,需要对ScriptModule的结果
进行验证和评估,使其与常规 PyTorch 模块的推断结果相同;
注意,使用no_grad()进行验证评估;
8. 在c++中加载torchscript模型的时候,发现输入尺寸不必和torchscript模型的尺寸一致;
参考
1. 在 C++ 中加载 TorchScript 模型;
2. 基于C++的PyTorch模型部署;
3. torch.jit.trace;
4. torch.jit.script;
5. 使用C++调用并部署pytorch模型;
6. libtorch c++部署-使用GPU;
完