TorchScript IR 中的类型体系 | PyTorch
本文最后更新于:2022年11月20日
之前的一篇翻译《JIT Technical Overview》简单介绍了 TorchScript IR 中涉及到的几个主要类,包括 Graph、Node、Block、Value、Type 等。本文主要是从 C++ 接口的角度梳理下这些类之间的关系。
Module::attributes()
通过 torch::jit::Module::attributes()
方法可以获取到 module 中的所有成员函数。所有成员变量最终以 torch::jit::IValue
的形式在 IR 中表示。
graph TD
module(torch::jit::Module)==>|"attributes()"|attributes(torch::jit::attribute_list)
attributes==>|"attribute()"|attribute(torch::jit::detail::AttributePolicy)
attribute==>|"attribute()"|attribute_(torch::jit::NameValue)
attribute_==>|"name"|name(std::string)
attribute_==>|"value"|value(torch::jit::IValue)
style value fill:#fff,stroke:#000,stroke-width:4px,stroke-dasharray:5,5
Module::get_methods
通过 torch::jit::Module::get_methods()
方法可以获取到 module 中的所有成员函数。下图将几种重要的类加粗标记了。其中一些具体的数据(如某条语句的输入输出、函数的参数等)最终以 torch::jit::Value
和 torch::jit::IValue
的形式在 IR 中表示。
graph TD
module(torch::jit::Module)==>|"get_methods()"|methods(std::vector of torch::jit::Method)
methods==>method(torch::jit::Method)
method==>|"graph()"|graph_(torch::jit::Graph)
graph_==>|"inputs() & outputs()"|ios(std::vector of torch::jit::Value)
ios==>io(torch::jit::Value)
graph_==>|"nodes()"|nodes(torch::jit::graph_node_list)
nodes==>|"node()"|node(torch::jit::Node)
node==>|"kind()"|kind(torch::jit::NodeKind)
kind==>|"ns()"|ns(torch::jit::Symbol // like an interned string)
node==>|"inputs() & outputs()"|ios
node==>|"blocks()"|blocks(std::vector of torch::jit::Block)
blocks==>block(torch::jit::Block)
block==>|"nodes()"|nodes
method==>|"function()"|function(torch::jit::Function)
function==>|"getSchema()"|schema(c10::FunctionSchema)
function==>|"torch::jit::toGraphFunction()"|graph_function(torch::jit::GraphFunction)
graph_function==>|"graph()"|graph_
schema==>|"arguments() & returns()"|ars(std::vector of c10::Argument)
ars==>ar(c10::Argument)
ar==>|"name"|name(std::string)
ar==>|"value"|ivalue(torch::jit::IValue)
style graph_ fill:#fff,stroke:#000,stroke-width:4px
style io fill:#fff,stroke:#000,stroke-width:4px,stroke-dasharray:5,5
style node fill:#fff,stroke:#000,stroke-width:4px
style block fill:#fff,stroke:#000,stroke-width:4px
style ivalue fill:#fff,stroke:#000,stroke-width:4px,stroke-dasharray:5,5
torch::jit::Value
torch::jit::Value
是 TorchScript IR 中非常重要的数据类,它和很多其他类有着复杂的关系。下图展示了几种数据类之间的关系。
graph TD
node(torch::jit::Node)==>|"inputs() & outputs()"|ios(std::vector of torch::jit::Value)
ios==>io(torch::jit::Value)
io==>|"node()"|node
io==>|"uses()"|uses(std::vector of torch::jit::Use)
uses==>use(torch::jit::Use)
use==>|"user()"|node
style io fill:#fff,stroke:#000,stroke-width:4px,stroke-dasharray:5,5
IR 中所有节点的输入输出都以 torch::jit::Value
的形式出现,用以承载复杂的类型系统。
graph LR
value(torch::jit::Value)==>|"debugName()"|debug_name(std::string)
value==>|"type()"|type(c10::TypePtr)
type==>list(c10::ListType)
type==>dict(c10::DictType)
type==>class_(c10::ClassType)
type==>tensor(c10::TensorType)
type==>int(c10::IntType)
type==>function(c10::FunctionType)
type==>type_etc(......)
list==>|"kind()"|kind_list(c10::TypeKind::ListType)
list==>|"getElementType()"|element_type(c10::TypePtr)
dict==>|"kind()"|kind_dict(c10::TypeKind::DictType)
dict==>|"getKeyType()"|key_type(c10::TypePtr)
dict==>|"getValueType()"|value_type(c10::TypePtr)
class_==>|"kind()"|kind_class_(c10::TypeKind::ClassType)
class_==>|"name"|name(std::string)
tensor==>|"kind()"|kind_tensor(c10::TypeKind::TensorType)
tensor==>|"sizes()"|sizes(std::vector of int64_t)
tensor==>|"scalarType()"|scalar_type(at::ScalarType)
tensor==>etc_tensor(......)
int==>|"kind()"|kind_int(c10::TypeKind::IntType)
function==>|"expect()"|func_type_ptr(c10::FunctionTypePtr)
func_type_ptr==>|"function()"|jit_function(torch::jit::Function)
jit_function==>|"torch::jit::toGraphFunction()"|graph_function(torch::jit::GraphFunction)
graph_function==>|"graph()"|graph_(torch::jit::Graph)
torch::jit::IValue
torch::jit::IValue
和 torch::jit::Value
的一个区别是:前者主要出现在程序运行的过程中,后者仅在 IR 中。torch::jit::IValue
是运行过程中出现的所有数据类型的容器。
graph LR
ivalue(torch::jit::IValue)==>|"type()"|type(c10::TypePtr)
ivalue==>|"payload()"|payload(torch::jit::IValue::Payload)
payload==>|"payload()"|nt_payload(torch::jit::IValue::Payload::TriviallyCopyablePayload)
nt_payload==>|"as_int()"|as_int(int64_t)
nt_payload==>|"as_double()"|as_double(double)
nt_payload==>|"as_bool()"|as_bool(bool)
nt_payload==>|"as_intrusive_ptr()"|as_intrusive_ptr(c10::intrusive_ptr_target)
payload==>|"as_tensor()"|as_tensor(at::Tensor // Tensor's actual implementation)
ivalue==>|"tag()"|tag(torch::jit::IValue::Tag)
tag==>tag_tensor(torch::jit::IValue::Tag::Tensor)
tag==>tag_int(torch::jit::IValue::Tag::Int)
tag==>tag_etc(......)
评论系统采用 utterances ,加载有延迟,请稍等片刻。