From 6f828e6daad90ed7e0d678d579cf6d766f62f6f9 Mon Sep 17 00:00:00 2001 From: Leyuan Wang Date: Sun, 3 Mar 2019 00:50:25 -0800 Subject: [PATCH 1/6] tile and repeat operator added in rely --- docs/api/python/topi.rst | 4 + docs/langref/relay_op.rst | 4 + include/tvm/relay/attrs/transform.h | 22 +++ python/tvm/relay/frontend/mxnet.py | 22 +++ python/tvm/relay/op/_transform.py | 2 + python/tvm/relay/op/transform.py | 46 ++++++ src/relay/op/tensor/transform.cc | 169 +++++++++++++++++++++++ topi/include/topi/transform.h | 109 +++++++++++++++ topi/python/topi/transform.py | 39 ++++++ topi/src/topi.cc | 10 ++ topi/tests/python/test_topi_transform.py | 56 ++++++++ 11 files changed, 483 insertions(+) diff --git a/docs/api/python/topi.rst b/docs/api/python/topi.rst index f0fc78909258..06f4f0d61f34 100644 --- a/docs/api/python/topi.rst +++ b/docs/api/python/topi.rst @@ -73,6 +73,8 @@ List of operators topi.logical_not topi.arange topi.stack + topi.repeat + topi.tile topi.layout_transform topi.image.resize @@ -132,6 +134,8 @@ topi .. autofunction:: topi.less .. autofunction:: topi.arange .. autofunction:: topi.stack +.. autofunction:: topi.repeat +.. autofunction:: topi.tile .. autofunction:: topi.layout_transform topi.nn diff --git a/docs/langref/relay_op.rst b/docs/langref/relay_op.rst index f706be08009d..0bde4f7f2839 100644 --- a/docs/langref/relay_op.rst +++ b/docs/langref/relay_op.rst @@ -97,6 +97,8 @@ This level enables additional math and transform operators. tvm.relay.split tvm.relay.arange tvm.relay.stack + tvm.relay.repeat + tvm.relay.tile **Level 4: Broadcast and Reductions** @@ -222,6 +224,8 @@ Level 3 Definitions .. autofunction:: tvm.relay.split .. autofunction:: tvm.relay.arange .. autofunction:: tvm.relay.stack +.. autofunction:: tvm.relay.repeat +.. autofunction:: tvm.relay.tile Level 4 Definitions diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index fea2c960d032..5382017d8c1c 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -124,6 +124,28 @@ struct StackAttrs : public tvm::AttrsNode { } }; // struct StackAttrs +/*! \brief Attributes used in repeat operators */ +struct RepeatAttrs : public tvm::AttrsNode { + Integer repeats; + Integer axis; + TVM_DECLARE_ATTRS(RepeatAttrs, "relay.attrs.RepeatAttrs") { + TVM_ATTR_FIELD(repeats) + .describe("The number of repetitions for each element."); + TVM_ATTR_FIELD(axis).set_default(NullValue()) + .describe(" The axis along which to repeat values."); + } +}; // struct RepeatAttrs + +/*! \brief Attributes used in tile operators */ +struct TileAttrs : public tvm::AttrsNode { + Array reps; + TVM_DECLARE_ATTRS(TileAttrs, "relay.attrs.TileAttrs") { + TVM_ATTR_FIELD(reps) + .describe("The number of times for repeating the tensor a." + "Each dim sizeof reps must be a positive integer."); + } +}; // struct TileAttrs + /*! \brief Attributes used in squeeze operators */ struct SqueezeAttrs : public tvm::AttrsNode { // use axis to make the name numpy compatible. diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 2e0ccd07fdc1..f05b579c36dd 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -166,6 +166,10 @@ def _mx_dropout(inputs, attrs): return _op.nn.dropout(inputs[0], rate=rate) +def _mx_BlockGrad(inputs, attrs): + return inputs + + def _mx_batch_norm(inputs, attrs): if attrs.get_bool("output_mean_var", False): raise RuntimeError("batch_norm do not support output_mean_var") @@ -357,6 +361,21 @@ def _mx_arange(inputs, attrs): return _op.arange(**new_attrs) +def _mx_repeat(inputs, attrs): + assert len(inputs) == 1 + new_attrs = {} + new_attrs["repeats"] = attrs.get_int("repeats") + new_attrs["axis"] = attrs.get_int("axis", 0) + return _op.repeat(inputs[0], **new_attrs) + + +def _mx_tile(inputs, attrs): + assert len(inputs) == 1 + new_attrs = {} + new_attrs["reps"] = attrs.get_int_tuple("reps") + return _op.tile(inputs[0], **new_attrs) + + def _mx_roi_align(inputs, attrs): new_attrs = {} new_attrs["pooled_size"] = attrs.get_int_tuple("pooled_size") @@ -490,6 +509,9 @@ def _mx_proposal(inputs, attrs): "batch_dot" : _mx_batch_dot, "LeakyReLU" : _mx_leaky_relu, "_arange" : _mx_arange, + "repeat" : _mx_repeat, + "tile" : _mx_tile, + "BlockGrad" : _mx_BlockGrad, "SoftmaxOutput" : _mx_softmax_output, "SoftmaxActivation" : _mx_softmax_activation, # vision diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 1389f96b8325..2b43c21f8e10 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -19,6 +19,8 @@ _reg.register_schedule("full", schedule_injective) _reg.register_schedule("full_like", schedule_injective) _reg.register_schedule("arange", schedule_injective) +_reg.register_schedule("repeat", schedule_broadcast) +_reg.register_schedule("tile", schedule_broadcast) _reg.register_schedule("cast", schedule_injective) _reg.register_schedule("strided_slice", schedule_injective) _reg.register_schedule("slice_like", schedule_injective) diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 845ee02b0582..28d1409e8e3e 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -316,6 +316,52 @@ def stack(data, axis): return _make.stack(data, axis) +def repeat(data, repeats, axis): + """Repeats elements of an array. + By default, repeat flattens the input array into 1-D and then repeats the elements. + + repeats : int + The number of repetitions for each element. + + axis: int + The axis along which to repeat values. The negative numbers are interpreted + counting from the backward. By default, use the flattened input array, and + return a flat output array. + + Returns + ------- + ret : relay.Expr + The computed result. + """ + return _make.repeat(data, repeats, axis) + + +def tile(data, reps): + """Repeats the whole array multiple times. + + Parameters + ---------- + data : relay.Expr + The input data to the operator. + + reps : tuple of int + The number of times repeating the tensor a. + + .. note:: + Each dim size of reps must be a positive integer. If reps has length d, + the result will have dimension of max(d, a.ndim); If a.ndim < d, a is + promoted to be d-dimensional by prepending new axes. If a.ndim ? d, reps + is promoted to a.ndim by pre-pending 1's to it. + + Returns + ------- + ret : relay.Expr + The computed result. + """ + + return _make.tile(data, reps) + + def where(condition, x, y): """Selecting elements from either x or y depending on the value of the condition. diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index de3ac03977f4..d64eb2d857d0 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1035,6 +1035,175 @@ RELAY_REGISTER_OP("arange") .set_attr("FTVMCompute", ArangeCompute) .set_attr("TOpPattern", kInjective); +// repeat operator +TVM_REGISTER_NODE_TYPE(RepeatAttrs); + +bool RepeatRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + // `types` contains: [data, result] + CHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + if (data == nullptr) { + CHECK(types[0].as()) + << "repeat: expect input type to be TensorType but get " + << types[0]; + return false; + } + const auto* param = attrs.as(); + const int ndim = static_cast(data->shape.size()); + const int repeats = param->repeats; + const int axis = param->axis; + CHECK(repeats >= 1) + << "repeat only accepts `repeats >= 1`" + << ", but got repeats = " << repeats; + CHECK(-ndim - 1 <= axis && axis <= ndim) + << "repeat only accepts `axis` in [-data.ndim - 1, data.ndim]" + << ", but got axis = " << axis + << ", and data.ndim = " << ndim; + const int pivot = axis < 0 ? ndim + axis : axis; + std::vector oshape; + oshape.reserve(ndim + repeats); + for (int i = 0; i < pivot; ++i) { + oshape.emplace_back(data->shape[i]); + } + oshape.emplace_back(data->shape[pivot] * repeats); + for (int i = pivot + 1; i < ndim; ++i) { + oshape.emplace_back(data->shape[i]); + } + reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype)); + return true; +} + +Array RepeatCompute(const Attrs& attrs, + const Array& inputs, + const Type& out_type, + const Target& target) { + const RepeatAttrs *param = attrs.as(); + CHECK(param != nullptr); + return { topi::repeat(inputs[0], param->repeats, param->axis) }; +} + +Expr MakeRepeat(Expr data, + int repeats, + int axis) { + auto attrs = make_node(); + attrs->repeats = repeats; + attrs->axis = axis; + static const Op& op = Op::Get("repeat"); + return CallNode::make(op, {data}, Attrs(attrs), {}); +} + +TVM_REGISTER_API("relay.op._make.repeat") +.set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeRepeat, args, rv); +}); + +RELAY_REGISTER_OP("repeat") +.describe(R"code(Repeat elements of an array `repeats` times along axis `axis` + +- **data**: The input data to the operator. + +)code" TVM_ADD_FILELINE) +.set_num_inputs(1) +.set_attrs_type_key("relay.attrs.Repeat") +.add_argument("data", "Tensor", "The input tensor.") +.set_support_level(1) +.add_type_rel("Repeat", RepeatRel) +.set_attr("FTVMCompute", RepeatCompute) +.set_attr("TOpPattern", kBroadcast); + +// tile operator +TVM_REGISTER_NODE_TYPE(TileAttrs); + +bool TileRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + // `types` contains: [data, result] + CHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + if (data == nullptr) { + CHECK(types[0].as()) + << "tile: expect input type to be TensorType but get " + << types[0]; + return false; + } + const auto* param = attrs.as(); + const int ndim = static_cast(data->shape.size()); + const Array& reps = param->reps; + // check dimension match + CHECK(!reps.defined()) + << "repetition array is not defined. data.ndim = " << ndim; + const int rndim = static_cast(reps.size()); + int tndim = (ndim > rndim) ? ndim : rndim; + // re-construct data shape or reps shape + std::vector data_shape; + std::vector reps_shape; + data_shape.reserve(tndim); + reps_shape.reserve(tndim); + if (ndim == rndim) { + for (int i = 0; i < tndim; ++i) { + data_shape.emplace_back(data->shape[i]); + reps_shape.emplace_back(reps[i]); + } + } else if (ndim > rndim) { + for (int i = 0; i < ndim; ++i) + data_shape.emplace_back(data->shape[i]); + for (int i = 0; i < (ndim - rndim); ++i) + reps_shape.emplace_back(1); + for (int i = 0; i < rndim; ++i) + reps_shape.emplace_back(reps[i]); + } else { + for (int i = 0; i < rndim; ++i) + reps_shape.emplace_back(reps[i]); + } + std::vector oshape; + oshape.reserve(tndim); + for (int i = 0; i < tndim; ++i) { + oshape.emplace_back(data_shape[i] * reps_shape[i]); + } + reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype)); + return true; +} + +Array TileCompute(const Attrs& attrs, + const Array& inputs, + const Type& out_type, + const Target& target) { + const TileAttrs *param = attrs.as(); + CHECK(param != nullptr); + return { topi::tile(inputs[0], param->reps) }; +} + +Expr MakeTile(Expr data, + Array reps) { + auto attrs = make_node(); + attrs->reps = reps; + static const Op& op = Op::Get("tile"); + return CallNode::make(op, {data}, Attrs(attrs), {}); +} + +TVM_REGISTER_API("relay.op._make.tile") +.set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeTile, args, rv); +}); + +RELAY_REGISTER_OP("tile") +.describe(R"code(Repeat the whole array multiple times. + +- **data**: The input data to the operator. + +)code" TVM_ADD_FILELINE) +.set_num_inputs(1) +.set_attrs_type_key("relay.attrs.Tile") +.add_argument("data", "Tensor", "The input tensor.") +.set_support_level(1) +.add_type_rel("Tile", TileRel) +.set_attr("FTVMCompute", TileCompute) +.set_attr("TOpPattern", kBroadcast); + // where operator bool WhereRel(const Array& types, int num_inputs, diff --git a/topi/include/topi/transform.h b/topi/include/topi/transform.h index fc686f88dba6..db7a434062c7 100644 --- a/topi/include/topi/transform.h +++ b/topi/include/topi/transform.h @@ -719,6 +719,115 @@ inline Tensor where(const Tensor& condition, return out; } +/*! +* \brief Creates an operation to repeat elements of an array +* +* \param x The input tensor +* \param repeats The number of repetitions for each element +* \param axis The axis along which to repeat values (allows +* negative indices as offsets from the last dimension) +* \param name The name of the operation +* \param tag The tag to mark the operation +* +* \return A Tensor whose op member is the repeat operation +*/ +inline Tensor repeat(const Tensor& x, + int repeats, + int axis, + std::string name = "tensor", + std::string tag = kBroadcast) { + int ndim = static_cast(x->shape.size()); + CHECK(-ndim - 1 <= axis && axis <= ndim) + << "repeat only accepts `axis` in [-data.ndim - 1, data.ndim]" + << ", but got axis = " << axis + << ", and data.ndim = " << ndim; + CHECK(repeats >= 1) + << "repeat only accepts `repeats >= 1`" + << ", but got repeats = " << repeats; + if (axis < 0) { + // Calculate offset from last dimension + axis += ndim; + } + Array new_shape; + for (size_t i = 0; i < static_cast(axis); ++i) { + new_shape.push_back(x->shape[i]); + } + new_shape.push_back(repeats * x->shape[axis]); + for (size_t i = axis + 1; i < x->shape.size(); ++i) { + new_shape.push_back(x->shape[i]); + } + + return compute( + new_shape, [&](const Array& indices) { + Array idx; + for (size_t i = 0; i < static_cast(axis); ++i) { + idx.push_back(indices[i]); + } + idx.push_back(indices[axis] / repeats); + for (size_t i = axis + 1; i < indices.size(); ++i) { + idx.push_back(indices[i]); + } + return x(idx); + }, name, tag); +} + +/*! +* \brief Creates an operation to tile elements of an array +* +* \param x The input tensor +* \param reps The number of times for repeating the tensor +* \param name The name of the operation +* \param tag The tag to mark the operation +* +* \return A Tensor whose op member is the tile operation +*/ +inline Tensor tile(const Tensor& x, + Array reps, + std::string name = "tensor", + std::string tag = kBroadcast) { + int ndim = static_cast(x->shape.size()); + int rdim = static_cast(reps.size()); + int tdim = (ndim > rdim) ? ndim : rdim; + Array data_shape; + Array reps_shape; + Array new_shape; + if (ndim == rdim) { + for (size_t i = 0; i < ndim; ++i) { + data_shape.push_back(x->shape[i]); + reps_shape.push_back(reps[i]); + } + } else if (ndim > rdim) { + for (size_t i = 0; i < ndim; ++i) + data_shape.push_back(x->shape[i]); + for (size_t i = 0; i < ndim - rdim; ++i) + reps_shape.push_back(1); + for (size_t i = 0; i < rdim; ++i) + reps_shape.push_back(reps[i]); + } else { + for (size_t i = 0; i < rdim - ndim; ++i) + data_shape.push_back(1); + for (size_t i = 0; i < ndim; ++i) + data_shape.push_back(x->shape[i]); + for (size_t i = 0; i < rdim; ++i) + reps_shape.push_back(reps[i]); + } + for (size_t i = 0; i < tdim; ++i) + new_shape.push_back(data_shape[i] * reps_shape[i]); + + return compute( + new_shape, [&](const Array& indices) { + Array idx; + if (ndim >= rdim) { + for (size_t i = 0; i < ndim; ++i) + idx.push_back(indices[i] % x->shape[i]); + } else { + for (size_t i = 0; i < ndim; ++i) + idx.push_back(indices[rdim - ndim + i] % x->shape[i]); + } + return x(idx); + }, name, tag); +} + /*! * \brief Gather elements from a n-dimension array. * diff --git a/topi/python/topi/transform.py b/topi/python/topi/transform.py index 2ddfee2806a5..063556852d26 100644 --- a/topi/python/topi/transform.py +++ b/topi/python/topi/transform.py @@ -339,6 +339,45 @@ def arange(start, stop=None, step=1, dtype="float32"): return cpp.arange(start, stop, step, dtype) +def repeat(a, repeats, axis): + """Repeats elements of an array. + + Parameters + ---------- + a : tvm.Tensor + The tensor to be repeated. + + repeats: int, required + Number of repetitions for each element + + axis: int, optional + The axis along which to repeat values + + Returns + ------- + ret : tvm.Tensor + """ + return cpp.repeat(a, repeats, axis) + + +def tile(a, reps): + """Repeats the whole array multiple times. + + Parameters + ---------- + a : tvm.Tensor + The tensor to be tiled. + + reps: tuple of ints, required + The number of times for repeating the tensor + + Returns + ------- + ret : tvm.Tensor + """ + return cpp.tile(a, reps) + + def layout_transform(array, src_layout, dst_layout): """Transform the layout according to src_layout and dst_layout diff --git a/topi/src/topi.cc b/topi/src/topi.cc index 3630c4cf3b85..14f92460fd25 100644 --- a/topi/src/topi.cc +++ b/topi/src/topi.cc @@ -305,6 +305,16 @@ TVM_REGISTER_GLOBAL("topi.arange") *rv = arange(args[0], args[1], args[2], args[3]); }); +TVM_REGISTER_GLOBAL("topi.repeat") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = repeat(args[0], args[1], args[2]); +}); + +TVM_REGISTER_GLOBAL("topi.tile") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = tile(args[0], args[1]); +}); + TVM_REGISTER_GLOBAL("topi.gather_nd") .set_body([](TVMArgs args, TVMRetValue *rv) { *rv = gather_nd(args[0], args[1]); diff --git a/topi/tests/python/test_topi_transform.py b/topi/tests/python/test_topi_transform.py index 66c75854193f..785da6fddbcf 100644 --- a/topi/tests/python/test_topi_transform.py +++ b/topi/tests/python/test_topi_transform.py @@ -359,6 +359,50 @@ def check_device(device): for device in get_all_backend(): check_device(device) +def verify_repeat(in_shape, repeats, axis): + A = tvm.placeholder(shape=in_shape, name="A") + B = topi.repeat(A, repeats, axis) + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + with tvm.target.create(device): + s = topi.generic.schedule_broadcast(B) + foo = tvm.build(s, [A, B], device, name="repeat") + data_npy = np.random.uniform(size=in_shape).astype(A.dtype) + out_npy = np.repeat(data_npy, repeats, axis) + data_nd = tvm.nd.array(data_npy, ctx) + out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(B.dtype), ctx) + foo(data_nd, out_nd) + tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy) + + for device in get_all_backend(): + check_device(device) + +def verify_tile(in_shape, reps): + A = tvm.placeholder(shape=in_shape, name="A") + B = topi.tile(A, reps) + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + with tvm.target.create(device): + s = topi.generic.schedule_broadcast(B) + foo = tvm.build(s, [A, B], device, name="tile") + data_npy = np.random.uniform(size=in_shape).astype(A.dtype) + out_npy = np.tile(data_npy, reps) + data_nd = tvm.nd.array(data_npy, ctx) + out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(B.dtype), ctx) + foo(data_nd, out_nd) + tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy) + + for device in get_all_backend(): + check_device(device) + def test_strided_slice(): verify_strided_slice((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2]) verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1]) @@ -481,6 +525,16 @@ def test_arange(): verify_arange(20, 1, -1) verify_arange(20, 1, -1.5) +def test_repeat(): + verify_repeat((2,), 1, 0) + verify_repeat((3, 2), 2, 0) + verify_repeat((3, 2, 4), 3, 1) + verify_repeat((1, 3, 2, 4), 4, -1) + +def test_tile(): + verify_tile((3, 2), (2, 3)) + verify_tile((3, 2, 5), (2,)) + verify_tile((3, ), (2, 3, 3)) def test_layout_transform(): in_shape = (1, 32, 8, 8) @@ -525,3 +579,5 @@ def check_device(device): test_gather_nd() test_arange() test_layout_transform() + test_repeat() + test_tile() From e67c8b0b4d2b0c430366df929e0b983bd85adaf8 Mon Sep 17 00:00:00 2001 From: Leyuan Wang Date: Sun, 3 Mar 2019 18:00:56 -0800 Subject: [PATCH 2/6] fix pylint --- python/tvm/relay/frontend/mxnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index f05b579c36dd..e95d0455f1fb 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -166,7 +166,7 @@ def _mx_dropout(inputs, attrs): return _op.nn.dropout(inputs[0], rate=rate) -def _mx_BlockGrad(inputs, attrs): +def _mx_BlockGrad(inputs, attrs): #pylint: disable=unused-argument return inputs From 86d035e1c1851060d201327269b98d59d60c6c6d Mon Sep 17 00:00:00 2001 From: Leyuan Wang Date: Sun, 3 Mar 2019 18:15:57 -0800 Subject: [PATCH 3/6] fix make warnings --- topi/include/topi/transform.h | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/topi/include/topi/transform.h b/topi/include/topi/transform.h index db7a434062c7..a01e4c7c5d4e 100644 --- a/topi/include/topi/transform.h +++ b/topi/include/topi/transform.h @@ -792,36 +792,36 @@ inline Tensor tile(const Tensor& x, Array reps_shape; Array new_shape; if (ndim == rdim) { - for (size_t i = 0; i < ndim; ++i) { + for (size_t i = 0; i < static_cast(ndim); ++i) { data_shape.push_back(x->shape[i]); reps_shape.push_back(reps[i]); } } else if (ndim > rdim) { - for (size_t i = 0; i < ndim; ++i) + for (size_t i = 0; i < static_cast(ndim); ++i) data_shape.push_back(x->shape[i]); - for (size_t i = 0; i < ndim - rdim; ++i) + for (size_t i = 0; i < static_cast(ndim - rdim); ++i) reps_shape.push_back(1); - for (size_t i = 0; i < rdim; ++i) + for (size_t i = 0; i < static_cast(rdim); ++i) reps_shape.push_back(reps[i]); } else { - for (size_t i = 0; i < rdim - ndim; ++i) + for (size_t i = 0; i < static_cast(rdim - ndim); ++i) data_shape.push_back(1); - for (size_t i = 0; i < ndim; ++i) + for (size_t i = 0; i < static_cast(ndim); ++i) data_shape.push_back(x->shape[i]); - for (size_t i = 0; i < rdim; ++i) + for (size_t i = 0; i < static_cast(rdim); ++i) reps_shape.push_back(reps[i]); } - for (size_t i = 0; i < tdim; ++i) + for (size_t i = 0; i < static_cast(tdim); ++i) new_shape.push_back(data_shape[i] * reps_shape[i]); return compute( new_shape, [&](const Array& indices) { Array idx; if (ndim >= rdim) { - for (size_t i = 0; i < ndim; ++i) + for (size_t i = 0; i < static_cast(ndim); ++i) idx.push_back(indices[i] % x->shape[i]); } else { - for (size_t i = 0; i < ndim; ++i) + for (size_t i = 0; i < static_cast(ndim); ++i) idx.push_back(indices[rdim - ndim + i] % x->shape[i]); } return x(idx); From d50d61c19f3c51927bec2d7725de56bc4486e3f6 Mon Sep 17 00:00:00 2001 From: Leyuan Wang Date: Fri, 8 Mar 2019 14:52:44 -0800 Subject: [PATCH 4/6] comments addressed --- python/tvm/relay/op/transform.py | 31 +++++++++++++++++++++++++++---- src/relay/op/tensor/transform.cc | 24 ++++++++++++------------ 2 files changed, 39 insertions(+), 16 deletions(-) diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 28d1409e8e3e..832bebb1909b 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -332,6 +332,16 @@ def repeat(data, repeats, axis): ------- ret : relay.Expr The computed result. + + Examples + -------- + .. code-block:: python + + x = [[1, 2], [3, 4]] + relay.repeat(x, repeats=2) = [1., 1., 2., 2., 3., 3., 4., 4.] + + relay.repeat(x, repeats=2, axis=1) = [[1., 1., 2., 2.], + [3., 3., 4., 4.]] """ return _make.repeat(data, repeats, axis) @@ -345,18 +355,31 @@ def tile(data, reps): The input data to the operator. reps : tuple of int - The number of times repeating the tensor a. + The number of times repeating the tensor data. .. note:: Each dim size of reps must be a positive integer. If reps has length d, - the result will have dimension of max(d, a.ndim); If a.ndim < d, a is - promoted to be d-dimensional by prepending new axes. If a.ndim ? d, reps - is promoted to a.ndim by pre-pending 1's to it. + the result will have dimension of max(d, data.ndim); If data.ndim < d, + data is promoted to be d-dimensional by prepending new axes. + If data.ndim >= d, reps is promoted to a.ndim by pre-pending 1's to it. Returns ------- ret : relay.Expr The computed result. + + Examples + -------- + .. code-block:: python + + x = [[1, 2], [3, 4]] + relay.tile(x, reps=(2,3)) = [[1., 2., 1., 2., 1., 2.], + [3., 4., 3., 4., 3., 4.], + [1., 2., 1., 2., 1., 2.], + [3., 4., 3., 4., 3., 4.]] + + relay.tile(x, reps=(2,)) = [[1., 2., 1., 2.], + [3., 4., 3., 4.]] """ return _make.tile(data, reps) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index d64eb2d857d0..3566001769e1 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1039,9 +1039,9 @@ RELAY_REGISTER_OP("arange") TVM_REGISTER_NODE_TYPE(RepeatAttrs); bool RepeatRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { // `types` contains: [data, result] CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); @@ -1131,37 +1131,37 @@ bool TileRel(const Array& types, return false; } const auto* param = attrs.as(); - const int ndim = static_cast(data->shape.size()); + const size_t ndim = static_cast(data->shape.size()); const Array& reps = param->reps; // check dimension match CHECK(!reps.defined()) << "repetition array is not defined. data.ndim = " << ndim; - const int rndim = static_cast(reps.size()); - int tndim = (ndim > rndim) ? ndim : rndim; + const size_t rndim = static_cast(reps.size()); + size_t tndim = (ndim > rndim) ? ndim : rndim; // re-construct data shape or reps shape std::vector data_shape; std::vector reps_shape; data_shape.reserve(tndim); reps_shape.reserve(tndim); if (ndim == rndim) { - for (int i = 0; i < tndim; ++i) { + for (size_t i = 0; i < tndim; ++i) { data_shape.emplace_back(data->shape[i]); reps_shape.emplace_back(reps[i]); } } else if (ndim > rndim) { - for (int i = 0; i < ndim; ++i) + for (size_t i = 0; i < ndim; ++i) data_shape.emplace_back(data->shape[i]); - for (int i = 0; i < (ndim - rndim); ++i) + for (size_t i = 0; i < (ndim - rndim); ++i) reps_shape.emplace_back(1); - for (int i = 0; i < rndim; ++i) + for (size_t i = 0; i < rndim; ++i) reps_shape.emplace_back(reps[i]); } else { - for (int i = 0; i < rndim; ++i) + for (size_t i = 0; i < rndim; ++i) reps_shape.emplace_back(reps[i]); } std::vector oshape; oshape.reserve(tndim); - for (int i = 0; i < tndim; ++i) { + for (size_t i = 0; i < tndim; ++i) { oshape.emplace_back(data_shape[i] * reps_shape[i]); } reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype)); From 3c31cca2babedf4479258c6dbb2f8134fc8c6864 Mon Sep 17 00:00:00 2001 From: Leyuan Wang Date: Fri, 8 Mar 2019 14:56:25 -0800 Subject: [PATCH 5/6] fix lint error --- python/tvm/relay/op/transform.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 832bebb1909b..75f1bdc60174 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -359,8 +359,8 @@ def tile(data, reps): .. note:: Each dim size of reps must be a positive integer. If reps has length d, - the result will have dimension of max(d, data.ndim); If data.ndim < d, - data is promoted to be d-dimensional by prepending new axes. + the result will have dimension of max(d, data.ndim); If data.ndim < d, + data is promoted to be d-dimensional by prepending new axes. If data.ndim >= d, reps is promoted to a.ndim by pre-pending 1's to it. Returns From 540195aee18ab9434a6a0c52051d11fdcc7ff4af Mon Sep 17 00:00:00 2001 From: Leyuan Wang Date: Sat, 9 Mar 2019 14:19:19 -0800 Subject: [PATCH 6/6] comment addressed --- src/relay/op/tensor/transform.cc | 4 ++-- topi/include/topi/transform.h | 26 +++++++++++++------------- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 3566001769e1..142a16b0b307 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1131,12 +1131,12 @@ bool TileRel(const Array& types, return false; } const auto* param = attrs.as(); - const size_t ndim = static_cast(data->shape.size()); + const size_t ndim = data->shape.size(); const Array& reps = param->reps; // check dimension match CHECK(!reps.defined()) << "repetition array is not defined. data.ndim = " << ndim; - const size_t rndim = static_cast(reps.size()); + const size_t rndim = reps.size(); size_t tndim = (ndim > rndim) ? ndim : rndim; // re-construct data shape or reps shape std::vector data_shape; diff --git a/topi/include/topi/transform.h b/topi/include/topi/transform.h index a01e4c7c5d4e..06327dac69f4 100644 --- a/topi/include/topi/transform.h +++ b/topi/include/topi/transform.h @@ -785,43 +785,43 @@ inline Tensor tile(const Tensor& x, Array reps, std::string name = "tensor", std::string tag = kBroadcast) { - int ndim = static_cast(x->shape.size()); - int rdim = static_cast(reps.size()); - int tdim = (ndim > rdim) ? ndim : rdim; + size_t ndim = x->shape.size(); + size_t rdim = reps.size(); + size_t tdim = (ndim > rdim) ? ndim : rdim; Array data_shape; Array reps_shape; Array new_shape; if (ndim == rdim) { - for (size_t i = 0; i < static_cast(ndim); ++i) { + for (size_t i = 0; i < ndim; ++i) { data_shape.push_back(x->shape[i]); reps_shape.push_back(reps[i]); } } else if (ndim > rdim) { - for (size_t i = 0; i < static_cast(ndim); ++i) + for (size_t i = 0; i < ndim; ++i) data_shape.push_back(x->shape[i]); - for (size_t i = 0; i < static_cast(ndim - rdim); ++i) + for (size_t i = 0; i < (ndim - rdim); ++i) reps_shape.push_back(1); - for (size_t i = 0; i < static_cast(rdim); ++i) + for (size_t i = 0; i < rdim; ++i) reps_shape.push_back(reps[i]); } else { - for (size_t i = 0; i < static_cast(rdim - ndim); ++i) + for (size_t i = 0; i < (rdim - ndim); ++i) data_shape.push_back(1); - for (size_t i = 0; i < static_cast(ndim); ++i) + for (size_t i = 0; i < ndim; ++i) data_shape.push_back(x->shape[i]); - for (size_t i = 0; i < static_cast(rdim); ++i) + for (size_t i = 0; i < rdim; ++i) reps_shape.push_back(reps[i]); } - for (size_t i = 0; i < static_cast(tdim); ++i) + for (size_t i = 0; i < tdim; ++i) new_shape.push_back(data_shape[i] * reps_shape[i]); return compute( new_shape, [&](const Array& indices) { Array idx; if (ndim >= rdim) { - for (size_t i = 0; i < static_cast(ndim); ++i) + for (size_t i = 0; i < ndim; ++i) idx.push_back(indices[i] % x->shape[i]); } else { - for (size_t i = 0; i < static_cast(ndim); ++i) + for (size_t i = 0; i < ndim; ++i) idx.push_back(indices[rdim - ndim + i] % x->shape[i]); } return x(idx);