Skip to content

Commit 02e8bbf

Browse files
authored
[Bugfix][TIR] Fix version conflict with typing for Python 3.8.0 (#13744)
I came across this bug under python3.8.0 with error from `typing.get_args()` while trying to run testcases like `tests/python/unittest/test_tir_schedule_set_axis_separator.py::test_set_axis_separator[transform_layout_named]` ``` > if get_origin(tp) is collections.abc.Callable and res[0] is not Ellipsis: E IndexError: tuple index out of range ``` And the root cause here is a difference between python3.8.0 and later version: ```diff get_args(Callable[[], T][int]) == ([], int) """ - if isinstance(tp, _GenericAlias): // python3.8.0 + if isinstance(tp, _GenericAlias) and not tp._special: // python3.8.15 res = tp.__args__ if get_origin(tp) is collections.abc.Callable and res[0] is not Ellipsis: } ``` So I added it back to `python/tvm/tir/schedule/_type_checker.py`
1 parent 60358a1 commit 02e8bbf

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

python/tvm/tir/schedule/_type_checker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def union(type_: Any) -> Optional[List[type]]: # pylint: disable=missing-functi
9898
@staticmethod
9999
def callable(type_: Any) -> Optional[List[type]]:
100100
if _Subtype._origin(type_) is collections.abc.Callable:
101-
if hasattr(typing, "get_args"):
101+
if hasattr(typing, "get_args") and not type_._special:
102102
subtypes = typing.get_args(type_) # type: ignore
103103
else:
104104
subtypes = type_.__args__

0 commit comments

Comments
 (0)