Skip to content

Commit cdca84a

Browse files
author
Yuanjing Shi
authored
[TVMScript][Fix] Add type hints for more uncovered cases (#9505)
* add support for prevously uncovered cases * remove PrimExpr import * add exp test and mypy ignore * disable ling too long * resolve long line * nit * add dtype to unary ops
1 parent d061d7f commit cdca84a

File tree

2 files changed

+139
-29
lines changed

2 files changed

+139
-29
lines changed

python/tvm/script/tir/__init__.pyi

Lines changed: 43 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,15 @@ class IterVar(Var): ...
6969

7070
class Buffer:
7171
@overload
72-
def __getitem__(self: Buffer, pos: Sequence[Union[PrimExpr, int]]) -> PrimExpr: ...
72+
def __getitem__(self: Buffer, pos: Sequence[Union[PrimExpr, int, slice]]) -> PrimExpr: ...
7373
@overload
74-
def __getitem__(self: Buffer, pos: Union[PrimExpr, int]) -> PrimExpr: ...
74+
def __getitem__(self: Buffer, pos: Union[PrimExpr, int, slice]) -> PrimExpr: ...
7575
@overload
76-
def __setitem__(self: Buffer, pos: Sequence[Union[PrimExpr, int]], value: PrimExpr) -> None: ...
76+
def __setitem__(
77+
self: Buffer, pos: Sequence[Union[PrimExpr, int, slice]], value: PrimExpr
78+
) -> None: ...
7779
@overload
78-
def __setitem__(self: Buffer, pos: Union[PrimExpr, int], value: PrimExpr) -> None: ...
80+
def __setitem__(self: Buffer, pos: Union[PrimExpr, int, slice], value: PrimExpr) -> None: ...
7981
@property
8082
def data(self: Buffer) -> Ptr: ...
8183

@@ -124,35 +126,47 @@ def reinterpret(value: PrimExpr, dtype: str) -> PrimExpr: ...
124126
def store(
125127
var: Var, index: PrimExpr, value: PrimExpr, predicate: Union[PrimExpr, builtins.bool] = True
126128
) -> None: ...
127-
def comm_reducer(lambda_io: Tuple[List[PrimExpr]], identities: List[PrimExpr]) -> PrimExpr: ...
129+
def comm_reducer(lambda_io: Callable[[Any, Any], Any], identities: List[PrimExpr]) -> PrimExpr: ...
130+
131+
"""
132+
Intrinsics - tvm builtin
133+
"""
134+
135+
def tvm_thread_allreduce(
136+
*freduceargs: Union[PrimExpr, builtins.bool, Ptr], dtype: str
137+
) -> PrimExpr: ...
128138

129139
"""
130140
Unary operator
141+
Note that any intrinsics not registered in script.tir.intrin
142+
should add "dtype" as an argument. This is different from their
143+
definition but intentional.
131144
"""
132145

133-
def exp2(x: PrimExpr) -> PrimExpr: ...
134-
def exp10(x: PrimExpr) -> PrimExpr: ...
135-
def erf(x: PrimExpr) -> PrimExpr: ...
136-
def tanh(x: PrimExpr) -> PrimExpr: ...
137-
def sigmoid(x: PrimExpr) -> PrimExpr: ...
138-
def log(x: PrimExpr) -> PrimExpr: ...
139-
def log2(x: PrimExpr) -> PrimExpr: ...
140-
def log10(x: PrimExpr) -> PrimExpr: ...
141-
def log1p(x: PrimExpr) -> PrimExpr: ...
142-
def tan(x: PrimExpr) -> PrimExpr: ...
143-
def cos(x: PrimExpr) -> PrimExpr: ...
144-
def cosh(x: PrimExpr) -> PrimExpr: ...
145-
def acos(x: PrimExpr) -> PrimExpr: ...
146-
def acosh(x: PrimExpr) -> PrimExpr: ...
147-
def sin(x: PrimExpr) -> PrimExpr: ...
148-
def sinh(x: PrimExpr) -> PrimExpr: ...
149-
def asin(x: PrimExpr) -> PrimExpr: ...
150-
def asinh(x: PrimExpr) -> PrimExpr: ...
151-
def atan(x: PrimExpr) -> PrimExpr: ...
152-
def atanh(x: PrimExpr) -> PrimExpr: ...
153-
def atan2(x: PrimExpr) -> PrimExpr: ...
154-
def sqrt(x: PrimExpr) -> PrimExpr: ...
155-
def rsqrt(x: PrimExpr) -> PrimExpr: ...
146+
def exp(x: PrimExpr, dtype: str) -> PrimExpr: ...
147+
def exp2(x: PrimExpr, dtype: str) -> PrimExpr: ...
148+
def exp10(x: PrimExpr, dtype: str) -> PrimExpr: ...
149+
def erf(x: PrimExpr, dtype: str) -> PrimExpr: ...
150+
def tanh(x: PrimExpr, dtype: str) -> PrimExpr: ...
151+
def sigmoid(x: PrimExpr, dtype: str) -> PrimExpr: ...
152+
def log(x: PrimExpr, dtype: str) -> PrimExpr: ...
153+
def log2(x: PrimExpr, dtype: str) -> PrimExpr: ...
154+
def log10(x: PrimExpr, dtype: str) -> PrimExpr: ...
155+
def log1p(x: PrimExpr, dtype: str) -> PrimExpr: ...
156+
def tan(x: PrimExpr, dtype: str) -> PrimExpr: ...
157+
def cos(x: PrimExpr, dtype: str) -> PrimExpr: ...
158+
def cosh(x: PrimExpr, dtype: str) -> PrimExpr: ...
159+
def acos(x: PrimExpr, dtype: str) -> PrimExpr: ...
160+
def acosh(x: PrimExpr, dtype: str) -> PrimExpr: ...
161+
def sin(x: PrimExpr, dtype: str) -> PrimExpr: ...
162+
def sinh(x: PrimExpr, dtype: str) -> PrimExpr: ...
163+
def asin(x: PrimExpr, dtype: str) -> PrimExpr: ...
164+
def asinh(x: PrimExpr, dtype: str) -> PrimExpr: ...
165+
def atan(x: PrimExpr, dtype: str) -> PrimExpr: ...
166+
def atanh(x: PrimExpr, dtype: str) -> PrimExpr: ...
167+
def atan2(x: PrimExpr, dtype: str) -> PrimExpr: ...
168+
def sqrt(x: PrimExpr, dtype: str) -> PrimExpr: ...
169+
def rsqrt(x: PrimExpr, dtype: str) -> PrimExpr: ...
156170

157171
"""
158172
special_stmt - Buffers
@@ -334,7 +348,7 @@ def for_range(
334348
end: Union[PrimExpr, int] = None,
335349
annotations: Optional[Mapping[str, Object]] = None,
336350
) -> Iterable[IterVar]: ...
337-
def grid(*extents: Union[PrimExpr, int]) -> Iterable[Tuple[IterVar]]: ...
351+
def grid(*extents: Union[PrimExpr, int]) -> Iterable[Sequence[IterVar]]: ...
338352

339353
"""
340354
ty - redefine types

tests/python/unittest/test_tvmscript_type.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,102 @@ def element_wise_env_thread_x(a: T.handle, b: T.handle, c: T.handle) -> None:
8181
)
8282

