TorchScript 原理篇(上) - 保存模型 | PyTorch

本文最后更新于:2021年9月11日

注:本文是在《PyTorch 源码解读之即时编译篇》 基础之上进行了一些修正和补充。

本文涉及代码来自 PyTorch 1.9.0 (commit_id = d69c22d)。为了方便查看,格式略微有变动。TorchScript 支持 script 多种类型,为了方便了解原理,本文选用最简单的 scripted funtion 进行解读。

下面是本篇分析使用的例子:

def my_func(a: int) -> int:
    a = a + 2
    return a

scripted_func: torch.jit.ScriptFunction = torch.jit.script(my_func)
scripted_func.save("my_func.pt")

Python 部分

本文涉及 C++ 代码的命名空间均为 torch::jit,下文省略,个别容易和 Python 混淆的地方才会额外标注。

整体调用逻辑如下:

  • script
    • get_jit_def:获得 torch 抽象语法树
      • ast.parse:获得 Python 抽象语法树
      • build_def:将 Python 的抽象语法树转化为 torch 抽象语法树
        • build_stmts:建立 torch 抽象语法树
    • torch._C._jit_script_compile:调用 C++ 函数,根据 torch 抽象语法树获得中间表达,即 torch::jit::CompilationUnit 对象
    • save:将 torch::jit::CompilationUnit 对象序列化,并保存到磁盘

script

# 去掉与 script function 无关的代码
def script(obj, optimize=None, _frames_up=0, _rcb=None):
    # 如果是已经 scripted 过的对象则直接返回
    if isinstance(obj, ScriptFunction):
        return obj
		
    # this is a decorated fn, and we need to the underlying fn and its rcb
    if hasattr(obj, "__script_if_tracing_wrapper"):
        obj = obj.__original_fn
        _rcb = _jit_internal.createResolutionCallbackFromClosure(obj)
		
    # 检查是否有重载,因为 function 是根据函数名寻找函数的
    _check_directly_compile_overloaded(obj)
    
    # 检查之前是否编译过了
    maybe_already_compiled_fn = _try_get_jit_cached_function(obj)
    if maybe_already_compiled_fn:
        return maybe_already_compiled_fn
      
    # 获得抽象语法树
    ast = get_jit_def(obj, obj.__name__)
    if _rcb is None:
        _rcb = _jit_internal.createResolutionCallbackFromClosure(obj)
    
    # 调用 C++ 函数,根据抽象语法树获得 IR
    fn = torch._C._jit_script_compile(
        qualified_name, ast, _rcb, get_default_args(obj)
    )
    # Forward docstrings
    fn.__doc__ = obj.__doc__
    
    # 将编译结果缓存
    _set_jit_function_cache(obj, fn)
    return fn

get_jit_def

def get_jit_def(fn, def_name, self_name=None, is_classmethod=False):
    """
    Build a JIT AST (TreeView) from the given function.

    Args:
        fn: A function object to compile
        def_name: The name to give to the resulting AST object. This is not
            always the same as `fn.__name__`, for example:
                def _forward(self):
                    ...
                forward = _forward
            In this case, the `__name__` attribute of the function object is "_forward",
            but we want the result AST to have the name "forward".
        self_name: If this function is a method, what the type name of `self` is.
    """
    
    # dedent_src 为包含了要 script 函数的字符串
    sourcelines, file_lineno, filename = get_source_lines_and_file(fn, torch._C.ErrorReport.call_stack())
    sourcelines = normalize_source_lines(sourcelines)
    source = ''.join(sourcelines)
    dedent_src = dedent(source)
    
    # 调用 Python ast 包将字符串解析为 Python 的抽象语法树
    py_ast = ast.parse(dedent_src)
    if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef):
        raise RuntimeError(f"Expected a single top-level function: {filename}:{file_lineno}")
    leading_whitespace_len = len(source.split('\n', 1)[0]) - len(dedent_src.split('\n', 1)[0])
    
    # 如果使用 torch.jit.annotate 标注变量类型,type_line 为变量类型标注
    type_line = torch.jit.annotations.get_type_line(source)
    
    # SourceContext 是 SourceRangeFactory 的 wrapper,包含编译所需的所有元信息
    ctx = SourceContext(source, filename, file_lineno, leading_whitespace_len, True)
    
    # 因为编译的是个单独的函数,所以取其第一个成员
    fn_def = py_ast.body[0]
		
    # 如果使用了 MonkeyType,则可以利用其生成的类型标注数据(MonkeyType 是通过运行时跟踪类型并自动将类型注释添加到 Python 3 代码的工具)
    # If MonkeyType is installed, get all the consolidated type traces
    # for the arguments from type_trace_db
    type_trace_db = torch.jit._script._get_type_trace_db()
    pdt_arg_types = None
    if monkeytype_trace:
        qualname = get_qualified_name(fn)
        pdt_arg_types = type_trace_db.get_args_types(qualname)
		
    # build_def 将 Python 的抽象语法树转化为 PyTorch 使用的抽象语法树格式
    return build_def(ctx, fn_def, type_line, def_name, self_name=self_name, pdt_arg_types=pdt_arg_types)

