TorchScript 原理篇(下) - 运行模型 | PyTorch

本文最后更新于:2021年9月7日

上一篇描述了如何生成 IR,本篇主要描述如何运行 IR。

下面是本篇分析使用的例子:

// load module
torch::jit::script::Module module = torch::jit::load("my_func.pt");

// run module
std::vector<torch::jit::IValue> inputs;
inputs.push_back(1);
torch::jit::IValue output = module.get_method("forward")(inputs);

// get output
int ret = output.toInt();

torch::jit::load

torch.jit.loadtorch::jit::load 底层都是通过 ScriptModuleDeserializer 类进行反序列化,这个函数会将序列化文件反序列化成 torch::jit::Module 对象。

Module::get_method

Module::get_method 会返回一个 Method 对象,其 operator() 重载如下:

IValue Method::operator()(std::vector<IValue> stack, const Kwargs& kwargs) {
  stack.insert(stack.begin(), owner()._ivalue()); // self
  RECORD_TORCHSCRIPT_FUNCTION(name(), stack);
  // 调用 function_ 的 operator()。每一个 Method 对象都有一个 Function 类型的成员变量 function_,本质是调用这个成员变量
  return (*function_)(std::move(stack), kwargs);
}

这里的 function_ 其实就是前面保存模型时涉及到的 GraphFunction 类型变量。

GraphExecutor

GraphFunction::operator()

IValue GraphFunction::operator()(std::vector<IValue> stack, const Kwargs& kwargs) {
  getSchema().checkAndNormalizeInputs(stack, kwargs);
  run(stack);
  return stack.front();
}

void GraphFunction::run(Stack& stack) {
  get_executor().run(stack);
}

可以看出,最后是调用了 GraphExecutor::run 函数。

GraphExecutor::run

void GraphExecutor::run(Stack& inputs) {
  return pImpl->run(inputs);
}

void GraphExecutorImplBase::run(Stack& stack) {
  ......
	
  // 获取 plan
  const ExecutionPlan& plan = getPlanFor(stack, GraphExecutor::getDefaultNumBailOuts());
  // 运行 plan
  InterpreterState(plan.code).run(stack);
  last_executed_optimized_graph = plan.graph;
}
const ExecutionPlan& GraphExecutorImpl::getPlanFor(Stack& stack, size_t remaining_bailout_depth)
    override {
  return getGraphExecutorOptimize() ? getOrCompile(stack)
                                    : getOrCompileFallback();
}
const ExecutionPlan& GraphExecutorImpl::getOrCompile(const Stack& stack) {
  ......

  auto plan = compileSpec(spec);  
  
  ......
}

GraphExecutorImpl::compileSpec