8383

84+
"""
85+
This test case is added to test T.grid
86+
"""
87+
88+
89+
@T.prim_func
90+
def loop_split(a: T.handle, b: T.handle) -> None:
91+
A = T.match_buffer(a, [128, 128], dtype="float32")
92+
B = T.match_buffer(b, [128], dtype="float32")
93+
for i, ko in T.grid(128, 4):
94+
for ki in T.thread_binding(0, 32, thread="threadIdx.x"):
95+
with T.block("B"):
96+
vi = T.axis.S(128, i)
97+
vk = T.axis.R(128, ko * 32 + ki)
98+
T.reads([B[vi], A[vi, vk]])
99+
T.writes([B[vi]])
100+
with T.init():
101+
B[vi] = T.float32(0)
102+
B[vi] = B[vi] + A[vi, vk]
103+
104+
105+
"""
106+
This test case is added to test T.comm_reducer, T.reinterpret, T.tvm_thread_allreduce
107+
"""
108+
109+
110+
@T.prim_func
111+
def lowered_loop_split(a: T.handle, b: T.handle) -> None:
112+
A = T.match_buffer(a, [128, 128], dtype="float32")
113+
B = T.match_buffer(b, [128], dtype="float32")
114+
reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local")
115+
normal_reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local")
116+
for i in T.serial(0, 128):
117+
for ki in T.thread_binding(0, 32, thread="threadIdx.x"):
118+
normal_reduce_temp0[0] = T.float32(0)
119+
for ko in T.serial(0, 4):
120+
with T.block("B_normal_reduction"):
121+
vi = T.axis.S(128, i)
122+
vk = T.axis.R(128, ko * 32 + ki)
123+
T.reads([A[vi, vk], normal_reduce_temp0[0]])
124+
T.writes([normal_reduce_temp0[0]])
125+
normal_reduce_temp0[0] = normal_reduce_temp0[0] + A[vi, vk]
126+
with T.block("B_cross_thread_reduction"):
127+
T.reads([normal_reduce_temp0[0]])
128+
T.writes([reduce_temp0[0]])
129+
T.attr(
130+
T.comm_reducer(lambda x, y: x + y, [T.float32(0)]),
131+
"reduce_scope",
132+
T.reinterpret(T.uint64(0), dtype="handle"),
133+
)
134+
T.evaluate(
135+
T.tvm_thread_allreduce(
136+
T.uint32(1),
137+
normal_reduce_temp0[0],
138+
True,
139+
reduce_temp0.data,
140+
ki,
141+
dtype="handle",
142+
)
143+
)
144+
with T.block("B_write_back"):
145+
vi = T.axis.S(128, i)
146+
T.reads([reduce_temp0[0]])
147+
T.writes([B[vi]])
148+
B[vi] = reduce_temp0[0]
149+
150+
151+
"""
152+
This test case is added to test T.Buffer with slice as argument and T.exp
153+
"""
154+
155+
156+
@T.prim_func
157+
def different_access_indices(a: T.handle, b: T.handle) -> None:
158+
A = T.match_buffer(a, [128, 128, 128], dtype="float32")
159+
B = T.match_buffer(b, [128, 128], dtype="float32")
160+
for i, j in T.grid(128, 128):
161+
for k in T.thread_binding(0, 128, thread="threadIdx.x"):
162+
with T.block("B"):
163+
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
164+
T.reads([B[vi, vj], A[vi, vj, vk]])
165+
T.writes(
166+
[
167+
B[
168+
T.min(vj, vi) : T.min(vj, vi) # type: ignore[misc]
169+
+ (T.max(vj, vi) + 1 - T.min(vj, vi)),
170+
T.min(vi, vj) : T.min(vi, vj) # type: ignore[misc]
171+
+ (T.max(vi, vj) + 1 - T.min(vi, vj)),
172+
]
173+
]
174+
)
175+
with T.init():
176+
B[vj, vi] = T.exp(B[vj, vi], dtype="float32")
177+
B[vi, vj] = B[vi, vj] + A[vi, vj, vk]
178+
179+
84180
# Not running any test as we only want to type-check here
85181
if __name__ == "__main__":
86182
pass

0 commit comments

Comments
 (0)