Skip to content

Commit 1c15f4d

Browse files
adjust related interfaces for IterMapLevel
1 parent 4edbb14 commit 1c15f4d

File tree

8 files changed

+397
-460
lines changed

8 files changed

+397
-460
lines changed

include/tvm/arith/iter_affine_map.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -263,10 +263,10 @@ class IterSumExpr : public IterMapExpr {
263263
enum IterMapLevel {
264264
// Require the mapping to be bijective.
265265
Bijective = 0,
266-
// Require the mapping to be subjective.
266+
// Require the mapping to be surjective.
267267
Surjective = 1,
268-
// Require the mapping to be injective.
269-
Injective = 2
268+
// No mapping safety check.
269+
NoCheck = 3
270270
};
271271

272272
/*!
@@ -327,7 +327,7 @@ class IterMapResult : public ObjectRef {
327327
* \param indices The indices to detect pattern for.
328328
* \param input_iters Map from variable to iterator's range.
329329
* \param predicate The predicate constraints on the input iterators
330-
* \param check_level The iter mapping check level.
330+
* \param check_level The iter mapping checking level.
331331
* \param analyzer Analyzer used to get context information.
332332
* \param simplify_trivial_iterators If true, iterators with extent of
333333
* 1 will be replaced with a constant value.
@@ -345,12 +345,12 @@ IterMapResult DetectIterMap(const Array<PrimExpr>& indices, const Map<Var, Range
345345
* \param indices The indices to detect pattern for.
346346
* \param input_iters Map from variable to iterator's range.
347347
* \param input_pred The predicate constraints on the input iterators
348-
* \param require_bijective A boolean flag that indicates whether the mapping should be bijective.
348+
* \param check_level The iter mapping checking level.
349349
*
350350
* \return The indices after rewrite
351351
*/
352352
Array<PrimExpr> IterMapSimplify(const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters,
353-
const PrimExpr& input_pred, bool require_bijective);
353+
const PrimExpr& input_pred, IterMapLevel check_level);
354354

355355
/*!
356356
* \brief Apply the inverse of the affine transformation to the outputs.
@@ -390,7 +390,7 @@ Map<Var, PrimExpr> InverseAffineIterMap(const Array<IterSumExpr>& iter_map,
390390
* \param input_iters Map from variable to iterator's range.
391391
* \param sub_iters Iterators of subspace.
392392
* \param predicate The predicate constraints on the input iterators
393-
* \param require_bijective A boolean flag that indicates whether the mapping should be bijective.
393+
* \param check_level The iter mapping checking level.
394394
* \param analyzer Analyzer used to get context information.
395395
*
396396
* \return The result list has length len(bindings) + 1
@@ -403,7 +403,7 @@ Map<Var, PrimExpr> InverseAffineIterMap(const Array<IterSumExpr>& iter_map,
403403
Array<Array<IterMark>> SubspaceDivide(const Array<PrimExpr>& bindings,
404404
const Map<Var, Range>& input_iters,
405405
const Array<Var>& sub_iters, const PrimExpr& predicate,
406-
bool require_bijective, arith::Analyzer* analyzer);
406+
IterMapLevel check_level, arith::Analyzer* analyzer);
407407

408408
/*!
409409
* \brief Given an expression that may contain IterMapExpr, transform it to normal PrimExpr.

python/tvm/arith/iter_affine_map.py

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
""" Iterator (quasi)affine mapping patterns."""
18+
from enum import IntEnum
1819
import tvm._ffi
1920
from tvm.runtime import Object
2021
from tvm.ir import PrimExpr
@@ -88,11 +89,35 @@ def __init__(self, args, base):
8889
self.__init_handle_by_constructor__(_ffi_api.IterSumExpr, args, base)
8990

9091

92+
class IterMapLevel(IntEnum):
93+
"""Possible kinds of iter mapping check level."""
94+
95+
Bijective = 0
96+
Surjective = 1
97+
NoCheck = 3
98+
99+
@staticmethod
100+
def from_str(name: str):
101+
"""Helper to create level enum from string"""
102+
if name is None:
103+
return IterMapLevel.NoCheck
104+
name = name.lower()
105+
if name == "bijective":
106+
check_level = IterMapLevel.Bijective
107+
elif name == "surjective":
108+
check_level = IterMapLevel.Surjective
109+
elif name == "nocheck":
110+
check_level = IterMapLevel.NoCheck
111+
else:
112+
raise ValueError(f"Unknown check level {name}")
113+
return check_level
114+
115+
91116
def detect_iter_map(
92117
indices,
93118
input_iters,
94119
predicate=True,
95-
require_bijective=False,
120+
check_level=IterMapLevel.Surjective,
96121
simplify_trivial_iterators=True,
97122
):
98123
"""Detect if indices can be written as mapped iters from input iters
@@ -108,8 +133,8 @@ def detect_iter_map(
108133
predicate : PrimExpr
109134
The predicate constraints on the input iterators
110135
111-
require_bijective : bool
112-
A boolean flag that indicates whether the mapping should be bijective
136+
check_level : Union[str, IterMapLevel]
137+
Checking level of iteration mapping
113138
114139
simplify_trivial_iterators: bool
115140
If true, iterators with extent of 1 will be replaced with a
@@ -122,9 +147,13 @@ def detect_iter_map(
122147
The result's .indices is empty array if no match can be found.
123148
124149
"""
150+
if isinstance(check_level, str):
151+
check_level = IterMapLevel.from_str(check_level)
152+
elif check_level is None:
153+
check_level = IterMapLevel.NoCheck
125154
return _ffi_api.DetectIterMap(
126-
indices, input_iters, predicate, require_bijective, simplify_trivial_iterators
127-
).indices
155+
indices, input_iters, predicate, check_level, simplify_trivial_iterators
156+
)
128157

129158

130159
def normalize_iter_map_to_expr(expr):
@@ -143,7 +172,9 @@ def normalize_iter_map_to_expr(expr):
143172
return _ffi_api.NormalizeIterMapToExpr(expr)
144173

145174

146-
def subspace_divide(bindings, input_iters, sub_iters, predicate=True, require_bijective=False):
175+
def subspace_divide(
176+
bindings, input_iters, sub_iters, predicate=True, check_level=IterMapLevel.Surjective
177+
):
147178
"""Detect if bindings can be written as
148179
[a_0*e_0 + b_0 + c_0, a_1*e_1 + b_1, ..., a_n*e_n + b_n]
149180
where a = some-quasi-affine-iter-map(input_iters set_minus sub_iters)
@@ -172,8 +203,8 @@ def subspace_divide(bindings, input_iters, sub_iters, predicate=True, require_bi
172203
predicate : PrimExpr
173204
The predicate constraints on the input iterators
174205
175-
require_bijective : bool
176-
A boolean flag that indicates whether the bindings should be bijective
206+
check_level : Union[str, IterMapLevel]
207+
Checking level of iteration mapping
177208
178209
Returns
179210
-------
@@ -185,7 +216,9 @@ def subspace_divide(bindings, input_iters, sub_iters, predicate=True, require_bi
185216
len(bindings): the predicate of outer space and inner space
186217
Empty array if no match can be found.
187218
"""
188-
return _ffi_api.SubspaceDivide(bindings, input_iters, sub_iters, predicate, require_bijective)
219+
if isinstance(check_level, str):
220+
check_level = IterMapLevel.from_str(check_level)
221+
return _ffi_api.SubspaceDivide(bindings, input_iters, sub_iters, predicate, check_level)
189222

190223

191224
def inverse_affine_iter_map(iter_map, outputs):

0 commit comments

Comments
 (0)