PyTorch:TorchScript 概要

  • Intro
  • Basic Usage
  • 1. torch.jit.trace
  • API Definition
  • Example (tracing a function)
  • Example (tracing an nn.Module)
  • 2. torch.jit.script
  • API Definition
  • Example (scripting a function)
  • Example (scripting an nn.Module)
  • 3. Mixing Scripting and Tracing
  • 4. torch.jit.save & torch.jit.load
  • 个人思考


Intro

  首先要提的一点概念就是 JIT(Just In Time)——这是一种即时编译器机制,其能把某些频繁被调用的代码编译成机器码,并进行各种层次的优化,以提高执行效率。
  而TorchScript 是 PyTorch 用来创建可序列化和可优化模型的即时编译器。任何 TorchScript 程序都可以从 Python 进程中保存,然乎再导出到生产环境,加载到没有 Python 依赖项的进程中。

  TorchScript 的强大,在于它的编译机制:

  1. 其作用于一个模型,既能记录模型的参数,又能保存计算图(Intermediate Representation Graph, IR Graph)。
  2. 其作用于一个函数,既能保存函数的有关参数,又能保存运行逻辑(本质上也还是计算图)。

Basic Usage

  在介绍具体用法之前,先引入几个概念。

  1. Scripted 类:由TorchScript 作用、编译而生成。其可分为 ScriptModuleScriptFunction 两种:
ScriptModule:由 nn.Module 编译而来。
ScriptFunction:由 一般的函数编译而来。
  1. 控制流(Control Flow):PyTorch 官方将 if 语句循环语句等具有选择/判断性质的语句称为控制流。
  2. 常用 API:
torch.jit.trace:跟踪函数或模型的运行过程,并即时编译成可执行对象,但是需要数据作为输入。与 torch.jit.script 类似。
torch.jit.script:将函数或模型的运行过程即时编译化,不需要输入数据。与 torch.jit.trace 类似。可以通过装饰器调用。
torch.jit.export:一般作用于模型中的成员函数,表示该函数也需要被即时编译。配合 torch.jit.script,通过装饰器调用。
torch.jit.ignore:一般作用于模型中的成员函数,表示该函数不需要被即时编译,但是是可访问的。配合 torch.jit.script,通过装饰器调用。
torch.jit.save:保存成 Scripted 类实例。
torch.jit.load:加载出 Scripted 类实例。

1. torch.jit.trace

API Definition

torch.jit.trace( 
	func , example_inputs , optimize=None , check_trace=True , check_inputs=None 
	, check_tolerance=1e-05 , strict=True , _force_outplace=False , _module_class=None 
	, _compilation_unit=<torch.jit.CompilationUnit  object> )

"""
主要参数:
	➡func(callable or torch.nn.Module)– Python 函数或torch.nn.Module。将与example_inputs一起运行。
	注意,func的输入参数和返回值(return)必须是Tensor或包含Tensor的元组。
	➡example_inputs (tuple or torch.Tensor) – 将被传递进func中。
	
其他参数:不作展开。详情可参阅官方文档。

返回:ScriptModule 或 ScriptFunction。
"""

Example (tracing a function)

import torch

def my_example(x):
    useful = 1
    useless = 888
    useless2 = useless + x

    if x.sum() > 0:
        return x * useful
    else:
        return -x * useful


if __name__ is '__main__':
    x1 = torch.tensor([-0.8, -0.7, 0.6])
    traced_example = torch.jit.trace(my_example, x1)
    x2 = torch.tensor([0.8, 0.7, 0.6])
    infer = traced_example(x2)

    print(traced_example.code)  # ScriptFuction 会有一个code属性
    print('--------分隔--------')
    print(traced_example.graph)  # ScriptFuction 会有一个graph属性
    print('--------分隔--------')
    print(infer)

  打印结果:

def my_example(x: Tensor) -> Tensor:
  return torch.mul(torch.neg(x), CONSTANTS.c0)

--------分隔--------
graph(%x : Float(3, strides=[1], requires_grad=0, device=cpu)):
  %5 : Float(3, strides=[1], requires_grad=0, device=cpu) = aten::neg(%x) # F:/PYcharm/test.py:39:0
  %6 : Long(requires_grad=0, device=cpu) = prim::Constant[value={1}]() # F:/PYcharm/test.py:39:0
  %7 : Float(3, strides=[1], requires_grad=0, device=cpu) = aten::mul(%5, %6) # F:/PYcharm/test.py:39:0
  return (%7)

