diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 162cc36a89e5..ace5fdd96a52 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -82,7 +82,8 @@ def __init__(self, model, subgraph, exp_tab): 'PACK': self.convert_pack, 'LOGISTIC': self.convert_logistic, 'SPLIT': self.convert_split, - 'TRANSPOSE': self.convert_transpose + 'TRANSPOSE': self.convert_transpose, + 'TILE': self.convert_tile } def check_unsupported_ops(self): @@ -769,6 +770,28 @@ def convert_transpose(self, op): return out + def convert_tile(self, op): + """tile implementation.""" + try: + from tflite.Operator import Operator + except ImportError: + raise ImportError("The tflite package must be installed") + + assert isinstance(op, Operator) + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 2, "input tensors length should be 2" + input_tensor = input_tensors[0] + input_tensor_idx = input_tensor.tensor_idx + + in_expr = self.get_expr(input_tensor_idx) + + # reps (tuple of int) – The number of times repeating the tensor data. + reps = tuple(self.get_tensor_value(input_tensors[1])) + + out = _op.tile(in_expr, reps) + + return out + def convert_pool2d(self, op, pool_type): """pool2d implementation.""" try: diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index a78225cd5646..771226056ed4 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -229,6 +229,26 @@ def test_forward_transpose(): _test_forward_transpose((2, 3, 4, 5), (3, 0, 1, 2)) _test_forward_transpose((2, 3, 4, 5), ()) +####################################################################### +# tile +# --------- + + +def _test_forward_tile(in_shape, reps, dtype): + data = np.random.uniform(-5, 5, size=in_shape).astype(dtype) + + with tf.Graph().as_default(): + in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) + + out = array_ops.tile(in_data, reps) + + compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out]) + + +def test_forward_tile(): + _test_forward_tile((2, ), (3, ), "int32") + _test_forward_tile((2, 2), (2, 3), "float32") + ####################################################################### # Pooling @@ -856,6 +876,9 @@ def test_forward_ssd_mobilenet_v1(): # Transpose test_forward_transpose() + # Tile + test_forward_tile() + # Transforms test_forward_concatenation() test_forward_pad()