py_ast

为了查看 py_ast 这个对象的内容,可以使用下面这行代码来打印。

import yaml, jsonpickle; print(yaml.dump(yaml.load(jsonpickle.encode(py_ast)), indent=2))

结果如下:

- py/type: _ast.Module
- body:
  - py/reduce:
    - py/type: _ast.FunctionDef
    - args:
        py/reduce:
        - py/type: _ast.arguments
        - args:
          - py/reduce:
            - py/type: _ast.arg
            - annotation:
                py/reduce:
                - py/type: _ast.Name
                - col_offset: 15
                  ctx:
                    py/reduce:
                    - py/type: _ast.Load
                    - {}
                  id: int
                  lineno: 1
              arg: a
              col_offset: 12
              lineno: 1
          defaults: []
          kw_defaults: []
          kwarg: null
          kwonlyargs: []
          vararg: null
      body:
      - py/reduce:
        - py/type: _ast.Assign
        - col_offset: 4
          lineno: 2
          targets:
          - py/reduce:
            - py/type: _ast.Name
            - col_offset: 4
              ctx:
                py/reduce:
                - py/type: _ast.Store
                - {}
              id: a
              lineno: 2
          value:
            py/reduce:
            - py/type: _ast.BinOp
            - col_offset: 8
              left:
                py/reduce:
                - py/type: _ast.Name
                - col_offset: 8
                  ctx:
                    py/id: 12
                  id: a
                  lineno: 2
              lineno: 2
              op:
                py/reduce:
                - py/type: _ast.Add
                - {}
              right:
                py/reduce:
                - py/type: _ast.Num
                - col_offset: 12
                  lineno: 2
                  n: 2
      - py/reduce:
        - py/type: _ast.Return
        - col_offset: 4
          lineno: 3
          value:
            py/reduce:
            - py/type: _ast.Name
            - col_offset: 11
              ctx:
                py/id: 12
              id: a
              lineno: 3
      col_offset: 0
      decorator_list: []
      lineno: 1
      name: my_func
      returns:
        py/reduce:
        - py/type: _ast.Name
        - col_offset: 23
          ctx:
            py/id: 12
          id: int
          lineno: 1

将上述内容简化如下:

  • ast.body[0]
    • _ast.Assign
      • op: _ast.Add
      • left: _ast.Name
        • id: a
      • right: _ast.Num
        • id: 2
    • _ast.Return
      • value: _ast.Name
        • id: a

ast.body 是一个 list,其长度等于解析的 string 中包含的函数的个数。其第一个元素的 body 成员包含两个元素,类型分别为 _ast.Assign_ast.Return。对于前者,其 value 是一个 Binop,包含三个成员:op = Add、left = Name、right = Num。这个 Binop 即解析的 a = a + 2

buid_def

def build_def(ctx, py_def, type_line, def_name, self_name=None, pdt_arg_types=None):
    body = py_def.body

    # 这一堆代码是利用输入参数构造了一堆元信息
    r = ctx.make_range(py_def.lineno + len(py_def.decorator_list),
                       py_def.col_offset,
                       py_def.col_offset + len("def"))

    param_list = build_param_list(ctx, py_def.args, self_name, pdt_arg_types)
    return_type = None
    if getattr(py_def, 'returns', None) is not None:
        return_type = build_expr(ctx, py_def.returns)

    decl = Decl(r, param_list, return_type)
    is_method = self_name is not None
    if type_line is not None:
        type_comment_decl = torch._C.parse_type_comment(type_line)
        decl = torch._C.merge_type_from_type_comment(decl, type_comment_decl, is_method)
		
    # build_stmts 函数利用了传进来了 ctx 和 body,其中 ctx 包含了源代码的元信息,body 则是抽象语法树
    return Def(Ident(r, def_name),
               decl,
               build_stmts(ctx, body))