ExecutionPlan GraphExecutorImpl::compileSpec(const ArgumentSpec& spec) {
  auto opt_graph = graph->copy();
  GRAPH_DUMP("Optimizing the following function:", opt_graph);
  arg_spec_creator_.specializeTypes(*opt_graph, spec);

  // Phase 0. Inline functions, then clean up any artifacts that the inliner
  //          left in that may inhibit optimization
  Inline(*opt_graph);
  GRAPH_DEBUG("After Inline, before LowerGradOf\n", *opt_graph);
  LowerGradOf(*opt_graph);
  GRAPH_DEBUG("After LowerGradOf, before specializeAutogradZero\n", *opt_graph);
  specializeAutogradZero(opt_graph);
  GRAPH_DEBUG("After specializeAutogradZero, before LowerSimpleTuples\n", *opt_graph);
  LowerSimpleTuples(opt_graph);
  GRAPH_DEBUG("After LowerSimpleTuples, before ConstantPooling\n", *opt_graph);
  ConstantPooling(opt_graph);
  GRAPH_DEBUG("After ConstantPooling, before runRequiredPasses\n", *opt_graph);

  // Phase 1. Specialize to input definedness (this is very important for
  //          gradient graphs), and run required passes to bring the graph
  //          to an executable form.
  runRequiredPasses(opt_graph);
  GRAPH_DEBUG("After runRequiredPasses, before ConstantPropagation\n", *opt_graph);

  // Phase 2. Propagate detailed information about the spec through the
  //          graph (enabled more specializations in later passes).
  //          Shape propagation sometimes depends on certain arguments being
  //          constants, and constant propagation doesn't need shape
  //          information anyway, so it's better to run it first.
  ConstantPropagation(opt_graph);
  GRAPH_DEBUG("After ConstantPropagation, before PropagateInputShapes\n", *opt_graph);
  PropagateInputShapes(opt_graph);
  GRAPH_DEBUG("After PropagateInputShapes, before PropagateRequiresGrad\n", *opt_graph);
  PropagateRequiresGrad(opt_graph);
  GRAPH_DEBUG("After PropagateRequiresGrad, before runOptimization\n", *opt_graph);

  // Phase 3. Run differentiable optimizations (i.e. simple graph rewrites
  //          that we can still execute using autograd).
  runOptimization(opt_graph);

  // Phase 4. If this graph will be differentiated, we need to slice out the
  //          symbolically differentiable subgraphs for further optimizations.
  // Phase 5. Apply non-differentiable optimizations to the graphs we've found
  //          (or the whole graph if we know we won't need its derivative).
  if (needsGradient(opt_graph)) {
    auto diff_nodes = CreateAutodiffSubgraphs(
        opt_graph,
        autodiff_subgraph_inlining ? autodiffSubgraphNodeThreshold : 1);
    GRAPH_DEBUG("After CreateAutodiffSubgraphs\n", *opt_graph);
    size_t idx = 0;
    for (Node* dnode : diff_nodes) {
      GRAPH_DEBUG("Optimizing diff node ", idx);
      auto diff_graph = std::move(dnode->g(attr::Subgraph));
      Gradient gradient = differentiate(diff_graph);
      GRAPH_DEBUG("Forward graph:\n", *(gradient.f));
      GRAPH_DEBUG("Backward graph:\n", *(gradient.df));
      // Run post differentiation optimizations, Autodiff will replace some
      // parts of graph with new graph, these new graphs usually consists of
      // control flows and miss shape information on nodes, so we run shape
      // prop and differentiable optimizations to ensure the graph is
      // optimized
      PropagateInputShapes(gradient.f);
      GRAPH_DEBUG("After PropagateInputShapes\n", *(gradient.f));
      runOptimization(gradient.f);
      // run non diff optimization on the forward graph
      runNondiffOptimization(gradient.f);
      packGradient(gradient, dnode);
      GRAPH_DEBUG("Finished optimizing diff node ", idx++);
    }
    InlineAutodiffSubgraphs(
        opt_graph,
        autodiff_subgraph_inlining ? autodiffSubgraphInlineThreshold : 1);
    GRAPH_DEBUG("After InlineAutodiffSubgraphs\n", *opt_graph);
  } else {
    runNondiffOptimization(opt_graph);
  }
  // Make sure there are no leftovers from any passes.
  EliminateDeadCode(opt_graph);
  GRAPH_DUMP("After compileSpec optimizations:", opt_graph);
  return ExecutionPlan(opt_graph, function_name_);
}

上面这段代码完成了 Graph 的优化,主要做了下面这几件事。为了方便理解,笔者将官方文档的一个示例搬了过来。(笔者注:这部分涉及大量的编译原理知识,看起来着实吃力,先简单总结下,等好好看看编译原理再仔细研究下这一块的代码。)

示例代码如下:

@torch.jit.script
def LSTMCellS(x, hx, cx, w_ih, w_hh, b_ih, b_hh):
    gates = x.mm(w_ih.t()) + hx.mm(w_hh.t()) + b_ih + b_hh
    ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
    ingate = torch.sigmoid(ingate)
    forgetgate = torch.sigmoid(forgetgate)
    cellgate = torch.tanh(cellgate)
    outgate = torch.sigmoid(outgate)
    cy = (forgetgate * cx) + (ingate * cellgate)
    hy = outgate * torch.tanh(cy)
    return hy, cy

这段代码对应的未经优化的 Graph 如下:

graph(%x : Tensor,
      %hx : Tensor,
      %cx : Tensor,
      %w_ih : Tensor,
      %w_hh : Tensor,
      %b_ih : Tensor,
      %b_hh : Tensor):
  %7 : int = prim::Constant[value=4]()
  %8 : int = prim::Constant[value=1]()
  %9 : Tensor = aten::t(%w_ih)
  %10 : Tensor = aten::mm(%x, %9)
  %11 : Tensor = aten::t(%w_hh)
  %12 : Tensor = aten::mm(%hx, %11)
  %13 : Tensor = aten::add(%10, %12, %8)
  %14 : Tensor = aten::add(%13, %b_ih, %8)
  %gates : Tensor = aten::add(%14, %b_hh, %8)
  %16 : Tensor[] = aten::chunk(%gates, %7, %8)
  %ingate.1 : Tensor, %forgetgate.1 : Tensor, %cellgate.1 : Tensor, %outgate.1 : Tensor = prim::ListUnpack(%16)
  %ingate : Tensor = aten::sigmoid(%ingate.1)
  %forgetgate : Tensor = aten::sigmoid(%forgetgate.1)
  %cellgate : Tensor = aten::tanh(%cellgate.1)
  %outgate : Tensor = aten::sigmoid(%outgate.1)
  %25 : Tensor = aten::mul(%forgetgate, %cx)
  %26 : Tensor = aten::mul(%ingate, %cellgate)
  %cy : Tensor = aten::add(%25, %26, %8)
  %28 : Tensor = aten::tanh(%cy)
  %hy : Tensor = aten::mul(%outgate, %28)
  %30 : (Tensor, Tensor) = prim::TupleConstruct(%hy, %cy)
  return (%30)

特化变量类型

获取 Graph 的输入参数的属性 ArgumentSpec,包括以下几种属性:

  • dtype
  • rank (not shape)
  • requires_grad
  • device type (CPU, CUDA)
  • defined (whether the Tensor exists or is a placeholder)
# post specialization, inputs are now specialized types
graph(%x : Float(*, *),
      %hx : Float(*, *),
      %cx : Float(*, *),
      %w_ih : Float(*, *),
      %w_hh : Float(*, *),
      %b_ih : Float(*),
      %b_hh : Float(*)):
  %7 : int = prim::Constant[value=4]()
  %8 : int = prim::Constant[value=1]()
  %9 : Tensor = aten::t(%w_ih)
  %10 : Tensor = aten::mm(%x, %9)
  %11 : Tensor = aten::t(%w_hh)
  %12 : Tensor = aten::mm(%hx, %11)
  %13 : Tensor = aten::add(%10, %12, %8)
  %14 : Tensor = aten::add(%13, %b_ih, %8)
  %gates : Tensor = aten::add(%14, %b_hh, %8)
  %16 : Tensor[] = aten::chunk(%gates, %7, %8)
  %ingate.1 : Tensor, %forgetgate.1 : Tensor, %cellgate.1 : Tensor, %outgate.1 : Tensor = prim::ListUnpack(%16)
  %ingate : Tensor = aten::sigmoid(%ingate.1)
  %forgetgate : Tensor = aten::sigmoid(%forgetgate.1)
  %cellgate : Tensor = aten::tanh(%cellgate.1)
  %outgate : Tensor = aten::sigmoid(%outgate.1)
  %25 : Tensor = aten::mul(%forgetgate, %cx)
  %26 : Tensor = aten::mul(%ingate, %cellgate)
  %cy : Tensor = aten::add(%25, %26, %8)
  %28 : Tensor = aten::tanh(%cy)
  %hy : Tensor = aten::mul(%outgate, %28)
  %30 : (Tensor, Tensor) = prim::TupleConstruct(%hy, %cy)
  return (%30)

可以看到输入参数的类型已经被重新标注了。

运行必须的 pass

pass 是编译器中的一个概念,意为“遍历一遍 IR,可以同时对它做一些操作”。这一步是为解释器生成合法 Graph 所必需的转换步骤。举个例子,一些 pass(如微分)会引入并非由算子定义的 Node,这就需要运行 pass 去清除这些多余的 NodeLowerGradOfspecializeAutogradZero 就是做了这件事。

通过 Graph 传播详细信息

这一步会传播常量和 ArgumentSpec 等信息,尽可能进行预计算。

