@@ -133,48 +133,6 @@ Value lowerExtendedMultiplication(Operation *mulOp, PatternRewriter &rewriter,
133133 loc, mulOp->getResultTypes ().front (), llvm::ArrayRef ({low, high}));
134134}
135135
136- Value lowerCarryAddition (Operation *addOp, PatternRewriter &rewriter, Value lhs,
137- Value rhs) {
138- Location loc = addOp->getLoc ();
139- Type argTy = lhs.getType ();
140- // Emulate 64-bit addition by splitting each input element of type i32 to
141- // i16 similar to above in lowerExtendedMultiplication. We then expand
142- // to 3 additions:
143- // - Add two low digits into low resut
144- // - Add two high digits into high result
145- // - Add the carry from low result to high result
146- Value cstLowMask = rewriter.create <ConstantOp>(
147- loc, lhs.getType (), getScalarOrSplatAttr (argTy, (1 << 16 ) - 1 ));
148- auto getLowDigit = [&rewriter, loc, cstLowMask](Value val) {
149- return rewriter.create <BitwiseAndOp>(loc, val, cstLowMask);
150- };
151-
152- Value cst16 = rewriter.create <ConstantOp>(loc, lhs.getType (),
153- getScalarOrSplatAttr (argTy, 16 ));
154- auto getHighDigit = [&rewriter, loc, cst16](Value val) {
155- return rewriter.create <ShiftRightLogicalOp>(loc, val, cst16);
156- };
157-
158- Value lhsLow = getLowDigit (lhs);
159- Value lhsHigh = getHighDigit (lhs);
160- Value rhsLow = getLowDigit (rhs);
161- Value rhsHigh = getHighDigit (rhs);
162-
163- Value low = rewriter.create <IAddOp>(loc, lhsLow, rhsLow);
164- Value high = rewriter.create <IAddOp>(loc, lhsHigh, rhsHigh);
165- Value highWithCarry = rewriter.create <IAddOp>(loc, high, getHighDigit (low));
166-
167- auto combineDigits = [loc, cst16, &rewriter](Value low, Value high) {
168- Value highBits = rewriter.create <ShiftLeftLogicalOp>(loc, high, cst16);
169- return rewriter.create <BitwiseOrOp>(loc, low, highBits);
170- };
171- Value out = combineDigits (getLowDigit (highWithCarry), getLowDigit (low));
172- Value carry = getHighDigit (highWithCarry);
173-
174- return rewriter.create <CompositeConstructOp>(
175- loc, addOp->getResultTypes ().front (), llvm::ArrayRef ({out, carry}));
176- }
177-
178136// ===----------------------------------------------------------------------===//
179137// Rewrite Patterns
180138// ===----------------------------------------------------------------------===//
@@ -220,13 +178,26 @@ struct ExpandAddCarryPattern final : OpRewritePattern<IAddCarryOp> {
220178
221179 // Currently, WGSL only supports 32-bit integer types. Any other integer
222180 // types should already have been promoted/demoted to i32.
223- auto elemTy = cast<IntegerType>(getElementTypeOrSelf (lhs.getType ()));
181+ Type argTy = lhs.getType ();
182+ auto elemTy = cast<IntegerType>(getElementTypeOrSelf (argTy));
224183 if (elemTy.getIntOrFloatBitWidth () != 32 )
225184 return rewriter.notifyMatchFailure (
226185 loc,
227186 llvm::formatv (" Unexpected integer type for WebGPU: '{0}'" , elemTy));
228187
229- Value add = lowerCarryAddition (op, rewriter, lhs, rhs);
188+ Value one =
189+ rewriter.create <ConstantOp>(loc, argTy, getScalarOrSplatAttr (argTy, 1 ));
190+ Value zero =
191+ rewriter.create <ConstantOp>(loc, argTy, getScalarOrSplatAttr (argTy, 0 ));
192+
193+ // Emulate 64-bit unsigned addition by allowing our addition to overflow,
194+ // and then set the carry accordingly.
195+ Value out = rewriter.create <IAddOp>(loc, lhs, rhs);
196+ Value cmp = rewriter.create <ULessThanOp>(loc, out, lhs);
197+ Value carry = rewriter.create <SelectOp>(loc, cmp, one, zero);
198+
199+ Value add = rewriter.create <CompositeConstructOp>(
200+ loc, op->getResultTypes ().front (), llvm::ArrayRef ({out, carry}));
230201
231202 rewriter.replaceOp (op, add);
232203 return success ();
0 commit comments