build_stmts

# 以下模块均从 C++ 导入
from torch._C._jit_tree_views import (
    ClassDef, Ident, Stmt, Decl, Def, Var,
    EmptyTypeAnnotation, Param, ExprStmt, Assign,
    Delete, Return, Raise, Assert, AugAssign, While,
    For, If, Pass, Break, Continue, Apply, Dots, Select,
    TrueLiteral, FalseLiteral, NoneLiteral, Starred,
    ListLiteral, TupleLiteral, DictLiteral, Const,
    StringLiteral, ListComp, Attribute, BinOp, UnaryOp,
    SliceExpr, Subscript, TernaryIf, With, WithItem, Property,
    DictComp,
)
# 对语法树进行递归调用,stmt 是 statement(语句)的简写
def build_stmts(ctx, stmts):
    stmts = [build_stmt(ctx, s) for s in stmts]
    return list(filter(None, stmts))
# build_stmt 会调用 StmtBuilder.__call__
build_stmt = StmtBuilder()
build_expr = ExprBuilder()


class Builder(object):
    # StmtBuilder 与 ExprBuilder 继承自 Builder 且没有重写 __call__ 函数,所以最终调用的还是 Builder.__call__
    def __call__(self, ctx, node):
        # 根据语法树节点的类型拼凑方法的名字以完成调用
        method = getattr(self, 'build_' + node.__class__.__name__, None)
        if method is None:
            raise UnsupportedNodeError(ctx, node)
        return method(ctx, node)


# 这里只留下了和例子相关的 build 函数
class StmtBuilder(Builder):
    @staticmethod
    def build_Expr(ctx, stmt):
        value = stmt.value
        if value.__class__.__name__ == 'Str':
            # If a statement is a string literal expression,
            # then it is a docstring. Just ignore it.
            return None
        else:
            return ExprStmt(build_expr(ctx, value))

    @staticmethod
    def build_Assign(ctx, stmt):
        rhs = build_expr(ctx, stmt.value)  # stmt.value = _ast.BinOp
        lhs = [build_expr(ctx, x) for x in stmt.targets]  # x = _ast.Name
        return Assign(lhs, rhs)

    @staticmethod
    def build_Return(ctx, stmt):
        r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("return"))
        return Return(r, None if stmt.value is None else build_expr(ctx, stmt.value))


class ExprBuilder(Builder):
    binop_map = {
        ast.Add: '+',
        ast.Sub: '-',
        ast.Mult: '*',
        ast.Div: '/',
        ast.Pow: '**',
        ast.Mod: '%',
        ast.FloorDiv: '//',
        ast.BitAnd: '&',
        ast.BitXor: '^',
        ast.BitOr: '|',
        ast.LShift: '<<',
        ast.RShift: '>>',
    }

    binop_map[ast.MatMult] = '@'

    @staticmethod
    def build_Name(ctx, expr):
        r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + len(expr.id))
        if expr.id.startswith(_reserved_prefix):
            raise NotSupportedError(r, "names of variables used in JIT-ed functions "
                                       "can't start with " + _reserved_prefix)
        if expr.id == "True":
            return TrueLiteral(r)
        elif expr.id == "False":
            return FalseLiteral(r)
        elif expr.id == "None":
            return NoneLiteral(r)
        elif expr.id == "Ellipsis":
            return Dots(r)
        return Var(Ident(r, expr.id))

    @staticmethod
    def build_Num(ctx, expr):
        value = str(expr.n)
        r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + len(value))
        return Const(r, value)

    @staticmethod
    def build_BinOp(ctx, expr):
        lhs = build_expr(ctx, expr.left)  # _ast.Name
        rhs = build_expr(ctx, expr.right) # _ast.Num
        op = type(expr.op)                # _ast.Add

        if op == ast.Div and not ctx.uses_true_division:
            err_range = ctx.make_raw_range(lhs.range().end, rhs.range().start)
            raise FrontendError(err_range, 'Division of ints in TorchScript uses Python 3 true '
                                'division semantics. Please put `from __future__ '
                                'import division` at the top of your file')

        # 转化为约定的代表运算类型的 string 符号
        op_token = ExprBuilder.binop_map.get(op)  # op_token = "+"
        if op_token is None:
            err_range = ctx.make_raw_range(lhs.range().end, rhs.range().start)
            raise NotSupportedError(err_range, "unsupported binary operator: " + op.__name__)
        return BinOp(op_token, lhs, rhs)

