Skip to content

Commit 7e05d15

Browse files
Matthew Brookharttrevor-m
authored andcommitted
Dynamic Strided Slice (apache#6316)
* Dynamic Strided Slice * fix clang-format lint * remove debug print * respond to review comments * respond to yongwww's comments * fix bad rebase * revert hybrid-script assert * reformat mxnet change * use new testing api * while getting test to work with the new testing API, refactor all of the tests iin the dyn directory
1 parent eda3e0d commit 7e05d15

30 files changed

+634
-274
lines changed

python/tvm/relay/frontend/keras.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -622,8 +622,8 @@ def _convert_cropping(inexpr, keras_layer, _):
622622
raise tvm.error.OpNotImplemented(
623623
'Operator {} is not supported for frontend Keras.'.format(crop_type))
624624
int32_max = np.iinfo(np.int32).max
625-
return _op.strided_slice(inexpr, begin=_expr.const([0, 0, crop_t, crop_l]), \
626-
end=_expr.const([int32_max, int32_max, in_h-crop_b, in_w-crop_r]))
625+
return _op.strided_slice(inexpr, begin=[0, 0, crop_t, crop_l], \
626+
end=[int32_max, int32_max, in_h-crop_b, in_w-crop_r])
627627

628628

629629
def _convert_batchnorm(inexpr, keras_layer, etab):

python/tvm/relay/frontend/mxnet.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -500,11 +500,11 @@ def _mx_slice(inputs, attrs):
500500
for i, ed in enumerate(end):
501501
if ed is None:
502502
end[i] = input_shape[i]
503-
new_attrs = {'begin': _expr.const(list(begin), dtype="int32"),
504-
'end': _expr.const(list(end), dtype="int32")}
503+
new_attrs = {'begin': list(begin),
504+
'end': list(end)}
505505
if stride is not None:
506506
stride = (x if x is not None else 1 for x in stride)
507-
new_attrs['strides'] = _expr.const(list(stride), dtype="int32")
507+
new_attrs['strides'] = list(stride)
508508
return _op.strided_slice(inputs[0], **new_attrs)
509509

510510

@@ -544,9 +544,7 @@ def _mx_slice_axis(inputs, attrs):
544544
else:
545545
begin.append(ax_beg)
546546
end.append(ax_end)
547-
return _op.strided_slice(inputs[0],
548-
_expr.const(begin, dtype="int32"),
549-
_expr.const(end, dtype="int32"))
547+
return _op.strided_slice(inputs[0], begin, end)
550548

551549

552550
def _mx_crop_like(inputs, attrs):
@@ -566,9 +564,9 @@ def _mx_crop_like(inputs, attrs):
566564
return _op.slice_like(*inputs, **new_attrs)
567565
expr = _infer_type(inputs[1])
568566
like_shape = expr.checked_type.shape
569-
new_attrs['begin'] = _expr.const([0, 0, offset[0], offset[1]], dtype="int32")
570-
new_attrs['end'] = _expr.const([like_shape[0], like_shape[1], offset[0]+like_shape[2],
571-
offset[1]+like_shape[3]], dtype="int32")
567+
new_attrs['begin'] = [0, 0, offset[0], offset[1]]
568+
new_attrs['end'] = [like_shape[0], like_shape[1], offset[0]+like_shape[2],
569+
offset[1]+like_shape[3]]
572570
return _op.strided_slice(inputs[0], **new_attrs)
573571

574572

python/tvm/relay/frontend/onnx.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1049,8 +1049,8 @@ def _impl_v1(cls, inputs, attr, params):
10491049
end = list(attr['ends'])
10501050

10511051
return _op.strided_slice(inputs[0],
1052-
begin=_expr.const(begin, dtype="int64"),
1053-
end=_expr.const(end, dtype="int64"))
1052+
begin=begin,
1053+
end=end)
10541054

