Skip to content

Commit 0e8107b

Browse files
authored
[TIR][Arith] Implemented padded inverses in IndexMap (#11235)
* [Debug] Error logging in DetectIterMap * [Affine] Allowed PrimExpr argument to NormalizeIterMapToExpr This allows it to be used for any expression containing an `IterMapExpr`, not just expressions whose top-level node is an `IterMapExpr`. * [Affine] Implemented DetectPaddedIterMap The existing DetectIterMap tries to rewrite index expression as a linear combination of split/fused iterators, where the new iterators cover the exact same indices as the original expression. DetectPaddedIterMap relaxes this condition, allowing the new iterators to cover a superset of indices that the initial index expression covered. It uses the minimum amount of padding necessary to represent these transformations, and also a predicate that identifies any padding that has been added. This is a utility function to be used for layout transformations of buffers, in cases where the pre-transformation shape of the buffer does not evenly fit into the post-transformation shape. * [IndexMap] Implemented IndexMap::NonSurjectiveInverse Allow non-surjective transformations, with DetectIterMap used to determine the minimum padding to insert. Returns the inverse function, along with a predicate that identifies padding indices. The predicate is in terms of the transformed variables. * [IndexMap] Exposed methods to python - `IndexMap::Inverse` exposed as `IndexMap.inverse` - `IndexMap::MapShape` exposed as `IndexMap.map_shape` - `IndexMap::NonSurjectiveInverse` exposed as `IndexMap.non_surjective_inverse` * [IndexMap] Extracted _assert_equal_index_map into class method In preparation for adding additional tests for the IndexMap class, which will require this functionality. * [IndexMap] Added unit tests for new behavior * Re-enabled divisibility check in CheckMapping Initially disabled as dynamic shapes resulted in padded lengths whose divisiblity couldn't be proven. Re-enabled along with a simplification rule to resolve it. * Fixed breakage in compute_at primitive * Corrected typos/examples in docstring
1 parent 6c339ea commit 0e8107b

File tree

9 files changed

+958
-150
lines changed

9 files changed

+958
-150
lines changed

include/tvm/arith/iter_affine_map.h

Lines changed: 70 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,73 @@ class IterSumExpr : public IterMapExpr {
285285
Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters,
286286
const PrimExpr& predicate, bool require_bijective,
287287
arith::Analyzer* analyzer, bool simplify_trivial_iterators = true);
288+
289+
/*! \brief A utility struct for return values from DetectPaddedIterMap
290+
*/
291+
struct PaddedIterMapResult {
292+
// Any errors that occurred while converting the input indices. If
293+
// the array is empty, the conversion was successful.
294+
Array<String> errors;
295+
296+
// The detected pattern if a match exists.
297+
Array<IterSumExpr> indices;
298+
299+
/* \brief Boolean expression indicating if padding was required
300+
*
301+
* `requires_padding` evaluates to true if the returned indices
302+
* contain padding relative to the provided expressions, and false
303+
* otherwise. If `input_iters` contains a variable extent, this
304+
* expression may be in terms of those variables.
305+
*/
306+
PrimExpr requires_padding;
307+
308+
/* \brief Boolean expression indicating if a specific value w
309+
*
310+
* `padding_predicate` evaluates to true for a set of indices that
311+
* are outside the bounds of the provided index iterators, but
312+
* inside the bounds of the returned index iterators. This
313+
* expression is in terms of the variables provided in
314+
* `input_iters`.
315+
*/
316+
PrimExpr padding_predicate;
317+
};
318+
319+
/*!
320+
* \brief Detect if indices can be written as
321+
* [y_0 + c_0, y_1 + c_1, ..., y_n + c_n]
322+
*
323+
* Here y = some-quasi-affine-iter-map(input_iters) and c are
324+
* symbolic constants. The y_i iterators may be padded to fit this
325+
* representation.
326+
*
327+
* We also requires that y_i and y_j to be independent for i != j.
328+
*
329+
* For returned value rv, the following is always true:
330+
* - rv.indices[i]->args.size() <=1: only one iterator per element.
331+
*
332+
* \param indices The indices to detect pattern for.
333+
*
334+
* \param input_iters Map from variable to iterator's range.
335+
*
336+
* \param predicate The predicate constraints on the input iterators
337+
*
338+
* \param require_bijective A boolean flag that indicates whether the
339+
* mapping should be bijective. If true, no padding may be
340+
* introduced.
341+
*
342+
* \param analyzer Analyzer used to get context information.
343+
*
344+
* \param simplify_trivial_iterators If true, iterators with extent of
345+
* 1 will be replaced with a constant value.
346+
*
347+
* \return An instance of PaddedIterMapResult.
348+
*/
349+
PaddedIterMapResult DetectPaddedIterMap(const Array<PrimExpr>& indices,
350+
const Map<Var, Range>& input_iters,
351+
const PrimExpr& predicate, bool require_bijective,
352+
arith::Analyzer* analyzer,
353+
bool simplify_trivial_iterators = true);
354+
288355
/*!
289356
* \brief Use IterVarMap detector to rewrite and simplify the indices
290357
*
@@ -352,11 +419,11 @@ Array<Array<IterMark>> SubspaceDivide(const Array<PrimExpr>& bindings,
352419
bool require_bijective, arith::Analyzer* analyzer);
353420

354421
/*!
355-
* \brief Given an IterMapExpr, transform it to normal PrimExpr.
356-
* \param expr The input IterMapExpr.
422+
* \brief Given an expression that may contain IterMapExpr, transform it to normal PrimExpr.
423+
* \param expr The input expression, which may contain IterMapExpr.
357424
* \return The corresponding normal PrimExpr.
358425
*/
359-
PrimExpr NormalizeIterMapToExpr(const IterMapExpr& expr);
426+
PrimExpr NormalizeIterMapToExpr(const PrimExpr& expr);
360427

361428
} // namespace arith
362429
} // namespace tvm

