Skip to content

Commit 6ca2341

Browse files
authored
[Relax] Remove the legalization of cumsum/cumprob (#16676)
* [Relax] Remove the legalization of cumsum/cumprob * remove related tests
1 parent d284cf4 commit 6ca2341

File tree

3 files changed

+0
-84
lines changed

3 files changed

+0
-84
lines changed

python/tvm/relax/transform/legalize_ops/statistical.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -85,17 +85,3 @@ def _variance(bb: BlockBuilder, call: Call) -> Expr:
8585
register_legalize("relax.min", _statistical(topi.min))
8686
register_legalize("relax.prod", _statistical(topi.prod))
8787
register_legalize("relax.sum", _statistical(topi.sum))
88-
89-
90-
@register_legalize("relax.cumsum")
91-
def _cumsum(bb: BlockBuilder, call: Call) -> Expr:
92-
return bb.call_te(
93-
topi.cumsum, call.args[0], call.attrs.axis, call.attrs.dtype, call.attrs.exclusive
94-
)
95-
96-
97-
@register_legalize("relax.cumprod")
98-
def _cumprod(bb: BlockBuilder, call: Call) -> Expr:
99-
return bb.call_te(
100-
topi.cumprod, call.args[0], call.attrs.axis, call.attrs.dtype, call.attrs.exclusive
101-
)

tests/python/relax/test_frontend_nn_op.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1161,7 +1161,6 @@ def foo(prob: R.Tensor((2, 3), dtype="float32"), sorted_prob: R.Tensor((2, 3), d
11611161

11621162
target = tvm.target.Target("cuda -libs=thrust", host="llvm")
11631163
with target:
1164-
mod = relax.backend.DispatchSortScan()(mod)
11651164
mod = relax.transform.LegalizeOps()(mod)
11661165
mod = tir.transform.DefaultGPUSchedule()(mod)
11671166

tests/python/relax/test_transform_legalize_ops_search_statistical.py

Lines changed: 0 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1066,74 +1066,5 @@ def main(x: R.Tensor((2, 3, 4, 5), dtype="float32")) -> R.Tensor((3, 4), dtype="
10661066
tvm.ir.assert_structural_equal(mod, Expected)
10671067

10681068

1069-
def test_cumsum():
1070-
# fmt: off
1071-
@I.ir_module
1072-
class Cumsum:
1073-
@R.function
1074-
def main(x: R.Tensor((3, 2, 3), "float32")):
1075-
gv = R.cumsum(x, axis=1, dtype="int32")
1076-
return gv
1077-
1078-
@I.ir_module
1079-
class Expected:
1080-
@T.prim_func(private=True)
1081-
def cumsum(var_rxplaceholder: T.handle, out_buf: T.Buffer((T.int64(3), T.int64(2), T.int64(3)), "int32")):
1082-
T.func_attr({"tir.noalias": True})
1083-
rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(3), T.int64(2), T.int64(3)), offset_factor=1)
1084-
with T.block("cumsum_generic"):
1085-
for fused in T.parallel(T.int64(9)):
1086-
out_buf[(fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3)) // T.int64(3) // T.int64(2), (fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3)) // T.int64(3) % T.int64(2), (fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3)) % T.int64(3)] = T.Cast("int32", rxplaceholder[(fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3)) // T.int64(3) // T.int64(2), (fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3)) // T.int64(3) % T.int64(2), (fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3)) % T.int64(3)])
1087-
for _k in range(T.int64(1)):
1088-
out_buf[(fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3) + (_k + T.int64(1)) * T.int64(3)) // T.int64(3) // T.int64(2), (fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3) + (_k + T.int64(1)) * T.int64(3)) // T.int64(3) % T.int64(2), (fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3) + (_k + T.int64(1)) * T.int64(3)) % T.int64(3)] = out_buf[(fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3) + (_k + T.int64(1) - T.int64(1)) * T.int64(3)) // T.int64(3) // T.int64(2), (fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3) + (_k + T.int64(1) - T.int64(1)) * T.int64(3)) // T.int64(3) % T.int64(2), (fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3) + (_k + T.int64(1) - T.int64(1)) * T.int64(3)) % T.int64(3)] + T.Cast("int32", rxplaceholder[(fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3) + (_k + T.int64(1)) * T.int64(3)) // T.int64(3) // T.int64(2), (fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3) + (_k + T.int64(1)) * T.int64(3)) // T.int64(3) % T.int64(2), (fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3) + (_k + T.int64(1)) * T.int64(3)) % T.int64(3)])
1089-
1090-
@R.function
1091-
def main(x: R.Tensor((3, 2, 3), dtype="float32")) -> R.Tensor((3, 2, 3), dtype="int32"):
1092-
cls = Expected
1093-
gv = R.call_tir(cls.cumsum, (x,), out_sinfo=R.Tensor((3, 2, 3), dtype="int32"))
1094-
return gv
1095-
# fmt: on
1096-
1097-
mod = LegalizeOps()(Cumsum)
1098-
tvm.ir.assert_structural_equal(mod, Expected)
1099-
1100-
1101-
def test_cumsum_symbolic():
1102-
# fmt: off
1103-
@I.ir_module
1104-
class Cumsum:
1105-
@R.function
1106-
def main(x: R.Tensor(("a", "b", "c"), "float32")):
1107-
gv = R.cumsum(x, axis=1, dtype="int32")
1108-
return gv
1109-
1110-
@I.ir_module
1111-
class Expected:
1112-
@T.prim_func(private=True)
1113-
def cumsum(var_rxplaceholder: T.handle, var_cumsum_generic: T.handle):
1114-
T.func_attr({"tir.noalias": True})
1115-
a, b, c = T.int64(), T.int64(), T.int64()
1116-
rxplaceholder = T.match_buffer(var_rxplaceholder, (a, b, c), offset_factor=1)
1117-
out_buf = T.match_buffer(var_cumsum_generic, (a, b, c), "int32")
1118-
with T.block("cumsum_generic"):
1119-
for fused in T.parallel(a * c):
1120-
out_buf[(fused // c * b * c + fused % c) // c // b, (fused // c * b * c + fused % c) // c % b, (fused // c * b * c + fused % c) % c] = T.Cast("int32", rxplaceholder[(fused // c * b * c + fused % c) // c // b, (fused // c * b * c + fused % c) // c % b, (fused // c * b * c + fused % c) % c])
1121-
for _k in range(b - T.int64(1)):
1122-
out_buf[(fused // c * b * c + fused % c + (_k + T.int64(1)) * c) // c // b, (fused // c * b * c + fused % c + (_k + T.int64(1)) * c) // c % b, (fused // c * b * c + fused % c + (_k + T.int64(1)) * c) % c] = out_buf[(fused // c * b * c + fused % c + (_k + T.int64(1) - T.int64(1)) * c) // c // b, (fused // c * b * c + fused % c + (_k + T.int64(1) - T.int64(1)) * c) // c % b, (fused // c * b * c + fused % c + (_k + T.int64(1) - T.int64(1)) * c) % c] + T.Cast("int32", rxplaceholder[(fused // c * b * c + fused % c + (_k + T.int64(1)) * c) // c // b, (fused // c * b * c + fused % c + (_k + T.int64(1)) * c) // c % b, (fused // c * b * c + fused % c + (_k + T.int64(1)) * c) % c])
1123-
1124-
@R.function
1125-
def main(x: R.Tensor(("a", "b", "c"), dtype="float32")) -> R.Tensor(("a", "b", "c"), dtype="int32"):
1126-
a = T.int64()
1127-
b = T.int64()
1128-
c = T.int64()
1129-
cls = Expected
1130-
gv = R.call_tir(cls.cumsum, (x,), out_sinfo=R.Tensor((a, b, c), dtype="int32"))
1131-
return gv
1132-
# fmt: on
1133-
1134-
mod = LegalizeOps()(Cumsum)
1135-
tvm.ir.assert_structural_equal(mod, Expected)
1136-
1137-
11381069
if __name__ == "__main__":
11391070
tvm.testing.main()

0 commit comments

Comments
 (0)