10551055
@classmethod
10561056
def _impl_v10(cls, inputs, attr, params):
@@ -1070,8 +1070,8 @@ def _impl_v10(cls, inputs, attr, params):
10701070
attrs['starts'] = new_starts
10711071
attrs['ends'] = new_ends
10721072
return _op.strided_slice(inputs[0],
1073-
begin=_expr.const(attrs['starts'], dtype="int64"),
1074-
end=_expr.const(attrs['ends'], dtype="int64"))
1073+
begin=list(attrs['starts']),
1074+
end=list(attrs['ends']))
10751075

10761076

10771077
class Gather(OnnxOpConverter):

python/tvm/relay/frontend/pytorch.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -309,9 +309,9 @@ def _impl(inputs, input_types):
309309
strides[dim] = int(inputs[4])
310310

311311
return _op.transform.strided_slice(data,
312-
begin=_expr.const(begin),
313-
end=_expr.const(end),
314-
strides=_expr.const(strides),
312+
begin=begin,
313+
end=end,
314+
strides=strides,
315315
slice_mode="end")
316316
return _impl
317317

@@ -1373,9 +1373,9 @@ def _impl(inputs, input_types):
13731373
stride = [1] * len(shape)
13741374

13751375
chunk_out = _op.transform.strided_slice(data,
1376-
begin=_expr.const(begin),
1377-
end=_expr.const(end),
1378-
strides=_expr.const(stride))
1376+
begin=begin,
1377+
end=end,
1378+
strides=stride)
13791379
chunks.append(chunk_out)
13801380

13811381
if dim % num_chunks:
@@ -1386,9 +1386,9 @@ def _impl(inputs, input_types):
13861386
stride = [1] * len(shape)
13871387

13881388
chunk_out = _op.transform.strided_slice(data,
1389-
begin=_expr.const(begin),
1390-
end=_expr.const(end),
1391-
strides=_expr.const(stride))
1389+
begin=begin,
1390+
end=end,
1391+
strides=stride)
13921392
chunks.append(chunk_out)
13931393

13941394
return chunks

python/tvm/relay/op/_tensor_grad.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -407,9 +407,9 @@ def conv2d_grad(orig, grad):
407407
assert padded_weight_grad_w >= filter_w
408408
if padded_weight_grad_h > filter_h or padded_weight_grad_w > filter_w:
409409
backward_weight = strided_slice(backward_weight,
410-
begin=const([0, 0, 0, 0], dtype="int64"),
411-
end=const([out_channel, in_channel // attrs.groups,
412-
filter_h, filter_w], dtype="int64"))
410+
begin=[0, 0, 0, 0],
411+
end=[out_channel, in_channel // attrs.groups,
412+
filter_h, filter_w])
413413

414414
return [backward_data, backward_weight]
415415

python/tvm/relay/op/_transform.py

