TorchDynamo初探:Python ByteCode的动态修改

作者|strint

1

布景

深度学习结构编译优化时,需求先根据核算逻辑构成一个逻辑核算图,然后再改写核算图,最终履行改写后的核算图。其间生成逻辑核算图方式有两种。

一种核算图生成是根据 trace tensor 的,跟踪 tensor 的履行途径。tensor 履行时,根据函数重载,能够落到支撑 tensor 核算的结构自定义函数,该函数一般是 c++ 层的。c++ 层的自定义函数中,功用是用于生成一个 Operation 的符号表达。比方一个关于加法运算,trace 便是记载一个符号化的加法算子。如此一连串的运算就被转换了符号化的核算图。

别的一种核算图生成是根据 AST(抽象语法树) 解析的。在代码履行前,直接根据 Python 文本代码得到 Python AST,然后根据 AST 来翻译成核算图(也叫做中间代码 IR)。

Python(特指 CPython)解释器履行,第一阶段会先把 Python 源码解析成 AST,第二阶段根据 AST 生成和优化 ByteCode(字节码),第三阶段在虚拟机中履行 ByteCode。

根据 AST 解析的核算图生成,发生在这儿的第一阶段;根据 trace tensor 的核算图生成,发生在第三阶段之后。

TorchDynamo 特别的当地在于其作业在第二阶段,动态修正 Python ByteCode,这样第三阶段履行的已经是修正后的 ByteCode了。

2

TorchDynamo 概述

TorchDynamo 是 PyTorch 新实验的 JIT 编译接口,支撑运用 Python 在运行时修正动态履行逻辑,修正的时机是 CPython 的 ByteCode 履行前。这个思想相似 DynamoRIO(dynamorio.org) 项目,DynamoRIO 能够动态的修正 x86 机器码。

CPython 的每次函数调用会生成一个 Frame(或许叫 Stack),Frame 中带有的代码部分便是 ByteCode。CPython 运行时支撑根据现有的 Frame 去设置一个自定义的 Frame,然后后边履行的便是自定义的 Frame。

TorchDynamo 的作业原理便是在运行时设置一个自定义的 Frame,该 Frame 中的 ByteCode 支撑 CallBack 到 Python 层去修正。其供给的典型的修正接口是 FX Graph,也便是说 TorchDynamo 会剖析 ByteCode,生成对应的 FX Graph,然后供给 FX Graph 的接口供用户自定义核算图。这种做法有如下长处:

  • 能够支撑一切的 Python 语法,因为假如在自定义 Frame 过程中的任何一点发现不支撑,都能够挑选不修正 Frame 而回退到原 Frame;

  • 开支少,劫持发生在 Python 履行比较早的阶段(ByteCode 生成和优化阶段),而非 Python ByteCode 履行后的阶段,有时能够削减 Python ByteCode 的履行开支(猜想假如很屡次 ByteCode 层面的函数调用被交融层成一次函数调用,确实能够缩减开支);

  • 能够做到不添加编译带来的延迟(之前的根据 tensor trace 或许 ast 解析的做法,一般都有先编译履行所以编译开支无法掩盖,可是改写 ByteCode 这个做法,猜想是能够在识别出热点代码后,单独开一个线程去做编译,而不影响主线程作业。Python ByteCode 改写的 API 中有这种延迟编译的样例,peps.python.org/pep-052 )。

之前核算图生成机制(根据 trace tensor、根据 AST 解析的)中的几个问题,得到了缓解:

  • 存在无法静态化的操作,之前一般需求显式的移除静态化作用域,现在总是答应不做编译,直接履行原 Python 代码,这样使得静态化标示变得简略;

  • 翻开静态图编译优化,之前编译时一般无法掩盖,现在有方法部分掩盖;

  • 动态 shape 问题,因为有了编译时和运行时的掩盖,也能够得到缓解。

这种尽量优化、动态优化的设计,最大程度了照顾了代码开发的体验,让编译优化上手变得更简略了。这是 TorchDynamo 带来的最首要的好处。这种做法非常契合 PyTorch 的 Python First、Eager First、User Experience First的偏好。可是这个设计关于寻求最好的性能、最方便的静态化部署这两个目标并没有改进。

