Skip to content

Commit e86b338

Browse files
authored
[DAGCombine] Support (shl %x, constant) in foldPartialReduceMLAMulOp. (#160663)
Support shifts in foldPartialReduceMLAMulOp by treating (shl %x, %c) as (mul %x, (shl 1, %c)). PR: #160663
1 parent 88c668d commit e86b338

File tree

2 files changed

+88
-22
lines changed

2 files changed

+88
-22
lines changed

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12994,13 +12994,31 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
1299412994
SDValue Op1 = N->getOperand(1);
1299512995
SDValue Op2 = N->getOperand(2);
1299612996

12997-
APInt C;
12998-
if (Op1->getOpcode() != ISD::MUL ||
12999-
!ISD::isConstantSplatVector(Op2.getNode(), C) || !C.isOne())
12997+
unsigned Opc = Op1->getOpcode();
12998+
if (Opc != ISD::MUL && Opc != ISD::SHL)
1300012999
return SDValue();
1300113000

1300213001
SDValue LHS = Op1->getOperand(0);
1300313002
SDValue RHS = Op1->getOperand(1);
13003+
13004+
// Try to treat (shl %a, %c) as (mul %a, (1 << %c)) for constant %c.
13005+
if (Opc == ISD::SHL) {
13006+
APInt C;
13007+
if (!ISD::isConstantSplatVector(RHS.getNode(), C))
13008+
return SDValue();
13009+
13010+
RHS =
13011+
DAG.getSplatVector(RHS.getValueType(), DL,
13012+
DAG.getConstant(APInt(C.getBitWidth(), 1).shl(C), DL,
13013+
RHS.getValueType().getScalarType()));
13014+
Opc = ISD::MUL;
13015+
}
13016+
13017+
APInt C;
13018+
if (Opc != ISD::MUL || !ISD::isConstantSplatVector(Op2.getNode(), C) ||
13019+
!C.isOne())
13020+
return SDValue();
13021+
1300413022
unsigned LHSOpcode = LHS->getOpcode();
1300513023
if (!ISD::isExtOpcode(LHSOpcode))
1300613024
return SDValue();

llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll

Lines changed: 67 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1257,21 +1257,55 @@ entry:
12571257
}
12581258

12591259
define <4 x i32> @partial_reduce_shl_sext_const_rhs6(<16 x i8> %l, <4 x i32> %part) {
1260-
; CHECK-COMMON-LABEL: partial_reduce_shl_sext_const_rhs6:
1260+
; CHECK-NODOT-LABEL: partial_reduce_shl_sext_const_rhs6:
1261+
; CHECK-NODOT: // %bb.0:
1262+
; CHECK-NODOT-NEXT: sshll v2.8h, v0.8b, #0
1263+
; CHECK-NODOT-NEXT: sshll2 v0.8h, v0.16b, #0
1264+
; CHECK-NODOT-NEXT: sshll v3.4s, v0.4h, #6
1265+
; CHECK-NODOT-NEXT: sshll2 v4.4s, v2.8h, #6
1266+
; CHECK-NODOT-NEXT: sshll v2.4s, v2.4h, #6
1267+
; CHECK-NODOT-NEXT: sshll2 v0.4s, v0.8h, #6
1268+
; CHECK-NODOT-NEXT: add v1.4s, v1.4s, v2.4s
1269+
; CHECK-NODOT-NEXT: add v2.4s, v4.4s, v3.4s
1270+
; CHECK-NODOT-NEXT: add v1.4s, v1.4s, v2.4s
1271+
; CHECK-NODOT-NEXT: add v0.4s, v1.4s, v0.4s
1272+
; CHECK-NODOT-NEXT: ret
1273+
;
1274+
; CHECK-DOT-LABEL: partial_reduce_shl_sext_const_rhs6:
1275+
; CHECK-DOT: // %bb.0:
1276+
; CHECK-DOT-NEXT: movi v2.16b, #64
1277+
; CHECK-DOT-NEXT: sdot v1.4s, v0.16b, v2.16b
1278+
; CHECK-DOT-NEXT: mov v0.16b, v1.16b
1279+
; CHECK-DOT-NEXT: ret
1280+
;
1281+
; CHECK-DOT-I8MM-LABEL: partial_reduce_shl_sext_const_rhs6:
1282+
; CHECK-DOT-I8MM: // %bb.0:
1283+
; CHECK-DOT-I8MM-NEXT: movi v2.16b, #64
1284+
; CHECK-DOT-I8MM-NEXT: sdot v1.4s, v0.16b, v2.16b
1285+
; CHECK-DOT-I8MM-NEXT: mov v0.16b, v1.16b
1286+
; CHECK-DOT-I8MM-NEXT: ret
1287+
%ext = sext <16 x i8> %l to <16 x i32>
1288+
%shift = shl nsw <16 x i32> %ext, splat (i32 6)
1289+
%red = tail call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %part, <16 x i32> %shift)
1290+
ret <4 x i32> %red
1291+
}
1292+
1293+
define <4 x i32> @partial_reduce_shl_sext_const_rhs7(<16 x i8> %l, <4 x i32> %part) {
1294+
; CHECK-COMMON-LABEL: partial_reduce_shl_sext_const_rhs7:
12611295
; CHECK-COMMON: // %bb.0:
12621296
; CHECK-COMMON-NEXT: sshll v2.8h, v0.8b, #0
12631297
; CHECK-COMMON-NEXT: sshll2 v0.8h, v0.16b, #0
1264-
; CHECK-COMMON-NEXT: sshll v3.4s, v0.4h, #6
1265-
; CHECK-COMMON-NEXT: sshll2 v4.4s, v2.8h, #6
1266-
; CHECK-COMMON-NEXT: sshll v2.4s, v2.4h, #6
1267-
; CHECK-COMMON-NEXT: sshll2 v0.4s, v0.8h, #6
1298+
; CHECK-COMMON-NEXT: sshll v3.4s, v0.4h, #7
1299+
; CHECK-COMMON-NEXT: sshll2 v4.4s, v2.8h, #7
1300+
; CHECK-COMMON-NEXT: sshll v2.4s, v2.4h, #7
1301+
; CHECK-COMMON-NEXT: sshll2 v0.4s, v0.8h, #7
12681302
; CHECK-COMMON-NEXT: add v1.4s, v1.4s, v2.4s
12691303
; CHECK-COMMON-NEXT: add v2.4s, v4.4s, v3.4s
12701304
; CHECK-COMMON-NEXT: add v1.4s, v1.4s, v2.4s
12711305
; CHECK-COMMON-NEXT: add v0.4s, v1.4s, v0.4s
12721306
; CHECK-COMMON-NEXT: ret
12731307
%ext = sext <16 x i8> %l to <16 x i32>
1274-
%shift = shl nsw <16 x i32> %ext, splat (i32 6)
1308+
%shift = shl nsw <16 x i32> %ext, splat (i32 7)
12751309
%red = tail call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %part, <16 x i32> %shift)
12761310
ret <4 x i32> %red
12771311
}
@@ -1331,19 +1365,33 @@ define <4 x i32> @partial_reduce_shl_sext_non_const_rhs(<16 x i8> %l, <4 x i32>
13311365
}
13321366

