TorchScript 原理篇(下) - 运行模型 | PyTorch

上一篇描述了如何生成 IR,本篇主要描述如何运行 IR。

概述

下面是本篇分析使用的例子:

1
2
3
4
5
6
7
8
9
10
// load module
torch::jit::script::Module module = torch::jit::load("my_func.pt");

// run module
std::vector<torch::jit::IValue> inputs;
inputs.push_back(1);
torch::jit::IValue output = module.get_method("forward")(inputs);

// get output
int ret = output.toInt();

torch::jit::load

torch.jit.loadtorch::jit::load 底层都是通过 ScriptModuleDeserializer 类进行反序列化,这个函数会将序列化文件反序列化成 torch::jit::Module 对象。

Module::get_method

Module::get_method 会返回一个 Method 对象,其 operator() 重载如下:

1
2
3
4
5
6
IValue Method::operator()(std::vector<IValue> stack, const Kwargs& kwargs) {
stack.insert(stack.begin(), owner()._ivalue()); // self
RECORD_TORCHSCRIPT_FUNCTION(name(), stack);
// 调用 function_ 的 operator()。每一个 Method 对象都有一个 Function 类型的成员变量 function_,本质是调用这个成员变量
return (*function_)(std::move(stack), kwargs);
}

这里的 function_ 其实就是前面保存模型时涉及到的 GraphFunction 类型变量。

GraphExecutor

GraphFunction::operator()

1
2
3
4
5
6
7
8
9
IValue GraphFunction::operator()(std::vector<IValue> stack, const Kwargs& kwargs) {
getSchema().checkAndNormalizeInputs(stack, kwargs);
run(stack);
return stack.front();
}

void GraphFunction::run(Stack& stack) {
get_executor().run(stack);
}

可以看出,最后是调用了 GraphExecutor::run 函数。

GraphExecutor::run

1
2
3
4
5
6
7
8
9
10
11
12
13
void GraphExecutor::run(Stack& inputs) {
return pImpl->run(inputs);
}

void GraphExecutorImplBase::run(Stack& stack) {
......

// 获取 plan
const ExecutionPlan& plan = getPlanFor(stack, GraphExecutor::getDefaultNumBailOuts());
// 运行 plan
InterpreterState(plan.code).run(stack);
last_executed_optimized_graph = plan.graph;
}
1
2
3
4
5
const ExecutionPlan& GraphExecutorImpl::getPlanFor(Stack& stack, size_t remaining_bailout_depth)
override {
return getGraphExecutorOptimize() ? getOrCompile(stack)
: getOrCompileFallback();
}
1
2
3
4
5
6
7
const ExecutionPlan& GraphExecutorImpl::getOrCompile(const Stack& stack) {
......

auto plan = compileSpec(spec);

......
}

