Skip to content

Commit 58ab25e

Browse files
authored
[FFI] Add ffi::Tensor.strides() (#18276)
* ffi::Tensor strides
1 parent 3c36ce2 commit 58ab25e

File tree

5 files changed

+27
-7
lines changed

5 files changed

+27
-7
lines changed

ffi/examples/packaging/src/extension.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
* The library is written in C++ and can be compiled into a shared library.
2525
* The shared library can then be loaded into python and used to call the functions.
2626
*/
27+
#include <tvm/ffi/container/tensor.h>
2728
#include <tvm/ffi/dtype.h>
2829
#include <tvm/ffi/error.h>
2930
#include <tvm/ffi/function.h>
@@ -43,7 +44,7 @@ namespace ffi = tvm::ffi;
4344
*/
4445
void RaiseError(ffi::String msg) { TVM_FFI_THROW(RuntimeError) << msg; }
4546

46-
void AddOne(DLTensor* x, DLTensor* y) {
47+
void AddOne(ffi::Tensor x, ffi::Tensor y) {
4748
// implementation of a library function
4849
TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
4950
DLDataType f32_dtype{kDLFloat, 32, 1};

ffi/examples/quick_start/src/add_one_cpu.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,14 @@
1616
* specific language governing permissions and limitations
1717
* under the License.
1818
*/
19-
19+
#include <tvm/ffi/container/tensor.h>
2020
#include <tvm/ffi/dtype.h>
2121
#include <tvm/ffi/error.h>
2222
#include <tvm/ffi/function.h>
2323

2424
namespace tvm_ffi_example {
2525

26-
void AddOne(DLTensor* x, DLTensor* y) {
26+
void AddOne(tvm::ffi::Tensor x, tvm::ffi::Tensor y) {
2727
// implementation of a library function
2828
TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
2929
DLDataType f32_dtype{kDLFloat, 32, 1};

ffi/examples/quick_start/src/add_one_cuda.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
* specific language governing permissions and limitations
1717
* under the License.
1818
*/
19+
#include <tvm/ffi/container/tensor.h>
1920
#include <tvm/ffi/dtype.h>
2021
#include <tvm/ffi/error.h>
2122
#include <tvm/ffi/extra/c_env_api.h>
@@ -30,7 +31,7 @@ __global__ void AddOneKernel(float* x, float* y, int n) {
3031
}
3132
}
3233

33-
void AddOneCUDA(DLTensor* x, DLTensor* y) {
34+
void AddOneCUDA(tvm::ffi::Tensor x, tvm::ffi::Tensor y) {
3435
// implementation of a library function
3536
TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
3637
DLDataType f32_dtype{kDLFloat, 32, 1};

ffi/include/tvm/ffi/container/tensor.h

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ class TensorObj : public Object, public DLTensor {
151151
protected:
152152
// backs up the shape/strides
153153
Optional<Shape> shape_data_;
154-
Optional<Shape> stride_data_;
154+
Optional<Shape> strides_data_;
155155

156156
static void DLManagedTensorDeleter(DLManagedTensor* tensor) {
157157
TensorObj* obj = static_cast<TensorObj*>(tensor->manager_ctx);
@@ -189,7 +189,7 @@ class TensorObjFromNDAlloc : public TensorObj {
189189
this->strides = const_cast<int64_t*>(strides.data());
190190
this->byte_offset = 0;
191191
this->shape_data_ = std::move(shape);
192-
this->stride_data_ = std::move(strides);
192+
this->strides_data_ = std::move(strides);
193193
alloc_.AllocData(static_cast<DLTensor*>(this), std::forward<ExtraArgs>(extra_args)...);
194194
}
195195

@@ -208,7 +208,7 @@ class TensorObjFromDLPack : public TensorObj {
208208
if (tensor_->dl_tensor.strides == nullptr) {
209209
Shape strides = Shape(details::MakeStridesFromShape(ndim, shape));
210210
this->strides = const_cast<int64_t*>(strides.data());
211-
this->stride_data_ = std::move(strides);
211+
this->strides_data_ = std::move(strides);
212212
}
213213
}
214214

@@ -244,6 +244,18 @@ class Tensor : public ObjectRef {
244244
}
245245
return *(obj->shape_data_);
246246
}
247+
/*!
248+
* \brief Get the strides of the Tensor.
249+
* \return The strides of the Tensor.
250+
*/
251+
tvm::ffi::Shape strides() const {
252+
TensorObj* obj = get_mutable();
253+
TVM_FFI_ICHECK(obj->strides != nullptr);
254+
if (!obj->strides_data_.has_value()) {
255+
obj->strides_data_ = tvm::ffi::Shape(obj->strides, obj->strides + obj->ndim);
256+
}
257+
return *(obj->strides_data_);
258+
}
247259
/*!
248260
* \brief Get the data type of the Tensor.
249261
* \return The data type of the Tensor.

ffi/tests/cpp/test_tensor.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,15 @@ inline Tensor Empty(Shape shape, DLDataType dtype, DLDevice device) {
3535
TEST(Tensor, Basic) {
3636
Tensor nd = Empty(Shape({1, 2, 3}), DLDataType({kDLFloat, 32, 1}), DLDevice({kDLCPU, 0}));
3737
Shape shape = nd.shape();
38+
Shape strides = nd.strides();
3839
EXPECT_EQ(shape.size(), 3);
3940
EXPECT_EQ(shape[0], 1);
4041
EXPECT_EQ(shape[1], 2);
4142
EXPECT_EQ(shape[2], 3);
43+
EXPECT_EQ(strides.size(), 3);
44+
EXPECT_EQ(strides[0], 6);
45+
EXPECT_EQ(strides[1], 3);
46+
EXPECT_EQ(strides[2], 1);
4247
EXPECT_EQ(nd.dtype(), DLDataType({kDLFloat, 32, 1}));
4348
for (int64_t i = 0; i < shape.Product(); ++i) {
4449
reinterpret_cast<float*>(nd->data)[i] = static_cast<float>(i);
@@ -47,6 +52,7 @@ TEST(Tensor, Basic) {
4752
Any any0 = nd;
4853
Tensor nd2 = any0.as<Tensor>().value();
4954
EXPECT_EQ(nd2.shape(), shape);
55+
EXPECT_EQ(nd2.strides(), strides);
5056
EXPECT_EQ(nd2.dtype(), DLDataType({kDLFloat, 32, 1}));
5157
for (int64_t i = 0; i < shape.Product(); ++i) {
5258
EXPECT_EQ(reinterpret_cast<float*>(nd2->data)[i], i);

0 commit comments

Comments
 (0)