Skip to content

Commit ec4ade0

Browse files
Lunderbergyongwww
authored andcommitted
[TVMScript][Relax] Use tir.SizeVar for shape variables
As reported in apache#16877, shape inference performed during a Relax transformation may produce different results than shape inference performed during TVMScript parsing. While Relax transformations call `Analyzer::MarkGlobalNonNegValue` for each shape expression, this is not provided during TVMScript parsing. As a result, output shapes that are conditional on the sign of a variable may produce different results when simplifying. This commit provides a partial resolution for this issue. Where prior to this commit, the TVMScript parser generated a `tir.Var` for each symbolic variable, the TVMScript parser now provides a `tir.SizeVar` for symbolic variables that are defined in contexts that require a non-negative size.
1 parent 046c1ba commit ec4ade0

File tree

3 files changed

+2688
-388
lines changed

3 files changed

+2688
-388
lines changed

python/tvm/script/parser/relax/entry.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,9 @@ def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> Struc
154154
def get_symbolic_vars(self) -> Set[str]:
155155
return {}
156156

157+
def get_symbolic_size_vars(self) -> Set[str]:
158+
return self.get_symbolic_vars()
159+
157160
def asobject(self):
158161
return self.as_struct_info(None)
159162

@@ -176,9 +179,6 @@ class ObjectProxy(StructInfoProxy):
176179
def __init__(self) -> None:
177180
pass
178181

179-
def get_symbolic_vars(self) -> Set[str]:
180-
return set()
181-
182182
def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> ShapeStructInfo:
183183
return ObjectStructInfo()
184184

@@ -341,6 +341,12 @@ def get_symbolic_vars(self) -> Set[str]:
341341
else:
342342
return set().union(*[p.get_symbolic_vars() for p in self.params])
343343

344+
def get_symbolic_size_vars(self) -> Set[str]:
345+
if self.params is None:
346+
return set()
347+
else:
348+
return set().union(*[p.get_symbolic_size_vars() for p in self.params])
349+
344350
def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> FuncStructInfo:
345351
if self.ret is None:
346352
ret = None
@@ -394,6 +400,9 @@ def __init__(
394400
def get_symbolic_vars(self) -> Set[str]:
395401
return set().union(*[f.get_symbolic_vars() for f in self.fields])
396402

403+
def get_symbolic_size_vars(self) -> Set[str]:
404+
return set().union(*[f.get_symbolic_size_vars() for f in self.fields])
405+
397406
def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> TupleStructInfo:
398407
fields = [field.as_struct_info(dict_globals) for field in self.fields]
399408
return TupleStructInfo(fields)
@@ -480,6 +489,13 @@ def get_symbolic_vars(self) -> Set[str]:
480489
else:
481490
return set()
482491

492+
def get_symbolic_size_vars(self) -> Set[str]:
493+
# While variables defined by R.Shape and R.Tensor arguments
494+
# are known to be non-negative, R.Prim arguments may be
495+
# negative. Overriding the default implementation of
496+
# `get_symbolic_size_vars()`
497+
return set()
498+
483499
def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> ShapeStructInfo:
484500
if self.value is None:
485501
return PrimStructInfo(dtype=self.dtype)

0 commit comments

Comments
 (0)