@@ -78,6 +78,13 @@ class LoopIdiomVectorize {
7878 const TargetTransformInfo *TTI;
7979 const DataLayout *DL;
8080
81+ // Blocks that will be used for inserting vectorized code.
82+ BasicBlock *EndBlock = nullptr ;
83+ BasicBlock *VectorLoopPreheaderBlock = nullptr ;
84+ BasicBlock *VectorLoopStartBlock = nullptr ;
85+ BasicBlock *VectorLoopMismatchBlock = nullptr ;
86+ BasicBlock *VectorLoopIncBlock = nullptr ;
87+
8188public:
8289 explicit LoopIdiomVectorize (DominatorTree *DT, LoopInfo *LI,
8390 const TargetTransformInfo *TTI,
@@ -95,9 +102,16 @@ class LoopIdiomVectorize {
95102 SmallVectorImpl<BasicBlock *> &ExitBlocks);
96103
97104 bool recognizeByteCompare ();
105+
98106 Value *expandFindMismatch (IRBuilder<> &Builder, DomTreeUpdater &DTU,
99107 GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
100108 Instruction *Index, Value *Start, Value *MaxLen);
109+
110+ Value *createMaskedFindMismatch (IRBuilder<> &Builder, DomTreeUpdater &DTU,
111+ GetElementPtrInst *GEPA,
112+ GetElementPtrInst *GEPB, Value *ExtStart,
113+ Value *ExtEnd);
114+
101115 void transformByteCompare (GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
102116 PHINode *IndPhi, Value *MaxLen, Instruction *Index,
103117 Value *Start, bool IncIdx, BasicBlock *FoundBB,
@@ -331,6 +345,115 @@ bool LoopIdiomVectorize::recognizeByteCompare() {
331345 return true ;
332346}
333347
348+ Value *LoopIdiomVectorize::createMaskedFindMismatch (
349+ IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA,
350+ GetElementPtrInst *GEPB, Value *ExtStart, Value *ExtEnd) {
351+ Type *I64Type = Builder.getInt64Ty ();
352+ Type *ResType = Builder.getInt32Ty ();
353+ Type *LoadType = Builder.getInt8Ty ();
354+ Value *PtrA = GEPA->getPointerOperand ();
355+ Value *PtrB = GEPB->getPointerOperand ();
356+
357+ // At this point we know two things must be true:
358+ // 1. Start <= End
359+ // 2. ExtMaxLen <= MinPageSize due to the page checks.
360+ // Therefore, we know that we can use a 64-bit induction variable that
361+ // starts from 0 -> ExtMaxLen and it will not overflow.
362+ ScalableVectorType *PredVTy =
363+ ScalableVectorType::get (Builder.getInt1Ty (), 16 );
364+
365+ Value *InitialPred = Builder.CreateIntrinsic (
366+ Intrinsic::get_active_lane_mask, {PredVTy, I64Type}, {ExtStart, ExtEnd});
367+
368+ Value *VecLen = Builder.CreateIntrinsic (Intrinsic::vscale, {I64Type}, {});
369+ VecLen = Builder.CreateMul (VecLen, ConstantInt::get (I64Type, 16 ), " " ,
370+ /* HasNUW=*/ true , /* HasNSW=*/ true );
371+
372+ Value *PFalse = Builder.CreateVectorSplat (PredVTy->getElementCount (),
373+ Builder.getInt1 (false ));
374+
375+ BranchInst *JumpToVectorLoop = BranchInst::Create (VectorLoopStartBlock);
376+ Builder.Insert (JumpToVectorLoop);
377+
378+ DTU.applyUpdates ({{DominatorTree::Insert, VectorLoopPreheaderBlock,
379+ VectorLoopStartBlock}});
380+
381+ // Set up the first vector loop block by creating the PHIs, doing the vector
382+ // loads and comparing the vectors.
383+ Builder.SetInsertPoint (VectorLoopStartBlock);
384+ PHINode *LoopPred = Builder.CreatePHI (PredVTy, 2 , " mismatch_vec_loop_pred" );
385+ LoopPred->addIncoming (InitialPred, VectorLoopPreheaderBlock);
386+ PHINode *VectorIndexPhi = Builder.CreatePHI (I64Type, 2 , " mismatch_vec_index" );
387+ VectorIndexPhi->addIncoming (ExtStart, VectorLoopPreheaderBlock);
388+ Type *VectorLoadType = ScalableVectorType::get (Builder.getInt8Ty (), 16 );
389+ Value *Passthru = ConstantInt::getNullValue (VectorLoadType);
390+
391+ Value *VectorLhsGep =
392+ Builder.CreateGEP (LoadType, PtrA, VectorIndexPhi, " " , GEPA->isInBounds ());
393+ Value *VectorLhsLoad = Builder.CreateMaskedLoad (VectorLoadType, VectorLhsGep,
394+ Align (1 ), LoopPred, Passthru);
395+
396+ Value *VectorRhsGep =
397+ Builder.CreateGEP (LoadType, PtrB, VectorIndexPhi, " " , GEPB->isInBounds ());
398+ Value *VectorRhsLoad = Builder.CreateMaskedLoad (VectorLoadType, VectorRhsGep,
399+ Align (1 ), LoopPred, Passthru);
400+
401+ Value *VectorMatchCmp = Builder.CreateICmpNE (VectorLhsLoad, VectorRhsLoad);
402+ VectorMatchCmp = Builder.CreateSelect (LoopPred, VectorMatchCmp, PFalse);
403+ Value *VectorMatchHasActiveLanes = Builder.CreateOrReduce (VectorMatchCmp);
404+ BranchInst *VectorEarlyExit = BranchInst::Create (
405+ VectorLoopMismatchBlock, VectorLoopIncBlock, VectorMatchHasActiveLanes);
406+ Builder.Insert (VectorEarlyExit);
407+
408+ DTU.applyUpdates (
409+ {{DominatorTree::Insert, VectorLoopStartBlock, VectorLoopMismatchBlock},
410+ {DominatorTree::Insert, VectorLoopStartBlock, VectorLoopIncBlock}});
411+
412+ // Increment the index counter and calculate the predicate for the next
413+ // iteration of the loop. We branch back to the start of the loop if there
414+ // is at least one active lane.
415+ Builder.SetInsertPoint (VectorLoopIncBlock);
416+ Value *NewVectorIndexPhi =
417+ Builder.CreateAdd (VectorIndexPhi, VecLen, " " ,
418+ /* HasNUW=*/ true , /* HasNSW=*/ true );
419+ VectorIndexPhi->addIncoming (NewVectorIndexPhi, VectorLoopIncBlock);
420+ Value *NewPred =
421+ Builder.CreateIntrinsic (Intrinsic::get_active_lane_mask,
422+ {PredVTy, I64Type}, {NewVectorIndexPhi, ExtEnd});
423+ LoopPred->addIncoming (NewPred, VectorLoopIncBlock);
424+
425+ Value *PredHasActiveLanes =
426+ Builder.CreateExtractElement (NewPred, uint64_t (0 ));
427+ BranchInst *VectorLoopBranchBack =
428+ BranchInst::Create (VectorLoopStartBlock, EndBlock, PredHasActiveLanes);
429+ Builder.Insert (VectorLoopBranchBack);
430+
431+ DTU.applyUpdates (
432+ {{DominatorTree::Insert, VectorLoopIncBlock, VectorLoopStartBlock},
433+ {DominatorTree::Insert, VectorLoopIncBlock, EndBlock}});
434+
435+ // If we found a mismatch then we need to calculate which lane in the vector
436+ // had a mismatch and add that on to the current loop index.
437+ Builder.SetInsertPoint (VectorLoopMismatchBlock);
438+ PHINode *FoundPred = Builder.CreatePHI (PredVTy, 1 , " mismatch_vec_found_pred" );
439+ FoundPred->addIncoming (VectorMatchCmp, VectorLoopStartBlock);
440+ PHINode *LastLoopPred =
441+ Builder.CreatePHI (PredVTy, 1 , " mismatch_vec_last_loop_pred" );
442+ LastLoopPred->addIncoming (LoopPred, VectorLoopStartBlock);
443+ PHINode *VectorFoundIndex =
444+ Builder.CreatePHI (I64Type, 1 , " mismatch_vec_found_index" );
445+ VectorFoundIndex->addIncoming (VectorIndexPhi, VectorLoopStartBlock);
446+
447+ Value *PredMatchCmp = Builder.CreateAnd (LastLoopPred, FoundPred);
448+ Value *Ctz = Builder.CreateIntrinsic (
449+ Intrinsic::experimental_cttz_elts, {ResType, PredMatchCmp->getType ()},
450+ {PredMatchCmp, /* ZeroIsPoison=*/ Builder.getInt1 (true )});
451+ Ctz = Builder.CreateZExt (Ctz, I64Type);
452+ Value *VectorLoopRes64 = Builder.CreateAdd (VectorFoundIndex, Ctz, " " ,
453+ /* HasNUW=*/ true , /* HasNSW=*/ true );
454+ return Builder.CreateTrunc (VectorLoopRes64, ResType);
455+ }
456+
334457Value *LoopIdiomVectorize::expandFindMismatch (
335458 IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA,
336459 GetElementPtrInst *GEPB, Instruction *Index, Value *Start, Value *MaxLen) {
@@ -345,8 +468,7 @@ Value *LoopIdiomVectorize::expandFindMismatch(
345468 Type *ResType = Builder.getInt32Ty ();
346469
347470 // Split block in the original loop preheader.
348- BasicBlock *EndBlock =
349- SplitBlock (Preheader, PHBranch, DT, LI, nullptr , " mismatch_end" );
471+ EndBlock = SplitBlock (Preheader, PHBranch, DT, LI, nullptr , " mismatch_end" );
350472
351473 // Create the blocks that we're going to need:
352474 // 1. A block for checking the zero-extended length exceeds 0
@@ -370,17 +492,17 @@ Value *LoopIdiomVectorize::expandFindMismatch(
370492 BasicBlock *MemCheckBlock = BasicBlock::Create (
371493 Ctx, " mismatch_mem_check" , EndBlock->getParent (), EndBlock);
372494
373- BasicBlock * VectorLoopPreheaderBlock = BasicBlock::Create (
495+ VectorLoopPreheaderBlock = BasicBlock::Create (
374496 Ctx, " mismatch_vec_loop_preheader" , EndBlock->getParent (), EndBlock);
375497
376- BasicBlock * VectorLoopStartBlock = BasicBlock::Create (
377- Ctx, " mismatch_vec_loop " , EndBlock->getParent (), EndBlock);
498+ VectorLoopStartBlock = BasicBlock::Create (Ctx, " mismatch_vec_loop " ,
499+ EndBlock->getParent (), EndBlock);
378500
379- BasicBlock * VectorLoopIncBlock = BasicBlock::Create (
380- Ctx, " mismatch_vec_loop_inc " , EndBlock->getParent (), EndBlock);
501+ VectorLoopIncBlock = BasicBlock::Create (Ctx, " mismatch_vec_loop_inc " ,
502+ EndBlock->getParent (), EndBlock);
381503
382- BasicBlock * VectorLoopMismatchBlock = BasicBlock::Create (
383- Ctx, " mismatch_vec_loop_found " , EndBlock->getParent (), EndBlock);
504+ VectorLoopMismatchBlock = BasicBlock::Create (Ctx, " mismatch_vec_loop_found " ,
505+ EndBlock->getParent (), EndBlock);
384506
385507 BasicBlock *LoopPreHeaderBlock = BasicBlock::Create (
386508 Ctx, " mismatch_loop_pre" , EndBlock->getParent (), EndBlock);
@@ -491,104 +613,8 @@ Value *LoopIdiomVectorize::expandFindMismatch(
491613 // processed in each iteration, etc.
492614 Builder.SetInsertPoint (VectorLoopPreheaderBlock);
493615
494- // At this point we know two things must be true:
495- // 1. Start <= End
496- // 2. ExtMaxLen <= MinPageSize due to the page checks.
497- // Therefore, we know that we can use a 64-bit induction variable that
498- // starts from 0 -> ExtMaxLen and it will not overflow.
499- ScalableVectorType *PredVTy =
500- ScalableVectorType::get (Builder.getInt1Ty (), 16 );
501-
502- Value *InitialPred = Builder.CreateIntrinsic (
503- Intrinsic::get_active_lane_mask, {PredVTy, I64Type}, {ExtStart, ExtEnd});
504-
505- Value *VecLen = Builder.CreateIntrinsic (Intrinsic::vscale, {I64Type}, {});
506- VecLen = Builder.CreateMul (VecLen, ConstantInt::get (I64Type, 16 ), " " ,
507- /* HasNUW=*/ true , /* HasNSW=*/ true );
508-
509- Value *PFalse = Builder.CreateVectorSplat (PredVTy->getElementCount (),
510- Builder.getInt1 (false ));
511-
512- BranchInst *JumpToVectorLoop = BranchInst::Create (VectorLoopStartBlock);
513- Builder.Insert (JumpToVectorLoop);
514-
515- DTU.applyUpdates ({{DominatorTree::Insert, VectorLoopPreheaderBlock,
516- VectorLoopStartBlock}});
517-
518- // Set up the first vector loop block by creating the PHIs, doing the vector
519- // loads and comparing the vectors.
520- Builder.SetInsertPoint (VectorLoopStartBlock);
521- PHINode *LoopPred = Builder.CreatePHI (PredVTy, 2 , " mismatch_vec_loop_pred" );
522- LoopPred->addIncoming (InitialPred, VectorLoopPreheaderBlock);
523- PHINode *VectorIndexPhi = Builder.CreatePHI (I64Type, 2 , " mismatch_vec_index" );
524- VectorIndexPhi->addIncoming (ExtStart, VectorLoopPreheaderBlock);
525- Type *VectorLoadType = ScalableVectorType::get (Builder.getInt8Ty (), 16 );
526- Value *Passthru = ConstantInt::getNullValue (VectorLoadType);
527-
528- Value *VectorLhsGep =
529- Builder.CreateGEP (LoadType, PtrA, VectorIndexPhi, " " , GEPA->isInBounds ());
530- Value *VectorLhsLoad = Builder.CreateMaskedLoad (VectorLoadType, VectorLhsGep,
531- Align (1 ), LoopPred, Passthru);
532-
533- Value *VectorRhsGep =
534- Builder.CreateGEP (LoadType, PtrB, VectorIndexPhi, " " , GEPB->isInBounds ());
535- Value *VectorRhsLoad = Builder.CreateMaskedLoad (VectorLoadType, VectorRhsGep,
536- Align (1 ), LoopPred, Passthru);
537-
538- Value *VectorMatchCmp = Builder.CreateICmpNE (VectorLhsLoad, VectorRhsLoad);
539- VectorMatchCmp = Builder.CreateSelect (LoopPred, VectorMatchCmp, PFalse);
540- Value *VectorMatchHasActiveLanes = Builder.CreateOrReduce (VectorMatchCmp);
541- BranchInst *VectorEarlyExit = BranchInst::Create (
542- VectorLoopMismatchBlock, VectorLoopIncBlock, VectorMatchHasActiveLanes);
543- Builder.Insert (VectorEarlyExit);
544-
545- DTU.applyUpdates (
546- {{DominatorTree::Insert, VectorLoopStartBlock, VectorLoopMismatchBlock},
547- {DominatorTree::Insert, VectorLoopStartBlock, VectorLoopIncBlock}});
548-
549- // Increment the index counter and calculate the predicate for the next
550- // iteration of the loop. We branch back to the start of the loop if there
551- // is at least one active lane.
552- Builder.SetInsertPoint (VectorLoopIncBlock);
553- Value *NewVectorIndexPhi =
554- Builder.CreateAdd (VectorIndexPhi, VecLen, " " ,
555- /* HasNUW=*/ true , /* HasNSW=*/ true );
556- VectorIndexPhi->addIncoming (NewVectorIndexPhi, VectorLoopIncBlock);
557- Value *NewPred =
558- Builder.CreateIntrinsic (Intrinsic::get_active_lane_mask,
559- {PredVTy, I64Type}, {NewVectorIndexPhi, ExtEnd});
560- LoopPred->addIncoming (NewPred, VectorLoopIncBlock);
561-
562- Value *PredHasActiveLanes =
563- Builder.CreateExtractElement (NewPred, uint64_t (0 ));
564- BranchInst *VectorLoopBranchBack =
565- BranchInst::Create (VectorLoopStartBlock, EndBlock, PredHasActiveLanes);
566- Builder.Insert (VectorLoopBranchBack);
567-
568- DTU.applyUpdates (
569- {{DominatorTree::Insert, VectorLoopIncBlock, VectorLoopStartBlock},
570- {DominatorTree::Insert, VectorLoopIncBlock, EndBlock}});
571-
572- // If we found a mismatch then we need to calculate which lane in the vector
573- // had a mismatch and add that on to the current loop index.
574- Builder.SetInsertPoint (VectorLoopMismatchBlock);
575- PHINode *FoundPred = Builder.CreatePHI (PredVTy, 1 , " mismatch_vec_found_pred" );
576- FoundPred->addIncoming (VectorMatchCmp, VectorLoopStartBlock);
577- PHINode *LastLoopPred =
578- Builder.CreatePHI (PredVTy, 1 , " mismatch_vec_last_loop_pred" );
579- LastLoopPred->addIncoming (LoopPred, VectorLoopStartBlock);
580- PHINode *VectorFoundIndex =
581- Builder.CreatePHI (I64Type, 1 , " mismatch_vec_found_index" );
582- VectorFoundIndex->addIncoming (VectorIndexPhi, VectorLoopStartBlock);
583-
584- Value *PredMatchCmp = Builder.CreateAnd (LastLoopPred, FoundPred);
585- Value *Ctz = Builder.CreateIntrinsic (
586- Intrinsic::experimental_cttz_elts, {ResType, PredMatchCmp->getType ()},
587- {PredMatchCmp, /* ZeroIsPoison=*/ Builder.getInt1 (true )});
588- Ctz = Builder.CreateZExt (Ctz, I64Type);
589- Value *VectorLoopRes64 = Builder.CreateAdd (VectorFoundIndex, Ctz, " " ,
590- /* HasNUW=*/ true , /* HasNSW=*/ true );
591- Value *VectorLoopRes = Builder.CreateTrunc (VectorLoopRes64, ResType);
616+ Value *VectorLoopRes =
617+ createMaskedFindMismatch (Builder, DTU, GEPA, GEPB, ExtStart, ExtEnd);
592618
593619 Builder.Insert (BranchInst::Create (EndBlock));
594620
0 commit comments