@@ -15,6 +15,8 @@ include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td"
1515include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td"
1616include "mlir/Interfaces/InferTypeOpInterface.td"
1717include "mlir/Interfaces/SideEffectInterfaces.td"
18+ include "mlir/Interfaces/ControlFlowInterfaces.td"
19+ include "mlir/Interfaces/LoopLikeInterface.td"
1820
1921//===----------------------------------------------------------------------===//
2022// Base class.
@@ -1277,7 +1279,7 @@ def SparseTensor_SelectOp : SparseTensor_Op<"select", [Pure, SameOperandsAndResu
12771279
12781280def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator,
12791281 ParentOneOf<["BinaryOp", "UnaryOp", "ReduceOp", "SelectOp",
1280- "ForeachOp"]>]>,
1282+ "ForeachOp", "IterateOp" ]>]>,
12811283 Arguments<(ins Variadic<AnyType>:$results)> {
12821284 let summary = "Yield from sparse_tensor set-like operations";
12831285 let description = [{
@@ -1430,6 +1432,154 @@ def SparseTensor_ForeachOp : SparseTensor_Op<"foreach",
14301432 let hasVerifier = 1;
14311433}
14321434
1435+ //===----------------------------------------------------------------------===//
1436+ // Sparse Tensor Iteration Operations.
1437+ //===----------------------------------------------------------------------===//
1438+
1439+ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
1440+ [Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
1441+
1442+ let arguments = (ins AnySparseTensor:$tensor,
1443+ Optional<AnySparseIterator>:$parentIter,
1444+ LevelAttr:$loLvl, LevelAttr:$hiLvl);
1445+
1446+ let results = (outs AnySparseIterSpace:$resultSpace);
1447+
1448+ let summary = "Extract an iteration space from a sparse tensor between certain levels";
1449+ let description = [{
1450+ Extracts a `!sparse_tensor.iter_space` from a sparse tensor between
1451+ certian (consecutive) levels.
1452+
1453+ `tensor`: the input sparse tensor that defines the iteration space.
1454+ `parentIter`: the iterator for the previous level, at which the iteration space
1455+ at the current levels will be extracted.
1456+ `loLvl`, `hiLvl`: the level range between [loLvl, hiLvl) in the input tensor that
1457+ the returned iteration space covers. `hiLvl - loLvl` defines the dimension of the
1458+ iteration space.
1459+
1460+ Example:
1461+ ```mlir
1462+ // Extracts a 1-D iteration space from a COO tensor at level 1.
1463+ %space = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 1
1464+ : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
1465+ ```
1466+ }];
1467+
1468+
1469+ let extraClassDeclaration = [{
1470+ std::pair<Level, Level> getLvlRange() {
1471+ return std::make_pair(getLoLvl(), getHiLvl());
1472+ }
1473+ unsigned getSpaceDim() {
1474+ return getHiLvl() - getLoLvl();
1475+ }
1476+ ArrayRef<::mlir::sparse_tensor::LevelType> getSpaceLvlTypes() {
1477+ return getResultSpace().getType().getLvlTypes();
1478+ }
1479+ }];
1480+
1481+ let hasVerifier = 1;
1482+ let assemblyFormat = "$tensor (`at` $parentIter^)? `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) "
1483+ " attr-dict `:` type($tensor) (`,` type($parentIter)^)?";
1484+ }
1485+
1486+ def IterateOp : SparseTensor_Op<"iterate",
1487+ [RecursiveMemoryEffects, RecursivelySpeculatable,
1488+ DeclareOpInterfaceMethods<LoopLikeOpInterface,
1489+ ["getInitsMutable", "getLoopResults", "getRegionIterArgs",
1490+ "getYieldedValuesMutable"]>,
1491+ DeclareOpInterfaceMethods<RegionBranchOpInterface,
1492+ ["getEntrySuccessorOperands"]>,
1493+ SingleBlockImplicitTerminator<"sparse_tensor::YieldOp">]> {
1494+
1495+ let arguments = (ins AnySparseIterSpace:$iterSpace,
1496+ Variadic<AnyType>:$initArgs,
1497+ LevelSetAttr:$crdUsedLvls);
1498+ let results = (outs Variadic<AnyType>:$results);
1499+ let regions = (region SizedRegion<1>:$region);
1500+
1501+ let summary = "Iterate over a sparse iteration space";
1502+ let description = [{
1503+ The `sparse_tensor.iterate` operations represents a loop over the
1504+ provided iteration space extracted from a specific sparse tensor.
1505+ The operation defines an SSA value for a sparse iterator that points
1506+ to the current stored element in the sparse tensor and SSA values
1507+ for coordinates of the stored element. The coordinates are always
1508+ converted to `index` type despite of the underlying sparse tensor
1509+ storage. When coordinates are not used, the SSA values can be skipped
1510+ by `_` symbols, which usually leads to simpler generated code after
1511+ sparsification. For example:
1512+
1513+ ```mlir
1514+ // The coordinate for level 0 is not used when iterating over a 2-D
1515+ // iteration space.
1516+ %sparse_tensor.iterate %iterator in %space at(_, %crd_1)
1517+ : !sparse_tensor.iter_space<#CSR, lvls = 0 to 2>
1518+ ```
1519+
1520+ `sparse_tensor.iterate` can also operate on loop-carried variables
1521+ and returns the final values after loop termination.
1522+ The initial values of the variables are passed as additional SSA operands
1523+ to the iterator SSA value and used coordinate SSA values mentioned
1524+ above. The operation region has an argument for the iterator, variadic
1525+ arguments for specified (used) coordiates and followed by one argument
1526+ for each loop-carried variable, representing the value of the variable
1527+ at the current iteration.
1528+ The body region must contain exactly one block that terminates with
1529+ `sparse_tensor.yield`.
1530+
1531+ `sparse_tensor.iterate` results hold the final values after the last
1532+ iteration. If the `sparse_tensor.iterate` defines any values, a yield
1533+ must be explicitly present.
1534+ The number and types of the `sparse_tensor.iterate` results must match
1535+ the initial values in the iter_args binding and the yield operands.
1536+
1537+
1538+ A nested `sparse_tensor.iterate` example that prints all the coordinates
1539+ stored in the sparse input:
1540+
1541+ ```mlir
1542+ func.func @nested_iterate(%sp : tensor<4x8xf32, #COO>) {
1543+ // Iterates over the first level of %sp
1544+ %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
1545+ %r1 = sparse_tensor.iterate %it1 in %l1 at (%crd0)
1546+ : !sparse_tensor.iter_space<#COO, lvls = 0 to 1> {
1547+ // Iterates over the second level of %sp
1548+ %l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1
1549+ : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1>
1550+ %r2 = sparse_tensor.iterate %it2 in %l2 at (crd1)
1551+ : !sparse_tensor.iter_space<#COO, lvls = 1 to 2> {
1552+ vector.print %crd0 : index
1553+ vector.print %crd1 : index
1554+ }
1555+ }
1556+ }
1557+
1558+ ```
1559+ }];
1560+
1561+ let extraClassDeclaration = [{
1562+ unsigned getSpaceDim() {
1563+ return getIterSpace().getType().getSpaceDim();
1564+ }
1565+ BlockArgument getIterator() {
1566+ return getRegion().getArguments().front();
1567+ }
1568+ Block::BlockArgListType getCrds() {
1569+ // The first block argument is iterator, the remaining arguments are
1570+ // referenced coordinates.
1571+ return getRegion().getArguments().slice(1, getCrdUsedLvls().count());
1572+ }
1573+ unsigned getNumRegionIterArgs() {
1574+ return getRegion().getArguments().size() - 1 - getCrdUsedLvls().count();
1575+ }
1576+ }];
1577+
1578+ let hasVerifier = 1;
1579+ let hasRegionVerifier = 1;
1580+ let hasCustomAssemblyFormat = 1;
1581+ }
1582+
14331583//===----------------------------------------------------------------------===//
14341584// Sparse Tensor Debugging and Test-Only Operations.
14351585//===----------------------------------------------------------------------===//
0 commit comments