Skip to content

Commit fc0d7eb

Browse files
icemelonwweic
authored andcommitted
[Relay/TOPI][Op] Add shape op in Relay and TOPI (apache#2749)
* Add shapeof op in topi * Add relay shape_of op * Add constant folding for shape_of * Allow shape op to specify dtype * Add mxnet converter for shape_array * lint * lint * Add doc
1 parent 12c75a5 commit fc0d7eb

File tree

15 files changed

+278
-1
lines changed

15 files changed

+278
-1
lines changed

docs/api/python/topi.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ List of operators
7575
topi.stack
7676
topi.repeat
7777
topi.tile
78+
topi.shape
7879
topi.layout_transform
7980
topi.image.resize
8081

@@ -136,6 +137,7 @@ topi
136137
.. autofunction:: topi.stack
137138
.. autofunction:: topi.repeat
138139
.. autofunction:: topi.tile
140+
.. autofunction:: topi.shape
139141
.. autofunction:: topi.layout_transform
140142

141143
topi.nn

docs/langref/relay_op.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ This level support backpropagation of broadcast operators. It is temporary.
155155
tvm.relay.broadcast_to_like
156156
tvm.relay.collapse_sum_like
157157
tvm.relay.slice_like
158+
tvm.relay.shape_of
158159
tvm.relay.layout_transform
159160
tvm.relay.device_copy
160161
tvm.relay.annotation.on_device
@@ -275,6 +276,7 @@ Level 10 Definitions
275276
.. autofunction:: tvm.relay.broadcast_to_like
276277
.. autofunction:: tvm.relay.collapse_sum_like
277278
.. autofunction:: tvm.relay.slice_like
279+
.. autofunction:: tvm.relay.shape_of
278280
.. autofunction:: tvm.relay.layout_transform
279281
.. autofunction:: tvm.relay.device_copy
280282
.. autofunction:: tvm.relay.annotation.on_device

include/tvm/relay/attrs/transform.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ struct ClipAttrs : public tvm::AttrsNode<ClipAttrs> {
226226
}
227227
};
228228

229+
/*! \brief Attributes for LayoutTransform operator */
229230
struct LayoutTransformAttrs : public tvm::AttrsNode<LayoutTransformAttrs> {
230231
std::string src_layout;
231232
std::string dst_layout;
@@ -238,6 +239,17 @@ struct LayoutTransformAttrs : public tvm::AttrsNode<LayoutTransformAttrs> {
238239
}
239240
};
240241

242+
/*! \brief Attributes for ShapeOf operator */
243+
struct ShapeOfAttrs : public tvm::AttrsNode<ShapeOfAttrs> {
244+
DataType dtype;
245+
246+
TVM_DECLARE_ATTRS(ShapeOfAttrs, "relay.attrs.ShapeOfAttrs") {
247+
TVM_ATTR_FIELD(dtype)
248+
.describe("Target data type")
249+
.set_default(NullValue<DataType>());
250+
}
251+
};
252+
241253
} // namespace relay
242254
} // namespace tvm
243255
#endif // TVM_RELAY_ATTRS_TRANSFORM_H_

python/tvm/relay/frontend/mxnet.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,19 @@ def _mx_l2_normalize(inputs, attrs):
495495
return _op.nn.l2_normalize(inputs[0], **new_attrs)
496496

497497

498+
def _mx_shape_array(inputs, attrs):
499+
assert len(inputs) == 1
500+
if attrs.get_int("lhs_begin", None) is not None:
501+
raise RuntimeError("shape_array doesn't support lhs_begin")
502+
if attrs.get_int("lhs_end", None) is not None:
503+
raise RuntimeError("shape_array doesn't support lhs_end")
504+
if attrs.get_int("rhs_begin", None) is not None:
505+
raise RuntimeError("shape_array doesn't support rhs_begin")
506+
if attrs.get_int("rhs_end", None) is not None:
507+
raise RuntimeError("shape_array doesn't support rhs_end")
508+
return _op.shape_of(inputs[0], dtype='int64')
509+
510+
498511
# Note: due to attribute conversion constraint
499512
# ops in the identity set must be attribute free
500513
_identity_list = [
@@ -621,6 +634,7 @@ def _mx_l2_normalize(inputs, attrs):
621634
"tile" : _mx_tile,
622635
"reverse" : _mx_reverse,
623636
"BlockGrad" : _mx_BlockGrad,
637+
"shape_array" : _mx_shape_array,
624638
"SoftmaxOutput" : _mx_softmax_output,
625639
"SoftmaxActivation" : _mx_softmax_activation,
626640
# vision

python/tvm/relay/op/_tensor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
register_schedule("minimum", schedule_injective)
4141
register_schedule("right_shift", schedule_injective)
4242
register_schedule("left_shift", schedule_injective)
43+
register_schedule("shape_of", schedule_injective)
4344

4445
# zeros
4546
@register_compute("zeros")

python/tvm/relay/op/tensor.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -713,3 +713,22 @@ def device_copy(data, src_dev, dst_dev):
713713
raise ValueError("dst_dev is expected to be the type of TVMContext or "
714714
"str, but received %s" % (type(dst_dev)))
715715
return _make.device_copy(data, src_dev, dst_dev)
716+
717+
718+
def shape_of(data, dtype="int32"):
719+
"""Get shape of a tensor.
720+
721+
Parameters
722+
----------
723+
data : tvm.relay.Expr
724+
The input tensor.
725+
726+
dtype : str, optional
727+
The target data type.
728+
729+
Returns
730+
-------
731+
result : tvm.relay.Expr
732+
The shape tensor.
733+
"""
734+
return _make.shape_of(data, dtype)

src/relay/op/tensor/unary.cc

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <tvm/relay/op.h>
88
#include <tvm/relay/attrs/transform.h>
99
#include <topi/elemwise.h>
10+
#include <topi/transform.h>
1011
#include "../type_relations.h"
1112
#include "../op_common.h"
1213

@@ -189,5 +190,56 @@ RELAY_REGISTER_UNARY_OP("logical_not")
189190
.set_support_level(4)
190191
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::logical_not));
191192

