小伙伴们好呀,TorchScript 解读系列教程更新啦~在上篇文章中,咱们带领咱们开始了解了 TorchScript。

TorchScript 是 PyTorch 提供的模型序列化以及布置计划,能够弥补 PyTorch 难于布置的缺点,也能够轻松完结图优化或后端对接。TorchScript 支撑经过trace来记载数据流的生成方法;也支撑解析 AST 直接生成图的script方法。

今日咱们将介绍 TorchScript 经过trace来记载数据流的生成方法,一起还将共享运用该机制完结的 ONNX 导出进程。接下来,就让咱们进入今日的正题吧~

基本概念

首要来看一下同一个模型的三种不同表述,为了便利展示各种 jit 的组件,这儿会运用script方法创立图:

代码

 def forward(self, x):
    x = x * 2 
    x.add_(0) 
    x = x.view(-1) 
    if x[0] > 1: 
        return x[0] 
    else: 
        return x[-1]

TorchScript Graph

graph(%self : __torch__.TestModel,
      %x.1 : Tensor): 
  %12 : int = prim::Constant[value=-1]() # graph_example.py:12:19 
  %3 : int = prim::Constant[value=2]() # graph_example.py:10:16 
  %6 : int = prim::Constant[value=0]() # graph_example.py:11:15 
  %10 : int = prim::Constant[value=1]() # graph_example.py:12:20 
  %x.3 : Tensor = aten::mul(%x.1, %3) # graph_example.py:10:12 
  %8 : Tensor = aten::add_(%x.3, %6, %10) # graph_example.py:11:8 
  %13 : int[] = prim::ListConstruct(%12) 
  %x.6 : Tensor = aten::view(%x.3, %13) # graph_example.py:12:12 
  %17 : Tensor = aten::select(%x.6, %6, %6) # graph_example.py:13:11 
  %18 : Tensor = aten::gt(%17, %10) # graph_example.py:13:11 
  %20 : bool = aten::Bool(%18) # graph_example.py:13:11 
  %41 : Tensor = prim::If(%20) # graph_example.py:13:8 
    block0(): 
      %23 : Tensor = aten::select(%x.6, %6, %6) # graph_example.py:14:19 
      -> (%23) 
    block1(): 
      %32 : Tensor = aten::select(%x.6, %6, %12) # graph_example.py:16:19 
      -> (%32) 
  return (%41)

上图中心的部分便是 TorchScript 模型的可视化成果,其间包含如下一些元素:

Graph

表格中Graph列全体用来表明一个Graph,它有如下性质

  • Graph 用来表明一个“函数”,一个 Module 中的不同函数(比方 forward 等)会被转化成不同的 Graph。
  • Graph 具有许多的 Node,这些 Node 由一个 Block 管理。所有 Node 安排成双向链表的形式,便利刺进删去,其间返回值节点“Return Node”会作为这个双向链表的“岗兵”。双向链表一般会被拓扑排序,确保履行的正确性。

Node

表格中 Graph 列里 3~14 行,以及 16和19 行表明各个Node,一个 Node 对应一个操作。操作的输入为 Value,少数状况下还会有一些 static attribute。Node 中包含许多信息,包含:

  • kind() 表明 Node 的操作类型,上图中的aten::mulprim::ListConstruct等都是对应 Node 的 kind。留意它仅仅个字符串,因而修正这个字符串也就意味着修正了操作。
  • FunctionSchema 指对这个函数的接口的描绘,格局看起来就相似 ops 函数的声明,别的能够添加一些符号表明某个 Tensor 是否是另一个 Tensor 的 Alias 等等(别名剖析是确保优化成果正确的依据),能够作为 peelhole-optimize 的时分的检索依据。以Tensor.add_函数为例:
// add_是一个inplace运算,因而输出和self同享相同的内存空间 
// FunctionSchema中标注了这种别名联系,确保了输出的正确性 
// netron的可视化好像不会进行alias analysis?因而上面右图的可视化中,add_的部分存在错误 
"add_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)" 
  • 常用的函数的 schema 能够在aten/src/ATen/native/native_functions.yaml中检查。

Block