13331367
define <4 x i32> @partial_reduce_shl_zext_const_rhs6(<16 x i8> %l, <4 x i32> %part) {
1334-
; CHECK-COMMON-LABEL: partial_reduce_shl_zext_const_rhs6:
1335-
; CHECK-COMMON: // %bb.0:
1336-
; CHECK-COMMON-NEXT: ushll v2.8h, v0.8b, #0
1337-
; CHECK-COMMON-NEXT: ushll2 v0.8h, v0.16b, #0
1338-
; CHECK-COMMON-NEXT: ushll v3.4s, v0.4h, #6
1339-
; CHECK-COMMON-NEXT: ushll2 v4.4s, v2.8h, #6
1340-
; CHECK-COMMON-NEXT: ushll v2.4s, v2.4h, #6
1341-
; CHECK-COMMON-NEXT: ushll2 v0.4s, v0.8h, #6
1342-
; CHECK-COMMON-NEXT: add v1.4s, v1.4s, v2.4s
1343-
; CHECK-COMMON-NEXT: add v2.4s, v4.4s, v3.4s
1344-
; CHECK-COMMON-NEXT: add v1.4s, v1.4s, v2.4s
1345-
; CHECK-COMMON-NEXT: add v0.4s, v1.4s, v0.4s
1346-
; CHECK-COMMON-NEXT: ret
1368+
; CHECK-NODOT-LABEL: partial_reduce_shl_zext_const_rhs6:
1369+
; CHECK-NODOT: // %bb.0:
1370+
; CHECK-NODOT-NEXT: ushll v2.8h, v0.8b, #0
1371+
; CHECK-NODOT-NEXT: ushll2 v0.8h, v0.16b, #0
1372+
; CHECK-NODOT-NEXT: ushll v3.4s, v0.4h, #6
1373+
; CHECK-NODOT-NEXT: ushll2 v4.4s, v2.8h, #6
1374+
; CHECK-NODOT-NEXT: ushll v2.4s, v2.4h, #6
1375+
; CHECK-NODOT-NEXT: ushll2 v0.4s, v0.8h, #6
1376+
; CHECK-NODOT-NEXT: add v1.4s, v1.4s, v2.4s
1377+
; CHECK-NODOT-NEXT: add v2.4s, v4.4s, v3.4s
1378+
; CHECK-NODOT-NEXT: add v1.4s, v1.4s, v2.4s
1379+
; CHECK-NODOT-NEXT: add v0.4s, v1.4s, v0.4s
1380+
; CHECK-NODOT-NEXT: ret
1381+
;
1382+
; CHECK-DOT-LABEL: partial_reduce_shl_zext_const_rhs6:
1383+
; CHECK-DOT: // %bb.0:
1384+
; CHECK-DOT-NEXT: movi v2.16b, #64
1385+
; CHECK-DOT-NEXT: udot v1.4s, v0.16b, v2.16b
1386+
; CHECK-DOT-NEXT: mov v0.16b, v1.16b
1387+
; CHECK-DOT-NEXT: ret
1388+
;
1389+
; CHECK-DOT-I8MM-LABEL: partial_reduce_shl_zext_const_rhs6:
1390+
; CHECK-DOT-I8MM: // %bb.0:
1391+
; CHECK-DOT-I8MM-NEXT: movi v2.16b, #64
1392+
; CHECK-DOT-I8MM-NEXT: udot v1.4s, v0.16b, v2.16b
1393+
; CHECK-DOT-I8MM-NEXT: mov v0.16b, v1.16b
1394+
; CHECK-DOT-I8MM-NEXT: ret
13471395
%ext = zext <16 x i8> %l to <16 x i32>
13481396
%shift = shl nsw <16 x i32> %ext, splat (i32 6)
13491397
%red = tail call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %part, <16 x i32> %shift)

0 commit comments

Comments
 (0)