Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/tvm/relax/block_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ class BlockBuilderNode : public Object {
* \brief Report an error during transformation construction.
* \param diagnostic The diagnostic information.
*/
virtual void ReportFatal(const Diagnostic& diagnostic) = 0;
[[noreturn]] virtual void ReportFatal(const Diagnostic& diagnostic) = 0;

//-------------------------------
// Scope management
Expand Down
43 changes: 31 additions & 12 deletions include/tvm/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -1036,7 +1036,7 @@ inline Tensor sequence_mask(const Tensor& data, const Tensor& valid_length, doub
*
* \return A Tensor whose op member is the take operation
*/
inline Tensor take(const Tensor& a, const Tensor& indices, int batch_dims, int axis,
inline Tensor take(const Tensor& a, Variant<Tensor, PrimExpr> indices, int batch_dims, int axis,
std::string mode = "clip", std::string name = "T_take",
std::string tag = kInjective) {
if (axis < 0) {
Expand All @@ -1045,22 +1045,30 @@ inline Tensor take(const Tensor& a, const Tensor& indices, int batch_dims, int a
ICHECK_GE(axis, 0) << "axis out of bounds";
ICHECK_LT(axis, a->shape.size()) << "axis out of bounds";
auto axis_dim = a->shape[axis];
int indices_len = static_cast<int>(indices->shape.size());
auto indices_shape = [&]() -> Array<PrimExpr> {
if (auto tensor = indices.as<TensorNode>()) {
return tensor->shape;
} else {
return {};
}
}();

int indices_len = static_cast<int>(indices_shape.size());

int batch_dims_ = batch_dims;
if (batch_dims_ != 0) {
ICHECK_GE(batch_dims_, -static_cast<int>(indices->shape.size())) << "batch_dims out of bounds";
ICHECK_LE(batch_dims_, indices->shape.size()) << "batch_dims out of bounds";
ICHECK_GE(batch_dims_, -indices_len) << "batch_dims out of bounds";
ICHECK_LE(batch_dims_, indices_len) << "batch_dims out of bounds";

if (batch_dims_ < 0) {
batch_dims_ = indices->shape.size() + batch_dims_;
batch_dims_ = indices_len + batch_dims_;
}

ICHECK_LT(batch_dims_, a->shape.size()) << "batch_dims out of bounds";
ICHECK_LE(batch_dims_, axis) << "batch_dims must be less than or equal to axis";
for (int i = 0; i < batch_dims_; ++i) {
auto addr1 = a->shape[i];
auto addr2 = indices->shape[i];
auto addr2 = indices_shape[i];
auto v1 = static_cast<IntImm*>(&addr1)->get()->value;
auto v2 = static_cast<IntImm*>(&addr2)->get()->value;
ICHECK_EQ(v1, v2) << "a.shape[" << i << "] should be equal to indices.shape[" << i << "]";
Expand All @@ -1077,13 +1085,24 @@ inline Tensor take(const Tensor& a, const Tensor& indices, int batch_dims, int a
for (int i = batch_dims_; i < axis; ++i) {
out_shape.push_back(a->shape[i]);
}
for (size_t i = static_cast<size_t>(batch_dims_); i < indices->shape.size(); ++i) {
out_shape.push_back(indices->shape[i]);
for (int i = batch_dims_; i < indices_len; ++i) {
out_shape.push_back(indices_shape[i]);
}
for (size_t i = axis + 1; i < a->shape.size(); ++i) {
out_shape.push_back(a->shape[i]);
}

auto get_index = [&](const Array<PrimExpr>& indices_position) -> PrimExpr {
if (auto tensor = indices.as<Tensor>()) {
return tensor.value()(indices_position);
} else if (auto prim = indices.as<PrimExpr>()) {
ICHECK_EQ(indices_position.size(), 0);
return prim.value();
} else {
LOG(FATAL) << "Variant did not contain either allowed type";
}
};

if (mode == "clip") {
if (batch_dims_ == 0) {
return compute(
Expand All @@ -1097,7 +1116,7 @@ inline Tensor take(const Tensor& a, const Tensor& indices, int batch_dims, int a
for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
real_indices.push_back(out_index[j]);
}
auto idx = tvm::min(tvm::max(0, indices(indices_position)), axis_dim - 1);
auto idx = tvm::min(tvm::max(0, get_index(indices_position)), axis_dim - 1);
real_indices.push_back(idx);
for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
real_indices.push_back(out_index[j]);
Expand All @@ -1120,7 +1139,7 @@ inline Tensor take(const Tensor& a, const Tensor& indices, int batch_dims, int a
for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
real_indices.push_back(out_index[j]);
}
auto idx = tvm::min(tvm::max(0, indices(indices_position)), axis_dim - 1);
auto idx = tvm::min(tvm::max(0, get_index(indices_position)), axis_dim - 1);
real_indices.push_back(idx);
for (size_t j = axis + indices_len - batch_dims_; j < out_index.size(); ++j) {
real_indices.push_back(out_index[j]);
Expand All @@ -1141,7 +1160,7 @@ inline Tensor take(const Tensor& a, const Tensor& indices, int batch_dims, int a
for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
real_indices.push_back(out_index[j]);
}
real_indices.push_back(indices(indices_position));
real_indices.push_back(get_index(indices_position));
for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
real_indices.push_back(out_index[j]);
}
Expand All @@ -1160,7 +1179,7 @@ inline Tensor take(const Tensor& a, const Tensor& indices, int batch_dims, int a
for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
real_indices.push_back(out_index[j]);
}
auto idx = truncmod(truncmod(indices(indices_position), axis_dim) + axis_dim, axis_dim);
auto idx = truncmod(truncmod(get_index(indices_position), axis_dim) + axis_dim, axis_dim);
real_indices.push_back(idx);
for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
real_indices.push_back(out_index[j]);
Expand Down
2 changes: 1 addition & 1 deletion src/relax/ir/block_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ class BlockBuilderImpl : public BlockBuilderNode {
}
}

void ReportFatal(const Diagnostic& diagnostic) final {
[[noreturn]] void ReportFatal(const Diagnostic& diagnostic) final {
// TODO(relax-team): Print more context information by looking
// into the diagnostic->loc and surrounding IRModule.
// We do not materialzie DiagnosticContext to avoid double referencing to
Expand Down
52 changes: 38 additions & 14 deletions src/relax/op/op_common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,24 +35,48 @@ Array<Expr> GetCallArgs(const Call& call) {
return args;
}

Array<TensorStructInfo> GetInputTensorStructInfo(const Call& call, const BlockBuilder& ctx) {
void CheckNumArguments(const Call& call, const BlockBuilder& ctx) {
Op op = Downcast<Op>(call->op);
int n_input = op->arguments.size();
if (static_cast<int>(call->args.size()) != n_input) {
int expected_input = op->arguments.size();
if (static_cast<int>(call->args.size()) != expected_input) {
ctx->ReportFatal(Diagnostic::Error(call)
<< op << " op should have " << n_input << " arguments");
<< "Operator " << op << " expects " << expected_input << " arguments"
<< ", but was called with " << call->args.size() << " arguments");
}
}

TensorStructInfo GetInputTensorStructInfo(const Call& call, size_t i_arg, const BlockBuilder& ctx) {
Op op = Downcast<Op>(call->op);

ICHECK_EQ(op->arguments.size(), call->args.size())
<< "Failure caught by this check "
<< "should have previously been caught by `CheckNumArguments`";
ICHECK_LT(i_arg, op->arguments.size());

auto arg = call->args[i_arg];
auto sinfo = GetStructInfo(arg);

if (auto tensor_sinfo = sinfo.as<TensorStructInfo>()) {
return tensor_sinfo.value();
} else {
ctx->ReportFatal(Diagnostic::Error(call)
<< "Operator " << op << " requires argument " << i_arg << " ("
<< op->arguments[i_arg]->name << ") to be a tensor. "
<< "However, the argument " << arg << " is instead of type " << sinfo);
// Unreachable, but [[noreturn]] attribute on virtual function
// `ReportFatal` is insufficient to silence -Wreturn-type, as
// child class might not be [[noreturn]].
return TensorStructInfo();
}
}

Array<TensorStructInfo> GetInputTensorStructInfo(const Call& call, const BlockBuilder& ctx) {
CheckNumArguments(call, ctx);

Op op = Downcast<Op>(call->op);
Array<TensorStructInfo> input_tensor_sinfo;
input_tensor_sinfo.reserve(n_input);
for (int i = 0; i < n_input; ++i) {
const auto* sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[i]);
if (sinfo == nullptr) {
ctx->ReportFatal(Diagnostic::Error(call)
<< op << " requires the input " << op->arguments[i]->name
<< " to be Tensor. However, the given one has a "
<< call->args[i]->struct_info_->GetTypeKey());
}
input_tensor_sinfo.push_back(GetRef<TensorStructInfo>(sinfo));
for (size_t i = 0; i < call->args.size(); ++i) {
input_tensor_sinfo.push_back(GetInputTensorStructInfo(call, i, ctx));
}
return input_tensor_sinfo;
}
Expand Down
21 changes: 21 additions & 0 deletions src/relax/op/op_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,27 @@ namespace relax {

/************ Op input struct info getter ************/

/*!
* \brief Check that the operator has
*
* Verify that the number of arguments matches the expected number for
* the operator.
*
* \param call The context Call to the operator.
*
* \param ctx The error reporting context.
*/
void CheckNumArguments(const Call& call, const BlockBuilder& ctx);

/*!
* \brief Get the tensor struct info of the operator input.
* \param call The context Call to the operator.
* \param i_arg The index of the argument to check
* \param ctx The error reporting context.
* \return The tensor struct info of the argument
*/
TensorStructInfo GetInputTensorStructInfo(const Call& call, size_t i_arg, const BlockBuilder& ctx);

/*!
* \brief Get the tensor struct info of the operator input.
* \param call The context Call to the operator.
Expand Down
26 changes: 23 additions & 3 deletions src/relax/op/tensor/index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,29 @@ Expr take(Expr x, Expr indices, Optional<Integer> axis) {
TVM_REGISTER_GLOBAL("relax.op.take").set_body_typed(take);

StructInfo InferStructInfoTake(const Call& call, const BlockBuilder& ctx) {
Array<TensorStructInfo> input_sinfo = GetInputTensorStructInfo(call, ctx);
TensorStructInfo data_sinfo = input_sinfo[0];
TensorStructInfo indices_sinfo = input_sinfo[1];
CheckNumArguments(call, ctx);
TensorStructInfo data_sinfo = GetInputTensorStructInfo(call, 0, ctx);

// StructInfo inference when the index is a PrimValue is equivalent
// to that of a scalar (0-d) tensor.
TensorStructInfo indices_sinfo = [&]() {
auto arg = call->args[1];
auto sinfo = GetStructInfo(arg);
if (auto tensor_sinfo = sinfo.as<TensorStructInfo>()) {
return tensor_sinfo.value();
} else if (auto prim_sinfo = sinfo.as<PrimStructInfoNode>()) {
return TensorStructInfo(ShapeExpr(Array<PrimExpr>{}), prim_sinfo->dtype);
} else {
ctx->ReportFatal(Diagnostic::Error(call)
<< "Operator " << call->op << " requires the indices argument to be "
<< "either a tensor or a scalar value. "
<< "However, argument " << arg << " has struct info " << sinfo);
// Unreachable, but [[noreturn]] attribute on virtual function
// `ReportFatal` is insufficient to silence -Wreturn-type, as
// child class might not be [[noreturn]].
return TensorStructInfo();
}
}();

if (indices_sinfo->IsUnknownDtype()) {
// TODO(tvm-team): Do we have an equivalent of `ctx->ReportFatal` for warning?
Expand Down
18 changes: 18 additions & 0 deletions tests/python/relax/test_op_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,24 @@ def test_take_infer_struct_info():
_check_inference(bb, relax.op.take(y3, idx7), relax.TensorStructInfo(dtype="", ndim=2))


def test_take_infer_struct_info_scalar_tensor_index():
bb = relax.BlockBuilder()
x0 = relax.Var("x", R.Tensor((4, 10), "float32"))
idx = relax.Var("idx", R.Tensor([], "int64"))

_check_inference(bb, relax.op.take(x0, idx, axis=0), relax.TensorStructInfo([10], "float32"))
_check_inference(bb, relax.op.take(x0, idx, axis=1), relax.TensorStructInfo([4], "float32"))


def test_take_infer_struct_info_prim_value_index():
bb = relax.BlockBuilder()
x0 = relax.Var("x", R.Tensor((4, 10), "float32"))
idx = relax.Var("idx", R.Prim("int64"))

_check_inference(bb, relax.op.take(x0, idx, axis=0), relax.TensorStructInfo([10], "float32"))
_check_inference(bb, relax.op.take(x0, idx, axis=1), relax.TensorStructInfo([4], "float32"))


def test_take_infer_struct_info_shape_symbolic():
bb = relax.BlockBuilder()
m = tir.Var("m", "int64")
Expand Down
Loading