Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions include/tvm/relax/attrs/manipulate.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,23 @@ struct FlipAttrs : public tvm::AttrsNode<FlipAttrs> {
}
}; // struct FlipAttrs

/*! \brief Attributes used in gather_elements operators */
struct GatherElementsAttrs : public tvm::AttrsNode<GatherElementsAttrs> {
Integer axis;

TVM_DECLARE_ATTRS(GatherElementsAttrs, "relax.attrs.GatherElementsAttrs") {
TVM_ATTR_FIELD(axis).set_default(0).describe("The axis along which to index.");
}
}; // struct GatherElementsAttrs

/*! \brief Attributes used in gather_nd operators */
struct GatherNDAttrs : public tvm::AttrsNode<GatherNDAttrs> {
Integer batch_dims;
TVM_DECLARE_ATTRS(GatherNDAttrs, "relax.attrs.GatherNDAttrs") {
TVM_ATTR_FIELD(batch_dims).set_default(Integer(0)).describe("The number of batch dims.");
}
}; // struct GatherNDAttrs

/*! \brief Attributes used in scatter_elements operators */
struct ScatterElementsAttrs : public tvm::AttrsNode<ScatterElementsAttrs> {
Integer axis;
Expand Down
22 changes: 20 additions & 2 deletions python/tvm/relax/frontend/onnx/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,24 @@ def _impl_v13(cls, bb, inputs, attr, params):
return relax.op.take(data, indices, axis)


class GatherElements(OnnxOpConverter):
"""Convert an onnx GatherElements node into an equivalent Relax expression."""

@classmethod
def _impl_v13(cls, bb, inputs, attr, params):
axis = attr.get("axis", 0)
return relax.op.gather_elements(inputs[0], inputs[1], axis)


class GatherND(OnnxOpConverter):
"""Convert an onnx GatherND node into an equivalent Relax expression."""

@classmethod
def _impl_v13(cls, bb, inputs, attr, params):
batch_dims = attr.get("batch_dims", 0)
return relax.op.gather_nd(inputs[0], inputs[1], batch_dims)


class Scatter(OnnxOpConverter):
"""Convert an onnx Scatter node into an equivalent Relax expression."""

Expand Down Expand Up @@ -3070,8 +3088,8 @@ def _get_convert_map():
"Squeeze": Squeeze,
"Constant": Constant,
"Gather": Gather,
# "GatherElements": GatherElements,
# "GatherND": GatherND,
"GatherElements": GatherElements,
"GatherND": GatherND,
"Scatter": Scatter,
"ScatterElements": ScatterElements,
"ScatterND": ScatterND,
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relax/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@
expand_dims,
flatten,
flip,
gather_elements,
gather_nd,
layout_transform,
one_hot,
permute_dims,
Expand Down
73 changes: 73 additions & 0 deletions python/tvm/relax/op/manipulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,79 @@ def flip(data, axis):
return _ffi_api.flip(data, axis) # type: ignore


def gather_elements(data: Expr, indices: Expr, axis: int = 0) -> Expr:
"""Gather elements from data according to indices along the specified axis.

Parameters
----------
data : relax.Expr
The input data to the operator.

indices : relax.Expr
The indices tensor, must have integer type.

axis : int
The axis along which to index. Default is 0.

Returns
-------
ret : relax.Expr
The computed result.

Examples
--------
.. code-block:: python

data = [[1, 2], [3, 4]]
indices = [[0, 0], [1, 0]]
axis = 1
output = [[1, 1], [4, 3]]

data = [[1, 2, 3], [4, 5, 6]]
indices = [[1, 1, 1]]
axis = 0
output = [[4, 5, 6]]
"""
return _ffi_api.gather_elements(data, indices, axis) # type: ignore


def gather_nd(data: Expr, indices: Expr, batch_dims: int = 0) -> Expr:
"""Update data at positions defined by indices with values in updates.

Parameters
----------
data : relax.Expr
The input data to the operator.

indices : relax.Expr
The indices tensor, must have integer type.

batch_dims : int
The number of batch dimensions. Default is 0.

Returns
-------
ret : relax.Expr
The computed result.

Examples
--------
.. code-block:: python

batch_dims = 0
data = [[0,1],[2,3]] # data_shape = [2, 2]
indices = [[0,0],[1,1]] # indices_shape = [2, 2]
output = [0,3] # output_shape = [2]

batch_dims = 1
data = [[[0,1],[2,3]],[[4,5],[6,7]]] # data_shape = [2, 2, 2]
indices = [[1],[0]] # indices_shape = [2, 1]
output = [[2,3],[4,5]] # output_shape = [2, 2]

"""
return _ffi_api.gather_nd(data, indices, batch_dims) # type: ignore


def scatter_elements(
data: Expr, indices: Expr, updates: Expr, axis: int = 0, reduction: str = "update"
):
Expand Down
16 changes: 16 additions & 0 deletions python/tvm/relax/transform/legalize_ops/manipulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,22 @@ def _flip(bb: BlockBuilder, call: Call) -> Expr:
return bb.call_te(topi.flip, call.args[0], int(call.attrs.axis))


@register_legalize("relax.gather_elements")
def _gather_elements(bb: BlockBuilder, call: Call) -> Expr:
return bb.call_te(topi.gather, call.args[0], int(call.attrs.axis), call.args[1])


@register_legalize("relax.gather_nd")
def _gather_nd(bb: BlockBuilder, call: Call) -> Expr:
def te_gather_nd(data, indices, batch_dims):
indices_ndim = len(indices.shape)
axes = [indices_ndim - 1] + list(range(indices_ndim - 1))
indices = topi.transpose(indices, axes)
return topi.gather_nd(data, indices, batch_dims)

return bb.call_te(te_gather_nd, call.args[0], call.args[1], int(call.attrs.batch_dims))


@register_legalize("relax.scatter_elements")
def _scatter_elements(bb: BlockBuilder, call: Call) -> Expr:
return bb.call_te(
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/script/ir_builder/relax/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@
floor_mod,
full,
full_like,
gather_elements,
gather_nd,
grad,
greater,
greater_equal,
Expand Down Expand Up @@ -772,6 +774,8 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
"func_ret_struct_info",
"func_ret_value",
"function",
"gather_elements",
"gather_nd",
"gpu",
"grad",
"greater",
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/topi/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ def gather(data, axis, indices):
return cpp.gather(data, axis, indices)


def gather_nd(a, indices):
def gather_nd(a, indices, batch_dims=0):
"""Gather elements from a n-dimension array..

Parameters
Expand All @@ -540,7 +540,7 @@ def gather_nd(a, indices):
-------
ret : tvm.te.Tensor
"""
return cpp.gather_nd(a, indices)
return cpp.gather_nd(a, indices, batch_dims)


def matmul(a, b, transp_a=False, transp_b=False):
Expand Down
163 changes: 163 additions & 0 deletions src/relax/op/tensor/manipulate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1418,6 +1418,169 @@ TVM_REGISTER_OP("relax.flip")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoFlip)
.set_attr<Bool>("FPurity", Bool(true));

/* relax.gather_elements */
TVM_REGISTER_NODE_TYPE(GatherElementsAttrs);

Expr gather_elements(Expr data, Expr indices, int axis) {
auto attrs = make_object<GatherElementsAttrs>();
attrs->axis = Integer(axis);
static const Op& op = Op::Get("relax.gather_elements");
return Call(op, {data, indices}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relax.op.gather_elements").set_body_typed(gather_elements);

StructInfo InferStructInfoGatherElements(const Call& call, const BlockBuilder& ctx) {
const auto* data_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
const auto* indices_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[1]);
const auto* attrs = call->attrs.as<GatherElementsAttrs>();

if (data_sinfo == nullptr) {
ctx->ReportFatal(
Diagnostic::Error(call)
<< "GatherElements requires the input data to be a Tensor. However, the given one is "
<< call->args[0]->struct_info_->GetTypeKey());
}
if (indices_sinfo == nullptr) {
ctx->ReportFatal(
Diagnostic::Error(call)
<< "GatherElements requires the input indices to be a Tensor. However, the given one is "
<< call->args[1]->struct_info_->GetTypeKey());
}

if (!indices_sinfo->IsUnknownDtype() && !indices_sinfo->dtype.is_int()) {
ctx->ReportFatal(
Diagnostic::Error(call)
<< "GatherElements requires the input indices to have int64 dtype. However, the "
<< "given indices dtype is " << indices_sinfo->dtype);
}

if (data_sinfo->IsUnknownNdim() || indices_sinfo->IsUnknownNdim()) {
return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice);
}

int axis = attrs->axis.IntValue();
if (axis < -data_sinfo->ndim || axis >= data_sinfo->ndim) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "GatherElements requires axis to be within the input dimension range ["
<< -data_sinfo->ndim << ", " << data_sinfo->ndim - 1 << "]. However, the "
<< "given axis is " << axis);
}