--------分隔--------
tensor([-0.8000, -0.7000, -0.6000])

  观察上面的结果,你会发现:

  • Scripted 类实例(也就是 traced_example )都会有 .code 属性 和 .graph 属性。前者描述了代码运行过程所做的运算,便于阅读;后者是抽象一点的计算图,具体会记录参数、算子(aten 是PyToch 的C++底层运算库)等。
  • torch.jit.trace 会记录与输入做运算对 return 结果产生影响的参数和算子。上面的打印结果你会发现,useful 变量被记录了下来,因为其既与输入做运算,又与最终 return 的内容有关系。而 useless、useless2 变量的语句没有包含在计算图中,因为其与实际要 return 的计算过程无关。
  • torch.jit.trace 不会记录控制流(比如 if-else 语句)的细节,只记录参与运算的部分,这同样可以在上面的打印结果中发现这个现象。这很有可能产生错误的计算图,即不是你想要的,因为“写死”了——使用 traced_example 进行推理,会发现 infer 变量的赋值逻辑不对,逻辑上应该是 tensor([0.8000, 0.7000, 0.6000]),但是结果却是 tensor([-0.8000, -0.7000, -0.6000])。

  上面的特性也是有一定意义的,当你不想让人知道你的控制流是如何设置的,你的其他变量是如何设置的——你不想展示给别人看,那么就可以用 torch.jit.trace,追踪出一条计算路线记录下来即可。需要注意的是,最好不要使用 torch.jit.trace 来作用 RNN,因为 RNN 里有循环,也就是有控制流。

Example (tracing an nn.Module)

import torch

class my_block(torch.nn.Module):
    def forward(self, x):
        return x * x  # 进行点积

def no_trace_func1(x, h):
    return x + h

class my_model(torch.nn.Module):
    def __init__(self):
        super(my_model, self).__init__()
        self.bk = my_block()
        self.linear = torch.nn.Linear(4, 4)

    def no_trace_func2(self, x, h):
        return x - h

    def forward(self, x, h):
        x = no_trace_func1(x, h)
        h = self.no_trace_func2(x, h)
        new_h = torch.tanh(self.bk(self.linear(x)) + h)
        return new_h, new_h


if __name__ is '__main__':
    x = torch.tensor([0.8, 0.7, 0.6, 0.5])
    h = torch.tensor([0.5, 0.5, 0.5, 0.5])
    mmodel = my_model()
    traced_model = torch.jit.trace(mmodel, (x, h))

    print(traced_model.code)
    print('--------分隔--------')
    print(traced_model.bk.code)
    # print(traced_cell.no_trace_func2.code)  # 报错,没有no_trace_func2属性

打印结果:

def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  _0 = self.bk
  _1 = self.linear
  x0 = torch.add(x, h, alpha=1)
  h0 = torch.sub(x0, h, alpha=1)
  _2 = torch.add((_0).forward((_1).forward(x0, ), ), h0, alpha=1)
  _3 = torch.tanh(_2)
  return (_3, _3)

--------分隔--------
def forward(self,
    argument_1: Tensor) -> Tensor:
  return torch.mul(argument_1, argument_1)

  观察上面的结果,你会发现:

  • 当 torch.jit.trace 作用于父模型(mmodel)时,只有父模型的 forward 函数被跟踪和记录。而 traced_model.no_trace_func2 属性无法访问到,因为其不会被记录。
  • forward 函数内的所有与输入做运算对 return 结果产生影响的参数和算子过程都被记录了。no_trace_func1 参与了计算,所以 forward 里追踪了 no_trace_func1 的运算细节,即 torch.add(x, h, alpha=1) ;no_trace_func2 参与了计算,所以 forward 里追踪了 no_trace_func2 的运算细节,即 torch.sub(x0, h, alpha=1)。
  • 父模型的 .code 属性不会再嵌套记录子模型的 forward 细节,而是将子模型定义成一个实例直接调用。
  • 子模型 traced_model.bk 是可以访问到的,因为其参与了前向计算、也有 forward。故可以调用 traced_model.bk.code 来查看它的 forward 发生了什么。