Lines changed: 19 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -127,33 +127,6 @@ def arange_shape_func(attrs, inputs, _):
127127
"""
128128
return [_arange_shape_func(*inputs)]
129129

130-
@script
131-
def _strided_slice_shape_func_input_data(data, begin, end, strides,
132-
slice_mode):
133-
ndim = len(data.shape)
134-
out = output_tensor((ndim,), "int64")
135-
for i in const_range(ndim):
136-
cbegin = 0
137-
cend = data.shape[i]
138-
cstride = 1
139-
if strides.shape[0] > i:
140-
cstride = strides[i]
141-
if begin.shape[0] > i:
142-
cbegin = begin[i]
143-
if end.shape[0] <= i:
144-
cend = data.shape[i]
145-
elif slice_mode != 0:
146-
cstride = 1
147-
if end[i] < 0:
148-
cend = data.shape[i]
149-
else:
150-
cend = cbegin + end[i]
151-
else:
152-
cend = end[i]
153-
assert cstride != 0, "Strides can't be zero."
154-
out[i] = int64(ceil_div((int64(cend) - int64(cbegin)), int64(cstride)))
155-
return out
156-
157130
@script
158131
def _strided_slice_shape_func_input_shape(data_shape, begin, end, strides, slice_mode):
159132
ndim = data_shape.shape[0]
@@ -166,6 +139,8 @@ def _strided_slice_shape_func_input_shape(data_shape, begin, end, strides, slice
166139
cstride = int64(strides[i])
167140
if len(begin) > i:
168141
cbegin = int64(begin[i])
142+
if cbegin < 0:
143+
cbegin += int64(data_shape[i])
169144
if len(end) <= i:
170145
cend = int64(data_shape[i])
171146
elif slice_mode != 0:
@@ -175,23 +150,32 @@ def _strided_slice_shape_func_input_shape(data_shape, begin, end, strides, slice
175150
else:
176151
cend = cbegin + int64(end[i])
177152
else:
178-
cend = int64(end[i])
153+
if end[i] > data_shape[i]:
154+
cend = int64(data_shape[i])
155+
else:
156+
cend = int64(end[i])
157+
if cend < 0:
158+
cend += int64(data_shape[i])
179159
assert cstride != 0, "Strides can't be zero."
180-
out[i] = int64(ceil_div((int64(cend) - int64(cbegin)), int64(cstride)))
160+
if cstride < 0:
161+
slice_range = cbegin - cend
162+
step = -cstride
163+
else:
164+
slice_range = cend - cbegin
165+
step = cstride
166+
167+
out[i] = int64(ceil_div(slice_range, step))
181168
return out
182169

183170

184-
@_reg.register_shape_func("strided_slice", True)
171+
@_reg.register_shape_func("strided_slice", False)
185172
def strided_slice_shape_func(attrs, inputs, _):
186173
"""
187174
Shape func for strided_slice
188175
"""
189176
slice_mode = convert(0 if attrs.slice_mode == "end" else 1)
190-
# data independent if begin, end and strides exist
191-
if attrs.begin and attrs.end and attrs.strides:
192-
return [_strided_slice_shape_func_input_shape(inputs[0], attrs.begin, attrs.end,
193-
attrs.strides, slice_mode)]
194-
return [_strided_slice_shape_func_input_data(*inputs, slice_mode)]
177+
return [_strided_slice_shape_func_input_shape(inputs[0], attrs.begin, attrs.end,
178+
attrs.strides, slice_mode)]
195179

196180
@script
197181
def _concatenate_shape_func(inputs, axis):

python/tvm/relay/op/dyn/_transform.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
_reg.register_broadcast_schedule("dyn.tile")
2828
_reg.register_injective_schedule("dyn.one_hot")
2929
_reg.register_injective_schedule("dyn.full")
30+
_reg.register_injective_schedule("dyn.strided_slice")
3031

3132
@script
3233
def _reshape_shape_func_input_data(data, newshape, ndim):
@@ -145,3 +146,53 @@ def one_hot_shape_func(attrs, inputs, _):
145146
"""
146147
axis = len(inputs[0].shape) if attrs.axis == -1 else attrs.axis
147148
return [_onehot_shape_func(inputs[0].shape, inputs[3], convert(axis))]
149+
150+
151+
@script
152+
def _strided_slice_shape_func_input_data(data, begin, end, strides,
153+
slice_mode):
154+
ndim = len(data.shape)
155+
out = output_tensor((ndim,), "int64")
156+
for i in const_range(ndim):
157+
cbegin = int64(0)
158+
cend = int64(data.shape[i])
159+
cstride = int64(1)
160+
if strides.shape[0] > i:
161+
cstride = int64(strides[i])
162+
if begin.shape[0] > i:
163+
cbegin = int64(begin[i])
164+
if cbegin < 0:
165+
cbegin += int64(data.shape[i])
166+
if end.shape[0] <= i:
167+
cend = int64(data.shape[i])
168+
elif slice_mode != 0:
169+
cstride = int64(1)
170+
if end[i] < 0:
171+
cend = int64(data.shape[i])
172+
else:
173+
cend = cbegin + int64(end[i])
174+
else:
175+
if end[i] > data.shape[i]:
176+
cend = int64(data.shape[i])
177+
else:
178+
cend = int64(end[i])
179+
if cend < 0:
180+
cend += int64(data.shape[i])
181+
assert cstride != 0, "Strides can't be zero."
182+
if cstride < 0:
183+
slice_range = cbegin - cend
184+
step = -cstride
185+
else:
186+
slice_range = cend - cbegin
187+
step = cstride
188+
189+
out[i] = int64(ceil_div(slice_range, step))
190+
return out
191+
192+
@_reg.register_shape_func("dyn.strided_slice", True)
193+
def strided_slice_shape_func(attrs, inputs, _):
194+
"""
195+
Shape func for strided_slice
196+
"""
197+
slice_mode = convert(0 if attrs.slice_mode == "end" else 1)
198+
return [_strided_slice_shape_func_input_data(*inputs, slice_mode)]