3

CPython 的规范履行流程

上文提到了 CPython 的履行从 Python 文本代码,到 AST,到 ByteCode。这儿用一个示例打开看一下。Python 的规范组件非常易用,能够在 Python 层用 ast 组件来检查 AST,能够用 compile 内置函数来编译 ByteCode,能够用 exec 体系函数来履行 ByteCode。咱们先在代码最初导入相关组件:

import ast
import dis
import sys

然后咱们构造一个 python 代码,能够看到 src_code 便是普通的字符串。其间包括了一段普通的 python 内置的乘法,一段深度学习的 tensor scalar 加法,最终一段是当时Python Frame 中的 ByteCode 相关目标的打印(用于一个查验,后边会提到)。

print("=== source code ===")
src_code = """
# normal python operation
x = 1
x = x * 2
# tensor operation
y = dl_framework.ones((1, 2))
z = x + y
print(z)
# print python frame
f = sys._getframe()
# print the code object
print(f.f_code)
"""
print(src_code)

然后运用 ast 组件来生成这段代码的 AST。

print("=== source code to ast ===")
# 把源代码解析成 AST
ast_obj = ast.parse(src_code)
# 打印 AST
print(ast.dump(ast_obj))

能够得到 AST,这儿展示的成果额外做了格式化,别的删减掉了和核算逻辑无关的打印 frame 的部分,代码和其 AST 的对应联系拜见注释。AST解析是纯文本层面的,dl_framework 还没有被 import 进来,AST解析仍然能够正常作业。AST 基本是一个多叉树的结构,每个节点对应一个表达式,节点子节点代表子表达式。以 x = x + 2 为例,Assign 是一个节点,是赋值运算,被赋值的是 x,赋值的值是一个二元乘法运算。

Module(body=[
  # x = 1
  Assign(targets=[Name(id='x', ctx=Store())],
         value=Constant(value=1, kind=None),
         type_comment=None),
  # x = x * 2
  Assign(targets=[Name(id='x', ctx=Store())],
         value=BinOp(left=Name(id='x', ctx=Load()), op=Mult(), right=Constant(value=2, kind=None)), type_comment=None),
  # y = dl_framework.ones((1, 2))
  Assign(targets=[Name(id='y', ctx=Store())],
         # dl_framework.ones((1, 2))
         value=Call(func=Attribute(value=Name(id='dl_framework', ctx=Load()),
                    attr='ones', ctx=Load()),
                    args=[Tuple(elts=[Constant(value=1, kind=None),
                    Constant(value=2, kind=None)], ctx=Load())], keywords=[]), type_comment=None),
  # z = x + y
  Assign(targets=[Name(id='z', ctx=Store())],
         # x + y
         value=BinOp(left=Name(id='x', ctx=Load()),
                    op=Add(),
                    right=Name(id='y', ctx=Load())), type_comment=None),
  # print(z)
  Expr(value=Call(func=Name(id='print', ctx=Load()), args=[Name(id='z', ctx=Load())], keywords=[])),
  # 省掉了打印 frame 的代码
],
type_ignores=[]
)

Python AST 生成后,能够利用体系函数 compile 把它转成 ByteCode 字节码。解释器履行也存在编译的环节,只不过是编译成字节码。

print("=== ast to bytecode ===")
# 编译成 ByteCode
code_obj = compile(ast_obj, filename="", mode="exec")
print(code_obj)
# 展示 ByteCode 的语法糖
byte_obj = dis.Bytecode(code_obj)
print(byte_obj.dis())

print(code_obj)的成果是 <code object <module> at 0x7ff79bb5c660, file "", line 3>,这儿能够看到生成的 code object 目标的指针是 0x7ff79bb5c660,后边咱们在履行字节码时,会再次看到这个指针。

