diff --git a/include/tvm/relax/attrs/manipulate.h b/include/tvm/relax/attrs/manipulate.h index e53ba3c36e7f..ea41488354d8 100644 --- a/include/tvm/relax/attrs/manipulate.h +++ b/include/tvm/relax/attrs/manipulate.h @@ -176,6 +176,17 @@ struct ScatterNDAttrs : public tvm::AttrsNode { } }; // struct ScatterNDAttrs +/*! \brief Attributes used in one_hot operator */ +struct OneHotAttrs : public tvm::AttrsNode { + int depth; + int axis; + + TVM_DECLARE_ATTRS(OneHotAttrs, "relax.attrs.OneHotAttrs") { + TVM_ATTR_FIELD(depth).describe("Depth of the one hot dimension."); + TVM_ATTR_FIELD(axis).set_default(-1).describe("Axis to fill."); + } +}; // struct OneHotAttrs + } // namespace relax } // namespace tvm diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 4770b7ce5cc5..8107859e8e8d 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -287,7 +287,7 @@ class Sub(BinaryBase): relax_op = relax.op.subtract @classmethod - def _impl_v1(cls, bb, inputs, attr, params): + def _impl_v7(cls, bb, inputs, attr, params): return cls.base_impl(bb, inputs, attr, params) @@ -298,7 +298,7 @@ class Mul(BinaryBase): relax_op = relax.op.multiply @classmethod - def _impl_v1(cls, bb, inputs, attr, params): + def _impl_v7(cls, bb, inputs, attr, params): return cls.base_impl(bb, inputs, attr, params) @@ -309,7 +309,7 @@ class Div(BinaryBase): relax_op = relax.op.divide @classmethod - def _impl_v1(cls, bb, inputs, attr, params): + def _impl_v7(cls, bb, inputs, attr, params): return cls.base_impl(bb, inputs, attr, params) @@ -320,7 +320,24 @@ class Pow(BinaryBase): relax_op = relax.op.power @classmethod - def _impl_v1(cls, bb, inputs, attr, params): + def _impl_v7(cls, bb, inputs, attr, params): + return cls.base_impl(bb, inputs, attr, params) + + +class Mod(BinaryBase): + """Converts an onnx Mod node into an equivalent Relax expression.""" + + numpy_op = _np.mod + relax_op = relax.op.mod + + @classmethod + def _impl_v10(cls, bb, inputs, attr, params): + if attr.get("fmod", 0) == 0: + cls.numpy_op = _np.fmod + cls.relax_op = relax.op.floor_mod + else: + cls.numpy_op = _np.mod + cls.relax_op = relax.op.mod return cls.base_impl(bb, inputs, attr, params) @@ -523,6 +540,23 @@ def _impl_v13(cls, bb, inputs, attr, params): return relax.op.nn.log_softmax(inputs[0], axis=axis) +class Hardmax(OnnxOpConverter): + """Converts an onnx Hardmax node into an equivalent Relax expression.""" + + @classmethod + def _impl_v13(cls, bb, inputs, attr, params): + axis = attr.get("axis", -1) + indices = inputs[0] + dtype = indices.struct_info.dtype + axis_len = int(inputs[0].struct_info.shape[axis]) + argmax = relax.op.argmax(indices, axis=axis) + on_value = relax.PrimValue(tvm.tir.const(1.0, dtype)) + off_value = relax.PrimValue(tvm.tir.const(0.0, dtype)) + + one_hot = relax.op.one_hot(argmax, on_value, off_value, axis_len, axis) + return one_hot + + class Transpose(OnnxOpConverter): """Converts an onnx Transpose node into an equivalent Relax expression.""" @@ -731,6 +765,20 @@ def _impl_v1(cls, bb, inputs, attr, params): return relax.op.prod(relax.op.shape_to_tensor(relax.op.shape_of(inputs[0]))) +class EyeLike(OnnxOpConverter): + """Convert an onnx EyeLike node into an equivalent Relax expression.""" + + @classmethod + def _impl_v9(cls, bb, inputs, attr, params): + k = attr.get("k", 0) + input_dtype = inputs[0].struct_info.dtype + if "dtype" in attr and get_type(attr["dtype"]) != input_dtype: + raise ValueError( + f"dtype mismatch between input ({input_dtype}) and attribute ({attr['dtype']})" + ) + return relax.op.eye_like(inputs[0], k, input_dtype) + + class Gemm(OnnxOpConverter): """Convert an onnx Gemm node into an equivalent Relax expression.""" @@ -2520,13 +2568,13 @@ def _impl_v11(cls, bb, inputs, attr, params): depth = get_constant(inputs[1], params) values = get_constant(inputs[2], params) axis = attr.get("axis", -1) - dtype = values.struct_info.dtype assert isinstance(depth, relax.Constant), "Only constant depth currently supported." depth = depth.data.numpy().tolist() assert isinstance(values, relax.Constant), "Only constant values currently supported." values = values.data.numpy().tolist() off_value, on_value = values - return bb.emit_te(topi.one_hot, indices, on_value, off_value, depth, axis, dtype) + off_value, on_value = relax.PrimValue(off_value), relax.PrimValue(on_value) + return relax.op.one_hot(indices, on_value, off_value, depth, axis) class Unique(OnnxOpConverter): @@ -2800,7 +2848,7 @@ def _get_convert_map(): "Sub": Sub, "Mul": Mul, "Div": Div, - # "Mod": Mod, + "Mod": Mod, "Less": Less, "LessOrEqual": LessOrEqual, "Greater": Greater, @@ -2870,7 +2918,7 @@ def _get_convert_map(): "Sigmoid": Sigmoid, "Softmax": Softmax, "LogSoftmax": LogSoftmax, - # "Hardmax": Hardmax, + "Hardmax": Hardmax, "Transpose": Transpose, "Unsqueeze": Unsqueeze, "Where": Where, @@ -2889,7 +2937,7 @@ def _get_convert_map(): "ScatterND": ScatterND, # "Compress": Compress, "Size": Size, - # "EyeLike": EyeLike, + "EyeLike": EyeLike, # Normalization "BatchNormalization": BatchNormalization, "LayerNormalization": LayerNormalization, diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index 84b31ccec01e..1603ea2f0f7e 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -50,6 +50,7 @@ divide, equal, floor_divide, + floor_mod, greater, greater_equal, left_shift, @@ -60,6 +61,7 @@ logical_xor, maximum, minimum, + mod, multiply, not_equal, power, @@ -72,6 +74,8 @@ full_like, ones, ones_like, + eye, + eye_like, tril, triu, zeros, @@ -89,6 +93,7 @@ flatten, flip, layout_transform, + one_hot, permute_dims, repeat, reshape, diff --git a/python/tvm/relax/op/binary.py b/python/tvm/relax/op/binary.py index 7632235cb32c..7a41c8b0953c 100644 --- a/python/tvm/relax/op/binary.py +++ b/python/tvm/relax/op/binary.py @@ -139,6 +139,32 @@ def subtract(x1: Expr, x2: Expr) -> Expr: return _ffi_api.subtract(x1, x2) # type: ignore +def mod(x1: Expr, x2: Expr) -> Expr: + """Modulo with numpy-style broadcasting. + + Parameters + ---------- + x1 : Expr + The first input tensor. + x2 : Expr + The second input tensor. + """ + return _ffi_api.mod(x1, x2) # type: ignore + + +def floor_mod(x1: Expr, x2: Expr) -> Expr: + """Floor modulo with numpy-style broadcasting. + + Parameters + ---------- + x1 : Expr + The first input tensor. + x2 : Expr + The second input tensor. + """ + return _ffi_api.floor_mod(x1, x2) # type: ignore + + ###################### Comparison operators ###################### diff --git a/python/tvm/relax/op/create.py b/python/tvm/relax/op/create.py index 092d79a74dc4..c61d9521a41d 100644 --- a/python/tvm/relax/op/create.py +++ b/python/tvm/relax/op/create.py @@ -163,6 +163,74 @@ def zeros_like(x: Expr, dtype: Optional[Union[str, DataType]] = None) -> Expr: return _ffi_api.zeros_like(x, dtype) # type: ignore +def eye( + n: Union[PrimExprLike, PrimValue], + m: Optional[Union[PrimExprLike, PrimValue]] = None, + k: Union[PrimExprLike, PrimValue] = 0, + dtype: Union[str, DataType] = "float32", +) -> Expr: + """Construct a 2-D tensor with ones on the diagonal and zeros elsewhere. + + Parameters + ---------- + n : Union[PrimExprLike, PrimValue] + Number of rows in the output. + + m : Optional[Union[PrimExprLike, PrimValue]] + Number of columns in the output. If None, defaults to n. + + k : Union[PrimExprLike, PrimValue] + Index of the diagonal: 0 (the default) refers to the main diagonal, + a positive value refers to an upper diagonal, and a negative value + to a lower diagonal. + + dtype : Union[str, DataType] + The data type of the created tensor. + + Returns + ------- + result : relax.Expr + The result tensor. + """ + m = n if m is None else m + n = n if isinstance(n, PrimValue) else PrimValue(n) + m = m if isinstance(m, PrimValue) else PrimValue(m) + k = k if isinstance(k, PrimValue) else PrimValue(k) + return _ffi_api.eye(n, m, k, dtype) # type: ignore + + +def eye_like( + x: Expr, + k: Union[PrimExprLike, PrimValue] = 0, + dtype: Optional[Union[str, DataType]] = None, +) -> Expr: + """Return a 2-D tensor with ones on the diagonal and zeros elsewhere, + with the same shape as the input tensor. + + Parameters + ---------- + x : relax.Expr + The input tensor, which provides the shape, and dtype + when the `dtype` field is not specified. + + k : Union[PrimExprLike, PrimValue] + Index of the diagonal: 0 (the default) refers to the main diagonal, + a positive value refers to an upper diagonal, and a negative value + to a lower diagonal. + + dtype : Optional[Union[str, DataType]] + The data type of the created tensor. + If dtype is not given, it will by default use the dtype of the input tensor. + + Returns + ------- + result : relax.Expr + The result tensor. + """ + k = k if isinstance(k, PrimValue) else PrimValue(k) + return _ffi_api.eye_like(x, k, dtype) # type: ignore + + def arange( start: Union[PrimExprLike, PrimValue], end: Optional[Union[PrimExprLike, PrimValue]] = None, diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index 1673a79b08c2..3210cc821689 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -550,3 +550,47 @@ def scatter_nd(data: Expr, indices: Expr, updates: Expr, reduction: str = "updat """ return _ffi_api.scatter_nd(data, indices, updates, reduction) # type: ignore + + +def one_hot( + indices: Expr, on_value: PrimValue, off_value: PrimValue, depth: int, axis: int = -1 +) -> Expr: + """Returns a one-hot tensor. + + Parameters + ---------- + indices : relax.Expr + The indices to set to `on_value`. + + on_value : relax.PrimValue + The value to fill at `indices`. + + off_value : relax.PrimValue + The value to fill at other locations. + + depth : int + The depth of the one-hot dimension. + + axis : int, optional + The axis to fill. Default is -1 which adds a new dimension at the end. + + Returns + ------- + result : relax.Expr + The computed result. + + Examples + -------- + .. code-block:: python + + indices = [0, 1, 2] + depth = 3 + on_value = 1 + off_value = 0 + + one_hot(indices, on_value, off_value, depth) = + [[1, 0, 0], + [0, 1, 0], + [0, 0, 1]] + """ + return _ffi_api.one_hot(indices, on_value, off_value, depth, axis) # type: ignore diff --git a/python/tvm/relax/transform/legalize_ops/binary.py b/python/tvm/relax/transform/legalize_ops/binary.py index d28e100edb9f..41e317f1e0ef 100644 --- a/python/tvm/relax/transform/legalize_ops/binary.py +++ b/python/tvm/relax/transform/legalize_ops/binary.py @@ -48,7 +48,8 @@ def binary_call_te(bb: BlockBuilder, call: Call) -> Expr: register_legalize("relax.power", _binary(topi.power)) register_legalize("relax.subtract", _binary(topi.subtract)) register_legalize("relax.equal", _binary(topi.equal)) - +register_legalize("relax.mod", _binary(topi.mod)) +register_legalize("relax.floor_mod", _binary(topi.floor_mod)) register_legalize("relax.greater", _binary(topi.greater)) register_legalize("relax.greater_equal", _binary(topi.greater_equal)) register_legalize("relax.less", _binary(topi.less)) diff --git a/python/tvm/relax/transform/legalize_ops/create.py b/python/tvm/relax/transform/legalize_ops/create.py index 1b022672d0bd..8bf85e34dee8 100644 --- a/python/tvm/relax/transform/legalize_ops/create.py +++ b/python/tvm/relax/transform/legalize_ops/create.py @@ -70,6 +70,36 @@ def tril_triu_call_te(bb: BlockBuilder, call: Call) -> Expr: register_legalize("relax.triu", _tril_triu(is_upper=True, primfunc_name="triu")) +def _eye(is_like: bool, primfunc_name: str) -> LegalizeFunc: + def eye_call_te(bb: BlockBuilder, call: Call) -> Expr: + _convert_to_scalar_const = lambda x: _try_convert_to_scalar_const(x, python_native=True) + if is_like: + x = call.args[0] + k = _convert_to_scalar_const(call.args[1]) if len(call.args) > 1 else 0 + n, m = x.struct_info.shape + dtype = x.struct_info.dtype + else: + n = _convert_to_scalar_const(call.args[0]) + m = _convert_to_scalar_const(call.args[1]) if len(call.args) > 1 else n + k = _convert_to_scalar_const(call.args[2]) if len(call.args) > 2 else 0 + dtype = call.attrs.dtype + + return bb.call_te( + topi.eye, + n, + m, + k, + dtype, + primfunc_name_hint=primfunc_name, + ) + + return eye_call_te + + +register_legalize("relax.eye", _eye(is_like=False, primfunc_name="eye")) +register_legalize("relax.eye_like", _eye(is_like=True, primfunc_name="eye_like")) + + @register_legalize("relax.arange") def _arange(bb: BlockBuilder, call: Call) -> Expr: assert len(call.args) == 3 diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py b/python/tvm/relax/transform/legalize_ops/manipulate.py index 105d763403af..163085a07c34 100644 --- a/python/tvm/relax/transform/legalize_ops/manipulate.py +++ b/python/tvm/relax/transform/legalize_ops/manipulate.py @@ -185,6 +185,25 @@ def scatter_nd(data, indices, updates, reduction): ) +@register_legalize("relax.one_hot") +def _one_hot(bb: BlockBuilder, call: Call) -> Expr: + indices, on_value, off_value = call.args + if not (isinstance(on_value, relax.PrimValue) and isinstance(off_value, relax.PrimValue)): + raise ValueError("on_value and off_value must be PrimValue") + on_value, off_value = on_value.value, off_value.value + if on_value.dtype != off_value.dtype: + raise ValueError("on_value and off_value must have the same dtype") + return bb.call_te( + topi.one_hot, + indices, + on_value, + off_value, + call.attrs.depth, + call.attrs.axis, + on_value.dtype, + ) + + @register_legalize("relax.layout_transform") def _layout_transform(bb: BlockBuilder, call: Call) -> Expr: def te_layout_transform(data, name): diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index f7847e2af8ed..049345fcb10d 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -85,10 +85,13 @@ ewise_fma, exp, expand_dims, + eye, + eye_like, flatten, flip, floor, floor_divide, + floor_mod, full, full_like, grad, @@ -119,6 +122,7 @@ memory, min, minimum, + mod, multinomial_from_uniform, multiply, negative, @@ -127,6 +131,7 @@ null_value, ones, ones_like, + one_hot, permute_dims, power, print, @@ -753,10 +758,13 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "exp", "expand_dims", "ext_dev", + "eye", + "eye_like", "flatten", "flip", "floor", "floor_divide", + "floor_mod", "full", "full_like", "func_attr", @@ -795,6 +803,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "metal", "min", "minimum", + "mod", "multinomial_from_uniform", "multiply", "negative", @@ -802,6 +811,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "null_value", "ones", "ones_like", + "one_hot", "opencl", "output", "permute_dims", diff --git a/python/tvm/topi/tensor.py b/python/tvm/topi/tensor.py index 31ebe86760cb..449c599deaf3 100644 --- a/python/tvm/topi/tensor.py +++ b/python/tvm/topi/tensor.py @@ -16,7 +16,11 @@ # under the License. # pylint: disable=invalid-name,consider-using-enumerate,unused-argument,len-as-condition """Elementwise operators""" -from __future__ import absolute_import as _abs + +from typing import Optional + +from tvm import te + from . import cpp @@ -73,3 +77,32 @@ def full_like(x, fill_value): The result. """ return cpp.full_like(x, fill_value) + + +def eye(n: int, m: Optional[int] = None, k: int = 0, dtype: str = "float32") -> te.Tensor: + """Generate an identity matrix or a matrix with ones on the k-th diagonal. + + Parameters + ---------- + n : int + Number of rows + m : int, optional + Number of columns. If None, defaults to n. + k : int, optional + Index of the diagonal. 0 (default) refers to the main diagonal. + A positive value refers to an upper diagonal, and a negative value + to a lower diagonal. + dtype : str, optional + Data type of the returned array. + + Returns + ------- + y : tvm.te.Tensor + The result. + """ + m = m if m is not None else n + return te.compute( + (n, m), + lambda i, j: te.if_then_else(i == j - k, te.const(1, dtype), te.const(0, dtype)), + name="eye", + ) diff --git a/src/relax/op/distributed/binary.cc b/src/relax/op/distributed/binary.cc index 6ad71e0f85bf..1e7fa8172718 100644 --- a/src/relax/op/distributed/binary.cc +++ b/src/relax/op/distributed/binary.cc @@ -42,6 +42,8 @@ RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(floor_divide); RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(multiply); RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(power); RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(subtract); +RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(mod); +RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(floor_mod); /***************** Comparison operators *****************/ diff --git a/src/relax/op/tensor/binary.cc b/src/relax/op/tensor/binary.cc index f1dc3d4904c8..bd4c681c7925 100644 --- a/src/relax/op/tensor/binary.cc +++ b/src/relax/op/tensor/binary.cc @@ -181,6 +181,8 @@ RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(floor_divide); RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(multiply); RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(power); RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(subtract); +RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(mod); +RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(floor_mod); /***************** Comparison operators *****************/ diff --git a/src/relax/op/tensor/binary.h b/src/relax/op/tensor/binary.h index 003bcb7e27cf..b66eb96f8452 100644 --- a/src/relax/op/tensor/binary.h +++ b/src/relax/op/tensor/binary.h @@ -79,6 +79,12 @@ Expr power(Expr x1, Expr x2); /*! \brief Subtraction with numpy-style broadcasting. */ Expr subtract(Expr x1, Expr x2); +/*! \brief Modulo with numpy-style broadcasting. */ +Expr mod(Expr x1, Expr x2); + +/*! \brief Floor modulo with numpy-style broadcasting. */ +Expr floor_mod(Expr x1, Expr x2); + /***************** Comparison operators *****************/ /*! \brief Broadcasted element-wise test for (lhs == rhs). */ diff --git a/src/relax/op/tensor/create.cc b/src/relax/op/tensor/create.cc index 7aca1470aee4..8696d85f7756 100644 --- a/src/relax/op/tensor/create.cc +++ b/src/relax/op/tensor/create.cc @@ -228,6 +228,90 @@ TVM_REGISTER_OP("relax.zeros_like") .set_attr("FInferStructInfo", InferStructInfoOnesLikeZerosLike) .set_attr("FPurity", Bool(true)); +/* relax.eye & relax.eye_like */ +Expr eye(PrimValue n, PrimValue m, PrimValue k, DataType dtype) { + ObjectPtr attrs = make_object(); + attrs->dtype = dtype; + static const Op& op = Op::Get("relax.eye"); + return Call(op, {std::move(n), std::move(m), std::move(k)}, Attrs(attrs), {}); +} + +Expr eye_like(Expr x, PrimValue k, DataType dtype) { + ObjectPtr attrs = make_object(); + attrs->dtype = dtype; + static const Op& op = Op::Get("relax.eye_like"); + return Call(op, {std::move(x), std::move(k)}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.eye").set_body_typed(eye); +TVM_REGISTER_GLOBAL("relax.op.eye_like").set_body_typed(eye_like); + +StructInfo InferStructInfoEye(const Call& call, const BlockBuilder& ctx) { + if (call->args.size() != 3) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Eye op should have 3 arguments: n, m, and k, but got " << call->args.size() + << " arguments"); + } + + auto get_prim_value = [&ctx](const Expr& expr, std::string key) { + if (!expr->IsInstance()) { + ctx->ReportFatal(Diagnostic::Error(expr) + << "Eye expects the `" << key << "` to be a PrimValue, but got " + << expr->GetTypeKey()); + } + return expr.as()->value; + }; + + PrimExpr n = get_prim_value(call->args[0], "n"); + PrimExpr m = get_prim_value(call->args[1], "m"); + + DataType dtype = call->attrs.as()->dtype; + return TensorStructInfo(ShapeExpr({n, m}), dtype); +} + +StructInfo InferStructInfoEyeLike(const Call& call, const BlockBuilder& ctx) { + if (call->args.size() != 2) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Eye_like op should have 2 arguments: x and k, but got " + << call->args.size() << " arguments"); + } + + const auto* x_sinfo = GetStructInfoAs(call->args[0]); + if (x_sinfo == nullptr) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Eye_like expects the input `x` to be a Tensor, but got " + << call->args[0]->struct_info_->GetTypeKey()); + } + if (x_sinfo->ndim != 2 && x_sinfo->ndim != kUnknownNDim) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Eye_like expects the input tensor to be 2-dimensional, but got " + << x_sinfo->ndim << " dimensions"); + } + + const auto* attrs = call->attrs.as(); + DataType out_dtype = attrs->dtype.is_void() ? x_sinfo->dtype : attrs->dtype; + + return TensorStructInfo(x_sinfo->shape.value(), out_dtype, x_sinfo->vdevice); +} + +TVM_REGISTER_OP("relax.eye") + .set_attrs_type() + .set_num_inputs(3) + .add_argument("n", "PrimValue", "Number of rows in the output.") + .add_argument("m", "PrimValue", "Number of columns in the output.") + .add_argument("k", "PrimValue", "Index of the diagonal.") + .set_attr("FInferStructInfo", InferStructInfoEye) + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) + .set_attr("FPurity", Bool(true)); + +TVM_REGISTER_OP("relax.eye_like") + .set_attrs_type() + .set_num_inputs(2) + .add_argument("x", "Tensor", "The input tensor.") + .add_argument("k", "PrimValue", "Index of the diagonal.") + .set_attr("FInferStructInfo", InferStructInfoEyeLike) + .set_attr("FPurity", Bool(true)); + /* relax.arange */ Expr arange(PrimValue start, PrimValue stop, PrimValue step, DataType dtype) { ObjectPtr attrs = make_object(); diff --git a/src/relax/op/tensor/create.h b/src/relax/op/tensor/create.h index 6e7c8255238a..d88336146d44 100644 --- a/src/relax/op/tensor/create.h +++ b/src/relax/op/tensor/create.h @@ -72,12 +72,48 @@ Expr ones(Expr shape, DataType dtype); */ Expr ones_like(Expr x, DataType dtype); -/*! \brief Construct a tensor of all zeros, with the input shape and dtype. */ +/*! + * \brief Construct a tensor of all zeros, with the input shape and dtype. + * \param shape The shape of the created tensor. + * \param dtype The data type of the created tensor. + * \return The result tensor. + */ Expr zeros(Expr shape, DataType dtype); -/*! \brief Construct a tensor with all zeros, with shape of the input tensor shape. */ +/*! + * \brief Construct a tensor with all zeros, with shape of the input tensor shape. + * \param x The input tensor, which provides the shape, and dtype + * when the input dtype is void. + * \param dtype The data type of the created tensor. If it is + * void, the input tensor's dtype will be used. + * \return The result tensor. + */ Expr zeros_like(Expr x, DataType dtype); +/*! + * \brief Construct a 2-D tensor with ones on the diagonal and zeros elsewhere. + * \param n The number of rows and columns in the output. + * \param m The number of columns in the output. If None, defaults to n. + * \param k The index of the diagonal. A positive value refers to an upper diagonal, + * a negative value to a lower diagonal, and 0 to the main diagonal. + * \param dtype The data type of the created tensor. + * \return The result tensor. + */ +Expr eye(PrimValue n, PrimValue m, PrimValue k, DataType dtype); + +/*! + * \brief Construct a tensor with ones on the diagonal and zeros elsewhere, + * with shape and dtype similar to the input tensor. + * \param x The input tensor, which provides the shape, and dtype + * when the input dtype is void. + * \param k The index of the diagonal. A positive value refers to an upper diagonal, + * a negative value to a lower diagonal, and 0 to the main diagonal. + * \param dtype The data type of the created tensor. If it is + * void, the input tensor's dtype will be used. + * \return The result tensor. + */ +Expr eye_like(Expr x, PrimValue k, DataType dtype); + /*! \brief Construct a tensor with evenly spaced elements. */ Expr arange(PrimValue start, PrimValue stop, PrimValue step, DataType dtype); diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index ca7d0a0945bc..ba443413025a 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -30,6 +30,8 @@ #include #include +#include "tvm/runtime/data_type.h" + namespace tvm { namespace relax { @@ -1665,5 +1667,78 @@ TVM_REGISTER_OP("relax.scatter_nd") .set_attr("FInferStructInfo", InferStructInfoScatterND) .set_attr("FPurity", Bool(true)); +/* relax.one_hot */ +TVM_REGISTER_NODE_TYPE(OneHotAttrs); +Expr one_hot(Expr indices, PrimValue on_value, PrimValue off_value, int depth, int axis) { + ObjectPtr attrs = make_object(); + attrs->depth = depth; + attrs->axis = axis; + + // Check if on_value and off_value have the same dtype + DataType on_dtype = on_value->value->dtype; + DataType off_dtype = off_value->value->dtype; + ICHECK(on_dtype == off_dtype) << "one_hot: on_value and off_value must have the same dtype, " + << "but got " << on_dtype << " and " << off_dtype; + + ICHECK(depth > 0) << "one_hot: depth must be positive, but got " << depth; + + static const Op& op = Op::Get("relax.one_hot"); + return Call(op, {indices, on_value, off_value}, Attrs(attrs), {}); +} // namespace relax + +TVM_REGISTER_GLOBAL("relax.op.one_hot").set_body_typed(one_hot); + +StructInfo InferStructInfoOneHot(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo indices_sinfo = GetInputTensorStructInfo(call, 0, ctx); + const auto* attrs = call->attrs.as(); + PrimValue on_value = Downcast(call->args[1]); + PrimValue off_value = Downcast(call->args[2]); + // Check if on_value and off_value have the same dtype + ICHECK(on_value->value->dtype == off_value->value->dtype) + << "one_hot: on_value and off_value must have the same dtype, " + << "but got " << on_value->value->dtype << " and " << off_value->value->dtype; + DataType dtype = on_value->value->dtype; + + // Check if indices has an integer dtype + if (indices_sinfo->IsUnknownDtype()) { + LOG(WARNING) << "Data type of indices has not been specified. Assume it has an integer type."; + } else if (!(indices_sinfo->dtype.is_int() || indices_sinfo->dtype.is_uint())) { + ctx->ReportFatal(Diagnostic::Error(call) + << "one_hot op requires the input indices to have integer dtype. However, the " + "given indices dtype is " + << indices_sinfo->dtype); + } + // Check if indices has unknown dimension + if (indices_sinfo->IsUnknownNdim()) { + return TensorStructInfo(dtype, kUnknownNDim, indices_sinfo->vdevice); + } + // Get the shape of indices + const auto* indices_shape = indices_sinfo->shape.as(); + if (indices_shape == nullptr) { + return TensorStructInfo(dtype, indices_sinfo->ndim + 1, indices_sinfo->vdevice); + } + + Array output_shape = indices_shape->values; + int axis = attrs->axis; + if (axis < 0) { + axis += output_shape.size() + 1; + } + ICHECK(0 <= axis && axis <= static_cast(output_shape.size())) + << "one_hot: axis must be in the range of [0, " << output_shape.size() << "], " + << "but got " << axis; + output_shape.insert(output_shape.begin() + axis, attrs->depth); + + return TensorStructInfo(ShapeExpr(output_shape), dtype, indices_sinfo->vdevice); +} + +TVM_REGISTER_OP("relax.one_hot") + .set_attrs_type() + .set_num_inputs(3) + .add_argument("indices", "Tensor", "The indices tensor.") + .add_argument("on_value", "PrimValue", "The value to fill at specified indices.") + .add_argument("off_value", "PrimValue", "The value to fill at other indices.") + .set_attr("FInferStructInfo", InferStructInfoOneHot) + .set_attr("FPurity", Bool(true)); + } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h index e9fa1131e803..010ceb663ef3 100644 --- a/src/relax/op/tensor/manipulate.h +++ b/src/relax/op/tensor/manipulate.h @@ -27,6 +27,7 @@ #include #include "../op_common.h" +#include "tvm/relax/expr.h" namespace tvm { namespace relax { @@ -206,6 +207,17 @@ Expr scatter_elements(Expr data, Expr indices, Expr updates, int axis, String re */ Expr scatter_nd(Expr data, Expr indices, Expr updates, String reduction); +/*! + * \brief Returns a one-hot tensor. + * \param indices The indices to set to `on_value`. + * \param on_value The value to fill at `indices`. + * \param off_value The value to fill at other locations. + * \param depth The depth of the one hot dimension. + * \param axis The axis to fill. + * \return The computed result. + */ +Expr one_hot(Expr indices, PrimValue on_value, PrimValue off_value, int depth, int axis); + } // namespace relax } // namespace tvm diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 1b4c5d281abb..46373510b101 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -63,8 +63,11 @@ def generate_random_inputs( if dtype == "bool": # random_value = np.random.choice(a=[False, True], size=shape) random_value = rg.choice(a=[False, True], size=shape) + elif dtype.startswith("int"): + # Keep non-zero values + random_value = rg.integers(low=-63, high=63, size=shape).astype(dtype) + random_value[random_value <= 0] -= 1 else: - # random_value = np.random.normal(size=shape).astype(dtype) random_value = rg.standard_normal(size=shape).astype(dtype) input_values[i.name] = random_value @@ -246,7 +249,6 @@ def verify_binary_scalar(op_name, attrs={}, domain=None, dtype=TensorProto.INT32 ) model = helper.make_model(graph, producer_name="binary_test") - # NOTE: explicitly pass inputs to avoid numerical error check_correctness(model, opset=opset) @@ -327,6 +329,16 @@ def test_binary(op_name: str): verify_binary_scalar(op_name) +@pytest.mark.parametrize("int_mode", [True, False]) +def test_mod(int_mode: bool): + if int_mode: + dtype, fmod = TensorProto.INT32, 0 + else: + dtype, fmod = TensorProto.FLOAT, 1 + verify_binary("Mod", [1, 32], [1, 32], [1, 32], attrs={"fmod": fmod}, dtype=dtype) + verify_binary_scalar("Mod", attrs={"fmod": fmod}, dtype=dtype) + + @pytest.mark.parametrize("num_inputs", [1, 2, 4]) @pytest.mark.parametrize("op_name", ["Min", "Max", "Sum", "Mean"]) def test_multi_input(op_name: str, num_inputs: int): @@ -430,6 +442,7 @@ def test_bitwise_shift(direction: str): "Sigmoid", "Softmax", "LogSoftmax", + "Hardmax", "Identity", ], ) @@ -445,7 +458,7 @@ def test_unary(op_name: str): output_dtype = TensorProto.BOOL else: output_dtype = TensorProto.FLOAT - verify_unary(op_name, [32, 32], input_dtype=input_dtype, output_dtype=output_dtype) + verify_unary(op_name, [8, 8, 8], input_dtype=input_dtype, output_dtype=output_dtype) @pytest.mark.parametrize("from_type", [TensorProto.INT32, TensorProto.FLOAT, TensorProto.FLOAT16]) @@ -567,6 +580,11 @@ def test_size(): check_correctness(model) +@pytest.mark.parametrize("k", [-1, 0, 1]) +def test_eye_like(k: int): + verify_unary("EyeLike", [32, 32], attrs={"k": k}) + + @pytest.mark.parametrize("alpha", [None, 0.25, 1.0]) @pytest.mark.parametrize("beta", [None, 0.35, 1.0]) @pytest.mark.parametrize("useC", [False, True]) @@ -966,7 +984,7 @@ def test_cumsum1(): ) model = helper.make_model(graph, producer_name="cumsum_graph") - check_correctness(model) + check_correctness(model, inputs={"axis": np.array([0], dtype=np.int32)}) @pytest.mark.parametrize("axis", [[0, 2], None]) diff --git a/tests/python/relax/test_op_create.py b/tests/python/relax/test_op_create.py index 1e895169f620..67f347019163 100644 --- a/tests/python/relax/test_op_create.py +++ b/tests/python/relax/test_op_create.py @@ -545,6 +545,64 @@ def test_ones_like_zeros_like_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.zeros_like(x1)) +def test_eye_infer_struct_info(): + bb = relax.BlockBuilder() + + _check_inference(bb, relax.op.eye(3), relax.TensorStructInfo((3, 3), "float32")) + _check_inference(bb, relax.op.eye(2, 4), relax.TensorStructInfo((2, 4), "float32")) + _check_inference(bb, relax.op.eye(3, dtype="int64"), relax.TensorStructInfo((3, 3), "int64")) + _check_inference(bb, relax.op.eye(3, 5, k=1), relax.TensorStructInfo((3, 5), "float32")) + _check_inference(bb, relax.op.eye(3, 5, k=-2), relax.TensorStructInfo((3, 5), "float32")) + + +def test_eye_infer_struct_info_symbolic(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + m = tir.Var("m", "int64") + k = tir.Var("k", "int64") + + _check_inference(bb, relax.op.eye(n), relax.TensorStructInfo((n, n), "float32")) + _check_inference(bb, relax.op.eye(n, m), relax.TensorStructInfo((n, m), "float32")) + _check_inference(bb, relax.op.eye(n, k=k), relax.TensorStructInfo((n, n), "float32")) + + +def test_eye_like_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((3, 4), "float32")) + x1 = relax.Var("x", R.Tensor((2, 5), "int64")) + x2 = relax.Var("x", R.Tensor((3, 3))) + + _check_inference(bb, relax.op.eye_like(x0), relax.TensorStructInfo((3, 4), "float32")) + _check_inference(bb, relax.op.eye_like(x1), relax.TensorStructInfo((2, 5), "int64")) + _check_inference(bb, relax.op.eye_like(x2), relax.TensorStructInfo((3, 3), dtype="")) + _check_inference(bb, relax.op.eye_like(x0, k=1), relax.TensorStructInfo((3, 4), "float32")) + _check_inference( + bb, relax.op.eye_like(x1, dtype="float32"), relax.TensorStructInfo((2, 5), "float32") + ) + + +def test_eye_like_infer_struct_info_symbolic(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + m = tir.Var("m", "int64") + x = relax.Var("x", R.Tensor((n, m), "float32")) + k = tir.Var("k", "int64") + + _check_inference(bb, relax.op.eye_like(x), relax.TensorStructInfo((n, m), "float32")) + _check_inference(bb, relax.op.eye_like(x, k=k), relax.TensorStructInfo((n, m), "float32")) + + +def test_eye_like_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.eye_like(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.eye_like(x1)) + + def test_arange_infer_struct_info(): bb = relax.BlockBuilder() diff --git a/tests/python/relax/test_op_manipulate.py b/tests/python/relax/test_op_manipulate.py index e958b03e4ce6..f6aefc859114 100644 --- a/tests/python/relax/test_op_manipulate.py +++ b/tests/python/relax/test_op_manipulate.py @@ -3377,5 +3377,57 @@ def test_scatter_nd_infer_struct_info(): ) +def test_one_hot_infer_struct_info(): + bb = relax.BlockBuilder() + + # Test case 1: Basic usage + i0 = relax.Var("indices", R.Tensor((3,), "int32")) + _check_inference( + bb, + relax.op.one_hot(i0, relax.PrimValue(1.0), relax.PrimValue(0.0), 5), + relax.TensorStructInfo((3, 5), "float32"), + ) + + # Test case 2: With specified axis + i1 = relax.Var("indices", R.Tensor((2, 2), "int32")) + _check_inference( + bb, + relax.op.one_hot(i1, relax.PrimValue(1), relax.PrimValue(0), 3, axis=1), + relax.TensorStructInfo((2, 3, 2), "int64"), + ) + + # Test case 3: With symbolic shape + n = tir.Var("n", "int64") + i2 = relax.Var("indices", R.Tensor((n,), "int32")) + _check_inference( + bb, + relax.op.one_hot(i2, relax.PrimValue(1.0), relax.PrimValue(0.0), 4), + relax.TensorStructInfo((n, 4), "float32"), + ) + + # Test case 4: With unknown shape + i3 = relax.Var("indices", R.Tensor("int32")) + _check_inference( + bb, + relax.op.one_hot(i3, relax.PrimValue(1.0), relax.PrimValue(0.0), 6), + relax.TensorStructInfo(dtype="float32"), + ) + + # Test case 5: With different on_value and off_value dtypes + i3 = relax.Var("indices", R.Tensor((2, 3), "int32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.one_hot(i3, relax.PrimValue(1.0), relax.PrimValue(0), 5)) + + # Test case 6: With invalid indices dtype + i4 = relax.Var("indices", R.Tensor((2, 3), "float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.one_hot(i4, relax.PrimValue(1.0), relax.PrimValue(0.0), 5)) + + # Test case 7: With invalid depth + i5 = relax.Var("indices", R.Tensor((2, 3), "int32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.one_hot(i5, relax.PrimValue(1.0), relax.PrimValue(0.0), -1)) + + if __name__ == "__main__": tvm.testing.main()