Skip to content

Commit 9a263a8

Browse files
committed
[TFLite] added scalar axis value handling in reduce
Axis value in reduce can now be specified as scalar
1 parent 448278d commit 9a263a8

File tree

2 files changed

+13
-8
lines changed

2 files changed

+13
-8
lines changed

python/tvm/relay/frontend/tflite.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1638,7 +1638,8 @@ def _convert_reduce(self, relay_op, op):
16381638
in_expr = self.get_expr(input_tensor.tensor_idx)
16391639

16401640
# axis
1641-
axis = tuple(self.get_tensor_value(input_tensors[1]))
1641+
axis_value = self.get_tensor_value(input_tensors[1])
1642+
axis = tuple(axis_value) if len(axis_value.shape) > 0 else tuple((axis_value.item(),))
16421643

16431644
# Options - keep_dims (bool)
16441645
assert op.BuiltinOptionsType() == BuiltinOptions.ReducerOptions

tests/python/frontend/tflite/test_forward.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2143,18 +2143,22 @@ def _test_forward_reduce(testop, dtype="float32"):
21432143
if dtype == "bool":
21442144
data0 = [np.random.choice(a=[False, True], size=(16, 16, 16, 16)).astype(dtype), None]
21452145
data1 = [
2146+
np.random.choice(a=[False, True], size=(16, 16, 16, 16)).astype(dtype),
2147+
np.array(1, dtype=np.int32),
2148+
]
2149+
data2 = [
21462150
np.random.choice(a=[False, True], size=(16, 16, 16, 16)).astype(dtype),
21472151
np.array([1, 2], dtype=np.int32),
21482152
]
21492153
else:
21502154
data0 = [np.random.rand(16, 16, 16, 16).astype(dtype), None]
2151-
data1 = [np.random.rand(16, 16, 16, 16).astype(dtype), np.array([1, 2], dtype=np.int32)]
2152-
testop(data0)
2153-
testop(data0, keep_dims=False)
2154-
testop(data0, keep_dims=True)
2155-
testop(data1)
2156-
testop(data1, keep_dims=False)
2157-
testop(data1, keep_dims=True)
2155+
data1 = [np.random.rand(16, 16, 16, 16).astype(dtype), np.array(1, dtype=np.int32)]
2156+
data2 = [np.random.rand(16, 16, 16, 16).astype(dtype), np.array([1, 2], dtype=np.int32)]
2157+
2158+
for data in [data0, data1, data2]:
2159+
testop(data)
2160+
testop(data, keep_dims=False)
2161+
testop(data, keep_dims=True)
21582162

21592163

21602164
def _test_forward_reduce_quantized(testop):

0 commit comments

Comments
 (0)