From b5a7bae2971354e4a87634b3fbfd0a1318f70e7e Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 26 Apr 2024 19:49:42 -0500 Subject: [PATCH 1/3] [Relax] Allow PrimValue as index in relax.op.take Prior to this commit, the `relax.op.take` only allowed tensors as the `indices` argument. This commit extends `R.take` to also allow the index to be a `relax::PrimValue`. --- include/tvm/relax/block_builder.h | 2 +- include/tvm/topi/transform.h | 43 +++-- src/relax/ir/block_builder.cc | 2 +- src/relax/op/op_common.cc | 48 ++++-- src/relax/op/op_common.h | 21 +++ src/relax/op/tensor/index.cc | 22 ++- tests/python/relax/test_op_index.py | 18 ++ tests/python/relax/test_op_take.py | 158 ++++++++++++++++++ ...sform_legalize_ops_index_linear_algebra.py | 97 +++++++++++ 9 files changed, 380 insertions(+), 31 deletions(-) create mode 100644 tests/python/relax/test_op_take.py diff --git a/include/tvm/relax/block_builder.h b/include/tvm/relax/block_builder.h index a1e5a6bc3125..7ca9aab6d5aa 100644 --- a/include/tvm/relax/block_builder.h +++ b/include/tvm/relax/block_builder.h @@ -116,7 +116,7 @@ class BlockBuilderNode : public Object { * \brief Report an error during transformation construction. * \param diagnostic The diagnostic information. */ - virtual void ReportFatal(const Diagnostic& diagnostic) = 0; + [[noreturn]] virtual void ReportFatal(const Diagnostic& diagnostic) = 0; //------------------------------- // Scope management diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index a1f66a70ca3d..30d2bd7a0b7d 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -1036,7 +1036,7 @@ inline Tensor sequence_mask(const Tensor& data, const Tensor& valid_length, doub * * \return A Tensor whose op member is the take operation */ -inline Tensor take(const Tensor& a, const Tensor& indices, int batch_dims, int axis, +inline Tensor take(const Tensor& a, Variant indices, int batch_dims, int axis, std::string mode = "clip", std::string name = "T_take", std::string tag = kInjective) { if (axis < 0) { @@ -1045,22 +1045,30 @@ inline Tensor take(const Tensor& a, const Tensor& indices, int batch_dims, int a ICHECK_GE(axis, 0) << "axis out of bounds"; ICHECK_LT(axis, a->shape.size()) << "axis out of bounds"; auto axis_dim = a->shape[axis]; - int indices_len = static_cast(indices->shape.size()); + auto indices_shape = [&]() -> Array { + if (auto tensor = indices.as()) { + return tensor->shape; + } else { + return {}; + } + }(); + + int indices_len = static_cast(indices_shape.size()); int batch_dims_ = batch_dims; if (batch_dims_ != 0) { - ICHECK_GE(batch_dims_, -static_cast(indices->shape.size())) << "batch_dims out of bounds"; - ICHECK_LE(batch_dims_, indices->shape.size()) << "batch_dims out of bounds"; + ICHECK_GE(batch_dims_, -indices_len) << "batch_dims out of bounds"; + ICHECK_LE(batch_dims_, indices_len) << "batch_dims out of bounds"; if (batch_dims_ < 0) { - batch_dims_ = indices->shape.size() + batch_dims_; + batch_dims_ = indices_len + batch_dims_; } ICHECK_LT(batch_dims_, a->shape.size()) << "batch_dims out of bounds"; ICHECK_LE(batch_dims_, axis) << "batch_dims must be less than or equal to axis"; for (int i = 0; i < batch_dims_; ++i) { auto addr1 = a->shape[i]; - auto addr2 = indices->shape[i]; + auto addr2 = indices_shape[i]; auto v1 = static_cast(&addr1)->get()->value; auto v2 = static_cast(&addr2)->get()->value; ICHECK_EQ(v1, v2) << "a.shape[" << i << "] should be equal to indices.shape[" << i << "]"; @@ -1077,13 +1085,24 @@ inline Tensor take(const Tensor& a, const Tensor& indices, int batch_dims, int a for (int i = batch_dims_; i < axis; ++i) { out_shape.push_back(a->shape[i]); } - for (size_t i = static_cast(batch_dims_); i < indices->shape.size(); ++i) { - out_shape.push_back(indices->shape[i]); + for (size_t i = static_cast(batch_dims_); i < indices_len; ++i) { + out_shape.push_back(indices_shape[i]); } for (size_t i = axis + 1; i < a->shape.size(); ++i) { out_shape.push_back(a->shape[i]); } + auto get_index = [&](const Array& indices_position) -> PrimExpr { + if (auto tensor = indices.as()) { + return tensor.value()(indices_position); + } else if (auto prim = indices.as()) { + ICHECK_EQ(indices_position.size(), 0); + return prim.value(); + } else { + LOG(FATAL) << "Variant did not contain either allowed type"; + } + }; + if (mode == "clip") { if (batch_dims_ == 0) { return compute( @@ -1097,7 +1116,7 @@ inline Tensor take(const Tensor& a, const Tensor& indices, int batch_dims, int a for (size_t j = 0; j < static_cast(axis); ++j) { real_indices.push_back(out_index[j]); } - auto idx = tvm::min(tvm::max(0, indices(indices_position)), axis_dim - 1); + auto idx = tvm::min(tvm::max(0, get_index(indices_position)), axis_dim - 1); real_indices.push_back(idx); for (size_t j = axis + indices_len; j < out_index.size(); ++j) { real_indices.push_back(out_index[j]); @@ -1120,7 +1139,7 @@ inline Tensor take(const Tensor& a, const Tensor& indices, int batch_dims, int a for (size_t j = 0; j < static_cast(axis); ++j) { real_indices.push_back(out_index[j]); } - auto idx = tvm::min(tvm::max(0, indices(indices_position)), axis_dim - 1); + auto idx = tvm::min(tvm::max(0, get_index(indices_position)), axis_dim - 1); real_indices.push_back(idx); for (size_t j = axis + indices_len - batch_dims_; j < out_index.size(); ++j) { real_indices.push_back(out_index[j]); @@ -1141,7 +1160,7 @@ inline Tensor take(const Tensor& a, const Tensor& indices, int batch_dims, int a for (size_t j = 0; j < static_cast(axis); ++j) { real_indices.push_back(out_index[j]); } - real_indices.push_back(indices(indices_position)); + real_indices.push_back(get_index(indices_position)); for (size_t j = axis + indices_len; j < out_index.size(); ++j) { real_indices.push_back(out_index[j]); } @@ -1160,7 +1179,7 @@ inline Tensor take(const Tensor& a, const Tensor& indices, int batch_dims, int a for (size_t j = 0; j < static_cast(axis); ++j) { real_indices.push_back(out_index[j]); } - auto idx = truncmod(truncmod(indices(indices_position), axis_dim) + axis_dim, axis_dim); + auto idx = truncmod(truncmod(get_index(indices_position), axis_dim) + axis_dim, axis_dim); real_indices.push_back(idx); for (size_t j = axis + indices_len; j < out_index.size(); ++j) { real_indices.push_back(out_index[j]); diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index 0c40c4e62a48..e9a513c317d6 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -149,7 +149,7 @@ class BlockBuilderImpl : public BlockBuilderNode { } } - void ReportFatal(const Diagnostic& diagnostic) final { + [[noreturn]] void ReportFatal(const Diagnostic& diagnostic) final { // TODO(relax-team): Print more context information by looking // into the diagnostic->loc and surrounding IRModule. // We do not materialzie DiagnosticContext to avoid double referencing to diff --git a/src/relax/op/op_common.cc b/src/relax/op/op_common.cc index b35bd4b5a31c..1ccc93e0b230 100644 --- a/src/relax/op/op_common.cc +++ b/src/relax/op/op_common.cc @@ -35,24 +35,44 @@ Array GetCallArgs(const Call& call) { return args; } -Array GetInputTensorStructInfo(const Call& call, const BlockBuilder& ctx) { +void CheckNumArguments(const Call& call, const BlockBuilder& ctx) { Op op = Downcast(call->op); - int n_input = op->arguments.size(); - if (static_cast(call->args.size()) != n_input) { + int expected_input = op->arguments.size(); + if (static_cast(call->args.size()) != expected_input) { ctx->ReportFatal(Diagnostic::Error(call) - << op << " op should have " << n_input << " arguments"); + << "Operator " << op << " expects " << expected_input << " arguments" + << ", but was called with " << call->args.size() << " arguments"); } +} + +TensorStructInfo GetInputTensorStructInfo(const Call& call, size_t i_arg, const BlockBuilder& ctx) { + Op op = Downcast(call->op); + + ICHECK_EQ(op->arguments.size(), call->args.size()) + << "Failure caught by this check " + << "should have previously been caught by `CheckNumArguments`"; + ICHECK_LT(i_arg, op->arguments.size()); + + auto arg = call->args[i_arg]; + auto sinfo = GetStructInfo(arg); + + if (auto tensor_sinfo = sinfo.as()) { + return tensor_sinfo.value(); + } else { + ctx->ReportFatal(Diagnostic::Error(call) + << "Operator " << op << " requires argument " << i_arg << " (" + << op->arguments[i_arg]->name << ") to be a tensor. " + << "However, the argument " << arg << " is instead of type " << sinfo); + } +} + +Array GetInputTensorStructInfo(const Call& call, const BlockBuilder& ctx) { + CheckNumArguments(call, ctx); + + Op op = Downcast(call->op); Array input_tensor_sinfo; - input_tensor_sinfo.reserve(n_input); - for (int i = 0; i < n_input; ++i) { - const auto* sinfo = GetStructInfoAs(call->args[i]); - if (sinfo == nullptr) { - ctx->ReportFatal(Diagnostic::Error(call) - << op << " requires the input " << op->arguments[i]->name - << " to be Tensor. However, the given one has a " - << call->args[i]->struct_info_->GetTypeKey()); - } - input_tensor_sinfo.push_back(GetRef(sinfo)); + for (int i = 0; i < call->args.size(); ++i) { + input_tensor_sinfo.push_back(GetInputTensorStructInfo(call, i, ctx)); } return input_tensor_sinfo; } diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h index 5e19edb47c45..94474ce78444 100644 --- a/src/relax/op/op_common.h +++ b/src/relax/op/op_common.h @@ -44,6 +44,27 @@ namespace relax { /************ Op input struct info getter ************/ +/*! + * \brief Check that the operator has + * + * Verify that the number of arguments matches the expected number for + * the operator. + * + * \param call The context Call to the operator. + * + * \param ctx The error reporting context. + */ +void CheckNumArguments(const Call& call, const BlockBuilder& ctx); + +/*! + * \brief Get the tensor struct info of the operator input. + * \param call The context Call to the operator. + * \param i_arg The index of the argument to check + * \param ctx The error reporting context. + * \return The tensor struct info of the argument + */ +TensorStructInfo GetInputTensorStructInfo(const Call& call, size_t i_arg, const BlockBuilder& ctx); + /*! * \brief Get the tensor struct info of the operator input. * \param call The context Call to the operator. diff --git a/src/relax/op/tensor/index.cc b/src/relax/op/tensor/index.cc index 7ab98e94684a..7d871a62c5da 100644 --- a/src/relax/op/tensor/index.cc +++ b/src/relax/op/tensor/index.cc @@ -44,9 +44,25 @@ Expr take(Expr x, Expr indices, Optional axis) { TVM_REGISTER_GLOBAL("relax.op.take").set_body_typed(take); StructInfo InferStructInfoTake(const Call& call, const BlockBuilder& ctx) { - Array input_sinfo = GetInputTensorStructInfo(call, ctx); - TensorStructInfo data_sinfo = input_sinfo[0]; - TensorStructInfo indices_sinfo = input_sinfo[1]; + CheckNumArguments(call, ctx); + TensorStructInfo data_sinfo = GetInputTensorStructInfo(call, 0, ctx); + + // StructInfo inference when the index is a PrimValue is equivalent + // to that of a scalar (0-d) tensor. + TensorStructInfo indices_sinfo = [&]() { + auto arg = call->args[1]; + auto sinfo = GetStructInfo(arg); + if (auto tensor_sinfo = sinfo.as()) { + return tensor_sinfo.value(); + } else if (auto prim_sinfo = sinfo.as()) { + return TensorStructInfo(ShapeExpr(Array{}), prim_sinfo->dtype); + } else { + ctx->ReportFatal(Diagnostic::Error(call) + << "Operator " << call->op << " requires the indices argument to be " + << "either a tensor or a scalar value. " + << "However, argument " << arg << " has struct info " << sinfo); + } + }(); if (indices_sinfo->IsUnknownDtype()) { // TODO(tvm-team): Do we have an equivalent of `ctx->ReportFatal` for warning? diff --git a/tests/python/relax/test_op_index.py b/tests/python/relax/test_op_index.py index e3c9e4a596ac..1455b4182ae6 100644 --- a/tests/python/relax/test_op_index.py +++ b/tests/python/relax/test_op_index.py @@ -194,6 +194,24 @@ def test_take_infer_struct_info(): _check_inference(bb, relax.op.take(y3, idx7), relax.TensorStructInfo(dtype="", ndim=2)) +def test_take_infer_struct_info_scalar_tensor_index(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((4, 10), "float32")) + idx = relax.Var("idx", R.Tensor([], "int64")) + + _check_inference(bb, relax.op.take(x0, idx, axis=0), relax.TensorStructInfo([10], "float32")) + _check_inference(bb, relax.op.take(x0, idx, axis=1), relax.TensorStructInfo([4], "float32")) + + +def test_take_infer_struct_info_prim_value_index(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((4, 10), "float32")) + idx = relax.Var("idx", R.Prim("int64")) + + _check_inference(bb, relax.op.take(x0, idx, axis=0), relax.TensorStructInfo([10], "float32")) + _check_inference(bb, relax.op.take(x0, idx, axis=1), relax.TensorStructInfo([4], "float32")) + + def test_take_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() m = tir.Var("m", "int64") diff --git a/tests/python/relax/test_op_take.py b/tests/python/relax/test_op_take.py new file mode 100644 index 000000000000..babf91869a41 --- /dev/null +++ b/tests/python/relax/test_op_take.py @@ -0,0 +1,158 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import tvm.testing +from tvm.script import ir as I, relax as R, tir as T + +import numpy as np + +axis = tvm.testing.parameter(0, 1) + + +@tvm.testing.parametrize_targets("llvm") +def test_take_scalar_tensor_as_index(target, dev, axis): + """The index of R.take may be a scalar tensor + + Using a scalar tensor as the index reduces the dimension of the + output. + + """ + + @I.ir_module + class Module: + @R.function + def main(A: R.Tensor([16, 16], "float16")): + output = R.take(A, R.const(1), axis=axis) + return output + + built = tvm.relax.build(Module, target=target) + vm = tvm.relax.VirtualMachine(built, dev) + + np_input = np.random.random(size=[16, 16]).astype("float16") + tvm_input = tvm.nd.array(np_input, dev) + tvm_output = vm["main"](tvm_input) + np_expected = np_input.take(1, axis=axis) + + tvm.testing.assert_allclose(tvm_output.numpy(), np_expected) + + +@tvm.testing.parametrize_targets("llvm") +def test_take_1d_tensor_as_index(target, dev, axis): + """The index of R.take may be a non-scalar tensor + + In general, `R.take` outputs a tensor of dimension + `data.ndim + indices.ndim - 1`. + """ + + @I.ir_module + class Module: + @R.function + def main(A: R.Tensor([16, 16], "float16")): + output = R.take(A, R.const([1]), axis=axis) + return output + + built = tvm.relax.build(Module, target=target) + vm = tvm.relax.VirtualMachine(built, dev) + + np_input = np.random.random(size=[16, 16]).astype("float16") + tvm_input = tvm.nd.array(np_input, dev) + tvm_output = vm["main"](tvm_input) + np_expected = np_input.take([1], axis=axis) + + tvm.testing.assert_allclose(tvm_output.numpy(), np_expected) + + +@tvm.testing.parametrize_targets("llvm") +def test_take_2d_tensor_as_index(target, dev, axis): + """The index of R.take may be a 2-d tensor""" + + @I.ir_module + class Module: + @R.function + def main(A: R.Tensor([16, 16], "float16")): + output = R.take(A, R.const([[1, 3], [5, 7]]), axis=axis) + return output + + built = tvm.relax.build(Module, target=target) + vm = tvm.relax.VirtualMachine(built, dev) + + np_input = np.random.random(size=[16, 16]).astype("float16") + tvm_input = tvm.nd.array(np_input, dev) + tvm_output = vm["main"](tvm_input) + np_expected = np_input.take([[1, 3], [5, 7]], axis=axis) + + tvm.testing.assert_allclose(tvm_output.numpy(), np_expected) + + +@tvm.testing.parametrize_targets("llvm") +def test_take_constant_prim_value_as_index(target, dev, axis): + """The index of R.take may be a R.prim_value + + The `R.prim_value` produces output equivalent to a scalar + tensor. + + """ + + @I.ir_module + class Module: + @R.function + def main(A: R.Tensor([16, 16], "float16")): + output = R.take(A, R.prim_value(1), axis=axis) + return output + + built = tvm.relax.build(Module, target=target) + vm = tvm.relax.VirtualMachine(built, dev) + + np_input = np.random.random(size=[16, 16]).astype("float16") + tvm_input = tvm.nd.array(np_input, dev) + tvm_output = vm["main"](tvm_input) + np_expected = np_input.take(1, axis=axis) + + tvm.testing.assert_allclose(tvm_output.numpy(), np_expected) + + +@tvm.testing.parametrize_targets("llvm") +def test_take_dynamic_prim_value_as_index(target, dev, axis): + """The index of R.take may be a dynamic R.prim_value + + The `R.prim_value` produces output equivalent to a scalar + tensor. + + """ + + @I.ir_module + class Module: + @R.function + def main(A: R.Tensor(["n", "n"], "float16")): + n = T.int64() + output = R.take(A, R.prim_value(n - 1), axis=axis) + return output + + built = tvm.relax.build(Module, target=target) + vm = tvm.relax.VirtualMachine(built, dev) + + np_input = np.random.random(size=[16, 16]).astype("float16") + tvm_input = tvm.nd.array(np_input, dev) + tvm_output = vm["main"](tvm_input) + np_expected = np_input.take(15, axis=axis) + + tvm.testing.assert_allclose(tvm_output.numpy(), np_expected) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py index 0d1e969b35e3..d0aaddb1ca52 100644 --- a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py +++ b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py @@ -55,6 +55,68 @@ def take(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), "float32" tvm.ir.assert_structural_equal(mod, Expected) +def test_take_prim_value(): + # fmt: off + @tvm.script.ir_module + class Take: + @R.function + def main(x: R.Tensor((2, 3, 4), "float32"), index: R.Prim("int64")) -> R.Tensor((2, 4), "float32"): + gv: R.Tensor((2, 4), "float32") = R.take(x, index, axis=1) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3, 4), "float32"), index: R.Prim("int64")) -> R.Tensor((2, 4), "float32"): + gv = R.call_tir(Expected.take, (x, index), R.Tensor((2, 4), dtype="float32")) + return gv + + @T.prim_func(private=True) + def take(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), "float32"), index: T.int64, T_take: T.Buffer((T.int64(2), T.int64(4)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i2 in T.grid(T.int64(2), T.int64(4)): + with T.block("T_take"): + ax0, ax2 = T.axis.remap("SS", [i0, i2]) + T.reads(rxplaceholder[ax0, index, ax2]) + T.writes(T_take[ax0, ax2]) + T_take[ax0, ax2] = rxplaceholder[ax0, index, ax2] + # fmt: on + + mod = LegalizeOps()(Take) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_take_const_prim_value(): + # fmt: off + @tvm.script.ir_module + class Take: + @R.function + def main(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor((2, 4), "float32"): + gv: R.Tensor((2, 4), "float32") = R.take(x, R.prim_value(0), axis=1) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor((2, 4), "float32"): + gv = R.call_tir(Expected.take, (x,), R.Tensor((2, 4), dtype="float32")) + return gv + + @T.prim_func(private=True) + def take(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), "float32"), T_take: T.Buffer((T.int64(2), T.int64(4)), "float32")): + T.func_attr({"tir.noalias": True}) + for i0, i2 in T.grid(T.int64(2), T.int64(4)): + with T.block("T_take"): + ax0, ax2 = T.axis.remap("SS", [i0, i2]) + T.reads(rxplaceholder[ax0, T.int64(0), ax2]) + T.writes(T_take[ax0, ax2]) + T_take[ax0, ax2] = rxplaceholder[ax0, T.int64(0), ax2] + # fmt: on + + mod = LegalizeOps()(Take) + tvm.ir.assert_structural_equal(mod, Expected) + + def test_take_symbolic(): # fmt: off @tvm.script.ir_module @@ -96,6 +158,41 @@ def take(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_take: tvm.ir.assert_structural_equal(mod, Expected) +def test_take_symbolic_prim_value(): + # fmt: off + @tvm.script.ir_module + class Take: + @R.function + def main(x: R.Tensor((2, "n", 4), "float32")) -> R.Tensor((2, 4), "float32"): + n = T.int64() + gv: R.Tensor((2, 4), "float32") = R.take(x, R.prim_value(n-1), axis=1) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, "n", 4), "float32")) -> R.Tensor((2, 4), "float32"): + gv = R.call_tir(Expected.take, (x,), R.Tensor((2, 4), dtype="float32")) + return gv + + @T.prim_func(private=True) + def take(x_handle: T.handle, T_take: T.Buffer((T.int64(2), T.int64(4)), "float32")): + n = T.int64() + rxplaceholder = T.match_buffer(x_handle, (T.int64(2), n, T.int64(4)), "float32") + + T.func_attr({"tir.noalias": True}) + for i0, i2 in T.grid(T.int64(2), T.int64(4)): + with T.block("T_take"): + ax0, ax2 = T.axis.remap("SS", [i0, i2]) + T.reads(rxplaceholder[ax0, n-1, ax2]) + T.writes(T_take[ax0, ax2]) + T_take[ax0, ax2] = rxplaceholder[ax0, n-1, ax2] + # fmt: on + + mod = LegalizeOps()(Take) + tvm.ir.assert_structural_equal(mod, Expected) + + def test_strided_slice(): # fmt: off @tvm.script.ir_module From 26aee6a1e9eef1c236e91fea8f3d822737d4e405 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Sat, 27 Apr 2024 08:19:36 -0500 Subject: [PATCH 2/3] Avoid comparison between signed/unsigned --- include/tvm/topi/transform.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index 30d2bd7a0b7d..3292ce57ba5c 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -1085,7 +1085,7 @@ inline Tensor take(const Tensor& a, Variant indices, int batch for (int i = batch_dims_; i < axis; ++i) { out_shape.push_back(a->shape[i]); } - for (size_t i = static_cast(batch_dims_); i < indices_len; ++i) { + for (int i = batch_dims_; i < indices_len; ++i) { out_shape.push_back(indices_shape[i]); } for (size_t i = axis + 1; i < a->shape.size(); ++i) { From 1d2c4760a48c3f6d3ad17e8ac99b95a424394c92 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Sat, 27 Apr 2024 10:31:02 -0500 Subject: [PATCH 3/3] Resolve/silence gcc warnings --- src/relax/op/op_common.cc | 6 +++++- src/relax/op/tensor/index.cc | 4 ++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/relax/op/op_common.cc b/src/relax/op/op_common.cc index 1ccc93e0b230..56bf708f5e06 100644 --- a/src/relax/op/op_common.cc +++ b/src/relax/op/op_common.cc @@ -63,6 +63,10 @@ TensorStructInfo GetInputTensorStructInfo(const Call& call, size_t i_arg, const << "Operator " << op << " requires argument " << i_arg << " (" << op->arguments[i_arg]->name << ") to be a tensor. " << "However, the argument " << arg << " is instead of type " << sinfo); + // Unreachable, but [[noreturn]] attribute on virtual function + // `ReportFatal` is insufficient to silence -Wreturn-type, as + // child class might not be [[noreturn]]. + return TensorStructInfo(); } } @@ -71,7 +75,7 @@ Array GetInputTensorStructInfo(const Call& call, const BlockBu Op op = Downcast(call->op); Array input_tensor_sinfo; - for (int i = 0; i < call->args.size(); ++i) { + for (size_t i = 0; i < call->args.size(); ++i) { input_tensor_sinfo.push_back(GetInputTensorStructInfo(call, i, ctx)); } return input_tensor_sinfo; diff --git a/src/relax/op/tensor/index.cc b/src/relax/op/tensor/index.cc index 7d871a62c5da..d052c2a64f9c 100644 --- a/src/relax/op/tensor/index.cc +++ b/src/relax/op/tensor/index.cc @@ -61,6 +61,10 @@ StructInfo InferStructInfoTake(const Call& call, const BlockBuilder& ctx) { << "Operator " << call->op << " requires the indices argument to be " << "either a tensor or a scalar value. " << "However, argument " << arg << " has struct info " << sinfo); + // Unreachable, but [[noreturn]] attribute on virtual function + // `ReportFatal` is insufficient to silence -Wreturn-type, as + // child class might not be [[noreturn]]. + return TensorStructInfo(); } }();