最终的 PyTorch AST 可以直接打印出来,如下:

(def
  (ident my_func)
  (decl
    (list
      (param
        (ident a)
        (option (variable (ident int)))
        (option)
        (False)))
    (option (variable (ident int))))
  (list
    (assign
      (list (variable (ident a)))
      (option
        (+
          (variable (ident a))
          (const 2)))
      (option))
    (return (variable (ident a)))))

C++ 部分

几个类

本节介绍 JIT 部分比较常见的几个类,按照 low-level 到 high-level 的顺序介绍。

Graph

Graphs 是用于定义 TorchScript 函数实现的 IR 的基础。如果对 LLVM 比较熟悉的话,这个类似 llvm::Function 类。Graph 主要由 NodeBlockValue 组成。Node 是指令(例如做矩阵乘)。Blocs 是一组有序执行的 Node。每个 Node 输入一系列 Value,输出一系列 Value

用前面分析使用的 Python 代码为例,并打印其 Graph

def my_func(a: int) -> int:
    a = a + 2
    return a

scripted_func: torch.jit.ScriptFunction = torch.jit.script(my_func)
print(scripted_func.graph)

"""
graph(%a.1 : int):
  %2 : int = prim::Constant[value=2]()
  %a.5 : int = aten::add(%a.1, %2)
  return (%a.5)
"""

打印出的内容是 PyTorch IR 的标准文本形式,很容易在其中找到之前提到的元素

  • graphGraph
  • %a.5 : int 是带有类型标注的 Value
  • %a.5 : int = aten::add(%a.1, %2)Node,表示 aten::add 算子,它以 %a.1%2 作为输入,以 %a.5 所谓作为输出

Graph 是静态单赋值(single-static assignment, SSA)的形式,即每个变量只会被赋值一次,每一次赋值定义一个新的变量。

Function

Function 有两种:BuiltinOpFunctionGraphFunction。前者是 native C++ 函数的 wrapper。后者拥有以下类型的主要成员:

  • FunctionSchema:描述输入参数和返回值的类型和名称

  • Graph:即 IR,定义了函数的实现。下一节会详细介绍

  • GraphExecutor:用于实际执行定义该方法的 Graph。下一节会详细介绍

Method

Method 可以简单理解为是 Function 的 wrapper。拥有 Module 所有的可执行方法。

CompilationUnit

CompilationUnit 本质是一组 FunctionCompilationUnit 拥有几个 define 函数,可以将 AST 转为 IR 存放在其 Function 类型的成员中。

StrongFunctionPtr

StrongFunctionPtrCompilationUnitFunction 的关系正如下面代码所描述的。

struct StrongFunctionPtr {
  StrongFunctionPtr(std::shared_ptr<CompilationUnit> cu, Function* function)
      : cu_(std::move(cu)), function_(function) {
    TORCH_INTERNAL_ASSERT(cu_);
    TORCH_INTERNAL_ASSERT(function_);
  }
  std::shared_ptr<CompilationUnit> cu_;
  Function* function_;
};

Module

TorchScript 和 LibTorch 交互的中介。序列化和反序列化都是以 Module 为单位进行的。Module 继承自 Object 类,拥有一个 ObjectPtr 类型成员,ObjectPtr 拥有类型为 StrongTypePtr 的成员 type_ 和类型为 std::vector<IValue> 的成员 slots_。其中前者存有该 module 所有 Method,后者存有该 module 的所有数值类成员。大概结构如下:

  • Module : Object
    • ObjectPtr _ivalue_
      • StrongTypePtr type_
      • std::vector<IValue> slots_

步骤

整体调用逻辑如下:

  • script_compile_function
    • get_python_cu
    • CompilationUnit::define

script_compile_function

