Skip to content

Commit 0542b24

Browse files
icemelonwweic
authored andcommitted
Add MXNet converter for RNN layer ops (apache#3125)
1 parent 5242c1c commit 0542b24

File tree

3 files changed

+173
-15
lines changed

3 files changed

+173
-15
lines changed

python/tvm/relay/build_module.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from ..contrib import graph_runtime as _graph_rt
2727
from . import ir_pass
2828
from . import expr as _expr
29+
from . import ty as _ty
2930
from .backend import interpreter as _interpreter
3031
from .backend import graph_runtime_codegen as _graph_gen
3132

@@ -427,6 +428,8 @@ def __init__(self, mod, ctx, target):
427428
self.target = target
428429

429430
def _make_executor(self, func):
431+
ret_type = ir_pass.infer_type(func).ret_type
432+
num_outputs = len(ret_type.fields) if isinstance(ret_type, _ty.TupleType) else 1
430433
graph_json, mod, params = build(func, target=self.target)
431434
gmodule = _graph_rt.create(graph_json, mod, self.ctx)
432435
if params:
@@ -440,7 +443,12 @@ def _graph_wrapper(*args, **kwargs):
440443
# Run the module, and fetch the output.
441444
gmodule.run()
442445
# make a copy so multiple invocation won't hurt perf.
443-
return gmodule.get_output(0).copyto(_nd.cpu(0))
446+
if num_outputs == 1:
447+
return gmodule.get_output(0).copyto(_nd.cpu(0))
448+
outputs = []
449+
for i in range(num_outputs):
450+
outputs.append(gmodule.get_output(i).copyto(_nd.cpu(0)))
451+
return outputs
444452

445453
return _graph_wrapper
446454

python/tvm/relay/frontend/mxnet.py

Lines changed: 115 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@
3434

3535
__all__ = ['from_mxnet']
3636

37+
_activation_map = {
38+
"sigmoid": _op.sigmoid,
39+
"tanh" : _op.tanh,
40+
"relu" : _op.nn.relu
41+
}
42+
3743
def _mx_fully_connected(inputs, attrs):
3844
import mxnet as mx
3945
units = attrs.get_int("num_hidden")
@@ -66,12 +72,6 @@ def _get_channel_axis(layout, op_name):
6672
def _mx_activations(inputs, attrs):
6773
act_type = attrs.get_str("act_type")
6874
assert len(inputs) == 1
69-
if act_type == "sigmoid":
70-
return _op.sigmoid(inputs[0])
71-
if act_type == "tanh":
72-
return _op.tanh(inputs[0])
73-
if act_type == "relu":
74-
return _op.nn.relu(inputs[0])
7575
if act_type == "softrelu":
7676
def _stable_softrelu(x):
7777
# log(1 + exp(-abs(x))) + relu(x)
@@ -80,8 +80,10 @@ def _stable_softrelu(x):
8080
return _op.add(_op.log(_op.add(one, exp_neg_abs_x)),
8181
_op.nn.relu(x))
8282
return _stable_softrelu(inputs[0])
83-
raise tvm.error.OpNotImplemented(
84-
'Operator {} is not supported for frontend MXNet.'.format(act_type))
83+
if act_type not in _activation_map:
84+
raise tvm.error.OpNotImplemented(
85+
'Operator {} is not supported for frontend MXNet.'.format(act_type))
86+
return _activation_map[act_type](inputs[0])
8587

8688

8789
def _mx_compare(new_op, wrapper):
@@ -189,7 +191,8 @@ def _pool2d(new_op, is_avg):
189191
def _mx_adaptive_avg_pooling(inputs, attrs):
190192
output_size = attrs.get_int_tuple("output_size", [])
191193
if output_size != (1,):
192-
raise RuntimeError("AdaptiveAvgPooling with output_size other than 1 is not supported yet.")
194+
raise tvm.error.OpAttributeUnimplemented(
195+
"AdaptiveAvgPooling with output_size other than 1 is not supported yet.")
193196
return _op.nn.global_avg_pool2d(inputs[0])
194197

195198

@@ -471,7 +474,7 @@ def _mx_take(inputs, attrs):
471474
assert len(inputs) == 2
472475
mode = attrs.get_str("mode", "clip")
473476
if mode == "raise":
474-
raise RuntimeError("take doesn't support raise mode")
477+
raise tvm.error.OpAttributeUnimplemented("take with raise mode is not supported yet")
475478
axis = attrs.get_int("axis", 0)
476479
return _op.take(inputs[0], inputs[1].astype("int32"), axis, mode)
477480

@@ -571,13 +574,13 @@ def _mx_l2_normalize(inputs, attrs):
571574
def _mx_shape_array(inputs, attrs):
572575
assert len(inputs) == 1
573576
if attrs.get_int("lhs_begin", None) is not None:
574-
raise RuntimeError("shape_array doesn't support lhs_begin")
577+
raise tvm.error.OpAttributeUnimplemented("shape_array doesn't support lhs_begin")
575578
if attrs.get_int("lhs_end", None) is not None:
576-
raise RuntimeError("shape_array doesn't support lhs_end")
579+
raise tvm.error.OpAttributeUnimplemented("shape_array doesn't support lhs_end")
577580
if attrs.get_int("rhs_begin", None) is not None:
578-
raise RuntimeError("shape_array doesn't support rhs_begin")
581+
raise tvm.error.OpAttributeUnimplemented("shape_array doesn't support rhs_begin")
579582
if attrs.get_int("rhs_end", None) is not None:
580-
raise RuntimeError("shape_array doesn't support rhs_end")
583+
raise tvm.error.OpAttributeUnimplemented("shape_array doesn't support rhs_end")
581584
return _op.shape_of(inputs[0], dtype='int64')
582585

583586

@@ -657,6 +660,101 @@ def _mx_argsort(inputs, attrs):
657660
return _op.argsort(inputs[0], **new_attrs)
658661

659662

663+
def _mx_rnn_param_concat(inputs, _):
664+
# We don't need to concatenate RNN params because we will unravel the RNN op
665+
return [inputs]
666+
667+
668+
def _mx_rnn_layer(inputs, attrs):
669+
def _rnn_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias, activation):
670+
i2h = _op.nn.bias_add(_op.nn.dense(data, i2h_weight), i2h_bias, axis=-1)
671+
h2h = _op.nn.bias_add(_op.nn.dense(states[0], h2h_weight), h2h_bias, axis=-1)
672+
out = _activation_map[activation](i2h + h2h)
673+
return out, [out]
674+
675+
def _gru_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias):
676+
dtype = ir_pass.infer_type(data).checked_type.dtype
677+
i2h = _op.nn.bias_add(_op.nn.dense(data, i2h_weight), i2h_bias, axis=-1)
678+
h2h = _op.nn.bias_add(_op.nn.dense(states[0], h2h_weight), h2h_bias, axis=-1)
679+
i2h_r, i2h_z, i2h = _op.split(i2h, indices_or_sections=3, axis=1)
680+
h2h_r, h2h_z, h2h = _op.split(h2h, indices_or_sections=3, axis=1)
681+
reset_gate = _activation_map["sigmoid"](i2h_r + h2h_r)
682+
update_gate = _activation_map["sigmoid"](i2h_z + h2h_z)
683+
next_h_tmp = _activation_map["tanh"](reset_gate * h2h + i2h)
684+
next_h = (_expr.const(1, dtype) - update_gate) * next_h_tmp + update_gate * states[0]
685+
return next_h, [next_h]
686+
687+
def _lstm_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias):
688+
i2h = _op.nn.bias_add(_op.nn.dense(data, i2h_weight), i2h_bias, axis=-1)
689+
h2h = _op.nn.bias_add(_op.nn.dense(states[0], h2h_weight), h2h_bias, axis=-1)
690+
gates = i2h + h2h
691+
slice_gates = _op.split(gates, indices_or_sections=4, axis=1)
692+
in_gate = _activation_map["sigmoid"](slice_gates[0])
693+
forget_gate = _activation_map["sigmoid"](slice_gates[1])
694+
in_transform = _activation_map["tanh"](slice_gates[2])
695+
out_gate = _activation_map["sigmoid"](slice_gates[3])
696+
next_c = forget_gate * states[1] + in_gate * in_transform
697+
next_h = out_gate * _activation_map["tanh"](next_c)
698+
return next_h, [next_h, next_c]
699+
700+
num_layers = attrs.get_int("num_layers", 1)
701+
mode = attrs.get_str("mode")
702+
if mode.startswith("rnn"):
703+
mode, activation = mode.split('_')
704+
assert mode in ["rnn", "gru", "lstm"]
705+
bidirectional = attrs.get_bool("bidirectional", False)
706+
if bidirectional:
707+
raise tvm.error.OpAttributeUnimplemented(
708+
"Bidirectional RNN op is not supported yet")
709+
layout = attrs.get_str("layout", "TNC")
710+
if layout != "TNC":
711+
raise tvm.error.OpAttributeUnimplemented(
712+
"RNN with layout other than TNC is not supported yet")
713+
num_states = 2 if mode == 'lstm' else 1
714+
assert len(inputs) == num_states + 2
715+
716+
seq_data = inputs[0]
717+
concat_weight = inputs[1]
718+
concat_states = inputs[2:]
719+
seq_len = int(ir_pass.infer_type(seq_data).checked_type.shape[0])
720+
assert len(concat_weight) == num_layers * 4
721+
722+
weights = []
723+
bias = []
724+
states = []
725+
for i in range(num_layers):
726+
w = []
727+
b = []
728+
s = []
729+
for j in range(2):
730+
w.append(concat_weight[i*2 + j].args[0])
731+
b.append(concat_weight[num_layers*2 + i*2 + j].args[0])
732+
for state in concat_states:
733+
s.append(_op.take(state, _expr.const(i, "int32"), axis=0))
734+
weights.append(w)
735+
bias.append(b)
736+
states.append(s)
737+
738+
seq_output = []
739+
for t in range(seq_len):
740+
data = _op.take(seq_data, _expr.const(t, "int32"), axis=0)
741+
for l in range(num_layers):
742+
if mode == "rnn":
743+
out, new_states = _rnn_cell(data, states[l], *weights[l], *bias[l], activation)
744+
elif mode == "gru":
745+
out, new_states = _gru_cell(data, states[l], *weights[l], *bias[l])
746+
else: # mode == "lstm"
747+
out, new_states = _lstm_cell(data, states[l], *weights[l], *bias[l])
748+
states[l] = new_states
749+
data = out
750+
seq_output.append(out)
751+
752+
outputs = [_op.stack(seq_output, axis=0)]
753+
for i in range(num_states):
754+
outputs.append(_op.stack([s[i] for s in states], axis=0))
755+
return outputs
756+
757+
660758
# Note: due to attribute conversion constraint
661759
# ops in the identity set must be attribute free
662760
_identity_list = [
@@ -807,6 +905,9 @@ def _mx_argsort(inputs, attrs):
807905
"_contrib_box_nms" : _mx_box_nms,
808906
"_contrib_DeformableConvolution" : _mx_deformable_convolution,
809907
"_contrib_AdaptiveAvgPooling2D" : _mx_adaptive_avg_pooling,
908+
# NLP
909+
"RNN" : _mx_rnn_layer,
910+
"_rnn_param_concat" : _mx_rnn_param_concat,
810911
# List of missing operators that are present in NNVMv1
811912
# TODO(tvm-tvm): support all operators.
812913
#

tests/python/frontend/mxnet/test_forward.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,54 @@ def test_forward_bilinear_resize():
527527
mx_sym = mx.sym.contrib.BilinearResize2D(data, height=5, width=10)
528528
verify_mxnet_frontend_impl(mx_sym, (1, 2, 3, 4), (1, 2, 5, 10))
529529

530+
def test_forward_rnn_layer():
531+
def verify(mode, input_size, seq_len, hidden_size, num_layers, batch=1):
532+
if mode == "rnn":
533+
layer = gluon.rnn.RNN(hidden_size, num_layers)
534+
elif mode == "gru":
535+
layer = gluon.rnn.GRU(hidden_size, num_layers)
536+
else: # mode == "lstm"
537+
layer = gluon.rnn.LSTM(hidden_size, num_layers)
538+
num_states = 2 if mode == "lstm" else 1
539+
layer.initialize()
540+
541+
dtype = "float32"
542+
data_np = np.random.uniform(size=(seq_len, batch, input_size)).astype(dtype)
543+
states_np = []
544+
states_mx = []
545+
shape_dict = {'data0': data_np.shape}
546+
inputs = {'data0': data_np}
547+
for i in range(num_states):
548+
s = np.random.uniform(size=(num_layers, batch, hidden_size)).astype(dtype)
549+
states_np.append(s)
550+
states_mx.append(mx.nd.array(s))
551+
shape_dict['data%s' % (i+1)] = s.shape
552+
inputs['data%s' % (i+1)] = s
553+
554+
layer.hybridize()
555+
mx_out, mx_states = layer(mx.nd.array(data_np), states_mx)
556+
mx_res = [mx_out] + mx_states
557+
mx_sym = layer._cached_graph[1]
558+
mx_params = {}
559+
for name, param in layer.collect_params().items():
560+
mx_params[name] = param._reduce()
561+
562+
new_sym, params = relay.frontend.from_mxnet(
563+
mx_sym, shape=shape_dict, arg_params=mx_params)
564+
for target, ctx in ctx_list():
565+
# only test graph runtime because debug runtime is too slow
566+
for kind in ["graph"]:
567+
intrp = relay.create_executor(kind, ctx=ctx, target=target)
568+
op_res = intrp.evaluate(new_sym)(**inputs, **params)
569+
assert len(op_res) == len(mx_res)
570+
for i, val in enumerate(op_res):
571+
tvm.testing.assert_allclose(val.asnumpy(), mx_res[i].asnumpy(), rtol=1e-3)
572+
573+
for mode in ["rnn", "gru", "lstm"]:
574+
verify(mode, 64, 10, 64, 1)
575+
verify(mode, 64, 10, 64, 2)
576+
verify(mode, 64, 10, 32, 2)
577+
530578

531579
if __name__ == '__main__':
532580
test_forward_mlp()
@@ -566,3 +614,4 @@ def test_forward_bilinear_resize():
566614
test_forward_take()
567615
test_forward_gather_nd()
568616
test_forward_bilinear_resize()
617+
test_forward_rnn_layer()

0 commit comments

Comments
 (0)