@@ -149,7 +149,7 @@ def _mx_conv2d_transpose(inputs, attrs):
149149 new_attrs ["groups" ] = attrs .get_int ("num_group" , 1 )
150150 new_attrs ["data_layout" ] = data_layout
151151 new_attrs ["kernel_layout" ] = kernel_layout
152- use_bias = not attrs .get_bool ("no_bias" , False )
152+ use_bias = not attrs .get_bool ("no_bias" , True )
153153 res = _op .nn .conv2d_transpose (inputs [0 ], inputs [1 ], ** new_attrs )
154154
155155 if use_bias :
@@ -277,6 +277,28 @@ def _mx_slice_axis(inputs, attrs):
277277 return _op .strided_slice (inputs [0 ], begin , end )
278278
279279
280+ def _mx_crop_like (inputs , attrs ):
281+ if len (inputs ) < 2 :
282+ raise tvm .error .OpAttributeUnimplemented (
283+ "Only support crop_like pattern for operator Crop." )
284+ if attrs .get_bool ("center_crop" , False ):
285+ raise tvm .error .OpAttributeUnimplemented (
286+ "Center crop is not supported in operator Crop." )
287+ if attrs .get_int_tuple ("h_w" , (0 , 0 )) != (0 , 0 ):
288+ raise tvm .error .OpAttributeUnimplemented (
289+ "Doesn't support h_w in operator Crop." )
290+ offset = attrs .get_int_tuple ("offset" , (0 , 0 ))
291+ new_attrs = {}
292+ if offset == (0 , 0 ):
293+ new_attrs ["axes" ] = (2 , 3 )
294+ return _op .slice_like (* inputs , ** new_attrs )
295+ like_shape = ir_pass .infer_type (inputs [1 ]).checked_type .shape
296+ new_attrs ['begin' ] = [0 , 0 , offset [0 ], offset [1 ]]
297+ new_attrs ['end' ] = [like_shape [0 ], like_shape [1 ], offset [0 ]+ like_shape [2 ],
298+ offset [1 ]+ like_shape [3 ]]
299+ return _op .strided_slice (inputs [0 ], ** new_attrs )
300+
301+
280302def _mx_split (inputs , attrs ):
281303 axis = attrs .get_int ("axis" , 1 )
282304 new_attrs = {}
@@ -300,6 +322,10 @@ def _mx_softmax_output(inputs, attrs):
300322 return _op .nn .softmax (inputs [0 ])
301323
302324
325+ def _mx_linear_regression_output (inputs , _ ):
326+ return inputs [0 ]
327+
328+
303329def _mx_concat (inputs , attrs ):
304330 axis = attrs .get_int ("dim" , 1 )
305331 return _op .concatenate (tuple (inputs ), axis = axis )
@@ -890,6 +916,7 @@ def _lstm_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias):
890916 "argsort" : _mx_argsort ,
891917 "SoftmaxOutput" : _mx_softmax_output ,
892918 "SoftmaxActivation" : _mx_softmax_activation ,
919+ "LinearRegressionOutput" : _mx_linear_regression_output ,
893920 "smooth_l1" : _mx_smooth_l1 ,
894921 # vision
895922 "_contrib_BilinearResize2D" : _mx_resize ,
@@ -905,11 +932,12 @@ def _lstm_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias):
905932 # NLP
906933 "RNN" : _mx_rnn_layer ,
907934 "_rnn_param_concat" : _mx_rnn_param_concat ,
935+ # Depricated:
936+ "Crop" : _mx_crop_like ,
908937 # List of missing operators that are present in NNVMv1
909938 # TODO(tvm-tvm): support all operators.
910939 #
911940 # "broadcast_to",
912- # "Crop" : _crop_like,
913941}
914942
915943# set identity list
0 commit comments