Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions include/tvm/relax/attrs/manipulate.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,17 @@ struct ScatterNDAttrs : public tvm::AttrsNode<ScatterNDAttrs> {
}
}; // struct ScatterNDAttrs

/*! \brief Attributes used in one_hot operator */
struct OneHotAttrs : public tvm::AttrsNode<OneHotAttrs> {
int depth;
int axis;

TVM_DECLARE_ATTRS(OneHotAttrs, "relax.attrs.OneHotAttrs") {
TVM_ATTR_FIELD(depth).describe("Depth of the one hot dimension.");
TVM_ATTR_FIELD(axis).set_default(-1).describe("Axis to fill.");
}
}; // struct OneHotAttrs

} // namespace relax
} // namespace tvm

Expand Down
66 changes: 57 additions & 9 deletions python/tvm/relax/frontend/onnx/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ class Sub(BinaryBase):
relax_op = relax.op.subtract

@classmethod
def _impl_v1(cls, bb, inputs, attr, params):
def _impl_v7(cls, bb, inputs, attr, params):
return cls.base_impl(bb, inputs, attr, params)


Expand All @@ -298,7 +298,7 @@ class Mul(BinaryBase):
relax_op = relax.op.multiply

@classmethod
def _impl_v1(cls, bb, inputs, attr, params):
def _impl_v7(cls, bb, inputs, attr, params):
return cls.base_impl(bb, inputs, attr, params)


Expand All @@ -309,7 +309,7 @@ class Div(BinaryBase):
relax_op = relax.op.divide

@classmethod
def _impl_v1(cls, bb, inputs, attr, params):
def _impl_v7(cls, bb, inputs, attr, params):
return cls.base_impl(bb, inputs, attr, params)


Expand All @@ -320,7 +320,24 @@ class Pow(BinaryBase):
relax_op = relax.op.power

@classmethod
def _impl_v1(cls, bb, inputs, attr, params):
def _impl_v7(cls, bb, inputs, attr, params):
return cls.base_impl(bb, inputs, attr, params)


class Mod(BinaryBase):
"""Converts an onnx Mod node into an equivalent Relax expression."""

numpy_op = _np.mod
relax_op = relax.op.mod

@classmethod
def _impl_v10(cls, bb, inputs, attr, params):
if attr.get("fmod", 0) == 0:
cls.numpy_op = _np.fmod
cls.relax_op = relax.op.floor_mod
else:
cls.numpy_op = _np.mod
cls.relax_op = relax.op.mod
return cls.base_impl(bb, inputs, attr, params)


Expand Down Expand Up @@ -523,6 +540,23 @@ def _impl_v13(cls, bb, inputs, attr, params):
return relax.op.nn.log_softmax(inputs[0], axis=axis)


class Hardmax(OnnxOpConverter):
"""Converts an onnx Hardmax node into an equivalent Relax expression."""

@classmethod
def _impl_v13(cls, bb, inputs, attr, params):
axis = attr.get("axis", -1)
indices = inputs[0]
dtype = indices.struct_info.dtype
axis_len = int(inputs[0].struct_info.shape[axis])
argmax = relax.op.argmax(indices, axis=axis)
on_value = relax.PrimValue(tvm.tir.const(1.0, dtype))
off_value = relax.PrimValue(tvm.tir.const(0.0, dtype))

one_hot = relax.op.one_hot(argmax, on_value, off_value, axis_len, axis)
return one_hot


class Transpose(OnnxOpConverter):
"""Converts an onnx Transpose node into an equivalent Relax expression."""

Expand Down Expand Up @@ -731,6 +765,20 @@ def _impl_v1(cls, bb, inputs, attr, params):
return relax.op.prod(relax.op.shape_to_tensor(relax.op.shape_of(inputs[0])))


class EyeLike(OnnxOpConverter):
"""Convert an onnx EyeLike node into an equivalent Relax expression."""

@classmethod
def _impl_v9(cls, bb, inputs, attr, params):
k = attr.get("k", 0)
input_dtype = inputs[0].struct_info.dtype
if "dtype" in attr and get_type(attr["dtype"]) != input_dtype:
raise ValueError(
f"dtype mismatch between input ({input_dtype}) and attribute ({attr['dtype']})"
)
return relax.op.eye_like(inputs[0], k, input_dtype)


class Gemm(OnnxOpConverter):
"""Convert an onnx Gemm node into an equivalent Relax expression."""

