Skip to content

Commit 76d4a50

Browse files
committed
[AArch64] Return Invalid partial reduction cost for i128 accumulator.
PR #158641 introduced an issue where i128 accumulator types resulted in a valid cost, because for a <2 x i128> type the code that checks for unsupported type legalization would see a type action of 'TypeSplitVector' which is supported, even though the legalised type of <1 x i128> would require further scalarization.
1 parent 00099bf commit 76d4a50

File tree

2 files changed

+11
-8
lines changed

2 files changed

+11
-8
lines changed

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5666,18 +5666,21 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
56665666
VectorType *AccumVectorType =
56675667
VectorType::get(AccumType, VF.divideCoefficientBy(Ratio));
56685668
// We don't yet support all kinds of legalization.
5669-
auto TA = TLI->getTypeAction(AccumVectorType->getContext(),
5670-
EVT::getEVT(AccumVectorType));
5671-
switch (TA) {
5669+
auto TC = TLI->getTypeConversion(AccumVectorType->getContext(),
5670+
EVT::getEVT(AccumVectorType));
5671+
switch (TC.first) {
56725672
default:
56735673
return Invalid;
56745674
case TargetLowering::TypeLegal:
56755675
case TargetLowering::TypePromoteInteger:
56765676
case TargetLowering::TypeSplitVector:
5677+
// The legalised type (e.g. after splitting) must be legal too.
5678+
if (TLI->getTypeAction(AccumVectorType->getContext(), TC.second) !=
5679+
TargetLowering::TypeLegal)
5680+
return Invalid;
56775681
break;
56785682
}
56795683

5680-
// Check what kind of type-legalisation happens.
56815684
std::pair<InstructionCost, MVT> AccumLT =
56825685
getTypeLegalizationCost(AccumVectorType);
56835686
std::pair<InstructionCost, MVT> InputLT =

llvm/test/Transforms/LoopVectorize/AArch64/pr162009.ll

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ define i128 @add_reduc_i32_i128_unsupported(ptr %a, ptr %b) "target-features"="+
1212
; CHECK-NO-PARTIAL-REDUCTION-NEXT: br label %[[VECTOR_BODY:.*]]
1313
; CHECK-NO-PARTIAL-REDUCTION: [[VECTOR_BODY]]:
1414
; CHECK-NO-PARTIAL-REDUCTION-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
15-
; CHECK-NO-PARTIAL-REDUCTION-NEXT: [[VEC_PHI:%.*]] = phi <2 x i128> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[PARTIAL_REDUCE:%.*]], %[[VECTOR_BODY]] ]
15+
; CHECK-NO-PARTIAL-REDUCTION-NEXT: [[VEC_PHI:%.*]] = phi <4 x i128> [ zeroinitializer, %[[VECTOR_PH]] ], [ [[TMP7:%.*]], %[[VECTOR_BODY]] ]
1616
; CHECK-NO-PARTIAL-REDUCTION-NEXT: [[TMP0:%.*]] = getelementptr i32, ptr [[A]], i64 [[INDEX]]
1717
; CHECK-NO-PARTIAL-REDUCTION-NEXT: [[WIDE_LOAD:%.*]] = load <4 x i32>, ptr [[TMP0]], align 1
1818
; CHECK-NO-PARTIAL-REDUCTION-NEXT: [[TMP1:%.*]] = zext <4 x i32> [[WIDE_LOAD]] to <4 x i64>
@@ -21,18 +21,18 @@ define i128 @add_reduc_i32_i128_unsupported(ptr %a, ptr %b) "target-features"="+
2121
; CHECK-NO-PARTIAL-REDUCTION-NEXT: [[TMP3:%.*]] = zext <4 x i32> [[WIDE_LOAD1]] to <4 x i64>
2222
; CHECK-NO-PARTIAL-REDUCTION-NEXT: [[TMP4:%.*]] = mul nuw <4 x i64> [[TMP1]], [[TMP3]]
2323
; CHECK-NO-PARTIAL-REDUCTION-NEXT: [[TMP5:%.*]] = zext <4 x i64> [[TMP4]] to <4 x i128>
24-
; CHECK-NO-PARTIAL-REDUCTION-NEXT: [[PARTIAL_REDUCE]] = call <2 x i128> @llvm.vector.partial.reduce.add.v2i128.v4i128(<2 x i128> [[VEC_PHI]], <4 x i128> [[TMP5]])
24+
; CHECK-NO-PARTIAL-REDUCTION-NEXT: [[TMP7]] = add <4 x i128> [[VEC_PHI]], [[TMP5]]
2525
; CHECK-NO-PARTIAL-REDUCTION-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 4
2626
; CHECK-NO-PARTIAL-REDUCTION-NEXT: [[TMP6:%.*]] = icmp eq i64 [[INDEX_NEXT]], 4024
2727
; CHECK-NO-PARTIAL-REDUCTION-NEXT: br i1 [[TMP6]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
2828
; CHECK-NO-PARTIAL-REDUCTION: [[MIDDLE_BLOCK]]:
29-
; CHECK-NO-PARTIAL-REDUCTION-NEXT: [[TMP7:%.*]] = call i128 @llvm.vector.reduce.add.v2i128(<2 x i128> [[PARTIAL_REDUCE]])
29+
; CHECK-NO-PARTIAL-REDUCTION-NEXT: [[TMP8:%.*]] = call i128 @llvm.vector.reduce.add.v4i128(<4 x i128> [[TMP7]])
3030
; CHECK-NO-PARTIAL-REDUCTION-NEXT: br label %[[SCALAR_PH:.*]]
3131
; CHECK-NO-PARTIAL-REDUCTION: [[SCALAR_PH]]:
3232
; CHECK-NO-PARTIAL-REDUCTION-NEXT: br label %[[FOR_BODY:.*]]
3333
; CHECK-NO-PARTIAL-REDUCTION: [[FOR_BODY]]:
3434
; CHECK-NO-PARTIAL-REDUCTION-NEXT: [[IV:%.*]] = phi i64 [ 4024, %[[SCALAR_PH]] ], [ [[IV_NEXT:%.*]], %[[FOR_BODY]] ]
35-
; CHECK-NO-PARTIAL-REDUCTION-NEXT: [[ACCUM:%.*]] = phi i128 [ [[TMP7]], %[[SCALAR_PH]] ], [ [[ADD:%.*]], %[[FOR_BODY]] ]
35+
; CHECK-NO-PARTIAL-REDUCTION-NEXT: [[ACCUM:%.*]] = phi i128 [ [[TMP8]], %[[SCALAR_PH]] ], [ [[ADD:%.*]], %[[FOR_BODY]] ]
3636
; CHECK-NO-PARTIAL-REDUCTION-NEXT: [[GEP_A:%.*]] = getelementptr i32, ptr [[A]], i64 [[IV]]
3737
; CHECK-NO-PARTIAL-REDUCTION-NEXT: [[LOAD_A:%.*]] = load i32, ptr [[GEP_A]], align 1
3838
; CHECK-NO-PARTIAL-REDUCTION-NEXT: [[EXT_A:%.*]] = zext i32 [[LOAD_A]] to i64

0 commit comments

Comments
 (0)