@@ -60,19 +60,34 @@ static cl::opt<bool> DisableAll("disable-loop-idiom-vectorize-all", cl::Hidden,
6060 cl::init (false ),
6161 cl::desc(" Disable Loop Idiom Vectorize Pass." ));
6262
63+ static cl::opt<LoopIdiomVectorizeStyle>
64+ LITVecStyle (" loop-idiom-vectorize-style" , cl::Hidden,
65+ cl::desc (" The vectorization style for loop idiom transform." ),
66+ cl::values(clEnumValN(LoopIdiomVectorizeStyle::Masked, " masked" ,
67+ " Use masked vector intrinsics" ),
68+ clEnumValN(LoopIdiomVectorizeStyle::Predicated,
69+ " predicated" , " Use VP intrinsics" )),
70+ cl::init(LoopIdiomVectorizeStyle::Masked));
71+
6372static cl::opt<bool >
6473 DisableByteCmp (" disable-loop-idiom-vectorize-bytecmp" , cl::Hidden,
6574 cl::init (false ),
6675 cl::desc(" Proceed with Loop Idiom Vectorize Pass, but do "
6776 " not convert byte-compare loop(s)." ));
6877
78+ static cl::opt<unsigned >
79+ ByteCmpVF (" loop-idiom-vectorize-bytecmp-vf" , cl::Hidden,
80+ cl::desc (" The vectorization factor for byte-compare patterns." ),
81+ cl::init(16 ));
82+
6983static cl::opt<bool >
7084 VerifyLoops (" loop-idiom-vectorize-verify" , cl::Hidden, cl::init(false ),
7185 cl::desc(" Verify loops generated Loop Idiom Vectorize Pass." ));
7286
7387namespace {
74-
7588class LoopIdiomVectorize {
89+ LoopIdiomVectorizeStyle VectorizeStyle;
90+ unsigned ByteCompareVF;
7691 Loop *CurLoop = nullptr ;
7792 DominatorTree *DT;
7893 LoopInfo *LI;
@@ -87,10 +102,11 @@ class LoopIdiomVectorize {
87102 BasicBlock *VectorLoopIncBlock = nullptr ;
88103
89104public:
90- explicit LoopIdiomVectorize (DominatorTree *DT, LoopInfo *LI,
91- const TargetTransformInfo *TTI,
92- const DataLayout *DL)
93- : DT(DT), LI(LI), TTI(TTI), DL(DL) {}
105+ LoopIdiomVectorize (LoopIdiomVectorizeStyle S, unsigned VF, DominatorTree *DT,
106+ LoopInfo *LI, const TargetTransformInfo *TTI,
107+ const DataLayout *DL)
108+ : VectorizeStyle(S), ByteCompareVF(VF), DT(DT), LI(LI), TTI(TTI), DL(DL) {
109+ }
94110
95111 bool run (Loop *L);
96112
@@ -111,6 +127,10 @@ class LoopIdiomVectorize {
111127 Value *createMaskedFindMismatch (IRBuilder<> &Builder, GetElementPtrInst *GEPA,
112128 GetElementPtrInst *GEPB, Value *ExtStart,
113129 Value *ExtEnd);
130+ Value *createPredicatedFindMismatch (IRBuilder<> &Builder,
131+ GetElementPtrInst *GEPA,
132+ GetElementPtrInst *GEPB, Value *ExtStart,
133+ Value *ExtEnd);
114134
115135 void transformByteCompare (GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
116136 PHINode *IndPhi, Value *MaxLen, Instruction *Index,
@@ -128,8 +148,16 @@ PreservedAnalyses LoopIdiomVectorizePass::run(Loop &L, LoopAnalysisManager &AM,
128148
129149 const auto *DL = &L.getHeader ()->getModule ()->getDataLayout ();
130150
131- LoopIdiomVectorize LIT (&AR.DT , &AR.LI , &AR.TTI , DL);
132- if (!LIT.run (&L))
151+ LoopIdiomVectorizeStyle VecStyle = VectorizeStyle;
152+ if (LITVecStyle.getNumOccurrences ())
153+ VecStyle = LITVecStyle;
154+
155+ unsigned BCVF = ByteCompareVF;
156+ if (ByteCmpVF.getNumOccurrences ())
157+ BCVF = ByteCmpVF;
158+
159+ LoopIdiomVectorize LIV (VecStyle, BCVF, &AR.DT , &AR.LI , &AR.TTI , DL);
160+ if (!LIV.run (&L))
133161 return PreservedAnalyses::all ();
134162
135163 return PreservedAnalyses::none ();
@@ -362,14 +390,15 @@ Value *LoopIdiomVectorize::createMaskedFindMismatch(IRBuilder<> &Builder,
362390 // Therefore, we know that we can use a 64-bit induction variable that
363391 // starts from 0 -> ExtMaxLen and it will not overflow.
364392 ScalableVectorType *PredVTy =
365- ScalableVectorType::get (Builder.getInt1Ty (), 16 );
393+ ScalableVectorType::get (Builder.getInt1Ty (), ByteCompareVF );
366394
367395 Value *InitialPred = Builder.CreateIntrinsic (
368396 Intrinsic::get_active_lane_mask, {PredVTy, I64Type}, {ExtStart, ExtEnd});
369397
370398 Value *VecLen = Builder.CreateIntrinsic (Intrinsic::vscale, {I64Type}, {});
371- VecLen = Builder.CreateMul (VecLen, ConstantInt::get (I64Type, 16 ), " " ,
372- /* HasNUW=*/ true , /* HasNSW=*/ true );
399+ VecLen =
400+ Builder.CreateMul (VecLen, ConstantInt::get (I64Type, ByteCompareVF), " " ,
401+ /* HasNUW=*/ true , /* HasNSW=*/ true );
373402
374403 Value *PFalse = Builder.CreateVectorSplat (PredVTy->getElementCount (),
375404 Builder.getInt1 (false ));
@@ -384,7 +413,8 @@ Value *LoopIdiomVectorize::createMaskedFindMismatch(IRBuilder<> &Builder,
384413 LoopPred->addIncoming (InitialPred, VectorLoopPreheaderBlock);
385414 PHINode *VectorIndexPhi = Builder.CreatePHI (I64Type, 2 , " mismatch_vec_index" );
386415 VectorIndexPhi->addIncoming (ExtStart, VectorLoopPreheaderBlock);
387- Type *VectorLoadType = ScalableVectorType::get (Builder.getInt8Ty (), 16 );
416+ Type *VectorLoadType =
417+ ScalableVectorType::get (Builder.getInt8Ty (), ByteCompareVF);
388418 Value *Passthru = ConstantInt::getNullValue (VectorLoadType);
389419
390420 Value *VectorLhsGep =
@@ -445,6 +475,112 @@ Value *LoopIdiomVectorize::createMaskedFindMismatch(IRBuilder<> &Builder,
445475 return Builder.CreateTrunc (VectorLoopRes64, ResType);
446476}
447477
478+ Value *LoopIdiomVectorize::createPredicatedFindMismatch (IRBuilder<> &Builder,
479+ GetElementPtrInst *GEPA,
480+ GetElementPtrInst *GEPB,
481+ Value *ExtStart,
482+ Value *ExtEnd) {
483+ Type *I64Type = Builder.getInt64Ty ();
484+ Type *I32Type = Builder.getInt32Ty ();
485+ Type *ResType = I32Type;
486+ Type *LoadType = Builder.getInt8Ty ();
487+ Value *PtrA = GEPA->getPointerOperand ();
488+ Value *PtrB = GEPB->getPointerOperand ();
489+
490+ // At this point we know two things must be true:
491+ // 1. Start <= End
492+ // 2. ExtMaxLen <= 4096 due to the page checks.
493+ // Therefore, we know that we can use a 64-bit induction variable that
494+ // starts from 0 -> ExtMaxLen and it will not overflow.
495+ auto *JumpToVectorLoop = BranchInst::Create (VectorLoopStartBlock);
496+ Builder.Insert (JumpToVectorLoop);
497+
498+ // Set up the first Vector loop block by creating the PHIs, doing the vector
499+ // loads and comparing the vectors.
500+ Builder.SetInsertPoint (VectorLoopStartBlock);
501+ auto *VectorIndexPhi = Builder.CreatePHI (I64Type, 2 , " mismatch_vector_index" );
502+ VectorIndexPhi->addIncoming (ExtStart, VectorLoopPreheaderBlock);
503+
504+ // Calculate AVL by subtracting the vector loop index from the trip count
505+ Value *AVL = Builder.CreateSub (ExtEnd, VectorIndexPhi, " avl" , /* HasNUW=*/ true ,
506+ /* HasNSW=*/ true );
507+
508+ auto *VectorLoadType = ScalableVectorType::get (LoadType, ByteCompareVF);
509+ auto *VF = ConstantInt::get (
510+ I32Type, VectorLoadType->getElementCount ().getKnownMinValue ());
511+ auto *IsScalable = ConstantInt::getBool (
512+ Builder.getContext (), VectorLoadType->getElementCount ().isScalable ());
513+
514+ Value *VL = Builder.CreateIntrinsic (Intrinsic::experimental_get_vector_length,
515+ {I64Type}, {AVL, VF, IsScalable});
516+ Value *GepOffset = VectorIndexPhi;
517+
518+ Value *VectorLhsGep = Builder.CreateGEP (LoadType, PtrA, GepOffset);
519+ if (GEPA->isInBounds ())
520+ cast<GetElementPtrInst>(VectorLhsGep)->setIsInBounds (true );
521+ VectorType *TrueMaskTy =
522+ VectorType::get (Builder.getInt1Ty (), VectorLoadType->getElementCount ());
523+ Value *AllTrueMask = Constant::getAllOnesValue (TrueMaskTy);
524+ Value *VectorLhsLoad = Builder.CreateIntrinsic (
525+ Intrinsic::vp_load, {VectorLoadType, VectorLhsGep->getType ()},
526+ {VectorLhsGep, AllTrueMask, VL}, nullptr , " lhs.load" );
527+
528+ Value *VectorRhsGep = Builder.CreateGEP (LoadType, PtrB, GepOffset);
529+ if (GEPB->isInBounds ())
530+ cast<GetElementPtrInst>(VectorRhsGep)->setIsInBounds (true );
531+ Value *VectorRhsLoad = Builder.CreateIntrinsic (
532+ Intrinsic::vp_load, {VectorLoadType, VectorLhsGep->getType ()},
533+ {VectorRhsGep, AllTrueMask, VL}, nullptr , " rhs.load" );
534+
535+ StringRef PredicateStr = CmpInst::getPredicateName (CmpInst::ICMP_NE);
536+ auto *PredicateMDS = MDString::get (VectorLhsLoad->getContext (), PredicateStr);
537+ Value *Pred = MetadataAsValue::get (VectorLhsLoad->getContext (), PredicateMDS);
538+ Value *VectorMatchCmp = Builder.CreateIntrinsic (
539+ Intrinsic::vp_icmp, {VectorLhsLoad->getType ()},
540+ {VectorLhsLoad, VectorRhsLoad, Pred, AllTrueMask, VL}, nullptr ,
541+ " mismatch.cmp" );
542+ Value *CTZ = Builder.CreateIntrinsic (
543+ Intrinsic::vp_cttz_elts, {ResType, VectorMatchCmp->getType ()},
544+ {VectorMatchCmp, /* ZeroIsPoison=*/ Builder.getInt1 (true ), AllTrueMask,
545+ VL});
546+ // RISC-V refines/lowers the poison returned by vp.cttz.elts to -1.
547+ Value *MismatchFound =
548+ Builder.CreateICmpSGE (CTZ, ConstantInt::get (ResType, 0 ));
549+ auto *VectorEarlyExit = BranchInst::Create (VectorLoopMismatchBlock,
550+ VectorLoopIncBlock, MismatchFound);
551+ Builder.Insert (VectorEarlyExit);
552+
553+ // Increment the index counter and calculate the predicate for the next
554+ // iteration of the loop. We branch back to the start of the loop if there
555+ // is at least one active lane.
556+ Builder.SetInsertPoint (VectorLoopIncBlock);
557+ Value *VL64 = Builder.CreateZExt (VL, I64Type);
558+ Value *NewVectorIndexPhi =
559+ Builder.CreateAdd (VectorIndexPhi, VL64, " " ,
560+ /* HasNUW=*/ true , /* HasNSW=*/ true );
561+ VectorIndexPhi->addIncoming (NewVectorIndexPhi, VectorLoopIncBlock);
562+ Value *ExitCond = Builder.CreateICmpNE (NewVectorIndexPhi, ExtEnd);
563+ auto *VectorLoopBranchBack =
564+ BranchInst::Create (VectorLoopStartBlock, EndBlock, ExitCond);
565+ Builder.Insert (VectorLoopBranchBack);
566+
567+ // If we found a mismatch then we need to calculate which lane in the vector
568+ // had a mismatch and add that on to the current loop index.
569+ Builder.SetInsertPoint (VectorLoopMismatchBlock);
570+
571+ // Add LCSSA phis for CTZ and VectorIndexPhi.
572+ auto *CTZLCSSAPhi = Builder.CreatePHI (CTZ->getType (), 1 , " ctz" );
573+ CTZLCSSAPhi->addIncoming (CTZ, VectorLoopStartBlock);
574+ auto *VectorIndexLCSSAPhi =
575+ Builder.CreatePHI (VectorIndexPhi->getType (), 1 , " mismatch_vector_index" );
576+ VectorIndexLCSSAPhi->addIncoming (VectorIndexPhi, VectorLoopStartBlock);
577+
578+ Value *CTZI64 = Builder.CreateZExt (CTZLCSSAPhi, I64Type);
579+ Value *VectorLoopRes64 = Builder.CreateAdd (VectorIndexLCSSAPhi, CTZI64, " " ,
580+ /* HasNUW=*/ true , /* HasNSW=*/ true );
581+ return Builder.CreateTrunc (VectorLoopRes64, ResType);
582+ }
583+
448584Value *LoopIdiomVectorize::expandFindMismatch (
449585 IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA,
450586 GetElementPtrInst *GEPB, Instruction *Index, Value *Start, Value *MaxLen) {
@@ -604,8 +740,17 @@ Value *LoopIdiomVectorize::expandFindMismatch(
604740 // processed in each iteration, etc.
605741 Builder.SetInsertPoint (VectorLoopPreheaderBlock);
606742
607- Value *VectorLoopRes =
608- createMaskedFindMismatch (Builder, GEPA, GEPB, ExtStart, ExtEnd);
743+ Value *VectorLoopRes = nullptr ;
744+ switch (VectorizeStyle) {
745+ case LoopIdiomVectorizeStyle::Masked:
746+ VectorLoopRes =
747+ createMaskedFindMismatch (Builder, GEPA, GEPB, ExtStart, ExtEnd);
748+ break ;
749+ case LoopIdiomVectorizeStyle::Predicated:
750+ VectorLoopRes =
751+ createPredicatedFindMismatch (Builder, GEPA, GEPB, ExtStart, ExtEnd);
752+ break ;
753+ }
609754
610755 Builder.Insert (BranchInst::Create (EndBlock));
611756
0 commit comments