graph(%x : Float(*, *),
      %hx : Float(*, *),
      %cx : Float(*, *),
      %w_ih : Float(*, *),
      %w_hh : Float(*, *),
      %b_ih : Float(*),
      %b_hh : Float(*)):
  %8 : int = prim::Constant[value=1]()
  %9 : Float(*, *) = aten::t(%w_ih)
  %10 : Float(*, *) = aten::mm(%x, %9)
  %11 : Float(*, *) = aten::t(%w_hh)
  %12 : Float(*, *) = aten::mm(%hx, %11)
  %13 : Float(*, *) = aten::add(%10, %12, %8)
  %14 : Float(*, *) = aten::add(%13, %b_ih, %8)
  %gates : Float(*, *) = aten::add(%14, %b_hh, %8)
  %31 : Float(*, *), %32 : Float(*, *), %33 : Float(*, *), %34 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%gates)
  %ingate : Float(*, *) = aten::sigmoid(%31)
  %forgetgate : Float(*, *) = aten::sigmoid(%32)
  %cellgate : Float(*, *) = aten::tanh(%33)
  %outgate : Float(*, *) = aten::sigmoid(%34)
  %25 : Float(*, *) = aten::mul(%forgetgate, %cx)
  %26 : Float(*, *) = aten::mul(%ingate, %cellgate)
  %cy : Float(*, *) = aten::add(%25, %26, %8)
  %28 : Float(*, *) = aten::tanh(%cy)
  %hy : Float(*, *) = aten::mul(%outgate, %28)
  %30 : (Float(*, *), Float(*, *)) = prim::TupleConstruct(%hy, %cy)
  return (%30)

可以看到,除了输入参数,中间变量的类型也被重新标注了。

保留梯度的优化

这些优化不会破坏 autograd。主要包括以下优化:

  • 消除 dead code
  • 消除公共子表达式
  • 将冗余常量池化为单个值
  • 窥孔优化,包括将一些代数操作重写为更简单的操作
  • 展开小循环
  • 展开循环产生的批处理矩阵乘法
graph(%x : Float(*, *),
      %hx : Float(*, *),
      %cx : Float(*, *),
      %w_ih : Float(*, *),
      %w_hh : Float(*, *),
      %b_ih : Float(*),
      %b_hh : Float(*)):
  %8 : int = prim::Constant[value=1]()
  %9 : Float(*, *) = aten::t(%w_ih)
  %10 : Float(*, *) = aten::mm(%x, %9)
  %11 : Float(*, *) = aten::t(%w_hh)
  %12 : Float(*, *) = aten::mm(%hx, %11)
  %13 : Float(*, *) = aten::add(%10, %12, %8)
  %14 : Float(*, *) = aten::add(%13, %b_ih, %8)
  %gates : Float(*, *) = aten::add(%14, %b_hh, %8)
  %31 : Float(*, *), %32 : Float(*, *), %33 : Float(*, *), %34 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%gates)
  %ingate : Float(*, *) = aten::sigmoid(%31)
  %forgetgate : Float(*, *) = aten::sigmoid(%32)
  %cellgate : Float(*, *) = aten::tanh(%33)
  %outgate : Float(*, *) = aten::sigmoid(%34)
  %25 : Float(*, *) = aten::mul(%forgetgate, %cx)
  %26 : Float(*, *) = aten::mul(%ingate, %cellgate)
  %cy : Float(*, *) = aten::add(%25, %26, %8)
  %28 : Float(*, *) = aten::tanh(%cy)
  %hy : Float(*, *) = aten::mul(%outgate, %28)
  %30 : (Float(*, *), Float(*, *)) = prim::TupleConstruct(%hy, %cy)
  return (%30)

这一步对于这个例子没有变化。

非微分相关优化

在不需要梯度的情况下(即仅推理时),可以直接应用优化来生成不带梯度的 Graph。以 FuseGraph 例,它寻找相邻的 point-wise 操作和诸如 splitconcat 之类的reviewing 操作,并在图中创建 prim::FusionGroup 节点来替换这些操作。被注册以执行 prim::FusionGroup 的节点将为每个唯一的 Node 生成一个新的 CUDA kernel 函数,取代原来分离执行的方式。

注意融合组编译的两个阶段:首先,FuseGraph pass 将 Graph 拆分为可融合的子图,并将生成的 Graph 返回给图执行器。然后,当 Graph 转化为 Code 时,会查找 FusionGroup 节点对应的 Operation,并生成一个新的 CUDA kernel 函数。其他编译器应该以类似的方式工作,首先将一个新的算子引入到编译代码应该运行的图中,然后注册一个运算符来实现执行实际编译的节点。