print(byte_obj.dis()) 的成果如下,每一行对应一条字节码,也即一条指令, 经过字面意义基本能够看出是在做什么:

 # x = 1
  3           0 LOAD_CONST               0 (1)
              2 STORE_NAME               0 (x)
             # x = x * 2
  4           4 LOAD_NAME                0 (x)
              6 LOAD_CONST               1 (2)
              8 BINARY_MULTIPLY
             10 STORE_NAME               0 (x)
             # y = dl_framework.ones((1, 2))
  7          12 LOAD_NAME                1 (dl_framework)
             14 LOAD_METHOD              2 (ones)
             16 LOAD_CONST               2 ((1, 2))
             18 CALL_METHOD              1
             20 STORE_NAME               3 (y)
             # x = x + y
  8          22 LOAD_NAME                0 (x)
             24 LOAD_NAME                3 (y)
             26 BINARY_ADD
             28 STORE_NAME               4 (z)
             # print(z)
  9          30 LOAD_NAME                5 (print)
             32 LOAD_NAME                4 (z)
             34 CALL_FUNCTION            1
             36 POP_TOP
             # 省掉了打印 frame 的代码

得到 ByteCode 之后,就能够传递给 Python VM 履行了。在真实履行前,先做了一下 ByteCode 中指令的打印,实践 Python VM 履行时,也基本是这样遍历每一行指令,然后履行指令。能够幻想,假如这些指令被修正,就能够让 Python VM 履行自定义的指令了。

print("=== execute bytecode ===")
# print instruction
for instr in byte_obj:
    print(instr.opname, instr.opcode)
# You can also do `import torch as dl_framework``
import oneflow as dl_framework
# execute bytecode
exec(code_obj)

字节码的履行成果如下。只需求在真实履行前,把 dl_framework导入就好,然后能够看到 tensor 核算的成果,是契合预期的。

frame(或许叫 stack)是运行时的目标,对应一个函数调用的栈,在履行时被创立。frame 中要履行的指令便是之前创立的 ByteCode。

在运行时之前,像咱们之前看到的,存在一个编译时进行 AST 和 ByteCode 的编译,之前编译时生成的 code object 目标的指针是 0x7ff79bb5c660

在运行时,能够获取当时的 frame,然后经过 frame.f_code拿到当时 frame 里面包括的 ByteCode(即 code object),能够发现它的指针便是之前编译时生成的那个。

# print(z) 的成果
tensor([[3., 3.]], dtype=oneflow.float32)
# 运行时获取当时 frame ,然后打印 frame 中的 ByteCode 目标的成果
# f = sys._getframe()
# print(f.f_code)
<code object <module> at 0x7f5cea7f1660, file "", line 3>

到此,窥见了一下 Python 源码到 AST, AST 到 ByteCode,ByteCode 到 Frame 履行这个默许的 Python 履行流程。TorchDynamo 用下图做了简略的介绍:

TorchDynamo初探:Python ByteCode的动态修改

其间 foo 对应一个 Python 函数,即上文介绍的 Python Source Code。PyCodeObject 是上文介绍的 code object (ByteCode)在 C 代码层面对应的类。PyFrameObject 是上文介绍的 Frame 在 C 代码层面对应的类,它包括了代码段 PyCodeObject。_PyEval_EvalFrameDefault 对应上文介绍的 exec,它履行一个 Frame,即运行 Frame 带有的 PyCodeObject

现在咱们看一下 CPython 在 C 层面的履行 Frame 的完成,对应 _PyEval_EvalFrameDefault(github.com/python/cpyt… )。 它的主逻辑便是取 ByteCode 指令和履行指令(github.com/python/cpyt… ):

co = f->f_code; // 从 PyFrameObject* f 中取出 PyCodeObject* ,放到 co 中
    names = co->co_names;
    consts = co->co_consts;
    fastlocals = f->f_localsplus;
    freevars = f->f_localsplus + co->co_nlocals;
    // 从 co 中取出第一条指令
    first_instr = (_Py_CODEUNIT *) PyBytes_AS_STRING(co->co_code);
    next_instr = first_instr;
#define NEXTOPARG()  do { \
        _Py_CODEUNIT word = *next_instr; \
        opcode = _Py_OPCODE(word); \
        oparg = _Py_OPARG(word); \
        // 指向下一条指令
        next_instr++; \
    } while (0)
    // 循环履行指令
    for (;;) {
        // 从当时的指令 next_instr 中获取 opcode
        NEXTOPARG();
        switch (opcode) {
            // 履行 op code,拜见下个部分
        }       
    }

