初识TorchScript

Posted by 刘知安 on 2021-04-08
文章目录
  1. TorchScript初见
    1. 1. Tracing技术
    2. 2. Script技术
    3. 3. References

TorchScript初见

自从PyTorch 1.3 开始,PyTorch也支持部署到移动端了,那么很自然地,就有以下两个必须待解决的问题:

  • 如果我们训练好了一个模型,我们怎么将这个模型序列化/反序列化,从而可以在移动端上导入并运行;
  • 序列化好的模型应该不受Python语言的限制,从而可以在别的语言环境下载入,例如Android中的Java,iOS中的Swift/Objective-C,甚至是服务端中的C++环境中,即语言无关特点。

很自然的一个解决方案就是,类似编译器中的中间代码技术,PyTorch以此为基础,提出了一个叫作TorchScript的技术,即将Python语言定义的PyTorch模型,转化成一个序列化文件,即脚本文件,我想,这也正是TorchScript技术的由来吧!

1. Tracing技术

如何将PyTorch模型转化成中间的序列化脚本呢?很原始的一个想法就是,我们将Python模型,按照给定的输入数据,将模型跑一遍,在跑的同时,我们记住模型的数据流动(也就是常说的计算图),按照一定的规则,我们将计算图序列化保存下来即可。这实际上是一种动态代码生成技术,PyTorch官方给这种方式起名为tracing,直意就是追踪数据在网络中的流动,从而确定计算图。

我们先定义一个简单的模型,如下:

1
2
3
4
5
6
7
class Cell(torch.nn.Module):
def __init__(self):
super(MyCell, self).__init__()

def forward(self, x, h):
new_h = x + h
return new_h

这个模型的功能就是将两个输入数据求和,然后返回结果,代码的功能不是关键。有了这个模型,我们如何tracing呢?

1
2
3
4
5
6
7
8
my_cell = Cell()
x = torch.ones(2, 2)
h = torch.ones(2, 2) * (-2)

# tracing
traced_cell = torch.jit.trace(my_cell, (x, h))
print(traced_cell.graph)
print(traced_cell.code)

我们只要调用torch.jit.trace()函数,并传入输入数据(x,h)即可,PyTorch会依据数据流返回一个traced object,这个对象是一个torch.jit.ScriptModule类的实例,该类又是nn.Module类的子类。traced_obj有两个重要的属性graphcode,分别表示当前输入对应的计算图底层表示和计算图对应的Python语法代码,如下:

1
2
3
4
5
6
7
8
9
10
11
graph(%self : __torch__.Cell,
%x : Float(2, 2),
%h : Float(2, 2)):
%4 : int = prim::Constant[value=1]() # /Users/liuzhian/PycharmProjects/TorchScriptDemo/demo1.py:9:0
%5 : Float(2, 2) = aten::add(%x, %h, %4) # /Users/liuzhian/PycharmProjects/TorchScriptDemo/demo1.py:9:0
return (%5)

def forward(self,
x: Tensor,
h: Tensor) -> Tensor:
return torch.add(x, h, alpha=1)

计算图稍微有点复杂,无非也就是一些符号标记,注意,这种计算图的底层表示方法显然是语言无关的,而下面的中间代码就类似是整个计算过程对于的Python-Syntax代码了,和我们定义的forward()很类似。

OK,现在我们要做的就是序列化到文件了,这个也不难,一行代码搞定

1
2
# 序列化
traced_cell.save("cell.pt")

序列化好了,下一步是什么?当然是反序列化,然后再给定一个输入数据,看看能否得到理想的输出啦!

1
2
3
4
5
6
7
# 反序列化
traced_cell_loaded = torch.jit.load('cell.pt')
print(traced_cell_loaded(x, h))

# output
tensor([[-1., -1.],
[-1., -1.]])

OK,一摸一样,起码功能是实现了。看到这里,是不是你会有点疑问?我们目前为止干了什么?首先,我们将PyTorch模型序列化到文件,然后又反序列化加载进来,最后又运行了一下,你是不是会想问,兜兜转转绕了一圈,跑了个结果,为啥不直接第一步就运行?哈哈,别忘了,我们的目的是在其他非Python环境下运行,只不过我们这里的目标宿主环境刚好也用的是Python,所以会让人觉得有点绕。实际上,我们是完完全全可以开发一个基于C++/Java的客户端代码,将序列化好的模型用另外一种语言加载,并运行。

讲到这里,似乎关键的技术就讲完了,可是你有没有想过一个这样的问题:如果我们的模型,在定义的时候是有if-else这样的逻辑判断语句的,如果我们给定的当前输入对应的是if的控制流,推断得到的动态计算图必然是对应当前if控制流的;而如果在宿主环境上测试时,如果输入的数据对应的是else控制语句呢?为了说明这个问题,我们考虑下面的例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
class MyDecisionGate(torch.nn.Module):
def __init__(self):
super(MyDecisionGate, self).__init__()

def forward(self, x):
if x.sum() > 0:
return x
else:
return -x

class MyCell(torch.nn.Module):
def __init__(self):
super(MyCell, self).__init__()
self.dg = MyDecisionGate()

def forward(self, x, h):
new_h = self.dg(x + h)
return new_h


my_cell = MyCell()
x = torch.ones(2, 2)
h = torch.ones(2, 2) * (-2)
# tracing时,对应MyDecisionGate类 forward()函数中的else段
traced_cell = torch.jit.trace(my_cell, (x, h))
print(traced_cell.code) # --- 1
print(traced_cell(x, h)) # --- 2

x1 = torch.ones(2, 2)
h1 = torch.ones(2, 2)
print(traced_cell(x1, h1)) # --- 3

在1和2处的输出如下,和我们的预期相符合:

1
2
3
4
5
6
7
8
9
def forward(self,
x: Tensor,
h: Tensor) -> Tensor:
_0 = self.dg
x0 = torch.add(x, h, alpha=1)
return (_0).forward(x0, )

tensor([[1., 1.],
[1., 1.]])

可是在3处的输出会有点出乎意料:

1
2
tensor([[-2., -2.],
[-2., -2.]])

为何?按照我们模型的定义,输出理应是全2的tensor的!噢!明白了,原来我们在tracing的时候,数据流对应的是else段,因此在以后无论给定什么输入数据时,都会走else这个部分,这就是动态计算图构建必然会出现的问题。那,有没有解决方案呢?当然!静态构建就行了呀,开发一个特定的编译器,用来直接将Python定义的Pytorch模型转化成静态计算图,而且是可以有if-else这种逻辑判断的,好在PyTorch开发团队为我们开发了这样的编译器,即torch.jit.script()

2. Script技术

script的使用方法相比tracing而言,无需传入输入数据,只需传入模型对象即可,使用方法如下,

1
2
3
4
5
scripted_cell = torch.jit.script(my_cell)
print(scripted_cell.code)
x2 = torch.ones(2, 2)
h2 = torch.ones(2, 2)
print(scripted_cell(x2, h2))

此时对应的代码和输出分别是

1
2
3
4
5
6
7
8
def forward(self,
x: Tensor,
h: Tensor) -> Tensor:
new_h = (self.dg).forward(torch.add(x, h, alpha=1), )
return new_h

tensor([[2., 2.],
[2., 2.]])

可以看到,相比tracing生成的代码,script生成的代码还被优化了,而且输出也是符合模型的定义的。

讲到这里,本文就可以暂告一段落了,至于后续如何在宿主语言环境(如C++)中使用TorchScript,期待下一篇博客吧。

3. References