Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions include/tvm/relax/attrs/manipulate.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,18 @@ struct ScatterElementsAttrs : public tvm::AttrsNode<ScatterElementsAttrs> {
"either \"update\", \"add\", \"mul\", \"mean\", \"min\" or \"max\".");
}
}; // struct ScatterElementsAttrs

/*! \brief Attributes used in scatter_nd operators */
struct ScatterNDAttrs : public tvm::AttrsNode<ScatterNDAttrs> {
String reduction;

TVM_DECLARE_ATTRS(ScatterNDAttrs, "relax.attrs.ScatterNDAttrs") {
TVM_ATTR_FIELD(reduction).set_default("update").describe(
"Accumulation mode of the ScatterND, "
"either \"update\", \"add\", \"mul\", \"min\" or \"max\".");
}
}; // struct ScatterNDAttrs

} // namespace relax
} // namespace tvm

Expand Down
32 changes: 31 additions & 1 deletion python/tvm/relax/frontend/onnx/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,36 @@ def _impl_v11(cls, bb, inputs, attr, params):
return relax.op.scatter_elements(inputs[0], inputs[1], inputs[2], axis=axis)


class ScatterND(OnnxOpConverter):
"""Convert an onnx ScatterND node into an equivalent Relax expression."""

@staticmethod
def _reduction_check(attr, valid_reductions: List[str]):
reduction = attr.get("reduction", None)
reduction = reduction or b"update"
reduction = reduction.decode("utf-8")
reduction = "update" if reduction == "none" else reduction
assert (
reduction in valid_reductions
), f"Only {valid_reductions} reductions are supported, but {reduction} is gotten"

return reduction

@classmethod
def _impl_v11(cls, bb, inputs, attr, params):
return relax.op.scatter_nd(inputs[0], inputs[1], inputs[2])

@classmethod
def _impl_v16(cls, bb, inputs, attr, params):
reduction = cls._reduction_check(attr, ["update", "add", "mul"])
return relax.op.scatter_nd(inputs[0], inputs[1], inputs[2], reduction)

@classmethod
def _impl_v18(cls, bb, inputs, attr, params):
reduction = cls._reduction_check(attr, ["update", "add", "mul", "min", "max"])
return relax.op.scatter_nd(inputs[0], inputs[1], inputs[2], reduction)


class Size(OnnxOpConverter):
"""Convert an onnx Size node into an equivalent Relax expression."""

