@@ -69,13 +69,15 @@ class IterVar(Var): ...
6969
7070class Buffer :
7171 @overload
72- def __getitem__ (self : Buffer , pos : Sequence [Union [PrimExpr , int ]]) -> PrimExpr : ...
72+ def __getitem__ (self : Buffer , pos : Sequence [Union [PrimExpr , int , slice ]]) -> PrimExpr : ...
7373 @overload
74- def __getitem__ (self : Buffer , pos : Union [PrimExpr , int ]) -> PrimExpr : ...
74+ def __getitem__ (self : Buffer , pos : Union [PrimExpr , int , slice ]) -> PrimExpr : ...
7575 @overload
76- def __setitem__ (self : Buffer , pos : Sequence [Union [PrimExpr , int ]], value : PrimExpr ) -> None : ...
76+ def __setitem__ (
77+ self : Buffer , pos : Sequence [Union [PrimExpr , int , slice ]], value : PrimExpr
78+ ) -> None : ...
7779 @overload
78- def __setitem__ (self : Buffer , pos : Union [PrimExpr , int ], value : PrimExpr ) -> None : ...
80+ def __setitem__ (self : Buffer , pos : Union [PrimExpr , int , slice ], value : PrimExpr ) -> None : ...
7981 @property
8082 def data (self : Buffer ) -> Ptr : ...
8183
@@ -124,35 +126,47 @@ def reinterpret(value: PrimExpr, dtype: str) -> PrimExpr: ...
124126def store (
125127 var : Var , index : PrimExpr , value : PrimExpr , predicate : Union [PrimExpr , builtins .bool ] = True
126128) -> None : ...
127- def comm_reducer (lambda_io : Tuple [List [PrimExpr ]], identities : List [PrimExpr ]) -> PrimExpr : ...
129+ def comm_reducer (lambda_io : Callable [[Any , Any ], Any ], identities : List [PrimExpr ]) -> PrimExpr : ...
130+
131+ """
132+ Intrinsics - tvm builtin
133+ """
134+
135+ def tvm_thread_allreduce (
136+ * freduceargs : Union [PrimExpr , builtins .bool , Ptr ], dtype : str
137+ ) -> PrimExpr : ...
128138
129139"""
130140Unary operator
141+ Note that any intrinsics not registered in script.tir.intrin
142+ should add "dtype" as an argument. This is different from their
143+ definition but intentional.
131144"""
132145
133- def exp2 (x : PrimExpr ) -> PrimExpr : ...
134- def exp10 (x : PrimExpr ) -> PrimExpr : ...
135- def erf (x : PrimExpr ) -> PrimExpr : ...
136- def tanh (x : PrimExpr ) -> PrimExpr : ...
137- def sigmoid (x : PrimExpr ) -> PrimExpr : ...
138- def log (x : PrimExpr ) -> PrimExpr : ...
139- def log2 (x : PrimExpr ) -> PrimExpr : ...
140- def log10 (x : PrimExpr ) -> PrimExpr : ...
141- def log1p (x : PrimExpr ) -> PrimExpr : ...
142- def tan (x : PrimExpr ) -> PrimExpr : ...
143- def cos (x : PrimExpr ) -> PrimExpr : ...
144- def cosh (x : PrimExpr ) -> PrimExpr : ...
145- def acos (x : PrimExpr ) -> PrimExpr : ...
146- def acosh (x : PrimExpr ) -> PrimExpr : ...
147- def sin (x : PrimExpr ) -> PrimExpr : ...
148- def sinh (x : PrimExpr ) -> PrimExpr : ...
149- def asin (x : PrimExpr ) -> PrimExpr : ...
150- def asinh (x : PrimExpr ) -> PrimExpr : ...
151- def atan (x : PrimExpr ) -> PrimExpr : ...
152- def atanh (x : PrimExpr ) -> PrimExpr : ...
153- def atan2 (x : PrimExpr ) -> PrimExpr : ...
154- def sqrt (x : PrimExpr ) -> PrimExpr : ...
155- def rsqrt (x : PrimExpr ) -> PrimExpr : ...
146+ def exp (x : PrimExpr , dtype : str ) -> PrimExpr : ...
147+ def exp2 (x : PrimExpr , dtype : str ) -> PrimExpr : ...
148+ def exp10 (x : PrimExpr , dtype : str ) -> PrimExpr : ...
149+ def erf (x : PrimExpr , dtype : str ) -> PrimExpr : ...
150+ def tanh (x : PrimExpr , dtype : str ) -> PrimExpr : ...
151+ def sigmoid (x : PrimExpr , dtype : str ) -> PrimExpr : ...
152+ def log (x : PrimExpr , dtype : str ) -> PrimExpr : ...
153+ def log2 (x : PrimExpr , dtype : str ) -> PrimExpr : ...
154+ def log10 (x : PrimExpr , dtype : str ) -> PrimExpr : ...
155+ def log1p (x : PrimExpr , dtype : str ) -> PrimExpr : ...
156+ def tan (x : PrimExpr , dtype : str ) -> PrimExpr : ...
157+ def cos (x : PrimExpr , dtype : str ) -> PrimExpr : ...
158+ def cosh (x : PrimExpr , dtype : str ) -> PrimExpr : ...
159+ def acos (x : PrimExpr , dtype : str ) -> PrimExpr : ...
160+ def acosh (x : PrimExpr , dtype : str ) -> PrimExpr : ...
161+ def sin (x : PrimExpr , dtype : str ) -> PrimExpr : ...
162+ def sinh (x : PrimExpr , dtype : str ) -> PrimExpr : ...
163+ def asin (x : PrimExpr , dtype : str ) -> PrimExpr : ...
164+ def asinh (x : PrimExpr , dtype : str ) -> PrimExpr : ...
165+ def atan (x : PrimExpr , dtype : str ) -> PrimExpr : ...
166+ def atanh (x : PrimExpr , dtype : str ) -> PrimExpr : ...
167+ def atan2 (x : PrimExpr , dtype : str ) -> PrimExpr : ...
168+ def sqrt (x : PrimExpr , dtype : str ) -> PrimExpr : ...
169+ def rsqrt (x : PrimExpr , dtype : str ) -> PrimExpr : ...
156170
157171"""
158172special_stmt - Buffers
@@ -334,7 +348,7 @@ def for_range(
334348 end : Union [PrimExpr , int ] = None ,
335349 annotations : Optional [Mapping [str , Object ]] = None ,
336350) -> Iterable [IterVar ]: ...
337- def grid (* extents : Union [PrimExpr , int ]) -> Iterable [Tuple [IterVar ]]: ...
351+ def grid (* extents : Union [PrimExpr , int ]) -> Iterable [Sequence [IterVar ]]: ...
338352
339353"""
340354ty - redefine types
0 commit comments