python/tvm/relay/op/transform.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from . import _make
2222
from .dyn import _make as _dyn_make
23+
from .tensor import shape_of
2324
from ..expr import TupleWrapper, const, Expr, Tuple
2425
from ...tir import expr as _expr
2526

@@ -827,13 +828,17 @@ def strided_slice(data, begin, end, strides=None, slice_mode="end"):
827828
ret : relay.Expr
828829
The computed result.
829830
"""
830-
strides = strides or const([1], dtype="int32")
831-
if isinstance(begin, (tuple, list)):
832-
begin = const(list(begin))
833-
if isinstance(end, (tuple, list)):
834-
end = const(list(end))
835-
if isinstance(strides, (tuple, list)):
836-
strides = const(list(strides))
831+
strides = strides or [1]
832+
if (isinstance(begin, Expr) or isinstance(end, Expr) or isinstance(strides, Expr)):
833+
if isinstance(begin, (tuple, list)):
834+
begin = const(list(begin))
835+
if isinstance(end, (tuple, list)):
836+
end = const(list(end))
837+
if isinstance(strides, (tuple, list)):
838+
strides = const(list(strides))
839+
normalized_begin = _make.where(begin < cast_like(const(0), begin),
840+
begin + cast_like(shape_of(data), begin), begin)
841+
return _dyn_make.strided_slice(data, normalized_begin, end, strides, slice_mode)
837842
return _make.strided_slice(data, begin, end, strides, slice_mode)
838843

839844

python/tvm/topi/cuda/conv2d_alter_op.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -276,8 +276,8 @@ def _conv2d_legalize(attrs, inputs, arg_types):
276276
new_attrs['channels'] = new_out_channel
277277
out = tvm.relay.nn.conv2d(data, kernel, **new_attrs)
278278
original_out_shape = [x.value for x in output_tensor.shape]
279-
out = relay.strided_slice(out, begin=relay.const([0, 0, 0, 0]),
280-
end=relay.const(original_out_shape))
279+
out = relay.strided_slice(out, begin=[0, 0, 0, 0],
280+
end=original_out_shape)
281281
else:
282282
out = relay.nn.conv2d(data, kernel, **new_attrs)
283283
return out

python/tvm/topi/x86/conv2d_alter_op.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -313,8 +313,8 @@ def _conv2d_legalize(attrs, inputs, arg_types):
313313
out = tvm.relay.nn.conv2d(data, kernel, **new_attrs)
314314
original_out_shape = [x.value for x in output_tensor.shape]
315315
out = relay.strided_slice(out,
316-
begin=relay.const([0, 0, 0, 0], "int32"),
317-
end=relay.const(original_out_shape, "int32"))
316+
begin=[0, 0, 0, 0],
317+
end=original_out_shape)
318318
else:
319319
out = relay.nn.conv2d(data, kernel, **new_attrs)
320320

0 commit comments

Comments
 (0)