Skip to content

Commit a50c269

Browse files
committed
[InstCombine] Handle load smaller than one byte in memset forward
APInt::getSplat() requires that the new size is >= the original one. If we're loading less than 8 bits, truncate instead. Fixes #58845.
1 parent 36e8e19 commit a50c269

File tree

2 files changed

+28
-4
lines changed

2 files changed

+28
-4
lines changed

llvm/lib/Analysis/Loads.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -532,13 +532,17 @@ static Value *getAvailableLoadStore(Instruction *Inst, const Value *Ptr,
532532
if (IsLoadCSE)
533533
*IsLoadCSE = false;
534534

535+
TypeSize LoadTypeSize = DL.getTypeSizeInBits(AccessTy);
536+
if (LoadTypeSize.isScalable())
537+
return nullptr;
538+
535539
// Make sure the read bytes are contained in the memset.
536-
TypeSize LoadSize = DL.getTypeSizeInBits(AccessTy);
537-
if (LoadSize.isScalable() ||
538-
(Len->getValue() * 8).ult(LoadSize.getFixedSize()))
540+
uint64_t LoadSize = LoadTypeSize.getFixedSize();
541+
if ((Len->getValue() * 8).ult(LoadSize))
539542
return nullptr;
540543

541-
APInt Splat = APInt::getSplat(LoadSize.getFixedSize(), Val->getValue());
544+
APInt Splat = LoadSize >= 8 ? APInt::getSplat(LoadSize, Val->getValue())
545+
: Val->getValue().trunc(LoadSize);
542546
ConstantInt *SplatC = ConstantInt::get(MSI->getContext(), Splat);
543547
if (CastInst::isBitOrNoopPointerCastable(SplatC->getType(), AccessTy, DL))
544548
return SplatC;

llvm/test/Transforms/InstCombine/load-store-forward.ll

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,16 @@ define i27 @load_after_memset_0_non_byte_sized(ptr %a) {
284284
ret i27 %v
285285
}
286286

287+
define i1 @load_after_memset_0_i1(ptr %a) {
288+
; CHECK-LABEL: @load_after_memset_0_i1(
289+
; CHECK-NEXT: call void @llvm.memset.p0.i64(ptr noundef nonnull align 1 dereferenceable(16) [[A:%.*]], i8 0, i64 16, i1 false)
290+
; CHECK-NEXT: ret i1 false
291+
;
292+
call void @llvm.memset.p0.i64(ptr %a, i8 0, i64 16, i1 false)
293+
%v = load i1, ptr %a
294+
ret i1 %v
295+
}
296+
287297
define <4 x i8> @load_after_memset_0_vec(ptr %a) {
288298
; CHECK-LABEL: @load_after_memset_0_vec(
289299
; CHECK-NEXT: call void @llvm.memset.p0.i64(ptr noundef nonnull align 1 dereferenceable(16) [[A:%.*]], i8 0, i64 16, i1 false)
@@ -324,6 +334,16 @@ define i27 @load_after_memset_1_non_byte_sized(ptr %a) {
324334
ret i27 %v
325335
}
326336

337+
define i1 @load_after_memset_1_i1(ptr %a) {
338+
; CHECK-LABEL: @load_after_memset_1_i1(
339+
; CHECK-NEXT: call void @llvm.memset.p0.i64(ptr noundef nonnull align 1 dereferenceable(16) [[A:%.*]], i8 1, i64 16, i1 false)
340+
; CHECK-NEXT: ret i1 true
341+
;
342+
call void @llvm.memset.p0.i64(ptr %a, i8 1, i64 16, i1 false)
343+
%v = load i1, ptr %a
344+
ret i1 %v
345+
}
346+
327347
define <4 x i8> @load_after_memset_1_vec(ptr %a) {
328348
; CHECK-LABEL: @load_after_memset_1_vec(
329349
; CHECK-NEXT: call void @llvm.memset.p0.i64(ptr noundef nonnull align 1 dereferenceable(16) [[A:%.*]], i8 1, i64 16, i1 false)

0 commit comments

Comments
 (0)