2424#include " mlir/IR/PatternMatch.h"
2525#include " mlir/Interfaces/DestinationStyleOpInterface.h"
2626#include " mlir/Interfaces/TilingInterface.h"
27+ #include " mlir/Rewrite/FrozenRewritePatternSet.h"
28+ #include " mlir/Transforms/GreedyPatternRewriteDriver.h"
2729#include " llvm/ADT/TypeSwitch.h"
2830#include " llvm/Support/Debug.h"
2931#include < optional>
@@ -1315,6 +1317,104 @@ FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer(
13151317 return generatedSlices;
13161318}
13171319
1320+ namespace {
1321+
1322+ // ===----------------------------------------------------------------------===//
1323+ // SliceTrackingListener
1324+ // ===----------------------------------------------------------------------===//
1325+
1326+ // / This class is a listener for tracking the insertion and removal of
1327+ // / `tensor.extract_slice` ops in a worklist. This can be used in a greedy
1328+ // / fusion algorithm to apply cleanup patterns in between fusion steps.
1329+ class SliceTrackingListener : public RewriterBase ::Listener {
1330+ public:
1331+ explicit SliceTrackingListener (
1332+ std::optional<FrozenRewritePatternSet> patterns);
1333+ SliceTrackingListener () = default ;
1334+
1335+ // / Adds the given list of operations to the worklist, and if present, applies
1336+ // / the list of `patterns` to the newly added operations. This only processes
1337+ // / the given operations and any newly inserted ones by the pattern set.
1338+ LogicalResult insertAndApplyPatterns (ArrayRef<Operation *> newOps);
1339+
1340+ // / Add to the new operation worklist if it is an extract_slice.
1341+ void notifyOperationInserted (Operation *op,
1342+ OpBuilder::InsertPoint previous) override ;
1343+
1344+ // / Shared helper for operation removal from the worklist.
1345+ void removeOp (Operation *op);
1346+
1347+ // / Remove the operation from the worklist.
1348+ void notifyOperationErased (Operation *op) override ;
1349+
1350+ // / Remove the operation from the worklist.
1351+ void notifyOperationReplaced (Operation *op, ValueRange replacement) override ;
1352+
1353+ // / The worklist for this transformation keeps track of the slices to visit
1354+ // / next for fusion.
1355+ std::deque<tensor::ExtractSliceOp> worklist;
1356+
1357+ private:
1358+ // / Optional pattern set to apply when adding new operations to the worklist.
1359+ std::optional<FrozenRewritePatternSet> patterns = std::nullopt ;
1360+ };
1361+
1362+ SliceTrackingListener::SliceTrackingListener (
1363+ std::optional<FrozenRewritePatternSet> p) {
1364+ patterns = std::move (p);
1365+ }
1366+
1367+ LogicalResult
1368+ SliceTrackingListener::insertAndApplyPatterns (ArrayRef<Operation *> ops) {
1369+ for (Operation *op : ops) {
1370+ if (auto slice = dyn_cast<tensor::ExtractSliceOp>(op))
1371+ worklist.push_back (slice);
1372+ }
1373+
1374+ if (!patterns)
1375+ return success ();
1376+
1377+ GreedyRewriteConfig config;
1378+ config.listener = this ;
1379+ config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
1380+ return applyOpPatternsAndFold (ops, patterns.value (), config);
1381+ }
1382+
1383+ void SliceTrackingListener::notifyOperationInserted (
1384+ Operation *op, OpBuilder::InsertPoint previous) {
1385+ auto slice = dyn_cast<tensor::ExtractSliceOp>(op);
1386+ if (!slice)
1387+ return ;
1388+ worklist.push_back (slice);
1389+ }
1390+
1391+ // Scan the worklist for the given op and remove it if present. The expectation
1392+ // is for the worklist to be small and for removal to be relatively rare.
1393+ void SliceTrackingListener::removeOp (Operation *op) {
1394+ if (!isa<tensor::ExtractSliceOp>(op))
1395+ return ;
1396+ auto iter = worklist.begin ();
1397+ while (iter != worklist.end ()) {
1398+ if (*iter == op)
1399+ break ;
1400+ iter++;
1401+ }
1402+ if (iter == worklist.end ())
1403+ return ;
1404+
1405+ worklist.erase (iter);
1406+ }
1407+
1408+ void SliceTrackingListener::notifyOperationErased (Operation *op) {
1409+ removeOp (op);
1410+ }
1411+
1412+ void SliceTrackingListener::notifyOperationReplaced (Operation *op,
1413+ ValueRange replacement) {
1414+ removeOp (op);
1415+ }
1416+ } // namespace
1417+
13181418// / Implementation of tile consumer and fuse producer greedily.
13191419FailureOr<scf::SCFTileAndFuseResult>
13201420mlir::scf::tileConsumerAndFuseProducersUsingSCF (
@@ -1370,33 +1470,32 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
13701470 tensor::ExtractSliceOp candidateSlice;
13711471 SCFTileAndFuseOptions::ControlFnResult controlFnResult;
13721472 };
1373- std::deque<WorklistItem> worklist;
1374- auto addCandidateSlices = [&worklist, &options,
1375- &loops](ArrayRef<Operation *> candidates) {
1376- for (auto candidate : candidates) {
1377- auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(candidate);
1378- if (!sliceOp || sliceOp.use_empty ())
1379- continue ;
13801473
1381- auto [fusableProducer, destinationInitArg] =
1382- getUntiledProducerFromSliceSource (&sliceOp.getSourceMutable (), loops);
1383- if (!fusableProducer)
1384- continue ;
1385- std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult =
1386- options.fusionControlFn (sliceOp, fusableProducer,
1387- destinationInitArg.has_value ());
1388- if (!controlFnResult)
1389- continue ;
1390- worklist.emplace_back (WorklistItem{sliceOp, controlFnResult.value ()});
1391- }
1392- };
1474+ SliceTrackingListener sliceTracker =
1475+ SliceTrackingListener (options.cleanupPatterns );
13931476
1394- addCandidateSlices (tilingResult->generatedSlices );
1477+ if (failed (
1478+ sliceTracker.insertAndApplyPatterns (tilingResult->generatedSlices ))) {
1479+ return rewriter.notifyMatchFailure (consumer, " cleanup patterns failed" );
1480+ }
13951481 OpBuilder::InsertionGuard g (rewriter);
1396- while (!worklist.empty ()) {
1397- // Traverse the slices in BFS fashion.
1398- WorklistItem worklistItem = worklist.front ();
1399- worklist.pop_front ();
1482+ while (!sliceTracker.worklist .empty ()) {
1483+ auto candidateSlice = sliceTracker.worklist .front ();
1484+ sliceTracker.worklist .pop_front ();
1485+
1486+ auto [fusableProducer, destinationInitArg] =
1487+ getUntiledProducerFromSliceSource (&candidateSlice.getSourceMutable (),
1488+ loops);
1489+ if (!fusableProducer)
1490+ continue ;
1491+
1492+ std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult =
1493+ options.fusionControlFn (candidateSlice, fusableProducer,
1494+ destinationInitArg.has_value ());
1495+ if (!controlFnResult)
1496+ continue ;
1497+
1498+ WorklistItem worklistItem = {candidateSlice, controlFnResult.value ()};
14001499
14011500 // The operands of the fused producer might themselved be slices of
14021501 // values produced by operations that implement the `TilingInterface`.
@@ -1407,6 +1506,8 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
14071506 if (!fusedResult)
14081507 continue ;
14091508
1509+ SmallVector<Operation *> worklistCandidates = fusedResult->generatedSlices ;
1510+
14101511 if (worklistItem.controlFnResult .yieldProducerReplacement ) {
14111512 // Reconstruct and yield all opResult of fusableProducerOp by default. The
14121513 // caller can specific which one to yield by designating optional argument
@@ -1421,20 +1522,23 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
14211522 fusableProducerOp, " failed to replacement value for this "
14221523 " operation from within the tiled loop" );
14231524 }
1424- addCandidateSlices (newSlices.value ());
1525+ worklistCandidates. append (newSlices.value ());
14251526 for (auto [index, result] :
14261527 llvm::enumerate (fusableProducerOp->getResults ())) {
14271528 origValToResultNumber[result] = loops.front ()->getNumResults () -
14281529 fusableProducerOp->getNumResults () +
14291530 index;
14301531 }
14311532 }
1432- addCandidateSlices (fusedResult->generatedSlices );
14331533 if (Operation *tiledAndFusedOp =
14341534 fusedResult->tiledAndFusedProducer .getDefiningOp ()) {
14351535 fusedProducers.insert (fusedResult->origProducer .getDefiningOp ());
14361536 tiledAndFusedOps.insert (tiledAndFusedOp);
14371537 }
1538+
1539+ if (failed (sliceTracker.insertAndApplyPatterns (worklistCandidates))) {
1540+ return rewriter.notifyMatchFailure (consumer, " cleanup patterns failed" );
1541+ }
14381542 }
14391543
14401544 DenseMap<Value, Value> replacements;
0 commit comments