@@ -91,40 +91,34 @@ static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter) {
9191}
9292
9393// / Expands tanh op into
94- // / 1-exp^{-2x} / 1+exp^{-2x}
95- // / To avoid overflow we exploit the reflection symmetry `tanh(-x) = -tanh(x)`.
96- // / We compute a "signs" value which is -1 if input is negative and +1 if input
97- // / is positive. Then multiply the input by this value, guaranteeing that the
98- // / result is positive, which also guarantees `exp^{-2x * sign(x)}` is in (0,
99- // / 1]. Expand the computation on the input `x * sign(x)`, then multiply the
100- // / result by `sign(x)` to retain sign of the real result.
94+ // / 1) 1-exp^{-2x} / 1+exp^{-2x}, if x => 0
95+ // / 2) exp^{2x}-1 / exp^{2x}+1 , if x < 0
10196static LogicalResult convertTanhOp (math::TanhOp op, PatternRewriter &rewriter) {
10297 auto floatType = op.getOperand ().getType ();
10398 Location loc = op.getLoc ();
104- Value zero = createFloatConst (loc, floatType, 0.0 , rewriter);
10599 Value one = createFloatConst (loc, floatType, 1.0 , rewriter);
106- Value negTwo = createFloatConst (loc, floatType, -2.0 , rewriter);
107-
108- // Compute sign(x) = cast<float_type>(x < 0) * (-2) + 1
109- Value sign = rewriter.create <arith::CmpFOp>(loc, arith::CmpFPredicate::OLT,
110- op.getOperand (), zero);
111- sign = rewriter.create <arith::SIToFPOp>(loc, floatType, sign);
112- sign = rewriter.create <arith::MulFOp>(loc, sign, negTwo);
113- sign = rewriter.create <arith::AddFOp>(loc, sign, one);
100+ Value two = createFloatConst (loc, floatType, 2.0 , rewriter);
101+ Value doubledX = rewriter.create <arith::MulFOp>(loc, op.getOperand (), two);
114102
115- // Normalize input to positive value: y = sign(x) * x
116- Value positiveX = rewriter.create <arith::MulFOp>(loc, sign, op.getOperand ());
117-
118- // Decompose on normalized input
119- Value negDoubledX = rewriter.create <arith::MulFOp>(loc, negTwo, positiveX);
103+ // Case 1: tanh(x) = 1-exp^{-2x} / 1+exp^{-2x}
104+ Value negDoubledX = rewriter.create <arith::NegFOp>(loc, doubledX);
120105 Value exp2x = rewriter.create <math::ExpOp>(loc, negDoubledX);
121106 Value dividend = rewriter.create <arith::SubFOp>(loc, one, exp2x);
122107 Value divisor = rewriter.create <arith::AddFOp>(loc, one, exp2x);
123108 Value positiveRes = rewriter.create <arith::DivFOp>(loc, dividend, divisor);
124109
125- // Multiply result by sign(x) to retain signs from negative inputs
126- rewriter.replaceOpWithNewOp <arith::MulFOp>(op, sign, positiveRes);
110+ // Case 2: tanh(x) = exp^{2x}-1 / exp^{2x}+1
111+ exp2x = rewriter.create <math::ExpOp>(loc, doubledX);
112+ dividend = rewriter.create <arith::SubFOp>(loc, exp2x, one);
113+ divisor = rewriter.create <arith::AddFOp>(loc, exp2x, one);
114+ Value negativeRes = rewriter.create <arith::DivFOp>(loc, dividend, divisor);
127115
116+ // tanh(x) = x >= 0 ? positiveRes : negativeRes
117+ Value zero = createFloatConst (loc, floatType, 0.0 , rewriter);
118+ Value cmpRes = rewriter.create <arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
119+ op.getOperand (), zero);
120+ rewriter.replaceOpWithNewOp <arith::SelectOp>(op, cmpRes, positiveRes,
121+ negativeRes);
128122 return success ();
129123}
130124
0 commit comments