每个指令类型对应一个 opcode,它是一个数值,履行 opcode(github.com/python/cpyt… ),这儿的 opcode 能够清晰的看到和之前咱们打印的 ByteCode 的类型对应联系:

#define TARGET(opcode) \
    case opcode:
    switch (opcode) {
        // TARGET 便是一个 case
        // load
        TARGET(LOAD_FAST) {
            PyObject *value = GETLOCAL(oparg);
            if (value == NULL) {
                format_exc_check_arg(PyExc_UnboundLocalError,
                                     UNBOUNDLOCAL_ERROR_MSG,
                                     PyTuple_GetItem(co->co_varnames, oparg));
                goto error;
            }
            Py_INCREF(value);
            PUSH(value);
            FAST_DISPATCH();
        }
        // store
        TARGET(STORE_FAST) {
            PyObject *value = POP();
            SETLOCAL(oparg, value);
            FAST_DISPATCH();
        }
        // 二元加法
        TARGET(BINARY_ADD) {
            PyObject *right = POP();
            PyObject *left = TOP();
            PyObject *sum;
            if (PyUnicode_CheckExact(left) &&
                     PyUnicode_CheckExact(right)) {
                sum = unicode_concatenate(left, right, f, next_instr);
                /* unicode_concatenate consumed the ref to left */
            }
            else {
                sum = PyNumber_Add(left, right);
                Py_DECREF(left);
            }
            Py_DECREF(right);
            SET_TOP(sum);
            if (sum == NULL)
                goto error;
            DISPATCH();
        }
        // 函数调用
        TARGET(CALL_FUNCTION) {
            PyObject **sp, *res;
            PCALL(PCALL_ALL);
            sp = stack_pointer;
            res = call_function(&sp, oparg, NULL);
            stack_pointer = sp;
            PUSH(res);
            if (res == NULL) {
                goto error;
            }
            DISPATCH();
        }
    }

以上总结了 Python的默许履行流程。

4

TorchDynamo 的作业流程

TorchDynamo 在规范的 Python 履行流程中做的首要改动便是支撑修正 Frame 履行前的 ByteCode。咱们暂时不重视 AST 生成,看 Python 的履行流程,是 Python Source Code -> ByteCode -> Evaluate. TorchDynamo 支撑 Python Source Code -> ByteCode -> [ByteCode rewrite] -> Evaluate。

ByteCode rewrite 的作业方式是把一段 ByteCode 转成 FX Graph,然后调用用户自定义的 FX Graph 改写履行逻辑,生成一个能够经过编译的履行函数。然后把该段 ByteCode 替换成函数调用 ByteCode,而调用的函数便是经过编译的履行函数。然后完成编译优化的功用。

