diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index ba076cc2819f..7319d5eb4a7e 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -777,12 +777,12 @@ def _impl(inputs, attr, params): ignores=['name', 'Tidx'])([inputs[0]], attr) return _impl -def _reduce_all(): +def _reduce(op): def _impl(inputs, attr, params): axis = params.pop(inputs[1].name_hint).asnumpy() axis = tuple(axis) return AttrCvt( - op_name='all', + op_name=op, extras={'axis': axis}, transforms={'keep_dims':'keepdims'}, ignores=['name', 'Tidx'])([inputs[0]], attr) @@ -807,6 +807,14 @@ def _impl(inputs, attr, params): 'Taxis', '_class'])(new_input, attr) return _impl +def _gather_nd(): + """GatherNd""" + def _impl(inputs, attr, params): + return AttrCvt(op_name="gather_nd", + ignores=['Tindices', 'Tparams',\ + 'Taxis', '_class'])(inputs, attr) + return _impl + def _stridedSlice(): def _impl(inputs, attr, params): """Strided Slice. @@ -971,15 +979,18 @@ def _impl(inputs, attr, params): def _range(): def _impl(inputs, attr, params): - start = _get_num_param(params, inputs[0]) - limit = _get_num_param(params, inputs[1]) - delta = _get_num_param(params, inputs[2]) - - name = attr["_node_name"] - params[name] = tvm.nd.array([start, limit, delta]) - return [_expr.var(name, - shape=params[name].shape, - dtype='int32')] + start = params.pop(inputs[0].name_hint).asnumpy()[0] + limit = params.pop(inputs[1].name_hint).asnumpy()[0] \ + if hasattr(inputs[1], "name_hint") else params.pop('Rank').asnumpy()[0] + delta = params.pop(inputs[2].name_hint).asnumpy()[0] + dtype = attr['dtype'].name if 'dtype' in attr else "int32" + return AttrCvt( + op_name="arange", + ignores=['Tidx'], + extras={'start': start, + "stop": limit, + 'step': delta, + 'dtype': dtype})([], attr) return _impl def _elu(): @@ -1099,6 +1110,13 @@ def _impl(inputs, attr, params): extras={'k': k, 'is_ascend': False, 'dtype': 'int32'})(inputs, attr) return _impl +def _floordiv(): + def _impl(inputs, attr, params): + assert len(inputs) == 2 + div = AttrCvt('divide')(inputs, attr) + return _get_relay_op('floor')(div) + return _impl + def _logical(name): def _impl(inputs, attr, params): return AttrCvt(op_name=name)(inputs, attr) @@ -1207,8 +1225,9 @@ def _impl(inputs, attr, params): # for 1 to N mapping(composed), use custom callable functions # for N to 1 mapping, currently not supported(?) _convert_map = { + 'Abs' : AttrCvt('abs'), 'Add' : _elemwise('add'), - 'All' : _reduce_all(), + 'All' : _reduce('all'), 'ArgMax' : _argx(_op.argmax, 'argmax'), 'ArgMin' : _argx(_op.argmin, 'argmin'), 'AvgPool' : _pooling('avg_pool'), @@ -1232,26 +1251,33 @@ def _impl(inputs, attr, params): 'ExpandDims' : _expand_dims(), 'Fill' : _fill(), 'Floor' : AttrCvt('floor'), + 'FloorDiv' : _floordiv(), 'FusedBatchNorm' : _fused_batch_norm(), 'FusedBatchNormV2' : _fused_batch_norm(), 'Gather' : _gather(), + 'GatherNd' : _gather_nd(), 'GatherV2' : _gather(), 'Greater' : _broadcast('greater'), 'GreaterEqual' : _broadcast('greater_equal'), 'Identity' : _identity(), 'LeakyRelu' : AttrCvt('leaky_relu'), + 'LeftShift' : AttrCvt('left_shift'), 'Less' : _broadcast('less'), 'LessEqual' : _broadcast('less_equal'), 'Log' : AttrCvt('log'), 'LogicalAnd' : _logical('logical_and'), 'LogicalOr' : _logical('logical_or'), 'LogicalNot' : _logical('logical_not'), + 'LogSoftmax' : AttrCvt('log_softmax'), 'LRN' : _lrn(), 'MatMul' : _matmul(), + 'Max' : _reduce('max'), 'MaxPool' : _pooling('max_pool'), 'Maximum' : _elemwise('maximum'), 'Mean' : _mean(), + 'Min' : _reduce('min'), 'Minimum' : _elemwise('minimum'), + 'Mod' : _elemwise('mod'), 'Mul' : _elemwise('multiply'), 'Neg' : AttrCvt('negative'), 'NotEqual' : _broadcast('not_equal'), @@ -1269,6 +1295,7 @@ def _impl(inputs, attr, params): 'ResizeBilinear' : _resize_bilinear(), 'ResizeBicubic' : _resize_bilinear(), 'ReverseV2' : _reverse_v2(), + 'RightShift' : AttrCvt('right_shift'), 'Round' : AttrCvt('round'), 'Rsqrt' : _rsqrt(), 'Select' : _where(), @@ -1292,7 +1319,9 @@ def _impl(inputs, attr, params): 'Tile' : _tile(), 'TopKV2' : _topk(), 'Transpose' : _transpose(), + 'TruncateMod' : _elemwise('mod'), 'Unpack' : _unpack(), + 'ZerosLike' : AttrCvt('zeros_like'), } diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 498c4735a9e8..6fc825a8924c 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -64,6 +64,7 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1, layout=layout, shape=shape_dict, outputs=out_names) + with relay.build_config(opt_level=opt_level): graph, lib, params = relay.build(sym, target, target_host, params) @@ -642,10 +643,53 @@ def test_forward_stridedslice(): 'float32', shrink_axis_mask=8, new_axis_mask=1, ellipsis_mask=2, begin_mask=5, end_mask=8) +####################################################################### +# FloorDiv, RealDiv +# ----------------- + +def _test_forward_divide(ip_shape, dtype): + np_numer = np.random.uniform(-100, 100, size=ip_shape).astype(dtype) + np_denomin = np.random.uniform(1, 100, size=ip_shape).astype(dtype) + tf.reset_default_graph() + numerator = tf.placeholder(dtype, ip_shape, name="numer") + denominator = tf.placeholder(dtype, ip_shape, name="denomin") + tf.math.divide(numerator, denominator, name='RealDiv') + compare_tf_with_tvm([np_numer, np_denomin], ['numer:0', 'denomin:0'], 'RealDiv:0') + +def _test_forward_floordiv(ip_shape, dtype): + np_numer = np.random.uniform(-100, 100, size=ip_shape).astype(dtype) + tf.reset_default_graph() + numerator = tf.placeholder(dtype, ip_shape, name="numer") + tf.math.floordiv(numerator, tf.constant(5, dtype=dtype), name='FloorDiv') + compare_tf_with_tvm([np_numer], ['numer:0'], 'FloorDiv:0') + +def test_forward_divide(): + '''test FloorDiv, RealDiv''' + _test_forward_divide((4,), 'int32') + _test_forward_divide((4, 3, 7), 'float32') + _test_forward_floordiv((4, 3, 7), 'float32') + ####################################################################### -# Gather, GatherV2 -# ---------------- +# TruncateMod +# ----------- +def _test_forward_truncatemod(ip_shape, dtype): + np_data_1 = np.random.uniform(-100, 100, size=ip_shape).astype(dtype) + np_data_2 = np.random.uniform(1, 10, size=ip_shape).astype(dtype) + tf.reset_default_graph() + in_data_1 = tf.placeholder(dtype, ip_shape, name="in_data_1") + in_data_2 = tf.placeholder(dtype, ip_shape, name="in_data_2") + tf.truncatemod(in_data_1, in_data_2, name='truncatemod') + compare_tf_with_tvm([np_data_1, np_data_2], ['in_data_1:0', 'in_data_2:0'], 'truncatemod:0') + +def test_forward_truncatemod(): + '''test TruncateMod''' + _test_forward_truncatemod((4, 3, 7), 'int32') + + +####################################################################### +# Gather, GatherV2, GatherNd +# -------------------------- def _test_gather(ip_shape, indice_shape, indice_value, axis, dtype): """ One iteration of a GatherV2 """ @@ -718,6 +762,33 @@ def test_forward_gather_v1(): _test_gather_v1((4, 3, 5, 6), (1, 4), [[2, 1, 0, 0]], 'float32') +def test_forward_gather_nd(): + """test operator GatherNd""" + np_data = np.random.uniform(1, 100, size=(2, 2)).astype(np.float32) + tf.reset_default_graph() + in_data = tf.placeholder(tf.float32, (2, 2), name="in_data") + tf.gather_nd(in_data, indices=[[1, 0], [0, 1]], name="gather_nd") + compare_tf_with_tvm([np_data], ['in_data:0'], 'gather_nd:0') + + +####################################################################### +# BiasAdd +# ------- +def test_forward_bias_add(): + """test Op BiasAdd""" + def check_bias_add(lh_shpae, rh_shape, dtype): + tf.reset_default_graph() + lh_data = np.random.uniform(size=lh_shpae).astype(dtype) + rh_data = np.random.uniform(size=rh_shape).astype(dtype) + lft_data = tf.placeholder(dtype, name="lft_data") + rgt_data = tf.placeholder(dtype, name="rgt_data") + tf.nn.bias_add(lft_data, rgt_data, name="BiasAdd") + compare_tf_with_tvm([lh_data, rh_data], ['lft_data:0', 'rgt_data:0'], 'BiasAdd:0') + + check_bias_add((10, 8, 16, 32), (32,), dtype="int32") + check_bias_add((10, 20), (20,), dtype="float32") + + ####################################################################### # Split # ----- @@ -1109,6 +1180,32 @@ def test_forward_pack(): _test_pack(axis, [3]) _test_pack(0, []) + +####################################################################### +# Unpack +# ------ +def _test_forward_unpack(in_shape, axis, dtype): + """test operator Unpack""" + np_data = np.random.uniform(-100, 100, size=in_shape).astype(dtype) + tf.reset_default_graph() + in_data = tf.placeholder(dtype, in_shape, name="in_data") + tf.unstack(in_data, axis=axis, name="Unpack") + compare_tf_with_tvm([np_data], ['in_data:0'], 'Unpack:0') + +def test_forward_unpack(): + _test_forward_unpack((3,), 0, 'int32') + _test_forward_unpack((3,), -1, 'int16') + _test_forward_unpack((21, 23, 3), 2, 'float32') + +####################################################################### +# Range +# ----- +def test_forward_range(): + """test operator Range""" + tf.reset_default_graph() + tf.range(1, 18, 3, name="range") + compare_tf_with_tvm([], [], 'range:0') + ####################################################################### # Pad # --- @@ -1182,7 +1279,7 @@ def test_forward_logical(): ####################################################################### # Where, Select # ------------- -def test_where(): +def test_forward_where(): ''' Where: return elements depending on conditions''' with tf.Graph().as_default(): with tf.Session() as sess: @@ -1553,6 +1650,22 @@ def test_forward_tanh(): tf.nn.tanh(in1) compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Tanh:0') + +####################################################################### +# Softmax +# ------- +def test_forward_softmax(): + """test operator Softmax """ + def check_softmax(in_shape, axis, dtype): + np_data = np.random.uniform(-100, 100, size=in_shape).astype(dtype) + tf.reset_default_graph() + in_data = tf.placeholder(dtype, in_shape, name="in_data") + tf.nn.softmax(in_data, axis=axis, name="Softmax") + compare_tf_with_tvm([np_data], ['in_data:0'], 'Softmax:0') + check_softmax((2, 3, 5), 2, "float32") + check_softmax((2, 3, 5), -1, "float32") + + ####################################################################### # Tensor # ------ @@ -1565,6 +1678,29 @@ def test_forward_round(): tf.round(in_data, name="round") compare_tf_with_tvm([np_data], ['in_data:0'], 'round:0') +def test_forward_abs(): + """test operator Abs""" + np_data = np.random.uniform(1, 100, size=(9, 11)).astype(np.float32) + tf.reset_default_graph() + in_data = tf.placeholder(tf.float32, (9, 11), name="in_data") + tf.math.abs(in_data, name="abs") + compare_tf_with_tvm([np_data], ['in_data:0'], 'abs:0') + +def _test_forward_zeros_like(in_shape, dtype): + np_data = np.random.uniform(-10, 10, size=in_shape).astype(dtype) + tf.reset_default_graph() + in_data = tf.placeholder(dtype, in_shape, name="in_data") + tf.zeros_like(in_data, name="zeros_like") + compare_tf_with_tvm([np_data], ['in_data:0'], 'zeros_like:0') + +def test_forward_zeros_like(): + if tf.__version__ < LooseVersion('1.2'): + _test_forward_zeros_like((2, 3), "int32") + _test_forward_zeros_like((2, 3, 5), "int8") + _test_forward_zeros_like((2, 3, 5, 7), "uint16") + _test_forward_zeros_like((2, 3, 11), "float32") + _test_forward_zeros_like((2, 3, 11), "float64") + def _test_forward_reverse_v2(in_shape, axis, dtype): np_data = np.random.uniform(-10, 10, size=in_shape).astype(dtype) tf.reset_default_graph() @@ -1588,6 +1724,14 @@ def test_forward_sign(): tf.sign(in_data, name="sign") compare_tf_with_tvm([np_data], ['in_data:0'], 'sign:0') +def test_forward_square(): + """test operator Square """ + np_data = np.random.uniform(1, 100, size=(2, 3, 5)).astype(np.float32) + tf.reset_default_graph() + in_data = tf.placeholder(tf.float32, (2, 3, 5), name="in_data") + tf.square(in_data, name="square") + compare_tf_with_tvm([np_data], ['in_data:0'], 'square:0') + def test_forward_pow_exp(): """test Pow and Exp """ np_in1 = np.random.uniform(-2, 2, size=(5, 7, 11)).astype(np.float32) @@ -1616,6 +1760,14 @@ def test_forward_negative(): tf.negative(in_data, name="negative") compare_tf_with_tvm([np_data], ['in_data:0'], 'negative:0') +def test_forward_log_softmax(): + """test operator LogSoftmax""" + np_data = np.random.uniform(1, 100, size=(9, 11)).astype(np.float32) + tf.reset_default_graph() + in_data = tf.placeholder(tf.float32, (9, 11), name="in_data") + tf.math.log_softmax(in_data, name="LogSoftmax") + compare_tf_with_tvm([np_data], ['in_data:0'], 'LogSoftmax:0') + def test_forward_softplus(): """test operator Softplus""" np_data = np.random.uniform(1, 10, size=(2, 3, 5)).astype(np.float32) @@ -1640,6 +1792,34 @@ def test_forward_sqrt(): tf.sqrt(in_data, name="sqrt") compare_tf_with_tvm([np_data], ['in_data:0'], 'sqrt:0') +def _test_forward_right_shift(in_shape, dtype): + """test operator RightShift""" + lh_data = np.random.randint(1, 3, size=in_shape).astype(dtype) + rh_data = np.random.randint(1, 8, size=in_shape).astype(dtype) + tf.reset_default_graph() + lft_data = tf.placeholder(dtype, in_shape, name="lft_data") + rgt_data = tf.placeholder(dtype, in_shape, name="rgt_data") + tf.bitwise.right_shift(lft_data, rgt_data, name="RightShift") + compare_tf_with_tvm([lh_data, rh_data], ['lft_data:0', 'rgt_data:0'], 'RightShift:0') + +def test_forward_right_shift(): + _test_forward_right_shift((7,), 'int32') + _test_forward_right_shift((3, 11), 'int16') + +def _test_forward_left_shift(in_shape, dtype): + """test operator LeftShift""" + lh_data = np.random.randint(100, 1000000, size=in_shape).astype(dtype) + rh_data = np.random.randint(1, 3, size=in_shape).astype(dtype) + tf.reset_default_graph() + lft_data = tf.placeholder(dtype, in_shape, name="lft_data") + rgt_data = tf.placeholder(dtype, in_shape, name="rgt_data") + tf.bitwise.left_shift(lft_data, rgt_data, name="LeftShift") + compare_tf_with_tvm([lh_data, rh_data], ['lft_data:0', 'rgt_data:0'], 'LeftShift:0') + +def test_forward_left_shift(): + _test_forward_left_shift((10,), 'int32') + _test_forward_left_shift((224, 224, 3), 'int16') + ####################################################################### # Mean # ---- @@ -1652,13 +1832,13 @@ def check_mean(ishape, **kwargs): compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Mean:0', no_gpu=True) check_mean((10, 8, 16, 32)) - check_mean((10, 8, 16, 32), axis=(2,3)) - check_mean((10, 8, 16, 32), axis=(1,2), keepdims=True) + check_mean((10, 8, 16, 32), axis=(2, 3)) + check_mean((10, 8, 16, 32), axis=(1, 2), keepdims=True) ####################################################################### -# All -# --- -def test_forward_all(): +# All, Max, Min +# ------------- +def test_forward_reduce_all(): """Test the All operator.""" np_data = np.random.choice([True, False], size=(5, 7, 11)) tf.reset_default_graph() @@ -1666,6 +1846,30 @@ def test_forward_all(): tf.reduce_all(in_data, name="all") compare_tf_with_tvm([np_data], ['in_data:0'], 'all:0') +def test_forward_reduce_max(): + def check_max(ishape, axis, keepdims, dtype): + tf.reset_default_graph() + np_data = np.random.uniform(size=ishape).astype(dtype) + in_data = tf.placeholder(dtype, name="in_data") + tf.math.reduce_max(in_data, axis=axis, keepdims=keepdims, name="reduce_max") + compare_tf_with_tvm([np_data], ['in_data:0'], 'reduce_max:0') + + check_max((10, 8, 16, 32), axis=(-1), keepdims=True, dtype="int32") + check_max((10, 8, 16, 32), axis=(2, 3), keepdims=True, dtype="float32") + check_max((10, 8, 16, 32), axis=(1, 2), keepdims=True, dtype='float32') + +def test_forward_reduce_min(): + def check_min(ishape, axis, keepdims, dtype): + tf.reset_default_graph() + np_data = np.random.uniform(size=ishape).astype(dtype) + in_data = tf.placeholder(dtype, name="in_data") + tf.math.reduce_min(in_data, axis=axis, keepdims=keepdims, name="reduce_max") + compare_tf_with_tvm([np_data], ['in_data:0'], 'reduce_max:0') + + check_min((10, 8, 16, 32), axis=(-1), keepdims=True, dtype="int32") + check_min((10, 8, 16, 32), axis=(2, 3), keepdims=True, dtype="float32") + check_min((10, 8, 16, 32), axis=(1, 2), keepdims=True, dtype='float32') + ####################################################################### # Relational operators # -------------------- @@ -1723,6 +1927,38 @@ def test_forward_reduce_prod(): _test_forward_reduce_prod((5, 5), 1, True) +####################################################################### +# Maximum, Minimum +# ---------------- +def test_forward_maximum(): + """test Op Maximum""" + def check_maximum(lh_shape, rh_shape, dtype): + tf.reset_default_graph() + lh_data = np.random.uniform(size=lh_shape).astype(dtype) + rh_data = np.random.uniform(size=rh_shape).astype(dtype) + lft_data = tf.placeholder(dtype, name="lft_data") + rgt_data = tf.placeholder(dtype, name="rgt_data") + tf.math.maximum(lft_data, rgt_data, name="maximum") + compare_tf_with_tvm([lh_data, rh_data], ['lft_data:0', 'rgt_data:0'], 'maximum:0') + + check_maximum((10, 8, 16, 32), (1,), dtype="int32") + check_maximum((10, 8, 16, 32), (10, 8, 16, 32), dtype="float32") + +def test_forward_minimum(): + """test Op Minimum""" + def check_minimum(lh_shape, rh_shape, dtype): + tf.reset_default_graph() + lh_data = np.random.uniform(size=lh_shape).astype(dtype) + rh_data = np.random.uniform(size=rh_shape).astype(dtype) + lft_data = tf.placeholder(dtype, name="lft_data") + rgt_data = tf.placeholder(dtype, name="rgt_data") + tf.math.minimum(lft_data, rgt_data, name="minimum") + compare_tf_with_tvm([lh_data, rh_data], ['lft_data:0', 'rgt_data:0'], 'minimum:0') + + check_minimum((10, 8, 16, 32), (1,), dtype="int32") + check_minimum((10, 8, 16, 32), (10, 8, 16, 32), dtype="float32") + + ####################################################################### # PlaceholderWithDefault # ---------------------- @@ -1740,6 +1976,7 @@ def test_placeholder(): compare_tf_with_tvm([in_data1, in_data2], ['place1:0', 'in2:0'], 'out2:0', init_global_variables=True) + ####################################################################### # Main # ---- @@ -1756,14 +1993,22 @@ def test_placeholder(): test_forward_fill() test_forward_crop() test_forward_pad() + test_forward_unpack() test_forward_gather() test_forward_gather_v1() + test_forward_gather_nd() test_forward_stridedslice() test_forward_split() test_forward_unstack() test_forward_tile() test_forward_top_k_v2() test_forward_clip_by_value() + test_forward_maximum() + test_forward_minimum() + test_forward_range() + test_forward_right_shift() + test_forward_left_shift() + test_forward_truncatemod() # Activations test_forward_sigmoid() @@ -1780,17 +2025,26 @@ def test_placeholder(): test_forward_sign() test_forward_log() test_forward_negative() + test_forward_divide() + test_forward_abs() test_forward_softplus() test_forward_sqrt() test_forward_rsqrt() test_forward_expand_dims() + test_forward_square() + test_forward_softmax() + test_forward_log_softmax() + test_forward_bias_add() + test_forward_zeros_like() # Reductions test_forward_argminmax() test_forward_reduce() test_forward_mean() test_forward_reduce_prod() - test_forward_all() + test_forward_reduce_all() + test_forward_reduce_max() + test_forward_reduce_min() # General test_forward_multi_input() @@ -1826,7 +2080,7 @@ def test_placeholder(): # Relational ops test_forward_rel_ops() test_forward_logical() - test_where() + test_forward_where() test_forward_matmul() # TODO missing tests: rank, range