From 7d2123ebd93237962945597b17b2aa7ea7c59d35 Mon Sep 17 00:00:00 2001 From: Yingwei Zheng Date: Sat, 30 Sep 2023 21:02:14 +0800 Subject: [PATCH 1/3] [SimplifyCFG] Improve range reducing for switches --- llvm/lib/Transforms/Utils/SimplifyCFG.cpp | 74 +++++++++- .../Transforms/SimplifyCFG/rangereduce.ll | 139 ++++++++++++++++++ 2 files changed, 210 insertions(+), 3 deletions(-) diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp index 998677af3411e..ef95cdfe60ebc 100644 --- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp @@ -82,6 +82,7 @@ #include #include #include +#include #include #include #include @@ -7158,6 +7159,71 @@ static bool switchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder, return true; } +/// Try to reduce the range of cases with an unreachable default. +static bool +ReduceSwitchRangeWithUnreachableDefault(SwitchInst *SI, + const SmallVectorImpl &Values, + uint64_t Base, IRBuilder<> &Builder) { + bool HasDefault = + !isa(SI->getDefaultDest()->getFirstNonPHIOrDbg()); + if (HasDefault) + return false; + + // Try reducing the range to (idx + offset) & mask + // Mask out common high bits + uint64_t CommonOnes = std::numeric_limits::max(); + uint64_t CommonZeros = std::numeric_limits::max(); + for (auto &V : Values) { + CommonOnes &= (uint64_t)V; + CommonZeros &= ~(uint64_t)V; + } + uint64_t CommonBits = countl_one(CommonOnes | CommonZeros); + unsigned LowBits = 64 - CommonBits; + uint64_t Mask = (1ULL << LowBits) - 1; + if (Mask == std::numeric_limits::max()) + return false; + // Now we have some case values in the additive group Z/(2**k)Z. + // Find the largest hole in the group and move it to back. + uint64_t MaxHole = 0; + uint64_t BestOffset = 0; + for (unsigned I = 0; I < Values.size(); ++I) { + uint64_t Hole = ((uint64_t)Values[I] - + (uint64_t)(I == 0 ? Values.back() : Values[I - 1])) & + Mask; + if (Hole > MaxHole) { + MaxHole = Hole; + BestOffset = Mask - (uint64_t)Values[I] + 1; + } + } + + SmallVector NewValues; + for (auto &V : Values) + NewValues.push_back( + (((int64_t)(((uint64_t)V + BestOffset) & Mask)) << CommonBits) >> + CommonBits); + + llvm::sort(NewValues); + if (!isSwitchDense(NewValues)) + // Transform didn't create a dense switch. + return false; + + auto *Ty = cast(SI->getCondition()->getType()); + APInt Offset(Ty->getBitWidth(), BestOffset - Base); + auto *Index = Builder.CreateAnd( + Builder.CreateAdd(SI->getCondition(), ConstantInt::get(Ty, Offset)), + Mask); + SI->replaceUsesOfWith(SI->getCondition(), Index); + + for (auto Case : SI->cases()) { + auto *Orig = Case.getCaseValue(); + auto CaseVal = + (Orig->getValue() + Offset).trunc(LowBits).sext(Ty->getBitWidth()); + Case.setValue(cast(ConstantInt::get(Ty, CaseVal))); + } + + return true; +} + /// Try to transform a switch that has "holes" in it to a contiguous sequence /// of cases. /// @@ -7173,9 +7239,8 @@ static bool reduceSwitchRange(SwitchInst *SI, IRBuilder<> &Builder, if (CondTy->getIntegerBitWidth() > 64 || !DL.fitsInLegalInteger(CondTy->getIntegerBitWidth())) return false; - // Only bother with this optimization if there are more than 3 switch cases; - // SDAG will only bother creating jump tables for 4 or more cases. - if (SI->getNumCases() < 4) + // Ignore switches with less than three cases. + if (SI->getNumCases() < 3) return false; // This transform is agnostic to the signedness of the input or case values. We @@ -7196,6 +7261,9 @@ static bool reduceSwitchRange(SwitchInst *SI, IRBuilder<> &Builder, for (auto &V : Values) V -= (uint64_t)(Base); + if (ReduceSwitchRangeWithUnreachableDefault(SI, Values, Base, Builder)) + return true; + // Now we have signed numbers that have been shifted so that, given enough // precision, there are no negative values. Since the rest of the transform // is bitwise only, we switch now to an unsigned representation. diff --git a/llvm/test/Transforms/SimplifyCFG/rangereduce.ll b/llvm/test/Transforms/SimplifyCFG/rangereduce.ll index 467ede9b75c33..94ecd6101f741 100644 --- a/llvm/test/Transforms/SimplifyCFG/rangereduce.ll +++ b/llvm/test/Transforms/SimplifyCFG/rangereduce.ll @@ -305,3 +305,142 @@ three: ret i32 99783 } +define i8 @pr67842(i32 %0) { +; CHECK-LABEL: @pr67842( +; CHECK-NEXT: start: +; CHECK-NEXT: [[TMP1:%.*]] = add i32 [[TMP0:%.*]], 1 +; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[TMP1]], 255 +; CHECK-NEXT: [[SWITCH_IDX_CAST:%.*]] = trunc i32 [[TMP2]] to i8 +; CHECK-NEXT: [[SWITCH_OFFSET:%.*]] = add nsw i8 [[SWITCH_IDX_CAST]], -1 +; CHECK-NEXT: ret i8 [[SWITCH_OFFSET]] +; +start: + switch i32 %0, label %bb2 [ + i32 0, label %bb5 + i32 1, label %bb4 + i32 255, label %bb1 + ] + +bb2: ; preds = %start + unreachable + +bb4: ; preds = %start + br label %bb5 + +bb1: ; preds = %start + br label %bb5 + +bb5: ; preds = %start, %bb1, %bb4 + %.0 = phi i8 [ -1, %bb1 ], [ 1, %bb4 ], [ 0, %start ] + ret i8 %.0 +} + +define i8 @reduce_masked_common_high_bits(i32 %0) { +; CHECK-LABEL: @reduce_masked_common_high_bits( +; CHECK-NEXT: start: +; CHECK-NEXT: [[TMP1:%.*]] = add i32 [[TMP0:%.*]], -127 +; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[TMP1]], 127 +; CHECK-NEXT: [[SWITCH_IDX_CAST:%.*]] = trunc i32 [[TMP2]] to i8 +; CHECK-NEXT: [[SWITCH_OFFSET:%.*]] = add nsw i8 [[SWITCH_IDX_CAST]], -1 +; CHECK-NEXT: ret i8 [[SWITCH_OFFSET]] +; +start: + switch i32 %0, label %bb2 [ + i32 128, label %bb5 + i32 129, label %bb4 + i32 255, label %bb1 + ] + +bb2: ; preds = %start + unreachable + +bb4: ; preds = %start + br label %bb5 + +bb1: ; preds = %start + br label %bb5 + +bb5: ; preds = %start, %bb1, %bb4 + %.0 = phi i8 [ -1, %bb1 ], [ 1, %bb4 ], [ 0, %start ] + ret i8 %.0 +} + +define i8 @reduce_masked_common_high_bits_fail(i32 %0) { +; CHECK-LABEL: @reduce_masked_common_high_bits_fail( +; CHECK-NEXT: start: +; CHECK-NEXT: switch i32 [[TMP0:%.*]], label [[BB2:%.*]] [ +; CHECK-NEXT: i32 128, label [[BB5:%.*]] +; CHECK-NEXT: i32 129, label [[BB4:%.*]] +; CHECK-NEXT: i32 511, label [[BB1:%.*]] +; CHECK-NEXT: ] +; CHECK: bb2: +; CHECK-NEXT: unreachable +; CHECK: bb4: +; CHECK-NEXT: br label [[BB5]] +; CHECK: bb1: +; CHECK-NEXT: br label [[BB5]] +; CHECK: bb5: +; CHECK-NEXT: [[DOT0:%.*]] = phi i8 [ -1, [[BB1]] ], [ 1, [[BB4]] ], [ 0, [[START:%.*]] ] +; CHECK-NEXT: ret i8 [[DOT0]] +; +start: + switch i32 %0, label %bb2 [ + i32 128, label %bb5 + i32 129, label %bb4 + i32 511, label %bb1 + ] + +bb2: ; preds = %start + unreachable + +bb4: ; preds = %start + br label %bb5 + +bb1: ; preds = %start + br label %bb5 + +bb5: ; preds = %start, %bb1, %bb4 + %.0 = phi i8 [ -1, %bb1 ], [ 1, %bb4 ], [ 0, %start ] + ret i8 %.0 +} + +; Optimization shouldn't trigger; The default block is reachable. +define i8 @reduce_masked_default_reachable(i32 %0) { +; CHECK-LABEL: @reduce_masked_default_reachable( +; CHECK-NEXT: start: +; CHECK-NEXT: switch i32 [[TMP0:%.*]], label [[COMMON_RET:%.*]] [ +; CHECK-NEXT: i32 0, label [[BB5:%.*]] +; CHECK-NEXT: i32 1, label [[BB4:%.*]] +; CHECK-NEXT: i32 255, label [[BB1:%.*]] +; CHECK-NEXT: ] +; CHECK: common.ret: +; CHECK-NEXT: [[COMMON_RET_OP:%.*]] = phi i8 [ [[DOT0:%.*]], [[BB5]] ], [ 24, [[START:%.*]] ] +; CHECK-NEXT: ret i8 [[COMMON_RET_OP]] +; CHECK: bb4: +; CHECK-NEXT: br label [[BB5]] +; CHECK: bb1: +; CHECK-NEXT: br label [[BB5]] +; CHECK: bb5: +; CHECK-NEXT: [[DOT0]] = phi i8 [ -1, [[BB1]] ], [ 1, [[BB4]] ], [ 0, [[START]] ] +; CHECK-NEXT: br label [[COMMON_RET]] +; +start: + switch i32 %0, label %bb2 [ + i32 0, label %bb5 + i32 1, label %bb4 + i32 255, label %bb1 + ] + +bb2: ; preds = %start + ret i8 24 + +bb4: ; preds = %start + br label %bb5 + +bb1: ; preds = %start + br label %bb5 + +bb5: ; preds = %start, %bb1, %bb4 + %.0 = phi i8 [ -1, %bb1 ], [ 1, %bb4 ], [ 0, %start ] + ret i8 %.0 +} From d766aff1e06599d9cf5f773cd3e76305e01bf2b2 Mon Sep 17 00:00:00 2001 From: Yingwei Zheng Date: Thu, 3 Apr 2025 11:31:35 +0800 Subject: [PATCH 2/3] Rebase --- llvm/lib/Transforms/Utils/SimplifyCFG.cpp | 10 ++++--- .../RISCV/switch-of-powers-of-two.ll | 26 ++++++------------- .../Transforms/SimplifyCFG/rangereduce.ll | 19 +++++++------- 3 files changed, 23 insertions(+), 32 deletions(-) diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp index ef95cdfe60ebc..b06cae846c4b6 100644 --- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp @@ -7161,7 +7161,7 @@ static bool switchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder, /// Try to reduce the range of cases with an unreachable default. static bool -ReduceSwitchRangeWithUnreachableDefault(SwitchInst *SI, +reduceSwitchRangeWithUnreachableDefault(SwitchInst *SI, const SmallVectorImpl &Values, uint64_t Base, IRBuilder<> &Builder) { bool HasDefault = @@ -7203,12 +7203,14 @@ ReduceSwitchRangeWithUnreachableDefault(SwitchInst *SI, CommonBits); llvm::sort(NewValues); - if (!isSwitchDense(NewValues)) + if (!isSwitchDense(NewValues)) { // Transform didn't create a dense switch. return false; + } auto *Ty = cast(SI->getCondition()->getType()); - APInt Offset(Ty->getBitWidth(), BestOffset - Base); + APInt Offset(Ty->getBitWidth(), BestOffset - Base, /*isSigned=*/true, + /*implicitTrunc=*/true); auto *Index = Builder.CreateAnd( Builder.CreateAdd(SI->getCondition(), ConstantInt::get(Ty, Offset)), Mask); @@ -7261,7 +7263,7 @@ static bool reduceSwitchRange(SwitchInst *SI, IRBuilder<> &Builder, for (auto &V : Values) V -= (uint64_t)(Base); - if (ReduceSwitchRangeWithUnreachableDefault(SI, Values, Base, Builder)) + if (reduceSwitchRangeWithUnreachableDefault(SI, Values, Base, Builder)) return true; // Now we have signed numbers that have been shifted so that, given enough diff --git a/llvm/test/Transforms/SimplifyCFG/RISCV/switch-of-powers-of-two.ll b/llvm/test/Transforms/SimplifyCFG/RISCV/switch-of-powers-of-two.ll index 2ac94afd95910..7b3280a588a80 100644 --- a/llvm/test/Transforms/SimplifyCFG/RISCV/switch-of-powers-of-two.ll +++ b/llvm/test/Transforms/SimplifyCFG/RISCV/switch-of-powers-of-two.ll @@ -157,22 +157,10 @@ return: define i32 @unable_to_create_dense_switch(i32 %x) { ; CHECK-LABEL: @unable_to_create_dense_switch( ; CHECK-NEXT: entry: -; CHECK-NEXT: switch i32 [[X:%.*]], label [[DEFAULT_CASE:%.*]] [ -; CHECK-NEXT: i32 1, label [[RETURN:%.*]] -; CHECK-NEXT: i32 2, label [[BB3:%.*]] -; CHECK-NEXT: i32 4, label [[BB4:%.*]] -; CHECK-NEXT: i32 4096, label [[BB5:%.*]] -; CHECK-NEXT: ] -; CHECK: default_case: -; CHECK-NEXT: unreachable -; CHECK: bb3: -; CHECK-NEXT: br label [[RETURN]] -; CHECK: bb4: -; CHECK-NEXT: br label [[RETURN]] -; CHECK: bb5: -; CHECK-NEXT: br label [[RETURN]] -; CHECK: return: -; CHECK-NEXT: [[P:%.*]] = phi i32 [ 1, [[BB3]] ], [ 0, [[BB4]] ], [ 42, [[BB5]] ], [ 2, [[ENTRY:%.*]] ] +; CHECK-NEXT: [[TMP0:%.*]] = add i32 [[X:%.*]], 0 +; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[TMP0]], 4095 +; CHECK-NEXT: [[SWITCH_GEP:%.*]] = getelementptr inbounds [5 x i32], ptr @switch.table.unable_to_create_dense_switch, i32 0, i32 [[TMP1]] +; CHECK-NEXT: [[P:%.*]] = load i32, ptr [[SWITCH_GEP]], align 4 ; CHECK-NEXT: ret i32 [[P]] ; entry: @@ -200,11 +188,13 @@ declare i32 @bar(i32) define i32 @unable_to_generate_lookup_table(i32 %x, i32 %y) { ; RV64I-LABEL: @unable_to_generate_lookup_table( ; RV64I-NEXT: entry: -; RV64I-NEXT: switch i32 [[Y:%.*]], label [[DEFAULT_CASE:%.*]] [ +; RV64I-NEXT: [[TMP0:%.*]] = add i32 [[Y1:%.*]], 0 +; RV64I-NEXT: [[Y:%.*]] = and i32 [[TMP0]], 63 +; RV64I-NEXT: switch i32 [[Y]], label [[DEFAULT_CASE:%.*]] [ ; RV64I-NEXT: i32 1, label [[BB2:%.*]] ; RV64I-NEXT: i32 2, label [[BB3:%.*]] ; RV64I-NEXT: i32 8, label [[BB4:%.*]] -; RV64I-NEXT: i32 64, label [[BB5:%.*]] +; RV64I-NEXT: i32 0, label [[BB5:%.*]] ; RV64I-NEXT: ] ; RV64I: default_case: ; RV64I-NEXT: unreachable diff --git a/llvm/test/Transforms/SimplifyCFG/rangereduce.ll b/llvm/test/Transforms/SimplifyCFG/rangereduce.ll index 94ecd6101f741..4e4bd4416d191 100644 --- a/llvm/test/Transforms/SimplifyCFG/rangereduce.ll +++ b/llvm/test/Transforms/SimplifyCFG/rangereduce.ll @@ -369,9 +369,9 @@ define i8 @reduce_masked_common_high_bits_fail(i32 %0) { ; CHECK-LABEL: @reduce_masked_common_high_bits_fail( ; CHECK-NEXT: start: ; CHECK-NEXT: switch i32 [[TMP0:%.*]], label [[BB2:%.*]] [ -; CHECK-NEXT: i32 128, label [[BB5:%.*]] -; CHECK-NEXT: i32 129, label [[BB4:%.*]] -; CHECK-NEXT: i32 511, label [[BB1:%.*]] +; CHECK-NEXT: i32 128, label [[BB5:%.*]] +; CHECK-NEXT: i32 129, label [[BB4:%.*]] +; CHECK-NEXT: i32 511, label [[BB1:%.*]] ; CHECK-NEXT: ] ; CHECK: bb2: ; CHECK-NEXT: unreachable @@ -409,19 +409,18 @@ define i8 @reduce_masked_default_reachable(i32 %0) { ; CHECK-LABEL: @reduce_masked_default_reachable( ; CHECK-NEXT: start: ; CHECK-NEXT: switch i32 [[TMP0:%.*]], label [[COMMON_RET:%.*]] [ -; CHECK-NEXT: i32 0, label [[BB5:%.*]] -; CHECK-NEXT: i32 1, label [[BB4:%.*]] -; CHECK-NEXT: i32 255, label [[BB1:%.*]] +; CHECK-NEXT: i32 0, label [[BB5:%.*]] +; CHECK-NEXT: i32 1, label [[BB4:%.*]] +; CHECK-NEXT: i32 255, label [[BB1:%.*]] ; CHECK-NEXT: ] ; CHECK: common.ret: -; CHECK-NEXT: [[COMMON_RET_OP:%.*]] = phi i8 [ [[DOT0:%.*]], [[BB5]] ], [ 24, [[START:%.*]] ] +; CHECK-NEXT: [[COMMON_RET_OP:%.*]] = phi i8 [ 24, [[START:%.*]] ], [ -1, [[BB1]] ], [ 1, [[BB4]] ], [ 0, [[BB5]] ] ; CHECK-NEXT: ret i8 [[COMMON_RET_OP]] ; CHECK: bb4: -; CHECK-NEXT: br label [[BB5]] +; CHECK-NEXT: br label [[COMMON_RET]] ; CHECK: bb1: -; CHECK-NEXT: br label [[BB5]] +; CHECK-NEXT: br label [[COMMON_RET]] ; CHECK: bb5: -; CHECK-NEXT: [[DOT0]] = phi i8 [ -1, [[BB1]] ], [ 1, [[BB4]] ], [ 0, [[START]] ] ; CHECK-NEXT: br label [[COMMON_RET]] ; start: From d732038a8db4b720b55e60ea8a927834deb1be0e Mon Sep 17 00:00:00 2001 From: Yingwei Zheng Date: Thu, 3 Apr 2025 14:10:16 +0800 Subject: [PATCH 3/3] Adjust code. NFC. --- llvm/lib/Transforms/Utils/SimplifyCFG.cpp | 31 +++++++++++------------ 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp index b06cae846c4b6..77eec60b6a86f 100644 --- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp @@ -7164,9 +7164,7 @@ static bool reduceSwitchRangeWithUnreachableDefault(SwitchInst *SI, const SmallVectorImpl &Values, uint64_t Base, IRBuilder<> &Builder) { - bool HasDefault = - !isa(SI->getDefaultDest()->getFirstNonPHIOrDbg()); - if (HasDefault) + if (!SI->defaultDestUndefined()) return false; // Try reducing the range to (idx + offset) & mask @@ -7177,30 +7175,30 @@ reduceSwitchRangeWithUnreachableDefault(SwitchInst *SI, CommonOnes &= (uint64_t)V; CommonZeros &= ~(uint64_t)V; } - uint64_t CommonBits = countl_one(CommonOnes | CommonZeros); - unsigned LowBits = 64 - CommonBits; - uint64_t Mask = (1ULL << LowBits) - 1; - if (Mask == std::numeric_limits::max()) + unsigned CommonPrefixLen = countl_one(CommonOnes | CommonZeros); + if (CommonPrefixLen == 64 || CommonPrefixLen == 0) return false; + uint64_t Mask = std::numeric_limits::max() >> CommonPrefixLen; // Now we have some case values in the additive group Z/(2**k)Z. // Find the largest hole in the group and move it to back. uint64_t MaxHole = 0; uint64_t BestOffset = 0; for (unsigned I = 0; I < Values.size(); ++I) { - uint64_t Hole = ((uint64_t)Values[I] - - (uint64_t)(I == 0 ? Values.back() : Values[I - 1])) & - Mask; + uint64_t LastVal = + static_cast(I == 0 ? Values.back() : Values[I - 1]); + uint64_t Hole = (static_cast(Values[I]) - LastVal) & Mask; if (Hole > MaxHole) { MaxHole = Hole; - BestOffset = Mask - (uint64_t)Values[I] + 1; + BestOffset = (-static_cast(Values[I])) & Mask; } } SmallVector NewValues; for (auto &V : Values) NewValues.push_back( - (((int64_t)(((uint64_t)V + BestOffset) & Mask)) << CommonBits) >> - CommonBits); + ((static_cast((static_cast(V) + BestOffset) & Mask)) + << CommonPrefixLen) >> + CommonPrefixLen); llvm::sort(NewValues); if (!isSwitchDense(NewValues)) { @@ -7211,15 +7209,16 @@ reduceSwitchRangeWithUnreachableDefault(SwitchInst *SI, auto *Ty = cast(SI->getCondition()->getType()); APInt Offset(Ty->getBitWidth(), BestOffset - Base, /*isSigned=*/true, /*implicitTrunc=*/true); - auto *Index = Builder.CreateAnd( + Value *Index = Builder.CreateAnd( Builder.CreateAdd(SI->getCondition(), ConstantInt::get(Ty, Offset)), Mask); SI->replaceUsesOfWith(SI->getCondition(), Index); for (auto Case : SI->cases()) { auto *Orig = Case.getCaseValue(); - auto CaseVal = - (Orig->getValue() + Offset).trunc(LowBits).sext(Ty->getBitWidth()); + APInt CaseVal = (Orig->getValue() + Offset) + .trunc(64 - CommonPrefixLen) + .sext(Ty->getBitWidth()); Case.setValue(cast(ConstantInt::get(Ty, CaseVal))); }