TorchScript初见
自从PyTorch 1.3 开始,PyTorch也支持部署到移动端了,那么很自然地,就有以下两个必须待解决的问题:
- 如果我们训练好了一个模型,我们怎么将这个模型序列化/反序列化,从而可以在移动端上导入并运行;
- 序列化好的模型应该不受Python语言的限制,从而可以在别的语言环境下载入,例如Android中的Java,iOS中的Swift/Objective-C,甚至是服务端中的C++环境中,即语言无关特点。
很自然的一个解决方案就是,类似编译器中的中间代码技术,PyTorch以此为基础,提出了一个叫作TorchScript
的技术,即将Python语言定义的PyTorch模型,转化成一个序列化文件,即脚本文件,我想,这也正是TorchScript
技术的由来吧!
1. Tracing技术
如何将PyTorch模型转化成中间的序列化脚本呢?很原始的一个想法就是,我们将Python模型,按照给定的输入数据,将模型跑一遍,在跑的同时,我们记住模型的数据流动(也就是常说的计算图
),按照一定的规则,我们将计算图序列化保存下来即可。这实际上是一种动态代码生成技术,PyTorch官方给这种方式起名为tracing
,直意就是追踪数据在网络中的流动,从而确定计算图。
我们先定义一个简单的模型,如下:
1 | class Cell(torch.nn.Module): |
这个模型的功能就是将两个输入数据求和,然后返回结果,代码的功能不是关键。有了这个模型,我们如何tracing呢?
1 | my_cell = Cell() |
我们只要调用torch.jit.trace()
函数,并传入输入数据(x,h)即可,PyTorch会依据数据流返回一个traced object,这个对象是一个torch.jit.ScriptModule
类的实例,该类又是nn.Module
类的子类。traced_obj有两个重要的属性graph
和code
,分别表示当前输入对应的计算图底层表示和计算图对应的Python语法代码,如下:
1 | graph(%self : __torch__.Cell, |
计算图稍微有点复杂,无非也就是一些符号标记,注意,这种计算图的底层表示方法显然是语言无关的,而下面的中间代码就类似是整个计算过程对于的Python-Syntax代码了,和我们定义的forward()
很类似。
OK,现在我们要做的就是序列化到文件了,这个也不难,一行代码搞定
1 | # 序列化 |
序列化好了,下一步是什么?当然是反序列化,然后再给定一个输入数据,看看能否得到理想的输出啦!
1 | # 反序列化 |
OK,一摸一样,起码功能是实现了。看到这里,是不是你会有点疑问?我们目前为止干了什么?首先,我们将PyTorch模型序列化到文件,然后又反序列化加载进来,最后又运行了一下,你是不是会想问,兜兜转转绕了一圈,跑了个结果,为啥不直接第一步就运行?哈哈,别忘了,我们的目的是在其他非Python环境下运行,只不过我们这里的目标宿主环境刚好也用的是Python,所以会让人觉得有点绕。实际上,我们是完完全全可以开发一个基于C++/Java的客户端代码,将序列化好的模型用另外一种语言加载,并运行。
讲到这里,似乎关键的技术就讲完了,可是你有没有想过一个这样的问题:如果我们的模型,在定义的时候是有if-else这样的逻辑判断语句的,如果我们给定的当前输入对应的是if的控制流,推断得到的动态计算图必然是对应当前if控制流的;而如果在宿主环境上测试时,如果输入的数据对应的是else控制语句呢?为了说明这个问题,我们考虑下面的例子:
1 | class MyDecisionGate(torch.nn.Module): |
在1和2处的输出如下,和我们的预期相符合:
1 | def forward(self, |
可是在3处的输出会有点出乎意料:
1 | tensor([[-2., -2.], |
为何?按照我们模型的定义,输出理应是全2的tensor的!噢!明白了,原来我们在tracing的时候,数据流对应的是else段,因此在以后无论给定什么输入数据时,都会走else这个部分,这就是动态计算图构建必然会出现的问题。那,有没有解决方案呢?当然!静态构建就行了呀,开发一个特定的编译器,用来直接将Python定义的Pytorch模型转化成静态计算图,而且是可以有if-else这种逻辑判断的,好在PyTorch开发团队为我们开发了这样的编译器,即torch.jit.script()
,
2. Script技术
script的使用方法相比tracing而言,无需传入输入数据,只需传入模型对象即可,使用方法如下,
1 | scripted_cell = torch.jit.script(my_cell) |
此时对应的代码和输出分别是
1 | def forward(self, |
可以看到,相比tracing生成的代码,script生成的代码还被优化了,而且输出也是符合模型的定义的。
讲到这里,本文就可以暂告一段落了,至于后续如何在宿主语言环境(如C++)中使用TorchScript,期待下一篇博客吧。
3. References
- Introduction to TorchScript, PyTorch Official Doc.
- 直观认识torch.jit模块