FX Graph 支撑了在 Python 层做代码改写,提高了写编译 Pass 的便利性,这儿不做深化,能够参阅资料1(
pytorch.org/docs/stable…) 和2(zhuanlan.zhihu.com/p/416165157…

ByteCode rewrite 发生在 ByteCode 履行前。同样的 Source Code,每次履行都会走到这个过程,都能够挑选是否进行 ByteCode rewrite,或许挑选进行什么样的 rewrite,还能够支撑 rewrite 成果的缓存和复用。这表现了 Dynamo 的动态性。

下面看一个 TorchDynamo 下 fn() 函数编译的的比如:

# 一个普通的函数
def fn(a, b):
    x = a + b
    x = x / 2.0
    if x.sum() < 0:
        return x * -1.0
    return x
# torchdynamo 函数接口
with torchdynamo.optimize(custom_compiler):   
   fn(torch.randn(10), torch.randn(10))

fn() 函数对应的原始的 python ByteCode,和代码对应的联系拜见其间的注释:

# x = a + b
 0  LOAD_FAST 0 (a)
 2  LOAD_FAST 1 (b)
 4  BINARY_ADD
 6  STORE_FAST 2 (x)
 # x = x / 2.0
 8  LOAD_FAST 2 (x)
 10 LOAD_CONST 1 (2.0)
 12 BINARY_TRUE_DIVIDE
 14 STORE_FAST 2 (x)
 # if x.sum() < 0:
 16 LOAD_FAST 2 (x)
 18 LOAD_METHOD 0 (sum)
 20 CALL_METHOD 0
 22 LOAD_CONST 2 (0)
 24 COMPARE_OP 0 (<)
 26 POP_JUMP_IF_FALSE 36
 # return x * -1.0
 28 LOAD_FAST 2 (x)
 30 LOAD_CONST 3 (-1.0)
 32 BINARY_MULTIPLY
 34 RETURN_VALUE
 # return x
 36 LOAD_FAST 2 (x)
 38 RETURN_VALUE

经过 TorchDynamo 动态改写后的 ByteCode:

# x = a + b
 # x = x / 2.0
 # x.sum() < 0
 # 上面两行被转换成了 __compiled_fn_0
 # __compiled_fn_0 会回来 x 和 x.sum() < 0 组成的 tuple
 0  LOAD_GLOBAL 1 (__compiled_fn_0)
 2  LOAD_FAST 0 (a)
 4  LOAD_FAST 1 (b)
 6  CALL_FUNCTION 2
 8  UNPACK_SEQUENCE 2
 10 STORE_FAST 2 (x)
 12 POP_JUMP_IF_FALSE 22
 # x * -1.0 被转换成了 __compiled_fn_1 
 14 LOAD_GLOBAL 2 (__compiled_fn_1)
 16 LOAD_FAST 2 (x)
 18 CALL_FUNCTION 1
 20 RETURN_VALUE
 # return x
 22 LOAD_FAST 2 (x)
 24 RETURN_VALUE

能够看到新增了两个函数调用, __compiled_fn_0__compiled_fn_1 ,这两个函数对应的代码逻辑拜见 bytecode 中的注释。这两个函数对应的 fx graph 如下:

__compiled_fn_0:
opcode         name     target                       args              kwargs
-------------  -------  ---------------------------  ----------------  --------
placeholder    a_0      a_0                          ()                {}
placeholder    b_1      b_1                          ()                {}
call_function  add      <built-in function add>      (a_0, b_1)        {}
call_function  truediv  <built-in function truediv>  (add, 2.0)        {}
call_method    sum_1    sum                          (truediv,)        {}
call_function  lt       <built-in function lt>       (sum_1, 0)        {}
output         output   output                       ((truediv, lt),)  {}
__compiled_fn_1:
opcode         name    target                   args         kwargs
-------------  ------  -----------------------  -----------  --------
placeholder    x_4     x_4                      ()           {}
call_function  mul     <built-in function mul>  (x_4, -1.0)  {}
output         output  output                   (mul,)       {}

在 ByteCode rewrite 的最终,TorchDynamo 为这一段代码的输入创立两个 Guard:

  • 局部参数 a 必须是一个 Tensor

  • 局部参数 b 必须是一个 Tensor

该 fn 函数被再次调用时,假如契合这两个条件,则能够命中缓存的 TrochDynamo 处理成果;不然下次 fn 履行时,会触发新的 ByteCode 剖析和变换。

别的,关于和 tensor 无关的、比较特别的 python 代码,其 ByteCode 会保持原状。这样就达到了不需求用户标示区域、自动寻觅优化时机的设计目标。

现在看下 TorchDynamo 履行的流程总结:

TorchDynamo初探:Python ByteCode的动态修改

能够看到它把本来的 PyFrameObject 替换成了 Patched PyFrameObject,这个是 CPython 支撑的特性。这个 Patched PyFrameObject 中最首要的改动便是 Frame 中的 ByteCode (即 PyCodeObject)被修正了,本来的 PyCodeObject 变成了 Transformed PyCodeObject。而这个被改写的 PyCodeObject 如上文和上图所示,首要是部分 ByteCode 被替换成了调用被编译过函数。这个被编译过的函数,支撑自定义编译逻辑,当时默许的编译接口是 FX Graph。

这部分基本参阅了Dynamo的官方介绍(dev-discuss.pytorch.org/t/torchdyna… )。

5

TorchDynamo 修正 Python ByteCode 的完成

Python ByteCode 修正首要依赖 PEP 523(peps.python.org/pep-0523/) 供给的履行自定义 Frame Evaluation API。默许的 Eval Frame 逻辑入口函数是 _PyEval_EvalFrame,默许状况,它会直接调用 _PyEval_EvalFrameDefault() 来处理没被修正的 frame,可是假如发现存在一个自定义的 Eval Frame 函数,就会履行自动线的函数。

CPython _PyEval_EvalFrame 函数完成(github.com/python/cpyt… ),所以只要在 ByteCode 履行前,设置一个自定义的 eval frame 函数即可:

static inline PyObject*
_PyEval_EvalFrame(PyThreadState *tstate, struct _PyInterpreterFrame *frame, int throwflag)
{
    EVAL_CALL_STAT_INC(EVAL_CALL_TOTAL);
    if (tstate->interp->eval_frame == NULL) {
        // 这是默许的 eval frame
        return _PyEval_EvalFrameDefault(tstate, frame, throwflag);
    }
    // 假如存在 eval_frame 就会被履行
    return tstate->interp->eval_frame(tstate, frame, throwflag);
}

能够看到 TorchDynamo 正是这么做的。第一步,在 Python 层根据 ContextManger 在进入 Dynamo 作用域时,就触发 eval_frame 的设置,完成(github.com/pytorch/pyt… ):

# torch._dynamo.optimize(...) 对应的 context manager.
class _TorchDynamoContext:
    def __init__(
        self,
        callback: DynamoCallback,
    ):
        super().__init__()
        assert callable(callback) or callback is False or callback is None
        self.callback: DynamoCallback = callback
        self.prior: Union[Unset, DynamoCallback] = unset
    def __enter__(self):
       # 设置 eval_frame,记载之前的 eval frame
        self.prior = set_eval_frame(self.callback)
    def __exit__(self, exc_type, exc_val, exc_tb):
        assert self.prior is not unset
       # 恢复之前的 eval frame
        set_eval_frame(self.prior)

这儿先大致以为设置的 DynamoCallback 对应一个自定义的 eval frame 所需的参数,通常是自定义的 eval frame 中所需的编译逻辑。

看下 set_eval_frame ,C 代码层面的完成(github.com/pytorch/pyt… ),它有点绕但最终走到了这儿(github.com/pytorch/pyt… ),也是设置 tstate->interp->eval_frame ,把 eval_frame 设置成自定义的 custom_eval_frame_shim:

// custom_eval_frame_shim 是自定义的 frame
inline static void enable_eval_frame_shim(PyThreadState* tstate) {
  if (tstate->interp->eval_frame != &custom_eval_frame_shim) {
    // First call
    // 设置自定义的 eval frame
    tstate->interp->eval_frame = &custom_eval_frame_shim;
  }
}

现在回头看一下 PEP 523 供给的 Python JIT 编译器的自定义 frame 履行的样例,它供给了一个比较规范的模版(注意笔者对比如做了微调,原文有剩余和不合理的当地)。在自定义 eval frame 之前,一般还需求自定义一个寄存自定义 ByteCode 的数据结构,能够以为是自定义编译成果,比方样例中自定义编译成果包括3个字段:

  • exec_count, 代表改 frame 被履行的次数;

  • jit_failed, 代表之前 jit 编译是否失利过;

  • jit_code,代表 jit 编译过后的自定义 ByteCode;

据此,来看下自定义 eval frame 的样例:

# 输入原始的 frame
def eval_frame(frame, throw_flag):
    # 获取 frame 中的 code object 中的寄存自定义编译成果的字段
    pyjion_code = frame.code.co_extra
    if not pyjion_code:
        # 不如不存在,就设置一个空的默许值
        frame.code.co_extra = PyjionJittedCode()
    elif not pyjion_code.jit_failed:
        # 假如之前 jit 履行成功
        if pyjion_code.jit_code:
            # 假如存在 jit 生成的 bytecode,就履行它
            return pyjion_code.eval(pyjion_code.jit_code, frame)
        elif pyjion_code.exec_count > 20000:
            # 没有 jit 编译过,且 frame 被履行超过 20000 次,就尝试进行 jit 编译
            # 假如不存在 jit 生成的 bytecode,就 jit 编译生成它
            if jit_compile(frame):
                # 假如 jit 编译成功,就履行 jit 编译的 bytecode
                return pyjion_code.eval(pyjion_code.jit_code, frame)
            else:
                # 假如 jit 编译失利,就记载下,后边不再编译
                pyjion_code.jit_failed = True
    # 添加 frame 履行次数计数
    pyjion_code.exec_count += 1
    # 履行默许的 frame
    return _PyEval_EvalFrameDefault(frame, throw_flag)

下面接着看 TorchDynamo 自定义 evale frame 的完成。在了解具体的自定义 frame 履行逻辑前,有个前置知识是 PyFrameObject 中的 PyCodeObject 为了履行自定义 frame 添加了一个 co_extra 字段,用来让用户放置自定义的数据,一般是寄存自定义编译成果(
peps.python.org/pep-0523/#e…

typedef struct {
   ...
   void *co_extra;  /* 自定义的 frame 需求的自定义数据 */
} PyCodeObject;

TorchDynamo 在自定义编译成果的类型是 CacheEntry,其间最重要的字段是 code,是被编译器修正后的 ByteCode:

typedef struct cache_entry {
  // check the guards: lambda: <locals of user function>: bool
  PyObject* check_fn;
  // modified user bytecode (protected by check_fn's guards)
  PyCodeObject* code;
  // on a cache miss, linked list of next thing to try
  struct cache_entry* next;
} CacheEntry;

现在看下自定义的 eval frame 逻辑 custom_eval_frame_shim(github.com/pytorch/pyt…):

static PyObject* _custom_eval_frame(PyThreadState* tstate, PyFrameObject* frame, int throw_flag, PyObject* callback) {
  // 获取当时 frame 的 PyCodeObject 的 extra 字段用于后边设置
  // 该字段用于放置自定义的编译成果
  CacheEntry* extra = get_extra(frame->f_code);
  // callback 即上文说的自定义编译器
  // 运用 callback 进行 bytecode 的修正,即编译
  // 编译成果写在了 frame->f_code中的 extra 中
  PyObject* result =
      call_callback(callback, (PyObject*)frame, cache_size(extra));
  if (result != Py_None) {
    // 缓存编译成果
    extra = create_cache_entry(extra, result);
    Py_DECREF(result);
    // 履行自定义的 frame
    // eval_custom_code 最终会调用 CPython 接口 _PyEval_EvalFrameDefault 来履行核算
    // 其间 extra->code 中寄存的就自定义编译器生成的 ByteCode
    // 所以最终 _PyEval_EvalFrameDefault 履行的是编译器生成的 ByteCode
    return eval_custom_code(tstate, frame, extra->code, throw_flag);
  }
}
inline static PyObject* eval_custom_code(PyThreadState* tstate, PyFrameObject* frame, PyCodeObject* custom_code, int throw_flag) {
    // 运用 custom_code 创立一个自定义的 frame
    PyFrameObject* shadow_frame = PyFrame_New(tstate, custom_code, frame->f_globals, NULL);
    // 调用 Python 的 frame 履行自定义 frame
    return _PyEval_EvalFrameDefault(tstate, shadow_frame, throw_flag);
}

到这儿,已经清楚了修正 Python ByteCode 履行的主线逻辑。

6

小结

这儿对 Python 的履行和 TorchDynamo 的首要原理做了初探,首要是自定义 Eval Frame 的完成技巧。其它相关的 Python ByteCode 规范,ByteCode 到 FX Graph 的转换,ByteCode 的改写等内容还没涉及。

参阅资料

  • tenthousandmeters.com/b (tenthousandmeters.com/blog/python…)

  • peps.python.org/pep-052 (peps.python.org/pep-0523/)

  • dev-discuss.pytorch.org (dev-discuss.pytorch.org/t/torchdyna…)

(原文:zhuanlan.zhihu.com/p/589115427…

欢迎 Star、试用 OneFlow 最新版别:
github.com/Oneflow-Inc…