graph(%x : Float(*, *),
      %hx : Float(*, *),
      %cx : Float(*, *),
      %w_ih : Float(*, *),
      %w_hh : Float(*, *),
      %b_ih : Float(*),
      %b_hh : Float(*)):
  %9 : Float(*, *) = aten::t(%w_ih)
  %10 : Float(*, *) = aten::mm(%x, %9)
  %11 : Float(*, *) = aten::t(%w_hh)
  %12 : Float(*, *) = aten::mm(%hx, %11)
  %77 : Tensor[] = prim::ListConstruct(%b_hh, %b_ih, %10, %12)
  %78 : Tensor[] = aten::broadcast_tensors(%77)
  %79 : Tensor, %80 : Tensor, %81 : Tensor, %82 : Tensor = prim::ListUnpack(%78)
  %hy : Float(*, *), %cy : Float(*, *) = prim::FusionGroup_0(%cx, %82, %81, %80, %79)
  %30 : (Float(*, *), Float(*, *)) = prim::TupleConstruct(%hy, %cy)
  return (%30);
with prim::FusionGroup_0 = graph(%13 : Float(*, *),
      %71 : Tensor,
      %76 : Tensor,
      %81 : Tensor,
      %86 : Tensor):
  %87 : Float(*, *), %88 : Float(*, *), %89 : Float(*, *), %90 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%86)
  %82 : Float(*, *), %83 : Float(*, *), %84 : Float(*, *), %85 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%81)
  %77 : Float(*, *), %78 : Float(*, *), %79 : Float(*, *), %80 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%76)
  %72 : Float(*, *), %73 : Float(*, *), %74 : Float(*, *), %75 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%71)
  %69 : int = prim::Constant[value=1]()
  %70 : Float(*, *) = aten::add(%77, %72, %69)
  %66 : Float(*, *) = aten::add(%78, %73, %69)
  %62 : Float(*, *) = aten::add(%79, %74, %69)
  %58 : Float(*, *) = aten::add(%80, %75, %69)
  %54 : Float(*, *) = aten::add(%70, %82, %69)
  %50 : Float(*, *) = aten::add(%66, %83, %69)
  %46 : Float(*, *) = aten::add(%62, %84, %69)
  %42 : Float(*, *) = aten::add(%58, %85, %69)
  %38 : Float(*, *) = aten::add(%54, %87, %69)
  %34 : Float(*, *) = aten::add(%50, %88, %69)
  %30 : Float(*, *) = aten::add(%46, %89, %69)
  %26 : Float(*, *) = aten::add(%42, %90, %69)
  %ingate : Float(*, *) = aten::sigmoid(%38)
  %forgetgate : Float(*, *) = aten::sigmoid(%34)
  %cellgate : Float(*, *) = aten::tanh(%30)
  %outgate : Float(*, *) = aten::sigmoid(%26)
  %14 : Float(*, *) = aten::mul(%forgetgate, %13)
  %11 : Float(*, *) = aten::mul(%ingate, %cellgate)
  %cy : Float(*, *) = aten::add(%14, %11, %69)
  %4 : Float(*, *) = aten::tanh(%cy)
  %hy : Float(*, *) = aten::mul(%outgate, %4)
  return (%hy, %cy)

可以看到,sigmoid、tanh 等 point-wise 的操作被融合成 prim::FusionGroup_0

在不需要梯度的情况下,优化过程到此结束。在 compileSpec 函数的最后,从 Graph 构造一个 ExecutionPlan 对象并交给 InterpreterState 执行。

InterpreterState

几个类

ExecutionPlan

struct ExecutionPlan {
  ExecutionPlan() = default;
  ExecutionPlan(
      std::shared_ptr<Graph> graph,
      std::string function_name,
      size_t remaining_bailout_depth = 0)
      : code(graph, std::move(function_name), remaining_bailout_depth),
        graph(std::move(graph)) {}

  operator bool() const {
    return static_cast<bool>(graph);
  }

  Code code;
  std::shared_ptr<Graph> graph;
};

ExecutionPlan 对象的构造仅仅是将 Graph 转换成 Code

