Skip to content

Commit 0586d44

Browse files
icemelonwweic
authored andcommitted
[Relay][Frontend] Add Crop op converter (apache#3241)
* Add Crop op converter * lint * x
1 parent 82b29db commit 0586d44

File tree

3 files changed

+57
-3
lines changed

3 files changed

+57
-3
lines changed

nnvm/python/nnvm/frontend/mxnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ def _crop_like(inputs, attrs):
269269
raise tvm.error.OpAttributeUnimplemented(
270270
'Center crop is not supported in operator crop_like.')
271271
if len(inputs) < 2:
272-
raise RuntimeError("Only support crop_like pattern.")
272+
raise tvm.error.OpAttributeUnimplemented("Only support crop_like pattern.")
273273
new_attrs["axis"] = [2, 3]
274274
return get_nnvm_op('slice_like')(inputs[0], inputs[1], **new_attrs)
275275

python/tvm/relay/frontend/mxnet.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def _mx_conv2d_transpose(inputs, attrs):
149149
new_attrs["groups"] = attrs.get_int("num_group", 1)
150150
new_attrs["data_layout"] = data_layout
151151
new_attrs["kernel_layout"] = kernel_layout
152-
use_bias = not attrs.get_bool("no_bias", False)
152+
use_bias = not attrs.get_bool("no_bias", True)
153153
res = _op.nn.conv2d_transpose(inputs[0], inputs[1], **new_attrs)
154154

155155
if use_bias:
@@ -277,6 +277,28 @@ def _mx_slice_axis(inputs, attrs):
277277
return _op.strided_slice(inputs[0], begin, end)
278278

279279

280+
def _mx_crop_like(inputs, attrs):
281+
if len(inputs) < 2:
282+
raise tvm.error.OpAttributeUnimplemented(
283+
"Only support crop_like pattern for operator Crop.")
284+
if attrs.get_bool("center_crop", False):
285+
raise tvm.error.OpAttributeUnimplemented(
286+
"Center crop is not supported in operator Crop.")
287+
if attrs.get_int_tuple("h_w", (0, 0)) != (0, 0):
288+
raise tvm.error.OpAttributeUnimplemented(
289+
"Doesn't support h_w in operator Crop.")
290+
offset = attrs.get_int_tuple("offset", (0, 0))
291+
new_attrs = {}
292+
if offset == (0, 0):
293+
new_attrs["axes"] = (2, 3)
294+
return _op.slice_like(*inputs, **new_attrs)
295+
like_shape = ir_pass.infer_type(inputs[1]).checked_type.shape
296+
new_attrs['begin'] = [0, 0, offset[0], offset[1]]
297+
new_attrs['end'] = [like_shape[0], like_shape[1], offset[0]+like_shape[2],
298+
offset[1]+like_shape[3]]
299+
return _op.strided_slice(inputs[0], **new_attrs)
300+
301+
280302
def _mx_split(inputs, attrs):
281303
axis = attrs.get_int("axis", 1)
282304
new_attrs = {}
@@ -300,6 +322,10 @@ def _mx_softmax_output(inputs, attrs):
300322
return _op.nn.softmax(inputs[0])
301323

302324

325+
def _mx_linear_regression_output(inputs, _):
326+
return inputs[0]
327+
328+
303329
def _mx_concat(inputs, attrs):
304330
axis = attrs.get_int("dim", 1)
305331
return _op.concatenate(tuple(inputs), axis=axis)
@@ -890,6 +916,7 @@ def _lstm_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias):
890916
"argsort" : _mx_argsort,
891917
"SoftmaxOutput" : _mx_softmax_output,
892918
"SoftmaxActivation" : _mx_softmax_activation,
919+
"LinearRegressionOutput" : _mx_linear_regression_output,
893920
"smooth_l1" : _mx_smooth_l1,
894921
# vision
895922
"_contrib_BilinearResize2D" : _mx_resize,
@@ -905,11 +932,12 @@ def _lstm_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias):
905932
# NLP
906933
"RNN" : _mx_rnn_layer,
907934
"_rnn_param_concat" : _mx_rnn_param_concat,
935+
# Depricated:
936+
"Crop" : _mx_crop_like,
908937
# List of missing operators that are present in NNVMv1
909938
# TODO(tvm-tvm): support all operators.
910939
#
911940
# "broadcast_to",
912-
# "Crop" : _crop_like,
913941
}
914942

915943
# set identity list

tests/python/frontend/mxnet/test_forward.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -583,6 +583,31 @@ def verify(mode, input_size, seq_len, hidden_size, num_layers, batch=1):
583583
verify(mode, 64, 10, 64, 2)
584584
verify(mode, 64, 10, 32, 2)
585585

586+
def test_forward_Crop():
587+
def verify(xshape, yshape, offset=None):
588+
x_data = np.random.uniform(size=xshape).astype("float32")
589+
y_data = np.random.uniform(size=yshape).astype("float32")
590+
if offset is None:
591+
mx_sym = mx.sym.Crop(mx.sym.var("x"), mx.sym.var("y"))
592+
ref_res = mx.nd.Crop(mx.nd.array(x_data), mx.nd.array(y_data))
593+
else:
594+
mx_sym = mx.sym.Crop(mx.sym.var("x"), mx.sym.var("y"), offset=offset)
595+
ref_res = mx.nd.Crop(mx.nd.array(x_data), mx.nd.array(y_data), offset=offset)
596+
new_sym, _ = relay.frontend.from_mxnet(mx_sym, {"x": xshape, "y": yshape})
597+
for target, ctx in ctx_list():
598+
for kind in ["graph", "debug"]:
599+
intrp = relay.create_executor(kind, ctx=ctx, target=target)
600+
if offset is None or offset == (0, 0):
601+
op_res = intrp.evaluate(new_sym)(x_data, y_data)
602+
else:
603+
op_res = intrp.evaluate(new_sym)(x_data)
604+
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
605+
verify((1, 3, 40, 40), (1, 3, 20, 20))
606+
verify((1, 3, 40, 40), (1, 3, 20, 20), (0, 0))
607+
verify((1, 3, 40, 40), (1, 3, 20, 20), (10, 10))
608+
verify((5, 32, 40, 40), (5, 32, 25, 25))
609+
verify((5, 32, 40, 40), (5, 32, 25, 25), (5, 5))
610+
586611

587612
if __name__ == '__main__':
588613
test_forward_mlp()
@@ -624,3 +649,4 @@ def verify(mode, input_size, seq_len, hidden_size, num_layers, batch=1):
624649
test_forward_gather_nd()
625650
test_forward_bilinear_resize()
626651
test_forward_rnn_layer()
652+
test_forward_Crop()

0 commit comments

Comments
 (0)