一.下载libtorch
到官网pytorch官网下载libtorch,选择适合自己版本pytorch官网.下方有相应的下载链接,一个debug版本,一个release版本
目前我是使用的pytorch也是1.4版本,之前网上查阅资料时,有人说下载的libtorch版本要跟pytorch版本一致,不一致的情况我也没试.CUDA我选了None,因为公司业务原因,客户处是不会用到GPU的。
本人环境:
系统 :ubuntu16.4 && win10
pytorch : 1.4.0
torchvision : 0.5.0
libtorch 1.4
c++ IDE: vs2017
opencv : 4.2.0
官方教程libtorch安装测试,不过下载下来后,文件夹build目录下已经有编译好的文件,直接使用.
验证libtorch库时,编译程序总是报错,"std 无法识别的符号"之类的,在github上面逛了很久,找到建议,所有报错std的地方,在std前面加上作用域符号::,所有修改完毕后再次编译后正常.也有另外一个办法,如下图,设置一个选项就可以了。
pytorch保存模型
官方教程:pytorch模型导出 这里说明一下,基本都要使用torch.jit.script方式导出模型,支持控制语句操作.
本来我想尝试在windows c++导入torchvision自带的maskrcnn模型的,奈何一直无法成功.决定先导出个目标检测模型.前段时间学习了CenterNet模型,公司业务可能用到,那就选你了!
CenterNet模型网上资料百度都很多,我就不细说了.我们要做的是看CenterNet源码,把模型单独提取出来,能尽量用简单的语句就用简单的语句(源代码可能涉及太多控制语句或其他内容,我试了,libtorch一直无法导入模型,后来自己提取模型,简化了语句才可以).然后训练好自己的数据.torch.jit.script方式保存好模型。
windows下libtorch导入
终于到了本文重点!
导入模型,代码如下
cout << "load model..." << endl;
torch::jit::script::Module module;
module = torch::jit::load("centernet_script.pt");
module.eval();
能直接运行就是没问题了,之前导入torchvision的maskrcnn模型时,一运行程序就崩溃。。
构建模型输入,torch::from_blob函数可以把opencv的Mat结构转换为torch的tensor。我的图形预处理,转换输入tensor代码如下:
// 图像转换为Tensor
torch::Tensor tensor_image = torch::from_blob(TransImg.data, { TransImg.rows, TransImg.cols,3 }, torch::kByte);
//输入预处理
tensor_image = tensor_image.toType(torch::kFloat);
tensor_image = tensor_image.div(255);
torch::Tensor mean = torch::tensor({ 0.31003028, 0.31786674, 0.31668854 });//均值
torch::Tensor dstd = torch::tensor({ 0.32514998, 0.31744885, 0.3204121 });
tensor_image = (tensor_image - mean) / dstd;
tensor_image = tensor_image.permute({ 2,0,1 });
tensor_image = tensor_image.unsqueeze(0);
// Create a vector of inputs.
std::vector<torch::jit::IValue> inputs;
inputs.push_back(tensor_image);
可以看到,其实语法都是跟python里面类似的。
运行forward函数,得到CenterNet最终输出。
auto outputs = module.forward(inputs).toTuple();
torch::Tensor heat = outputs->elements()[0].toTensor();
torch::Tensor wh = outputs->elements()[1].toTensor();
torch::Tensor reg = outputs->elements()[2].toTensor();
数据处理部分,我没有在python里面做(怕涉及复杂操作导致模型不能用libtorch导入),后面就是进行最终结果的处理啦。涉及公司业务,更多程序细节就不展示了。
附上几张效果图,折腾了很久,最终实现了还是很有成就感的。
目前对libtorch的数据结构太陌生,都是网上参考各路大神的做法,感觉后续要多了解libtorch数据结构,要不然不会处理提取最终的结果。或者后续在python里面进行处理,这就要保证模型能正常被libtorch导入。
感谢各路大神提供的帮助,参考链接:
libtorch导入pytorch模型:导入pytorch模型 libtorch数据结构介绍:数据结构转换