// 将 script_compile_function 函数注册到 Python 接口 torch._C._jit_script_compile
m.def(
  "_jit_script_compile",
  [](const std::string& qualname,
     const Def& def,
     const ResolutionCallback& rcb,
     const FunctionDefaults& defaults) {
    C10_LOG_API_USAGE_ONCE("torch.script.compile");
    const auto name = c10::QualifiedName(qualname);
    TORCH_INTERNAL_ASSERT(name.name() == def.name().name());
    return script_compile_function(name, def, defaults, rcb);
  });
// StrongFunctionPtr 对应到 Python 就是 torch.jit.ScriptFunction
static StrongFunctionPtr script_compile_function(
    const c10::QualifiedName& name,
    const Def& def,
    const FunctionDefaults& defaults,
    const ResolutionCallback& rcb) {
  auto cu = get_python_cu();
  // 通过后面的解释可以知道,defined_functions 包含了一系列 creator 函数,其作用是将 AST 转为 IR
  auto defined_functions = cu->define(
      QualifiedName(name.prefix()),
      /*properties=*/{},
      /*propResolvers=*/{},
      {def},
      {pythonResolver(rcb)},
      nullptr,
      true);
  TORCH_INTERNAL_ASSERT(defined_functions.size() == 1);
  auto& defined = defined_functions[0];
  defined->setSchema(getSchemaWithNameAndDefaults(
      def.range(), defined->getSchema(), def.name().name(), defaults));
  // ret 中包含了一个 CompilationUnit 成员和一个 Function 成员,他们都包含了 creator 函数
  StrongFunctionPtr ret(std::move(cu), defined);
  didFinishEmitFunction(ret);
  return ret;
}

get_python_cu

// 构造 CompilationUnit 对象
inline std::shared_ptr<CompilationUnit> get_python_cu() {
  return py::module::import("torch.jit._state")
      .attr("_python_cu")
      .cast<std::shared_ptr<CompilationUnit>>();
}

CompilationUnit::define

// for historic reasons, these are defined in ir_emitter.cpp
// Returns the list of Functions just defined.
std::vector<Function*> CompilationUnit::define(
  const c10::optional<c10::QualifiedName>& prefix,
  const std::vector<Property>& properties,
  const std::vector<ResolverPtr>& propResolvers,
  const std::vector<Def>& definitions,  // 这个是 torch 抽象语法树
  const std::vector<ResolverPtr>& defResolvers,  // determines how we handle free variables in each definition
  const Self* self,  // if non-null, the first argument to each def, is bound to this value
  bool shouldMangle = false  /* see [name mangling] */) {
  TORCH_INTERNAL_ASSERT(definitions.size() == defResolvers.size());
  TORCH_INTERNAL_ASSERT(properties.size() == propResolvers.size());
  std::vector<Function*> functions;
  std::unordered_map<std::string, Function*> function_table;

  // Records fn in function_table, functions and with register_function.
  // This is done several times below, so this lambda helps avoid repeating
  // code.
  auto record_function = [&](std::unique_ptr<Function> fn) {
    function_table[fn->name()] = fn.get();
    functions.emplace_back(fn.get());
    this->register_function(std::move(fn));
  };

  for (size_t i = 0; i < properties.size(); i++) {
    PropertyPair property_fns = define_property(
        prefix,
        properties[i],
        propResolvers[i],
        self,
        function_table,
        shouldMangle);

    auto& getter_fn = property_fns.getGetter();
    auto& setter_fn = property_fns.getSetter();

    record_function(std::move(getter_fn));

    if (setter_fn) {
      record_function(std::move(setter_fn));
    }
  }

  // torch 抽象语法树传给了 define 函数的另一个重载
  for (size_t i = 0; i < definitions.size(); i++) {
    auto fn = define(
        prefix,
        definitions[i],
        defResolvers[i],
        self,
        function_table,
        shouldMangle,
        CompilationUnit::FunctionType::Method);

    // 通过后面的代码可以知道,fn 包含了 creator 函数,其功能是将 AST 转为 IR
    record_function(std::move(fn));
  }

  // We need to compile `__init__` first, since it can determine what attributes
  // are available to other methods. So reorder the definitions accordingly.
  for (auto& kv : function_table) {
    if (kv.first == "__init__") {
      kv.second->ensure_defined();
    }
  }

  for (Function* function : functions) {
    // 调用 function 中保存的 creator,通过 AST 创建 IR,存入 function 中
    function->ensure_defined();
  }

  return functions;
}

