@@ -15,6 +15,8 @@ include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td"
1515include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td"
1616include "mlir/Interfaces/InferTypeOpInterface.td"
1717include "mlir/Interfaces/SideEffectInterfaces.td"
18+ include "mlir/Interfaces/ControlFlowInterfaces.td"
19+ include "mlir/Interfaces/LoopLikeInterface.td"
1820
1921//===----------------------------------------------------------------------===//
2022// Base class.
@@ -1304,7 +1306,7 @@ def SparseTensor_SelectOp : SparseTensor_Op<"select", [Pure, SameOperandsAndResu
13041306
13051307def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator,
13061308 ParentOneOf<["BinaryOp", "UnaryOp", "ReduceOp", "SelectOp",
1307- "ForeachOp"]>]> {
1309+ "ForeachOp", "IterateOp" ]>]> {
13081310 let summary = "Yield from sparse_tensor set-like operations";
13091311 let description = [{
13101312 Yields a value from within a `binary`, `unary`, `reduce`,
@@ -1476,7 +1478,7 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
14761478 the returned iteration space covers. `hiLvl - loLvl` defines the dimension of the
14771479 iteration space.
14781480
1479- The type of returned the value is automatically inferred to
1481+ The type of returned the value is must be
14801482 `!sparse_tensor.iter_space<#INPUT_ENCODING, lvls = $loLvl to $hiLvl>`.
14811483 The returned iteration space can then be iterated over by
14821484 `sparse_tensor.iterate` operations to visit every stored element
@@ -1487,6 +1489,7 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
14871489 // Extracts a 1-D iteration space from a COO tensor at level 1.
14881490 %space = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 1
14891491 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
1492+ ->!sparse_tensor.iter_space<#COO, lvls = 1>
14901493 ```
14911494 }];
14921495
@@ -1499,20 +1502,120 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
14991502 return getHiLvl() - getLoLvl();
15001503 }
15011504 ArrayRef<::mlir::sparse_tensor::LevelType> getSpaceLvlTypes() {
1502- return getResultSpace ().getType().getLvlTypes();
1505+ return getExtractedSpace ().getType().getLvlTypes();
15031506 }
15041507 }];
15051508
15061509 let arguments = (ins AnySparseTensor:$tensor,
15071510 Optional<AnySparseIterator>:$parentIter,
15081511 LevelAttr:$loLvl, LevelAttr:$hiLvl);
1509- let results = (outs AnySparseIterSpace:$resultSpace );
1512+ let results = (outs AnySparseIterSpace:$extractedSpace );
15101513 let assemblyFormat = "$tensor (`at` $parentIter^)? `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) "
1511- " attr-dict `:` type($tensor) (`,` type($parentIter)^)?";
1514+ " attr-dict `:` type($tensor) (`,` type($parentIter)^)? "
1515+ "`->` qualified(type($extractedSpace))";
15121516
15131517 let hasVerifier = 1;
15141518}
15151519
1520+ def IterateOp : SparseTensor_Op<"iterate",
1521+ [RecursiveMemoryEffects, RecursivelySpeculatable,
1522+ DeclareOpInterfaceMethods<LoopLikeOpInterface,
1523+ ["getInitsMutable", "getLoopResults", "getRegionIterArgs",
1524+ "getYieldedValuesMutable"]>,
1525+ DeclareOpInterfaceMethods<RegionBranchOpInterface,
1526+ ["getEntrySuccessorOperands"]>,
1527+ SingleBlockImplicitTerminator<"sparse_tensor::YieldOp">]> {
1528+
1529+ let summary = "Iterates over a sparse iteration space";
1530+ let description = [{
1531+ The `sparse_tensor.iterate` operation represents a loop (nest) over
1532+ the provided iteration space extracted from a specific sparse tensor.
1533+ The operation defines an SSA value for a sparse iterator that points
1534+ to the current stored element in the sparse tensor and SSA values
1535+ for coordinates of the stored element. The coordinates are always
1536+ converted to `index` type despite of the underlying sparse tensor
1537+ storage. When coordinates are not used, the SSA values can be skipped
1538+ by `_` symbols, which usually leads to simpler generated code after
1539+ sparsification. For example:
1540+
1541+ ```mlir
1542+ // The coordinate for level 0 is not used when iterating over a 2-D
1543+ // iteration space.
1544+ %sparse_tensor.iterate %iterator in %space at(_, %crd_1)
1545+ : !sparse_tensor.iter_space<#CSR, lvls = 0 to 2>
1546+ ```
1547+
1548+ `sparse_tensor.iterate` can also operate on loop-carried variables.
1549+ It returns the final values after loop termination.
1550+ The initial values of the variables are passed as additional SSA operands
1551+ to the iterator SSA value and used coordinate SSA values mentioned
1552+ above. The operation region has an argument for the iterator, variadic
1553+ arguments for specified (used) coordiates and followed by one argument
1554+ for each loop-carried variable, representing the value of the variable
1555+ at the current iteration.
1556+ The body region must contain exactly one block that terminates with
1557+ `sparse_tensor.yield`.
1558+
1559+ The results of an `sparse_tensor.iterate` hold the final values after
1560+ the last iteration. If the `sparse_tensor.iterate` defines any values,
1561+ a yield must be explicitly present.
1562+ The number and types of the `sparse_tensor.iterate` results must match
1563+ the initial values in the iter_args binding and the yield operands.
1564+
1565+
1566+ A nested `sparse_tensor.iterate` example that prints all the coordinates
1567+ stored in the sparse input:
1568+
1569+ ```mlir
1570+ func.func @nested_iterate(%sp : tensor<4x8xf32, #COO>) {
1571+ // Iterates over the first level of %sp
1572+ %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0
1573+ : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0 to 1>
1574+ %r1 = sparse_tensor.iterate %it1 in %l1 at (%coord0)
1575+ : !sparse_tensor.iter_space<#COO, lvls = 0 to 1> {
1576+ // Iterates over the second level of %sp
1577+ %l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1
1578+ : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1>
1579+ -> !sparse_tensor.iter_space<#COO, lvls = 1 to 2>
1580+ %r2 = sparse_tensor.iterate %it2 in %l2 at (coord1)
1581+ : !sparse_tensor.iter_space<#COO, lvls = 1 to 2> {
1582+ vector.print %coord0 : index
1583+ vector.print %coord1 : index
1584+ }
1585+ }
1586+ }
1587+
1588+ ```
1589+ }];
1590+
1591+ let arguments = (ins AnySparseIterSpace:$iterSpace,
1592+ Variadic<AnyType>:$initArgs,
1593+ LevelSetAttr:$crdUsedLvls);
1594+ let results = (outs Variadic<AnyType>:$results);
1595+ let regions = (region SizedRegion<1>:$region);
1596+
1597+ let extraClassDeclaration = [{
1598+ unsigned getSpaceDim() {
1599+ return getIterSpace().getType().getSpaceDim();
1600+ }
1601+ BlockArgument getIterator() {
1602+ return getRegion().getArguments().front();
1603+ }
1604+ Block::BlockArgListType getCrds() {
1605+ // The first block argument is iterator, the remaining arguments are
1606+ // referenced coordinates.
1607+ return getRegion().getArguments().slice(1, getCrdUsedLvls().count());
1608+ }
1609+ unsigned getNumRegionIterArgs() {
1610+ return getRegion().getArguments().size() - 1 - getCrdUsedLvls().count();
1611+ }
1612+ }];
1613+
1614+ let hasVerifier = 1;
1615+ let hasRegionVerifier = 1;
1616+ let hasCustomAssemblyFormat = 1;
1617+ }
1618+
15161619//===----------------------------------------------------------------------===//
15171620// Sparse Tensor Debugging and Test-Only Operations.
15181621//===----------------------------------------------------------------------===//
0 commit comments