Skip to content

Commit 57a0ad5

Browse files
Lunderbergyongwww
authored andcommitted
[TVMScript][Relax] Use tir.SizeVar for shape variables
As reported in apache#16877, shape inference performed during a Relax transformation may produce different results than shape inference performed during TVMScript parsing. While Relax transformations call `Analyzer::MarkGlobalNonNegValue` for each shape expression, this is not provided during TVMScript parsing. As a result, output shapes that are conditional on the sign of a variable may produce different results when simplifying. This commit provides a partial resolution for this issue. Where prior to this commit, the TVMScript parser generated a `tir.Var` for each symbolic variable, the TVMScript parser now provides a `tir.SizeVar` for symbolic variables that are defined in contexts that require a non-negative size.
1 parent 046c1ba commit 57a0ad5

File tree

3 files changed

+167
-4
lines changed

3 files changed

+167
-4
lines changed

python/tvm/script/parser/relax/entry.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,9 @@ def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> Struc
154154
def get_symbolic_vars(self) -> Set[str]:
155155
return {}
156156

157+
def get_symbolic_size_vars(self) -> Set[str]:
158+
return self.get_symbolic_vars()
159+
157160
def asobject(self):
158161
return self.as_struct_info(None)
159162

@@ -176,9 +179,6 @@ class ObjectProxy(StructInfoProxy):
176179
def __init__(self) -> None:
177180
pass
178181

179-
def get_symbolic_vars(self) -> Set[str]:
180-
return set()
181-
182182
def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> ShapeStructInfo:
183183
return ObjectStructInfo()
184184

@@ -341,6 +341,12 @@ def get_symbolic_vars(self) -> Set[str]:
341341
else:
342342
return set().union(*[p.get_symbolic_vars() for p in self.params])
343343

