@@ -56,18 +56,34 @@ static cl::opt<bool> DisableAll("disable-loop-idiom-transform-all", cl::Hidden,
5656 cl::init (false ),
5757 cl::desc(" Disable Loop Idiom Transform Pass." ));
5858
59+ static cl::opt<LoopIdiomTransformStyle>
60+ LITVecStyle (" loop-idiom-transform-style" , cl::Hidden,
61+ cl::desc (" The vectorization style for loop idiom transform." ),
62+ cl::values(clEnumValN(LoopIdiomTransformStyle::Masked, " masked" ,
63+ " Use masked vector intrinsics" ),
64+ clEnumValN(LoopIdiomTransformStyle::Predicated,
65+ " predicated" , " Use VP intrinsics" )),
66+ cl::init(LoopIdiomTransformStyle::Masked));
67+
5968static cl::opt<bool >
6069 DisableByteCmp (" disable-loop-idiom-transform-bytecmp" , cl::Hidden,
6170 cl::init (false ),
6271 cl::desc(" Proceed with Loop Idiom Transform Pass, but do "
6372 " not convert byte-compare loop(s)." ));
6473
74+ static cl::opt<unsigned >
75+ ByteCmpVF (" loop-idiom-transform-bytecmp-vf" , cl::Hidden,
76+ cl::desc (" The vectorization factor for byte-compare patterns." ),
77+ cl::init(16 ));
78+
6579static cl::opt<bool >
6680 VerifyLoops (" verify-loop-idiom-transform" , cl::Hidden, cl::init(false ),
6781 cl::desc(" Verify loops generated Loop Idiom Transform Pass." ));
6882
6983namespace {
7084class LoopIdiomTransform {
85+ LoopIdiomTransformStyle VectorizeStyle;
86+ unsigned ByteCompareVF;
7187 Loop *CurLoop = nullptr ;
7288 DominatorTree *DT;
7389 LoopInfo *LI;
@@ -82,10 +98,11 @@ class LoopIdiomTransform {
8298 BasicBlock *VectorLoopIncBlock = nullptr ;
8399
84100public:
85- explicit LoopIdiomTransform (DominatorTree *DT, LoopInfo *LI,
86- const TargetTransformInfo *TTI,
87- const DataLayout *DL)
88- : DT(DT), LI(LI), TTI(TTI), DL(DL) {}
101+ LoopIdiomTransform (LoopIdiomTransformStyle S, unsigned VF, DominatorTree *DT,
102+ LoopInfo *LI, const TargetTransformInfo *TTI,
103+ const DataLayout *DL)
104+ : VectorizeStyle(S), ByteCompareVF(VF), DT(DT), LI(LI), TTI(TTI), DL(DL) {
105+ }
89106
90107 bool run (Loop *L);
91108
@@ -106,6 +123,10 @@ class LoopIdiomTransform {
106123 Value *createMaskedFindMismatch (IRBuilder<> &Builder, GetElementPtrInst *GEPA,
107124 GetElementPtrInst *GEPB, Value *ExtStart,
108125 Value *ExtEnd);
126+ Value *createPredicatedFindMismatch (IRBuilder<> &Builder,
127+ GetElementPtrInst *GEPA,
128+ GetElementPtrInst *GEPB, Value *ExtStart,
129+ Value *ExtEnd);
109130
110131 void transformByteCompare (GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
111132 PHINode *IndPhi, Value *MaxLen, Instruction *Index,
@@ -123,7 +144,15 @@ PreservedAnalyses LoopIdiomTransformPass::run(Loop &L, LoopAnalysisManager &AM,
123144
124145 const auto *DL = &L.getHeader ()->getModule ()->getDataLayout ();
125146
126- LoopIdiomTransform LIT (&AR.DT , &AR.LI , &AR.TTI , DL);
147+ LoopIdiomTransformStyle VecStyle = VectorizeStyle;
148+ if (LITVecStyle.getNumOccurrences ())
149+ VecStyle = LITVecStyle;
150+
151+ unsigned BCVF = ByteCompareVF;
152+ if (ByteCmpVF.getNumOccurrences ())
153+ BCVF = ByteCmpVF;
154+
155+ LoopIdiomTransform LIT (VecStyle, BCVF, &AR.DT , &AR.LI , &AR.TTI , DL);
127156 if (!LIT.run (&L))
128157 return PreservedAnalyses::all ();
129158
@@ -357,14 +386,15 @@ Value *LoopIdiomTransform::createMaskedFindMismatch(IRBuilder<> &Builder,
357386 // Therefore, we know that we can use a 64-bit induction variable that
358387 // starts from 0 -> ExtMaxLen and it will not overflow.
359388 ScalableVectorType *PredVTy =
360- ScalableVectorType::get (Builder.getInt1Ty (), 16 );
389+ ScalableVectorType::get (Builder.getInt1Ty (), ByteCompareVF );
361390
362391 Value *InitialPred = Builder.CreateIntrinsic (
363392 Intrinsic::get_active_lane_mask, {PredVTy, I64Type}, {ExtStart, ExtEnd});
364393
365394 Value *VecLen = Builder.CreateIntrinsic (Intrinsic::vscale, {I64Type}, {});
366- VecLen = Builder.CreateMul (VecLen, ConstantInt::get (I64Type, 16 ), " " ,
367- /* HasNUW=*/ true , /* HasNSW=*/ true );
395+ VecLen =
396+ Builder.CreateMul (VecLen, ConstantInt::get (I64Type, ByteCompareVF), " " ,
397+ /* HasNUW=*/ true , /* HasNSW=*/ true );
368398
369399 Value *PFalse = Builder.CreateVectorSplat (PredVTy->getElementCount (),
370400 Builder.getInt1 (false ));
@@ -379,7 +409,8 @@ Value *LoopIdiomTransform::createMaskedFindMismatch(IRBuilder<> &Builder,
379409 LoopPred->addIncoming (InitialPred, VectorLoopPreheaderBlock);
380410 PHINode *VectorIndexPhi = Builder.CreatePHI (I64Type, 2 , " mismatch_vec_index" );
381411 VectorIndexPhi->addIncoming (ExtStart, VectorLoopPreheaderBlock);
382- Type *VectorLoadType = ScalableVectorType::get (Builder.getInt8Ty (), 16 );
412+ Type *VectorLoadType =
413+ ScalableVectorType::get (Builder.getInt8Ty (), ByteCompareVF);
383414 Value *Passthru = ConstantInt::getNullValue (VectorLoadType);
384415
385416 Value *VectorLhsGep = Builder.CreateGEP (LoadType, PtrA, VectorIndexPhi);
@@ -442,6 +473,112 @@ Value *LoopIdiomTransform::createMaskedFindMismatch(IRBuilder<> &Builder,
442473 return Builder.CreateTrunc (VectorLoopRes64, ResType);
443474}
444475
476+ Value *LoopIdiomTransform::createPredicatedFindMismatch (IRBuilder<> &Builder,
477+ GetElementPtrInst *GEPA,
478+ GetElementPtrInst *GEPB,
479+ Value *ExtStart,
480+ Value *ExtEnd) {
481+ Type *I64Type = Builder.getInt64Ty ();
482+ Type *I32Type = Builder.getInt32Ty ();
483+ Type *ResType = I32Type;
484+ Type *LoadType = Builder.getInt8Ty ();
485+ Value *PtrA = GEPA->getPointerOperand ();
486+ Value *PtrB = GEPB->getPointerOperand ();
487+
488+ // At this point we know two things must be true:
489+ // 1. Start <= End
490+ // 2. ExtMaxLen <= 4096 due to the page checks.
491+ // Therefore, we know that we can use a 64-bit induction variable that
492+ // starts from 0 -> ExtMaxLen and it will not overflow.
493+ auto *JumpToVectorLoop = BranchInst::Create (VectorLoopStartBlock);
494+ Builder.Insert (JumpToVectorLoop);
495+
496+ // Set up the first Vector loop block by creating the PHIs, doing the vector
497+ // loads and comparing the vectors.
498+ Builder.SetInsertPoint (VectorLoopStartBlock);
499+ auto *VectorIndexPhi = Builder.CreatePHI (I64Type, 2 , " mismatch_vector_index" );
500+ VectorIndexPhi->addIncoming (ExtStart, VectorLoopPreheaderBlock);
501+
502+ // Calculate AVL by subtracting the vector loop index from the trip count
503+ Value *AVL = Builder.CreateSub (ExtEnd, VectorIndexPhi, " avl" , /* HasNUW=*/ true ,
504+ /* HasNSW=*/ true );
505+
506+ auto *VectorLoadType = ScalableVectorType::get (LoadType, ByteCompareVF);
507+ auto *VF = ConstantInt::get (
508+ I32Type, VectorLoadType->getElementCount ().getKnownMinValue ());
509+ auto *IsScalable = ConstantInt::getBool (
510+ Builder.getContext (), VectorLoadType->getElementCount ().isScalable ());
511+
512+ Value *VL = Builder.CreateIntrinsic (Intrinsic::experimental_get_vector_length,
513+ {I64Type}, {AVL, VF, IsScalable});
514+ Value *GepOffset = VectorIndexPhi;
515+
516+ Value *VectorLhsGep = Builder.CreateGEP (LoadType, PtrA, GepOffset);
517+ if (GEPA->isInBounds ())
518+ cast<GetElementPtrInst>(VectorLhsGep)->setIsInBounds (true );
519+ VectorType *TrueMaskTy =
520+ VectorType::get (Builder.getInt1Ty (), VectorLoadType->getElementCount ());
521+ Value *AllTrueMask = Constant::getAllOnesValue (TrueMaskTy);
522+ Value *VectorLhsLoad = Builder.CreateIntrinsic (
523+ Intrinsic::vp_load, {VectorLoadType, VectorLhsGep->getType ()},
524+ {VectorLhsGep, AllTrueMask, VL}, nullptr , " lhs.load" );
525+
526+ Value *VectorRhsGep = Builder.CreateGEP (LoadType, PtrB, GepOffset);
527+ if (GEPB->isInBounds ())
528+ cast<GetElementPtrInst>(VectorRhsGep)->setIsInBounds (true );
529+ Value *VectorRhsLoad = Builder.CreateIntrinsic (
530+ Intrinsic::vp_load, {VectorLoadType, VectorLhsGep->getType ()},
531+ {VectorRhsGep, AllTrueMask, VL}, nullptr , " rhs.load" );
532+
533+ StringRef PredicateStr = CmpInst::getPredicateName (CmpInst::ICMP_NE);
534+ auto *PredicateMDS = MDString::get (VectorLhsLoad->getContext (), PredicateStr);
535+ Value *Pred = MetadataAsValue::get (VectorLhsLoad->getContext (), PredicateMDS);
536+ Value *VectorMatchCmp = Builder.CreateIntrinsic (
537+ Intrinsic::vp_icmp, {VectorLhsLoad->getType ()},
538+ {VectorLhsLoad, VectorRhsLoad, Pred, AllTrueMask, VL}, nullptr ,
539+ " mismatch.cmp" );
540+ Value *CTZ = Builder.CreateIntrinsic (
541+ Intrinsic::vp_cttz_elts, {ResType, VectorMatchCmp->getType ()},
542+ {VectorMatchCmp, /* ZeroIsPoison=*/ Builder.getInt1 (true ), AllTrueMask,
543+ VL});
544+ // RISC-V refines/lowers the poison returned by vp.cttz.elts to -1.
545+ Value *MismatchFound =
546+ Builder.CreateICmpSGE (CTZ, ConstantInt::get (ResType, 0 ));
547+ auto *VectorEarlyExit = BranchInst::Create (VectorLoopMismatchBlock,
548+ VectorLoopIncBlock, MismatchFound);
549+ Builder.Insert (VectorEarlyExit);
550+
551+ // Increment the index counter and calculate the predicate for the next
552+ // iteration of the loop. We branch back to the start of the loop if there
553+ // is at least one active lane.
554+ Builder.SetInsertPoint (VectorLoopIncBlock);
555+ Value *VL64 = Builder.CreateZExt (VL, I64Type);
556+ Value *NewVectorIndexPhi =
557+ Builder.CreateAdd (VectorIndexPhi, VL64, " " ,
558+ /* HasNUW=*/ true , /* HasNSW=*/ true );
559+ VectorIndexPhi->addIncoming (NewVectorIndexPhi, VectorLoopIncBlock);
560+ Value *ExitCond = Builder.CreateICmpNE (NewVectorIndexPhi, ExtEnd);
561+ auto *VectorLoopBranchBack =
562+ BranchInst::Create (VectorLoopStartBlock, EndBlock, ExitCond);
563+ Builder.Insert (VectorLoopBranchBack);
564+
565+ // If we found a mismatch then we need to calculate which lane in the vector
566+ // had a mismatch and add that on to the current loop index.
567+ Builder.SetInsertPoint (VectorLoopMismatchBlock);
568+
569+ // Add LCSSA phis for CTZ and VectorIndexPhi.
570+ auto *CTZLCSSAPhi = Builder.CreatePHI (CTZ->getType (), 1 , " ctz" );
571+ CTZLCSSAPhi->addIncoming (CTZ, VectorLoopStartBlock);
572+ auto *VectorIndexLCSSAPhi =
573+ Builder.CreatePHI (VectorIndexPhi->getType (), 1 , " mismatch_vector_index" );
574+ VectorIndexLCSSAPhi->addIncoming (VectorIndexPhi, VectorLoopStartBlock);
575+
576+ Value *CTZI64 = Builder.CreateZExt (CTZLCSSAPhi, I64Type);
577+ Value *VectorLoopRes64 = Builder.CreateAdd (VectorIndexLCSSAPhi, CTZI64, " " ,
578+ /* HasNUW=*/ true , /* HasNSW=*/ true );
579+ return Builder.CreateTrunc (VectorLoopRes64, ResType);
580+ }
581+
445582Value *LoopIdiomTransform::expandFindMismatch (
446583 IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA,
447584 GetElementPtrInst *GEPB, Instruction *Index, Value *Start, Value *MaxLen) {
@@ -593,8 +730,17 @@ Value *LoopIdiomTransform::expandFindMismatch(
593730 // processed in each iteration, etc.
594731 Builder.SetInsertPoint (VectorLoopPreheaderBlock);
595732
596- Value *VectorLoopRes =
597- createMaskedFindMismatch (Builder, GEPA, GEPB, ExtStart, ExtEnd);
733+ Value *VectorLoopRes = nullptr ;
734+ switch (VectorizeStyle) {
735+ case LoopIdiomTransformStyle::Masked:
736+ VectorLoopRes =
737+ createMaskedFindMismatch (Builder, GEPA, GEPB, ExtStart, ExtEnd);
738+ break ;
739+ case LoopIdiomTransformStyle::Predicated:
740+ VectorLoopRes =
741+ createPredicatedFindMismatch (Builder, GEPA, GEPB, ExtStart, ExtEnd);
742+ break ;
743+ }
598744
599745 Builder.Insert (BranchInst::Create (EndBlock));
600746
0 commit comments