Skip to content

Commit bcb68b1

Browse files
Deivanayaki-Sdeivanayakisankaralingam
andauthored
[Relax][PyTorch] Add div.Tensor_mode and trunc Op Support for Exported Program and FX graph (#17924)
* add div.Tensor_mode and trunc op support * rename the func name into _div --------- Co-authored-by: deivanayakisankaralingam <deiva@Deivanayaki>
1 parent 95d1268 commit bcb68b1

File tree

12 files changed

+199
-0
lines changed

12 files changed

+199
-0
lines changed

python/tvm/relax/frontend/torch/base_fx_graph_translator.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,32 @@ def call_binary_op(op, lhs, rhs):
409409

410410
return convert
411411

412+
def _div(self, node: fx.Node) -> relax.Var:
413+
args = self.retrieve_args(node)
414+
inp_1 = args[0]
415+
inp_2 = args[1]
416+
417+
# Handle scalar cases
418+
if isinstance(inp_2, (int, float)):
419+
inp_2 = relax.const(inp_2)
420+
421+
# Get rounding_mode from node kwargs
422+
rounding_mode = args[2] if len(node.args) > 2 else node.kwargs.get("rounding_mode", None)
423+
424+
# Perform division based on rounding mode
425+
if rounding_mode is None:
426+
# True division (normal float division)
427+
return self.block_builder.emit(relax.op.divide(inp_1, inp_2))
428+
elif rounding_mode == "floor":
429+
# Floor division
430+
return self.block_builder.emit(relax.op.floor_divide(inp_1, inp_2))
431+
elif rounding_mode == "trunc":
432+
# Trunc division: perform true division then truncate
433+
true_div = self.block_builder.emit(relax.op.divide(inp_1, inp_2))
434+
return self.block_builder.emit(relax.op.trunc(true_div))
435+
else:
436+
raise ValueError(f"Unsupported rounding_mode: {rounding_mode}")
437+
412438
def _fmod(self, node: fx.Node):
413439
args = self.retrieve_args(node)
414440
lhs = args[0]

python/tvm/relax/frontend/torch/exported_program_translator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,7 @@ def create_convert_map(
336336
"tanh.default": self._unary_op(relax.op.tanh),
337337
"tril.default": self._tril_triu(relax.op.tril),
338338
"triu.default": self._tril_triu(relax.op.triu),
339+
"trunc.default": self._unary_op(relax.op.trunc),
339340
# binary
340341
"add.Tensor": self._binary_op(relax.op.add, operator.add),
341342
"add_.Tensor": self._binary_op(relax.op.add, operator.add),
@@ -344,6 +345,7 @@ def create_convert_map(
344345
"bitwise_or_.Tensor": self._binary_op(relax.op.bitwise_or, operator.or_),
345346
"bitwise_or.Tensor": self._binary_op(relax.op.bitwise_or, operator.or_),
346347
"div.Tensor": self._binary_op(relax.op.divide, operator.truediv),
348+
"div.Tensor_mode": self._div,
347349
"eq.Scalar": self._binary_op(relax.op.equal, operator.eq),
348350
"eq.Tensor": self._binary_op(relax.op.equal, operator.eq),
349351
"floor_divide.default": self._binary_op(relax.op.floor_divide, operator.floordiv),

python/tvm/relax/frontend/torch/fx_translator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -725,11 +725,13 @@ def create_convert_map(
725725
"tril": self._tril_triu(relax.op.tril),
726726
"triu_": self._inplace_tril_triu(relax.op.triu),
727727
"triu": self._tril_triu(relax.op.triu),
728+
"trunc": self._unary_op(relax.op.trunc),
728729
# binary
729730
"add": self._binary_op(relax.op.add, operator.add),
730731
"and_": self._binary_op(relax.op.bitwise_and, operator.and_),
731732
"bitwise_or_": self._binary_op_inplace(relax.op.bitwise_or, operator.or_),
732733
"bitwise_or": self._binary_op(relax.op.bitwise_or, operator.or_),
734+
"div": self._div,
733735
"eq": self._binary_op(relax.op.equal, operator.eq),
734736
"floordiv": self._binary_op(relax.op.floor_divide, operator.floordiv),
735737
"fmod": self._fmod,

python/tvm/relax/op/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@
150150
square,
151151
tan,
152152
tanh,
153+
trunc,
153154
)
154155

155156

python/tvm/relax/op/unary.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,20 @@ def tanh(x: Expr) -> Expr:
511511
return _ffi_api.tanh(x) # type: ignore
512512

513513

514+
def trunc(x: Expr) -> Expr:
515+
"""Take trunc of input data.
516+
Parameters
517+
----------
518+
x : relax.Expr
519+
The input data
520+
Returns
521+
-------
522+
result : relax.Expr
523+
The computed result.
524+
"""
525+
return _ffi_api.trunc(x) # type: ignore
526+
527+
514528
@args_converter.auto
515529
def clip(x: Expr, min: Expr, max: Expr) -> Expr:
516530
"""Clips tensor values to a specified min and max.

python/tvm/relax/transform/legalize_ops/unary.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
register_legalize("relax.sqrt", _call_topi_without_attr(topi.sqrt, "tir_sqrt"))
5151
register_legalize("relax.tan", _call_topi_without_attr(topi.tan, "tir_tan"))
5252
register_legalize("relax.tanh", _call_topi_without_attr(topi.tanh, "tir_tanh"))
53+
register_legalize("relax.trunc", _call_topi_without_attr(topi.trunc, "tir_trunc"))
5354
register_legalize("relax.clip", _call_topi_without_attr(topi.clip, "tir_clip"))
5455

5556

python/tvm/script/ir_builder/relax/ir.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@
174174
topk,
175175
tril,
176176
triu,
177+
trunc,
177178
unique,
178179
variance,
179180
vm,
@@ -870,6 +871,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
870871
"to_vdevice",
871872
"tril",
872873
"triu",
874+
"trunc",
873875
"tuple",
874876
"unique",
875877
"variance",

src/relax/op/tensor/unary.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(square, /*require_float_dtype=*/false);
6262
RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(sqrt, /*require_float_dtype=*/true);
6363
RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(tan, /*require_float_dtype=*/true);
6464
RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(tanh, /*require_float_dtype=*/true);
65+
RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(trunc, /*require_float_dtype=*/false);
6566
RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(erf, /*require_float_dtype=*/true);
6667

6768
// relax.clip

src/relax/op/tensor/unary.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,9 @@ Expr tan(Expr x);
133133
/*! \brief Compute element-wise tanh of data. */
134134
Expr tanh(Expr x);
135135

136+
/*! \brief Take trunc of input data (round towards zero). */
137+
Expr trunc(Expr x);
138+
136139
/*! \brief Clips tensor values to a specified min and max. */
137140
Expr clip(Expr x, Expr min, Expr max);
138141

src/target/intrin_rule.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ TVM_REGISTER_OP("tir.tanh")
5555
TVM_REGISTER_OP("tir.tan").set_attr<FLowerIntrinsic>("default.FLowerIntrinsic",
5656
DispatchPureExtern<FloatSuffix>);
5757

58+
TVM_REGISTER_OP("tir.trunc")
59+
.set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", DispatchPureExtern<FloatSuffix>);
60+
5861
TVM_REGISTER_OP("tir.atan")
5962
.set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", DispatchPureExtern<FloatSuffix>);
6063

0 commit comments

Comments
 (0)