@@ -15,8 +15,6 @@ include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td"
1515include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td"
1616include "mlir/Interfaces/InferTypeOpInterface.td"
1717include "mlir/Interfaces/SideEffectInterfaces.td"
18- include "mlir/Interfaces/ControlFlowInterfaces.td"
19- include "mlir/Interfaces/LoopLikeInterface.td"
2018
2119//===----------------------------------------------------------------------===//
2220// Base class.
@@ -1279,7 +1277,7 @@ def SparseTensor_SelectOp : SparseTensor_Op<"select", [Pure, SameOperandsAndResu
12791277
12801278def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator,
12811279 ParentOneOf<["BinaryOp", "UnaryOp", "ReduceOp", "SelectOp",
1282- "ForeachOp", "IterateOp" ]>]>,
1280+ "ForeachOp"]>]>,
12831281 Arguments<(ins Variadic<AnyType>:$results)> {
12841282 let summary = "Yield from sparse_tensor set-like operations";
12851283 let description = [{
@@ -1432,154 +1430,6 @@ def SparseTensor_ForeachOp : SparseTensor_Op<"foreach",
14321430 let hasVerifier = 1;
14331431}
14341432
1435- //===----------------------------------------------------------------------===//
1436- // Sparse Tensor Iteration Operations.
1437- //===----------------------------------------------------------------------===//
1438-
1439- def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
1440- [Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
1441-
1442- let arguments = (ins AnySparseTensor:$tensor,
1443- Optional<AnySparseIterator>:$parentIter,
1444- LevelAttr:$loLvl, LevelAttr:$hiLvl);
1445-
1446- let results = (outs AnySparseIterSpace:$resultSpace);
1447-
1448- let summary = "Extract an iteration space from a sparse tensor between certain levels";
1449- let description = [{
1450- Extracts a `!sparse_tensor.iter_space` from a sparse tensor between
1451- certian (consecutive) levels.
1452-
1453- `tensor`: the input sparse tensor that defines the iteration space.
1454- `parentIter`: the iterator for the previous level, at which the iteration space
1455- at the current levels will be extracted.
1456- `loLvl`, `hiLvl`: the level range between [loLvl, hiLvl) in the input tensor that
1457- the returned iteration space covers. `hiLvl - loLvl` defines the dimension of the
1458- iteration space.
1459-
1460- Example:
1461- ```mlir
1462- // Extracts a 1-D iteration space from a COO tensor at level 1.
1463- %space = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 1
1464- : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
1465- ```
1466- }];
1467-
1468-
1469- let extraClassDeclaration = [{
1470- std::pair<Level, Level> getLvlRange() {
1471- return std::make_pair(getLoLvl(), getHiLvl());
1472- }
1473- unsigned getSpaceDim() {
1474- return getHiLvl() - getLoLvl();
1475- }
1476- ArrayRef<::mlir::sparse_tensor::LevelType> getSpaceLvlTypes() {
1477- return getResultSpace().getType().getLvlTypes();
1478- }
1479- }];
1480-
1481- let hasVerifier = 1;
1482- let assemblyFormat = "$tensor (`at` $parentIter^)? `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) "
1483- " attr-dict `:` type($tensor) (`,` type($parentIter)^)?";
1484- }
1485-
1486- def IterateOp : SparseTensor_Op<"iterate",
1487- [RecursiveMemoryEffects, RecursivelySpeculatable,
1488- DeclareOpInterfaceMethods<LoopLikeOpInterface,
1489- ["getInitsMutable", "getLoopResults", "getRegionIterArgs",
1490- "getYieldedValuesMutable"]>,
1491- DeclareOpInterfaceMethods<RegionBranchOpInterface,
1492- ["getEntrySuccessorOperands"]>,
1493- SingleBlockImplicitTerminator<"sparse_tensor::YieldOp">]> {
1494-
1495- let arguments = (ins AnySparseIterSpace:$iterSpace,
1496- Variadic<AnyType>:$initArgs,
1497- LevelSetAttr:$crdUsedLvls);
1498- let results = (outs Variadic<AnyType>:$results);
1499- let regions = (region SizedRegion<1>:$region);
1500-
1501- let summary = "Iterate over a sparse iteration space";
1502- let description = [{
1503- The `sparse_tensor.iterate` operations represents a loop over the
1504- provided iteration space extracted from a specific sparse tensor.
1505- The operation defines an SSA value for a sparse iterator that points
1506- to the current stored element in the sparse tensor and SSA values
1507- for coordinates of the stored element. The coordinates are always
1508- converted to `index` type despite of the underlying sparse tensor
1509- storage. When coordinates are not used, the SSA values can be skipped
1510- by `_` symbols, which usually leads to simpler generated code after
1511- sparsification. For example:
1512-
1513- ```mlir
1514- // The coordinate for level 0 is not used when iterating over a 2-D
1515- // iteration space.
1516- %sparse_tensor.iterate %iterator in %space at(_, %crd_1)
1517- : !sparse_tensor.iter_space<#CSR, lvls = 0 to 2>
1518- ```
1519-
1520- `sparse_tensor.iterate` can also operate on loop-carried variables
1521- and returns the final values after loop termination.
1522- The initial values of the variables are passed as additional SSA operands
1523- to the iterator SSA value and used coordinate SSA values mentioned
1524- above. The operation region has an argument for the iterator, variadic
1525- arguments for specified (used) coordiates and followed by one argument
1526- for each loop-carried variable, representing the value of the variable
1527- at the current iteration.
1528- The body region must contain exactly one block that terminates with
1529- `sparse_tensor.yield`.
1530-
1531- `sparse_tensor.iterate` results hold the final values after the last
1532- iteration. If the `sparse_tensor.iterate` defines any values, a yield
1533- must be explicitly present.
1534- The number and types of the `sparse_tensor.iterate` results must match
1535- the initial values in the iter_args binding and the yield operands.
1536-
1537-
1538- A nested `sparse_tensor.iterate` example that prints all the coordinates
1539- stored in the sparse input:
1540-
1541- ```mlir
1542- func.func @nested_iterate(%sp : tensor<4x8xf32, #COO>) {
1543- // Iterates over the first level of %sp
1544- %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
1545- %r1 = sparse_tensor.iterate %it1 in %l1 at (%crd0)
1546- : !sparse_tensor.iter_space<#COO, lvls = 0 to 1> {
1547- // Iterates over the second level of %sp
1548- %l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1
1549- : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1>
1550- %r2 = sparse_tensor.iterate %it2 in %l2 at (crd1)
1551- : !sparse_tensor.iter_space<#COO, lvls = 1 to 2> {
1552- vector.print %crd0 : index
1553- vector.print %crd1 : index
1554- }
1555- }
1556- }
1557-
1558- ```
1559- }];
1560-
1561- let extraClassDeclaration = [{
1562- unsigned getSpaceDim() {
1563- return getIterSpace().getType().getSpaceDim();
1564- }
1565- BlockArgument getIterator() {
1566- return getRegion().getArguments().front();
1567- }
1568- Block::BlockArgListType getCrds() {
1569- // The first block argument is iterator, the remaining arguments are
1570- // referenced coordinates.
1571- return getRegion().getArguments().slice(1, getCrdUsedLvls().count());
1572- }
1573- unsigned getNumRegionIterArgs() {
1574- return getRegion().getArguments().size() - 1 - getCrdUsedLvls().count();
1575- }
1576- }];
1577-
1578- let hasVerifier = 1;
1579- let hasRegionVerifier = 1;
1580- let hasCustomAssemblyFormat = 1;
1581- }
1582-
15831433//===----------------------------------------------------------------------===//
15841434// Sparse Tensor Debugging and Test-Only Operations.
15851435//===----------------------------------------------------------------------===//
0 commit comments