193+
194+
// shape_of
195+
TVM_REGISTER_NODE_TYPE(ShapeOfAttrs);
196+
197+
bool ShapeOfRel(const Array<Type>& types,
198+
int num_inputs,
199+
const Attrs& attrs,
200+
const TypeReporter& reporter) {
201+
CHECK_EQ(num_inputs, 1);
202+
auto tt = types[0].as<TensorTypeNode>();
203+
CHECK(tt != nullptr);
204+
const auto* param = attrs.as<ShapeOfAttrs>();
205+
CHECK(param != nullptr);
206+
auto vector_out = tvm::Integer(tt->shape.size());
207+
reporter->Assign(types[1], TensorTypeNode::make({ vector_out }, param->dtype));
208+
return true;
209+
}
210+
211+
Array<Tensor> ShapeOfCompute(const Attrs& attrs,
212+
const Array<Tensor>& inputs,
213+
const Type& out_type,
214+
const Target& target) {
215+
CHECK_EQ(inputs.size(), 1);
216+
const auto* param = attrs.as<ShapeOfAttrs>();
217+
CHECK(param != nullptr);
218+
return {topi::shape(inputs[0], param->dtype)};
219+
}
220+
221+
TVM_REGISTER_API("relay.op._make.shape_of")
222+
.set_body_typed<Expr(Expr, DataType)>([](Expr data, DataType dtype) {
223+
auto attrs = make_node<ShapeOfAttrs>();
224+
attrs->dtype = dtype;
225+
static const Op& op = Op::Get("shape_of");
226+
return CallNode::make(op, {data}, Attrs(attrs), {});
227+
});
228+
229+
RELAY_REGISTER_OP("shape_of")
230+
.describe(R"code(Returns a tensor representing the shape of a tensor.
231+
232+
)code" TVM_ADD_FILELINE)
233+
.set_num_inputs(1)
234+
.set_attrs_type_key("relay.attrs.ShapeOfAttrs")
235+
.add_argument("data", "Tensor", "The input tensor.")
236+
.add_type_rel("ShapeOf", ShapeOfRel)
237+
.set_attr<TOpIsStateful>("TOpIsStateful", false)
238+
.set_attr<TOpPattern>("TOpPattern", kInjective)
239+
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
240+
ElemwiseArbitraryLayout)
241+
.set_support_level(10)
242+
.set_attr<FTVMCompute>("FTVMCompute", ShapeOfCompute);
243+
192244
} // namespace relay
193245
} // namespace tvm

src/relay/pass/fold_constant.cc

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <tvm/relay/expr_functor.h>
77
#include <tvm/relay/op_attr_types.h>
88
#include <tvm/relay/interpreter.h>
9+
#include <tvm/relay/attrs/transform.h>
910