GraphExecutorImpl::compileSpec

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
ExecutionPlan GraphExecutorImpl::compileSpec(const ArgumentSpec& spec) {
auto opt_graph = graph->copy();
GRAPH_DUMP("Optimizing the following function:", opt_graph);
arg_spec_creator_.specializeTypes(*opt_graph, spec);

// Phase 0. Inline functions, then clean up any artifacts that the inliner
// left in that may inhibit optimization
Inline(*opt_graph);
GRAPH_DEBUG("After Inline, before LowerGradOf\n", *opt_graph);
LowerGradOf(*opt_graph);
GRAPH_DEBUG("After LowerGradOf, before specializeAutogradZero\n", *opt_graph);
specializeAutogradZero(opt_graph);
GRAPH_DEBUG("After specializeAutogradZero, before LowerSimpleTuples\n", *opt_graph);
LowerSimpleTuples(opt_graph);
GRAPH_DEBUG("After LowerSimpleTuples, before ConstantPooling\n", *opt_graph);
ConstantPooling(opt_graph);
GRAPH_DEBUG("After ConstantPooling, before runRequiredPasses\n", *opt_graph);

// Phase 1. Specialize to input definedness (this is very important for
// gradient graphs), and run required passes to bring the graph
// to an executable form.
runRequiredPasses(opt_graph);
GRAPH_DEBUG("After runRequiredPasses, before ConstantPropagation\n", *opt_graph);

// Phase 2. Propagate detailed information about the spec through the
// graph (enabled more specializations in later passes).
// Shape propagation sometimes depends on certain arguments being
// constants, and constant propagation doesn't need shape
// information anyway, so it's better to run it first.
ConstantPropagation(opt_graph);
GRAPH_DEBUG("After ConstantPropagation, before PropagateInputShapes\n", *opt_graph);
PropagateInputShapes(opt_graph);
GRAPH_DEBUG("After PropagateInputShapes, before PropagateRequiresGrad\n", *opt_graph);
PropagateRequiresGrad(opt_graph);
GRAPH_DEBUG("After PropagateRequiresGrad, before runOptimization\n", *opt_graph);

// Phase 3. Run differentiable optimizations (i.e. simple graph rewrites
// that we can still execute using autograd).
runOptimization(opt_graph);

// Phase 4. If this graph will be differentiated, we need to slice out the
// symbolically differentiable subgraphs for further optimizations.
// Phase 5. Apply non-differentiable optimizations to the graphs we've found
// (or the whole graph if we know we won't need its derivative).
if (needsGradient(opt_graph)) {
auto diff_nodes = CreateAutodiffSubgraphs(
opt_graph,
autodiff_subgraph_inlining ? autodiffSubgraphNodeThreshold : 1);
GRAPH_DEBUG("After CreateAutodiffSubgraphs\n", *opt_graph);
size_t idx = 0;
for (Node* dnode : diff_nodes) {
GRAPH_DEBUG("Optimizing diff node ", idx);
auto diff_graph = std::move(dnode->g(attr::Subgraph));
Gradient gradient = differentiate(diff_graph);
GRAPH_DEBUG("Forward graph:\n", *(gradient.f));
GRAPH_DEBUG("Backward graph:\n", *(gradient.df));
// Run post differentiation optimizations, Autodiff will replace some
// parts of graph with new graph, these new graphs usually consists of
// control flows and miss shape information on nodes, so we run shape
// prop and differentiable optimizations to ensure the graph is
// optimized
PropagateInputShapes(gradient.f);
GRAPH_DEBUG("After PropagateInputShapes\n", *(gradient.f));
runOptimization(gradient.f);
// run non diff optimization on the forward graph
runNondiffOptimization(gradient.f);
packGradient(gradient, dnode);
GRAPH_DEBUG("Finished optimizing diff node ", idx++);
}
InlineAutodiffSubgraphs(
opt_graph,
autodiff_subgraph_inlining ? autodiffSubgraphInlineThreshold : 1);
GRAPH_DEBUG("After InlineAutodiffSubgraphs\n", *opt_graph);
} else {
runNondiffOptimization(opt_graph);
}
// Make sure there are no leftovers from any passes.
EliminateDeadCode(opt_graph);
GRAPH_DUMP("After compileSpec optimizations:", opt_graph);
return ExecutionPlan(opt_graph, function_name_);
}

上面这段代码完成了 Graph 的优化,主要做了下面这几件事。为了方便理解,笔者将官方文档的一个示例搬了过来。(笔者注:这部分涉及大量的编译原理知识,看起来着实吃力,先简单总结下,等好好看看编译原理再仔细研究下这一块的代码。)

示例代码如下:

1
2
3
4
5
6
7
8
9
10
11
@torch.jit.script
def LSTMCellS(x, hx, cx, w_ih, w_hh, b_ih, b_hh):
gates = x.mm(w_ih.t()) + hx.mm(w_hh.t()) + b_ih + b_hh
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
ingate = torch.sigmoid(ingate)
forgetgate = torch.sigmoid(forgetgate)
cellgate = torch.tanh(cellgate)
outgate = torch.sigmoid(outgate)
cy = (forgetgate * cx) + (ingate * cellgate)
hy = outgate * torch.tanh(cy)
return hy, cy