Block表明一个 Node 的有序列表,代表输入的 Node 的kind=Param,代表输出的 Node 的kind=Return

实践上 Graph 本身隐含一个 root Block 目标,用来管理所有的 Node。部分 Node 或许还会存在 sub Block。比方表中的 Graph 就有3个 Block,一个是 Graph 隐含的 root Block,另两个是prim::IfNode 的 sub Block。

Block 的概念或许源于编译原理中的基本块。所谓基本块便是一系列不包含任何跳转指令的指令序列,因为基本块内的内容能够确保是次序履行的,因而许多的优化都会以基本块作为前提。实践上 PyTorch 中对中心表明(IR)的优化有非常多是 Block 级别的。

Value

Value是 Node 的输入输出,能够是 Tensor 也能够是容器或其他类型,能够经过type()判断。

Value 目标维护了一个 use_list,只需这个 Value 成为某个 Node 的输入,那么这个 Node 就要加入到它的 use_list 中。经过这个 use_list,能够很便利地解决新加入的 Node 与其他 Node 的输入输出联系。

留意:Value 是用来表述 Graph 的结构的,与 Runtime 无关!真正在推理时用到的是 IValue 目标,IValue 中有运转时的真实数据。

Pass

严格地说这不是 Graph 的一部分,pass 是一个来源于编译原理的概念,它会接纳一种中心表明(IR),遍历它并且进行一些改换,生成满意某种条件的新 IR。

TorchScript 中界说了许多 pass 来优化 Graph。比方对于常规编译器很常见的 DeadCodeElimination(DCE),CommonSubgraphElimination(CSE)等等;也有一些针对深度学习的交融优化,比方 FuseConvBN 等;还有针对特殊任务的 pass,ONNX 的导出便是其间一类 pass。

JIT Trace

Jit trace 在 python 侧的接口为torch.jit.trace,输入的参数会经过层层传递,最终会进入torch/jit/frontend/trace.cpp中的trace函数中。这个函数是 Jit trace 的中心,大致履行了下面几个进程:

  1. 创立新的TracingState目标,该目标会维护 trace 的 Graph 以及一些必要的环境参数。
  2. 依据 trace 时的模型输入参数,生成 Graph 的输入节点。
  3. 进行模型推理,一起生成 Graph 中的各个元素。
  4. 生成 Graph 的输出节点。
  5. 进行一些简略的优化。

下面会逐个介绍这些进程的细节:

1.创立TracingState目标

TracingState目标包含了 Graph 的指针、函数名映射、栈帧信息等,trace 的进程便是不断更新 TracingState 的进程。

struct TORCH_API TracingState
    : public std::enable_shared_from_this<TracingState> { 
  // 部分接口,能够协助Graph的构建 
  std::shared_ptr<Graph> graph; 
  void enterFrame(); 
  void leaveFrame(); 
  void setValue(const IValue& v, Value* value); 
  void delValue(const IValue& var); 
  Value* getValue(const IValue& var); 
  Value* getOutput(const IValue& var, size_t i); 
  bool hasValue(const IValue& var) const; 
  Node* createNode(c10::Symbol op_name, size_t num_outputs); 
  void insertNode(Node* node); 
}; 

2.生成 Graph 输入

这个进程会依据输入的 IValue 的类型,在 graph 中刺进新的输入 Value。还记得在基本概念章节中咱们说到的 IValue 与 Value 的区别吗?

for (IValue& input : inputs) {
    // addInput这个函数会unpack一些容器类型的IValue,创立对应的Node 
    input = addInput(state, input, input.type(), state->graph->addInput()); 
} 

3.进行 Tracing

Tracing 的进程便是运用样本数据进行一次推理的进程,可是实践在 github 的源码中,并不能找到关于推理时怎么更新 TracingState 的代码。

那么 PyTorch 到底是怎么做到在推理时更新 TracingState 的呢?咱们首要介绍关于 PyTorch 源码编译的一些小细节。

PyTorch 要适配各种硬件以及环境,为所有这些状况定制代码工作量大得可怕,也不便利后续的维护更新。因而 PyTorch 中许多代码是依据 build 时的参数生成出来,更新 TracingState 的代码便是其间之一。生成 Tracing 代码的脚本如下:

python -m tools.autograd.gen_autograd \
    aten/src/ATen/native/native_functions.yaml \ 
    ${OUTPUT_DIR} \ 
    tools/autograd 
# derivatives.yaml和native_functions.yaml中包含 
# 许多FunctionSchema以及生成代码需求的信息 

咱们能够跑一下看看都生成了些什么。生成的代码中TraceTypeEverything.cpp包含了许多关于更新 TracingState 的内容,咱们还是以add算子举例如下:

yaml

- func: scatter_add(Tensor self, int dim, Tensor index, Tensor src) -> Tensor
  structured_delegate: scatter_add.out 
  variants: function, method 
- func: scatter_add_(Tensor(a!) self, int dim, Tensor index, Tensor src) -> Tensor(a!) 
  structured_delegate: scatter_add.out 
  variants: method 
- func: scatter_add.out(Tensor self, int dim, Tensor index, Tensor src, *, Tensor(a!) out) -> Tensor(a!) 
  structured: True 
  variants: function 
  dispatch: 
    CPU, CUDA: scatter_add 
 # func的内容是一个FunctionSchema,界说了函数的输入输出、别名信息等。

cpp

at::Tensor scatter_add(c10::DispatchKeySet ks, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src) {
  torch::jit::Node* node = nullptr; 
  std::shared_ptr<jit::tracer::TracingState> tracer_state; 
  if (jit::tracer::isTracing()) { 
  // 进程1: 假如tracing时,运用TracingState创立ops对应的Node并刺进Graph 
    tracer_state = jit::tracer::getTracingState(); 
    at::Symbol op_name; 
    op_name = c10::Symbol::fromQualString("aten::scatter_add"); 
    node = tracer_state->createNode(op_name, /*num_outputs=*/0); 
    jit::tracer::recordSourceLocation(node); 
    jit::tracer::addInputs(node, "self", self); 
    jit::tracer::addInputs(node, "dim", dim); 
    jit::tracer::addInputs(node, "index", index); 
    jit::tracer::addInputs(node, "src", src); 
    tracer_state->insertNode(node); 
    jit::tracer::setTracingState(nullptr); 
  } 
  // 进程2: ops核算,不论是否进行Tracing都会履行 
  auto result =at::_ops::scatter_add::redispatch(ks & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::Tracer), self, dim, index, src); 
  if (tracer_state) { 
  // 进程3: 在TracingState中设置ops输出 
    jit::tracer::setTracingState(std::move(tracer_state)); 
    jit::tracer::addOutput(node, result); 
  } 
  return result; 
}

以上上方 是FunctionSchema,下方为生成的代码。代码会依据是否isTracing来选择是否记载 Graph 的结构信息。

实践在 Tracing 时,每经过一个 ops,都会调用一个相似上面生成的函数,履行如下进程:

  1. 在推理前依据解析的 FunctionSchema 生成 Node 以及各个输入 Value;
  2. 然后进行 ops 的正常核算;
  3. 最后依据 ops 的输出生成 Node 的输出 Value。

4.注册 Graph 输出

这部分没有太多值得说的,便是挨个把推理的输出注册成 Graph 的输出 Value。因为输出在一个栈中,因而输出的编号要逆序。

    size_t i = 0;
    for (auto& output : out_stack) { 
      // NB: The stack is in "reverse" order, so when we pass the diagnostic 
      // number we need to flip it based on size. 
      state->graph->registerOutput( 
          state->getOutput(output, out_stack.size() - i)); 
      i++; 
    } 

5.Graph 优化

完结 Tracing 后,会对 Graph 进行一些简略的优化,包含如下数个 passes:

  • Inline(Optional):网络界说经常会包含许多嵌套结构,比方Resnet会由许多BottleNeck组成。这就会涉及到对 sub module 的调用,这种调用会生成prim::CallMethod等 Node。Inline 优化会将 sub module 的 Graph 内联到当前的 Graph 中,消除 CallMethod、CallFunction 等节点。
  • FixupTraceScopeBlock:处理一些与 scope 相关的 node,比方将比方prim::TracedAttr[scope="__module.f.param"]()这样的 Node 拆成数个prim::GetAttr的组合。
  • NormalizeOps:有些不同名 Node 或许有相同的功用,比方aten::absoluteaten::abs,N ormalizeOps 会把这些 Node 的类型姓名一致(一般为较短的那个)。

