Skip to content

Commit f0470d8

Browse files
committed
{relay,topi}.reinterpret support
= Motivation It's useful to expose the tvm::reinterpret functionality to Relay/TOPI users, as this allows them to build (fused) operators leveraging the bitwise reinterpretation of an operator. An example is approximate transcendental functions, which can be implemented similar to: ```.py def C(x): return relay.expr.const(x, "float32") def approx_exp(x): x = relay.minimum(relay.maximum(x, C(-88.0)), C(88.0)) x = C(127.0) + x * C(1.44269504) xf = relay.floor(x) i = relay.cast(xf, "int32") x = x - xf Y = C(0.99992522) + x * (C(0.69583354) + x * (C(0.22606716) + x * C(0.078024523))) exponent = relay.left_shift(i, relay.expr.const(23, "int32")) exponent = relay.reinterpret(exponent, "float32") return exponent * Y def approx_sigmoid(x): # <2.0e-5 absolute error over [-5, 5] y = approx_exp(x) return y / (y + C(1.0)) def approx_tanh(x): # <4.0e-5 absolute error over [-5, 5] x = x * C(2.0) y = approx_exp(x) return (y - C(1.0)) / (y + C(1.0)) ``` See unit tests for implementations of these approximate transendentals.
1 parent 19eb829 commit f0470d8

File tree

9 files changed

+219
-8
lines changed

9 files changed

+219
-8
lines changed

docs/api/python/topi.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ List of operators
4040
topi.sigmoid
4141
topi.clip
4242
topi.cast
43+
topi.reinterpret
4344
topi.transpose
4445
topi.flip
4546
topi.strided_slice
@@ -133,6 +134,7 @@ topi
133134
.. autofunction:: topi.sigmoid
134135
.. autofunction:: topi.clip
135136
.. autofunction:: topi.cast
137+
.. autofunction:: topi.reinterpret
136138
.. autofunction:: topi.transpose
137139
.. autofunction:: topi.flip
138140
.. autofunction:: topi.strided_slice

python/tvm/relay/op/_transform.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
_reg.register_schedule("repeat", schedule_broadcast)
4141
_reg.register_schedule("tile", schedule_broadcast)
4242
_reg.register_schedule("cast", schedule_injective)
43+
_reg.register_schedule("reinterpret", schedule_injective)
4344
_reg.register_schedule("strided_slice", schedule_injective)
4445
_reg.register_schedule("slice_like", schedule_injective)
4546
_reg.register_schedule("split", schedule_injective)

python/tvm/relay/op/transform.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,26 @@ def cast(data, dtype):
4040
return _relay_make.cast(data, dtype)
4141

4242