Expand Down Expand Up @@ -2520,13 +2568,13 @@ def _impl_v11(cls, bb, inputs, attr, params):
depth = get_constant(inputs[1], params)
values = get_constant(inputs[2], params)
axis = attr.get("axis", -1)
dtype = values.struct_info.dtype
assert isinstance(depth, relax.Constant), "Only constant depth currently supported."
depth = depth.data.numpy().tolist()
assert isinstance(values, relax.Constant), "Only constant values currently supported."
values = values.data.numpy().tolist()
off_value, on_value = values
return bb.emit_te(topi.one_hot, indices, on_value, off_value, depth, axis, dtype)
off_value, on_value = relax.PrimValue(off_value), relax.PrimValue(on_value)
return relax.op.one_hot(indices, on_value, off_value, depth, axis)


class Unique(OnnxOpConverter):
Expand Down Expand Up @@ -2800,7 +2848,7 @@ def _get_convert_map():
"Sub": Sub,
"Mul": Mul,
"Div": Div,
# "Mod": Mod,
"Mod": Mod,
"Less": Less,
"LessOrEqual": LessOrEqual,
"Greater": Greater,
Expand Down Expand Up @@ -2870,7 +2918,7 @@ def _get_convert_map():
"Sigmoid": Sigmoid,
"Softmax": Softmax,
"LogSoftmax": LogSoftmax,
# "Hardmax": Hardmax,
"Hardmax": Hardmax,
"Transpose": Transpose,
"Unsqueeze": Unsqueeze,
"Where": Where,
Expand All @@ -2889,7 +2937,7 @@ def _get_convert_map():
"ScatterND": ScatterND,
# "Compress": Compress,
"Size": Size,
# "EyeLike": EyeLike,
"EyeLike": EyeLike,
# Normalization
"BatchNormalization": BatchNormalization,
"LayerNormalization": LayerNormalization,
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/relax/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
divide,
equal,
floor_divide,
floor_mod,
greater,
greater_equal,
left_shift,
Expand All @@ -60,6 +61,7 @@
logical_xor,
maximum,
minimum,
mod,
multiply,
not_equal,
power,
Expand All @@ -72,6 +74,8 @@
full_like,
ones,
ones_like,
eye,
eye_like,
tril,
triu,
zeros,
Expand All @@ -89,6 +93,7 @@
flatten,
flip,
layout_transform,
one_hot,
permute_dims,
repeat,
reshape,
Expand Down
26 changes: 26 additions & 0 deletions python/tvm/relax/op/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,32 @@ def subtract(x1: Expr, x2: Expr) -> Expr:
return _ffi_api.subtract(x1, x2) # type: ignore


def mod(x1: Expr, x2: Expr) -> Expr:
"""Modulo with numpy-style broadcasting.

Parameters
----------
x1 : Expr
The first input tensor.
x2 : Expr
The second input tensor.
"""
return _ffi_api.mod(x1, x2) # type: ignore


def floor_mod(x1: Expr, x2: Expr) -> Expr:
"""Floor modulo with numpy-style broadcasting.

Parameters
----------
x1 : Expr
The first input tensor.
x2 : Expr
The second input tensor.
"""
return _ffi_api.floor_mod(x1, x2) # type: ignore


###################### Comparison operators ######################


Expand Down
68 changes: 68 additions & 0 deletions python/tvm/relax/op/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,74 @@ def zeros_like(x: Expr, dtype: Optional[Union[str, DataType]] = None) -> Expr:
return _ffi_api.zeros_like(x, dtype) # type: ignore


def eye(
n: Union[PrimExprLike, PrimValue],
m: Optional[Union[PrimExprLike, PrimValue]] = None,
k: Union[PrimExprLike, PrimValue] = 0,
dtype: Union[str, DataType] = "float32",
) -> Expr:
"""Construct a 2-D tensor with ones on the diagonal and zeros elsewhere.

Parameters
----------
n : Union[PrimExprLike, PrimValue]
Number of rows in the output.

m : Optional[Union[PrimExprLike, PrimValue]]
Number of columns in the output. If None, defaults to n.

k : Union[PrimExprLike, PrimValue]
Index of the diagonal: 0 (the default) refers to the main diagonal,
a positive value refers to an upper diagonal, and a negative value
to a lower diagonal.

dtype : Union[str, DataType]
The data type of the created tensor.

Returns
-------
result : relax.Expr
The result tensor.
"""
m = n if m is None else m
n = n if isinstance(n, PrimValue) else PrimValue(n)
m = m if isinstance(m, PrimValue) else PrimValue(m)
k = k if isinstance(k, PrimValue) else PrimValue(k)
return _ffi_api.eye(n, m, k, dtype) # type: ignore


