Skip to content

Commit 68c8106

Browse files
apivovarovwweic
authored andcommitted
Add PAD operator to relay tflite frontend (apache#3310)
1 parent 88b4873 commit 68c8106

File tree

2 files changed

+56
-0
lines changed

2 files changed

+56
-0
lines changed

python/tvm/relay/frontend/tflite.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def __init__(self, model, subgraph, exp_tab):
6666
'ADD': self.convert_add,
6767
'MUL': self.convert_mul,
6868
'FULLY_CONNECTED': self.convert_fully_connected,
69+
'PAD': self.convert_pad,
6970
}
7071

7172
def check_unsupported_ops(self):
@@ -596,6 +597,31 @@ def convert_pool2d(self, op, pool_type):
596597

597598
return out
598599

600+
def convert_pad(self, op):
601+
"""Convert TFLite PAD"""
602+
try:
603+
from tflite.Operator import Operator
604+
except ImportError:
605+
raise ImportError("The tflite package must be installed")
606+
607+
assert isinstance(op, Operator)
608+
input_tensors = self.get_input_tensors(op)
609+
assert len(input_tensors) == 2, "input tensors length should be 2"
610+
611+
# TFLite only support CONSTANT mode and does not support constant_values parameter.
612+
# tensor
613+
input_tensor = input_tensors[0]
614+
in_expr = self.get_expr(input_tensor.tensor_idx)
615+
616+
# paddings
617+
pad_list = self.get_tensor_value(input_tensors[1])
618+
# convert list of lists to tuple of tuples
619+
paddings = tuple(tuple(l) for l in pad_list)
620+
621+
# Use default pad_value 0 because TFLite does not support constant_values parameter
622+
out = _op.nn.pad(in_expr, paddings)
623+
return out
624+
599625
def get_expr(self, input_tensor_idx):
600626
return self.exp_tab.get_expr(get_tensor_name(self.subgraph, input_tensor_idx))
601627

tests/python/frontend/tflite/test_forward.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,35 @@ def test_forward_squeeze():
394394
_test_squeeze(np.arange(6).reshape((1, 2, 1, 3)), [0, 2])
395395
_test_squeeze(np.arange(6).reshape((2, 1, 3, 1)), [1, 3])
396396

397+
398+
#######################################################################
399+
# Pad
400+
# ---
401+
402+
def _test_pad(data):
403+
""" One iteration of PAD """
404+
405+
assert len(data) == 2
406+
407+
# Test with tensor and constant
408+
with tf.Graph().as_default():
409+
in_data = [array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in')]
410+
out = array_ops.pad(in_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype))
411+
compare_tflite_with_tvm([data[0]], ['in:0'], in_data, [out])
412+
413+
414+
def test_forward_pad():
415+
""" Pad """
416+
_test_pad([np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 1, 3)),
417+
np.array([[1, 1], [2, 2], [1, 1], [2, 2]], dtype=np.int32)])
418+
_test_pad([np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 3)),
419+
np.array([[2, 2], [1, 1], [1, 1]], dtype=np.int32)])
420+
_test_pad([np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 3)),
421+
np.array([[1, 1], [2, 2]], dtype=np.int32)])
422+
_test_pad([np.arange(1.0, 4.0, dtype=np.float32).reshape((1, 3)),
423+
np.array([[1, 1], [2, 2]], dtype=np.int32)])
424+
425+
397426
#######################################################################
398427
# Softmax
399428
# -------
@@ -528,6 +557,7 @@ def test_forward_inception_v4_net():
528557
if __name__ == '__main__':
529558
# Transforms
530559
test_forward_concatenation()
560+
test_forward_pad()
531561
test_forward_reshape()
532562
test_forward_squeeze()
533563

0 commit comments

Comments
 (0)