PyTorch:TorchScript 概要
- Intro
- Basic Usage
- 1. torch.jit.trace
- API Definition
- Example (tracing a function)
- Example (tracing an nn.Module)
- 2. torch.jit.script
- API Definition
- Example (scripting a function)
- Example (scripting an nn.Module)
- 3. Mixing Scripting and Tracing
- 4. torch.jit.save & torch.jit.load
- 个人思考
Intro
首先要提的一点概念就是 JIT(Just In Time)——这是一种即时编译器机制,其能把某些频繁被调用的代码编译成机器码,并进行各种层次的优化,以提高执行效率。
而TorchScript 是 PyTorch 用来创建可序列化和可优化模型的即时编译器。任何 TorchScript 程序都可以从 Python 进程中保存,然乎再导出到生产环境,加载到没有 Python 依赖项的进程中。
TorchScript 的强大,在于它的编译机制:
- 其作用于一个模型,既能记录模型的参数,又能保存计算图(Intermediate Representation Graph, IR Graph)。
- 其作用于一个函数,既能保存函数的有关参数,又能保存运行逻辑(本质上也还是计算图)。
Basic Usage
在介绍具体用法之前,先引入几个概念。
- Scripted 类:由TorchScript 作用、编译而生成。其可分为 ScriptModule 和 ScriptFunction 两种:
ScriptModule:由 nn.Module 编译而来。
ScriptFunction:由 一般的函数编译而来。
- 控制流(Control Flow):PyTorch 官方将 if 语句、循环语句等具有选择/判断性质的语句称为控制流。
- 常用 API:
torch.jit.trace:跟踪函数或模型的运行过程,并即时编译成可执行对象,但是需要数据作为输入。与 torch.jit.script 类似。
torch.jit.script:将函数或模型的运行过程即时编译化,不需要输入数据。与 torch.jit.trace 类似。可以通过装饰器调用。
torch.jit.export:一般作用于模型中的成员函数,表示该函数也需要被即时编译。配合 torch.jit.script,通过装饰器调用。
torch.jit.ignore:一般作用于模型中的成员函数,表示该函数不需要被即时编译,但是是可访问的。配合 torch.jit.script,通过装饰器调用。
torch.jit.save:保存成 Scripted 类实例。
torch.jit.load:加载出 Scripted 类实例。
1. torch.jit.trace
API Definition
torch.jit.trace(
func , example_inputs , optimize=None , check_trace=True , check_inputs=None
, check_tolerance=1e-05 , strict=True , _force_outplace=False , _module_class=None
, _compilation_unit=<torch.jit.CompilationUnit object> )
"""
主要参数:
➡func(callable or torch.nn.Module)– Python 函数或torch.nn.Module。将与example_inputs一起运行。
注意,func的输入参数和返回值(return)必须是Tensor或包含Tensor的元组。
➡example_inputs (tuple or torch.Tensor) – 将被传递进func中。
其他参数:不作展开。详情可参阅官方文档。
返回:ScriptModule 或 ScriptFunction。
"""
Example (tracing a function)
import torch
def my_example(x):
useful = 1
useless = 888
useless2 = useless + x
if x.sum() > 0:
return x * useful
else:
return -x * useful
if __name__ is '__main__':
x1 = torch.tensor([-0.8, -0.7, 0.6])
traced_example = torch.jit.trace(my_example, x1)
x2 = torch.tensor([0.8, 0.7, 0.6])
infer = traced_example(x2)
print(traced_example.code) # ScriptFuction 会有一个code属性
print('--------分隔--------')
print(traced_example.graph) # ScriptFuction 会有一个graph属性
print('--------分隔--------')
print(infer)
打印结果:
def my_example(x: Tensor) -> Tensor:
return torch.mul(torch.neg(x), CONSTANTS.c0)
--------分隔--------
graph(%x : Float(3, strides=[1], requires_grad=0, device=cpu)):
%5 : Float(3, strides=[1], requires_grad=0, device=cpu) = aten::neg(%x) # F:/PYcharm/test.py:39:0
%6 : Long(requires_grad=0, device=cpu) = prim::Constant[value={1}]() # F:/PYcharm/test.py:39:0
%7 : Float(3, strides=[1], requires_grad=0, device=cpu) = aten::mul(%5, %6) # F:/PYcharm/test.py:39:0
return (%7)
--------分隔--------
tensor([-0.8000, -0.7000, -0.6000])
观察上面的结果,你会发现:
- Scripted 类实例(也就是 traced_example )都会有 .code 属性 和 .graph 属性。前者描述了代码运行过程所做的运算,便于阅读;后者是抽象一点的计算图,具体会记录参数、算子(aten 是PyToch 的C++底层运算库)等。
- torch.jit.trace 会记录与输入做运算、对 return 结果产生影响的参数和算子。上面的打印结果你会发现,useful 变量被记录了下来,因为其既与输入做运算,又与最终 return 的内容有关系。而 useless、useless2 变量的语句没有包含在计算图中,因为其与实际要 return 的计算过程无关。
- torch.jit.trace 不会记录控制流(比如 if-else 语句)的细节,只记录参与运算的部分,这同样可以在上面的打印结果中发现这个现象。这很有可能产生错误的计算图,即不是你想要的,因为“写死”了——使用 traced_example 进行推理,会发现 infer 变量的赋值逻辑不对,逻辑上应该是 tensor([0.8000, 0.7000, 0.6000]),但是结果却是 tensor([-0.8000, -0.7000, -0.6000])。
上面的特性也是有一定意义的,当你不想让人知道你的控制流是如何设置的,你的其他变量是如何设置的——你不想展示给别人看,那么就可以用 torch.jit.trace,追踪出一条计算路线记录下来即可。需要注意的是,最好不要使用 torch.jit.trace 来作用 RNN,因为 RNN 里有循环,也就是有控制流。
Example (tracing an nn.Module)
import torch
class my_block(torch.nn.Module):
def forward(self, x):
return x * x # 进行点积
def no_trace_func1(x, h):
return x + h
class my_model(torch.nn.Module):
def __init__(self):
super(my_model, self).__init__()
self.bk = my_block()
self.linear = torch.nn.Linear(4, 4)
def no_trace_func2(self, x, h):
return x - h
def forward(self, x, h):
x = no_trace_func1(x, h)
h = self.no_trace_func2(x, h)
new_h = torch.tanh(self.bk(self.linear(x)) + h)
return new_h, new_h
if __name__ is '__main__':
x = torch.tensor([0.8, 0.7, 0.6, 0.5])
h = torch.tensor([0.5, 0.5, 0.5, 0.5])
mmodel = my_model()
traced_model = torch.jit.trace(mmodel, (x, h))
print(traced_model.code)
print('--------分隔--------')
print(traced_model.bk.code)
# print(traced_cell.no_trace_func2.code) # 报错,没有no_trace_func2属性
打印结果:
def forward(self,
x: Tensor,
h: Tensor) -> Tuple[Tensor, Tensor]:
_0 = self.bk
_1 = self.linear
x0 = torch.add(x, h, alpha=1)
h0 = torch.sub(x0, h, alpha=1)
_2 = torch.add((_0).forward((_1).forward(x0, ), ), h0, alpha=1)
_3 = torch.tanh(_2)
return (_3, _3)
--------分隔--------
def forward(self,
argument_1: Tensor) -> Tensor:
return torch.mul(argument_1, argument_1)
观察上面的结果,你会发现:
- 当 torch.jit.trace 作用于父模型(mmodel)时,只有父模型的 forward 函数被跟踪和记录。而 traced_model.no_trace_func2 属性无法访问到,因为其不会被记录。
- forward 函数内的所有与输入做运算、对 return 结果产生影响的参数和算子过程都被记录了。no_trace_func1 参与了计算,所以 forward 里追踪了 no_trace_func1 的运算细节,即 torch.add(x, h, alpha=1) ;no_trace_func2 参与了计算,所以 forward 里追踪了 no_trace_func2 的运算细节,即 torch.sub(x0, h, alpha=1)。
- 父模型的 .code 属性不会再嵌套记录子模型的 forward 细节,而是将子模型定义成一个实例直接调用。
- 子模型 traced_model.bk 是可以访问到的,因为其参与了前向计算、也有 forward。故可以调用 traced_model.bk.code 来查看它的 forward 发生了什么。
2. torch.jit.script
API Definition
torch.jit.script(obj, optimize=None, _frames_up=0, _rcb=None)
"""
主要参数:
➡obj (callable, class, or nn.Module) – Python 函数或torch.nn.Module。
其他参数:不作展开。详情可参阅源码注释。
返回:ScriptModule、ScriptFunction、torch._C.ScriptDict 或 torch._C.ScriptList。
"""
Example (scripting a function)
import torch
@torch.jit.script
def my_func(x, y):
useless = 666
useless2 = 1 * x
if x.max() > y.max():
r = x
else:
r = y
return r
if __name__ is '__main__':
infer = my_func(torch.tensor([-0.1, 0.6]), torch.tensor([0.8, 0.4]))
infer2 = my_func(torch.tensor([0.1, 0.6]), torch.tensor([-0.8, 0.4]))
print(type(my_func))
print('--------分隔--------')
print(my_func.code)
print('--------分隔--------')
print(my_func.graph)
print('--------分隔--------')
print(infer)
print('--------分隔--------')
print(infer2)
打印结果:
<class 'torch.jit.ScriptFunction'>
--------分隔--------
def my_func(x: Tensor,
y: Tensor) -> Tensor:
_0 = bool(torch.gt(torch.max(x), torch.max(y)))
if _0:
r = x
else:
r = y
return r
--------分隔--------
graph(%x.1 : Tensor,
%y.1 : Tensor):
%7 : Tensor = aten::max(%x.1) # F:/PYcharm/test.py:100:7
%9 : Tensor = aten::max(%y.1) # F:/PYcharm/test.py:100:17
%10 : Tensor = aten::gt(%7, %9) # F:/PYcharm/test.py:100:7
%12 : bool = aten::Bool(%10) # F:/PYcharm/test.py:100:7
%r : Tensor = prim::If(%12) # F:/PYcharm/test.py:100:4
block0():
-> (%x.1)
block1():
-> (%y.1)
return (%r)
--------分隔--------
tensor([0.8000, 0.4000])
--------分隔--------
tensor([0.1000, 0.6000])
观察上面的结果,你会发现:
- 可以通过装饰器来调用 torch.jit.script 。也就是说,当调用 my_func 函数时,会先调用 torch.jit.script。
- 可以发现控制流被完整的记录下来了,即 if-else 语句被完整的记录下来了。于是输入不同的数据,产生的结果就会不同,逻辑正确。这就是 script 和 trace 的区别之一。
- 当然也可以发现,只有与输入做运算、对 return 结果产生影响的参数和算子被记录了下来,useless、useless2 都没有被记录。
- 不需要 example_inputs,就可以形成完整的 Graph。
Example (scripting an nn.Module)
import torch
import torch.nn as nn
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
@torch.jit.export
def compiled_stored_func(self, input):
return input + 10
@torch.jit.ignore
def observed_func(self, input):
# 该函数不会被 TorchScript 编译,但是你在接下来的 Python 环境中可以继续使用。
return input - 1
def normal_func(self, input):
return input / 10
def forward(self, input):
out = self.observed_func(input)
out = self.compiled_stored_func(out)
return out * 1
if __name__ is '__main__':
x = torch.tensor([1.1, 2.2])
scripted_module = torch.jit.script(MyModule())
print(type(scripted_module.compiled_stored_func))
print('--------分隔--------')
print(scripted_module.compiled_stored_func.code)
print('--------分隔--------')
print(type(scripted_module.observed_func))
print('--------分隔--------')
# print(type(scripted_module.observed_func.code)) # 报错,observed_func没有code这个属性
# print(type(scripted_module.normal_func)) # 报错,没有normal_func这个属性
打印结果:
<class 'torch._C.ScriptMethod'>
--------分隔--------
def compiled_stored_func(self,
input: Tensor) -> Tensor:
return torch.add(input, 10, 1)
--------分隔--------
<class 'method'>
--------分隔--------
观察上面的结果,你会发现:
- torch.jit.export 的作用下,compiled_stored_func 函数也被编译和记录了;torch.jit.ignore 的作用下,observed_func 可以在以后的代码中可作为成员函数调用,但是其没有被编译,最后不会被 ScriptModule 记录;而normal_func 函数即一般的成员函数,没有被编译,也没有被记录,也不可被访问。
3. Mixing Scripting and Tracing
实际情况中,torch.jit.script 和 torch.jit.trace 是可以混合使用的,以利用各自的特点。如下面的例子所示:
import torch
class my_block(torch.nn.Module):
def forward(self, x):
return x * x # 进行点积
class my_model1(torch.nn.Module):
def __init__(self):
super(my_model1, self).__init__()
self.bk = my_block()
self.linear = torch.jit.trace(torch.nn.Linear(4, 4), torch.tensor([0.8, 0.7, 0.6, 0.5]))
def forward(self, x, h):
new_h = torch.tanh(self.bk(self.linear(x)) + h)
return new_h, new_h
class my_model2(torch.nn.Module):
def __init__(self):
super(my_model2, self).__init__()
self.bk = my_block()
self.linear = torch.jit.script(torch.nn.Linear(4, 4))
def forward(self, x, h):
new_h = torch.tanh(self.bk(self.linear(x)) + h)
return new_h, new_h
if __name__ is '__main__':
x = torch.tensor([0.8, 0.7, 0.6, 0.5])
h = torch.tensor([0.5, 0.5, 0.5, 0.5])
mmodel1 = my_model1()
scripted_model1 = torch.jit.script(mmodel1)
mmodel2 = my_model2()
traced_model2 = torch.jit.trace(mmodel2, (x, h))
print(scripted_model1.code)
print('--------分隔--------')
print(traced_model2.code)
打印结果:
def forward(self,
x: Tensor,
h: Tensor) -> Tuple[Tensor, Tensor]:
_0 = (self.bk).forward((self.linear).forward(x, ), )
new_h = torch.tanh(torch.add(_0, h, alpha=1))
return (new_h, new_h)
--------分隔--------
def forward(self,
argument_1: Tensor,
h: Tensor) -> Tuple[Tensor, Tensor]:
_0 = self.bk
x = (self.linear).forward(argument_1, )
_1 = torch.add((_0).forward(x, ), h, alpha=1)
_2 = torch.tanh(_1)
return (_2, _2)
对于同样的一个模型,仔细观察你会发现,使用 torch.jit.trace 和使用 torch.jit.script 所生成的计算图是有细微差别的,且不同的混搭方法也会有细微差别。
4. torch.jit.save & torch.jit.load
二者会保存和加载代码、参数、属性和调试信息。用例如下:
import torch
if __name__ is '__main__':
traced_L = torch.jit.trace(torch.nn.Linear(4, 4), torch.tensor([0.8, 0.7, 0.6, 0.5]))
scripted_L = torch.jit.script(torch.nn.Linear(4, 4))
traced_L.save('traced_L.pt')
scripted_L.save('scripted_L.pt')
L1 = torch.jit.load('traced_L.pt')
L2 = torch.jit.load('scripted_L.pt')
print(L1.weight.requires_grad)
print('--------分隔--------')
print(L2.weight.requires_grad)
打印结果:
True
--------分隔--------
True
有意思的是,无论是 script 还是 trace,加载回来了还可以继续训练。
个人思考
首先要提的一点就是,一般我们用 torch.save 保存模型,本质上调用的是 pickle 模块来序列化保存模型,只保存模型参数,并没有计算图。计算图通俗来说就是一个流图,描述了数据从模型输入到模型输出该经过什么必要流程、什么必要计算。也就是说,如果你需要使用该模型进行推理时,还需要实现定义模型的类的 .py 代码——也就是定义对应的计算图。否则不得行。
而个人认为,TorchScript 的出现解决了这个问题,其储存 ScriptModule 模型时能同时将计算图(也即模型的运行过程)保存下来。利用 TorchScript 保存下来的模型可以加载到没有 Python 环境的进程中——比如 C++ 环境。
本文章只做了一个 Introduction,更多的功能和用法还请参考官方文档。以上总结如有谬误,还请指正、包涵。