Skip to content

Commit 6737739

Browse files
Ashutosh Parkhisrkreddy1238
authored andcommitted
[Tensorflow] Support for Crop (apache#2285)
fixes fixes
1 parent a9bd559 commit 6737739

File tree

2 files changed

+34
-0
lines changed

2 files changed

+34
-0
lines changed

nnvm/python/nnvm/frontend/tensorflow.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,21 @@ def _impl(inputs, attr, params):
388388

389389
return _impl
390390

391+
def _slice():
392+
def _impl(inputs, attr, params):
393+
begin = params.pop(inputs[1].list_output_names()[0]).asnumpy().tolist()
394+
size = params.pop(inputs[2].list_output_names()[0]).asnumpy().tolist()
395+
data_shape = attr['_input_shapes'][inputs[0]]
396+
data_dim = len(data_shape)
397+
end = size
398+
for i in range(data_dim):
399+
if size[i] == -1:
400+
end[i] = data_shape[i] - begin[i]
401+
else:
402+
end[i] += begin[i]
403+
return _sym.strided_slice(inputs[0], begin=begin, end=size)
404+
return _impl
405+
391406
def _reshape():
392407
def _impl(inputs, attr, params):
393408
try:
@@ -883,6 +898,7 @@ def _expand_dims_0d_aware(data, attr, axis, num_newaxis=1):
883898
'Sum' : _sum(),
884899
'Square' : _square(),
885900
'Pack' : _pack(),
901+
'Slice' : _slice(),
886902
'LeakyRelu' : AttrCvt('leaky_relu'),
887903
'Relu' : AttrCvt('relu'),
888904
'Reshape' : _reshape(),

nnvm/tests/python/frontend/tensorflow/test_forward.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -655,6 +655,23 @@ def test_forward_resize_bilinear():
655655
_test_resize_bilinear((6, 32, 64, 64), [20, 20], True)
656656

657657

658+
#######################################################################
659+
# Crop to bounding box
660+
# --------------------
661+
662+
def _test_crop(in_shape, off_h, off_w, tar_h, tar_w):
663+
""" Crop to bounding box """
664+
data = np.random.uniform(size=in_shape).astype('float32')
665+
with tf.Graph().as_default():
666+
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
667+
tf.image.crop_to_bounding_box(in_data, off_h, off_w, tar_h, tar_w)
668+
compare_tf_with_tvm(data, 'Placeholder:0', 'crop_to_bounding_box/Slice:0')
669+
670+
def test_forward_crop():
671+
""" Crop to bounding box """
672+
_test_crop((1, 224, 224, 3), 20, 20, 120, 120)
673+
674+
658675
#######################################################################
659676
# LSTM
660677
# ----
@@ -1139,6 +1156,7 @@ def test_forward_rel_ops():
11391156
test_forward_squeeze()
11401157
test_forward_pack()
11411158
test_forward_resize_bilinear()
1159+
test_forward_crop()
11421160
test_forward_pad()
11431161
test_forward_gather()
11441162
test_forward_stridedslice()

0 commit comments

Comments
 (0)