if (data_sinfo->ndim != indices_sinfo->ndim) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "GatherElements requires data and indices to have the same rank. However, "
<< "data rank is " << data_sinfo->ndim << " while indices rank is "
<< indices_sinfo->ndim);
}
if (indices_sinfo->shape.defined()) {
return TensorStructInfo(indices_sinfo->shape.value(), data_sinfo->dtype, data_sinfo->vdevice);
}
return TensorStructInfo(data_sinfo->dtype, indices_sinfo->ndim, data_sinfo->vdevice);
}

TVM_REGISTER_OP("relax.gather_elements")
.set_attrs_type<GatherElementsAttrs>()
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("indices", "Tensor", "The indices tensor.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoGatherElements)
.set_attr<Bool>("FPurity", Bool(true));

/* relax.gather_nd */
TVM_REGISTER_NODE_TYPE(GatherNDAttrs);

Expr gather_nd(Expr data, Expr indices, int batch_dims) {
auto attrs = make_object<GatherNDAttrs>();
attrs->batch_dims = Integer(batch_dims);
static const Op& op = Op::Get("relax.gather_nd");
return Call(op, {data, indices}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relax.op.gather_nd").set_body_typed(gather_nd);

StructInfo InferStructInfoGatherND(const Call& call, const BlockBuilder& ctx) {
const auto* data_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
const auto* indices_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[1]);
const auto* attrs = call->attrs.as<GatherNDAttrs>();

if (data_sinfo == nullptr) {
ctx->ReportFatal(
Diagnostic::Error(call)
<< "GatherND requires the input data to be a Tensor. However, the given one is "
<< call->args[0]->struct_info_->GetTypeKey());
}
if (indices_sinfo == nullptr) {
ctx->ReportFatal(
Diagnostic::Error(call)
<< "GatherND requires the input indices to be a Tensor. However, the given one is "
<< call->args[1]->struct_info_->GetTypeKey());
}
ICHECK_GE(attrs->batch_dims.IntValue(), 0);
int batch_dims = attrs->batch_dims.IntValue();
int input_dims = data_sinfo->ndim;
if (!indices_sinfo->IsUnknownDtype() && indices_sinfo->dtype != DataType::Int(64)) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "GatherND requires the input indices to have int64 dtype. However, the "
<< "given indices dtype is " << indices_sinfo->dtype);
}

if (data_sinfo->IsUnknownNdim() || indices_sinfo->IsUnknownNdim()) {
return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice);
}

