Skip to content

Commit 5235cac

Browse files
committed
review comments:
- inline the lowerAddCarry function to make it clearer that we check that each integer is of i32 - switch the computation of the carry to reduce instructions and simplify
1 parent 8bbc9f2 commit 5235cac

File tree

2 files changed

+27
-77
lines changed

2 files changed

+27
-77
lines changed

mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp

Lines changed: 15 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -133,48 +133,6 @@ Value lowerExtendedMultiplication(Operation *mulOp, PatternRewriter &rewriter,
133133
loc, mulOp->getResultTypes().front(), llvm::ArrayRef({low, high}));
134134
}
135135

136-
Value lowerCarryAddition(Operation *addOp, PatternRewriter &rewriter, Value lhs,
137-
Value rhs) {
138-
Location loc = addOp->getLoc();
139-
Type argTy = lhs.getType();
140-
// Emulate 64-bit addition by splitting each input element of type i32 to
141-
// i16 similar to above in lowerExtendedMultiplication. We then expand
142-
// to 3 additions:
143-
// - Add two low digits into low resut
144-
// - Add two high digits into high result
145-
// - Add the carry from low result to high result
146-
Value cstLowMask = rewriter.create<ConstantOp>(
147-
loc, lhs.getType(), getScalarOrSplatAttr(argTy, (1 << 16) - 1));
148-
auto getLowDigit = [&rewriter, loc, cstLowMask](Value val) {
149-
return rewriter.create<BitwiseAndOp>(loc, val, cstLowMask);
150-
};
151-
152-
Value cst16 = rewriter.create<ConstantOp>(loc, lhs.getType(),
153-
getScalarOrSplatAttr(argTy, 16));
154-
auto getHighDigit = [&rewriter, loc, cst16](Value val) {
155-
return rewriter.create<ShiftRightLogicalOp>(loc, val, cst16);
156-
};
157-
158-
Value lhsLow = getLowDigit(lhs);
159-
Value lhsHigh = getHighDigit(lhs);
160-
Value rhsLow = getLowDigit(rhs);
161-
Value rhsHigh = getHighDigit(rhs);
162-
163-
Value low = rewriter.create<IAddOp>(loc, lhsLow, rhsLow);
164-
Value high = rewriter.create<IAddOp>(loc, lhsHigh, rhsHigh);
165-
Value highWithCarry = rewriter.create<IAddOp>(loc, high, getHighDigit(low));
166-
167-
auto combineDigits = [loc, cst16, &rewriter](Value low, Value high) {
168-
Value highBits = rewriter.create<ShiftLeftLogicalOp>(loc, high, cst16);
169-
return rewriter.create<BitwiseOrOp>(loc, low, highBits);
170-
};
171-
Value out = combineDigits(getLowDigit(highWithCarry), getLowDigit(low));
172-
Value carry = getHighDigit(highWithCarry);
173-
174-
return rewriter.create<CompositeConstructOp>(
175-
loc, addOp->getResultTypes().front(), llvm::ArrayRef({out, carry}));
176-
}
177-
178136
//===----------------------------------------------------------------------===//
179137
// Rewrite Patterns
180138
//===----------------------------------------------------------------------===//
@@ -220,13 +178,26 @@ struct ExpandAddCarryPattern final : OpRewritePattern<IAddCarryOp> {
220178

221179
// Currently, WGSL only supports 32-bit integer types. Any other integer
222180
// types should already have been promoted/demoted to i32.
223-
auto elemTy = cast<IntegerType>(getElementTypeOrSelf(lhs.getType()));
181+
Type argTy = lhs.getType();
182+
auto elemTy = cast<IntegerType>(getElementTypeOrSelf(argTy));
224183
if (elemTy.getIntOrFloatBitWidth() != 32)
225184
return rewriter.notifyMatchFailure(
226185
loc,
227186
llvm::formatv("Unexpected integer type for WebGPU: '{0}'", elemTy));
228187

229-
Value add = lowerCarryAddition(op, rewriter, lhs, rhs);
188+
Value one =
189+
rewriter.create<ConstantOp>(loc, argTy, getScalarOrSplatAttr(argTy, 1));
190+
Value zero =
191+
rewriter.create<ConstantOp>(loc, argTy, getScalarOrSplatAttr(argTy, 0));
192+
193+
// Emulate 64-bit unsigned addition by allowing our addition to overflow,
194+
// and then set the carry accordingly.
195+
Value out = rewriter.create<IAddOp>(loc, lhs, rhs);
196+
Value cmp = rewriter.create<ULessThanOp>(loc, out, lhs);
197+
Value carry = rewriter.create<SelectOp>(loc, cmp, one, zero);
198+
199+
Value add = rewriter.create<CompositeConstructOp>(
200+
loc, op->getResultTypes().front(), llvm::ArrayRef({out, carry}));
230201

231202
rewriter.replaceOp(op, add);
232203
return success();

mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir

Lines changed: 12 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -147,47 +147,26 @@ spirv.func @smul_extended_i16(%arg : i16) -> !spirv.struct<(i16, i16)> "None" {
147147

148148
// CHECK-LABEL: func @iaddcarry_i32
149149
// CHECK-SAME: ([[A:%.+]]: i32, [[B:%.+]]: i32)
150-
// CHECK-NEXT: [[CSTMASK:%.+]] = spirv.Constant 65535 : i32
151-
// CHECK-NEXT: [[CST16:%.+]] = spirv.Constant 16 : i32
152-
// CHECK-NEXT: [[LHSLOW:%.+]] = spirv.BitwiseAnd [[A]], [[CSTMASK]] : i32
153-
// CHECK-NEXT: [[LHSHI:%.+]] = spirv.ShiftRightLogical [[A]], [[CST16]] : i32
154-
// CHECK-DAG: [[RHSLOW:%.+]] = spirv.BitwiseAnd [[B]], [[CSTMASK]] : i32
155-
// CHECK-DAG: [[RHSHI:%.+]] = spirv.ShiftRightLogical [[B]], [[CST16]] : i32
156-
// CHECK-DAG: [[LOW:%.+]] = spirv.IAdd [[LHSLOW]], [[RHSLOW]] : i32
157-
// CHECK-DAG: [[HI:%.+]] = spirv.IAdd [[LHSHI]], [[RHSHI]]
158-
// CHECK-DAG: [[LOWCRY:%.+]] = spirv.ShiftRightLogical [[LOW]], [[CST16]] : i32
159-
// CHECK-DAG: [[HI_TTL:%.+]] = spirv.IAdd [[HI]], [[LOWCRY]]
160-
// CHECK-DAG: spirv.ShiftRightLogical
161-
// CHECK-DAG: spirv.BitwiseAnd
162-
// CHECK-DAG: spirv.BitwiseAnd
163-
// CHECK-DAG: spirv.ShiftLeftLogical {{%.+}}, [[CST16]] : i32
164-
// CHECK-DAG: spirv.BitwiseOr
165-
// CHECK-NEXT: [[RES:%.+]] = spirv.CompositeConstruct [[RESLO:%.+]], [[RESHI:%.+]] : (i32, i32) -> !spirv.struct<(i32, i32)>
150+
// CHECK-NEXT: [[ONE:%.+]] = spirv.Constant 1 : i32
151+
// CHECK-NEXT: [[ZERO:%.+]] = spirv.Constant 0 : i32
152+
// CHECK-NEXT: [[OUT:%.+]] = spirv.IAdd [[A]], [[B]]
153+
// CHECK-NEXT: [[CMP:%.+]] = spirv.ULessThan [[OUT]], [[A]]
154+
// CHECK-NEXT: [[CARRY:%.+]] = spirv.Select [[CMP]], [[ONE]], [[ZERO]]
155+
// CHECK-NEXT: [[RES:%.+]] = spirv.CompositeConstruct [[OUT]], [[CARRY]] : (i32, i32) -> !spirv.struct<(i32, i32)>
166156
// CHECK-NEXT: spirv.ReturnValue [[RES]] : !spirv.struct<(i32, i32)>
167157
spirv.func @iaddcarry_i32(%a : i32, %b : i32) -> !spirv.struct<(i32, i32)> "None" {
168158
%0 = spirv.IAddCarry %a, %b : !spirv.struct<(i32, i32)>
169159
spirv.ReturnValue %0 : !spirv.struct<(i32, i32)>
170160
}
171161

172-
173162
// CHECK-LABEL: func @iaddcarry_vector_i32
174163
// CHECK-SAME: ([[A:%.+]]: vector<3xi32>, [[B:%.+]]: vector<3xi32>)
175-
// CHECK-NEXT: [[CSTMASK:%.+]] = spirv.Constant dense<65535> : vector<3xi32>
176-
// CHECK-NEXT: [[CST16:%.+]] = spirv.Constant dense<16> : vector<3xi32>
177-
// CHECK-NEXT: [[LHSLOW:%.+]] = spirv.BitwiseAnd [[A]], [[CSTMASK]] : vector<3xi32>
178-
// CHECK-NEXT: [[LHSHI:%.+]] = spirv.ShiftRightLogical [[A]], [[CST16]] : vector<3xi32>
179-
// CHECK-DAG: [[RHSLOW:%.+]] = spirv.BitwiseAnd [[B]], [[CSTMASK]] : vector<3xi32>
180-
// CHECK-DAG: [[RHSHI:%.+]] = spirv.ShiftRightLogical [[B]], [[CST16]] : vector<3xi32>
181-
// CHECK-DAG: [[LOW:%.+]] = spirv.IAdd [[LHSLOW]], [[RHSLOW]] : vector<3xi32>
182-
// CHECK-DAG: [[HI:%.+]] = spirv.IAdd [[LHSHI]], [[RHSHI]]
183-
// CHECK-DAG: [[LOWCRY:%.+]] = spirv.ShiftRightLogical [[LOW]], [[CST16]] : vector<3xi32>
184-
// CHECK-DAG: [[HI_TTL:%.+]] = spirv.IAdd [[HI]], [[LOWCRY]]
185-
// CHECK-DAG: spirv.ShiftRightLogical
186-
// CHECK-DAG: spirv.BitwiseAnd
187-
// CHECK-DAG: spirv.BitwiseAnd
188-
// CHECK-DAG: spirv.ShiftLeftLogical {{%.+}}, [[CST16]] : vector<3xi32>
189-
// CHECK-DAG: spirv.BitwiseOr
190-
// CHECK-NEXT: [[RES:%.+]] = spirv.CompositeConstruct [[RESLO:%.+]], [[RESHI:%.+]] : (vector<3xi32>, vector<3xi32>) -> !spirv.struct<(vector<3xi32>, vector<3xi32>)>
164+
// CHECK-NEXT: [[ONE:%.+]] = spirv.Constant dense<1> : vector<3xi32>
165+
// CHECK-NEXT: [[ZERO:%.+]] = spirv.Constant dense<0> : vector<3xi32>
166+
// CHECK-NEXT: [[OUT:%.+]] = spirv.IAdd [[A]], [[B]]
167+
// CHECK-NEXT: [[CMP:%.+]] = spirv.ULessThan [[OUT]], [[A]]
168+
// CHECK-NEXT: [[CARRY:%.+]] = spirv.Select [[CMP]], [[ONE]], [[ZERO]]
169+
// CHECK-NEXT: [[RES:%.+]] = spirv.CompositeConstruct [[OUT]], [[CARRY]] : (vector<3xi32>, vector<3xi32>) -> !spirv.struct<(vector<3xi32>, vector<3xi32>)>
191170
// CHECK-NEXT: spirv.ReturnValue [[RES]] : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
192171
spirv.func @iaddcarry_vector_i32(%a : vector<3xi32>, %b : vector<3xi32>)
193172
-> !spirv.struct<(vector<3xi32>, vector<3xi32>)> "None" {

0 commit comments

Comments
 (0)