@@ -1319,95 +1319,6 @@ FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer(
13191319
13201320namespace {
13211321
1322- // ===----------------------------------------------------------------------===//
1323- // SliceWorklist
1324- // ===----------------------------------------------------------------------===//
1325-
1326- // / Struct for tracking the number of stale entries on the worklist and whether
1327- // / there is a remaining valid entry.
1328- struct EntryCount {
1329- bool isValid = true ;
1330- unsigned count = 0 ;
1331- };
1332-
1333- // / A FIFO worklist of operations with efficient removal and set semantics.
1334- // /
1335- // / This class maintains a queue of operations and a mapping of operations to
1336- // / positions in the vector, so that operations can be removed efficiently at
1337- // / random. When an operation is removed, it is replaced with nullptr. Such
1338- // / nullptr are skipped when pop'ing elements.
1339- // /
1340- // / This is similar to the worklist used by the GreedyPatternRewriteDriver,
1341- // / except instead FIFO so that slices for fusion can be processed breadth
1342- // / first.
1343- class SliceWorklist {
1344- public:
1345- SliceWorklist () = default ;
1346-
1347- // / Push an operation to the end of the worklist. This assumes that
1348- // / the given operation is not already on the worklist.
1349- void push (Operation *op);
1350-
1351- // / Pop the an operation from the end of the worklist. Returns nullptr if
1352- // / there are no remaining valid operations.
1353- Operation *pop ();
1354-
1355- // / Remove an operation from the worklist.
1356- void remove (Operation *op);
1357-
1358- protected:
1359- // / The queue of operations.
1360- std::deque<Operation *> list;
1361-
1362- // / A mapping of operations to the number of stale copies in the queue.
1363- DenseMap<Operation *, EntryCount> map;
1364- };
1365-
1366- void SliceWorklist::push (Operation *op) {
1367- assert (op && " cannot push nullptr to worklist" );
1368- list.push_back (op);
1369- EntryCount newCount = map.lookup (op);
1370- // Because operations are only pushed on creation, valid duplicates are
1371- // never added.
1372- assert ((!map.contains (op) || !newCount.isValid ) &&
1373- " cannot push a duplicate operation" );
1374- map[op] = {/* isValid=*/ true , newCount.count + 1 };
1375- }
1376-
1377- Operation *SliceWorklist::pop () {
1378- // Pop the front of the queue until we hit a valid entry.
1379- while (!list.empty ()) {
1380- Operation *op = list.front ();
1381- list.pop_front ();
1382-
1383- EntryCount e = map.lookup (op);
1384- // If the entry count is greater than 1 or there is no valid entry,
1385- // this must be a stale entry. Decrement the map entry by one and continue.
1386- if (e.count > 1 || !e.isValid ) {
1387- int64_t newCount = e.count - 1 ;
1388- if (newCount <= 0 )
1389- map.erase (op);
1390- else
1391- map[op] = {e.isValid , static_cast <unsigned int >(newCount)};
1392- continue ;
1393- }
1394-
1395- map.erase (op);
1396- return op;
1397- }
1398- return nullptr ;
1399- }
1400-
1401- // Mark the operation as invalid if present. Removal from the map will
1402- // happen later when popping from the worklist.
1403- void SliceWorklist::remove (Operation *op) {
1404- if (!map.contains (op))
1405- return ;
1406-
1407- EntryCount e = map.lookup (op);
1408- map[op] = {/* isValid=*/ false , e.count };
1409- }
1410-
14111322// ===----------------------------------------------------------------------===//
14121323// SliceTrackingListener
14131324// ===----------------------------------------------------------------------===//
@@ -1430,15 +1341,18 @@ class SliceTrackingListener : public RewriterBase::Listener {
14301341 void notifyOperationInserted (Operation *op,
14311342 OpBuilder::InsertPoint previous) override ;
14321343
1344+ // / Shared helper for operation removal from the worklist.
1345+ void removeOp (Operation *op);
1346+
14331347 // / Remove the operation from the worklist.
14341348 void notifyOperationErased (Operation *op) override ;
14351349
14361350 // / Remove the operation from the worklist.
14371351 void notifyOperationReplaced (Operation *op, ValueRange replacement) override ;
14381352
1439- // / The worklist for this transformation keeps track of the operations that
1440- // / need to be (re)visited .
1441- SliceWorklist worklist;
1353+ // / The worklist for this transformation keeps track of the slices to visit
1354+ // / next for fusion .
1355+ std::deque<tensor::ExtractSliceOp> worklist;
14421356
14431357private:
14441358 // / Optional pattern set to apply when adding new operations to the worklist.
@@ -1453,8 +1367,8 @@ SliceTrackingListener::SliceTrackingListener(
14531367LogicalResult
14541368SliceTrackingListener::insertAndApplyPatterns (ArrayRef<Operation *> ops) {
14551369 for (Operation *op : ops) {
1456- if (isa <tensor::ExtractSliceOp>(op))
1457- worklist.push (op );
1370+ if (auto slice = dyn_cast <tensor::ExtractSliceOp>(op))
1371+ worklist.push_back (slice );
14581372 }
14591373
14601374 if (!patterns)
@@ -1468,18 +1382,36 @@ SliceTrackingListener::insertAndApplyPatterns(ArrayRef<Operation *> ops) {
14681382
14691383void SliceTrackingListener::notifyOperationInserted (
14701384 Operation *op, OpBuilder::InsertPoint previous) {
1385+ auto slice = dyn_cast<tensor::ExtractSliceOp>(op);
1386+ if (!slice)
1387+ return ;
1388+ worklist.push_back (slice);
1389+ }
1390+
1391+ // Scan the worklist for the given op and remove it if present. The expectation
1392+ // is for the worklist to be small and for removal to be relatively rare.
1393+ void SliceTrackingListener::removeOp (Operation *op) {
14711394 if (!isa<tensor::ExtractSliceOp>(op))
14721395 return ;
1473- worklist.push (op);
1396+ auto iter = worklist.begin ();
1397+ while (iter != worklist.end ()) {
1398+ if (*iter == op)
1399+ break ;
1400+ iter++;
1401+ }
1402+ if (iter == worklist.end ())
1403+ return ;
1404+
1405+ worklist.erase (iter);
14741406}
14751407
14761408void SliceTrackingListener::notifyOperationErased (Operation *op) {
1477- worklist. remove (op);
1409+ removeOp (op);
14781410}
14791411
14801412void SliceTrackingListener::notifyOperationReplaced (Operation *op,
14811413 ValueRange replacement) {
1482- worklist. remove (op);
1414+ removeOp (op);
14831415}
14841416} // namespace
14851417
@@ -1547,10 +1479,9 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
15471479 return rewriter.notifyMatchFailure (consumer, " cleanup patterns failed" );
15481480 }
15491481 OpBuilder::InsertionGuard g (rewriter);
1550- while (Operation *next = sliceTracker.worklist .pop ()) {
1551- auto candidateSlice = dyn_cast<tensor::ExtractSliceOp>(next);
1552- if (!candidateSlice)
1553- continue ;
1482+ while (!sliceTracker.worklist .empty ()) {
1483+ auto candidateSlice = sliceTracker.worklist .front ();
1484+ sliceTracker.worklist .pop_front ();
15541485
15551486 auto [fusableProducer, destinationInitArg] =
15561487 getUntiledProducerFromSliceSource (&candidateSlice.getSourceMutable (),
0 commit comments