Skip to content

Commit d200c79

Browse files
committed
Address comments
Change-Id: I506bb622065e516b4e7111d461e3dac6ee821889
1 parent aef55bf commit d200c79

File tree

4 files changed

+32
-7
lines changed

4 files changed

+32
-7
lines changed

src/arith/scalable_expression.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,15 @@ std::optional<int> ExtractVscaleFactor(const PrimExpr& lanes) {
6666
}
6767
}
6868

69+
bool IsComparison(const PrimExpr& expr) {
70+
return expr->IsInstance<tir::LENode>() || expr->IsInstance<tir::LTNode>() ||
71+
expr->IsInstance<tir::GENode>() || expr->IsInstance<tir::GTNode>() ||
72+
expr->IsInstance<tir::EQNode>() || expr->IsInstance<tir::NENode>();
73+
}
74+
6975
bool CanProveVscaleExpressionFromKnownValues(arith::Analyzer* analyzer, const PrimExpr& expr,
7076
const std::vector<unsigned int>& vscale_values) {
77+
ICHECK(IsComparison(expr)) << "Expected comparison but got: " << expr;
7178
bool can_prove_expr = true;
7279
for (const unsigned int vscale_value : vscale_values) {
7380
PrimExpr result = SubstituteVScaleWithKnownValue(expr, vscale_value);

src/arith/scalable_expression.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ std::optional<int> ExtractVscaleFactor(const PrimExpr& lanes);
6565
of vscale.
6666
* \param analyzer An analyzer instance.
6767
* \param expr The expression to try to prove.
68+
* \param vscale_values A list of values to substitute vscale with.
6869
* \return Whether or not the expression can be proven with this technique.
6970
*/
7071
bool CanProveVscaleExpressionFromKnownValues(arith::Analyzer* analyzer, const PrimExpr& expr,

tests/python/arith/test_arith_simplify.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import tvm
2121
import tvm.testing
2222
from tvm import tir
23+
from tvm.script import tir as T
2324

2425

2526
def test_simplify_reshape_flattened_index():
@@ -56,14 +57,20 @@ def test_simplify_symbolic_comparison():
5657
assert ana.can_prove((n + 31) // 32 * 32 >= i0 * 32 + i1, PS.SYMBOLIC_BOUND)
5758

5859

59-
def test_simplify_vscale_comparison_with_sve_target():
60+
@pytest.mark.parametrize(
61+
"expression",
62+
[
63+
T.vscale() * 32 < T.vscale() * 64,
64+
T.vscale() * 2 * (T.vscale() * 2) >= T.vscale() * 4,
65+
(T.vscale() * 4 + 114) // (T.vscale() * 4) * (T.vscale() * 4) >= 115,
66+
64 % T.vscale() <= T.vscale(),
67+
],
68+
)
69+
def test_simplify_vscale_comparison_with_sve_target(expression):
6070
ana = tvm.arith.Analyzer()
61-
vs = tvm.tir.vscale()
6271

6372
with tvm.target.Target("llvm -mtriple=aarch64-linux-gnu -mattr=+sve"):
64-
assert ana.can_prove(vs * 32 < vs * 64)
65-
assert ana.can_prove(vs * 2 * (vs * 2) >= vs * 4)
66-
assert ana.can_prove((vs * 4 + 114) // (vs * 4) * (vs * 4) >= 115)
73+
assert ana.can_prove(expression)
6774

6875

6976
def test_simplify_vscale_comparison_without_sve_target(capfd):
@@ -83,6 +90,16 @@ def test_simplify_vscale_comparison_without_sve_target(capfd):
8390
assert warning_msg in capture
8491

8592

93+
def test_simplify_vscale_non_comparison():
94+
ana = tvm.arith.Analyzer()
95+
vs = tvm.tir.vscale()
96+
97+
err_msg = r".*Expected comparison but got: T.vscale\(\) \* 4"
98+
with pytest.raises(tvm.TVMError, match=err_msg):
99+
with tvm.target.Target("llvm -mtriple=aarch64-linux-gnu -mattr=+sve"):
100+
ana.can_prove(vs * 4)
101+
102+
86103
def test_regression_simplify_inf_recursion():
87104
ana = tvm.arith.Analyzer()
88105
cond = tir.Var("cond", "int32")

tests/python/tir-schedule/test_tir_schedule_split_fuse.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -697,7 +697,7 @@ def test_sve_scalable_split_assume_exact_multiple():
697697
If the schedule writer knows the extent of the loop to be split will always
698698
be a multiple of vscale, they may use `disable_predication=True` to ensure
699699
a predicate is not created. This can be used to ensure predication is not
700-
inserted where current analysis is not powerful enough to recognise this.
700+
inserted.
701701
"""
702702

703703
@T.prim_func
@@ -761,7 +761,7 @@ def after(a: T.handle):
761761
tvm.ir.assert_structural_equal(sch.mod["main"], after)
762762

763763

764-
def test_default_scalable_split(capfd):
764+
def test_unsupported_target_scalable_split(capfd):
765765
@T.prim_func
766766
def before(a: T.handle):
767767
A = T.match_buffer(a, (128,), "float32")

0 commit comments

Comments
 (0)