2. torch.jit.script

API Definition

torch.jit.script(obj, optimize=None, _frames_up=0, _rcb=None)

"""
主要参数:
	➡obj (callable, class, or nn.Module) – Python 函数或torch.nn.Module。
	
其他参数:不作展开。详情可参阅源码注释。

返回:ScriptModule、ScriptFunction、torch._C.ScriptDict 或 torch._C.ScriptList。
"""

Example (scripting a function)

import torch

@torch.jit.script
def my_func(x, y):
    useless = 666
    useless2 = 1 * x
    if x.max() > y.max():
        r = x
    else:
        r = y
    return r


if __name__ is '__main__':
    infer = my_func(torch.tensor([-0.1, 0.6]), torch.tensor([0.8, 0.4]))
    infer2 = my_func(torch.tensor([0.1, 0.6]), torch.tensor([-0.8, 0.4]))

    print(type(my_func))
    print('--------分隔--------')
    print(my_func.code)
    print('--------分隔--------')
    print(my_func.graph)
    print('--------分隔--------')
    print(infer)
    print('--------分隔--------')
    print(infer2)

打印结果:

<class 'torch.jit.ScriptFunction'>
--------分隔--------
def my_func(x: Tensor,
    y: Tensor) -> Tensor:
  _0 = bool(torch.gt(torch.max(x), torch.max(y)))
  if _0:
    r = x
  else:
    r = y
  return r

--------分隔--------
graph(%x.1 : Tensor,
      %y.1 : Tensor):
  %7 : Tensor = aten::max(%x.1) # F:/PYcharm/test.py:100:7
  %9 : Tensor = aten::max(%y.1) # F:/PYcharm/test.py:100:17
  %10 : Tensor = aten::gt(%7, %9) # F:/PYcharm/test.py:100:7
  %12 : bool = aten::Bool(%10) # F:/PYcharm/test.py:100:7
  %r : Tensor = prim::If(%12) # F:/PYcharm/test.py:100:4
    block0():
      -> (%x.1)
    block1():
      -> (%y.1)
  return (%r)

--------分隔--------
tensor([0.8000, 0.4000])
--------分隔--------
tensor([0.1000, 0.6000])

  观察上面的结果,你会发现:

  • 可以通过装饰器来调用 torch.jit.script 。也就是说,当调用 my_func 函数时,会先调用 torch.jit.script。
  • 可以发现控制流被完整的记录下来了,即 if-else 语句被完整的记录下来了。于是输入不同的数据,产生的结果就会不同,逻辑正确。这就是 script 和 trace 的区别之一。
  • 当然也可以发现,只有与输入做运算对 return 结果产生影响的参数和算子被记录了下来,useless、useless2 都没有被记录。
  • 不需要 example_inputs,就可以形成完整的 Graph。

Example (scripting an nn.Module)

import torch
import torch.nn as nn

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()

    @torch.jit.export
    def compiled_stored_func(self, input):
        return input + 10

    @torch.jit.ignore
    def observed_func(self, input):
        # 该函数不会被 TorchScript 编译,但是你在接下来的 Python 环境中可以继续使用。
        return input - 1

    def normal_func(self, input):
        return input / 10

    def forward(self, input):
        out = self.observed_func(input)
        out = self.compiled_stored_func(out)
        return out * 1


if __name__ is '__main__':
    x = torch.tensor([1.1, 2.2])
    scripted_module = torch.jit.script(MyModule())



    print(type(scripted_module.compiled_stored_func))
    print('--------分隔--------')
    print(scripted_module.compiled_stored_func.code)
    print('--------分隔--------')
    print(type(scripted_module.observed_func))
    print('--------分隔--------')
    # print(type(scripted_module.observed_func.code))  # 报错,observed_func没有code这个属性
    # print(type(scripted_module.normal_func))  # 报错,没有normal_func这个属性

打印结果:

<class 'torch._C.ScriptMethod'>
--------分隔--------
def compiled_stored_func(self,
    input: Tensor) -> Tensor:
  return torch.add(input, 10, 1)