def eye_like(
x: Expr,
k: Union[PrimExprLike, PrimValue] = 0,
dtype: Optional[Union[str, DataType]] = None,
) -> Expr:
"""Return a 2-D tensor with ones on the diagonal and zeros elsewhere,
with the same shape as the input tensor.

Parameters
----------
x : relax.Expr
The input tensor, which provides the shape, and dtype
when the `dtype` field is not specified.

k : Union[PrimExprLike, PrimValue]
Index of the diagonal: 0 (the default) refers to the main diagonal,
a positive value refers to an upper diagonal, and a negative value
to a lower diagonal.

dtype : Optional[Union[str, DataType]]
The data type of the created tensor.
If dtype is not given, it will by default use the dtype of the input tensor.

Returns
-------
result : relax.Expr
The result tensor.
"""
k = k if isinstance(k, PrimValue) else PrimValue(k)
return _ffi_api.eye_like(x, k, dtype) # type: ignore


def arange(
start: Union[PrimExprLike, PrimValue],
end: Optional[Union[PrimExprLike, PrimValue]] = None,
Expand Down
44 changes: 44 additions & 0 deletions python/tvm/relax/op/manipulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,3 +550,47 @@ def scatter_nd(data: Expr, indices: Expr, updates: Expr, reduction: str = "updat

"""
return _ffi_api.scatter_nd(data, indices, updates, reduction) # type: ignore


def one_hot(
indices: Expr, on_value: PrimValue, off_value: PrimValue, depth: int, axis: int = -1
) -> Expr:
"""Returns a one-hot tensor.

Parameters
----------
indices : relax.Expr
The indices to set to `on_value`.

on_value : relax.PrimValue
The value to fill at `indices`.

off_value : relax.PrimValue
The value to fill at other locations.

depth : int
The depth of the one-hot dimension.

axis : int, optional
The axis to fill. Default is -1 which adds a new dimension at the end.

Returns
-------
result : relax.Expr
The computed result.

Examples
--------
.. code-block:: python

indices = [0, 1, 2]
depth = 3
on_value = 1
off_value = 0

one_hot(indices, on_value, off_value, depth) =
[[1, 0, 0],
[0, 1, 0],
[0, 0, 1]]
"""
return _ffi_api.one_hot(indices, on_value, off_value, depth, axis) # type: ignore
3 changes: 2 additions & 1 deletion python/tvm/relax/transform/legalize_ops/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def binary_call_te(bb: BlockBuilder, call: Call) -> Expr:
register_legalize("relax.power", _binary(topi.power))
register_legalize("relax.subtract", _binary(topi.subtract))
register_legalize("relax.equal", _binary(topi.equal))

register_legalize("relax.mod", _binary(topi.mod))
register_legalize("relax.floor_mod", _binary(topi.floor_mod))
register_legalize("relax.greater", _binary(topi.greater))
register_legalize("relax.greater_equal", _binary(topi.greater_equal))
register_legalize("relax.less", _binary(topi.less))
Expand Down
30 changes: 30 additions & 0 deletions python/tvm/relax/transform/legalize_ops/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,36 @@ def tril_triu_call_te(bb: BlockBuilder, call: Call) -> Expr:
register_legalize("relax.triu", _tril_triu(is_upper=True, primfunc_name="triu"))


def _eye(is_like: bool, primfunc_name: str) -> LegalizeFunc:
def eye_call_te(bb: BlockBuilder, call: Call) -> Expr:
_convert_to_scalar_const = lambda x: _try_convert_to_scalar_const(x, python_native=True)
if is_like:
x = call.args[0]
k = _convert_to_scalar_const(call.args[1]) if len(call.args) > 1 else 0
n, m = x.struct_info.shape
dtype = x.struct_info.dtype
else:
n = _convert_to_scalar_const(call.args[0])
m = _convert_to_scalar_const(call.args[1]) if len(call.args) > 1 else n
k = _convert_to_scalar_const(call.args[2]) if len(call.args) > 2 else 0
dtype = call.attrs.dtype

return bb.call_te(
topi.eye,
n,
m,
k,
dtype,
primfunc_name_hint=primfunc_name,
)

return eye_call_te


register_legalize("relax.eye", _eye(is_like=False, primfunc_name="eye"))
register_legalize("relax.eye_like", _eye(is_like=True, primfunc_name="eye_like"))


@register_legalize("relax.arange")
def _arange(bb: BlockBuilder, call: Call) -> Expr:
assert len(call.args) == 3
Expand Down
Loading
Loading