如果搜索 PyTorch JIT,找到的将会是「TorchScript」的文档,那么什么是 JIT 呢?JIT 和 TorchScript 又有什么联系?
文章只会关注概念的部分,如果关注细节或实现部分,文章最后有一个完整的 Demo 可供参考。
什么是 JIT?
首先要知道 JIT 是一种概念,全称是 Just In Time Compilation,中文译为「即时编译」,是一种程序优化的方法,一种常见的使用场景是「正则表达式」。例如,在 Python 中使用正则表达式:
prog = re.compile(pattern)
result = prog.match(string)
或
result = re.match(pattern, string)
上面两个例子是直接从 Python 官方文档中摘出来的 ,并且从文档中可知,两种写法从结果上来说是「等价」的。但注意第一种写法种,会先对正则表达式进行 compile,然后再进行使用。如果继续阅读 Python 的文档,可以找到下面这段话:
using re.compile() and saving the resulting regular expression object for reuse is more efficient when the expression will be used several times in a single program.
也就是说,如果多次使用到某一个正则表达式,则建议先对其进行 compile,然后再通过 compile 之后得到的对象来做正则匹配。而这个 compile 的过程,就可以理解为 JIT(即时编译)。
在深度学习中 JIT 的思想更是随处可见,最明显的例子就是 Keras 框架的 model.compile,TensorFlow 中的 Graph 也是一种 JIT,虽然他没有显示调用编译方法。
那 PyTorch 呢?PyTorch 从面世以来一直以「易用性」著称,最贴合原生 Python 的开发方式,这得益于 PyTorch 的「动态图」结构。我们可以在 PyTorch 的模型前向中加任何 Python 的流程控制语句,甚至是下断点单步跟进都不会有任何问题,但是如果是 TensorFlow,则需要使用 tf.cond 等 TensorFlow 自己开发的流程控制,谁更简单一目了然。那么为什么 PyTorch 还需要引入 JIT 呢?
TorchScript
动态图模型通过牺牲一些高级特性来换取易用性,那到底 JIT 有哪些特性,在什么情况下不得不用到 JIT 呢?下面主要通过介绍 TorchScript(PyTorch 的 JIT 实现)来分析 JIT 到底带来了哪些好处。
- 模型部署
PyTorch 的 1.0 版本发布的最核心的两个新特性就是 JIT 和 C++ API,这两个特性一起发布不是没有道理的,JIT 是 Python 和 C++ 的桥梁,我们可以使用 Python 训练模型,然后通过 JIT 将模型转为语言无关的模块,从而让 C++ 可以非常方便得调用,从此「使用 Python 训练模型,使用 C++ 将模型部署到生产环境」对 PyTorch 来说成为了一件很容易的事。而因为使用了 C++,我们现在几乎可以把 PyTorch 模型部署到任意平台和设备上:树莓派、iOS、Android 等等…
2. 性能提升
既然是为部署生产所提供的特性,那免不了在性能上面做了极大的优化,如果推断的场景对性能要求高,则可以考虑将模型(torch.nn.Module)转换为 TorchScript Module,再进行推断。
3. 模型可视化
TensorFlow 或 Keras 对模型可视化工具(TensorBoard等)非常友好,因为本身就是静态图的编程模型,在模型定义好后整个模型的结构和正向逻辑就已经清楚了;但 PyTorch 本身是不支持的,所以 PyTorch 模型在可视化上一直表现得不好,但 JIT 改善了这一情况。现在可以使用 JIT 的 trace 功能来得到 PyTorch 模型针对某一输入的正向逻辑,通过正向逻辑可以得到模型大致的结构,但如果在 `forward` 方法中有很多条件控制语句,这依然不是一个好的方法,所以 PyTorch JIT 还提供了 Scripting 的方式,这两种方式在下文中将详细介绍。
TorchScript Module 的两种生成方式
1. 编码(Scripting)
可以直接使用 TorchScript Language 来定义一个 PyTorch JIT Module,然后用 torch.jit.script 来将他转换成 TorchScript Module 并保存成文件。而 TorchScript Language 本身也是 Python 代码,所以可以直接写在 Python 文件中。
使用 TorchScript Language 就如同使用 TensorFlow 一样,需要前定义好完整的图。对于 TensorFlow 我们知道不能直接使用 Python 中的 if 等语句来做条件控制,而是需要用 tf.cond,但对于 TorchScript 我们依然能够直接使用 if 和 for 等条件控制语句,所以即使是在静态图上,PyTorch 依然秉承了「易用」的特性。TorchScript Language 是静态类型的 Python 子集,静态类型也是用了 Python 3 的 typing 模块来实现,所以写 TorchScript Language 的体验也跟 Python 一模一样,只是某些 Python 特性无法使用(因为是子集),可以通过 TorchScript Language Reference 来查看和原生 Python 的异同。
理论上,使用 Scripting 的方式定义的 TorchScript Module 对模型可视化工具非常友好,因为已经提前定义了整个图结构。
2. 追踪(Tracing)
使用 TorchScript Module 的更简单的办法是使用 Tracing,Tracing 可以直接将 PyTorch 模型(torch.nn.Module)转换成 TorchScript Module。「追踪」顾名思义,就是需要提供一个「输入」来让模型 forward 一遍,以通过该输入的流转路径,获得图的结构。这种方式对于 forward 逻辑简单的模型来说非常实用,但如果 forward 里面本身夹杂了很多流程控制语句,则可能会有问题,因为同一个输入不可能遍历到所有的逻辑分枝。
此外,还可以混合使用上面两种方式。
一个完整的例子
我简单写了一个简单的 MNIST demo,从使用 Python 训练到用 JIT 将 Python 模型转换为 TorchScript Module,然后用 C++ 加载 TorchScript Module 做推断的完整的过程:
https://github.com/louis-she/torchscript-mnist
https://github.com/louis-she/torchscript-mnist