这段代码对应的未经优化的 Graph 如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
graph(%x : Tensor,
%hx : Tensor,
%cx : Tensor,
%w_ih : Tensor,
%w_hh : Tensor,
%b_ih : Tensor,
%b_hh : Tensor):
%7 : int = prim::Constant[value=4]()
%8 : int = prim::Constant[value=1]()
%9 : Tensor = aten::t(%w_ih)
%10 : Tensor = aten::mm(%x, %9)
%11 : Tensor = aten::t(%w_hh)
%12 : Tensor = aten::mm(%hx, %11)
%13 : Tensor = aten::add(%10, %12, %8)
%14 : Tensor = aten::add(%13, %b_ih, %8)
%gates : Tensor = aten::add(%14, %b_hh, %8)
%16 : Tensor[] = aten::chunk(%gates, %7, %8)
%ingate.1 : Tensor, %forgetgate.1 : Tensor, %cellgate.1 : Tensor, %outgate.1 : Tensor = prim::ListUnpack(%16)
%ingate : Tensor = aten::sigmoid(%ingate.1)
%forgetgate : Tensor = aten::sigmoid(%forgetgate.1)
%cellgate : Tensor = aten::tanh(%cellgate.1)
%outgate : Tensor = aten::sigmoid(%outgate.1)
%25 : Tensor = aten::mul(%forgetgate, %cx)
%26 : Tensor = aten::mul(%ingate, %cellgate)
%cy : Tensor = aten::add(%25, %26, %8)
%28 : Tensor = aten::tanh(%cy)
%hy : Tensor = aten::mul(%outgate, %28)
%30 : (Tensor, Tensor) = prim::TupleConstruct(%hy, %cy)
return (%30)

特化变量类型

获取 Graph 的输入参数的属性 ArgumentSpec,包括以下几种属性:

  • dtype
  • rank (not shape)
  • requires_grad
  • device type (CPU, CUDA)
  • defined (whether the Tensor exists or is a placeholder)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# post specialization, inputs are now specialized types
graph(%x : Float(*, *),
%hx : Float(*, *),
%cx : Float(*, *),
%w_ih : Float(*, *),
%w_hh : Float(*, *),
%b_ih : Float(*),
%b_hh : Float(*)):
%7 : int = prim::Constant[value=4]()
%8 : int = prim::Constant[value=1]()
%9 : Tensor = aten::t(%w_ih)
%10 : Tensor = aten::mm(%x, %9)
%11 : Tensor = aten::t(%w_hh)
%12 : Tensor = aten::mm(%hx, %11)
%13 : Tensor = aten::add(%10, %12, %8)
%14 : Tensor = aten::add(%13, %b_ih, %8)
%gates : Tensor = aten::add(%14, %b_hh, %8)
%16 : Tensor[] = aten::chunk(%gates, %7, %8)
%ingate.1 : Tensor, %forgetgate.1 : Tensor, %cellgate.1 : Tensor, %outgate.1 : Tensor = prim::ListUnpack(%16)
%ingate : Tensor = aten::sigmoid(%ingate.1)
%forgetgate : Tensor = aten::sigmoid(%forgetgate.1)
%cellgate : Tensor = aten::tanh(%cellgate.1)
%outgate : Tensor = aten::sigmoid(%outgate.1)
%25 : Tensor = aten::mul(%forgetgate, %cx)
%26 : Tensor = aten::mul(%ingate, %cellgate)
%cy : Tensor = aten::add(%25, %26, %8)
%28 : Tensor = aten::tanh(%cy)
%hy : Tensor = aten::mul(%outgate, %28)
%30 : (Tensor, Tensor) = prim::TupleConstruct(%hy, %cy)
return (%30)

可以看到输入参数的类型已经被重新标注了。

运行必须的 pass

pass 是编译器中的一个概念,意为“遍历一遍 IR,可以同时对它做一些操作”。这一步是为解释器生成合法 Graph 所必需的转换步骤。举个例子,一些 pass(如微分)会引入并非由算子定义的 Node,这就需要运行 pass 去清除这些多余的 NodeLowerGradOfspecializeAutogradZero 就是做了这件事。

通过 Graph 传播详细信息

