1515# specific language governing permissions and limitations
1616# under the License.
1717""" Iterator (quasi)affine mapping patterns."""
18+ from enum import IntEnum
1819import tvm ._ffi
1920from tvm .runtime import Object
2021from tvm .ir import PrimExpr
@@ -88,11 +89,35 @@ def __init__(self, args, base):
8889 self .__init_handle_by_constructor__ (_ffi_api .IterSumExpr , args , base )
8990
9091
92+ class IterMapLevel (IntEnum ):
93+ """Possible kinds of iter mapping check level."""
94+
95+ Bijective = 0
96+ Surjective = 1
97+ NoCheck = 3
98+
99+ @staticmethod
100+ def from_str (name : str ):
101+ """Helper to create level enum from string"""
102+ if name is None :
103+ return IterMapLevel .NoCheck
104+ name = name .lower ()
105+ if name == "bijective" :
106+ check_level = IterMapLevel .Bijective
107+ elif name == "surjective" :
108+ check_level = IterMapLevel .Surjective
109+ elif name == "nocheck" :
110+ check_level = IterMapLevel .NoCheck
111+ else :
112+ raise ValueError (f"Unknown check level { name } " )
113+ return check_level
114+
115+
91116def detect_iter_map (
92117 indices ,
93118 input_iters ,
94119 predicate = True ,
95- require_bijective = False ,
120+ check_level = IterMapLevel . Surjective ,
96121 simplify_trivial_iterators = True ,
97122):
98123 """Detect if indices can be written as mapped iters from input iters
@@ -108,8 +133,8 @@ def detect_iter_map(
108133 predicate : PrimExpr
109134 The predicate constraints on the input iterators
110135
111- require_bijective : bool
112- A boolean flag that indicates whether the mapping should be bijective
136+ check_level : Union[str, IterMapLevel]
137+ Checking level of iteration mapping
113138
114139 simplify_trivial_iterators: bool
115140 If true, iterators with extent of 1 will be replaced with a
@@ -122,9 +147,13 @@ def detect_iter_map(
122147 The result's .indices is empty array if no match can be found.
123148
124149 """
150+ if isinstance (check_level , str ):
151+ check_level = IterMapLevel .from_str (check_level )
152+ elif check_level is None :
153+ check_level = IterMapLevel .NoCheck
125154 return _ffi_api .DetectIterMap (
126- indices , input_iters , predicate , require_bijective , simplify_trivial_iterators
127- ). indices
155+ indices , input_iters , predicate , check_level , simplify_trivial_iterators
156+ )
128157
129158
130159def normalize_iter_map_to_expr (expr ):
@@ -143,7 +172,9 @@ def normalize_iter_map_to_expr(expr):
143172 return _ffi_api .NormalizeIterMapToExpr (expr )
144173
145174
146- def subspace_divide (bindings , input_iters , sub_iters , predicate = True , require_bijective = False ):
175+ def subspace_divide (
176+ bindings , input_iters , sub_iters , predicate = True , check_level = IterMapLevel .Surjective
177+ ):
147178 """Detect if bindings can be written as
148179 [a_0*e_0 + b_0 + c_0, a_1*e_1 + b_1, ..., a_n*e_n + b_n]
149180 where a = some-quasi-affine-iter-map(input_iters set_minus sub_iters)
@@ -172,8 +203,8 @@ def subspace_divide(bindings, input_iters, sub_iters, predicate=True, require_bi
172203 predicate : PrimExpr
173204 The predicate constraints on the input iterators
174205
175- require_bijective : bool
176- A boolean flag that indicates whether the bindings should be bijective
206+ check_level : Union[str, IterMapLevel]
207+ Checking level of iteration mapping
177208
178209 Returns
179210 -------
@@ -185,7 +216,9 @@ def subspace_divide(bindings, input_iters, sub_iters, predicate=True, require_bi
185216 len(bindings): the predicate of outer space and inner space
186217 Empty array if no match can be found.
187218 """
188- return _ffi_api .SubspaceDivide (bindings , input_iters , sub_iters , predicate , require_bijective )
219+ if isinstance (check_level , str ):
220+ check_level = IterMapLevel .from_str (check_level )
221+ return _ffi_api .SubspaceDivide (bindings , input_iters , sub_iters , predicate , check_level )
189222
190223
191224def inverse_affine_iter_map (iter_map , outputs ):
0 commit comments