对 pass 更详细的剖析会在后续的共享中介绍。

经过上述进程,就能够得到经过 trace 的成果。

ONNX Export

Onnx 模型的导出相同要用到 jit trace 的进程,大致的进程如下:

  1. 加载 ops 的 symbolic 函数,主要是 torch 中预界说的 symbolic。
  2. 设置环境,包含 opset_version,是否折叠常量等等。
  3. 运用 jit trace 生成 Graph。
  4. 将 Graph 中的 Node 映射成 ONNX 的 Node,并进行必要的优化。
  5. 将模型导出成 ONNX 的序列化格局。

接下来,咱们将依照次序介绍以上几个进程:

1.加载Symbolic

严格地说这一步在 export 之前就现已完结。在symbolic_registry.py中,会维护一个_symbolic_versions目标,在导入这个模块时会运用 importlib 将预先界说的 symbolic(torch.onnx.symbolic_opset)加载到其间。

_symbolic_versions: Dict[Union[int, str], Any] = {}
from torch.onnx.symbolic_helper import _onnx_stable_opsets, _onnx_main_opset 
for opset_version in _onnx_stable_opsets + [_onnx_main_opset]: 
    module = importlib.import_module("torch.onnx.symbolic_opset{}".format(opset_version)) 
    _symbolic_versions[opset_version] = module 

_symbolic_versions中 key 为 opset_version,value 为对应的 symbolic 集合。symbolic 是一种映射函数,用来把对应的 aten/prim Node 映射成 onnx 的 Node。能够阅览torch/onnx/symbolic_opset.py了解更多细节。

2.设置环境

依据 export 的输入参数调整环境信息,比方 opset 的版本、是否将 init 导出成 Input、是否进行常量折叠等等。后续的优化会依据这些环境运转特定的 passes。

3.Graph Tracing

这一步实践履行的便是上面介绍过的 Jit Tracing 进程,假如忘记的话能够再复习一下哦。

4.ToONNX

Graph 在实践运用之前会经过许多的 pass,每个 pass 都会对 Graph 进行一些改换,能够在torch/csrc/jit/passes中检查完结细节。这些 pass 许多功用与常见的编译器中的相似,篇幅联系就不在这儿展开介绍了。对于 torchscript->ONNX 而言,最重要的 pass 当属ToONNX

ToONNX 的 python 接口为torch._C._jit_pass_onnx,对应的完结为onnx.cpp。它会遍历 Graph 中所有的 Node,生成对应的 ONNX Node,刺进新的 Graph 中:

  auto k = old_node->kind();    // 取得Node的ops类型 
  if (k.is_caffe2()) { 
    // ToONNX之前的会有一些对caffe2算子的pass 
    // 因而这儿只需直接clone到新的graph中即可 
    cloneNode(old_node); 
  } else if (k == prim::PythonOp) { 
    // 假如是Python自界说的函数,比方继承自torch.autograd.Function的函数 
    // 就会查找并调用对应的symbolic函数进行转化 
    callPySymbolicMethod(static_cast<ConcretePythonOp*>(old_node)); 
  } else { 
    // 假如是其他状况(一般是aten的算子)调用进程1加载的symbolic进行转化 
    callPySymbolicFunction(old_node); 
  } 

cloneNode 的功用就和姓名一样,便是简略的拷贝 old_node,然后塞进新的 Graph 中。

callPySymbolicFunction

当 Node 的类型为 PyTorch 的内置类型时,会调用这个函数来处理。

该函数会调用 python 侧的torch.onnx.utils._run_symbolic_function函数,将 Node 进行转化,并刺进新的 Graph,咱们能够测验如下 python 代码:

graph = torch._C.Graph()  # 创立Graph 
[graph.addInput() for _ in range(2)]  # 刺进两个输入 
node = graph.create('aten::add', list(graph.inputs()))  # 创立节点 
node = graph.insertNode(node)  # 刺进节点 
graph.registerOutput(node.output())  # 注册输出 
print(f'old graph:\n {graph}') 
new_graph = torch._C.Graph()  # 创立新的Graph用于ONNX 
[new_graph.addInput() for _ in range(2)]  # 刺进两个输入 
_run_symbolic_function( 
    new_graph, node, inputs=list(new_graph.inputs()), 
    env={})  # 将aten Node转化为onnx Node, 刺进新的Graph 
# 假如是torch>1.8,那么或许还要传入block 
print(f'new graph:\n {new_graph}') 

然后看一下可视化的成果:

Old graph

 graph(%0 : Tensor,
      %1 : Tensor): 
  %2 : Tensor = aten::add(%0, %1) 
  return (%2)

New graph

 graph(%0 : Tensor,
      %1 : Tensor): 
  %2 : Tensor = onnx::Add(%0, %1) 
  return ()

能够看见,本来的aten::add节点现已被替换为了onnx::Add。那么这个映射是怎么完结的呢?还记得第一步记载的_symbolic_versions吗?_run_symbolic_function会调用torch.onnx.symbolic_registry中的_find_symbolic_in_registry函数,查找_symbolic_versions中是否存在满意条件的映射,假如存在,就会进行如上图中的转化。

留意:转化的新 Graph 中没有输出 Value,这是因为这部分是在 ToONNX 的 c++ 代码中完结,_run_symbolic_function仅担任 Node 的映射。

callPySymbolicMethod

一些非 pytorch 原生的核算会被符号为 PythonOp。碰到这种 Node 时,会有三种或许的处理方法:

  1. 假如这个 PythonOp 带有名为 symbolic 的特点,那么就会测验运用这个 symbolic当作映射函数,生成 ONNX 节点
  2. 假如没有 symbolic 特点,可是在进程 1 的时分注册了 prim::PythonOp 的 symbolic 函数,那么就会运用这个函数生成节点。
  3. 假如都没有,则直接 clone PythonOp 节点到新的 Graph。

symbolic 函数的写法很简略,基本上便是调用 python bind 的 Graph接口创立新节点,比方:

class CustomAdd(torch.autograd.Function):
    @staticmethod 
    def forward(ctx, x, val): 
        return x + val 
    @staticmethod 
    def symbolic(g, x, val): 
        # g.op 能够创立新的Node 
        # Node的姓名 为 <domain>::<node_name>,假如domain为onnx,能够只写node_name 
        # Node能够有许多特点,这些特点名必须有_<type>后缀,比方val假如为float类型,则必须有_f后缀 
        return g.op("custom_domain::add", x, val_f=val) 

实践在运用上面的函数时,就会生成custom_domain::add这个 Node。当然,能否被用于推理这就要看推理引擎的支撑状况了。

经过callPySymbolicFunctioncallPySymbolicMethod,就能够生成一个由 ONNX(或自界说的 domain 下的 Node)组成的新 Graph。这之后还会履行一些优化 ONNX Graph 的 pass,这儿不详细展开了。

5.序列化

到这儿为止建图算是完结了,可是要给其他后端运用的话,需求将这个 Grap 序列化并导出。序列化的进程比较简略,基本上仅仅调用 ONNX 的 proto 接口,将 Graph 中的各个元素映射到 ONNX 的 GraphProto 上。没有太多值得展开的内容,能够阅览export.cpp中的EncodeGraphEncodeBlockEncodeNode函数了解更多细节。

之后只需依据具体的 export_type,将序列化后的 proto 写入文件即可。

至此,ONNX export 完结,能够开始享用各种推理引擎带来的速度提升了。

经过上面的内容共享,咱们应该对怎么运用 trace 方法生成 jit 模型,以及 trace 模型怎么影响 ONNX 导出有了一个开始的知道。为了让模型更好地为布置服务,咱们能够考虑对模型进行优化,后续的共享中将介绍一种常用的优化范式,敬请期待哦。

MMDeploy 已添加对 torchscript 模型的支撑,其间也选用 trace 的方法构建 jit 模型,欢迎咱们访问MMDeploy GitHub主页体验~

TorchScript 系列解读(二):Torch jit tracer 实现解析