From b362f86038b1a5dbb31a659e82b622799697b585 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 14 Mar 2024 09:57:05 -0500 Subject: [PATCH 1/3] [TIR] LowerTVMBuiltin may use device_type from PrimFunc annotation If an allocation occurs within a host function, it may not have a device/host split. --- src/tir/transforms/lower_tvm_builtin.cc | 36 +++++++++++++----- .../test_tir_transform_lower_tvm_builtin.py | 37 +++++++++++++++++-- 2 files changed, 60 insertions(+), 13 deletions(-) diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index 6da2f873b728..486830e90a50 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -38,6 +38,19 @@ namespace tir { // These information are needed during codegen. class BuiltinLower : public StmtExprMutator { public: + static PrimFunc Build(PrimFunc func) { + Optional device_type = NullOpt; + if (auto target = func->GetAttr(tvm::attr::kTarget)) { + device_type = Integer(target.value()->kind->default_device_type); + } + + BuiltinLower mutator(device_type); + func.CopyOnWrite()->body = mutator.VisitBodyAndRealizeAlloca(func->body); + return func; + } + + BuiltinLower(Optional device_type = NullOpt) : device_type_(device_type) {} + // NOTE: Right now, we make the following scoping requirement // for memory allocated by the following primitives // - tvm_stack_make_array @@ -284,13 +297,17 @@ class BuiltinLower : public StmtExprMutator { Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == attr::device_id) { - ICHECK(!device_id_); + auto cache = device_id_; device_id_ = op->value; - return this->VisitStmt(op->body); + Stmt out = this->VisitStmt(op->body); + device_id_ = cache; + return out; } else if (op->attr_key == attr::device_type) { - ICHECK(!device_type_); + auto cache = device_type_; device_type_ = op->value; - return this->VisitStmt(op->body); + Stmt out = this->VisitStmt(op->body); + device_type_ = cache; + return out; } else { return StmtExprMutator::VisitStmt_(op); } @@ -656,13 +673,12 @@ class BuiltinLower : public StmtExprMutator { namespace transform { Pass LowerTVMBuiltin() { - auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { - if (IsHostFunc(f).value_or(false)) { - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); - f.CopyOnWrite()->body = BuiltinLower().Build(f->body); - VLOG(2) << "LowerTVMBuiltin: " << f; + auto pass_func = [](PrimFunc func, IRModule m, PassContext ctx) { + if (IsHostFunc(func).value_or(false)) { + func = BuiltinLower::Build(func); + VLOG(2) << "LowerTVMBuiltin: " << func; } - return f; + return func; }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerTVMBuiltin", {}); } diff --git a/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py b/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py index de1020ef2078..754ce032404d 100644 --- a/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py +++ b/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py @@ -260,11 +260,13 @@ def expected(): class TestLowerAllocateRequiresDeviceID(tvm.testing.CompareBeforeAfter): + """If device id is missing, error.""" + transform = tvm.tir.transform.LowerTVMBuiltin() def before(): T.func_attr({"target": T.target("llvm")}) - T.attr("dummy", "device_id", 0) + T.attr("dummy", "device_type", 2) # kDLCuda ptr = T.allocate([16], "float32") buf = T.decl_buffer(16, "float32", data=ptr) buf[0] = 0.0 @@ -273,16 +275,45 @@ def before(): class TestLowerAllocateRequiresDeviceType(tvm.testing.CompareBeforeAfter): + """If device type is missing, error. + + The device type can be inferred either from the `"device_type"` + statement attribute, or from the `"target"` function attribute. + Here, we provide neither. The `"tir.is_host_func"` attribute is + provided as otherwise the function would be skipped altogether by + LowerTVMBuiltin. + """ + transform = tvm.tir.transform.LowerTVMBuiltin() def before(): - T.func_attr({"target": T.target("llvm")}) + T.func_attr({"tir.is_host_func": True}) T.attr("dummy", "device_id", 0) + ptr = T.allocate([1024 * 1024], "float32") + buf = T.decl_buffer(1024 * 1024, "float32", data=ptr) + buf[0] = 0.0 + + expected = tvm.TVMError + + +class TestLowerCPUAllocWithFunctionAttr(tvm.testing.CompareBeforeAfter): + """CPU allocations can be handled at codegen time + + Like `TestLowerCPUAllocation`, but the device type is taken from + the function attribute. The `AttrStmt` can override the device + type for allocations within its scope, but it defaults to the + function's target. + """ + + transform = tvm.tir.transform.LowerTVMBuiltin() + + def before(): + T.func_attr({"target": T.target("llvm")}) ptr = T.allocate([16], "float32") buf = T.decl_buffer(16, "float32", data=ptr) buf[0] = 0.0 - expected = tvm.TVMError + expected = before if __name__ == "__main__": From 3145b353b8b98eb231b1a74d314c16dd300cc1a9 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 15 Mar 2024 14:53:16 -0500 Subject: [PATCH 2/3] lint fix --- src/tir/transforms/lower_tvm_builtin.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index 486830e90a50..1a3888a7cd48 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -49,7 +49,7 @@ class BuiltinLower : public StmtExprMutator { return func; } - BuiltinLower(Optional device_type = NullOpt) : device_type_(device_type) {} + explicit BuiltinLower(Optional device_type = NullOpt) : device_type_(device_type) {} // NOTE: Right now, we make the following scoping requirement // for memory allocated by the following primitives From 3af29e5683bed37898f42bd9436d0f4ed4f12bb4 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 14 Mar 2024 09:58:13 -0500 Subject: [PATCH 3/3] [Relax] Implement operators to inspec DLTensor::strides and offset A follow-up PR to https://github.com/apache/tvm/pull/16563. This PR implements similar operators to inspect the runtime values of `DLTensor::strides` and `DLTensor::byte_offset`. In addition, while the element offset is not explicitly present in the `DLTensor` struct, a Relax operator is implemented to infer it from the `byte_offset` and `data_type` fields, for use when interacting with the TIR `BufferNode::elem_offset` field. --- python/tvm/relax/expr.py | 97 +++++++ .../relax/transform/legalize_ops/__init__.py | 1 + .../transform/legalize_ops/inspect_op.py | 128 +++++++++ src/relax/op/tensor/inspect.cc | 180 ++++++++++--- src/relax/op/tensor/inspect.h | 39 +++ tests/python/relax/test_op_inspect.py | 252 ++++++++++++++++++ tests/python/relax/test_op_unpack.py | 127 --------- 7 files changed, 667 insertions(+), 157 deletions(-) create mode 100644 python/tvm/relax/transform/legalize_ops/inspect_op.py create mode 100644 tests/python/relax/test_op_inspect.py delete mode 100644 tests/python/relax/test_op_unpack.py diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index 12f08f4dbf1a..4dca710e7781 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -280,6 +280,33 @@ def shape(self) -> "_DLTensorShapeProxy": self._check_for_tensor_struct_info() return _DLTensorShapeProxy(self) + @property + def strides(self) -> "_DLTensorStrideProxy": + """Returns a proxy object for accessing DLTensor::strides""" + self._check_for_tensor_struct_info() + return _DLTensorStrideProxy(self) + + @property + def byte_offset(self) -> "Expr": + """Returns a proxy object for accessing DLTensor::byte_offset""" + self._check_for_tensor_struct_info() + op = tvm.ir.Op.get("relax.inspect.tensor_byte_offset") + return tvm.relax.Call(op, [self]) + + @property + def elem_offset(self) -> "Expr": + """Returns a proxy object for accessing a DLTensor's elem_offset + + This parameter is not stored in the DLTensor, but is instead + derived from the DLTensor's byte offset and datatype. This is + exposed in Relax for ease of use, and for translation into the + `tir::BufferNode::elem_offset` field when interacting with TIR + buffers. + """ + self._check_for_tensor_struct_info() + op = tvm.ir.Op.get("relax.inspect.tensor_elem_offset") + return tvm.relax.Call(op, [self]) + class _DLTensorDTypeProxy(tvm.runtime.ObjectGeneric): """A proxy object for unpacking DLDatatype from DLTensor @@ -431,6 +458,76 @@ def __getitem__(self, axis: Union[int, PrimExpr, Expr]) -> Expr: return tvm.relax.Call(op, [self.tensor, axis]) +class _DLTensorStrideProxy(tvm.runtime.ObjectGeneric): + """A proxy object for unpacking the strides from DLTensor + + Exposes accessors for the `DLTensor::strides` field. Accessing + these fields will produce `relax.Call` expressions, representing + the field's runtime value. If the datatype of the tensor is known + at compile-time, the `relax.Call` will be normalized into a + `relax.PrimValue`, with no runtime cost. + + Parameters + ---------- + tensor: relax.Expr + + The relax tensor (or a variable referring to a relax tensor), + whose runtime strides is being inspected. + """ + + def __init__(self, tensor): + self.tensor = tensor + + def asobject(self): + """Provide expected in error message + + This method is called when `_DLTensorStrideProxy` is used in a + context that requires a `relax.Expr`. This usage is not + supported, and raising an error here can provide suggested + fixes that are not present in the default error message from + `tvm.runtime.convert_to_object`. + """ + raise TypeError( + f"{self.tensor}.strides cannot be converted to a relax expression, " + f"and should be used as a proxy object to access the runtime strides of the DLTensor. " + f"The DLTensor::ndim field can be accessed as len({self.tensor}), " + f"and the DLTensor::strides array can be accessed as {self.tensor}.strides[i]" + ) + + def __getitem__(self, axis: Union[int, PrimExpr, Expr]) -> Expr: + """Returns the extent of a tensor axis + + Parameters + ---------- + axis: Union[int, PrimExpr, Expr] + + The tensor axis whose extent should be returned. For ease + of use, any python integers or TIR expressions are + converted to `relax.Expr`. + + Returns + ------- + extent: Expr + + The extent of the tensor's axis. + """ + + if not isinstance(axis, tvm.relax.Expr): + axis = tvm.relax.PrimValue(axis) + + if axis.struct_info_ is not None and not isinstance( + axis.struct_info_, tvm.relax.PrimStructInfo + ): + raise TypeError( + f"The index used to access {self.tensor}.strides " + f'must have struct info R.Prim("int64"), ' + f"but index {axis} had struct info {axis.struct_info_}." + ) + + op = tvm.ir.Op.get("relax.inspect.tensor_stride_i") + return tvm.relax.Call(op, [self.tensor, axis]) + + @tvm._ffi.register_object("relax.expr.Call") class Call(ExprWithOp): """Function call node in Relax. diff --git a/python/tvm/relax/transform/legalize_ops/__init__.py b/python/tvm/relax/transform/legalize_ops/__init__.py index e3b3213a38b5..b4aba0291fc1 100644 --- a/python/tvm/relax/transform/legalize_ops/__init__.py +++ b/python/tvm/relax/transform/legalize_ops/__init__.py @@ -23,6 +23,7 @@ from . import grad from . import image from . import index +from . import inspect_op from . import linear_algebra from . import manipulate from . import nn diff --git a/python/tvm/relax/transform/legalize_ops/inspect_op.py b/python/tvm/relax/transform/legalize_ops/inspect_op.py new file mode 100644 index 000000000000..5f1b36667a52 --- /dev/null +++ b/python/tvm/relax/transform/legalize_ops/inspect_op.py @@ -0,0 +1,128 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Legalization functions for DLTensor inspection.""" + +import enum + +from tvm.script import tir as T + +from ...block_builder import BlockBuilder +from ...expr import Call, Expr +from .common import register_legalize + + +class TVMStructFieldKind(enum.IntEnum): + """Equivalent to tvm::tir::builtin::TVMStructFieldKind + + This does not use `enum.auto()` to define the values, because + `enum.auto()` starts from 1, and this must match the C++ + definition which starts from 0. + """ + + kArrAddr = 0 + kArrData = 1 + kArrShape = 2 + kArrStrides = 3 + kArrNDim = 4 + kArrTypeCode = 5 + kArrTypeBits = 6 + kArrTypeLanes = 7 + kArrByteOffset = 8 + kArrDeviceId = 9 + kArrDeviceType = 10 + kArrKindBound_ = 11 + kTVMValueContent = 12 + kTVMValueKindBound_ = 13 + + +@register_legalize("relax.inspect.tensor_stride_i") +def _tensor_stride_i(bb: BlockBuilder, call: Call) -> Expr: + @T.prim_func(private=True) + def _get_tensor_stride_i(dlpack_handle: T.handle, axis: T.int64) -> T.int64: + T.func_attr({"tir.is_host": T.bool(True), "tir.is_scheduled": T.bool(True)}) + assert T.int64(0) <= axis, "Specified axis may not be negative" + ndim: T.int32 = T.tvm_struct_get( + dlpack_handle, 0, int(TVMStructFieldKind.kArrNDim), "int32" + ) + assert axis < T.Cast( + "int64", ndim + ), "Specified axis may not be larger than the tensor's dimensionality" + stride_ptr: T.handle("int64") = T.tvm_struct_get( + dlpack_handle, 0, int(TVMStructFieldKind.kArrStrides), "handle" + ) + + if T.isnullptr(stride_ptr): + shape_ptr: T.handle("int64") = T.tvm_struct_get( + dlpack_handle, 0, int(TVMStructFieldKind.kArrShape), "handle" + ) + shape = T.decl_buffer(ndim, "int64", data=shape_ptr) + + product = T.decl_buffer([], "int64") + product[()] = 1 + + # TODO(Lunderberg): Add a TIR lowering pass to allow + # ranges to start somewhere other than zero. This loop + # could then iterate on `range(axis+1, ndim)`. + for dim_offset in range(ndim - (axis + 1)): + dim = dim_offset + (axis + 1) + product[()] = product[()] * shape[dim] + + return product[()] + else: + strides = T.decl_buffer(ndim, "int64", data=stride_ptr) + stride: T.int64 = strides[axis] + return stride + + gvar = bb.add_func(_get_tensor_stride_i, "_get_tensor_stride_i") + return Call(gvar, call.args) + + +@register_legalize("relax.inspect.tensor_byte_offset") +def _tensor_byte_offset(bb: BlockBuilder, call: Call) -> Expr: + @T.prim_func(private=True) + def _get_tensor_byte_offset(dlpack_handle: T.handle) -> T.int64: + T.func_attr({"tir.is_host": T.bool(True), "tir.is_scheduled": T.bool(True)}) + byte_offset: T.uint64 = T.tvm_struct_get( + dlpack_handle, 0, int(TVMStructFieldKind.kArrByteOffset), "uint64" + ) + return byte_offset + + gvar = bb.add_func(_get_tensor_byte_offset, "_get_tensor_byte_offset") + return Call(gvar, call.args) + + +@register_legalize("relax.inspect.tensor_elem_offset") +def _tensor_elem_offset(bb: BlockBuilder, call: Call) -> Expr: + @T.prim_func(private=True) + def _get_tensor_elem_offset(dlpack_handle: T.handle) -> T.int64: + T.func_attr({"tir.is_host": T.bool(True), "tir.is_scheduled": T.bool(True)}) + byte_offset: T.uint64 = T.tvm_struct_get( + dlpack_handle, 0, int(TVMStructFieldKind.kArrByteOffset), "uint64" + ) + scalar_bits: T.uint8 = T.tvm_struct_get( + dlpack_handle, 0, int(TVMStructFieldKind.kArrTypeBits), "uint8" + ) + lanes: T.uint16 = T.tvm_struct_get( + dlpack_handle, 0, int(TVMStructFieldKind.kArrTypeLanes), "uint16" + ) + bytes_per_element = T.ceildiv(scalar_bits.astype("uint64") * lanes.astype("uint64"), 8) + elem_offset = byte_offset // bytes_per_element + return elem_offset + + gvar = bb.add_func(_get_tensor_elem_offset, "_get_tensor_elem_offset") + return Call(gvar, call.args) diff --git a/src/relax/op/tensor/inspect.cc b/src/relax/op/tensor/inspect.cc index a40b2af5eff4..186fc9fa8690 100644 --- a/src/relax/op/tensor/inspect.cc +++ b/src/relax/op/tensor/inspect.cc @@ -29,6 +29,8 @@ #include #include +#include + namespace tvm { namespace relax { namespace inspect { @@ -50,6 +52,42 @@ TensorStructInfo GetTensorArgInfo(const Call& call) { return tensor_sinfo.value(); } +std::tuple GetTensorArgInfoWithIndex(const Call& call) { + CHECK_EQ(call->args.size(), 2) << "TypeError: " + << "Operator " << call->op << " expects two arguments, " + << "but received " << call->args.size() + << " arguments: " << call->args; + const auto& arg = call->args[0]; + const auto& axis = call->args[1]; + + auto tensor_sinfo = arg->struct_info_.as(); + CHECK(tensor_sinfo) << "TypeError: " + << "Operator " << call->op << " expects arguments (tensor, axis), " + << "but the first argument " << arg << " in expression " << call + << " has struct info " << arg->struct_info_; + + auto axis_sinfo = axis->struct_info_.as(); + CHECK(axis_sinfo) << "TypeError: " + << "Operator " << call->op << " expects arguments (tensor, axis), " + << "but the second argument " << arg << " in expression " << call + << " has struct info " << axis->struct_info_; + + auto int_imm_axis = axis_sinfo->value.as(); + + if (int_imm_axis) { + CHECK_GE(int_imm_axis->value, 0); + } + if (int_imm_axis && !tensor_sinfo->IsUnknownNdim()) { + CHECK_LT(int_imm_axis->value, tensor_sinfo->ndim) + << "ValueError: " + << "Expression " << call << " attempts to access " << arg << ".shape[" + << int_imm_axis->value << "]" + << ", but " << arg << ".shape only has " << tensor_sinfo->ndim << " elements"; + } + + return {GetRef(tensor_sinfo), GetRef(axis_sinfo)}; +} + DataType GetTensorDataType(const Call& call) { return GetTensorArgInfo(call)->dtype; } tir::PrimFunc GetDLTensorField(tir::builtin::TVMStructFieldKind field, DataType field_dtype) { @@ -244,39 +282,11 @@ Expr tensor_shape_i(Expr expr) { StructInfo InferStructInfoTensorShape(const Call& call, const BlockBuilder&) { auto dlpack_type = DataType::Int(64); - CHECK_EQ(call->args.size(), 2) << "TypeError: " - << "Operator " << call->op << " expects two arguments, " - << "but received " << call->args.size() - << " arguments: " << call->args; - const auto& arg = call->args[0]; - const auto& axis = call->args[1]; - - auto tensor_sinfo = arg->struct_info_.as(); - CHECK(tensor_sinfo) << "TypeError: " - << "Operator " << call->op << " expects arguments (tensor, axis), " - << "but the first argument " << arg << " in expression " << call - << " has struct info " << arg->struct_info_; - - auto axis_sinfo = axis->struct_info_.as(); - CHECK(axis_sinfo) << "TypeError: " - << "Operator " << call->op << " expects arguments (tensor, axis), " - << "but the second argument " << arg << " in expression " << call - << " has struct info " << axis->struct_info_; + auto [tensor_sinfo, axis_sinfo] = GetTensorArgInfoWithIndex(call); + auto tensor_shape = tensor_sinfo->GetShape(); auto int_imm_axis = axis_sinfo->value.as(); - if (int_imm_axis) { - CHECK_GE(int_imm_axis->value, 0); - } - if (int_imm_axis && !tensor_sinfo->IsUnknownNdim()) { - CHECK_LT(int_imm_axis->value, tensor_sinfo->ndim) - << "ValueError: " - << "Expression " << call << " attempts to access " << arg << ".shape[" - << int_imm_axis->value << "]" - << ", but " << arg << ".shape only has " << tensor_sinfo->ndim << " elements"; - } - - auto tensor_shape = tensor_sinfo->GetShape(); if (int_imm_axis && tensor_shape.defined()) { return PrimStructInfo(tensor_shape.value()[int_imm_axis->value]); } else { @@ -346,6 +356,116 @@ TVM_REGISTER_OP("relax.inspect.tensor_shape_i") .set_attr("FNormalize", NormalizeToKnownPrimValue) .set_attr("FPurity", Bool(true)); +//// relax.tensor_stride_i + +Expr tensor_stride_i(Expr expr) { + static const Op& op = Op::Get("relax.inspect.tensor_stride_i"); + return Call(op, {expr}); +} + +StructInfo InferStructInfoTensorStride(const Call& call, const BlockBuilder&) { + auto dlpack_type = DataType::Int(64); + + auto [tensor_sinfo, axis_sinfo] = GetTensorArgInfoWithIndex(call); + + auto opt_tensor_shape = tensor_sinfo->GetShape(); + auto int_imm_axis = axis_sinfo->value.as(); + + if (int_imm_axis && opt_tensor_shape.defined()) { + // As of 2024-03-14, Relax does not have an explicit + // representation for striding in `TensorStructInfo`. The + // `FLegalize` function for most operators is implemented in terms + // of `topi`, and is then converted from TE to `tir::PrimFunc` + // using `tvm::tir::CreatePrimFunc`. The `te::Tensor` is + // converted to a `tir::Buffer` in `RewriteStageToBlock`, and uses + // the default empty list for the strides. The empty strides + // represent a compact data array. + // + // Therefore, while Relax does not explicitly represent the + // striding of a tensor, it implicitly requires compact striding + // for any legalizable Tensor. + auto tensor_shape = opt_tensor_shape.value(); + PrimExpr stride = IntImm(DataType::Int(64), 1); + for (size_t axis = int_imm_axis->value + 1; axis < tensor_shape.size(); axis++) { + stride = stride * tensor_shape[axis]; + } + return PrimStructInfo(stride); + } else { + return PrimStructInfo(dlpack_type); + } +} + +TVM_REGISTER_OP("relax.inspect.tensor_stride_i") + .set_num_inputs(2) + .add_argument("tensor", "Tensor", "The tensor to be inspected") + .add_argument("axis", "Prim(int64)", "The axis whose extent should be returned") + .set_attr("FInferStructInfo", InferStructInfoTensorStride) + .set_attr("RequiresArgumentShapes", Bool(false)) + .set_attr("FNormalize", NormalizeToKnownPrimValue) + .set_attr("FPurity", Bool(true)); + +//// relax.tensor_byte_offset + +Expr tensor_byte_offset(Expr expr) { + static const Op& op = Op::Get("relax.inspect.tensor_byte_offset"); + return Call(op, {expr}); +} + +StructInfo InferStructInfoTensorByteOffset(const Call& call, const BlockBuilder&) { + auto dlpack_type = DataType::UInt(64); + + auto tensor_sinfo = GetTensorArgInfo(call); + + auto opt_tensor_shape = tensor_sinfo->GetShape(); + if (opt_tensor_shape.defined()) { + // Relax implicitly requires that the byte offset is zero for any + // legalizable tensor. See InferStructInfoTensorStride for full + // explanation. + return PrimStructInfo(IntImm(dlpack_type, 0)); + } else { + return PrimStructInfo(dlpack_type); + } +} + +TVM_REGISTER_OP("relax.inspect.tensor_byte_offset") + .set_num_inputs(1) + .add_argument("tensor", "Tensor", "The tensor to be inspected") + .set_attr("FInferStructInfo", InferStructInfoTensorByteOffset) + .set_attr("RequiresArgumentShapes", Bool(false)) + .set_attr("FNormalize", NormalizeToKnownPrimValue) + .set_attr("FPurity", Bool(true)); + +//// relax.tensor_elem_offset + +Expr tensor_elem_offset(Expr expr) { + static const Op& op = Op::Get("relax.inspect.tensor_elem_offset"); + return Call(op, {expr}); +} + +StructInfo InferStructInfoTensorElemOffset(const Call& call, const BlockBuilder&) { + auto dlpack_type = DataType::UInt(64); + + auto tensor_sinfo = GetTensorArgInfo(call); + + auto opt_tensor_shape = tensor_sinfo->GetShape(); + if (opt_tensor_shape.defined()) { + // Relax implicitly requires that the element offset is zero for + // any legalizable tensor. See InferStructInfoTensorStride for + // full explanation. + return PrimStructInfo(IntImm(dlpack_type, 0)); + } else { + return PrimStructInfo(dlpack_type); + } +} + +TVM_REGISTER_OP("relax.inspect.tensor_elem_offset") + .set_num_inputs(1) + .add_argument("tensor", "Tensor", "The tensor to be inspected") + .set_attr("FInferStructInfo", InferStructInfoTensorElemOffset) + .set_attr("RequiresArgumentShapes", Bool(false)) + .set_attr("FNormalize", NormalizeToKnownPrimValue) + .set_attr("FPurity", Bool(true)); + } // namespace inspect } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/inspect.h b/src/relax/op/tensor/inspect.h index 0225b00fb307..2aa20a13813f 100644 --- a/src/relax/op/tensor/inspect.h +++ b/src/relax/op/tensor/inspect.h @@ -85,6 +85,45 @@ Expr tensor_ndim(Expr expr); */ Expr tensor_shape_i(Expr expr, Expr axis); +/* \brief Return the DLTensor::strides[i] field + * + * The `int64_t* DLTensor::strides` is allowed to be NULL, which + * represents a compact packing of the data. In this case, the + * returned stride is computed from the `DLTensor::shape`. + * + * \param expr The relax expression to be inspected. Must have + * `TensorStructInfo`. + * + * \param axis The axis to inspect. Must be within the range `0 <= + * axis < tensor_ndim(expr)`, or else the results are undefined. + * + * \returns The int64_t extent of the specified tensor axis, with + * `PrimStructInfo(DataType::Int(64))`. + */ +Expr tensor_stride_i(Expr expr, Expr axis); + +/* \brief Return the DLTensor::byte_offset field + * + * \param expr The relax expression to be inspected. Must have + * `TensorStructInfo`. + * + * \returns The uint64_t byte offset, with `PrimStructInfo(DataType::UInt(64))`. + */ +Expr tensor_byte_offset(Expr expr); + +/* \brief Return the element offset of a DLTensor + * + * While the DLTensor does not directly contain the element offset, it + * can be inferred from the `DLTensor::byte_offset` and + * `DLTensor::data_type` fields. + * + * \param expr The relax expression to be inspected. Must have + * `TensorStructInfo`. + * + * \returns The uint64_t element offset, with `PrimStructInfo(DataType::UInt(64))`. + */ +Expr tensor_elem_offset(Expr expr); + } // namespace inspect } // namespace relax } // namespace tvm diff --git a/tests/python/relax/test_op_inspect.py b/tests/python/relax/test_op_inspect.py new file mode 100644 index 000000000000..18d7a88f051a --- /dev/null +++ b/tests/python/relax/test_op_inspect.py @@ -0,0 +1,252 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import ctypes + +import numpy as np +import pytest + +import tvm.testing + +from tvm import relax +from tvm.ir import Op +from tvm.script import ir as I, relax as R + +# Parameterization for reading dtype of DLTensor. Chosen to have +# multiple distinct type codes, number of lanes, and widths. +dtype = tvm.testing.parameter( + "int32", + "int64", + "float32", + "float32x4", + "bfloat", + "e4m3_float8", +) +shape = tvm.testing.parameter( + [], + [16], + [128, 256], + [1] * 64, +) + +elem_offset = tvm.testing.parameter(0, 64, 128) + + +def test_tensor_dtype_code(dtype): + @I.ir_module + class mod: + @R.function + def main(A: R.Tensor): + return A.dtype.type_code + + built = relax.build(mod) + vm = relax.VirtualMachine(built, tvm.cpu()) + + arg = tvm.nd.empty([16], dtype) + res = vm["main"](arg) + + expected_type_code = tvm.runtime.DataType(dtype).type_code + assert res == expected_type_code + + +def test_tensor_dtype_bits(dtype): + @I.ir_module + class mod: + @R.function + def main(A: R.Tensor): + return A.dtype.bits + + built = relax.build(mod) + vm = relax.VirtualMachine(built, tvm.cpu()) + + arg = tvm.nd.empty([16], dtype) + res = vm["main"](arg) + + expected_type_bits = tvm.runtime.DataType(dtype).bits + assert res == expected_type_bits + + +def test_tensor_dtype_lanes(dtype): + @I.ir_module + class mod: + @R.function + def main(A: R.Tensor): + return A.dtype.lanes + + built = relax.build(mod) + vm = relax.VirtualMachine(built, tvm.cpu()) + + arg = tvm.nd.empty([16], dtype) + res = vm["main"](arg) + + expected_type_lanes = tvm.runtime.DataType(dtype).lanes + assert res == expected_type_lanes + + +def test_tensor_ndim(shape): + @I.ir_module + class mod: + @R.function + def main(A: R.Tensor): + return A.ndim + + built = relax.build(mod) + vm = relax.VirtualMachine(built, tvm.cpu()) + + arg = tvm.nd.empty(shape, "int32") + res = vm["main"](arg) + + assert res == len(shape) + + +def test_tensor_shape(shape): + @I.ir_module + class mod: + @R.function + def main(A: R.Tensor, axis: R.Prim("int64")): + return A.shape[axis] + + built = relax.build(mod) + vm = relax.VirtualMachine(built, tvm.cpu()) + + arg = tvm.nd.empty(shape, "int32") + + res = [vm["main"](arg, i) for i, _ in enumerate(shape)] + + tvm.ir.assert_structural_equal(res, shape) + + +def _get_compact_striding(shape): + strides = [] + product = 1 + for dim in reversed(shape): + strides.append(product) + product *= dim + return list(reversed(strides)) + + +def test_strides_of_compact_tensor(shape): + @I.ir_module + class mod: + @R.function + def main(A: R.Tensor, axis: R.Prim("int64")): + return A.strides[axis] + + built = relax.build(mod) + vm = relax.VirtualMachine(built, tvm.cpu()) + + arg = tvm.nd.empty(shape, "int32") + + res = [vm["main"](arg, i) for i, _ in enumerate(shape)] + expected = _get_compact_striding(shape) + + tvm.ir.assert_structural_equal(res, expected) + + +def test_strides_of_non_compact_tensor(): + backing_shape = [64, 64] + view_shape = [16, 16] + expected_strides = [backing_shape[0], 1] + + @I.ir_module + class mod: + @R.function + def main(A: R.Tensor, axis: R.Prim("int64")): + return A.strides[axis] + + built = relax.build(mod) + vm = relax.VirtualMachine(built, tvm.cpu()) + + backing_ndarray = tvm.nd.empty(backing_shape, "int32") + + # Manually overwrite the DLTensor fields to make a view into the + # tensor. + view = backing_ndarray.handle[0] + np_shape = np.array([16, 16], "int64") + view.shape = np_shape.ctypes.data_as(ctypes.POINTER(ctypes.c_long)) + np_strides = np.array([64, 1], "int64") + view.strides = np_strides.ctypes.data_as(ctypes.POINTER(ctypes.c_long)) + backing_ndarray.handle[0] = view + + res = [vm["main"](backing_ndarray, i) for i, _ in enumerate(view_shape)] + + tvm.ir.assert_structural_equal(res, expected_strides) + + +def test_byte_offset(elem_offset): + backing_shape = [64, 64] + view_shape = [16, 16] + byte_offset = elem_offset * 4 + + @I.ir_module + class mod: + @R.function + def main(A: R.Tensor): + return A.byte_offset + + built = relax.build(mod) + vm = relax.VirtualMachine(built, tvm.cpu()) + + backing_ndarray = tvm.nd.empty(backing_shape, "int32") + + # Manually overwrite the DLTensor fields to make a view into the + # tensor. + view = backing_ndarray.handle[0] + np_shape = np.array(view_shape, "int64") + view.shape = np_shape.ctypes.data_as(ctypes.POINTER(ctypes.c_long)) + view.byte_offset = byte_offset + backing_ndarray.handle[0] = view + + res = vm["main"](backing_ndarray) + + assert res == byte_offset + + +def test_elem_offset(elem_offset, dtype): + tvm_dtype = tvm.runtime.DataType(dtype) + + backing_shape = [64, 64] + view_shape = [16, 16] + element_bytes = (tvm_dtype.bits * tvm_dtype.lanes) // 8 + byte_offset = elem_offset * element_bytes + + @I.ir_module + class mod: + @R.function + def main(A: R.Tensor): + return A.elem_offset + + built = relax.build(mod) + vm = relax.VirtualMachine(built, tvm.cpu()) + + backing_ndarray = tvm.nd.empty(backing_shape, dtype) + + # Manually overwrite the DLTensor fields to make a view into the + # tensor. + view = backing_ndarray.handle[0] + np_shape = np.array(view_shape, "int64") + view.shape = np_shape.ctypes.data_as(ctypes.POINTER(ctypes.c_long)) + view.byte_offset = byte_offset + backing_ndarray.handle[0] = view + + res = vm["main"](backing_ndarray) + + assert res == elem_offset + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_op_unpack.py b/tests/python/relax/test_op_unpack.py deleted file mode 100644 index 03e4e0fc85e4..000000000000 --- a/tests/python/relax/test_op_unpack.py +++ /dev/null @@ -1,127 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import tvm.testing - -from tvm import relax -from tvm.ir import Op -from tvm.script import ir as I, relax as R - -# Parameterization for reading dtype of DLTensor. Chosen to have -# multiple distinct type codes, number of lanes, and widths. -dtype = tvm.testing.parameter( - "int32", - "int64", - "float32", - "float32x4", - "bfloat", - "e4m3_float8", -) -shape = tvm.testing.parameter( - [], - [16], - [128, 256], - [1] * 64, -) - - -def test_tensor_dtype_code(dtype): - @I.ir_module - class mod: - @R.function - def main(A: R.Tensor): - return A.dtype.type_code - - built = relax.build(mod) - vm = relax.VirtualMachine(built, tvm.cpu()) - - arg = tvm.nd.empty([16], dtype) - res = vm["main"](arg) - - expected_type_code = tvm.runtime.DataType(dtype).type_code - assert res == expected_type_code - - -def test_tensor_dtype_bits(dtype): - @I.ir_module - class mod: - @R.function - def main(A: R.Tensor): - return A.dtype.bits - - built = relax.build(mod) - vm = relax.VirtualMachine(built, tvm.cpu()) - - arg = tvm.nd.empty([16], dtype) - res = vm["main"](arg) - - expected_type_bits = tvm.runtime.DataType(dtype).bits - assert res == expected_type_bits - - -def test_tensor_dtype_lanes(dtype): - @I.ir_module - class mod: - @R.function - def main(A: R.Tensor): - return A.dtype.lanes - - built = relax.build(mod) - vm = relax.VirtualMachine(built, tvm.cpu()) - - arg = tvm.nd.empty([16], dtype) - res = vm["main"](arg) - - expected_type_lanes = tvm.runtime.DataType(dtype).lanes - assert res == expected_type_lanes - - -def test_tensor_ndim(shape): - @I.ir_module - class mod: - @R.function - def main(A: R.Tensor): - return A.ndim - - built = relax.build(mod) - vm = relax.VirtualMachine(built, tvm.cpu()) - - arg = tvm.nd.empty(shape, "int32") - res = vm["main"](arg) - - assert res == len(shape) - - -def test_tensor_shape(shape): - @I.ir_module - class mod: - @R.function - def main(A: R.Tensor, axis: R.Prim("int64")): - return A.shape[axis] - - built = relax.build(mod) - vm = relax.VirtualMachine(built, tvm.cpu()) - - arg = tvm.nd.empty(shape, "int32") - - res = [vm["main"](arg, i) for i, _ in enumerate(shape)] - - tvm.ir.assert_structural_equal(res, shape) - - -if __name__ == "__main__": - tvm.testing.main()