Skip to content

Commit b74a4f9

Browse files
committed
fast tanh
1 parent 29cfac6 commit b74a4f9

File tree

1 file changed

+55
-3
lines changed

1 file changed

+55
-3
lines changed

topi/include/topi/elemwise.h

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
* to you under the Apache License, Version 2.0 (the
77
* "License"); you may not use this file except in compliance
88
* with the License. You may obtain a copy of the License at
9-
*
9+
*
1010
* http://www.apache.org/licenses/LICENSE-2.0
11-
*
11+
*
1212
* Unless required by applicable law or agreed to in writing,
1313
* software distributed under the License is distributed on an
1414
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -31,6 +31,7 @@
3131
#include "tvm/tvm.h"
3232
#include "tvm/ir.h"
3333
#include "tvm/ir_pass.h"
34+
#include "broadcast.h"
3435

3536
namespace topi {
3637
using namespace tvm;
@@ -46,7 +47,6 @@ using namespace tvm;
4647
}
4748

4849
TOPI_DECLARE_UNARY_OP(exp);
49-
TOPI_DECLARE_UNARY_OP(tanh);
5050
TOPI_DECLARE_UNARY_OP(sigmoid);
5151
TOPI_DECLARE_UNARY_OP(sqrt);
5252
TOPI_DECLARE_UNARY_OP(log);
@@ -56,6 +56,58 @@ TOPI_DECLARE_UNARY_OP(round);
5656
TOPI_DECLARE_UNARY_OP(trunc);
5757
TOPI_DECLARE_UNARY_OP(abs);
5858

59+
/*!
60+
* \brief Creates an operation that returns hyperbolic tan of a given tensor
61+
* Same as the fast_tanh_float implementation from Eigen
62+
*
63+
* \param in The input tensor
64+
* \param name The name of the operation
65+
* \param tag The tag to mark the operation
66+
*
67+
* \return A Tensor whose op member is tanh
68+
*/
69+
70+
inline Tensor tanh(const Tensor& in,
71+
std::string name = "T_tanh",
72+
std::string tag = kElementWise) {
73+
// Clamp the inputs to the range [-9, 9] since anything outside
74+
// this range is +/-1.0f in single-precision.
75+
auto x = maximum(minimum(in, make_const(in->dtype, 9.0)), make_const(in->dtype, -9.0));
76+
77+
// The monomial coefficients of the numerator polynomial (odd).
78+
auto alpha_1 = make_const(in->dtype, 4.89352455891786e-03);
79+
auto alpha_3 = make_const(in->dtype, 6.37261928875436e-04);
80+
auto alpha_5 = make_const(in->dtype, 1.48572235717979e-05);
81+
auto alpha_7 = make_const(in->dtype, 5.12229709037114e-08);
82+
auto alpha_9 = make_const(in->dtype, -8.60467152213735e-11);
83+
auto alpha_11 = make_const(in->dtype, 2.00018790482477e-13);
84+
auto alpha_13 = make_const(in->dtype, -2.76076847742355e-16);
85+
86+
// The monomial coefficients of the denominator polynomial (even).
87+
auto beta_0 = make_const(in->dtype, 4.89352518554385e-03);
88+
auto beta_2 = make_const(in->dtype, 2.26843463243900e-03);
89+
auto beta_4 = make_const(in->dtype, 1.18534705686654e-04);
90+
auto beta_6 = make_const(in->dtype, 1.19825839466702e-06);
91+
92+
return compute(x->shape,
93+
[&](const Array<Var>& i) {
94+
auto x2 = x(i) * x(i);
95+
auto p = x2 * alpha_13 + alpha_11;
96+
p = x2 * p + alpha_9;
97+
p = x2 * p + alpha_7;
98+
p = x2 * p + alpha_5;
99+
p = x2 * p + alpha_3;
100+
p = x2 * p + alpha_1;
101+
p = x(i) * p;
102+
103+
auto q = x2 * beta_6 + beta_4;
104+
q = x2 * q + beta_2;
105+
q = x2 * q + beta_0;
106+
return p / q;
107+
},
108+
name, tag);
109+
}
110+
59111
/*!
60112
* \brief Creates an operation that returns identity of a given tensor
61113
*

0 commit comments

Comments
 (0)