Expand Down Expand Up @@ -2729,7 +2759,7 @@ def _get_convert_map():
# "GatherND": GatherND,
"Scatter": Scatter,
"ScatterElements": ScatterElements,
# "ScatterND": ScatterND,
"ScatterND": ScatterND,
# "Compress": Compress,
"Size": Size,
# "EyeLike": EyeLike,
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relax/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
repeat,
reshape,
scatter_elements,
scatter_nd,
split,
squeeze,
tile,
Expand Down
39 changes: 39 additions & 0 deletions python/tvm/relax/op/manipulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,3 +511,42 @@ def scatter_elements(

"""
return _ffi_api.scatter_elements(data, indices, updates, axis, reduction) # type: ignore


def scatter_nd(data: Expr, indices: Expr, updates: Expr, reduction: str = "update") -> Expr:
"""Scatter updates into an array according to indices.

Parameters
----------
data: relax.Expr
The input data to be updated.

indices: relax.Expr
The index positions to update in `data`.

updates: relax.Expr
Values to replace to.

reduction: str
Type of reduction to apply: update, add, mul, max, min.
It is "update" by default.

Returns
-------
result : relax.Expr
The result has the same shape as data.

Examples
--------
.. code-block:: python

# inputs
data = [1, 2, 3, 4, 5, 6, 7, 8]
indices = [[4], [3], [1], [7]]
updates = [9, 10, 11, 12]

# output
output = [1, 11, 3, 10, 9, 6, 7, 12]

"""
return _ffi_api.scatter_nd(data, indices, updates, reduction) # type: ignore
17 changes: 17 additions & 0 deletions python/tvm/relax/transform/legalize_ops/manipulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,23 @@ def _scatter_elements(bb: BlockBuilder, call: Call) -> Expr:
)


@register_legalize("relax.scatter_nd")
def _scatter_nd(bb: BlockBuilder, call: Call) -> Expr:
# TODO(relax-team): Support native scatter_nd without te extern
def scatter_nd(data, indices, updates, reduction):
axes = list(range(len(indices.shape)))
indices = topi.transpose(indices, axes[-1:] + axes[:-1])
return topi.scatter_nd(data, indices, updates, reduction)

return bb.call_te(
scatter_nd,
call.args[0],
call.args[1],
call.args[2],
call.attrs.reduction,
)


@register_legalize("relax.layout_transform")
def _layout_transform(bb: BlockBuilder, call: Call) -> Expr:
def te_layout_transform(data, name):
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/script/ir_builder/relax/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@
round,
rsqrt,
scatter_elements,
scatter_nd,
shape_of,
shape_to_tensor,
sigmoid,
Expand Down Expand Up @@ -736,6 +737,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
"cumsum",
"einsum",
"scatter_elements",
"scatter_nd",
"dataflow",
"device",
"divide",
Expand Down
134 changes: 134 additions & 0 deletions src/relax/op/tensor/manipulate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1531,5 +1531,139 @@ TVM_REGISTER_OP("relax.scatter_elements")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoScatterElements)
.set_attr<Bool>("FPurity", Bool(true));

/* relax.scatter_nd */
TVM_REGISTER_NODE_TYPE(ScatterNDAttrs);

Expr scatter_nd(Expr data, Expr indices, Expr updates, String reduction) {
auto attrs = make_object<ScatterNDAttrs>();
attrs->reduction = std::move(reduction);
static const Op& op = Op::Get("relax.scatter_nd");
return Call(op, {data, indices, updates}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relax.op.scatter_nd").set_body_typed(scatter_nd);

StructInfo InferStructInfoScatterND(const Call& call, const BlockBuilder& ctx) {
// `call->args` contains: [data, indices, updates]
arith::Analyzer* analyzer = ctx->GetAnalyzer();
ICHECK_EQ(call->args.size(), 3);
const auto* data_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
const auto* indices_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[1]);
const auto* updates_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[2]);

if (data_sinfo == nullptr) {
ctx->ReportFatal(
Diagnostic::Error(call)
<< "ScatterND op requires the input data to be a tensor. However, the given type is "
<< call->args[0]->GetTypeKey());
}
if (indices_sinfo == nullptr) {
ctx->ReportFatal(
Diagnostic::Error(call)
<< "ScatterND op requires the input indices to be a tensor. However, the given type is "
<< call->args[1]->GetTypeKey());
}
if (updates_sinfo == nullptr) {
ctx->ReportFatal(
Diagnostic::Error(call)
<< "ScatterND op requires the input updates to be a tensor. However, the given type is "
<< call->args[2]->GetTypeKey());
}

if (data_sinfo->IsUnknownDtype() || updates_sinfo->IsUnknownDtype()) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "ScatterND op requires the input data and updates to have known dtype. "
"However, the given types are "
<< "data: " << data_sinfo->dtype << ", updates: " << updates_sinfo->dtype);
}

if (data_sinfo->dtype != updates_sinfo->dtype) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "ScatterND op requires the input data to have same type with updates. "
"However, the given types are "
<< "data: " << data_sinfo->dtype << ", updates: " << updates_sinfo->dtype);
}

if (indices_sinfo->IsUnknownDtype()) {
LOG(WARNING) << "Data type of indices has not been specified. Assume it has an integer type.";
} else if (!(indices_sinfo->dtype.is_int() || indices_sinfo->dtype.is_uint())) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "ScatterND op requires the input indices to have integer dtype. However, "
"the given indices dtype is "
<< indices_sinfo->dtype);
}

const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>();
const auto* indices_shape = indices_sinfo->shape.as<ShapeExprNode>();
const auto* updates_shape = updates_sinfo->shape.as<ShapeExprNode>();

if (data_shape && indices_shape && updates_shape) {
const IntImmNode* k_dim = indices_shape->values[indices_sinfo->ndim - 1].as<IntImmNode>();
if (!k_dim) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "ScatterND needs a static shape for the last axis of indices, got "
<< indices_shape->values);
}
const size_t data_ndim = data_sinfo->ndim;
const size_t indices_ndim = indices_sinfo->ndim;
const size_t updates_ndim = updates_sinfo->ndim;
if (data_ndim + indices_ndim - k_dim->value - 1 != updates_ndim) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "ScatterND op requires the updates tensor to have the rank of "
"`data tensor + indices tensor - last axis of indices tensor - 1`. "
"However, the given shapes are "
<< "data: " << ShapeExpr(data_shape->values)
<< ", indices: " << ShapeExpr(indices_shape->values)
<< ", updates: " << ShapeExpr(updates_shape->values));
}
if (k_dim->value > static_cast<int>(data_ndim)) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "ScatterND op requires the last axis of indices tensor to be less than "
"or equal to the rank of data tensor. However, the given shapes are "
<< "data: " << ShapeExpr(data_shape->values)
<< ", indices: " << ShapeExpr(indices_shape->values));
}
Array<PrimExpr> expected_updates_shape;
for (size_t i = 0; i < indices_ndim - 1; i++) {
expected_updates_shape.push_back(indices_shape->values[i]);
}
for (size_t i = k_dim->value; i < data_ndim; i++) {
expected_updates_shape.push_back(data_shape->values[i]);
}
auto check_shape = [&](const Array<PrimExpr>& expected, const Array<PrimExpr>& actual) {
if (expected.size() != actual.size()) {
return false;
}
for (size_t i = 0; i < expected.size(); i++) {
if (!analyzer->CanProve(expected[i] == actual[i])) {
return false;
}
}
return true;
};
if (!check_shape(expected_updates_shape, updates_shape->values)) {
ctx->ReportFatal(
Diagnostic::Error(call)
<< "ScatterND op requires the updates tensor to have the shape with constraint: "
<< "`updates.shape = indices.shape[:-1] + data.shape[K:]`, but got "
<< "updates.shape: " << ShapeExpr(updates_shape->values) << ", indices.shape: "
<< ShapeExpr(indices_shape->values) << ", data.shape: " << ShapeExpr(data_shape->values));
}
}
if (data_shape) {
return TensorStructInfo(ShapeExpr(data_shape->values), data_sinfo->dtype, data_sinfo->vdevice);
}
return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, data_sinfo->vdevice);
}

TVM_REGISTER_OP("relax.scatter_nd")
.set_attrs_type<ScatterNDAttrs>()
.set_num_inputs(3)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("indices", "Tensor", "The indices tensor.")
.add_argument("updates", "Tensor", "The input tensor of updates.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoScatterND)
.set_attr<Bool>("FPurity", Bool(true));

} // namespace relax
} // namespace tvm
33 changes: 33 additions & 0 deletions src/relax/op/tensor/manipulate.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,39 @@ Expr tile(Expr data, Array<Integer> repeats);
*/
Expr flip(Expr data, Integer axis);

/*!
* \brief Scatter updates into an array according to indices.
* \param data The input tensor.
* \param indices The index positions to update in `data`.
* \param updates The values to replace to.
* \param axis The axis along which to scatter the elements.
* \param reduction The reduction mode of the scatter elements,
* either "update", "add", "mul", "mean", "max" or "min".
* \return The computed result.
*/
Expr scatter_elements(Expr data, Expr indices, Expr updates, int axis, String reduction);

/*!
* \brief Scatter updates into an array according to indices.
* \param data The input tensor to be updated.
* \param indices The index positions to update in `data`.
* \param updates The values to replace to.
* \param reduction The reduction mode of the scatter operation.
* Supported modes are:
* - "update": Replace the values at the indices with the update values.
* - "add": Add the update values to the existing values at the indices.
* - "mul": Multiply the existing values at the indices by the update values.
* - "max": Take the maximum of the existing value and the update value at each index.
* - "min": Take the minimum of the existing value and the update value at each index.
* \return The computed result tensor with the same shape as `data`.
*
* \note The shape of `indices` defines the shape of the scattered tensor.
* The last dimension of `indices` corresponds to the depth of each index vector.
* The shape of `updates` must match the shape of `indices` except for the last dimension,
* which must match the slice shape at each index.
*/
Expr scatter_nd(Expr data, Expr indices, Expr updates, String reduction);

} // namespace relax
} // namespace tvm

Expand Down
33 changes: 32 additions & 1 deletion tests/python/relax/test_frontend_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ def check_correctness(
tvm_model = relax.transform.DecomposeOpsForInference()(tvm_model)
# Legalize any relax ops into tensorir.
tvm_model = relax.transform.LegalizeOps()(tvm_model)
print(tvm_model)

# Separate model from parameters.
tvm_model, params = relax.frontend.detach_params(tvm_model)
Expand Down Expand Up @@ -487,6 +486,38 @@ def test_scatter(axis: int, name: str, opset: int):
check_correctness(model, inputs={"indices": indices}, opset=opset)


@pytest.mark.parametrize("reduction", ["none", "add", "mul"])
def test_scatter_nd(reduction):
def verify_scatter_nd(data_shape, indices_shape, updates_shape):
scatter_nd_node = helper.make_node(
"ScatterND",
["data", "indices", "updates"],
["output"],
reduction=reduction,
)

graph = helper.make_graph(
[scatter_nd_node],
"scatter_nd_test",
inputs=[
helper.make_tensor_value_info("data", TensorProto.FLOAT, data_shape),
helper.make_tensor_value_info("indices", TensorProto.INT64, indices_shape),
helper.make_tensor_value_info("updates", TensorProto.FLOAT, updates_shape),
],
outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, data_shape)],
)

model = helper.make_model(graph, producer_name="scatter_nd_test")

indices = np.random.choice(data_shape[0], indices_shape)
check_correctness(model, inputs={"indices": indices}, opset=16)

verify_scatter_nd([8], [4, 1], [4])
verify_scatter_nd([4, 4, 4], [2, 1], [2, 4, 4])
verify_scatter_nd([4, 5, 6], [2, 3, 2], [2, 3, 6])
verify_scatter_nd([10], [5, 1], [5])


def test_size():
test_node = helper.make_node("Size", ["x"], ["y"])
graph = helper.make_graph(
Expand Down
Loading
Loading