if (batch_dims < 0 || batch_dims > data_sinfo->ndim) {
ctx->ReportFatal(
Diagnostic::Error(call)
<< "GatherND batch_dims must be in range [0, data.ndim]. However, got batch_dims="
<< batch_dims << ", data.ndim=" << input_dims);
}

if (batch_dims > indices_sinfo->ndim - 1) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "GatherND batch_dims cannot exceed indices.ndim-1. However, got batch_dims="
<< batch_dims << ", indices.ndim=" << indices_sinfo->ndim);
}

// Check if indices shape is known
const auto* indices_shape = indices_sinfo->shape.as<ShapeExprNode>();
const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>();
if (!indices_shape || !indices_shape->values.back()->IsInstance<IntImmNode>()) {
return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice);
}
int l = indices_shape->values.back().as<IntImmNode>()->value;
int output_ndim = indices_sinfo->ndim + input_dims - l - 1 - batch_dims;
if (!data_shape) {
return TensorStructInfo(data_sinfo->dtype, output_ndim, data_sinfo->vdevice);
}

// In this condition, all input shapes are known
Array<PrimExpr> out_shape;
if (l > input_dims - batch_dims) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "GatherND requires the last dimension of indices to be less than or "
"equal to the rank of data minus batch_dims. However, the given shapes are "
<< "indices: " << ShapeExpr(indices_shape->values) << ", data: "
<< ShapeExpr(data_shape->values) << ", with batch_dims=" << batch_dims);
}
for (int i = 0; i < indices_sinfo->ndim - 1; ++i) {
out_shape.push_back(indices_shape->values[i]);
}
for (int i = batch_dims + l; i < input_dims; ++i) {
out_shape.push_back(data_shape->values[i]);
}
ICHECK_EQ(out_shape.size(), output_ndim);
return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, data_sinfo->vdevice);
}

TVM_REGISTER_OP("relax.gather_nd")
.set_attrs_type<GatherNDAttrs>()
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("indices", "Tensor", "The indices tensor.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoGatherND)
.set_attr<Bool>("FPurity", Bool(true));

/* relax.scatter_elements */
TVM_REGISTER_NODE_TYPE(ScatterElementsAttrs);

Expand Down
Loading
Loading