framework::Tensor
表示的是張量。他的一些接口和使用方法
形狀相關的
- dims()
獲取每個維度的一些信息,可以使用接口dims()
比如
framework::Tensor test_tensor;
tensor.dims().size(); //表示有多少維
int batch_size = tensor.dims()[0]; //獲取某一維度的大小
- Resize(shape)
std::vector<int64_t> shape_vec({dim0, dim1, dim2});
framework::DDim shape(framework::make_ddim(shape_vec));
tensor.Resize(shape);
- numel()
int nums = tensor.numel(); //表示有多少個元素
- Slice(i, j), 這個注意只能是切最外層的,返回的是從i到j-1的子張量。(從0開始計算的)
tensor.Slice(i, j);
- 拼接
是用內容拷貝實現的,目前實現在一個op里,concat - 內存拷貝
auto xxstride = framework::stride(xx.dims());
StridedMemcpy<T>(context.template device_context<DeviceContext>(),
輸入的數據, 輸入stride, 拷貝部分dims(),
輸出stride, 輸出的指針);