We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 95d1268 commit bcb68b1Copy full SHA for bcb68b1
python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -409,6 +409,32 @@ def call_binary_op(op, lhs, rhs):
409
410
return convert
411
412
+ def _div(self, node: fx.Node) -> relax.Var:
413
+ args = self.retrieve_args(node)
414
+ inp_1 = args[0]
415
+ inp_2 = args[1]
416
+
417
+ # Handle scalar cases
418
+ if isinstance(inp_2, (int, float)):
419
+ inp_2 = relax.const(inp_2)
420
421
+ # Get rounding_mode from node kwargs
422
+ rounding_mode = args[2] if len(node.args) > 2 else node.kwargs.get("rounding_mode", None)
423
424
+ # Perform division based on rounding mode
425
+ if rounding_mode is None:
426
+ # True division (normal float division)
427
+ return self.block_builder.emit(relax.op.divide(inp_1, inp_2))
428
+ elif rounding_mode == "floor":
429
+ # Floor division
430
+ return self.block_builder.emit(relax.op.floor_divide(inp_1, inp_2))
431
+ elif rounding_mode == "trunc":
432
+ # Trunc division: perform true division then truncate
433
+ true_div = self.block_builder.emit(relax.op.divide(inp_1, inp_2))
434
+ return self.block_builder.emit(relax.op.trunc(true_div))
435
+ else:
436
+ raise ValueError(f"Unsupported rounding_mode: {rounding_mode}")
437
438
def _fmod(self, node: fx.Node):
439
args = self.retrieve_args(node)
440
lhs = args[0]
python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -336,6 +336,7 @@ def create_convert_map(
336
"tanh.default": self._unary_op(relax.op.tanh),
337
"tril.default": self._tril_triu(relax.op.tril),
338
"triu.default": self._tril_triu(relax.op.triu),
339
+ "trunc.default": self._unary_op(relax.op.trunc),
340
# binary
341
"add.Tensor": self._binary_op(relax.op.add, operator.add),
342
"add_.Tensor": self._binary_op(relax.op.add, operator.add),
@@ -344,6 +345,7 @@ def create_convert_map(
344
345
"bitwise_or_.Tensor": self._binary_op(relax.op.bitwise_or, operator.or_),
346
"bitwise_or.Tensor": self._binary_op(relax.op.bitwise_or, operator.or_),
347
"div.Tensor": self._binary_op(relax.op.divide, operator.truediv),
348
+ "div.Tensor_mode": self._div,
349
"eq.Scalar": self._binary_op(relax.op.equal, operator.eq),
350
"eq.Tensor": self._binary_op(relax.op.equal, operator.eq),
351
"floor_divide.default": self._binary_op(relax.op.floor_divide, operator.floordiv),
python/tvm/relax/frontend/torch/fx_translator.py
@@ -725,11 +725,13 @@ def create_convert_map(
725
"tril": self._tril_triu(relax.op.tril),
726
"triu_": self._inplace_tril_triu(relax.op.triu),
727
"triu": self._tril_triu(relax.op.triu),
728
+ "trunc": self._unary_op(relax.op.trunc),
729
730
"add": self._binary_op(relax.op.add, operator.add),
731
"and_": self._binary_op(relax.op.bitwise_and, operator.and_),
732
"bitwise_or_": self._binary_op_inplace(relax.op.bitwise_or, operator.or_),
733
"bitwise_or": self._binary_op(relax.op.bitwise_or, operator.or_),
734
+ "div": self._div,
735
"eq": self._binary_op(relax.op.equal, operator.eq),
736
"floordiv": self._binary_op(relax.op.floor_divide, operator.floordiv),
737
"fmod": self._fmod,
python/tvm/relax/op/__init__.py
@@ -150,6 +150,7 @@
150
square,
151
tan,
152
tanh,
153
+ trunc,
154
)
155
156
python/tvm/relax/op/unary.py
@@ -511,6 +511,20 @@ def tanh(x: Expr) -> Expr:
511
return _ffi_api.tanh(x) # type: ignore
512
513
514
+def trunc(x: Expr) -> Expr:
515
+ """Take trunc of input data.
516
+ Parameters
517
+ ----------
518
+ x : relax.Expr
519
+ The input data
520
+ Returns
521
+ -------
522
+ result : relax.Expr
523
+ The computed result.
524
+ """
525
+ return _ffi_api.trunc(x) # type: ignore
526
527
528
@args_converter.auto
529
def clip(x: Expr, min: Expr, max: Expr) -> Expr:
530
"""Clips tensor values to a specified min and max.
python/tvm/relax/transform/legalize_ops/unary.py
@@ -50,6 +50,7 @@
50
register_legalize("relax.sqrt", _call_topi_without_attr(topi.sqrt, "tir_sqrt"))
51
register_legalize("relax.tan", _call_topi_without_attr(topi.tan, "tir_tan"))
52
register_legalize("relax.tanh", _call_topi_without_attr(topi.tanh, "tir_tanh"))
53
+register_legalize("relax.trunc", _call_topi_without_attr(topi.trunc, "tir_trunc"))
54
register_legalize("relax.clip", _call_topi_without_attr(topi.clip, "tir_clip"))
55
56
python/tvm/script/ir_builder/relax/ir.py
@@ -174,6 +174,7 @@
174
topk,
175
tril,
176
triu,
177
178
unique,
179
variance,
180
vm,
@@ -870,6 +871,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
870
871
"to_vdevice",
872
"tril",
873
"triu",
874
+ "trunc",
875
"tuple",
876
"unique",
877
"variance",
src/relax/op/tensor/unary.cc
@@ -62,6 +62,7 @@ RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(square, /*require_float_dtype=*/false);
62
RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(sqrt, /*require_float_dtype=*/true);
63
RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(tan, /*require_float_dtype=*/true);
64
RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(tanh, /*require_float_dtype=*/true);
65
+RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(trunc, /*require_float_dtype=*/false);
66
RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(erf, /*require_float_dtype=*/true);
67
68
// relax.clip
src/relax/op/tensor/unary.h
@@ -133,6 +133,9 @@ Expr tan(Expr x);
133
/*! \brief Compute element-wise tanh of data. */
134
Expr tanh(Expr x);
135
136
+/*! \brief Take trunc of input data (round towards zero). */
137
+Expr trunc(Expr x);
138
139
/*! \brief Clips tensor values to a specified min and max. */
140
Expr clip(Expr x, Expr min, Expr max);
141
src/target/intrin_rule.cc
@@ -55,6 +55,9 @@ TVM_REGISTER_OP("tir.tanh")
TVM_REGISTER_OP("tir.tan").set_attr<FLowerIntrinsic>("default.FLowerIntrinsic",
DispatchPureExtern<FloatSuffix>);
57
58
+TVM_REGISTER_OP("tir.trunc")
59
+ .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", DispatchPureExtern<FloatSuffix>);
60
61
TVM_REGISTER_OP("tir.atan")
.set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", DispatchPureExtern<FloatSuffix>);
0 commit comments