43+
def reinterpret(data, dtype):
44+
"""Reinterpret input tensor to data type.
45+
46+
Parameters
47+
----------
48+
data : relay.Expr
49+
The input data to the operator.
50+
51+
dtype: str
52+
The target data type
53+
54+
Returns
55+
-------
56+
result : relay.Expr
57+
The reinterpreted result.
58+
"""
59+
from .. import _make as _relay_make
60+
return _relay_make.reinterpret(data, dtype)
61+
62+
4363
def expand_dims(data, axis, num_newaxis=1):
4464
"""Insert `num_newaxis` axises at the position given by `axis`.
4565

src/relay/op/tensor/transform.cc

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,52 @@ RELAY_REGISTER_OP("cast")
9797
.set_attr<TOpPattern>("TOpPattern", kElemWise)
9898
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout);
9999

100+
// relay.reinterpret
101+
bool ReinterpretRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
102+
const TypeReporter& reporter) {
103+
CHECK_EQ(types.size(), 2);
104+
const auto* data = types[0].as<TensorTypeNode>();
105+
if (data == nullptr) {
106+
CHECK(types[0].as<IncompleteTypeNode>())
107+
<< "Reinterpret: expect input type to be TensorType but get " << types[0];
108+
return false;
109+
}
110+
const auto* param = attrs.as<CastAttrs>();
111+
reporter->Assign(types[1], TensorTypeNode::make(data->shape, param->dtype));
112+
return true;
113+
}
114+
115+
Array<Tensor> ReinterpretCompute(const Attrs& attrs, const Array<Tensor>& inputs,
116+
const Type& out_type, const Target& target) {
117+
const CastAttrs* param = attrs.as<CastAttrs>();
118+
CHECK(param != nullptr);
119+
DataType dtype = param->dtype;
120+
return {topi::reinterpret(inputs[0], dtype)};
121+
}
122+
123+
Expr MakeReinterpret(Expr data, DataType dtype) {
124+
auto attrs = make_node<CastAttrs>();
125+
attrs->dtype = dtype;
126+
static const Op& op = Op::Get("reinterpret");
127+
return CallNode::make(op, {data}, Attrs(attrs), {});
128+
}
129+
130+
TVM_REGISTER_API("relay._make.reinterpret").set_body([](const TVMArgs& args, TVMRetValue* rv) {
131+
runtime::detail::unpack_call<Expr, 2>(MakeReinterpret, args, rv);
132+
});
133+
134+
RELAY_REGISTER_OP("reinterpret")
135+
.describe(R"code(Reinterpret the data into a new data type.
136+
)code" TVM_ADD_FILELINE)
137+
.set_num_inputs(1)
138+
.set_attrs_type_key("relay.attrs.CastAttrs")
139+
.add_argument("data", "Tensor", "The input tensor.")
140+
.set_support_level(3)
141+
.add_type_rel("Reinterpret", CastRel)
142+
.set_attr<FTVMCompute>("FTVMCompute", ReinterpretCompute)
143+
.set_attr<TOpPattern>("TOpPattern", kElemWise)
144+
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout);
145+
100146
// relay.expand_dims
101147
TVM_REGISTER_NODE_TYPE(ExpandDimsAttrs);
102148

tests/python/relay/test_op_level3.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def test_cast():
7575
assert "dtype=" in yy.astext()
7676
assert yy.checked_type == relay.TensorType((8, 9, 4), "int32")
7777

78+
7879
def test_clip():
7980
a = relay.var("a", relay.TensorType((10, 4), "float32"))
8081
y = relay.clip(a, 1., 4.)
@@ -88,6 +89,69 @@ def test_clip():
8889
np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01)
8990

9091

92+
def test_reinterpret():
93+
a = relay.var("a", relay.TensorType((1000, 4), "float32"))
94+
y = relay.reinterpret(a, "int32")
95+
yy = run_infer_type(y)
96+
assert yy.checked_type == relay.TensorType((1000, 4), "int32")
97+
98+
data = np.random.randn(1000, 4).astype('float32') * 1000
99+
intrp = create_executor()
100+
op_res = intrp.evaluate(y, {a: relay.const(data)})
101+
ref_res = data.view("int32")
102+
np.testing.assert_equal(op_res.asnumpy(), ref_res)
103+
104+
105+
def test_approximate_transcendental():
106+
def C(x):
107+
return relay.expr.const(x, "float32")
108+
109+
def approx_exp(x):
110+
# An approximation derived from Opus,
111+
# https://github.com/xiph/opus/blob/c1c247/celt/mathops.h#L147-L165
112+
x = relay.minimum(relay.maximum(x, C(-88.0)), C(88.0))
113+
x = C(127.0) + x * C(1.44269504)
114+
xf = relay.floor(x)
115+
i = relay.cast(xf, "int32")
116+
x = x - xf
117+
Y = C(0.99992522) + x * (C(0.69583354) + x * (C(0.22606716) + x * C(0.078024523)))
118+
exponent = relay.left_shift(i, relay.expr.const(23, "int32"))
119+
exponent = relay.reinterpret(exponent, "float32")
120+
return exponent * Y
121+
122+
def approximate_sigmoid(x):
123+
y = approx_exp(x)
124+
return y / (y + C(1.0))
125+
126+
def approximate_tanh(x):
127+
x = x * C(2.0)
128+
y = approx_exp(x)
129+
return (y - C(1.0)) / (y + C(1.0))
130+
131+
a = relay.var("a", relay.TensorType((1000,), "float32"))
132+
y = approximate_sigmoid(a)
133+
yy = run_infer_type(y)
134+
assert yy.checked_type == relay.TensorType((1000,), "float32")
135+
data = np.linspace(-5, 5, 1000).astype("float32")
136+
intrp = create_executor()
137+
op_res = intrp.evaluate(y, {a: relay.const(data)})
138+
139+
def reference_sigmoid(x):
140+
return np.exp(-np.logaddexp(0, -x))
141+
np.testing.assert_allclose(op_res.asnumpy(), reference_sigmoid(data), atol=2e-5, rtol=1e-9)
142+
143+
y = approximate_tanh(a)
144+
yy = run_infer_type(y)
145+
assert yy.checked_type == relay.TensorType((1000,), "float32")
146+
data = np.linspace(-5, 5, 1000).astype("float32")
147+
intrp = create_executor()
148+
op_res = intrp.evaluate(y, {a: relay.const(data)})
149+
150+
def reference_tanh(x):
151+
return np.tanh(x)
152+
np.testing.assert_allclose(op_res.asnumpy(), reference_tanh(data), atol=4e-5, rtol=1e-9)
153+
154+
91155
def test_squeeze():
92156
def verify_squeeze(shape, dtype, axis):
93157
x = relay.var("x", relay.TensorType(shape, dtype))

topi/include/topi/elemwise.h

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -269,14 +269,34 @@ inline Tensor cast(const Tensor& x,
269269
}
270270

271271
/*!
272-
* \brief Creates an operation that sum each element of a tensor
273-
*
274-
* \param xs The input tensor array
275-
* \param name The name of the operation
276-
* \param tag The tag to mark the operation
277-
*
278-
* \return A Tensor whose op member is the sum operation
279-
*/
272+
* \brief Reinterpret each element of x to the given type.
273+
274+
* \param x The input tensor
275+
* \param type The type to cast to
276+
* \param name The name of the operation
277+
* \param tag The tag to mark the operation
278+
*
279+
* \return A Tensor whose op member is the reinterpret operation
280+
*/
281+
inline Tensor reinterpret(const Tensor& x, Type type, std::string name = "tensor",
282+
std::string tag = kElementWise) {
283+
return compute(x->shape,
284+
[&](const Array<Var>& i) {
285+
return tvm::ir::Call::make(type, "reinterpret", {x(i)},
286+
tvm::ir::Call::PureIntrinsic);
287+
},
288+
name, tag);
289+
}
290+
291+
/*!
292+
* \brief Creates an operation that sum each element of a tensor
293+
*
294+
* \param xs The input tensor array
295+
* \param name The name of the operation
296+
* \param tag The tag to mark the operation
297+
*
298+
* \return A Tensor whose op member is the sum operation
299+
*/
280300
inline Tensor elemwise_sum(const Array<Tensor>& xs,
281301
std::string name = "T_elemwise_sum",
282302
std::string tag = kElementWise) {

topi/python/topi/math.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,3 +343,21 @@ def cast(x, dtype):
343343
return tvm.compute(
344344
x.shape, lambda *i: x(*i).astype(dtype), tag=tag.ELEMWISE)
345345
return tvm.make._cast(dtype, x)
346+
347+
def reinterpret(x, dtype):
348+
"""Reinterpret input to specified data type.
349+
350+
Parameters
351+
----------
352+
x : tvm.Tensor
353+
Input argument.
354+
355+
dtype : str
356+
Data type.
357+
358+
Returns
359+
-------
360+
y : tvm.Tensor
361+
The result.
362+
"""
363+
return cpp.reinterpret(x, dtype)

topi/src/topi.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,12 @@ TVM_REGISTER_GLOBAL("topi.cast")
193193
*rv = cast(args[0], args[1]);
194194
});
195195

196+
197+
TVM_REGISTER_GLOBAL("topi.reinterpret")
198+
.set_body([](TVMArgs args, TVMRetValue* rv) {
199+
*rv = reinterpret(args[0], args[1]);
200+
});
201+
196202
TVM_REGISTER_GLOBAL("topi.elemwise_sum")
197203
.set_body([](TVMArgs args, TVMRetValue *rv) {
198204
*rv = elemwise_sum(args[0]);

topi/tests/python/test_topi_transform.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,29 @@ def check_device(device):
4545
check_device(device)
4646

4747

48+
def verify_reinterpret(in_shape, in_dtype, out_dtype, generator):
49+
A = tvm.placeholder(shape=in_shape, name="A", dtype=in_dtype)
50+
B = topi.reinterpret(A, out_dtype)
51+
def check_device(device):
52+
ctx = tvm.context(device, 0)
53+
if not ctx.exist:
54+
print("Skip because %s is not enabled" % device)
55+
return
56+
print("Running on target: %s" % device)
57+
with tvm.target.create(device):
58+
s = topi.generic.schedule_elemwise(B)
59+
foo = tvm.build(s, [A, B], device, name="reinterpret")
60+
data_npy = generator(in_shape).astype(in_dtype)
61+
out_npy = data_npy.view(B.dtype)
62+
data_nd = tvm.nd.array(data_npy, ctx)
63+
out_nd = tvm.nd.array(np.empty(in_shape).astype(B.dtype), ctx)
64+
foo(data_nd, out_nd)
65+
np.testing.assert_equal(out_nd.asnumpy(), out_npy)
66+
67+
for device in get_all_backend():
68+
check_device(device)
69+
70+
4871
def verify_transpose(in_shape, axes):
4972
A = tvm.placeholder(shape=in_shape, name="A")
5073
B = topi.transpose(A, axes)
@@ -434,6 +457,17 @@ def test_expand_dims():
434457
verify_expand_dims((3, 10), (1, 3, 10), -3, 1)
435458

436459

460+
def test_reinterpret():
461+
verify_reinterpret((1000,), "float32", "int32",
462+
lambda shape: np.random.randn(*shape) * 1000)
463+
verify_reinterpret((1000,), "float16", "int16",
464+
lambda shape: np.random.randn(*shape) * 100)
465+
verify_reinterpret((1000,), "int16", "uint16",
466+
lambda shape: np.random.randint(-1000, 1000, size=shape))
467+
verify_reinterpret((1000,), "uint32", "int32",
468+
lambda shape: np.random.randint(0, 2 ** 32 - 1, size=shape))
469+
470+
437471
def test_transpose():
438472
verify_transpose((3, 10, 2), (1, 0, 2))
439473
verify_transpose((3, 10, 5), (2, 0, 1))

0 commit comments

Comments
 (0)