Skip to content

Commit d1a0c90

Browse files
Rastereryzhliu
authored andcommitted
[FRONTEND][TENSORFLOW]Add Split and realdiv op support (apache#2123)
* Add Split and realdiv op support * Fix the pad calculation in the case of dilated convolution
1 parent 9c1195e commit d1a0c90

File tree

2 files changed

+114
-15
lines changed

2 files changed

+114
-15
lines changed

nnvm/python/nnvm/frontend/tensorflow.py

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def _impl(inputs, attr, params):
215215
attr['channels'] = input_shape[3] * depth_mult
216216

217217
if 'dilations' in attr:
218-
attr['dilations'] = (attr['dilations'][0], attr['dilations'][1])
218+
attr['dilations'] = (attr['dilations'][1], attr['dilations'][2])
219219
attr['strides'] = (attr['strides'][1], attr['strides'][2])
220220
elif attr['data_format'] == 'NCHW':
221221
depth_mult, _, kernel_h, kernel_w = weights_shape
@@ -252,8 +252,12 @@ def _impl(inputs, attr, params):
252252
in_h = input_shape[2]
253253
in_w = input_shape[3]
254254

255-
pad_v = _get_pad_pair(in_h, kernel_h, stride_h)
256-
pad_h = _get_pad_pair(in_w, kernel_w, stride_w)
255+
dilation_h = attr['dilations'][0]
256+
dilation_w = attr['dilations'][1]
257+
dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
258+
dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
259+
pad_v = _get_pad_pair(in_h, dilated_kernel_h, stride_h)
260+
pad_h = _get_pad_pair(in_w, dilated_kernel_w, stride_w)
257261

258262
if attr['data_format'] == 'NHWC':
259263
inputs[0] = _sym.pad(data=inputs[0],
@@ -783,6 +787,15 @@ def _impl(inputs, attr, params):
783787
)(inputs, attr)
784788
return _impl
785789

790+
def _split():
791+
def _impl(inputs, attr, params):
792+
axis = params.pop(inputs[0].list_output_names()[0])
793+
return AttrCvt(
794+
op_name="split", ignores=['T'],
795+
transforms={'num_split': 'indices_or_sections'},
796+
extras={'axis': axis.asnumpy()[0]})(inputs[1], attr)
797+
return _impl
798+
786799
# compatible operators that do NOT require any conversion.
787800
_identity_list = []
788801

@@ -813,6 +826,7 @@ def _impl(inputs, attr, params):
813826
'Add' : _elemwise('add'),
814827
'Sub' : _elemwise('sub'),
815828
'Mul' : _elemwise('mul'),
829+
'RealDiv' : _elemwise('div'),
816830
'Maximum' : _elemwise('max'),
817831
'Minimum' : _elemwise('min'),
818832
'Sum' : _sum(),
@@ -849,6 +863,7 @@ def _impl(inputs, attr, params):
849863
'GreaterEqual' : _broadcast('greater_equal'),
850864
'Equal' : _broadcast('equal'),
851865
'NotEqual' : _broadcast('not_equal'),
866+
'Split' : _split(),
852867
}
853868

854869
# _convert_map_rnn defines maps of rnn operator name to
@@ -1144,21 +1159,26 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
11441159
# Pass the target layout
11451160
attr["_target_layout"] = layout
11461161

1147-
#ToDo: Some of the tensorflow operators internaly maintain
1148-
#execution layers and its output name will the layer number along with
1149-
#graph node name.eg: Node name:- 'Model/RNN/cell_0/RnnCell', but the
1150-
#output name will be 'Model/RNN/cell_0/RnnCell:0'. In this case,
1151-
#the digit has to be ignored.
1152-
if ":" in node.input[0]:
1153-
in_name, _ = node.input[0].split(':')
1154-
node.input[0] = in_name
1155-
11561162
# Fill shapes for all inputs in a list
11571163
inputs = []
11581164
for i in node.input:
1159-
if i in self._nodes:
1160-
inputs.append(self._nodes[i])
1161-
input_shapes[self._nodes[i]] = self._output_shapes[i]
1165+
#ToDo: Some of the tensorflow operators internaly maintain
1166+
#execution layers and its output name will the layer number along with
1167+
#graph node name.eg: Node name:- 'Model/RNN/cell_0/RnnCell', but the
1168+
#output name will be 'Model/RNN/cell_0/RnnCell:0'. In this case,
1169+
#the digit has to be ignored.
1170+
tensor_name = i.split(':')
1171+
node_name = tensor_name[0]
1172+
if node_name in self._nodes:
1173+
in_sym = self._nodes[node_name]
1174+
if len(in_sym.list_output_names()) > 1:
1175+
tensor_slot = int(tensor_name[1]) if len(tensor_name) > 1 else 0
1176+
in_sym = in_sym[tensor_slot]
1177+
input_shape = (self._output_shapes[node_name])[tensor_slot]
1178+
else:
1179+
input_shape = self._output_shapes[node_name][0]
1180+
inputs.append(in_sym)
1181+
input_shapes[in_sym] = [input_shape]
11621182
attr['_input_shapes'] = input_shapes
11631183

