Skip to content

Commit 25ba3ea

Browse files
icemelonwweic
authored andcommitted
[Relay][Frontend] Add a few mxnet ops in relay frontend (apache#2704)
1 parent 22eb2d8 commit 25ba3ea

File tree

2 files changed

+136
-26
lines changed

2 files changed

+136
-26
lines changed

python/tvm/relay/frontend/mxnet.py

Lines changed: 53 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,13 @@ def _stable_softrelu(x):
6464
raise RuntimeError("Do not support act_type: {}".format(act_type))
6565

6666

67+
def _mx_compare(new_op, wrapper):
68+
def impl(inputs, attrs):
69+
dtype = ir_pass.infer_type(inputs[0]).checked_type.dtype
70+
return wrapper(new_op)(inputs, attrs).astype(dtype)
71+
return impl
72+
73+
6774
def _mx_conv2d(inputs, attrs):
6875
kernel_size = attrs.get_int_tuple("kernel")
6976
if len(kernel_size) != 2:
@@ -333,32 +340,52 @@ def _mx_roi_align(inputs, attrs):
333340
]
334341

335342
_convert_map = {
336-
"_copy" : _rename(_op.copy),
337-
"relu" : _rename(_op.nn.relu),
338-
"broadcast_add" : _rename(_op.add),
339-
"broadcast_sub" : _rename(_op.subtract),
340-
"broadcast_mul" : _rename(_op.multiply),
341-
"broadcast_div" : _rename(_op.divide),
342-
"elemwise_add" : _rename(_op.add),
343-
"elemwise_sub" : _rename(_op.subtract),
344-
"elemwise_mul" : _rename(_op.multiply),
345-
"elemwise_div" : _rename(_op.divide),
346-
"flatten" : _rename(_op.nn.batch_flatten),
347-
"Flatten" : _rename(_op.nn.batch_flatten),
348-
"_plus_scalar" : _binop_scalar(_op.add),
349-
"__add_scalar__": _binop_scalar(_op.add),
350-
"__sub_scalar__": _binop_scalar(_op.subtract),
351-
"_minus_scalar" : _binop_scalar(_op.subtract),
352-
"__mul_scalar__": _binop_scalar(_op.multiply),
353-
"_mul_scalar" : _binop_scalar(_op.multiply),
354-
"__div_scalar__": _binop_scalar(_op.divide),
355-
"_div_scalar" : _binop_scalar(_op.divide),
356-
"__pow_scalar__": _binop_scalar(_op.power),
357-
"_rminus_scalar": _rbinop_scalar(_op.subtract),
358-
"__rsub_scalar__": _rbinop_scalar(_op.subtract),
359-
"_rdiv_scalar" : _rbinop_scalar(_op.divide),
360-
"__rdiv_scalar__" : _rbinop_scalar(_op.divide),
361-
"__rpow_scalar__": _rbinop_scalar(_op.power),
343+
"_copy" : _rename(_op.copy),
344+
"relu" : _rename(_op.nn.relu),
345+
"broadcast_add" : _rename(_op.add),
346+
"broadcast_sub" : _rename(_op.subtract),
347+
"broadcast_mul" : _rename(_op.multiply),
348+
"broadcast_div" : _rename(_op.divide),
349+
"broadcast_mod" : _rename(_op.mod),
350+
"broadcast_maximum" : _rename(_op.maximum),
351+
"broadcast_minimum" : _rename(_op.minimum),
352+
"broadcast_equal" : _mx_compare(_op.equal, _rename),
353+
"broadcast_not_equal" : _mx_compare(_op.not_equal, _rename),
354+
"broadcast_greater" : _mx_compare(_op.greater, _rename),
355+
"broadcast_greater_equal": _mx_compare(_op.greater_equal, _rename),
356+
"broadcast_lesser" : _mx_compare(_op.less, _rename),
357+
"broadcast_lesser_equal" : _mx_compare(_op.less_equal, _rename),
358+
"elemwise_add" : _rename(_op.add),
359+
"elemwise_sub" : _rename(_op.subtract),
360+
"elemwise_mul" : _rename(_op.multiply),
361+
"elemwise_div" : _rename(_op.divide),
362+
"_maximum" : _rename(_op.maximum),
363+
"_minimum" : _rename(_op.minimum),
364+
"flatten" : _rename(_op.nn.batch_flatten),
365+
"Flatten" : _rename(_op.nn.batch_flatten),
366+
"__add_scalar__" : _binop_scalar(_op.add),
367+
"_plus_scalar" : _binop_scalar(_op.add),
368+
"__sub_scalar__" : _binop_scalar(_op.subtract),
369+
"_minus_scalar" : _binop_scalar(_op.subtract),
370+
"__mul_scalar__" : _binop_scalar(_op.multiply),
371+
"_mul_scalar" : _binop_scalar(_op.multiply),
372+
"__div_scalar__" : _binop_scalar(_op.divide),
373+
"_div_scalar" : _binop_scalar(_op.divide),
374+
"__pow_scalar__" : _binop_scalar(_op.power),
375+
"_power_scalar" : _binop_scalar(_op.power),
376+
"__rsub_scalar__" : _rbinop_scalar(_op.subtract),
377+
"_rminus_scalar" : _rbinop_scalar(_op.subtract),
378+
"__rdiv_scalar__" : _rbinop_scalar(_op.divide),
379+
"_rdiv_scalar" : _rbinop_scalar(_op.divide),
380+
"__rpow_scalar__" : _rbinop_scalar(_op.power),
381+
"_equal_scalar" : _mx_compare(_op.equal, _binop_scalar),
382+
"_not_equal_scalar" : _mx_compare(_op.not_equal, _binop_scalar),
383+
"_greater_scalar" : _mx_compare(_op.greater, _binop_scalar),
384+
"_greater_equal_scalar" : _mx_compare(_op.greater_equal, _binop_scalar),
385+
"_lesser_scalar" : _mx_compare(_op.less, _binop_scalar),
386+
"_lesser_equal_scalar" : _mx_compare(_op.less_equal, _binop_scalar),
387+
"_maximum_scalar" : _binop_scalar(_op.maximum),
388+
"_minimum_scalar" : _binop_scalar(_op.minimum),
362389
# reduction ops
363390
"max" : _reduce(_op.max),
364391
"min" : _reduce(_op.min),

tests/python/frontend/mxnet/test_forward.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
import operator
23

34
import tvm
45
from tvm.contrib import graph_runtime
@@ -256,6 +257,85 @@ def verify(start, stop, step):
256257
verify(20, 1, -1)
257258
verify(20, 1, -1.5)
258259

260+
def _mx_symbol(F, op_name, inputs):
261+
op = getattr(F, op_name)
262+
return op(*inputs)
263+
264+
def test_forward_broadcast_ops():
265+
for op in ["broadcast_add", "broadcast_sub", "broadcast_mul",
266+
"broadcast_div", "broadcast_mod", "broadcast_maximum",
267+
"broadcast_minimum", "broadcast_equal", "broadcast_not_equal",
268+
"broadcast_greater", "broadcast_greater_equal",
269+
"broadcast_lesser", "broadcast_lesser_equal"]:
270+
a_shape = (3, 4, 5)
271+
b_shape = (4, 5)
272+
if op == "broadcast_mod":
273+
dtype = 'int32'
274+
a_np = np.random.randint(1, 100, size=a_shape).astype(dtype)
275+
b_np = np.random.randint(1, 100, size=b_shape).astype(dtype)
276+
else:
277+
dtype = 'float32'
278+
a_np = np.random.uniform(size=a_shape).astype(dtype)
279+
b_np = np.random.uniform(size=b_shape).astype(dtype)
280+
mx_sym = _mx_symbol(mx.sym, op, [mx.sym.var('a'), mx.sym.var('b')])
281+
ref_res = _mx_symbol(mx.nd, op, [mx.nd.array(a_np), mx.nd.array(b_np)])
282+
shapes = {'a': a_shape, 'b': b_shape}
283+
new_sym, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype)
284+
for target, ctx in ctx_list():
285+
for kind in ["graph", "debug"]:
286+
intrp = relay.create_executor(kind, ctx=ctx, target=target)
287+
op_res = intrp.evaluate(new_sym)(a_np, b_np)
288+
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
289+
290+
def test_forward_elemwise_ops():
291+
for op in ["elemwise_add", "elemwise_sub", "elemwise_mul",
292+
"elemwise_div", "maximum", "minimum"]:
293+
shape = (3, 4, 5)
294+
dtype = 'float32'
295+
a_np = np.random.uniform(size=shape).astype(dtype)
296+
b_np = np.random.uniform(size=shape).astype(dtype)
297+
mx_sym = _mx_symbol(mx.sym, op, [mx.sym.var('a'), mx.sym.var('b')])
298+
ref_res = _mx_symbol(mx.nd, op, [mx.nd.array(a_np), mx.nd.array(b_np)])
299+
shapes = {'a': shape, 'b': shape}
300+
new_sym, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype)
301+
for target, ctx in ctx_list():
302+
for kind in ["graph", "debug"]:
303+
intrp = relay.create_executor(kind, ctx=ctx, target=target)
304+
op_res = intrp.evaluate(new_sym)(a_np, b_np)
305+
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
306+
307+
def test_forward_scalar_ops():
308+
for op in [operator.add, operator.sub, operator.mul, operator.truediv,
309+
operator.pow, operator.lt, operator.le, operator.eq,
310+
operator.ne, operator.gt, operator.ge]:
311+
dtype='float32'
312+
a_shape = (3, 4, 5)
313+
a_np = np.random.uniform(size=a_shape).astype(dtype)
314+
b_scalar = 2.3
315+
mx_sym = op(mx.sym.var('a'), b_scalar)
316+
ref_res = op(mx.nd.array(a_np), b_scalar)
317+
shapes = {'a': a_shape}
318+
new_sym, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype)
319+
for target, ctx in ctx_list():
320+
for kind in ["graph", "debug"]:
321+
intrp = relay.create_executor(kind, ctx=ctx, target=target)
322+
op_res = intrp.evaluate(new_sym)(a_np)
323+
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
324+
for op in ["maximum", "minimum"]:
325+
dtype='float32'
326+
a_shape = (3, 4, 5)
327+
a_np = np.random.uniform(size=a_shape).astype(dtype)
328+
b_scalar = 2.3
329+
mx_sym = _mx_symbol(mx.sym, op, [mx.sym.var('a'), b_scalar])
330+
ref_res = _mx_symbol(mx.nd, op, [mx.nd.array(a_np), b_scalar])
331+
shapes = {'a': a_shape}
332+
new_sym, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype)
333+
for target, ctx in ctx_list():
334+
for kind in ["graph", "debug"]:
335+
intrp = relay.create_executor(kind, ctx=ctx, target=target)
336+
op_res = intrp.evaluate(new_sym)(a_np)
337+
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
338+
259339

260340
if __name__ == '__main__':
261341
test_forward_mlp()
@@ -280,3 +360,6 @@ def verify(start, stop, step):
280360
test_forward_argmin()
281361
test_forward_where()
282362
test_forward_arange()
363+
test_forward_broadcast_ops()
364+
test_forward_elemwise_ops()
365+
test_forward_scalar_ops()

0 commit comments

Comments
 (0)