--------分隔--------
<class 'method'>
--------分隔--------

  观察上面的结果,你会发现:

  • torch.jit.export 的作用下,compiled_stored_func 函数也被编译和记录了;torch.jit.ignore 的作用下,observed_func 可以在以后的代码中可作为成员函数调用,但是其没有被编译,最后不会被 ScriptModule 记录;而normal_func 函数即一般的成员函数,没有被编译,也没有被记录,也不可被访问。

3. Mixing Scripting and Tracing

  实际情况中,torch.jit.script 和 torch.jit.trace 是可以混合使用的,以利用各自的特点。如下面的例子所示:

import torch

class my_block(torch.nn.Module):
    def forward(self, x):
        return x * x  # 进行点积

class my_model1(torch.nn.Module):
    def __init__(self):
        super(my_model1, self).__init__()
        self.bk = my_block()
        self.linear = torch.jit.trace(torch.nn.Linear(4, 4), torch.tensor([0.8, 0.7, 0.6, 0.5]))

    def forward(self, x, h):
        new_h = torch.tanh(self.bk(self.linear(x)) + h)
        return new_h, new_h

class my_model2(torch.nn.Module):
    def __init__(self):
        super(my_model2, self).__init__()
        self.bk = my_block()
        self.linear = torch.jit.script(torch.nn.Linear(4, 4))

    def forward(self, x, h):
        new_h = torch.tanh(self.bk(self.linear(x)) + h)
        return new_h, new_h


if __name__ is '__main__':
    x = torch.tensor([0.8, 0.7, 0.6, 0.5])
    h = torch.tensor([0.5, 0.5, 0.5, 0.5])

    mmodel1 = my_model1()
    scripted_model1 = torch.jit.script(mmodel1)

    mmodel2 = my_model2()
    traced_model2 = torch.jit.trace(mmodel2, (x, h))

    print(scripted_model1.code)
    print('--------分隔--------')
    print(traced_model2.code)

  打印结果:

def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  _0 = (self.bk).forward((self.linear).forward(x, ), )
  new_h = torch.tanh(torch.add(_0, h, alpha=1))
  return (new_h, new_h)

--------分隔--------
def forward(self,
    argument_1: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  _0 = self.bk
  x = (self.linear).forward(argument_1, )
  _1 = torch.add((_0).forward(x, ), h, alpha=1)
  _2 = torch.tanh(_1)
  return (_2, _2)

  对于同样的一个模型,仔细观察你会发现,使用 torch.jit.trace 和使用 torch.jit.script 所生成的计算图是有细微差别的,且不同的混搭方法也会有细微差别。

4. torch.jit.save & torch.jit.load

  二者会保存和加载代码、参数、属性和调试信息。用例如下:

import torch


if __name__ is '__main__':
    traced_L = torch.jit.trace(torch.nn.Linear(4, 4), torch.tensor([0.8, 0.7, 0.6, 0.5]))
    scripted_L = torch.jit.script(torch.nn.Linear(4, 4))

    traced_L.save('traced_L.pt')
    scripted_L.save('scripted_L.pt')

    L1 = torch.jit.load('traced_L.pt')
    L2 = torch.jit.load('scripted_L.pt')

    print(L1.weight.requires_grad)
    print('--------分隔--------')
    print(L2.weight.requires_grad)

  打印结果:

True
--------分隔--------
True

  有意思的是,无论是 script 还是 trace,加载回来了还可以继续训练。

个人思考

  首先要提的一点就是,一般我们用 torch.save 保存模型,本质上调用的是 pickle 模块来序列化保存模型,只保存模型参数,并没有计算图。计算图通俗来说就是一个流图,描述了数据从模型输入到模型输出该经过什么必要流程、什么必要计算。也就是说,如果你需要使用该模型进行推理时,还需要实现定义模型的类的 .py 代码——也就是定义对应的计算图。否则不得行。

  而个人认为,TorchScript 的出现解决了这个问题,其储存 ScriptModule 模型时能同时将计算图(也即模型的运行过程)保存下来。利用 TorchScript 保存下来的模型可以加载到没有 Python 环境的进程中——比如 C++ 环境。

  本文章只做了一个 Introduction,更多的功能和用法还请参考官方文档。以上总结如有谬误,还请指正、包涵。