@@ -67,6 +67,9 @@ FunctionPass *llvm::createX86FixupVectorConstants() {
6767static std::optional<APInt> extractConstantBits (const Constant *C) {
6868 unsigned NumBits = C->getType ()->getPrimitiveSizeInBits ();
6969
70+ if (auto *CUndef = dyn_cast<UndefValue>(C))
71+ return APInt::getZero (NumBits);
72+
7073 if (auto *CInt = dyn_cast<ConstantInt>(C))
7174 return CInt->getValue ();
7275
@@ -80,6 +83,18 @@ static std::optional<APInt> extractConstantBits(const Constant *C) {
8083 return APInt::getSplat (NumBits, *Bits);
8184 }
8285 }
86+
87+ APInt Bits = APInt::getZero (NumBits);
88+ for (unsigned I = 0 , E = CV->getNumOperands (); I != E; ++I) {
89+ Constant *Elt = CV->getOperand (I);
90+ std::optional<APInt> SubBits = extractConstantBits (Elt);
91+ if (!SubBits)
92+ return std::nullopt ;
93+ assert (NumBits == (E * SubBits->getBitWidth ()) &&
94+ " Illegal vector element size" );
95+ Bits.insertBits (*SubBits, I * SubBits->getBitWidth ());
96+ }
97+ return Bits;
8398 }
8499
85100 if (auto *CDS = dyn_cast<ConstantDataSequential>(C)) {
@@ -223,6 +238,35 @@ static Constant *rebuildSplatableConstant(const Constant *C,
223238 return rebuildConstant (OriginalType->getContext (), SclTy, *Splat, NumSclBits);
224239}
225240
241+ static Constant *rebuildZeroUpperConstant (const Constant *C,
242+ unsigned ScalarBitWidth) {
243+ Type *Ty = C->getType ();
244+ Type *SclTy = Ty->getScalarType ();
245+ unsigned NumBits = Ty->getPrimitiveSizeInBits ();
246+ unsigned NumSclBits = SclTy->getPrimitiveSizeInBits ();
247+ LLVMContext &Ctx = C->getContext ();
248+
249+ if (NumBits > ScalarBitWidth) {
250+ // Determine if the upper bits are all zero.
251+ if (std::optional<APInt> Bits = extractConstantBits (C)) {
252+ if (Bits->countLeadingZeros () >= (NumBits - ScalarBitWidth)) {
253+ // If the original constant was made of smaller elements, try to retain
254+ // those types.
255+ if (ScalarBitWidth > NumSclBits && (ScalarBitWidth % NumSclBits) == 0 )
256+ return rebuildConstant (Ctx, SclTy, *Bits, NumSclBits);
257+
258+ // Fallback to raw integer bits.
259+ APInt RawBits = Bits->zextOrTrunc (ScalarBitWidth);
260+ return ConstantInt::get (Ctx, RawBits);
261+ }
262+ }
263+ }
264+
265+ return nullptr ;
266+ }
267+
268+ typedef std::function<Constant *(const Constant *, unsigned )> RebuildFn;
269+
226270bool X86FixupVectorConstantsPass::processInstruction (MachineFunction &MF,
227271 MachineBasicBlock &MBB,
228272 MachineInstr &MI) {
@@ -233,117 +277,128 @@ bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
233277 bool HasBWI = ST->hasBWI ();
234278 bool HasVLX = ST->hasVLX ();
235279
236- auto ConvertToBroadcast = [&](unsigned OpBcst256, unsigned OpBcst128,
237- unsigned OpBcst64, unsigned OpBcst32,
238- unsigned OpBcst16, unsigned OpBcst8,
239- unsigned OperandNo) {
240- assert (MI.getNumOperands () >= (OperandNo + X86::AddrNumOperands) &&
241- " Unexpected number of operands!" );
242-
243- if (auto *C = X86::getConstantFromPool (MI, OperandNo)) {
244- // Attempt to detect a suitable splat from increasing splat widths.
245- std::pair<unsigned , unsigned > Broadcasts[] = {
246- {8 , OpBcst8}, {16 , OpBcst16}, {32 , OpBcst32},
247- {64 , OpBcst64}, {128 , OpBcst128}, {256 , OpBcst256},
248- };
249- for (auto [BitWidth, OpBcst] : Broadcasts) {
250- if (OpBcst) {
251- // Construct a suitable splat constant and adjust the MI to
252- // use the new constant pool entry.
253- if (Constant *NewCst = rebuildSplatableConstant (C, BitWidth)) {
254- unsigned NewCPI =
255- CP->getConstantPoolIndex (NewCst, Align (BitWidth / 8 ));
256- MI.setDesc (TII->get (OpBcst));
257- MI.getOperand (OperandNo + X86::AddrDisp).setIndex (NewCPI);
258- return true ;
280+ auto FixupConstant =
281+ [&](unsigned OpBcst256, unsigned OpBcst128, unsigned OpBcst64,
282+ unsigned OpBcst32, unsigned OpBcst16, unsigned OpBcst8,
283+ unsigned OpUpper64, unsigned OpUpper32, unsigned OperandNo) {
284+ assert (MI.getNumOperands () >= (OperandNo + X86::AddrNumOperands) &&
285+ " Unexpected number of operands!" );
286+
287+ if (auto *C = X86::getConstantFromPool (MI, OperandNo)) {
288+ // Attempt to detect a suitable splat/vzload from increasing constant
289+ // bitwidths.
290+ // Prefer vzload vs broadcast for same bitwidth to avoid domain flips.
291+ std::tuple<unsigned , unsigned , RebuildFn> FixupLoad[] = {
292+ {8 , OpBcst8, rebuildSplatableConstant},
293+ {16 , OpBcst16, rebuildSplatableConstant},
294+ {32 , OpUpper32, rebuildZeroUpperConstant},
295+ {32 , OpBcst32, rebuildSplatableConstant},
296+ {64 , OpUpper64, rebuildZeroUpperConstant},
297+ {64 , OpBcst64, rebuildSplatableConstant},
298+ {128 , OpBcst128, rebuildSplatableConstant},
299+ {256 , OpBcst256, rebuildSplatableConstant},
300+ };
301+ for (auto [BitWidth, Op, RebuildConstant] : FixupLoad) {
302+ if (Op) {
303+ // Construct a suitable constant and adjust the MI to use the new
304+ // constant pool entry.
305+ if (Constant *NewCst = RebuildConstant (C, BitWidth)) {
306+ unsigned NewCPI =
307+ CP->getConstantPoolIndex (NewCst, Align (BitWidth / 8 ));
308+ MI.setDesc (TII->get (Op));
309+ MI.getOperand (OperandNo + X86::AddrDisp).setIndex (NewCPI);
310+ return true ;
311+ }
312+ }
259313 }
260314 }
261- }
262- }
263- return false ;
264- };
315+ return false ;
316+ };
265317
266- // Attempt to convert full width vector loads into broadcast loads.
318+ // Attempt to convert full width vector loads into broadcast/vzload loads.
267319 switch (Opc) {
268320 /* FP Loads */
269321 case X86::MOVAPDrm:
270322 case X86::MOVAPSrm:
271323 case X86::MOVUPDrm:
272324 case X86::MOVUPSrm:
273325 // TODO: SSE3 MOVDDUP Handling
274- return false ;
326+ return FixupConstant ( 0 , 0 , 0 , 0 , 0 , 0 , X86::MOVSDrm, X86::MOVSSrm, 1 ) ;
275327 case X86::VMOVAPDrm:
276328 case X86::VMOVAPSrm:
277329 case X86::VMOVUPDrm:
278330 case X86::VMOVUPSrm:
279- return ConvertToBroadcast (0 , 0 , X86::VMOVDDUPrm, X86::VBROADCASTSSrm, 0 , 0 ,
280- 1 );
331+ return FixupConstant (0 , 0 , X86::VMOVDDUPrm, X86::VBROADCASTSSrm, 0 , 0 ,
332+ X86::VMOVSDrm, X86::VMOVSSrm, 1 );
281333 case X86::VMOVAPDYrm:
282334 case X86::VMOVAPSYrm:
283335 case X86::VMOVUPDYrm:
284336 case X86::VMOVUPSYrm:
285- return ConvertToBroadcast (0 , X86::VBROADCASTF128rm, X86::VBROADCASTSDYrm,
286- X86::VBROADCASTSSYrm, 0 , 0 , 1 );
337+ return FixupConstant (0 , X86::VBROADCASTF128rm, X86::VBROADCASTSDYrm,
338+ X86::VBROADCASTSSYrm, 0 , 0 , 0 , 0 , 1 );
287339 case X86::VMOVAPDZ128rm:
288340 case X86::VMOVAPSZ128rm:
289341 case X86::VMOVUPDZ128rm:
290342 case X86::VMOVUPSZ128rm:
291- return ConvertToBroadcast (0 , 0 , X86::VMOVDDUPZ128rm,
292- X86::VBROADCASTSSZ128rm, 0 , 0 , 1 );
343+ return FixupConstant (0 , 0 , X86::VMOVDDUPZ128rm, X86::VBROADCASTSSZ128rm, 0 ,
344+ 0 , X86::VMOVSDZrm, X86::VMOVSSZrm , 1 );
293345 case X86::VMOVAPDZ256rm:
294346 case X86::VMOVAPSZ256rm:
295347 case X86::VMOVUPDZ256rm:
296348 case X86::VMOVUPSZ256rm:
297- return ConvertToBroadcast (0 , X86::VBROADCASTF32X4Z256rm,
298- X86::VBROADCASTSDZ256rm, X86::VBROADCASTSSZ256rm,
299- 0 , 0 , 1 );
349+ return FixupConstant (0 , X86::VBROADCASTF32X4Z256rm, X86::VBROADCASTSDZ256rm,
350+ X86::VBROADCASTSSZ256rm, 0 , 0 , 0 , 0 , 1 );
300351 case X86::VMOVAPDZrm:
301352 case X86::VMOVAPSZrm:
302353 case X86::VMOVUPDZrm:
303354 case X86::VMOVUPSZrm:
304- return ConvertToBroadcast (X86::VBROADCASTF64X4rm, X86::VBROADCASTF32X4rm,
305- X86::VBROADCASTSDZrm, X86::VBROADCASTSSZrm, 0 , 0 ,
306- 1 );
355+ return FixupConstant (X86::VBROADCASTF64X4rm, X86::VBROADCASTF32X4rm,
356+ X86::VBROADCASTSDZrm, X86::VBROADCASTSSZrm, 0 , 0 , 0 , 0 ,
357+ 1 );
307358 /* Integer Loads */
359+ case X86::MOVDQArm:
360+ case X86::MOVDQUrm:
361+ return FixupConstant (0 , 0 , 0 , 0 , 0 , 0 , X86::MOVQI2PQIrm, X86::MOVDI2PDIrm,
362+ 1 );
308363 case X86::VMOVDQArm:
309364 case X86::VMOVDQUrm:
310- return ConvertToBroadcast (
311- 0 , 0 , HasAVX2 ? X86::VPBROADCASTQrm : X86::VMOVDDUPrm ,
312- HasAVX2 ? X86::VPBROADCASTDrm : X86::VBROADCASTSSrm ,
313- HasAVX2 ? X86::VPBROADCASTWrm : 0 , HasAVX2 ? X86::VPBROADCASTBrm : 0 ,
314- 1 );
365+ return FixupConstant ( 0 , 0 , HasAVX2 ? X86::VPBROADCASTQrm : X86::VMOVDDUPrm,
366+ HasAVX2 ? X86::VPBROADCASTDrm : X86::VBROADCASTSSrm ,
367+ HasAVX2 ? X86::VPBROADCASTWrm : 0 ,
368+ HasAVX2 ? X86::VPBROADCASTBrm : 0 , X86::VMOVQI2PQIrm ,
369+ X86::VMOVDI2PDIrm, 1 );
315370 case X86::VMOVDQAYrm:
316371 case X86::VMOVDQUYrm:
317- return ConvertToBroadcast (
372+ return FixupConstant (
318373 0 , HasAVX2 ? X86::VBROADCASTI128rm : X86::VBROADCASTF128rm,
319374 HasAVX2 ? X86::VPBROADCASTQYrm : X86::VBROADCASTSDYrm,
320375 HasAVX2 ? X86::VPBROADCASTDYrm : X86::VBROADCASTSSYrm,
321376 HasAVX2 ? X86::VPBROADCASTWYrm : 0 , HasAVX2 ? X86::VPBROADCASTBYrm : 0 ,
322- 1 );
377+ 0 , 0 , 1 );
323378 case X86::VMOVDQA32Z128rm:
324379 case X86::VMOVDQA64Z128rm:
325380 case X86::VMOVDQU32Z128rm:
326381 case X86::VMOVDQU64Z128rm:
327- return ConvertToBroadcast (0 , 0 , X86::VPBROADCASTQZ128rm,
328- X86::VPBROADCASTDZ128rm ,
329- HasBWI ? X86::VPBROADCASTWZ128rm : 0 ,
330- HasBWI ? X86::VPBROADCASTBZ128rm : 0 , 1 );
382+ return FixupConstant (0 , 0 , X86::VPBROADCASTQZ128rm, X86::VPBROADCASTDZ128rm ,
383+ HasBWI ? X86::VPBROADCASTWZ128rm : 0 ,
384+ HasBWI ? X86::VPBROADCASTBZ128rm : 0 ,
385+ X86::VMOVQI2PQIZrm, X86::VMOVDI2PDIZrm , 1 );
331386 case X86::VMOVDQA32Z256rm:
332387 case X86::VMOVDQA64Z256rm:
333388 case X86::VMOVDQU32Z256rm:
334389 case X86::VMOVDQU64Z256rm:
335- return ConvertToBroadcast (0 , X86::VBROADCASTI32X4Z256rm,
336- X86::VPBROADCASTQZ256rm, X86::VPBROADCASTDZ256rm,
337- HasBWI ? X86::VPBROADCASTWZ256rm : 0 ,
338- HasBWI ? X86::VPBROADCASTBZ256rm : 0 , 1 );
390+ return FixupConstant (0 , X86::VBROADCASTI32X4Z256rm, X86::VPBROADCASTQZ256rm ,
391+ X86::VPBROADCASTDZ256rm,
392+ HasBWI ? X86::VPBROADCASTWZ256rm : 0 ,
393+ HasBWI ? X86::VPBROADCASTBZ256rm : 0 , 0 , 0 , 1 );
339394 case X86::VMOVDQA32Zrm:
340395 case X86::VMOVDQA64Zrm:
341396 case X86::VMOVDQU32Zrm:
342397 case X86::VMOVDQU64Zrm:
343- return ConvertToBroadcast (X86::VBROADCASTI64X4rm, X86::VBROADCASTI32X4rm,
344- X86::VPBROADCASTQZrm, X86::VPBROADCASTDZrm,
345- HasBWI ? X86::VPBROADCASTWZrm : 0 ,
346- HasBWI ? X86::VPBROADCASTBZrm : 0 , 1 );
398+ return FixupConstant (X86::VBROADCASTI64X4rm, X86::VBROADCASTI32X4rm,
399+ X86::VPBROADCASTQZrm, X86::VPBROADCASTDZrm,
400+ HasBWI ? X86::VPBROADCASTWZrm : 0 ,
401+ HasBWI ? X86::VPBROADCASTBZrm : 0 , 0 , 0 , 1 );
347402 }
348403
349404 auto ConvertToBroadcastAVX512 = [&](unsigned OpSrc32, unsigned OpSrc64) {
@@ -368,7 +423,7 @@ bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
368423
369424 if (OpBcst32 || OpBcst64) {
370425 unsigned OpNo = OpBcst32 == 0 ? OpNoBcst64 : OpNoBcst32;
371- return ConvertToBroadcast (0 , 0 , OpBcst64, OpBcst32, 0 , 0 , OpNo);
426+ return FixupConstant (0 , 0 , OpBcst64, OpBcst32, 0 , 0 , 0 , 0 , OpNo);
372427 }
373428 return false ;
374429 };
0 commit comments