diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index b16b5e28bf34..a4929d0b839d 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -26,6 +26,7 @@ from ..contrib import graph_runtime as _graph_rt from . import ir_pass from . import expr as _expr +from . import ty as _ty from .backend import interpreter as _interpreter from .backend import graph_runtime_codegen as _graph_gen @@ -427,6 +428,8 @@ def __init__(self, mod, ctx, target): self.target = target def _make_executor(self, func): + ret_type = ir_pass.infer_type(func).ret_type + num_outputs = len(ret_type.fields) if isinstance(ret_type, _ty.TupleType) else 1 graph_json, mod, params = build(func, target=self.target) gmodule = _graph_rt.create(graph_json, mod, self.ctx) if params: @@ -440,7 +443,12 @@ def _graph_wrapper(*args, **kwargs): # Run the module, and fetch the output. gmodule.run() # make a copy so multiple invocation won't hurt perf. - return gmodule.get_output(0).copyto(_nd.cpu(0)) + if num_outputs == 1: + return gmodule.get_output(0).copyto(_nd.cpu(0)) + outputs = [] + for i in range(num_outputs): + outputs.append(gmodule.get_output(i).copyto(_nd.cpu(0))) + return outputs return _graph_wrapper diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index f1bf6788ea20..b93bd5b244eb 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -34,6 +34,12 @@ __all__ = ['from_mxnet'] +_activation_map = { + "sigmoid": _op.sigmoid, + "tanh" : _op.tanh, + "relu" : _op.nn.relu +} + def _mx_fully_connected(inputs, attrs): import mxnet as mx units = attrs.get_int("num_hidden") @@ -66,12 +72,6 @@ def _get_channel_axis(layout, op_name): def _mx_activations(inputs, attrs): act_type = attrs.get_str("act_type") assert len(inputs) == 1 - if act_type == "sigmoid": - return _op.sigmoid(inputs[0]) - if act_type == "tanh": - return _op.tanh(inputs[0]) - if act_type == "relu": - return _op.nn.relu(inputs[0]) if act_type == "softrelu": def _stable_softrelu(x): # log(1 + exp(-abs(x))) + relu(x) @@ -80,8 +80,10 @@ def _stable_softrelu(x): return _op.add(_op.log(_op.add(one, exp_neg_abs_x)), _op.nn.relu(x)) return _stable_softrelu(inputs[0]) - raise tvm.error.OpNotImplemented( - 'Operator {} is not supported for frontend MXNet.'.format(act_type)) + if act_type not in _activation_map: + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported for frontend MXNet.'.format(act_type)) + return _activation_map[act_type](inputs[0]) def _mx_compare(new_op, wrapper): @@ -189,7 +191,8 @@ def _pool2d(new_op, is_avg): def _mx_adaptive_avg_pooling(inputs, attrs): output_size = attrs.get_int_tuple("output_size", []) if output_size != (1,): - raise RuntimeError("AdaptiveAvgPooling with output_size other than 1 is not supported yet.") + raise tvm.error.OpAttributeUnimplemented( + "AdaptiveAvgPooling with output_size other than 1 is not supported yet.") return _op.nn.global_avg_pool2d(inputs[0]) @@ -471,7 +474,7 @@ def _mx_take(inputs, attrs): assert len(inputs) == 2 mode = attrs.get_str("mode", "clip") if mode == "raise": - raise RuntimeError("take doesn't support raise mode") + raise tvm.error.OpAttributeUnimplemented("take with raise mode is not supported yet") axis = attrs.get_int("axis", 0) return _op.take(inputs[0], inputs[1].astype("int32"), axis, mode) @@ -571,13 +574,13 @@ def _mx_l2_normalize(inputs, attrs): def _mx_shape_array(inputs, attrs): assert len(inputs) == 1 if attrs.get_int("lhs_begin", None) is not None: - raise RuntimeError("shape_array doesn't support lhs_begin") + raise tvm.error.OpAttributeUnimplemented("shape_array doesn't support lhs_begin") if attrs.get_int("lhs_end", None) is not None: - raise RuntimeError("shape_array doesn't support lhs_end") + raise tvm.error.OpAttributeUnimplemented("shape_array doesn't support lhs_end") if attrs.get_int("rhs_begin", None) is not None: - raise RuntimeError("shape_array doesn't support rhs_begin") + raise tvm.error.OpAttributeUnimplemented("shape_array doesn't support rhs_begin") if attrs.get_int("rhs_end", None) is not None: - raise RuntimeError("shape_array doesn't support rhs_end") + raise tvm.error.OpAttributeUnimplemented("shape_array doesn't support rhs_end") return _op.shape_of(inputs[0], dtype='int64') @@ -657,6 +660,101 @@ def _mx_argsort(inputs, attrs): return _op.argsort(inputs[0], **new_attrs) +def _mx_rnn_param_concat(inputs, _): + # We don't need to concatenate RNN params because we will unravel the RNN op + return [inputs] + + +def _mx_rnn_layer(inputs, attrs): + def _rnn_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias, activation): + i2h = _op.nn.bias_add(_op.nn.dense(data, i2h_weight), i2h_bias, axis=-1) + h2h = _op.nn.bias_add(_op.nn.dense(states[0], h2h_weight), h2h_bias, axis=-1) + out = _activation_map[activation](i2h + h2h) + return out, [out] + + def _gru_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias): + dtype = ir_pass.infer_type(data).checked_type.dtype + i2h = _op.nn.bias_add(_op.nn.dense(data, i2h_weight), i2h_bias, axis=-1) + h2h = _op.nn.bias_add(_op.nn.dense(states[0], h2h_weight), h2h_bias, axis=-1) + i2h_r, i2h_z, i2h = _op.split(i2h, indices_or_sections=3, axis=1) + h2h_r, h2h_z, h2h = _op.split(h2h, indices_or_sections=3, axis=1) + reset_gate = _activation_map["sigmoid"](i2h_r + h2h_r) + update_gate = _activation_map["sigmoid"](i2h_z + h2h_z) + next_h_tmp = _activation_map["tanh"](reset_gate * h2h + i2h) + next_h = (_expr.const(1, dtype) - update_gate) * next_h_tmp + update_gate * states[0] + return next_h, [next_h] + + def _lstm_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias): + i2h = _op.nn.bias_add(_op.nn.dense(data, i2h_weight), i2h_bias, axis=-1) + h2h = _op.nn.bias_add(_op.nn.dense(states[0], h2h_weight), h2h_bias, axis=-1) + gates = i2h + h2h + slice_gates = _op.split(gates, indices_or_sections=4, axis=1) + in_gate = _activation_map["sigmoid"](slice_gates[0]) + forget_gate = _activation_map["sigmoid"](slice_gates[1]) + in_transform = _activation_map["tanh"](slice_gates[2]) + out_gate = _activation_map["sigmoid"](slice_gates[3]) + next_c = forget_gate * states[1] + in_gate * in_transform + next_h = out_gate * _activation_map["tanh"](next_c) + return next_h, [next_h, next_c] + + num_layers = attrs.get_int("num_layers", 1) + mode = attrs.get_str("mode") + if mode.startswith("rnn"): + mode, activation = mode.split('_') + assert mode in ["rnn", "gru", "lstm"] + bidirectional = attrs.get_bool("bidirectional", False) + if bidirectional: + raise tvm.error.OpAttributeUnimplemented( + "Bidirectional RNN op is not supported yet") + layout = attrs.get_str("layout", "TNC") + if layout != "TNC": + raise tvm.error.OpAttributeUnimplemented( + "RNN with layout other than TNC is not supported yet") + num_states = 2 if mode == 'lstm' else 1 + assert len(inputs) == num_states + 2 + + seq_data = inputs[0] + concat_weight = inputs[1] + concat_states = inputs[2:] + seq_len = int(ir_pass.infer_type(seq_data).checked_type.shape[0]) + assert len(concat_weight) == num_layers * 4 + + weights = [] + bias = [] + states = [] + for i in range(num_layers): + w = [] + b = [] + s = [] + for j in range(2): + w.append(concat_weight[i*2 + j].args[0]) + b.append(concat_weight[num_layers*2 + i*2 + j].args[0]) + for state in concat_states: + s.append(_op.take(state, _expr.const(i, "int32"), axis=0)) + weights.append(w) + bias.append(b) + states.append(s) + + seq_output = [] + for t in range(seq_len): + data = _op.take(seq_data, _expr.const(t, "int32"), axis=0) + for l in range(num_layers): + if mode == "rnn": + out, new_states = _rnn_cell(data, states[l], *weights[l], *bias[l], activation) + elif mode == "gru": + out, new_states = _gru_cell(data, states[l], *weights[l], *bias[l]) + else: # mode == "lstm" + out, new_states = _lstm_cell(data, states[l], *weights[l], *bias[l]) + states[l] = new_states + data = out + seq_output.append(out) + + outputs = [_op.stack(seq_output, axis=0)] + for i in range(num_states): + outputs.append(_op.stack([s[i] for s in states], axis=0)) + return outputs + + # Note: due to attribute conversion constraint # ops in the identity set must be attribute free _identity_list = [ @@ -807,6 +905,9 @@ def _mx_argsort(inputs, attrs): "_contrib_box_nms" : _mx_box_nms, "_contrib_DeformableConvolution" : _mx_deformable_convolution, "_contrib_AdaptiveAvgPooling2D" : _mx_adaptive_avg_pooling, + # NLP + "RNN" : _mx_rnn_layer, + "_rnn_param_concat" : _mx_rnn_param_concat, # List of missing operators that are present in NNVMv1 # TODO(tvm-tvm): support all operators. # diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index d00efb39e16f..067c356830bb 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -527,6 +527,54 @@ def test_forward_bilinear_resize(): mx_sym = mx.sym.contrib.BilinearResize2D(data, height=5, width=10) verify_mxnet_frontend_impl(mx_sym, (1, 2, 3, 4), (1, 2, 5, 10)) +def test_forward_rnn_layer(): + def verify(mode, input_size, seq_len, hidden_size, num_layers, batch=1): + if mode == "rnn": + layer = gluon.rnn.RNN(hidden_size, num_layers) + elif mode == "gru": + layer = gluon.rnn.GRU(hidden_size, num_layers) + else: # mode == "lstm" + layer = gluon.rnn.LSTM(hidden_size, num_layers) + num_states = 2 if mode == "lstm" else 1 + layer.initialize() + + dtype = "float32" + data_np = np.random.uniform(size=(seq_len, batch, input_size)).astype(dtype) + states_np = [] + states_mx = [] + shape_dict = {'data0': data_np.shape} + inputs = {'data0': data_np} + for i in range(num_states): + s = np.random.uniform(size=(num_layers, batch, hidden_size)).astype(dtype) + states_np.append(s) + states_mx.append(mx.nd.array(s)) + shape_dict['data%s' % (i+1)] = s.shape + inputs['data%s' % (i+1)] = s + + layer.hybridize() + mx_out, mx_states = layer(mx.nd.array(data_np), states_mx) + mx_res = [mx_out] + mx_states + mx_sym = layer._cached_graph[1] + mx_params = {} + for name, param in layer.collect_params().items(): + mx_params[name] = param._reduce() + + new_sym, params = relay.frontend.from_mxnet( + mx_sym, shape=shape_dict, arg_params=mx_params) + for target, ctx in ctx_list(): + # only test graph runtime because debug runtime is too slow + for kind in ["graph"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + op_res = intrp.evaluate(new_sym)(**inputs, **params) + assert len(op_res) == len(mx_res) + for i, val in enumerate(op_res): + tvm.testing.assert_allclose(val.asnumpy(), mx_res[i].asnumpy(), rtol=1e-3) + + for mode in ["rnn", "gru", "lstm"]: + verify(mode, 64, 10, 64, 1) + verify(mode, 64, 10, 64, 2) + verify(mode, 64, 10, 32, 2) + if __name__ == '__main__': test_forward_mlp() @@ -566,3 +614,4 @@ def test_forward_bilinear_resize(): test_forward_take() test_forward_gather_nd() test_forward_bilinear_resize() + test_forward_rnn_layer()