Skip to content

Commit c6f7758

Browse files
committed
tile and repeat operator added in rely
1 parent c8373ec commit c6f7758

File tree

11 files changed

+487
-0
lines changed

11 files changed

+487
-0
lines changed

docs/api/python/topi.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ List of operators
7272
topi.logical_or
7373
topi.logical_not
7474
topi.arange
75+
topi.repeat
76+
topi.tile
7577
topi.layout_transform
7678
topi.image.resize
7779

@@ -130,6 +132,8 @@ topi
130132
.. autofunction:: topi.greater
131133
.. autofunction:: topi.less
132134
.. autofunction:: topi.arange
135+
.. autofunction:: topi.repeat
136+
.. autofunction:: topi.tile
133137
.. autofunction:: topi.layout_transform
134138

135139
topi.nn

docs/langref/relay_op.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@ This level enables additional math and transform operators.
9696
tvm.relay.cast
9797
tvm.relay.split
9898
tvm.relay.arange
99+
tvm.relay.repeat
100+
tvm.relay.tile
99101

100102

101103
**Level 4: Broadcast and Reductions**
@@ -220,6 +222,8 @@ Level 3 Definitions
220222
.. autofunction:: tvm.relay.cast
221223
.. autofunction:: tvm.relay.split
222224
.. autofunction:: tvm.relay.arange
225+
.. autofunction:: tvm.relay.repeat
226+
.. autofunction:: tvm.relay.tile
223227

224228

225229
Level 4 Definitions

include/tvm/relay/attrs/transform.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,28 @@ struct ArangeAttrs : public tvm::AttrsNode<ArangeAttrs> {
115115
}
116116
}; // struct ArangeAttrs
117117