这一步会传播常量和 ArgumentSpec 等信息,尽可能进行预计算。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
graph(%x : Float(*, *),
%hx : Float(*, *),
%cx : Float(*, *),
%w_ih : Float(*, *),
%w_hh : Float(*, *),
%b_ih : Float(*),
%b_hh : Float(*)):
%8 : int = prim::Constant[value=1]()
%9 : Float(*, *) = aten::t(%w_ih)
%10 : Float(*, *) = aten::mm(%x, %9)
%11 : Float(*, *) = aten::t(%w_hh)
%12 : Float(*, *) = aten::mm(%hx, %11)
%13 : Float(*, *) = aten::add(%10, %12, %8)
%14 : Float(*, *) = aten::add(%13, %b_ih, %8)
%gates : Float(*, *) = aten::add(%14, %b_hh, %8)
%31 : Float(*, *), %32 : Float(*, *), %33 : Float(*, *), %34 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%gates)
%ingate : Float(*, *) = aten::sigmoid(%31)
%forgetgate : Float(*, *) = aten::sigmoid(%32)
%cellgate : Float(*, *) = aten::tanh(%33)
%outgate : Float(*, *) = aten::sigmoid(%34)
%25 : Float(*, *) = aten::mul(%forgetgate, %cx)
%26 : Float(*, *) = aten::mul(%ingate, %cellgate)
%cy : Float(*, *) = aten::add(%25, %26, %8)
%28 : Float(*, *) = aten::tanh(%cy)
%hy : Float(*, *) = aten::mul(%outgate, %28)
%30 : (Float(*, *), Float(*, *)) = prim::TupleConstruct(%hy, %cy)
return (%30)

可以看到,除了输入参数,中间变量的类型也被重新标注了。

保留梯度的优化

这些优化不会破坏 autograd。主要包括以下优化:

  • 消除 dead code
  • 消除公共子表达式
  • 将冗余常量池化为单个值
  • 窥孔优化,包括将一些代数操作重写为更简单的操作
  • 展开小循环
  • 展开循环产生的批处理矩阵乘法
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
graph(%x : Float(*, *),
%hx : Float(*, *),
%cx : Float(*, *),
%w_ih : Float(*, *),
%w_hh : Float(*, *),
%b_ih : Float(*),
%b_hh : Float(*)):
%8 : int = prim::Constant[value=1]()
%9 : Float(*, *) = aten::t(%w_ih)
%10 : Float(*, *) = aten::mm(%x, %9)
%11 : Float(*, *) = aten::t(%w_hh)
%12 : Float(*, *) = aten::mm(%hx, %11)
%13 : Float(*, *) = aten::add(%10, %12, %8)
%14 : Float(*, *) = aten::add(%13, %b_ih, %8)
%gates : Float(*, *) = aten::add(%14, %b_hh, %8)
%31 : Float(*, *), %32 : Float(*, *), %33 : Float(*, *), %34 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%gates)
%ingate : Float(*, *) = aten::sigmoid(%31)
%forgetgate : Float(*, *) = aten::sigmoid(%32)
%cellgate : Float(*, *) = aten::tanh(%33)
%outgate : Float(*, *) = aten::sigmoid(%34)
%25 : Float(*, *) = aten::mul(%forgetgate, %cx)
%26 : Float(*, *) = aten::mul(%ingate, %cellgate)
%cy : Float(*, *) = aten::add(%25, %26, %8)
%28 : Float(*, *) = aten::tanh(%cy)
%hy : Float(*, *) = aten::mul(%outgate, %28)
%30 : (Float(*, *), Float(*, *)) = prim::TupleConstruct(%hy, %cy)
return (%30)

这一步对于这个例子没有变化。

非微分相关优化

在不需要梯度的情况下(即仅推理时),可以直接应用优化来生成不带梯度的 Graph。以 FuseGraph 例,它寻找相邻的 point-wise 操作和诸如 splitconcat 之类的reviewing 操作,并在图中创建 prim::FusionGroup 节点来替换这些操作。被注册以执行 prim::FusionGroup 的节点将为每个唯一的 Node 生成一个新的 CUDA kernel 函数,取代原来分离执行的方式。

