Skip to content

Commit ad7a37f

Browse files
author
Matthew Brookhart
committed
respond to review comments
1 parent 19c70a8 commit ad7a37f

File tree

5 files changed

+24
-23
lines changed

5 files changed

+24
-23
lines changed

python/tvm/relay/op/transform.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -836,9 +836,9 @@ def strided_slice(data, begin, end, strides=None, slice_mode="end"):
836836
end = const(list(end))
837837
if isinstance(strides, (tuple, list)):
838838
strides = const(list(strides))
839-
begin = _make.where(begin < cast_like(const(0), begin),
840-
begin + cast_like(shape_of(data), begin), begin)
841-
return _dyn_make.strided_slice(data, begin, end, strides, slice_mode)
839+
normalized_begin = _make.where(begin < cast_like(const(0), begin),
840+
begin + cast_like(shape_of(data), begin), begin)
841+
return _dyn_make.strided_slice(data, normalized_begin, end, strides, slice_mode)
842842
return _make.strided_slice(data, begin, end, strides, slice_mode)
843843

844844

src/relay/op/dyn/tensor/transform.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,7 @@ RELAY_REGISTER_OP("dyn.full")
436436

437437
bool StridedSliceRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
438438
const TypeReporter& reporter) {
439+
// [data, begin, end, strides, out]
439440
CHECK_EQ(types.size(), 5);
440441
const StridedSliceAttrs* param = attrs.as<StridedSliceAttrs>();
441442
if (param == nullptr) {
@@ -487,12 +488,12 @@ Array<te::Tensor> StridedSliceCompute(const Attrs& attrs, const Array<te::Tensor
487488
te::Tensor end = inputs[2];
488489
te::Tensor strides = inputs[3];
489490
// Dynamic computation
490-
int64_t attr_size = data->shape.size();
491-
CHECK(begin->shape[0].as<IntImmNode>()->value == attr_size &&
492-
end->shape[0].as<IntImmNode>()->value == attr_size &&
493-
strides->shape[0].as<IntImmNode>()->value == attr_size)
491+
int64_t data_rank = data->shape.size();
492+
CHECK(begin->shape[0].as<IntImmNode>()->value == data_rank &&
493+
end->shape[0].as<IntImmNode>()->value == data_rank &&
494+
strides->shape[0].as<IntImmNode>()->value == data_rank)
494495
<< "begin, end, and strides are required to have the same length"
495-
<< " if they are non-constant.";
496+
<< " if they are dynamic variables.";
496497
return Array<te::Tensor>{DynamicStridedSlice(data, begin, end, strides)};
497498
}
498499

src/relay/op/tensor/transform.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2070,7 +2070,9 @@ bool StridedSliceRel(const Array<Type>& types, int num_inputs, const Attrs& attr
20702070
oshape[i] = tir::make_const(dshape[i].dtype(), (slice_range + step - 1) / step);
20712071
}
20722072
} else {
2073-
CHECK(false) << "strided_slice recieved invalid params";
2073+
CHECK(param->begin) << "strided_slice recieved invalid begin";
2074+
CHECK(param->end) << "strided_slice recieved invalid end";
2075+
CHECK(param->strides) << "strided_slice recieved invalid strides";
20742076
}
20752077
reporter->Assign(types[1], TensorType(oshape, data->dtype));
20762078
return true;

src/relay/transforms/dynamic_to_static.cc

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -141,19 +141,17 @@ class DynamicToStaticMutator : public MixedModeMutator {
141141
}},
142142
{Op::Get("dyn.strided_slice"),
143143
[](const CallNode* call_node) {
144-
if (const ConstantNode* begin = call_node->args[1].as<ConstantNode>()) {
145-
if (const ConstantNode* end = call_node->args[2].as<ConstantNode>()) {
146-
if (const ConstantNode* stride = call_node->args[3].as<ConstantNode>()) {
147-
CHECK_EQ(begin->data->ndim, 1);
148-
CHECK_EQ(end->data->ndim, 1);
149-
CHECK_EQ(stride->data->ndim, 1);
150-
const StridedSliceAttrs* param = call_node->attrs.as<StridedSliceAttrs>();
151-
CHECK(param);
152-
return MakeStridedSlice(call_node->args[0], ToVector(begin->data),
153-
ToVector(end->data), ToVector(stride->data),
154-
param->slice_mode);
155-
}
156-
}
144+
const ConstantNode* begin = call_node->args[1].as<ConstantNode>();
145+
const ConstantNode* end = call_node->args[2].as<ConstantNode>();
146+
const ConstantNode* stride = call_node->args[3].as<ConstantNode>();
147+
if (begin && end && stride) {
148+
CHECK_EQ(begin->data->ndim, 1);
149+
CHECK_EQ(end->data->ndim, 1);
150+
CHECK_EQ(stride->data->ndim, 1);
151+
const StridedSliceAttrs* param = call_node->attrs.as<StridedSliceAttrs>();
152+
CHECK(param);
153+
return MakeStridedSlice(call_node->args[0], ToVector(begin->data), ToVector(end->data),
154+
ToVector(stride->data), param->slice_mode);
157155
}
158156
return Expr(nullptr);
159157
}},

tests/python/relay/test_op_level4.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ def verify(dshape, begin, end, strides, output, slice_mode="end",
337337
text = func.astext()
338338
assert "begin=" in text
339339
assert "end=" in text
340-
340+
341341
if output:
342342
assert func.body.checked_type == relay.ty.TensorType(output, "float32")
343343

0 commit comments

Comments
 (0)