Skip to content

Commit dee5246

Browse files
cchung100mjroesch
authored andcommitted
Implementation of tile for TFLite (#3814)
1 parent eef35a5 commit dee5246

File tree

2 files changed

+47
-1
lines changed

2 files changed

+47
-1
lines changed

python/tvm/relay/frontend/tflite.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,8 @@ def __init__(self, model, subgraph, exp_tab):
8282
'PACK': self.convert_pack,
8383
'LOGISTIC': self.convert_logistic,
8484
'SPLIT': self.convert_split,
85-
'TRANSPOSE': self.convert_transpose
85+
'TRANSPOSE': self.convert_transpose,
86+
'TILE': self.convert_tile
8687
}
8788

8889
def check_unsupported_ops(self):
@@ -769,6 +770,28 @@ def convert_transpose(self, op):
769770

770771
return out
771772

773+
def convert_tile(self, op):
774+
"""tile implementation."""
775+
try:
776+
from tflite.Operator import Operator
777+
except ImportError:
778+
raise ImportError("The tflite package must be installed")
779+
780+
assert isinstance(op, Operator)
781+
input_tensors = self.get_input_tensors(op)
782+
assert len(input_tensors) == 2, "input tensors length should be 2"
783+
input_tensor = input_tensors[0]
784+
input_tensor_idx = input_tensor.tensor_idx
785+
786+
in_expr = self.get_expr(input_tensor_idx)
787+
788+
# reps (tuple of int) – The number of times repeating the tensor data.
789+
reps = tuple(self.get_tensor_value(input_tensors[1]))
790+
791+
out = _op.tile(in_expr, reps)
792+
793+
return out
794+
772795
def convert_pool2d(self, op, pool_type):
773796
"""pool2d implementation."""
774797
try:

tests/python/frontend/tflite/test_forward.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,26 @@ def test_forward_transpose():
229229
_test_forward_transpose((2, 3, 4, 5), (3, 0, 1, 2))
230230
_test_forward_transpose((2, 3, 4, 5), ())
231231

232+
#######################################################################
233+
# tile
234+
# ---------
235+
236+
237+
def _test_forward_tile(in_shape, reps, dtype):
238+
data = np.random.uniform(-5, 5, size=in_shape).astype(dtype)
239+
240+
with tf.Graph().as_default():
241+
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
242+
243+
out = array_ops.tile(in_data, reps)
244+
245+
compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
246+
247+
248+
def test_forward_tile():
249+
_test_forward_tile((2, ), (3, ), "int32")
250+
_test_forward_tile((2, 2), (2, 3), "float32")
251+
232252

233253
#######################################################################
234254
# Pooling
@@ -856,6 +876,9 @@ def test_forward_ssd_mobilenet_v1():
856876
# Transpose
857877
test_forward_transpose()
858878

879+
# Tile
880+
test_forward_tile()
881+
859882
# Transforms
860883
test_forward_concatenation()
861884
test_forward_pad()

0 commit comments

Comments
 (0)