注意融合组编译的两个阶段:首先,FuseGraph pass 将 Graph 拆分为可融合的子图,并将生成的 Graph 返回给图执行器。然后,当 Graph 转化为 Code 时,会查找 FusionGroup 节点对应的 Operation,并生成一个新的 CUDA kernel 函数。其他编译器应该以类似的方式工作,首先将一个新的算子引入到编译代码应该运行的图中,然后注册一个运算符来实现执行实际编译的节点。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
graph(%x : Float(*, *),
%hx : Float(*, *),
%cx : Float(*, *),
%w_ih : Float(*, *),
%w_hh : Float(*, *),
%b_ih : Float(*),
%b_hh : Float(*)):
%9 : Float(*, *) = aten::t(%w_ih)
%10 : Float(*, *) = aten::mm(%x, %9)
%11 : Float(*, *) = aten::t(%w_hh)
%12 : Float(*, *) = aten::mm(%hx, %11)
%77 : Tensor[] = prim::ListConstruct(%b_hh, %b_ih, %10, %12)
%78 : Tensor[] = aten::broadcast_tensors(%77)
%79 : Tensor, %80 : Tensor, %81 : Tensor, %82 : Tensor = prim::ListUnpack(%78)
%hy : Float(*, *), %cy : Float(*, *) = prim::FusionGroup_0(%cx, %82, %81, %80, %79)
%30 : (Float(*, *), Float(*, *)) = prim::TupleConstruct(%hy, %cy)
return (%30);
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
with prim::FusionGroup_0 = graph(%13 : Float(*, *),
%71 : Tensor,
%76 : Tensor,
%81 : Tensor,
%86 : Tensor):
%87 : Float(*, *), %88 : Float(*, *), %89 : Float(*, *), %90 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%86)
%82 : Float(*, *), %83 : Float(*, *), %84 : Float(*, *), %85 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%81)
%77 : Float(*, *), %78 : Float(*, *), %79 : Float(*, *), %80 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%76)
%72 : Float(*, *), %73 : Float(*, *), %74 : Float(*, *), %75 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%71)
%69 : int = prim::Constant[value=1]()
%70 : Float(*, *) = aten::add(%77, %72, %69)
%66 : Float(*, *) = aten::add(%78, %73, %69)
%62 : Float(*, *) = aten::add(%79, %74, %69)
%58 : Float(*, *) = aten::add(%80, %75, %69)
%54 : Float(*, *) = aten::add(%70, %82, %69)
%50 : Float(*, *) = aten::add(%66, %83, %69)
%46 : Float(*, *) = aten::add(%62, %84, %69)
%42 : Float(*, *) = aten::add(%58, %85, %69)
%38 : Float(*, *) = aten::add(%54, %87, %69)
%34 : Float(*, *) = aten::add(%50, %88, %69)
%30 : Float(*, *) = aten::add(%46, %89, %69)
%26 : Float(*, *) = aten::add(%42, %90, %69)
%ingate : Float(*, *) = aten::sigmoid(%38)
%forgetgate : Float(*, *) = aten::sigmoid(%34)
%cellgate : Float(*, *) = aten::tanh(%30)
%outgate : Float(*, *) = aten::sigmoid(%26)
%14 : Float(*, *) = aten::mul(%forgetgate, %13)
%11 : Float(*, *) = aten::mul(%ingate, %cellgate)
%cy : Float(*, *) = aten::add(%14, %11, %69)
%4 : Float(*, *) = aten::tanh(%cy)
%hy : Float(*, *) = aten::mul(%outgate, %4)
return (%hy, %cy)

可以看到,sigmoid、tanh 等 point-wise 的操作被融合成 prim::FusionGroup_0

在不需要梯度的情况下,优化过程到此结束。在 compileSpec 函数的最后,从 Graph 构造一个 ExecutionPlan 对象并交给 InterpreterState 执行。

InterpreterState

几个类

ExecutionPlan

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
struct ExecutionPlan {
ExecutionPlan() = default;
ExecutionPlan(
std::shared_ptr<Graph> graph,
std::string function_name,
size_t remaining_bailout_depth = 0)
: code(graph, std::move(function_name), remaining_bailout_depth),
graph(std::move(graph)) {}

operator bool() const {
return static_cast<bool>(graph);
}

Code code;
std::shared_ptr<Graph> graph;
};

ExecutionPlan 对象的构造仅仅是将 Graph 转换成 Code

Code