include/tvm/tir/index_map.h

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
#include <tvm/runtime/object.h>
3232
#include <tvm/tir/var.h>
3333

34+
#include <utility>
35+
3436
namespace tvm {
3537
namespace tir {
3638

@@ -141,12 +143,24 @@ class IndexMap : public ObjectRef {
141143
*
142144
* TODO(Lunderberg): Look into allowing non-bijective
143145
* transformations. If injective, the inverse mapping could still
144-
* be generated with some predicate. If non-injective, could
145-
* simplify the implementation of other optimizations (e.g. double
146-
* buffering as a map `lambda *indices: [buffer_loop%2, *indices]`).
146+
* be generated with some predicate (see NonSurjectiveInverse). If
147+
* non-injective, could simplify the implementation of other
148+
* optimizations (e.g. double buffering as a map `lambda *indices:
149+
* [buffer_loop%2, *indices]`).
147150
*/
148151
IndexMap Inverse(Array<Range> initial_ranges) const;
149152

153+
/*! \brief Generate the inverse mapping.
154+
*
155+
* Determine the inverse, where the output range may contain
156+
* addresses that do not correspond to an address in the input
157+
* range.
158+
*
159+
* \return The inverted index map, along with the predicate for
160+
* which the inverse maps to a valid range.
161+
*/
162+
std::pair<IndexMap, PrimExpr> NonSurjectiveInverse(Array<Range> initial_ranges) const;
163+
150164
TVM_DEFINE_OBJECT_REF_METHODS(IndexMap, ObjectRef, IndexMapNode);
151165
};
152166

python/tvm/tir/function.py

Lines changed: 107 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,14 @@
1616
# under the License.
1717
"""Function data types."""
1818

19-
from typing import Callable, List, Mapping, Optional, Union
19+
from typing import Callable, List, Mapping, Optional, Union, Tuple
2020
import inspect
2121

22+
import tvm
2223
import tvm._ffi
2324
import tvm.runtime
2425
from tvm.runtime import Object
25-
from tvm.ir import BaseFunc
26+
from tvm.ir import BaseFunc, Range
2627
from .buffer import Buffer
2728
from .expr import Var, PrimExpr
2829
from . import _ffi_api
@@ -296,12 +297,42 @@ def from_func(mapping_function: Callable, ndim: Optional[int] = None):
296297
final_indices = mapping_function(*args)
297298
return IndexMap(args, final_indices)
298299

300+
def is_equivalent_to(self, other_map: "IndexMap") -> bool:
301+
"""Return if the index maps are equivalent.
302+
303+
Parameters
304+
----------
305+
other_map: IndexMap
306+
307+
The IndexMap to which the comparison should be made.
308+
309+
Returns
310+
-------
311+
is_equivalent: bool
312+
313+
True if the two mappings represent the same
314+
transformation, otherwise False
315+
"""
316+
if len(self.initial_indices) != len(other_map.initial_indices):
317+
return False
318+
if len(self.final_indices) != len(other_map.final_indices):
319+
return False
320+
321+
analyzer = tvm.arith.Analyzer()
322+
323+
mapped_other_final_indices = other_map.map_indices(self.initial_indices)
324+
for self_index, other_index in zip(self.final_indices, mapped_other_final_indices):
325+
if not analyzer.can_prove_equal(self_index, other_index):
326+
return False
327+
328+
return True
329+
299330
def map_indices(self, indices: List[PrimExpr]) -> List[PrimExpr]:
300331
"""Apply the index map to a set of indices
301332
302333
Parameters
303334
----------
304-
indices : List[PriExpr]
335+
indices : List[PrimExpr]
305336
The indices to be mapped
306337
307338
Returns
@@ -310,3 +341,76 @@ def map_indices(self, indices: List[PrimExpr]) -> List[PrimExpr]:
310341
The mapped indices
311342
"""
312343
return _ffi_api.IndexMapMapIndices(self, indices)
344+
345+
def map_shape(self, shape: List[PrimExpr]) -> List[PrimExpr]:
346+
"""Apply the index map to a buffer shape
347+
348+
Parameters
349+
----------
350+
shape : List[PrimExpr]
351+
The buffer shape to be mapped
352+
353+
Returns
354+
-------
355+
result : List[PrimExpr]
356+
The mapped shape
357+
"""
358+
return _ffi_api.IndexMapMapShape(self, shape)
359+
360+
def inverse(self, shape: List[Union[Range, PrimExpr]]) -> "IndexMap":
361+
"""Return the inverse of the map
362+
363+
Throws an error if the function is not bijective.
364+
365+
Parameters
366+
----------
367+
shape: List[Union[Range,PrimExpr]]
368+
369+
The region over which the inverse should be determined.
370+
Used for validating that the mapping is bijective over
371+
this range.
372+
373+
Returns
374+
-------
375+
inverse : IndexMap
376+
377+
The inverse
378+
"""
379+
380+
shape = [dim if isinstance(dim, Range) else Range(0, dim) for dim in shape]
381+
return _ffi_api.IndexMapInverse(self, shape)
382+
383+
def non_surjective_inverse(
384+
self, shape: List[Union[Range, PrimExpr]]
385+
) -> Tuple["IndexMap", PrimExpr]:
386+
"""Return the inverse of the map
387+
388+
Can be applied to transformations that introduce padding.
389+
390+
Parameters
391+
----------
392+
shape: List[Union[Range,PrimExpr]]
393+
394+
The region over which the inverse should be determined.
395+
Used for determining the predicate.
396+
397+
Returns
398+
-------
399+
result : Tuple[IndexMap, PrimExpr]
400+
401+
The inverse, and a predicate for which the inverse maps to
402+
a valid index in the input range.
403+
404+
Examples
405+
--------
406+
407+
.. code-block:: python
408+
409+
index_map = IndexMap.from_func(lambda i: [i//4, i%4])
410+
inverse_map, predicate = index_map.non_surjective_inverse([14])
411+
assert inverse_map.is_equivalent_to(IndexMap.from_func(lambda j,k: [4*j + k])
412+
print(predicate) # Prints "(axis0==3) && (axis2 >= 2)"
413+
"""
414+
415+
shape = [dim if isinstance(dim, Range) else Range(0, dim) for dim in shape]
416+
return _ffi_api.IndexMapNonSurjectiveInverse(self, shape)

0 commit comments

Comments
 (0)