Skip to content

Commit cf439ec

Browse files
author
Matthew Brookhart
authored
support slicing with out of order axes (#8959)
1 parent 1854e10 commit cf439ec

File tree

2 files changed

+13
-21
lines changed

2 files changed

+13
-21
lines changed

python/tvm/relay/frontend/onnx.py

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1419,20 +1419,13 @@ class Slice(OnnxOpConverter):
14191419

14201420
@classmethod
14211421
def _common(cls, starts, ends, axes):
1422-
new_axes = []
1423-
new_starts = []
1424-
new_ends = []
1425-
pop_index = 0
1426-
for i in range(max(axes) + 1):
1427-
if i in axes:
1428-
new_axes.append(i)
1429-
new_starts.append(starts[pop_index])
1430-
new_ends.append(ends[pop_index])
1431-
pop_index += 1
1432-
else:
1433-
new_axes.append(i)
1434-
new_starts.append(0)
1435-
new_ends.append(np.iinfo(np.int32).max)
1422+
N = max(axes) + 1
1423+
new_axes = list(range(N))
1424+
new_starts = [0] * N
1425+
new_ends = [np.iinfo(np.int32).max] * N
1426+
for i, axis in enumerate(axes):
1427+
new_starts[axis] = starts[i]
1428+
new_ends[axis] = ends[i]
14361429
return new_starts, new_ends, new_axes
14371430

14381431
@classmethod
@@ -1445,13 +1438,10 @@ def _impl_v1(cls, inputs, attr, params):
14451438
# Update the starts and ends according to axes if required.
14461439
if isinstance(attr["axes"], int):
14471440
attr["axes"] = (attr["axes"],)
1448-
if (max(attr["axes"]) + 1) != len(attr["axes"]):
1449-
new_starts, new_ends, new_axes = cls._common(
1450-
attr["starts"], attr["ends"], attr["axes"]
1451-
)
1452-
attr["axes"] = new_axes
1453-
attr["starts"] = new_starts
1454-
attr["ends"] = new_ends
1441+
new_starts, new_ends, new_axes = cls._common(attr["starts"], attr["ends"], attr["axes"])
1442+
attr["axes"] = new_axes
1443+
attr["starts"] = new_starts
1444+
attr["ends"] = new_ends
14551445
except KeyError:
14561446
pass
14571447
begin = list(attr["starts"])

tests/python/frontend/onnx/test_forward.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -885,10 +885,12 @@ def add_noop_to_input_attr(attr_name, attr):
885885

886886
x = np.random.randn(20, 10, 5).astype(np.float32)
887887
_test_slice_iteration_v1(x, x[0:3, 0:10], starts=(0, 0), ends=(3, 10), axes=(0, 1))
888+
_test_slice_iteration_v1(x, x[0:3, 0:10], starts=(0, 0), ends=(10, 3), axes=(1, 0))
888889
_test_slice_iteration_v1(x, x[:, :, 3:4], starts=(0, 0, 3), ends=(20, 10, 4))
889890
_test_slice_iteration_v1(x, x[:, 1:1000], starts=(1,), ends=(1000,), axes=(1,))
890891
_test_slice_iteration_v1(x, x[:, 0:-1], starts=(0,), ends=(-1,), axes=(1,))
891892
_test_slice_iteration_v10(x, x[0:3, 0:10], starts=(0, 0), ends=(3, 10), axes=(0, 1))
893+
_test_slice_iteration_v10(x, x[0:3, 0:10], starts=(0, 0), ends=(10, 3), axes=(1, 0))
892894
_test_slice_iteration_v10(x, x[:, :, 3:4], starts=(0, 0, 3), ends=(20, 10, 4))
893895
_test_slice_iteration_v10(x, x[:, 1:1000], starts=(1,), ends=(1000,), axes=(1,))
894896
_test_slice_iteration_v10(x, x[:, 0:-1], starts=(0,), ends=(-1,), axes=(1,))

0 commit comments

Comments
 (0)