@@ -1489,6 +1489,7 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
14891489 // Extracts a 1-D iteration space from a COO tensor at level 1.
14901490 %space = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 1
14911491 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
1492+ ->!sparse_tensor.iter_space<#COO, lvls = 1>
14921493 ```
14931494 }];
14941495
@@ -1501,16 +1502,17 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
15011502 return getHiLvl() - getLoLvl();
15021503 }
15031504 ArrayRef<::mlir::sparse_tensor::LevelType> getSpaceLvlTypes() {
1504- return getResultSpace ().getType().getLvlTypes();
1505+ return getExtractedSpace ().getType().getLvlTypes();
15051506 }
15061507 }];
15071508
15081509 let arguments = (ins AnySparseTensor:$tensor,
15091510 Optional<AnySparseIterator>:$parentIter,
15101511 LevelAttr:$loLvl, LevelAttr:$hiLvl);
1511- let results = (outs AnySparseIterSpace:$resultSpace );
1512+ let results = (outs AnySparseIterSpace:$extractedSpace );
15121513 let assemblyFormat = "$tensor (`at` $parentIter^)? `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) "
1513- " attr-dict `:` type($tensor) (`,` type($parentIter)^)? `->` type($resultSpace)";
1514+ " attr-dict `:` type($tensor) (`,` type($parentIter)^)? "
1515+ "`->` qualified(type($extractedSpace))";
15141516
15151517 let hasVerifier = 1;
15161518}
@@ -1567,12 +1569,14 @@ def IterateOp : SparseTensor_Op<"iterate",
15671569 ```mlir
15681570 func.func @nested_iterate(%sp : tensor<4x8xf32, #COO>) {
15691571 // Iterates over the first level of %sp
1570- %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
1572+ %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0
1573+ : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0 to 1>
15711574 %r1 = sparse_tensor.iterate %it1 in %l1 at (%coord0)
15721575 : !sparse_tensor.iter_space<#COO, lvls = 0 to 1> {
15731576 // Iterates over the second level of %sp
15741577 %l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1
15751578 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1>
1579+ -> !sparse_tensor.iter_space<#COO, lvls = 1 to 2>
15761580 %r2 = sparse_tensor.iterate %it2 in %l2 at (coord1)
15771581 : !sparse_tensor.iter_space<#COO, lvls = 1 to 2> {
15781582 vector.print %coord0 : index
0 commit comments