CodeCodeImpl 的 wrapper。这里只列出一部分代码。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
struct CodeImpl {
PreprocessGraph preprocess_;
std::vector<Instruction> instructions_;

CodeImpl(
const std::shared_ptr<Graph>& graph,
std::string function_name,
size_t remaining_bailout_depth,
bool emit_instructions = true)
: function_name_(std::move(function_name)),
preprocess_(*graph),
current_node_(preprocess_.graph->return_node()),
remaining_bailout_depth_(remaining_bailout_depth) {
graph_ = preprocess_.graph;
n_outputs = graph_->outputs().size();
if (n_outputs == 1) {
return_type_ = graph->outputs().at(0)->type();
} else {
return_type_ = TupleType::create(fmap(graph->outputs(), [](const Value* v) { return v->type(); }));
}
n_inputs = graph_->inputs().size();
if (emit_instructions) {
// NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
run();
}
}

virtual void run() {
emitCodeForBlock(graph_->block());
insertInstruction(RET);
// we deferred the emission of bailout blocks so they appear at the end
// emit them now and patch up the jumps
insertBailoutBlocks();
}
}

Code 通过 PreprocessGraph、emitCodeForBlock、insertInstruction 和 insertBailoutBlocks 四个步骤将 Graph 转为一系列指令。

Operation

大多数 Operation 都有一个关联的 FunctionSchema,它描述了有多少输入将被 pop 以及有多少将被 push。堆栈概念使得定义具有可变数量输入和输出的运算符变得容易,而无需为每个单独的 Operation 分配输入和输出向量。下面的例子展示了 Operation 的使用过程:

1
2
3
4
5
6
7
8
9
10
using Stack = std::vector<IValue>;
using Operation = std::function<void(Stack*)>;

// schema: example_add(Tensor a, Tensor b) -> Tensor
void example_add(Stack* stack) {
Tensor a, b, c;
pop(stack, a, b);
c = a + b;
push(stack, c);
}

Operator

Operator 可以简单理解为 OperationFunctionSchema 的 wrapper。

步骤

PreprocessGraph

1
2
3
4
5
6
7
PreprocessGraph::PreprocessGraph(Graph& g) : graph(g.copy()) {
insertEnterMethodCalls(*graph);
dropUnused(graph->block());
// fill in move_flags by scanning blocks;
insertLastUses(*graph);
can_emit_inline = std::move(CanEmitInline(*graph.get()).can_emit_inline_);
}
  • insertEnterMethodCalls:在 prim::Enter 节点之后插入显式 prim::MethodCall 节点以实际调用对象上的 __enter__ 方法
  • dropUnused:插入 prim::Drop 节点以终止任何未被使用的引用
  • insertLastUses:确保每个 Value 在定义它的 Block 中都有消费者。对于大多数节点来说,这已经是正确的。例外情况是:
    • 从未被使用的 Value :添加一个 prim::Drop 节点,该节点在定义后立即使用该值
    • 最后一次被使用的 Value:在最后一次使用的控制流节点之后插入一个 prim::Drop

CodeImpl::emitCodeForBlock

