@@ -151,7 +151,7 @@ class TensorObj : public Object, public DLTensor {
151151 protected:
152152 // backs up the shape/strides
153153 Optional<Shape> shape_data_;
154- Optional<Shape> stride_data_ ;
154+ Optional<Shape> strides_data_ ;
155155
156156 static void DLManagedTensorDeleter (DLManagedTensor* tensor) {
157157 TensorObj* obj = static_cast <TensorObj*>(tensor->manager_ctx );
@@ -189,7 +189,7 @@ class TensorObjFromNDAlloc : public TensorObj {
189189 this ->strides = const_cast <int64_t *>(strides.data ());
190190 this ->byte_offset = 0 ;
191191 this ->shape_data_ = std::move (shape);
192- this ->stride_data_ = std::move (strides);
192+ this ->strides_data_ = std::move (strides);
193193 alloc_.AllocData (static_cast <DLTensor*>(this ), std::forward<ExtraArgs>(extra_args)...);
194194 }
195195
@@ -208,7 +208,7 @@ class TensorObjFromDLPack : public TensorObj {
208208 if (tensor_->dl_tensor .strides == nullptr ) {
209209 Shape strides = Shape (details::MakeStridesFromShape (ndim, shape));
210210 this ->strides = const_cast <int64_t *>(strides.data ());
211- this ->stride_data_ = std::move (strides);
211+ this ->strides_data_ = std::move (strides);
212212 }
213213 }
214214
@@ -244,6 +244,18 @@ class Tensor : public ObjectRef {
244244 }
245245 return *(obj->shape_data_ );
246246 }
247+ /* !
248+ * \brief Get the strides of the Tensor.
249+ * \return The strides of the Tensor.
250+ */
251+ tvm::ffi::Shape strides () const {
252+ TensorObj* obj = get_mutable ();
253+ TVM_FFI_ICHECK (obj->strides != nullptr );
254+ if (!obj->strides_data_ .has_value ()) {
255+ obj->strides_data_ = tvm::ffi::Shape (obj->strides , obj->strides + obj->ndim );
256+ }
257+ return *(obj->strides_data_ );
258+ }
247259 /* !
248260 * \brief Get the data type of the Tensor.
249261 * \return The data type of the Tensor.
0 commit comments