void GraphFunction::ensure_defined() {
  if (function_creator_) {
    auto creator = function_creator_;
    function_creator_ = placeholderCreator;
    creator(*this);
    function_creator_ = nullptr;
  }
  check_single_output();
}
std::unique_ptr<Function> CompilationUnit::define(
    const c10::optional<QualifiedName>& prefix,
    const Def& def,
    const ResolverPtr& resolver,
    const Self* self,
    const std::unordered_map<std::string, Function*>& function_table,
    bool shouldMangle,
    CompilationUnit::FunctionType type) const {
  TORCH_INTERNAL_ASSERT(resolver);
  auto _resolver = resolver;
  
  // script_compile_function 传进来的是个 nullptr
  if (!self) {
    // if self is defined, then these are methods and do not go into the
    // global namespace otherwise, they get defined together so we add them to
    // the function table so the methods can see each other
    _resolver =
      std::make_shared<FunctionResolver>(resolver.get(), function_table);
  }
	
  // creator 将 def(存有 torch 抽象语法树)通过 FunctionResolver 转成 IR 存入 method 中
  auto creator = [def, _resolver, self](Function& method) {
    // Store the function name so that it can be referenced if there is an error
    // while compiling this function
    std::string call_name = method.qualname().name();
    if (self) {
      auto atoms = method.qualname().atoms();
      // There should be at least a ClassName.method_name
      TORCH_INTERNAL_ASSERT(atoms.size() >= 2);
      call_name = atoms.at(atoms.size() - 2) + "." + atoms.at(atoms.size() - 1);
    }
    ErrorReport::CallStack call(call_name, def.range());
    to_ir(def, _resolver, self, method);
  };

  auto name = prefix ? QualifiedName(*prefix, def.name().name())
    : QualifiedName(def.name().name());
  if (shouldMangle) {
    // If `shouldMangle` is set, we should generate a unique name for this
    // function if there is already an existing one.
    if (auto fn = find_function(name)) {
      name = mangle(name);
    }
  }
	
  // 将 creator 存入 fn 中
  auto fn = torch::make_unique<GraphFunction>(
    std::move(name), std::make_shared<Graph>(), creator);
  if (self) {
    // Register this as a method on `self`'s type
    if (type == CompilationUnit::FunctionType::Hook) {
      self->getClassType()->addForwardHook(fn.get());
    } else if (type == CompilationUnit::FunctionType::PreHook) {
      self->getClassType()->addForwardPreHook(fn.get());
    } else {
      self->getClassType()->addMethod(fn.get());
    }
  }

  return fn;
}

define 函数中,function 中保存的 creator 被调用,其通过 AST 创建 IR,存入 function 中。creator 函数会调用 to_ir 创建 IR。

to_ir

Script 模式的整个流程可以简单分为两步:首先生成 AST(以树对象的形式呈现),然后 IR emitter 通过对该树对象做语义分析从而将其转化为 Module。第二步的主要工作由结构体 to_ir 的构造函数完成。在 emitStatements 函数中,语句根据其类型被 emit 到 Graph 中。

struct to_ir {
  to_ir(
      const Def& def,  // AST
      ResolverPtr resolver_,
      const Self* self,
      Function& method) // method being constructed
      : method(method),
        graph(method.graph()),
        resolver(std::move(resolver_)),
        typeParser_(resolver),
        environment_stack(nullptr) {
    ......

    method.setSchema(emitDef(def, self, graph->block()));

    ......
  }
}

