diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index b34e6c723645..ea1abc843c20 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -1005,6 +1005,13 @@ def _impl(inputs, attr, params, mod): return _impl +def _identityn(): + def _impl(inputs, attr, params, mod): + return inputs + + return _impl + + def _concatV2(): def _impl(inputs, attr, params, mod): pop_node = inputs.pop(len(inputs) - 1) @@ -2378,6 +2385,7 @@ def _impl(inputs, attr, params, mod): "Greater": _broadcast("greater"), "GreaterEqual": _broadcast("greater_equal"), "Identity": _identity(), + "IdentityN": _identityn(), "IsFinite": AttrCvt("isfinite"), "IsInf": AttrCvt("isinf"), "IsNan": AttrCvt("isnan"), diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 34ee0f3528ae..fd4b9f49e6a4 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -4074,6 +4074,56 @@ def test_forward_dilation(): _test_dilation2d([1, 3, 3, 1], [2, 2, 1], [1, 1, 1, 1], [1, 1, 2, 1], "VALID") +def _test_identityn(data_np_list): + with tf.Graph().as_default(): + data_tensors = [] + data_tensors_name = [] + for index, data_np in enumerate(data_np_list): + tensor_name = f"data_{index}" + data_tensors_name.append(tensor_name + ":0") + data_tensors.append( + tf.placeholder(shape=data_np.shape, dtype=str(data_np.dtype), name=tensor_name) + ) + + output = tf.identity_n(data_tensors) + output_names = [out.name for out in output] + compare_tf_with_tvm( + data_np_list, + data_tensors_name, + output_names, + ) + + +@pytest.mark.parametrize( + "data_np_list", + [ + ( + [ + np.array([[1, 1], [0, 3], [0, 1], [2, 0], [3, 1]], dtype=np.int64), + np.array([1, 2, 3, 4, 5], dtype=np.int64), + np.array([5, 6], dtype=np.int64), + ] + ), + ( + [ + np.array([[1, 1], [0, 3], [2, 0], [3, 1]], dtype=np.int64), + np.array([1, 2, 3, 4], dtype=np.int64), + np.array([5, 6], dtype=np.int64), + np.array([True, False, True]), + ] + ), + ( + [ + np.array([]), + np.array([[]]), + ] + ), + ], +) +def test_forward_identityn(data_np_list): + _test_identityn(data_np_list) + + ####################################################################### # Sparse To Dense # ---------------