11641184
inputs = self._fix_extranodes(node.op, attr, inputs)

nnvm/tests/python/frontend/tensorflow/test_forward.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,83 @@ def test_forward_gather():
502502
_test_gather((4,3,5,6), (1,4), [[2,1,0,0]], 0, 'float32')
503503

504504

505+
#######################################################################
506+
# Split
507+
# -----
508+
509+
def _test_split(in_shape, axis, num_split, dtype):
510+
""" One iteration of a Split """
511+
512+
with tf.Graph().as_default():
513+
in_data = tf.placeholder(dtype, in_shape, name="in_data")
514+
tf.split(in_data, num_split, axis)
515+
np_data = np.random.uniform(size=in_shape).astype(dtype)
516+
compare_tf_with_tvm(np_data, 'in_data:0', 'split:0')
517+
518+
def test_forward_split():
519+
'''test split layer'''
520+
# rank 1
521+
_test_split((3,), 0, 1, 'float32')
522+
_test_split((3,), 0, 3, 'float32')
523+
_test_split((6,), 0, 3, 'float32')
524+
# rank 2
525+
_test_split((6, 2), 0, 3, 'float32')
526+
_test_split((2, 6), 1, 3, 'float32')
527+
# rank 3
528+
_test_split((6, 2, 4), 0, 3, 'float32')
529+
_test_split((2, 6, 4), 1, 3, 'float32')
530+
_test_split((2, 4, 6), 2, 3, 'float32')
531+
# rank 4
532+
_test_split((6, 1, 3, 5), 0, 3, 'float32')
533+
_test_split((1, 6, 3, 5), 1, 3, 'float32')
534+
_test_split((1, 3, 6, 5), 2, 3, 'float32')
535+
_test_split((1, 3, 5, 6), 3, 3, 'float32')
536+
# split along negative axis
537+
_test_split((6, 1, 3, 5), -4, 3, 'float32')
538+
_test_split((1, 6, 3, 5), -3, 3, 'float32')
539+
_test_split((1, 3, 6, 5), -2, 3, 'float32')
540+
_test_split((1, 3, 5, 6), -1, 3, 'float32')
541+
542+
543+
#######################################################################
544+
# Split followed by concat
545+
# ------------------------
546+
547+
def _test_split_concat(in_shape, axis, num_split, dtype):
548+
""" One iteration of a split_concat pair"""
549+
550+
with tf.Graph().as_default():
551+
in_data = tf.placeholder(dtype, in_shape, name="in_data")
552+
splitted = tf.split(in_data, num_split, axis)
553+
tf.concat(splitted, axis)
554+
np_data = np.random.uniform(size=in_shape).astype(dtype)
555+
compare_tf_with_tvm(np_data, 'in_data:0', 'concat:0')
556+
557+
def test_forward_split_concat():
558+
'''test split followed by concat layers'''
559+
# rank 1
560+
_test_split_concat((3,), 0, 1, 'float32')
561+
_test_split_concat((3,), 0, 3, 'float32')
562+
_test_split_concat((6,), 0, 3, 'float32')
563+
# rank 2
564+
_test_split_concat((6, 2), 0, 3, 'float32')
565+
_test_split_concat((2, 6), 1, 3, 'float32')
566+
# rank 3
567+
_test_split_concat((6, 2, 4), 0, 3, 'float32')
568+
_test_split_concat((2, 6, 4), 1, 3, 'float32')
569+
_test_split_concat((2, 4, 6), 2, 3, 'float32')
570+
# rank 4
571+
_test_split((6, 1, 3, 5), 0, 3, 'float32')
572+
_test_split((1, 6, 3, 5), 1, 3, 'float32')
573+
_test_split((1, 3, 6, 5), 2, 3, 'float32')
574+
_test_split((1, 3, 5, 6), 3, 3, 'float32')
575+
# split along negative axis
576+
_test_split((6, 1, 3, 5), -4, 3, 'float32')
577+
_test_split((1, 6, 3, 5), -3, 3, 'float32')
578+
_test_split((1, 3, 6, 5), -2, 3, 'float32')
579+
_test_split((1, 3, 5, 6), -1, 3, 'float32')
580+
581+
505582
#######################################################################
506583
# Multi Input to graph
507584
# --------------------
@@ -1061,6 +1138,8 @@ def test_forward_rel_ops():
10611138
test_forward_pad()
10621139
test_forward_gather()
10631140
test_forward_stridedslice()
1141+
test_forward_split()
1142+
test_forward_split_concat()
10641143

10651144
# Activations
10661145
test_forward_sigmoid()

0 commit comments

Comments
 (0)