@@ -26,7 +26,7 @@ namespace mlir {
2626using namespace mlir ;
2727
2828namespace {
29- // The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780.
29+
3030struct AbsOpConversion : public OpConversionPattern <complex ::AbsOp> {
3131 using OpConversionPattern<complex ::AbsOp>::OpConversionPattern;
3232
@@ -35,49 +35,27 @@ struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
3535 ConversionPatternRewriter &rewriter) const override {
3636 mlir::ImplicitLocOpBuilder b (op.getLoc (), rewriter);
3737
38- arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr ();
38+ arith::FastMathFlags fmf = op.getFastMathFlagsAttr (). getValue ();
3939
4040 Type elementType = op.getType ();
41- Value arg = adaptor.getComplex ();
42-
43- Value zero =
44- b.create <arith::ConstantOp>(elementType, b.getZeroAttr (elementType));
4541 Value one = b.create <arith::ConstantOp>(elementType,
4642 b.getFloatAttr (elementType, 1.0 ));
4743
48- Value real = b.create <complex ::ReOp>(elementType, arg);
49- Value imag = b.create <complex ::ImOp>(elementType, arg);
50-
51- Value realIsZero =
52- b.create <arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
53- Value imagIsZero =
54- b.create <arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
44+ Value real = b.create <complex ::ReOp>(adaptor.getComplex ());
45+ Value imag = b.create <complex ::ImOp>(adaptor.getComplex ());
46+ Value absReal = b.create <math::AbsFOp>(real, fmf);
47+ Value absImag = b.create <math::AbsFOp>(imag, fmf);
5548
56- // Real > Imag
57- Value imagDivReal = b.create <arith::DivFOp>(imag, real, fmf.getValue ());
58- Value imagSq =
59- b.create <arith::MulFOp>(imagDivReal, imagDivReal, fmf.getValue ());
60- Value imagSqPlusOne = b.create <arith::AddFOp>(imagSq, one, fmf.getValue ());
61- Value imagSqrt = b.create <math::SqrtOp>(imagSqPlusOne, fmf.getValue ());
62- Value realAbs = b.create <math::AbsFOp>(real, fmf.getValue ());
63- Value absImag = b.create <arith::MulFOp>(imagSqrt, realAbs, fmf.getValue ());
64-
65- // Real <= Imag
66- Value realDivImag = b.create <arith::DivFOp>(real, imag, fmf.getValue ());
67- Value realSq =
68- b.create <arith::MulFOp>(realDivImag, realDivImag, fmf.getValue ());
69- Value realSqPlusOne = b.create <arith::AddFOp>(realSq, one, fmf.getValue ());
70- Value realSqrt = b.create <math::SqrtOp>(realSqPlusOne, fmf.getValue ());
71- Value imagAbs = b.create <math::AbsFOp>(imag, fmf.getValue ());
72- Value absReal = b.create <arith::MulFOp>(realSqrt, imagAbs, fmf.getValue ());
73-
74- rewriter.replaceOpWithNewOp <arith::SelectOp>(
75- op, realIsZero, imagAbs,
76- b.create <arith::SelectOp>(
77- imagIsZero, realAbs,
78- b.create <arith::SelectOp>(
79- b.create <arith::CmpFOp>(arith::CmpFPredicate::OGT, real, imag),
80- absImag, absReal)));
49+ Value max = b.create <arith::MaximumFOp>(absReal, absImag, fmf);
50+ Value min = b.create <arith::MinimumFOp>(absReal, absImag, fmf);
51+ Value ratio = b.create <arith::DivFOp>(min, max, fmf);
52+ Value ratioSq = b.create <arith::MulFOp>(ratio, ratio, fmf);
53+ Value ratioSqPlusOne = b.create <arith::AddFOp>(ratioSq, one, fmf);
54+ Value sqrt = b.create <math::SqrtOp>(ratioSqPlusOne, fmf);
55+ Value result = b.create <arith::MulFOp>(max, sqrt, fmf);
56+ Value isNaN =
57+ b.create <arith::CmpFOp>(arith::CmpFPredicate::UNO, result, result, fmf);
58+ rewriter.replaceOpWithNewOp <arith::SelectOp>(op, isNaN, min, result);
8159
8260 return success ();
8361 }
0 commit comments