1
2
3
4
5
6
7
void interpreter::CodeImpl::emitCodeForBlock(Block* block) {
emitNodeAtBlockLevel(block->param_node());
for (auto node : block->nodes()) {
emitNodeAtBlockLevel(node);
}
emitNodeAtBlockLevel(block->return_node());
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
void interpreter::CodeImpl::emitNodeAtBlockLevel(Node* node) {
WithCurrentNode guard(&current_node_, node);
switch (node->kind()) {
case prim::Constant:
emitConstant(node);
break;
case prim::Return:
emitLoadInputs(node->inputs());
break;
default:
if (!preprocess_.can_emit_inline[node]) {
emitNode(node);
emitStoreOutputs(node);
}
break;
}
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
void interpreter::CodeImpl::emitNode(Node* node) {
WithCurrentNode guard(&current_node_, node);
switch (node->kind()) {
default:
emitOperator(node);
break;
case prim::Drop:
emitDrop(node->inputs());
break;
case prim::Constant:
emitConstant(node);
break;
case prim::If: ...
case prim::Loop: ...
case aten::wait: ...
case prim::Param: ...
case prim::CallMethod: ...
case prim::TypeCheck: ...
case prim::BailOut: ...
case prim::profile_ivalue: ...
case prim::profile: ...
case prim::GetAttr: ...
case prim::SetAttr: ...
case prim::ListUnpack: ...
case prim::TupleConstruct: ...
case prim::ListConstruct: ...
case prim::DictConstruct: ...
case prim::CreateObject: ...
case prim::isinstance: ...
case prim::TupleSlice: ...
case prim::fork: ...
case aten::warn: ...
case prim::Enter: ...
case prim::Exit: ...
}
}

emitNode 函数根据节点的类型将指令添加到 interpreter::CodeImpl::instructions_ 中。

emitCallemitOperator 为例:

1
2
3
4
5
void interpreter::CodeImpl::emitCall(Function* func, at::ArrayRef<Value*> inputs) {
emitLoadInputs(inputs);
insertInstruction(CALL, function_table_.size());
function_table_.emplace_back(func);
}
1
2
3
4
5
6
7
8
9
10
virtual void interpreter::CodeImpl::emitOperator(Node* node) {
emitLoadInputs(node->inputs());
const Operator& op = node->getOperator();
if (op.hasOperation() && op.schema().is_vararg()) {
insertInstruction(OPN, operator_table_.size(), node->inputs().size());
} else {
insertInstruction(OP, operator_table_.size());
}
operator_table_.emplace_back(op.getOperation(node));
}

这里出现了两个变量:function_table_operator_table_,前者存放需要调用的函数,后者存放 Operator 对应的 Operation

CodeImpl::insertInstruction

1
2
3
4
5
6
7
8
void interpreter::CodeImpl::insertInstruction(OpCode op, int64_t X = 0, uint64_t N = 0) {
instructions_.emplace_back(
op,
safe_narrow_cast<int32_t, int64_t>(X),
safe_narrow_cast<uint16_t, uint64_t>(N));

......
}

可以看到,insertInstructionOperator 对应的 OpCode 填入了 instructions_ 中。

CodeImpl::insertBailoutBlocks

这里涉及到一个概念,什么是 bailout ?介绍这个概念之前,先引入另一个节点 prim::profileprim::profile 节点会被 ProfilingRecord::instrumentBlock 插入到每个 Value 的使用者中。prim::profile 节点的作用可以简单理解为记录和推导假设 Tensor 的形状,这有利于代码生成器能够生成更高效的代码。当在实际运行时,假设的 Tensor 形状有可能是错的,为了防止由此导致的运行失败,需要运行原始代码,即不依赖于形状假设的代码。CodeImpl::insertBailoutBlocks 构建原始计算图的去优化版本。bailout 这个单词的中文意思是“紧急救助”,在 LibTorch 中的作用也类似于此。

InterpreterStateImpl::run

InterpreterStateInterpreterStateImpl 的 wrapper。而 InterpreterStateImpl 则是一个可以执行指令的虚拟机。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
void InterpreterStateImpl::run(Stack& stack) {
if (runImpl(stack)) {
future_->wait();

auto num_outputs = frames.front().function->n_outputs;
if (num_outputs == 1) {
push(stack, future_->value());
} else {
auto tuple = future_->value().toTuple();
for (const IValue& value : tuple->elements()) {
push(stack, value);
}
}
}
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
bool runImpl(Stack& stack) {
// if we have never run before, then we might have to return the
// stack when we suspend, record where it starts so we return the right
// stack
if (stack_start_ == -1) {
TORCH_INTERNAL_ASSERT(stack.size() >= frames.back().function->n_inputs);
stack_start_ = stack.size() - frames.back().function->n_inputs;
} else {
// during restarts, all of the stack is always our own, so we leave
// nothing
stack_start_ = 0;
}

TLSCurrentInterpreterGuard g(this);
if (frames.back().pc == 0 && stack_start_ == 0) {
checkAndStartRecordFunction(frames.back(), stack);
}
try {
while (true) {
Frame& frame = frames.back();
// std::cout << "RUNNING ";
// frames.back().function->dump(std::cout, frame.pc);
Instruction inst = frame.function->instructions_[frame.pc];
switch (inst.op) {
case ENTER: {
const auto& obj = peek(stack, 0, 1);
TORCH_INTERNAL_ASSERT(obj.isObject());
entered_objects.push_back(obj);
++frame.pc;
} break;
case EXIT: {
auto obj = entered_objects.back().toObject();
auto& f = obj->type()->getMethod("__exit__");
push(stack, std::move(obj));
entered_objects.pop_back();
push(stack, IValue());
push(stack, IValue());
push(stack, IValue());
runGraphFunction(stack, &f);
} break;
case OP:
frame.function->operator_table_[inst.X](&stack);
++frame.pc;
break;
case OPN:
stack.push_back(inst.N);
frame.function->operator_table_[inst.X](&stack);
++frame.pc;
break;
case LOAD:
stack.emplace_back(reg(inst.X));
++frame.pc;
break;
case MOVE:
stack.emplace_back(std::move(reg(inst.X)));
++frame.pc;
break;
case STORE:
reg(inst.X) = pop(stack);
++frame.pc;
break;
case STOREN:
for (size_t i = inst.N; i > 0; --i) {
reg(inst.X + i - 1) = pop(stack);
}
++frame.pc;
break;
case DROP:
pop(stack);
++frame.pc;
break;
case DROPR:
reg(inst.X) = IValue();
++frame.pc;
break;
case LOADC:
stack.emplace_back(frame.function->constant_table_[inst.X]);
++frame.pc;
break;
......
case CALL: {
Function* fn = frame.function->function_table_[inst.X];
if (!fn->isGraphFunction()) {
runBuiltinFunction(stack, fn);
} else {
runGraphFunction(stack, fn);
}
} break;
......
}
}
} catch (std::exception& e) {
......
}
}

void runGraphFunction(Stack& stack, Function* fn) {
const Code& code =
// consider passing
// `frames.back().function->remaining_bailout_depth_` into
// `get_executor().getPlanFor()` to propagate caller's depth
// restrictions onto children while this strategy has a
// potential to reduce the number of compilations for too
// dynamic callers we might miss opportunities where a caller is
// dynamic but a callee gets stable arguments
fn->get_executor()
.getPlanFor(stack, GraphExecutor::getDefaultNumBailOuts())
.code;
++frames.back().pc;
enterFrame(code, stack.size() - code.num_inputs());
checkAndStartRecordFunction(frames.back(), stack);
}

void enterFrame(const Code& code, size_t base_pointer) {
frames.emplace_back(Frame{code.pImpl, 0, base_pointer, c10::nullopt});
registers.resize(registers.size() + code.pImpl->register_size_);
}

这个函数虽然复杂而且比较核心,但是也容易理解,它只是模仿了 CPU 处理指令的行为:

  1. 从 stack 中获取指令,并将指令、pc 等信息存入当前 frame 中
  2. 获取到当前 pc 指向的指令
  3. 根据指令的类型进行不同的操作。例如,对于 OpCode::OP,获取 operator_table_ 中存放的 Operation,然后执行它;对于 LOAD 将 reg 中的变量放入 stack 中
  4. 遇到子图调用,push 当前 frame 到 frames

warm up

从上面的分析可以看出,LibTorch 是 lazy initialization 模式,在 Module 第一次 run 的时候才会运行图优化。所以在使用 LibTorch 的时候要注意进行 warm up。需要注意的是,用来保存 Graph 优化结果的 map 所使用的的 key 是 ArgumentSpec。这就要求在使用的时候尽量保持输入的一致,这里的一致是指输入 Tensor 的和 ArgumentSpec 相关的属性要保持一致。

总结

纵观整个 LibTorch 运行流程,LibTorch 实际上实现了一个编译器。保存模型阶段对应编译器的前端(语法分析、类型检查、中间代码生成)。运行模型阶段对应编译器后端(代码优化、目标代码生成、目标代码优化)。除此之外,LibTorch 还实现了一个可以运行该编译器所生成代码的解释器。

参考

PyTorch 源码解读之即时编译篇

TorchScript 如何实现Python -> C++ 代码转换

PyTorch JIT Source Code Read Note

pytorch/OVERVIEW.md

TorchScript 原理篇(下) - 运行模型 | PyTorch

http://www.zh0ngtian.tech/posts/ec218ec1.html

作者

zhongtian

发布于

2021-07-29

更新于

2023-12-16

许可协议

评论