0. 简介
c10::IValue
像一个数据容器,但是它又不用来直接存储数据,只是一层数据的封装。
怎么理解呢?c10::IValue
可以存储torchscript
里很多类型的数据,比如c10::IValue
存储可能是一个Tensor
,一组Tensor
,或者是一个Moudle
,甚至是一个int
,所以c10::IValue
更像是一种封装,对不同的数据类型进行了一次统一的封装,然后很多很多函数的接口都可以使用这种统一的数据类型了。
如果你用过opencv
,那么你可以觉得眼熟,cv::InputArray
,cv::OutputArray
不就是这么干的么,比如常用的cv::resize
函数,它的输入、输出数据就是cv::InputArray
,cv::OutputArray
,而不是直接使用cv::Mat
,这其实就是一种封装的思想。
void resize( InputArray src, OutputArray dst,Size dsize, double fx = 0, double fy = 0,int interpolation = INTER_LINEAR );
1. 类的构造
先看一下c10::IValue
的定义:
class c10::IValue {Payload payload;Tag tag;bool is_intrusive_ptr;}union Payload {// We use a nested union here so that we can make the copy easy// and efficient in the non-tensor (i.e., trivially copyable)// case. Specifically, we do not have to do a switch-on-tag to// figure out which union member to assign; we can just use// TriviallyCopyablePayload::operator=.union TriviallyCopyablePayload {TriviallyCopyablePayload() : as_int(0) {}int64_t as_int;double as_double;bool as_bool;// Invariant: never nullptr; null state is represented as// c10::UndefinedTensorImpl::singleton() for consistency of// representation with Tensor.c10::intrusive_ptr_target* as_intrusive_ptr;struct {DeviceType type;DeviceIndex index;} as_device;} u;at::Tensor as_tensor;Payload() : u() {}~Payload() {}};
c10::IValue
只有3个成员变量,一个用于存储数据的payload
,一个表示数据类型的tag
,还有一个指示是不是others类型的is_intrusive_ptr
,当然,还有很多很多成员函数,详情看这里,或者..../libtorch/include/ATen/core/ivalue.h
文件
Payload payload
:c10::Payload
是一个union
类型,c10::IValues
在IValue::Payload
中包含这些数据的值,它将基本类型(int64_t, bool, double, Device)
和Tensor
作为值,并将所有其他类型保存在c10::intrusive_ptr_target指针里边
。Tag tag
:c10::Tag
是一个enum
类型,表示c10::IValue
里保存的是什么类型数据,可以支持下面这些类型
#define TORCH_FORALL_TAGS(_) \_(None) \_(Tensor) \_(Storage) \_(Double) \_(ComplexDouble) \_(Int) \_(Bool) \_(Tuple) \_(String) \_(Blob) \_(GenericList) \_(GenericDict) \_(Future) \_(Device) \_(Stream) \_(Object) \_(PyObject) \_(Uninitialized) \_(Capsule) \_(RRef) \_(Quantizer) \_(Generator) \_(Enum)
bool is_intrusive_ptr
:一个bool值,是否为intrusive class
,这个intrusive class
是啥意思我也没太理解,大概可能就是非 [基本类型(int64_t, bool, double, Device)
和Tensor
],其他都是intrusive class
,比如Tuple
,String
之类的。如果为True的话,就得去c10::intrusive_ptr_target
指针里读取数据了。
2. 用法
c10::IValue
最主要的用法应该就是把数据取出来了,这一点从c10::IValue
的成员函数也能看出来,一大半函数都是isXXX,toXXX之类的,转化为其他类型
简单说几个用法:
- 判断
c10::IValue
里边存储的什么类型
c10::IValue a = torch::ones({1, 3, 640, 640});
auto b = a.type().get()->kind();
auto c = c10::typeKindToString(b);
std::cout << c << std::endl;
- 获取数据
使用c10::IValue::toXXX()
函数
torch::Tensor t = ivalue.toTensor(); //TensorType
bool t = ivalue.toBool(); //BoolType
auto t = ivalue.toList(); //ListType