344+
def get_symbolic_size_vars(self) -> Set[str]:
345+
if self.params is None:
346+
return set()
347+
else:
348+
return set().union(*[p.get_symbolic_size_vars() for p in self.params])
349+
344350
def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> FuncStructInfo:
345351
if self.ret is None:
346352
ret = None
@@ -394,6 +400,9 @@ def __init__(
394400
def get_symbolic_vars(self) -> Set[str]:
395401
return set().union(*[f.get_symbolic_vars() for f in self.fields])
396402

403+
def get_symbolic_size_vars(self) -> Set[str]:
404+
return set().union(*[f.get_symbolic_size_vars() for f in self.fields])
405+
397406
def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> TupleStructInfo:
398407
fields = [field.as_struct_info(dict_globals) for field in self.fields]
399408
return TupleStructInfo(fields)
@@ -480,6 +489,13 @@ def get_symbolic_vars(self) -> Set[str]:
480489
else:
481490
return set()
482491

492+
def get_symbolic_size_vars(self) -> Set[str]:
493+
# While variables defined by R.Shape and R.Tensor arguments
494+
# are known to be non-negative, R.Prim arguments may be
495+
# negative. Overriding the default implementation of
496+
# `get_symbolic_size_vars()`
497+
return set()
498+
483499
def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> ShapeStructInfo:
484500
if self.value is None:
485501
return PrimStructInfo(dtype=self.dtype)

python/tvm/script/parser/relax/parser.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ def collect_symbolic_var_from_prelude(
177177
def collect_symbolic_var_from_params(self: Parser, node: doc.FunctionDef) -> None:
178178
# Collect symbolic vars from parameters
179179
symbolic_vars = {}
180+
symbolic_size_vars = set()
180181
for arg in node.args.args:
181182
if arg.annotation is None:
182183
self.report_error(arg, "Type annotation is required for function parameters.")
@@ -186,12 +187,20 @@ def collect_symbolic_var_from_params(self: Parser, node: doc.FunctionDef) -> Non
186187
if var_name not in symbolic_vars:
187188
symbolic_vars[var_name] = tir.Var(var_name, "int64")
188189

190+
symbolic_size_vars.update(param_sinfo_proxy.get_symbolic_size_vars())
191+
192+
assert len(symbolic_size_vars - symbolic_vars) == 0, (
193+
"Internal error: "
194+
"All collected tir.SizeVar names must also appear in the list of tir.Var names"
195+
)
196+
189197
# Update symbolic vars based on
190198
symbolic_vars = collect_symbolic_var_from_prelude(self, node, symbolic_vars)
191199

192200
# Define symbolic vars to the current var_table frame
193201
for var_name, var in symbolic_vars.items():
194-
self.var_table.add(var_name, var, allow_shadowing=False)
202+
var_cls = tir.SizeVar if var_name in symbolic_size_vars else tir.Var
203+
self.var_table.add(var_name, var_cls(var_name, "int64"), allow_shadowing=False)
195204

196205

197206
@dispatch.register(token="relax", type_name="FunctionDef")

tests/python/relax/test_tvmscript_parser.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2440,6 +2440,144 @@ def return_inside_dataflow(A: R.Tensor([16], "float16")):
24402440

24412441
tvm.ir.assert_structural_equal(output_then_return, return_inside_dataflow)
24422442

2443+
def test_symbolic_shape_variables_are_size_var():
2444+
"""Symbolic variables inferred from shapes are SizeVar
2445+
The indices in `R.strided_slice` follow Python's conventions for
2446+
negative indices. Absent any additional information, a slice
2447+
`arr[0:i]` would either have length `i` when `i >= 0`, or length
2448+
`len(arr) + i` when `i < 0`.
2449+
In this case, though, the dynamic `extent` variable is known to be
2450+
non-negative, because negative values may not be used as the
2451+
dimensions of `R.Tensor` or `R.Shape`. Because Relax struct
2452+
inference is performed while TVMScript is being parsed, this
2453+
constraint must be exposed during TVMScript parsing in order to
2454+
correctly infer the resulting StructInfo.
2455+
"""
2456+
2457+
@R.function(private=True)
2458+
def inferred_sinfo(A: R.Tensor(["extent"])):
2459+
extent = T.int64()
2460+
output = R.strided_slice(A, [0], [0], [extent])
2461+
return output
2462+
2463+
@R.function(private=True)
2464+
def expected(A: R.Tensor(["extent"])) -> R.Tensor(["extent"]):
2465+
extent = T.int64()
2466+
output: R.Tensor([extent]) = R.strided_slice(A, [0], [0], [extent])
2467+
return output
2468+
2469+
tvm.ir.assert_structural_equal(inferred_sinfo, expected)
2470+
2471+
assert isinstance(inferred_sinfo.params[0].struct_info.shape[0], tir.SizeVar)
2472+
2473+
2474+
def test_symbolic_variables_from_prim_value_may_be_negative():
2475+
"""Symbolic variables inferred from R.Prim are Var
2476+
Not all symbolic variables represent shapes. While a
2477+
`relax::PrimValue` can be the source of definition for a TIR
2478+
variable, a `relax::PrimValue` may not represent a shape, and may
2479+
be negative.
2480+
This test is similar to
2481+
`test_symbolic_shape_variables_are_size_var`, except that the
2482+
`extent` variable is defined by a `R.Prim` argument, and not by a
2483+
`R.Tensor` argument. As a result, we do not know whether `extent`
2484+
is negative, and cannot simplify expressions that depend on
2485+
`extent<0`.
2486+
"""
2487+
2488+
@R.function(private=True)
2489+
def inferred_sinfo(A: R.Tensor([16]), _: R.Prim(value="extent")):
2490+
extent = T.int64()
2491+
output = R.strided_slice(A, [0], [0], [extent])
2492+
return output
2493+
2494+
@R.function(private=True)
2495+
def expected(A: R.Tensor([16]), _: R.Prim(value="extent")):
2496+
extent = T.int64()
2497+
output: R.Tensor(
2498+
[T.min(T.max(T.if_then_else(extent < 0, extent + 16, extent), 0), 16)]
2499+
) = R.strided_slice(A, [0], [0], [extent])
2500+
return output
2501+
2502+
tvm.ir.assert_structural_equal(inferred_sinfo, expected)
2503+
2504+
assert not isinstance(inferred_sinfo.params[1].struct_info.value, tir.SizeVar)
2505+
2506+
2507+
def test_other_arguments_may_cause_prim_value_to_define_size_var():
2508+
"""Other arguments may cause R.Prim to hold SizeVar
2509+
This test is similar to
2510+
`test_symbolic_variables_from_prim_value_may_be_negative`, except
2511+
that `extent` also appears in a `R.Shape`. While the
2512+
`R.Prim(value="extent")` occurs first in the parameter list, and
2513+
is the source of definition, the presence of `extent` in `R.Shape`
2514+
parameter shows that it is a `SizeVar`.
2515+
"""
2516+
2517+
@R.function(private=True)
2518+
def inferred_sinfo(
2519+
A: R.Tensor([16]),
2520+
_prim: R.Prim(value="extent"),
2521+
_shape: R.Shape(
2522+
["extent"],
2523+
),
2524+
):
2525+
extent = T.int64()
2526+
output = R.strided_slice(A, [0], [0], [extent])
2527+
return output
2528+
2529+
@R.function(private=True)
2530+
def expected(
2531+
A: R.Tensor([16]),
2532+
_prim: R.Prim(value="extent"),
2533+
_shape: R.Shape(["extent"]),
2534+
):
2535+
extent = T.int64()
2536+
output: R.Tensor([T.min(extent, 16)]) = R.strided_slice(A, [0], [0], [extent])
2537+
return output
2538+
2539+
tvm.ir.assert_structural_equal(inferred_sinfo, expected)
2540+
2541+
assert isinstance(inferred_sinfo.params[1].struct_info.value, tir.SizeVar)
2542+
2543+
2544+
@pytest.mark.xfail(reason="Bug: Implicit bounds not provided when parsing")
2545+
def test_known_positive_expressions():
2546+
"""Expressions may be known as non-negative
2547+
The variable `N` is not defined as a shape variable, and may be
2548+
either positive or negative. However, the expression `N+16` is
2549+
used as the shape of a tensor, and is therefore known not to be
2550+
negative. Later use of the expression `N+16 < 0` may therefore be
2551+
simplified.
2552+
This test is currently marked as failing. When using
2553+
`relax::BlockBuilder::VisitWithNewScope` is provided with
2554+
parameters, it can mark shape expressions as non-negative, in
2555+
addition to individual variables. However, this is not currently
2556+
used for TVMScript parsing.
2557+
"""
2558+
2559+
@R.function(private=True)
2560+
def inferred_sinfo(
2561+
A: R.Tensor(["N + 16"]),
2562+
_: R.Prim(value="N"),
2563+
):
2564+
N = T.int64()
2565+
output = R.strided_slice(A, [0], [0], [N + 16])
2566+
return output
2567+
2568+
@R.function(private=True)
2569+
def expected(
2570+
A: R.Tensor(["N + 16"]),
2571+
_: R.Prim(value="N"),
2572+
):
2573+
N = T.int64()
2574+
output: R.Tensor([N + 16]) = R.strided_slice(A, [0], [0], [N + 16])
2575+
return output
2576+
2577+
tvm.ir.assert_structural_equal(inferred_sinfo, expected)
2578+
2579+
assert not isinstance(inferred_sinfo.params[1].struct_info.value, tir.SizeVar)
2580+
24432581

24442582
if __name__ == "__main__":
24452583
tvm.testing.main()

0 commit comments

Comments
 (0)