3434
3535__all__ = ['from_mxnet' ]
3636
37+ _activation_map = {
38+ "sigmoid" : _op .sigmoid ,
39+ "tanh" : _op .tanh ,
40+ "relu" : _op .nn .relu
41+ }
42+
3743def _mx_fully_connected (inputs , attrs ):
3844 import mxnet as mx
3945 units = attrs .get_int ("num_hidden" )
@@ -66,12 +72,6 @@ def _get_channel_axis(layout, op_name):
6672def _mx_activations (inputs , attrs ):
6773 act_type = attrs .get_str ("act_type" )
6874 assert len (inputs ) == 1
69- if act_type == "sigmoid" :
70- return _op .sigmoid (inputs [0 ])
71- if act_type == "tanh" :
72- return _op .tanh (inputs [0 ])
73- if act_type == "relu" :
74- return _op .nn .relu (inputs [0 ])
7575 if act_type == "softrelu" :
7676 def _stable_softrelu (x ):
7777 # log(1 + exp(-abs(x))) + relu(x)
@@ -80,8 +80,10 @@ def _stable_softrelu(x):
8080 return _op .add (_op .log (_op .add (one , exp_neg_abs_x )),
8181 _op .nn .relu (x ))
8282 return _stable_softrelu (inputs [0 ])
83- raise tvm .error .OpNotImplemented (
84- 'Operator {} is not supported for frontend MXNet.' .format (act_type ))
83+ if act_type not in _activation_map :
84+ raise tvm .error .OpNotImplemented (
85+ 'Operator {} is not supported for frontend MXNet.' .format (act_type ))
86+ return _activation_map [act_type ](inputs [0 ])
8587
8688
8789def _mx_compare (new_op , wrapper ):
@@ -189,7 +191,8 @@ def _pool2d(new_op, is_avg):
189191def _mx_adaptive_avg_pooling (inputs , attrs ):
190192 output_size = attrs .get_int_tuple ("output_size" , [])
191193 if output_size != (1 ,):
192- raise RuntimeError ("AdaptiveAvgPooling with output_size other than 1 is not supported yet." )
194+ raise tvm .error .OpAttributeUnimplemented (
195+ "AdaptiveAvgPooling with output_size other than 1 is not supported yet." )
193196 return _op .nn .global_avg_pool2d (inputs [0 ])
194197
195198
@@ -471,7 +474,7 @@ def _mx_take(inputs, attrs):
471474 assert len (inputs ) == 2
472475 mode = attrs .get_str ("mode" , "clip" )
473476 if mode == "raise" :
474- raise RuntimeError ("take doesn't support raise mode" )
477+ raise tvm . error . OpAttributeUnimplemented ("take with raise mode is not supported yet " )
475478 axis = attrs .get_int ("axis" , 0 )
476479 return _op .take (inputs [0 ], inputs [1 ].astype ("int32" ), axis , mode )
477480
@@ -571,13 +574,13 @@ def _mx_l2_normalize(inputs, attrs):
571574def _mx_shape_array (inputs , attrs ):
572575 assert len (inputs ) == 1
573576 if attrs .get_int ("lhs_begin" , None ) is not None :
574- raise RuntimeError ("shape_array doesn't support lhs_begin" )
577+ raise tvm . error . OpAttributeUnimplemented ("shape_array doesn't support lhs_begin" )
575578 if attrs .get_int ("lhs_end" , None ) is not None :
576- raise RuntimeError ("shape_array doesn't support lhs_end" )
579+ raise tvm . error . OpAttributeUnimplemented ("shape_array doesn't support lhs_end" )
577580 if attrs .get_int ("rhs_begin" , None ) is not None :
578- raise RuntimeError ("shape_array doesn't support rhs_begin" )
581+ raise tvm . error . OpAttributeUnimplemented ("shape_array doesn't support rhs_begin" )
579582 if attrs .get_int ("rhs_end" , None ) is not None :
580- raise RuntimeError ("shape_array doesn't support rhs_end" )
583+ raise tvm . error . OpAttributeUnimplemented ("shape_array doesn't support rhs_end" )
581584 return _op .shape_of (inputs [0 ], dtype = 'int64' )
582585
583586
@@ -657,6 +660,101 @@ def _mx_argsort(inputs, attrs):
657660 return _op .argsort (inputs [0 ], ** new_attrs )
658661
659662
663+ def _mx_rnn_param_concat (inputs , _ ):
664+ # We don't need to concatenate RNN params because we will unravel the RNN op
665+ return [inputs ]
666+
667+
668+ def _mx_rnn_layer (inputs , attrs ):
669+ def _rnn_cell (data , states , i2h_weight , h2h_weight , i2h_bias , h2h_bias , activation ):
670+ i2h = _op .nn .bias_add (_op .nn .dense (data , i2h_weight ), i2h_bias , axis = - 1 )
671+ h2h = _op .nn .bias_add (_op .nn .dense (states [0 ], h2h_weight ), h2h_bias , axis = - 1 )
672+ out = _activation_map [activation ](i2h + h2h )
673+ return out , [out ]
674+
675+ def _gru_cell (data , states , i2h_weight , h2h_weight , i2h_bias , h2h_bias ):
676+ dtype = ir_pass .infer_type (data ).checked_type .dtype
677+ i2h = _op .nn .bias_add (_op .nn .dense (data , i2h_weight ), i2h_bias , axis = - 1 )
678+ h2h = _op .nn .bias_add (_op .nn .dense (states [0 ], h2h_weight ), h2h_bias , axis = - 1 )
679+ i2h_r , i2h_z , i2h = _op .split (i2h , indices_or_sections = 3 , axis = 1 )
680+ h2h_r , h2h_z , h2h = _op .split (h2h , indices_or_sections = 3 , axis = 1 )
681+ reset_gate = _activation_map ["sigmoid" ](i2h_r + h2h_r )
682+ update_gate = _activation_map ["sigmoid" ](i2h_z + h2h_z )
683+ next_h_tmp = _activation_map ["tanh" ](reset_gate * h2h + i2h )
684+ next_h = (_expr .const (1 , dtype ) - update_gate ) * next_h_tmp + update_gate * states [0 ]
685+ return next_h , [next_h ]
686+
687+ def _lstm_cell (data , states , i2h_weight , h2h_weight , i2h_bias , h2h_bias ):
688+ i2h = _op .nn .bias_add (_op .nn .dense (data , i2h_weight ), i2h_bias , axis = - 1 )
689+ h2h = _op .nn .bias_add (_op .nn .dense (states [0 ], h2h_weight ), h2h_bias , axis = - 1 )
690+ gates = i2h + h2h
691+ slice_gates = _op .split (gates , indices_or_sections = 4 , axis = 1 )
692+ in_gate = _activation_map ["sigmoid" ](slice_gates [0 ])
693+ forget_gate = _activation_map ["sigmoid" ](slice_gates [1 ])
694+ in_transform = _activation_map ["tanh" ](slice_gates [2 ])
695+ out_gate = _activation_map ["sigmoid" ](slice_gates [3 ])
696+ next_c = forget_gate * states [1 ] + in_gate * in_transform
697+ next_h = out_gate * _activation_map ["tanh" ](next_c )
698+ return next_h , [next_h , next_c ]
699+
700+ num_layers = attrs .get_int ("num_layers" , 1 )
701+ mode = attrs .get_str ("mode" )
702+ if mode .startswith ("rnn" ):
703+ mode , activation = mode .split ('_' )
704+ assert mode in ["rnn" , "gru" , "lstm" ]
705+ bidirectional = attrs .get_bool ("bidirectional" , False )
706+ if bidirectional :
707+ raise tvm .error .OpAttributeUnimplemented (
708+ "Bidirectional RNN op is not supported yet" )
709+ layout = attrs .get_str ("layout" , "TNC" )
710+ if layout != "TNC" :
711+ raise tvm .error .OpAttributeUnimplemented (
712+ "RNN with layout other than TNC is not supported yet" )
713+ num_states = 2 if mode == 'lstm' else 1
714+ assert len (inputs ) == num_states + 2
715+
716+ seq_data = inputs [0 ]
717+ concat_weight = inputs [1 ]
718+ concat_states = inputs [2 :]
719+ seq_len = int (ir_pass .infer_type (seq_data ).checked_type .shape [0 ])
720+ assert len (concat_weight ) == num_layers * 4
721+
722+ weights = []
723+ bias = []
724+ states = []
725+ for i in range (num_layers ):
726+ w = []
727+ b = []
728+ s = []
729+ for j in range (2 ):
730+ w .append (concat_weight [i * 2 + j ].args [0 ])
731+ b .append (concat_weight [num_layers * 2 + i * 2 + j ].args [0 ])
732+ for state in concat_states :
733+ s .append (_op .take (state , _expr .const (i , "int32" ), axis = 0 ))
734+ weights .append (w )
735+ bias .append (b )
736+ states .append (s )
737+
738+ seq_output = []
739+ for t in range (seq_len ):
740+ data = _op .take (seq_data , _expr .const (t , "int32" ), axis = 0 )
741+ for l in range (num_layers ):
742+ if mode == "rnn" :
743+ out , new_states = _rnn_cell (data , states [l ], * weights [l ], * bias [l ], activation )
744+ elif mode == "gru" :
745+ out , new_states = _gru_cell (data , states [l ], * weights [l ], * bias [l ])
746+ else : # mode == "lstm"
747+ out , new_states = _lstm_cell (data , states [l ], * weights [l ], * bias [l ])
748+ states [l ] = new_states
749+ data = out
750+ seq_output .append (out )
751+
752+ outputs = [_op .stack (seq_output , axis = 0 )]
753+ for i in range (num_states ):
754+ outputs .append (_op .stack ([s [i ] for s in states ], axis = 0 ))
755+ return outputs
756+
757+
660758# Note: due to attribute conversion constraint
661759# ops in the identity set must be attribute free
662760_identity_list = [
@@ -807,6 +905,9 @@ def _mx_argsort(inputs, attrs):
807905 "_contrib_box_nms" : _mx_box_nms ,
808906 "_contrib_DeformableConvolution" : _mx_deformable_convolution ,
809907 "_contrib_AdaptiveAvgPooling2D" : _mx_adaptive_avg_pooling ,
908+ # NLP
909+ "RNN" : _mx_rnn_layer ,
910+ "_rnn_param_concat" : _mx_rnn_param_concat ,
810911 # List of missing operators that are present in NNVMv1
811912 # TODO(tvm-tvm): support all operators.
812913 #
0 commit comments