@@ -15,6 +15,8 @@ include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td"
1515include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td"
1616include "mlir/Interfaces/InferTypeOpInterface.td"
1717include "mlir/Interfaces/SideEffectInterfaces.td"
18+ include "mlir/Interfaces/ControlFlowInterfaces.td"
19+ include "mlir/Interfaces/LoopLikeInterface.td"
1820
1921//===----------------------------------------------------------------------===//
2022// Base class.
@@ -1277,7 +1279,7 @@ def SparseTensor_SelectOp : SparseTensor_Op<"select", [Pure, SameOperandsAndResu
12771279
12781280def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator,
12791281 ParentOneOf<["BinaryOp", "UnaryOp", "ReduceOp", "SelectOp",
1280- "ForeachOp"]>]>,
1282+ "ForeachOp", "IterateOp" ]>]>,
12811283 Arguments<(ins Variadic<AnyType>:$results)> {
12821284 let summary = "Yield from sparse_tensor set-like operations";
12831285 let description = [{
@@ -1490,6 +1492,103 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
14901492 " attr-dict `:` type($tensor) (`,` type($parentIter)^)?";
14911493}
14921494
1495+ def IterateOp : SparseTensor_Op<"iterate",
1496+ [RecursiveMemoryEffects, RecursivelySpeculatable,
1497+ DeclareOpInterfaceMethods<LoopLikeOpInterface,
1498+ ["getInitsMutable", "getLoopResults", "getRegionIterArgs",
1499+ "getYieldedValuesMutable"]>,
1500+ DeclareOpInterfaceMethods<RegionBranchOpInterface,
1501+ ["getEntrySuccessorOperands"]>,
1502+ SingleBlockImplicitTerminator<"sparse_tensor::YieldOp">]> {
1503+
1504+ let arguments = (ins AnySparseIterSpace:$iterSpace,
1505+ Variadic<AnyType>:$initArgs,
1506+ LevelSetAttr:$crdUsedLvls);
1507+ let results = (outs Variadic<AnyType>:$results);
1508+ let regions = (region SizedRegion<1>:$region);
1509+
1510+ let summary = "Iterate over a sparse iteration space";
1511+ let description = [{
1512+ The `sparse_tensor.iterate` operations represents a loop over the
1513+ provided iteration space extracted from a specific sparse tensor.
1514+ The operation defines an SSA value for a sparse iterator that points
1515+ to the current stored element in the sparse tensor and SSA values
1516+ for coordinates of the stored element. The coordinates are always
1517+ converted to `index` type despite of the underlying sparse tensor
1518+ storage. When coordinates are not used, the SSA values can be skipped
1519+ by `_` symbols, which usually leads to simpler generated code after
1520+ sparsification. For example:
1521+
1522+ ```mlir
1523+ // The coordinate for level 0 is not used when iterating over a 2-D
1524+ // iteration space.
1525+ %sparse_tensor.iterate %iterator in %space at(_, %crd_1)
1526+ : !sparse_tensor.iter_space<#CSR, lvls = 0 to 2>
1527+ ```
1528+
1529+ `sparse_tensor.iterate` can also operate on loop-carried variables
1530+ and returns the final values after loop termination.
1531+ The initial values of the variables are passed as additional SSA operands
1532+ to the iterator SSA value and used coordinate SSA values mentioned
1533+ above. The operation region has an argument for the iterator, variadic
1534+ arguments for specified (used) coordiates and followed by one argument
1535+ for each loop-carried variable, representing the value of the variable
1536+ at the current iteration.
1537+ The body region must contain exactly one block that terminates with
1538+ `sparse_tensor.yield`.
1539+
1540+ `sparse_tensor.iterate` results hold the final values after the last
1541+ iteration. If the `sparse_tensor.iterate` defines any values, a yield
1542+ must be explicitly present.
1543+ The number and types of the `sparse_tensor.iterate` results must match
1544+ the initial values in the iter_args binding and the yield operands.
1545+
1546+
1547+ A nested `sparse_tensor.iterate` example that prints all the coordinates
1548+ stored in the sparse input:
1549+
1550+ ```mlir
1551+ func.func @nested_iterate(%sp : tensor<4x8xf32, #COO>) {
1552+ // Iterates over the first level of %sp
1553+ %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
1554+ %r1 = sparse_tensor.iterate %it1 in %l1 at (%crd0)
1555+ : !sparse_tensor.iter_space<#COO, lvls = 0 to 1> {
1556+ // Iterates over the second level of %sp
1557+ %l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1
1558+ : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1>
1559+ %r2 = sparse_tensor.iterate %it2 in %l2 at (crd1)
1560+ : !sparse_tensor.iter_space<#COO, lvls = 1 to 2> {
1561+ vector.print %crd0 : index
1562+ vector.print %crd1 : index
1563+ }
1564+ }
1565+ }
1566+
1567+ ```
1568+ }];
1569+
1570+ let extraClassDeclaration = [{
1571+ unsigned getSpaceDim() {
1572+ return getIterSpace().getType().getSpaceDim();
1573+ }
1574+ BlockArgument getIterator() {
1575+ return getRegion().getArguments().front();
1576+ }
1577+ Block::BlockArgListType getCrds() {
1578+ // The first block argument is iterator, the remaining arguments are
1579+ // referenced coordinates.
1580+ return getRegion().getArguments().slice(1, getCrdUsedLvls().count());
1581+ }
1582+ unsigned getNumRegionIterArgs() {
1583+ return getRegion().getArguments().size() - 1 - getCrdUsedLvls().count();
1584+ }
1585+ }];
1586+
1587+ let hasVerifier = 1;
1588+ let hasRegionVerifier = 1;
1589+ let hasCustomAssemblyFormat = 1;
1590+ }
1591+
14931592//===----------------------------------------------------------------------===//
14941593// Sparse Tensor Debugging and Test-Only Operations.
14951594//===----------------------------------------------------------------------===//
0 commit comments