1011
namespace tvm {
1112
namespace relay {
@@ -71,6 +72,7 @@ class ConstantFolder : public ExprMutator {
7172

7273
Expr VisitExpr_(const CallNode* call) final {
7374
static auto op_stateful = Op::GetAttr<TOpIsStateful>("TOpIsStateful");
75+
auto origin_args = call->args;
7476
Expr res = ExprMutator::VisitExpr_(call);
7577
call = res.as<CallNode>();
7678
// We don't constant fold function with zero arguments.
@@ -81,6 +83,10 @@ class ConstantFolder : public ExprMutator {
8183
if (op == nullptr) return res;
8284
// skip stateful ops.
8385
if (op_stateful.get(GetRef<Op>(op), false)) return res;
86+
// Try to evaluate shape_of op
87+
if (call->op.same_as(Op::Get("shape_of"))) {
88+
return EvaluateShapeOf(res, origin_args, call->attrs);
89+
}
8490
bool all_const_args = true;
8591
for (Expr arg : call->args) {
8692
if (!checker_.Check(arg)) {
@@ -132,6 +138,42 @@ class ConstantFolder : public ExprMutator {
132138
expr = InferType(expr, Module(nullptr));
133139
return ValueToExpr(executor_(expr));
134140
}
141+
// Evaluate shape_of op
142+
Expr EvaluateShapeOf(Expr expr, Array<Expr> args, Attrs attrs) {
143+
Expr input = args[0];
144+
const auto* param = attrs.as<ShapeOfAttrs>();
145+
CHECK(param != nullptr);
146+
tvm::Array<IndexExpr> ishape;
147+
if (const ConstantNode* op = input.as<ConstantNode>()) {
148+
ishape = op->tensor_type()->shape;
149+
} else if (input->checked_type_.defined()) {
150+
ishape = input->checked_type().as<TensorTypeNode>()->shape;
151+
} else {
152+
return expr;
153+
}
154+
// Get the constant shape
155+
DLContext ctx;
156+
ctx.device_type = kDLCPU;
157+
ctx.device_id = 0;
158+
auto val = runtime::NDArray::Empty(
159+
{(int64_t)ishape.size()}, Type2TVMType(Int(32)), ctx);
160+
int32_t* dims = static_cast<int32_t*>(val->data);
161+
using ::tvm::ir::IntImm;
162+
for (size_t i = 0; i < ishape.size(); ++i) {
163+
if (const IntImm* dim = ishape[i].as<IntImm>()) {
164+
dims[i] = dim->value;
165+
} else {
166+
return expr;
167+
}
168+
}
169+
Expr shape = ValueToExpr(TensorValueNode::make(val));
170+
// Cast the constant into correct dtype
171+
auto cast_attrs = make_node<CastAttrs>();
172+
cast_attrs->dtype = param->dtype;
173+
static const Op& cast_op = Op::Get("cast");
174+
Expr ret = CallNode::make(cast_op, {shape}, Attrs(cast_attrs), {});
175+
return ConstEvaluate(ret);
176+
}
135177
};
136178

137179

tests/python/frontend/mxnet/test_forward.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,22 @@ def test_forward_l2_normalize():
380380
verify_mxnet_frontend_impl(mx_sym, (2, 3, 4, 5), (2, 3, 4, 5))
381381

382382

383+
def test_forward_shape_array():
384+
def verify(shape):
385+
x_np = np.random.uniform(size=shape).astype("float32")
386+
ref_res = mx.nd.shape_array(mx.nd.array(x_np))
387+
mx_sym = mx.sym.shape_array(mx.sym.var("x"))
388+
new_sym, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape})
389+
for target, ctx in ctx_list():
390+
for kind in ["debug"]:
391+
intrp = relay.create_executor(kind, ctx=ctx, target=target)
392+
op_res = intrp.evaluate(new_sym)(x_np)
393+
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
394+
verify((1,))
395+
verify((3, 4, 5))
396+
verify((3, 4, 5, 6))
397+
398+
383399
if __name__ == '__main__':
384400
test_forward_mlp()
385401
test_forward_vgg()
@@ -409,3 +425,4 @@ def test_forward_l2_normalize():
409425
test_forward_slice_like()
410426
test_forward_slice_axis()
411427
test_forward_l2_normalize()
428+
test_forward_shape_array()

tests/python/relay/test_op_level10.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,10 +177,25 @@ def test_batch_matmul():
177177
verify_batch_matmul((5, 16, 32), (5, 20, 32), (5, 16, 20))
178178
verify_batch_matmul((30, 16, 32), (30, 20, 32), (30, 16, 20))
179179

180+
def test_shape_of():
181+
shape = (10, 5, 12)
182+
x = relay.var("x", shape=shape)
183+
func = relay.Function([x], relay.op.shape_of(x))
184+
func = relay.ir_pass.infer_type(func)
185+
x_data = np.random.rand(*shape).astype('float32')
186+
for target, ctx in ctx_list():
187+
# Because using graph executor, this op will be optimized after
188+
# constant folding pass, here we only test with interpreter
189+
for kind in ["debug"]:
190+
intrp = relay.create_executor(kind, ctx=ctx, target=target)
191+
op_res = intrp.evaluate(func)(x_data)
192+
tvm.testing.assert_allclose(op_res.asnumpy(),
193+
np.array(shape).astype('int32'))
180194

181195
if __name__ == "__main__":
182196
test_collapse_sum_like()
183197
test_broadcast_to_like()
184198
test_slice_like()
185199
test_reverse_reshape()
186200
test_batch_matmul()
201+
test_shape_of()

0 commit comments

Comments
 (0)