FunctionSchema emitDef(const Def& def, const Self* self, Block* block) {
  ......

  emitStatements(stmts_list.begin(), stmts_list.end());

  ......
}
void emitStatements(
    List<Stmt>::const_iterator begin,
    List<Stmt>::const_iterator end) {
  for (; begin != end; ++begin) {
    auto stmt = *begin;
    ErrorReport::CallStack::update_pending_range(stmt.range());
    switch (stmt.kind()) {
      case TK_IF:
        emitIf(If(stmt));
        break;
      case TK_WHILE:
        emitWhile(While(stmt));
        break;
      case TK_FOR:
        emitFor(For(stmt));
        break;
      case TK_ASSIGN:
        emitAssignment(Assign(stmt));
        break;
      case TK_AUG_ASSIGN:
        emitAugAssignment(AugAssign(stmt));
        break;
      case TK_EXPR_STMT: {
        auto expr = ExprStmt(stmt).expr();
        emitSugaredExpr(expr, 0);
      } break;
      case TK_RAISE:
        emitRaise(Raise(stmt));
        break;
      case TK_ASSERT:
        emitAssert(Assert(stmt));
        break;
      case TK_RETURN: {
        emitReturn(Return(stmt));
      } break;
      case TK_CONTINUE: {
        emitContinue(Continue(stmt));
      } break;
      case TK_BREAK: {
        emitBreak(Break(stmt));
      } break;
      case TK_PASS:
        // Emit nothing for pass
        break;
      case TK_DEF:
        emitClosure(Def(stmt));
        break;
      case TK_DELETE:
        emitDelete(Delete(stmt));
        break;
      case TK_WITH:
        emitWith(With(stmt));
        break;
      default:
        throw ErrorReport(stmt)
            << "Unrecognized statement kind " << kindToString(stmt.kind());
    }
    // Found an exit statement in this block. The remaining statements aren't
    // reachable so we don't emit them.
    if (exit_blocks.count(environment_stack->block()))
      return;
  }
}

emitX 函数会调用 graph->insert 将各种 Node 插入到 Graph 中。emitX 函数的细节比较复杂,暂不展开介绍。

Module::save

// 将 Module::save 注册到 torch.jit.ScriptFunction
py::class_<StrongFunctionPtr>(m, "ScriptFunction", py::dynamic_attr())
    .def(
        "save",
        [](const StrongFunctionPtr& self,
            const std::string& filename,
            const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
          Module module("__torch__.PlaceholderModule");
          // [issue 27343]
          // Modules have 'training' attributes by default, but due to
          // https://github.com/pytorch/pytorch/issues/27343, functions end
          // up having a training attribute when they are loaded. This adds
          // a fake 'training' attribute that shouldn't be used, but prevents
          // jitter on saving and loading. Once that issue is fixed this can
          // be deleted.
          module.register_attribute("training", BoolType::get(), true);
          // 将 StrongFunctionPtr 添加到 Module 对象中
          addFunctionToModule(module, self);
          module.save(filename, _extra_files);
        },
        py::arg("filename"),
        py::arg("_extra_files") = ExtraFilesMap());
void addFunctionToModule(Module& module, const StrongFunctionPtr& func) {
  // Make a graph with a fake self argument
  auto graph = func.function_->graph()->copy();
  auto v = graph->insertInput(0, "self");
  v->setType(module._ivalue()->type());
  const auto name = QualifiedName(*module.type()->name(), "forward");
  // 将 StrongFunctionPtr 转成 Method 对象
  auto method = module._ivalue()->compilation_unit()->create_function(name, graph);
  // 将 Method 对象添加到 Module 对象中
  module.type()->addMethod(method);
}

void ClassType::addMethod(torch::jit::Function* method) {
  TORCH_CHECK(
      findMethod(method->name()) == nullptr,
      "Can't redefine method: ",
      method->name(),
      " on class: ",
      repr_str());
  methods_.push_back(method);
}
void Module::save(const std::string& filename, const ExtraFilesMap& extra_files) const {
  ExportModule(*this, filename, extra_files, false /* bytecode_format */);
}

void ExportModule(
    const Module& module,
    const std::string& filename,
    const ExtraFilesMap& extra_files,
    bool bytecode_format,
    bool save_mobile_debug_info) {
  caffe2::serialize::PyTorchStreamWriter writer(filename);
  ScriptModuleSerializer serializer(writer);
  serializer.serialize(module, extra_files, bytecode_format, save_mobile_debug_info);
}

通过代码可以知道,在进行保存时,先将 StrongFunctionPtr 转换成 Module 对象,然后通过 ScriptModuleSerializer 进行序列化。ScriptModuleSerializerScriptModuleDeserializer 是 PyTorch 中用于序列化的类,其中涉及到模型的序列化,可以参考 TorchScript serialization

参考

PyTorch 源码解读之即时编译篇

TorchScript 如何实现Python -> C++ 代码转换

PyTorch JIT Source Code Read Note

pytorch/OVERVIEW.md