118+
/*! \brief Attributes used in repeat operators */
119+
struct RepeatAttrs : public tvm::AttrsNode<RepeatAttrs> {
120+
Integer repeats;
121+
Integer axis;
122+
TVM_DECLARE_ATTRS(RepeatAttrs, "relay.attrs.RepeatAttrs") {
123+
TVM_ATTR_FIELD(repeats)
124+
.describe("The number of repetitions for each element.");
125+
TVM_ATTR_FIELD(axis).set_default(NullValue<Integer>())
126+
.describe(" The axis along which to repeat values.");
127+
}
128+
}; // struct RepeatAttrs
129+
130+
/*! \brief Attributes used in tile operators */
131+
struct TileAttrs : public tvm::AttrsNode<TileAttrs> {
132+
Array<Integer> reps;
133+
TVM_DECLARE_ATTRS(TileAttrs, "relay.attrs.TileAttrs") {
134+
TVM_ATTR_FIELD(reps)
135+
.describe("The number of times for repeating the tensor a."
136+
"Each dim sizeof reps must be a positive integer.");
137+
}
138+
}; // struct TileAttrs
139+
118140
/*! \brief Attributes used in squeeze operators */
119141
struct SqueezeAttrs : public tvm::AttrsNode<SqueezeAttrs> {
120142
// use axis to make the name numpy compatible.

python/tvm/relay/frontend/mxnet.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,10 @@ def _mx_dropout(inputs, attrs):
166166
return _op.nn.dropout(inputs[0], rate=rate)
167167

168168

169+
def _mx_BlockGrad(inputs, attrs):
170+
return inputs
171+
172+
169173
def _mx_batch_norm(inputs, attrs):
170174
if attrs.get_bool("output_mean_var", False):
171175
raise RuntimeError("batch_norm do not support output_mean_var")
@@ -314,6 +318,21 @@ def _mx_arange(inputs, attrs):
314318
return _op.arange(**new_attrs)
315319

316320

321+
def _mx_repeat(inputs, attrs):
322+
assert len(inputs) == 1
323+
new_attrs = {}
324+
new_attrs["repeats"] = attrs.get_int("repeats")
325+
new_attrs["axis"] = attrs.get_int("axis", 0)
326+
return _op.repeat(inputs[0], **new_attrs)
327+
328+
329+
def _mx_tile(inputs, attrs):
330+
assert len(inputs) == 1
331+
new_attrs = {}
332+
new_attrs["reps"] = attrs.get_int_tuple("reps")
333+
return _op.tile(inputs[0], **new_attrs)
334+
335+
317336
def _mx_roi_align(inputs, attrs):
318337
new_attrs = {}
319338
new_attrs["pooled_size"] = attrs.get_int_tuple("pooled_size")
@@ -431,6 +450,9 @@ def _mx_roi_align(inputs, attrs):
431450
"batch_dot" : _mx_batch_dot,
432451
"LeakyReLU" : _mx_leaky_relu,
433452
"_arange" : _mx_arange,
453+
"repeat" : _mx_repeat,
454+
"tile" : _mx_tile,
455+
"BlockGrad" : _mx_BlockGrad,
434456
"SoftmaxOutput" : _mx_softmax_output,
435457
"SoftmaxActivation" : _mx_softmax_activation,
436458
# vision

python/tvm/relay/op/_transform.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
_reg.register_schedule("full", schedule_injective)
2121
_reg.register_schedule("full_like", schedule_injective)
2222
_reg.register_schedule("arange", schedule_injective)
23+
_reg.register_schedule("repeat", schedule_broadcast)
24+
_reg.register_schedule("tile", schedule_broadcast)
2325
_reg.register_schedule("cast", schedule_injective)
2426
_reg.register_schedule("strided_slice", schedule_injective)
2527
_reg.register_schedule("slice_like", schedule_injective)

python/tvm/relay/op/transform.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,56 @@ def arange(start, stop=None, step=1, dtype="float32"):
294294
return _make.arange(start, stop, step, dtype)
295295

296296

297+
def repeat(data, repeats, axis):
298+
"""Repeats elements of an array.
299+
By default, repeat flattens the input array into 1-D and then repeats the elements.
300+
301+
Parameters
302+
----------
303+
data : relay.Expr
304+
The input data to the operator.
305+
306+
repeats : int
307+
The number of repetitions for each element.
308+
309+
axis: int
310+
The axis along which to repeat values. The negative numbers are interpreted
311+
counting from the backward. By default, use the flattened input array, and
312+
return a flat output array.
313+
314+
Returns
315+
-------
316+
ret : relay.Expr
317+
The computed result.
318+
"""
319+
return _make.repeat(data, repeats, axis)
320+
321+
322+
def tile(data, reps):
323+
"""Repeats the whole array multiple times.
324+
325+
Parameters
326+
----------
327+
data : relay.Expr
328+
The input data to the operator.
329+
330+
reps : tuple of int
331+
The number of times repeating the tensor a.
332+
333+
.. note::
334+
Each dim size of reps must be a positive integer. If reps has length d,
335+
the result will have dimension of max(d, a.ndim); If a.ndim < d, a is
336+
promoted to be d-dimensional by prepending new axes. If a.ndim ? d, reps
337+
is promoted to a.ndim by pre-pending 1's to it.
338+
339+
Returns
340+
-------
341+
ret : relay.Expr
342+
The computed result.
343+
"""
344+
return _make.tile(data, reps)
345+
346+
297347
def where(condition, x, y):
298348
"""Selecting elements from either x or y depending on the value of the
299349
condition.

src/relay/op/tensor/transform.cc

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -937,6 +937,175 @@ RELAY_REGISTER_OP("arange")
937937
.set_attr<FTVMCompute>("FTVMCompute", ArangeCompute)
938938
.set_attr<TOpPattern>("TOpPattern", kInjective);
939939

940+
// repeat operator
941+
TVM_REGISTER_NODE_TYPE(RepeatAttrs);
942+
943+
bool RepeatRel(const Array<Type>& types,
944+
int num_inputs,
945+
const Attrs& attrs,
946+
const TypeReporter& reporter) {
947+
// `types` contains: [data, result]
948+
CHECK_EQ(types.size(), 2);
949+
const auto* data = types[0].as<TensorTypeNode>();
950+
if (data == nullptr) {
951+
CHECK(types[0].as<IncompleteTypeNode>())
952+
<< "repeat: expect input type to be TensorType but get "
953+
<< types[0];
954+
return false;
955+
}
956+
const auto* param = attrs.as<RepeatAttrs>();
957+
const int ndim = static_cast<int>(data->shape.size());
958+
const int repeats = param->repeats;
959+
const int axis = param->axis;
960+
CHECK(repeats >= 1)
961+
<< "repeat only accepts `repeats >= 1`"
962+
<< ", but got repeats = " << repeats;
963+
CHECK(-ndim - 1 <= axis && axis <= ndim)
964+
<< "repeat only accepts `axis` in [-data.ndim - 1, data.ndim]"
965+
<< ", but got axis = " << axis
966+
<< ", and data.ndim = " << ndim;
967+
const int pivot = axis < 0 ? ndim + axis : axis;
968+
std::vector<IndexExpr> oshape;
969+
oshape.reserve(ndim + repeats);
970+
for (int i = 0; i < pivot; ++i) {
971+
oshape.emplace_back(data->shape[i]);
972+
}
973+
oshape.emplace_back(data->shape[pivot] * repeats);
974+
for (int i = pivot + 1; i < ndim; ++i) {
975+
oshape.emplace_back(data->shape[i]);
976+
}
977+
reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
978+
return true;
979+
}
980+
981+
Array<Tensor> RepeatCompute(const Attrs& attrs,
982+
const Array<Tensor>& inputs,
983+
const Type& out_type,
984+
const Target& target) {
985+
const RepeatAttrs *param = attrs.as<RepeatAttrs>();
986+
CHECK(param != nullptr);
987+
return { topi::repeat(inputs[0], param->repeats, param->axis) };
988+
}
989+
990+
Expr MakeRepeat(Expr data,
991+
int repeats,
992+
int axis) {
993+
auto attrs = make_node<RepeatAttrs>();
994+
attrs->repeats = repeats;
995+
attrs->axis = axis;
996+
static const Op& op = Op::Get("repeat");
997+
return CallNode::make(op, {data}, Attrs(attrs), {});
998+
}
999+
1000+
TVM_REGISTER_API("relay.op._make.repeat")
1001+
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
1002+
runtime::detail::unpack_call<Expr, 3>(MakeRepeat, args, rv);
1003+
});
1004+
1005+
RELAY_REGISTER_OP("repeat")
1006+
.describe(R"code(Repeat elements of an array `repeats` times along axis `axis`
1007+
1008+
- **data**: The input data to the operator.
1009+
1010+
)code" TVM_ADD_FILELINE)
1011+
.set_num_inputs(1)
1012+
.set_attrs_type_key("relay.attrs.Repeat")
1013+
.add_argument("data", "Tensor", "The input tensor.")
1014+
.set_support_level(1)
1015+
.add_type_rel("Repeat", RepeatRel)
1016+
.set_attr<FTVMCompute>("FTVMCompute", RepeatCompute)
1017+
.set_attr<TOpPattern>("TOpPattern", kBroadcast);
1018+
1019+
// tile operator
1020+
TVM_REGISTER_NODE_TYPE(TileAttrs);
1021+
1022+
bool TileRel(const Array<Type>& types,
1023+
int num_inputs,
1024+
const Attrs& attrs,
1025+
const TypeReporter& reporter) {
1026+
// `types` contains: [data, result]
1027+
CHECK_EQ(types.size(), 2);
1028+
const auto* data = types[0].as<TensorTypeNode>();
1029+
if (data == nullptr) {
1030+
CHECK(types[0].as<IncompleteTypeNode>())
1031+
<< "tile: expect input type to be TensorType but get "
1032+
<< types[0];
1033+
return false;
1034+
}
1035+
const auto* param = attrs.as<TileAttrs>();
1036+
const int ndim = static_cast<int>(data->shape.size());
1037+
const Array<Integer>& reps = param->reps;
1038+
// check dimension match
1039+
CHECK(!reps.defined())
1040+
<< "repetition array is not defined. data.ndim = " << ndim;
1041+
const int rndim = static_cast<int>(reps.size());
1042+
int tndim = (ndim > rndim) ? ndim : rndim;
1043+
// re-construct data shape or reps shape
1044+
std::vector<IndexExpr> data_shape;
1045+
std::vector<IndexExpr> reps_shape;
1046+
data_shape.reserve(tndim);
1047+
reps_shape.reserve(tndim);
1048+
if (ndim == rndim) {
1049+
for (int i = 0; i < tndim; ++i) {
1050+
data_shape.emplace_back(data->shape[i]);
1051+
reps_shape.emplace_back(reps[i]);
1052+
}
1053+
} else if (ndim > rndim) {
1054+
for (int i = 0; i < ndim; ++i)
1055+
data_shape.emplace_back(data->shape[i]);
1056+
for (int i = 0; i < (ndim - rndim); ++i)
1057+
reps_shape.emplace_back(1);
1058+
for (int i = 0; i < rndim; ++i)
1059+
reps_shape.emplace_back(reps[i]);
1060+
} else {
1061+
for (int i = 0; i < rndim; ++i)
1062+
reps_shape.emplace_back(reps[i]);
1063+
}
1064+
std::vector<IndexExpr> oshape;
1065+
oshape.reserve(tndim);
1066+
for (int i = 0; i < tndim; ++i) {
1067+
oshape.emplace_back(data_shape[i] * reps_shape[i]);
1068+
}
1069+
reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
1070+
return true;
1071+
}
1072+
1073+
Array<Tensor> TileCompute(const Attrs& attrs,
1074+
const Array<Tensor>& inputs,
1075+
const Type& out_type,
1076+
const Target& target) {
1077+
const TileAttrs *param = attrs.as<TileAttrs>();
1078+
CHECK(param != nullptr);
1079+
return { topi::tile(inputs[0], param->reps) };
1080+
}
1081+
1082+
Expr MakeTile(Expr data,
1083+
Array<Integer> reps) {
1084+
auto attrs = make_node<TileAttrs>();
1085+
attrs->reps = reps;
1086+
static const Op& op = Op::Get("tile");
1087+
return CallNode::make(op, {data}, Attrs(attrs), {});
1088+
}
1089+
1090+
TVM_REGISTER_API("relay.op._make.tile")
1091+
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
1092+
runtime::detail::unpack_call<Expr, 2>(MakeTile, args, rv);
1093+
});
1094+
1095+
RELAY_REGISTER_OP("tile")
1096+
.describe(R"code(Repeat the whole array multiple times.
1097+
1098+
- **data**: The input data to the operator.
1099+
1100+
)code" TVM_ADD_FILELINE)
1101+
.set_num_inputs(1)
1102+
.set_attrs_type_key("relay.attrs.Tile")
1103+
.add_argument("data", "Tensor", "The input tensor.")
1104+
.set_support_level(1)
1105+
.add_type_rel("Tile", TileRel)
1106+
.set_attr<FTVMCompute>("FTVMCompute", TileCompute)
1107+
.set_attr<TOpPattern>("TOpPattern", kBroadcast);
1108+
9401109
// where operator
9411110
bool WhereRel(const Array<Type>& types,
9421111
int num_inputs,

0 commit comments

Comments
 (0)