Code

CodeCodeImpl 的 wrapper。这里只列出一部分代码。

struct CodeImpl {
  PreprocessGraph preprocess_;
  std::vector<Instruction> instructions_;

  CodeImpl(
      const std::shared_ptr<Graph>& graph,
      std::string function_name,
      size_t remaining_bailout_depth,
      bool emit_instructions = true)
      : function_name_(std::move(function_name)),
        preprocess_(*graph),
        current_node_(preprocess_.graph->return_node()),
        remaining_bailout_depth_(remaining_bailout_depth) {
    graph_ = preprocess_.graph;
    n_outputs = graph_->outputs().size();
    if (n_outputs == 1) {
      return_type_ = graph->outputs().at(0)->type();
    } else {
      return_type_ = TupleType::create(fmap(graph->outputs(), [](const Value* v) { return v->type(); }));
    }
    n_inputs = graph_->inputs().size();
    if (emit_instructions) {
      // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
      run();
    }
  }
  
  virtual void run() {
    emitCodeForBlock(graph_->block());
    insertInstruction(RET);
    // we deferred the emission of bailout blocks so they appear at the end
    // emit them now and patch up the jumps
    insertBailoutBlocks();
  }
}

Code 通过 PreprocessGraph、emitCodeForBlock、insertInstruction 和 insertBailoutBlocks 四个步骤将 Graph 转为一系列指令。

Operation

大多数 Operation 都有一个关联的 FunctionSchema,它描述了有多少输入将被 pop 以及有多少将被 push。堆栈概念使得定义具有可变数量输入和输出的运算符变得容易,而无需为每个单独的 Operation 分配输入和输出向量。下面的例子展示了 Operation 的使用过程:

using Stack = std::vector<IValue>;
using Operation = std::function<void(Stack*)>;

// schema: example_add(Tensor a, Tensor b) -> Tensor
void example_add(Stack* stack) {
  Tensor a, b, c;
  pop(stack, a, b);
  c = a + b;
  push(stack, c);
}

Operator

Operator 可以简单理解为 OperationFunctionSchema 的 wrapper。

步骤

PreprocessGraph

PreprocessGraph::PreprocessGraph(Graph& g) : graph(g.copy()) {
  insertEnterMethodCalls(*graph);
  dropUnused(graph->block());
  // fill in move_flags by scanning blocks;
  insertLastUses(*graph);
  can_emit_inline = std::move(CanEmitInline(*graph.get()).can_emit_inline_);
}
  • insertEnterMethodCalls:在 prim::Enter 节点之后插入显式 prim::MethodCall 节点以实际调用对象上的 __enter__ 方法
  • dropUnused:插入 prim::Drop 节点以终止任何未被使用的引用
  • insertLastUses:确保每个 Value 在定义它的 Block 中都有消费者。对于大多数节点来说,这已经是正确的。例外情况是:
    • 从未被使用的 Value :添加一个 prim::Drop 节点,该节点在定义后立即使用该值
    • 最后一次被使用的 Value:在最后一次使用的控制流节点之后插入一个 prim::Drop

CodeImpl::emitCodeForBlock

void interpreter::CodeImpl::emitCodeForBlock(Block* block) {
  emitNodeAtBlockLevel(block->param_node());
  for (auto node : block->nodes()) {
    emitNodeAtBlockLevel(node);
  }
  emitNodeAtBlockLevel(block->return_node());
}
void interpreter::CodeImpl::emitNodeAtBlockLevel(Node* node) {
  WithCurrentNode guard(&current_node_, node);
  switch (node->kind()) {
    case prim::Constant:
      emitConstant(node);
      break;
    case prim::Return:
      emitLoadInputs(node->inputs());
      break;
    default:
      if (!preprocess_.can_emit_inline[node]) {
        emitNode(node);
        emitStoreOutputs(node);
      }
      break;
  }
}
void interpreter::CodeImpl::emitNode(Node* node) {
  WithCurrentNode guard(&current_node_, node);
  switch (node->kind()) {
    default:
      emitOperator(node);
      break;
    case prim::Drop:
      emitDrop(node->inputs());
      break;
    case prim::Constant:
      emitConstant(node);
      break;
    case prim::If: ...
    case prim::Loop: ...
    case aten::wait: ...
    case prim::Param: ...
    case prim::CallMethod: ...
    case prim::TypeCheck: ...
    case prim::BailOut: ...
    case prim::profile_ivalue: ...
    case prim::profile: ...
    case prim::GetAttr: ...
    case prim::SetAttr: ...
    case prim::ListUnpack: ...
    case prim::TupleConstruct: ...
    case prim::ListConstruct: ...
    case prim::DictConstruct: ...
    case prim::CreateObject: ...
    case prim::isinstance: ...
    case prim::TupleSlice: ...
    case prim::fork: ...
    case aten::warn: ...
    case prim::Enter: ...
    case prim::Exit: ...
  }
}

