Skip to content

Commit 66744d9

Browse files
authored
[TFLite] pack operation extedned with const args (#6984)
pack operation now accepts constant arguments
1 parent 968b6f6 commit 66744d9

File tree

2 files changed

+22
-12
lines changed

2 files changed

+22
-12
lines changed

python/tvm/relay/frontend/tflite.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2524,9 +2524,6 @@ def convert_pack(self, op):
25242524
raise ImportError("The tflite package must be installed")
25252525

25262526
input_tensors = self.get_input_tensors(op)
2527-
assert len(input_tensors) >= 1, "input tensors should greater than 1"
2528-
in_exprs = [self.get_expr(input_tensor.tensor_idx) for input_tensor in input_tensors]
2529-
25302527
output_tensors = self.get_output_tensors(op)
25312528
assert len(output_tensors) == 1, "output tensors length should be 1"
25322529

@@ -2535,8 +2532,11 @@ def convert_pack(self, op):
25352532
pack_options = PackOptions()
25362533
pack_options.Init(op_options.Bytes, op_options.Pos)
25372534
pack_axis = pack_options.Axis()
2535+
pack_values_count = pack_options.ValuesCount()
2536+
assert len(input_tensors) == pack_values_count, "Discordance in input values count"
25382537

2539-
in_exprs_reshaped = [_op.expand_dims(i, axis=pack_axis, num_newaxis=1) for i in in_exprs]
2538+
in_exprs = [self.get_tensor_expr(_) for _ in input_tensors]
2539+
in_exprs_reshaped = [_op.expand_dims(_, axis=pack_axis, num_newaxis=1) for _ in in_exprs]
25402540
out = _op.concatenate(in_exprs_reshaped, pack_axis)
25412541
return out
25422542

tests/python/frontend/tflite/test_forward.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2750,34 +2750,44 @@ def test_forward_one_hot():
27502750
# ----
27512751

27522752

2753-
def _test_pack(data, axis):
2753+
def _test_pack(data, is_var, axis):
27542754
""" One iteration of pack """
27552755

27562756
assert len(data) >= 1
2757+
assert len(data) == len(is_var)
27572758

27582759
with tf.Graph().as_default():
27592760
in_data = [
2760-
array_ops.placeholder(shape=tensor.shape, dtype=tensor.dtype, name="in_{}".format(idx))
2761-
for idx, tensor in enumerate(data)
2761+
array_ops.placeholder(shape=d.shape, dtype=d.dtype, name="in_" + str(idx))
2762+
if is_var[idx]
2763+
else constant_op.constant(
2764+
d, shape=d.shape, dtype=d.dtype, name="in_constant_" + str(idx)
2765+
)
2766+
for idx, d in enumerate(data)
27622767
]
2763-
out = array_ops.pack(in_data, axis=axis)
2764-
name = ["in_{}:0".format(idx) for idx in range(len(data))]
27652768

2766-
compare_tflite_with_tvm(data, name, in_data, [out])
2769+
out = array_ops.pack(in_data, axis=axis)
2770+
name = [_.name for _ in in_data]
2771+
compare_tflite_with_tvm(data, name, in_data, [out], experimental_new_converter=True)
27672772

27682773

27692774
def test_forward_pack():
27702775
""" Pack """
2771-
_test_pack([np.arange(6).reshape((1, 2, 1, 3)), np.arange(6).reshape((1, 2, 1, 3))], 1)
2776+
_test_pack([np.int32(1), np.int32(5)], [False, False], 0)
2777+
_test_pack([np.array([1, 4]), np.array([2, 5]), np.array([3, 6])], [True, False, False], 0)
2778+
_test_pack(
2779+
[np.arange(6).reshape((1, 2, 1, 3)), np.arange(6).reshape((1, 2, 1, 3))], [True, True], 1
2780+
)
27722781

2773-
_test_pack([np.arange(6).reshape((3, 2)), np.arange(6).reshape((3, 2))], 1)
2782+
_test_pack([np.arange(6).reshape((3, 2)), np.arange(6).reshape((3, 2))], [True, True], 1)
27742783

27752784
_test_pack(
27762785
[
27772786
np.arange(6).reshape((2, 1, 1, 3)),
27782787
np.arange(6).reshape((2, 1, 1, 3)),
27792788
np.arange(6).reshape((2, 1, 1, 3)),
27802789
],
2790+
[True, True, True],
27812791
1,
27822792
)
27832793

0 commit comments

Comments
 (0)