Skip to content

Commit 89acfeb

Browse files
icemelonLaurawly
authored andcommitted
[Relay][Frontend] Add ops in mxnet converter (#2844)
* Add ops in mxnet converter * trigger ci
1 parent f81e287 commit 89acfeb

File tree

2 files changed

+122
-5
lines changed

2 files changed

+122
-5
lines changed

python/tvm/relay/frontend/mxnet.py

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def _mx_slice_axis(inputs, attrs):
213213
ax_end = attrs.get_str("end")
214214
if axis < 0:
215215
axis += len(shape)
216-
assert axis >= 0 and axis < len(shape)
216+
assert 0 <= axis < len(shape)
217217
if ax_end == "None":
218218
ax_end = int(shape[axis])
219219
else:
@@ -222,8 +222,8 @@ def _mx_slice_axis(inputs, attrs):
222222
ax_beg += int(shape[axis])
223223
if ax_end < 0:
224224
ax_end += int(shape[axis])
225-
assert ax_beg >= 0 and ax_beg < int(shape[axis])
226-
assert ax_end > ax_beg and ax_end <= int(shape[axis])
225+
assert 0 <= ax_beg < int(shape[axis])
226+
assert ax_beg < ax_end <= int(shape[axis])
227227
begin = []
228228
end = []
229229
for i, dim in enumerate(shape):
@@ -527,11 +527,53 @@ def _mx_shape_array(inputs, attrs):
527527
return _op.shape_of(inputs[0], dtype='int64')
528528

529529

530+
def _mx_full(inputs, attrs):
531+
assert len(inputs) == 0
532+
val = attrs.get_float("value")
533+
shape = attrs.get_int_tuple("shape")
534+
dtype = attrs.get_str("dtype", "float32")
535+
return _op.full(_expr.const(val, dtype), shape, dtype)
536+
537+
538+
def _mx_squeeze(inputs, attrs):
539+
assert len(inputs) == 1
540+
axis = attrs.get_int_tuple("axis", None)
541+
return _op.squeeze(inputs[0], axis)
542+
543+
544+
def _mx_broadcast_axis(inputs, attrs):
545+
assert len(inputs) == 1
546+
axis = attrs.get_int_tuple("axis", [])
547+
size = attrs.get_int_tuple("size", [])
548+
assert len(axis) == len(size)
549+
if len(axis) == 0:
550+
return inputs[0]
551+
src_shape = ir_pass.infer_type(inputs[0])._checked_type_.shape
552+
tgt_shape = []
553+
for i, dim in enumerate(src_shape):
554+
if i not in axis:
555+
tgt_shape.append(dim)
556+
else:
557+
assert int(dim) == 1
558+
idx = axis.index(i)
559+
tgt_shape.append(size[idx])
560+
return _op.broadcast_to(inputs[0], tgt_shape)
561+
562+
563+
def _mx_embedding(inputs, _):
564+
assert len(inputs) == 2
565+
indices, weight = inputs
566+
return _op.take(weight, indices.astype('int32'), axis=0)
567+
568+
530569
# Note: due to attribute conversion constraint
531570
# ops in the identity set must be attribute free
532571
_identity_list = [
533572
"log",
534573
"exp",
574+
"sqrt",
575+
"floor",
576+
"ceil",
535577
"sigmoid",
536578
"tanh",
537579
"negative",
@@ -567,7 +609,6 @@ def _mx_shape_array(inputs, attrs):
567609
"Flatten" : _rename(_op.nn.batch_flatten),
568610
# scalar power
569611
"square" : _mx_make_power(2),
570-
"sqrt" : _mx_make_power(1/2),
571612
"rsqrt" : _mx_make_power(-1/2),
572613
"cbrt" : _mx_make_power(1/3),
573614
"rcbrt" : _mx_make_power(-1/3),
@@ -649,11 +690,15 @@ def _mx_shape_array(inputs, attrs):
649690
"batch_dot" : _mx_batch_dot,
650691
"LeakyReLU" : _mx_leaky_relu,
651692
"_arange" : _mx_arange,
693+
"_full" : _mx_full,
652694
"repeat" : _mx_repeat,
653695
"tile" : _mx_tile,
654696
"reverse" : _mx_reverse,
697+
"squeeze" : _mx_squeeze,
698+
"broadcast_axis": _mx_broadcast_axis,
655699
"BlockGrad" : _mx_BlockGrad,
656700
"shape_array" : _mx_shape_array,
701+
"Embedding" : _mx_embedding,
657702
"SoftmaxOutput" : _mx_softmax_output,
658703
"SoftmaxActivation" : _mx_softmax_activation,
659704
# vision

tests/python/frontend/mxnet/test_forward.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,6 @@ def test_forward_l2_normalize():
379379
mx_sym = mx.sym.L2Normalization(data, mode="channel")
380380
verify_mxnet_frontend_impl(mx_sym, (2, 3, 4, 5), (2, 3, 4, 5))
381381

382-
383382
def test_forward_shape_array():
384383
def verify(shape):
385384
x_np = np.random.uniform(size=shape).astype("float32")
@@ -395,6 +394,75 @@ def verify(shape):
395394
verify((3, 4, 5))
396395
verify((3, 4, 5, 6))
397396

397+
def test_forward_squeeze():
398+
def verify(shape, axis):
399+
x_np = np.random.uniform(size=shape).astype("float32")
400+
if axis is None:
401+
ref_res = mx.nd.squeeze(mx.nd.array(x_np))
402+
mx_sym = mx.sym.squeeze(mx.sym.var("x"))
403+
else:
404+
ref_res = mx.nd.squeeze(mx.nd.array(x_np), axis=axis)
405+
mx_sym = mx.sym.squeeze(mx.sym.var("x"), axis=axis)
406+
new_sym, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape})
407+
for target, ctx in ctx_list():
408+
for kind in ["graph", "debug"]:
409+
intrp = relay.create_executor(kind, ctx=ctx, target=target)
410+
op_res = intrp.evaluate(new_sym)(x_np)
411+
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
412+
verify((1, 3, 1), None)
413+
verify((1, 3, 1), 0)
414+
verify((1, 3, 1), 2)
415+
verify((1, 3, 1), (0, 2))
416+
417+
def test_forward_broadcast_axis():
418+
def verify(shape, axis, size):
419+
x_np = np.random.uniform(size=shape).astype("float32")
420+
ref_res = mx.nd.broadcast_axis(mx.nd.array(x_np), axis=axis, size=size)
421+
mx_sym = mx.sym.broadcast_axis(mx.sym.var("x"), axis=axis, size=size)
422+
new_sym, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape})
423+
for target, ctx in ctx_list():
424+
for kind in ["graph", "debug"]:
425+
intrp = relay.create_executor(kind, ctx=ctx, target=target)
426+
op_res = intrp.evaluate(new_sym)(x_np)
427+
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
428+
verify((1, 2, 1), 2, 3)
429+
verify((1, 2, 1), (0, 2), (2, 3))
430+
431+
def test_forward_full():
432+
def verify(val, shape, dtype):
433+
ctx = mx.cpu()
434+
ref_res = mx.nd.full(shape, val, dtype=dtype)
435+
mx_sym = mx.sym.full(shape, val, dtype=dtype)
436+
new_sym, _ = relay.frontend.from_mxnet(mx_sym, {})
437+
for target, ctx in ctx_list():
438+
# Skip testing graph runtime because this op will be optimized out
439+
# by constant folding.
440+
for kind in ["debug"]:
441+
intrp = relay.create_executor(kind, ctx=ctx, target=target)
442+
op_res = intrp.evaluate(new_sym)()
443+
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
444+
verify(2, (3, 4), "float32")
445+
verify(2, (3, 4), "int32")
446+
verify(3.5, (1, 3, 4), "float32")
447+
448+
def test_forward_embedding():
449+
def verify(data_shape, weight_shape):
450+
in_dim, out_dim = weight_shape
451+
x_np = np.random.randint(0, weight_shape[0], size=data_shape).astype("float32")
452+
w_np = np.random.uniform(size=weight_shape).astype("float32")
453+
ref_res = mx.nd.Embedding(mx.nd.array(x_np), mx.nd.array(w_np),
454+
input_dim=in_dim, output_dim=out_dim)
455+
mx_sym = mx.sym.Embedding(mx.sym.var("x"), mx.sym.var("w"),
456+
input_dim=in_dim, output_dim=out_dim)
457+
new_sym, _ = relay.frontend.from_mxnet(
458+
mx_sym, {"x": data_shape, "w": weight_shape})
459+
for target, ctx in ctx_list():
460+
for kind in ["graph", "debug"]:
461+
intrp = relay.create_executor(kind, ctx=ctx, target=target)
462+
op_res = intrp.evaluate(new_sym)(x=x_np, w=w_np)
463+
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
464+
verify((2, 2), (4, 5))
465+
verify((2, 3, 4), (4, 5))
398466

399467
if __name__ == '__main__':
400468
test_forward_mlp()
@@ -426,3 +494,7 @@ def verify(shape):
426494
test_forward_slice_axis()
427495
test_forward_l2_normalize()
428496
test_forward_shape_array()
497+
test_forward_squeeze()
498+
test_forward_broadcast_axis()
499+
test_forward_full()
500+
test_forward_embedding()

0 commit comments

Comments
 (0)