emitNode 函数根据节点的类型将指令添加到 interpreter::CodeImpl::instructions_ 中。

emitCallemitOperator 为例:

void interpreter::CodeImpl::emitCall(Function* func, at::ArrayRef<Value*> inputs) {
  emitLoadInputs(inputs);
  insertInstruction(CALL, function_table_.size());
  function_table_.emplace_back(func);
}
virtual void interpreter::CodeImpl::emitOperator(Node* node) {
  emitLoadInputs(node->inputs());
  const Operator& op = node->getOperator();
  if (op.hasOperation() && op.schema().is_vararg()) {
    insertInstruction(OPN, operator_table_.size(), node->inputs().size());
  } else {
    insertInstruction(OP, operator_table_.size());
  }
  operator_table_.emplace_back(op.getOperation(node));
}

这里出现了两个变量:function_table_operator_table_,前者存放需要调用的函数,后者存放 Operator 对应的 Operation

CodeImpl::insertInstruction

void interpreter::CodeImpl::insertInstruction(OpCode op, int64_t X = 0, uint64_t N = 0) {
  instructions_.emplace_back(
      op,
      safe_narrow_cast<int32_t, int64_t>(X),
      safe_narrow_cast<uint16_t, uint64_t>(N));

  ......
}

可以看到,insertInstructionOperator 对应的 OpCode 填入了 instructions_ 中。

CodeImpl::insertBailoutBlocks

这里涉及到一个概念,什么是 bailout ?介绍这个概念之前,先引入另一个节点 prim::profileprim::profile 节点会被 ProfilingRecord::instrumentBlock 插入到每个 Value 的使用者中。prim::profile 节点的作用可以简单理解为记录和推导假设 Tensor 的形状,这有利于代码生成器能够生成更高效的代码。当在实际运行时,假设的 Tensor 形状有可能是错的,为了防止由此导致的运行失败,需要运行原始代码,即不依赖于形状假设的代码。CodeImpl::insertBailoutBlocks 构建原始计算图的去优化版本。bailout 这个单词的中文意思是“紧急救助”,在 LibTorch 中的作用也类似于此。

InterpreterStateImpl::run

InterpreterStateInterpreterStateImpl 的 wrapper。而 InterpreterStateImpl 则是一个可以执行指令的虚拟机。

void InterpreterStateImpl::run(Stack& stack) {
  if (runImpl(stack)) {
    future_->wait();

    auto num_outputs = frames.front().function->n_outputs;
    if (num_outputs == 1) {
      push(stack, future_->value());
    } else {
      auto tuple = future_->value().toTuple();
      for (const IValue& value : tuple->elements()) {
        push(stack, value);
      }
    }
  }
}
bool runImpl(Stack& stack) {
  // if we have never run before, then we might have to return the
  // stack when we suspend, record where it starts so we return the right
  // stack
  if (stack_start_ == -1) {
    TORCH_INTERNAL_ASSERT(stack.size() >= frames.back().function->n_inputs);
    stack_start_ = stack.size() - frames.back().function->n_inputs;
  } else {
    // during restarts, all of the stack is always our own, so we leave
    // nothing
    stack_start_ = 0;
  }

  TLSCurrentInterpreterGuard g(this);
  if (frames.back().pc == 0 && stack_start_ == 0) {
    checkAndStartRecordFunction(frames.back(), stack);
  }
  try {
    while (true) {
      Frame& frame = frames.back();
      // std::cout << "RUNNING ";
      // frames.back().function->dump(std::cout, frame.pc);
      Instruction inst = frame.function->instructions_[frame.pc];