Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions include/tvm/relay/attrs/vision.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,44 @@ struct YoloReorgAttrs : public tvm::AttrsNode<YoloReorgAttrs> {
}
};

/*! \brief Attributes used in proposal operators */
struct ProposalAttrs : public tvm::AttrsNode<ProposalAttrs> {
Array<IndexExpr> scales;
Array<IndexExpr> ratios;
int feature_stride;
double threshold;
int rpn_pre_nms_top_n;
int rpn_post_nms_top_n;
int rpn_min_size;
bool iou_loss;

TVM_DECLARE_ATTRS(ProposalAttrs, "relay.attrs.ProposalAttrs") {
TVM_ATTR_FIELD(scales)
.set_default(Array<IndexExpr>({4.0f, 8.0f, 16.0f, 32.0f}))
.describe("Used to generate anchor windows by enumerating scales");
TVM_ATTR_FIELD(ratios)
.set_default(Array<IndexExpr>({0.5f, 1.0f, 2.0f}))
.describe("Used to generate anchor windows by enumerating ratios");
TVM_ATTR_FIELD(feature_stride)
.set_default(16)
.describe(
"The size of the receptive field each unit in the convolution layer of the rpn,"
"for example the product of all stride's prior to this layer.");
TVM_ATTR_FIELD(threshold)
.set_default(0.7)
.describe(
"IoU threshold of non-maximum suppresion (suppress boxes with IoU >= this threshold)");
TVM_ATTR_FIELD(rpn_pre_nms_top_n)
.set_default(6000)
.describe("Number of top scoring boxes to apply NMS. -1 to use all boxes");
TVM_ATTR_FIELD(rpn_post_nms_top_n)
.set_default(300)
.describe("Number of top scoring boxes to keep after applying NMS to RPN proposals");
TVM_ATTR_FIELD(rpn_min_size).set_default(16).describe("Minimum height or width in proposal");
TVM_ATTR_FIELD(iou_loss).set_default(false).describe("Usage of IoU Loss");
}
};

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_VISION_H_
16 changes: 16 additions & 0 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,20 @@ def _mx_roi_align(inputs, attrs):
return _op.vision.roi_align(inputs[0], inputs[1], **new_attrs)


def _mx_proposal(inputs, attrs):
new_attrs = {}
new_attrs["scales"] = attrs.get_float_tuple("scales", (4.0, 8.0, 16.0, 32.0))
new_attrs["ratios"] = attrs.get_float_tuple("ratios", (0.5, 1.0, 2.0))
new_attrs["feature_stride"] = attrs.get_int("feature_stride", 16)
new_attrs["threshold"] = attrs.get_float("threshold", 0.7)
new_attrs["rpn_pre_nms_top_n"] = attrs.get_int("rpn_pre_nms_top_n", 6000)
new_attrs["rpn_post_nms_top_n"] = attrs.get_int("rpn_post_nms_top_n", 300)
new_attrs["rpn_min_size"] = attrs.get_int("rpn_min_size", 16)
new_attrs["iou_loss"] = attrs.get_bool("iou_loss", False)
assert not attrs.get_bool("output_score", False), "proposal doesn't support output score"
return _op.vision.proposal(inputs[0], inputs[1], inputs[2], **new_attrs)


# Note: due to attribute conversion constraint
# ops in the identity set must be attribute free
_identity_list = [
Expand Down Expand Up @@ -397,6 +411,8 @@ def _mx_roi_align(inputs, attrs):
"_contrib_MultiBoxPrior" : _mx_multibox_prior,
"_contrib_MultiBoxDetection" : _mx_multibox_detection,
"_contrib_ROIAlign" : _mx_roi_align,
"_contrib_Proposal" : _mx_proposal,
"_contrib_MultiProposal" : _mx_proposal,
# List of missing operators that are present in NNVMv1
# TODO(tvm-tvm): support all operators.
#
Expand Down
28 changes: 27 additions & 1 deletion python/tvm/relay/op/vision/_rcnn.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# pylint: disable=invalid-name, unused-argument
"""Faster R-CNN and Mask R-CNN operations."""
import topi
from topi.util import get_const_tuple
from topi.util import get_const_tuple, get_float_tuple, get_const_int
from .. import op as reg
from ..op import OpPattern

Expand All @@ -21,3 +21,29 @@ def schedule_roi_align(_, outs, target):
return topi.generic.vision.schedule_roi_align(outs)

reg.register_pattern("vision.roi_align", OpPattern.OUT_ELEMWISE_FUSABLE)

@reg.register_compute("vision.proposal")
def compute_proposal(attrs, inputs, _, target):
"""Compute definition of proposal"""
scales = get_float_tuple(attrs.scales)
ratios = get_float_tuple(attrs.ratios)
feature_stride = attrs.feature_stride
threshold = attrs.threshold
rpn_pre_nms_top_n = attrs.rpn_pre_nms_top_n
rpn_post_nms_top_n = attrs.rpn_post_nms_top_n
rpn_min_size = attrs.rpn_min_size
iou_loss = bool(get_const_int(attrs.iou_loss))
with target:
return [
topi.vision.rcnn.proposal(inputs[0], inputs[1], inputs[2], scales, ratios,
feature_stride, threshold, rpn_pre_nms_top_n,
rpn_post_nms_top_n, rpn_min_size, iou_loss)
]

@reg.register_schedule("vision.proposal")
def schedule_proposal(_, outs, target):
"""Schedule definition of proposal"""
with target:
return topi.generic.schedule_proposal(outs)

reg.register_pattern("vision.proposal", OpPattern.OPAQUE)
60 changes: 60 additions & 0 deletions python/tvm/relay/op/vision/rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,63 @@ def roi_align(data, rois, pooled_size, spatial_scale, sample_ratio=-1, layout='N
4-D tensor with shape [num_roi, channel, pooled_size, pooled_size]
"""
return _make.roi_align(data, rois, pooled_size, spatial_scale, sample_ratio, layout)


def proposal(cls_prob,
bbox_pred,
im_info,
scales,
ratios,
feature_stride,
threshold,
rpn_pre_nms_top_n,
rpn_post_nms_top_n,
rpn_min_size,
iou_loss):
"""Proposal operator.

Parameters
----------
cls_prob : relay.Expr
4-D tensor with shape [batch, 2 * num_anchors, height, width].

bbox_pred : relay.Expr
4-D tensor with shape [batch, 4 * num_anchors, height, width].

im_info : relay.Expr
2-D tensor with shape [batch, 3]. The last dimension should be in format of
[im_height, im_width, im_scale]

scales : list/tuple of float
Scales of anchor windoes.

ratios : list/tuple of float
Ratios of anchor windoes.

feature_stride : int
The size of the receptive field each unit in the convolution layer of the rpn, for example
the product of all stride's prior to this layer.

threshold : float
Non-maximum suppression threshold.

rpn_pre_nms_top_n : int
Number of top scoring boxes to apply NMS. -1 to use all boxes.

rpn_post_nms_top_n : int
Number of top scoring boxes to keep after applying NMS to RPN proposals.

rpn_min_size : int
Minimum height or width in proposal.

iou_loss : bool
Usage of IoU loss.

Returns
-------
output : relay.Expr
2-D tensor with shape [batch * rpn_post_nms_top_n, 5]. The last dimension is in format of
[batch_index, w_start, h_start, w_end, h_end].
"""
return _make.proposal(cls_prob, bbox_pred, im_info, scales, ratios, feature_stride, threshold,
rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_min_size, iou_loss)
67 changes: 67 additions & 0 deletions src/relay/op/vision/rcnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,5 +63,72 @@ RELAY_REGISTER_OP("vision.roi_align")
.set_support_level(5)
.add_type_rel("ROIAlign", ROIAlignRel);

TVM_REGISTER_NODE_TYPE(ProposalAttrs);

bool ProposalRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
auto proposal_attrs = attrs.as<ProposalAttrs>();
CHECK_EQ(types.size(), 4);
const auto* cls_prob = types[0].as<TensorTypeNode>();
const auto* bbox_pred = types[1].as<TensorTypeNode>();
const auto* im_info = types[2].as<TensorTypeNode>();

if (!cls_prob || !bbox_pred || !im_info) {
return false;
}

CHECK_EQ(cls_prob->shape.size(), 4U)
<< "The dimension of class probability should be 4, but received " << cls_prob->shape.size();
CHECK_EQ(bbox_pred->shape.size(), 4U)
<< "The dimension of box prediction should be 4, but received " << bbox_pred->shape.size();
CHECK_EQ(im_info->shape.size(), 2U)
<< "The dimension of image info should be 2, but received " << im_info->shape.size();
CHECK(reporter->AssertEQ(im_info->shape[1], 3));

auto batch = cls_prob->shape[0];

std::vector<IndexExpr> oshape(
{batch * proposal_attrs->rpn_post_nms_top_n, 5});
reporter->Assign(types[3], TensorTypeNode::make(oshape, cls_prob->dtype));
return true;
}

Expr MakeProposal(Expr cls_prob, Expr bbox_pred, Expr im_info, Array<IndexExpr> scales,
Array<IndexExpr> ratios, int feature_stride, double threshold,
int rpn_pre_nms_top_n, int rpn_post_nms_top_n, int rpn_min_size,
bool iou_loss) {
auto attrs = make_node<ProposalAttrs>();
attrs->scales = scales;
attrs->ratios = ratios;
attrs->feature_stride = feature_stride;
attrs->threshold = threshold;
attrs->rpn_pre_nms_top_n = rpn_pre_nms_top_n;
attrs->rpn_post_nms_top_n = rpn_post_nms_top_n;
attrs->rpn_min_size = rpn_min_size;
attrs->iou_loss = iou_loss;
static const Op& op = Op::Get("vision.proposal");
return CallNode::make(op, {cls_prob, bbox_pred, im_info}, Attrs(attrs), {});
}

TVM_REGISTER_API("relay.op.vision._make.proposal")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 11>(MakeProposal, args, rv);
});

RELAY_REGISTER_OP("vision.proposal")
.describe(R"code(Generate region proposals via RPN.

- **cls_prob**: 4-D with shape [batch, 2 * num_anchors, height, width].
- **bbox_pred**: 4-D with shape [batch, 4 * num_anchors, height, width].
- **im_info**: 2-D with shape [batch, 3].
- **out**: 2-D with shape [batch * rpn_post_nms_top_n, 5].
)code" TVM_ADD_FILELINE)
.set_num_inputs(3)
.add_argument("cls_prob", "Tensor", "Score of how likely proposal is object")
.add_argument("bbox_pred", "Tensor", "BBox predicted deltas from anchors for proposals")
.add_argument("im_info", "Tensor", "Image size and scale")
.set_support_level(5)
.add_type_rel("Proposal", ProposalRel);

} // namespace relay
} // namespace tvm
67 changes: 67 additions & 0 deletions tests/python/relay/test_op_level5.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,72 @@ def verify_roi_align(data_shape, rois_shape, pooled_size, spatial_scale, sample_
verify_roi_align((4, 4, 16, 16), (32, 5), pooled_size=7, spatial_scale=0.5, sample_ratio=2)


def test_proposal():
def verify_proposal(np_cls_prob, np_bbox_pred, np_im_info, np_out, attrs):
cls_prob = relay.var("cls_prob", relay.ty.TensorType(np_cls_prob.shape, "float32"))
bbox_pred = relay.var("bbox_pred", relay.ty.TensorType(np_bbox_pred.shape, "float32"))
im_info = relay.var("im_info", relay.ty.TensorType(np_im_info.shape, "float32"))
z = relay.vision.proposal(cls_prob, bbox_pred, im_info, **attrs)
zz = relay.ir_pass.infer_type(z)

assert zz.checked_type == relay.ty.TensorType(np_out.shape, "float32")

func = relay.Function([cls_prob, bbox_pred, im_info], z)
func = relay.ir_pass.infer_type(func)
for target in ['cuda']:
if not tvm.module.enabled(target):
print("Skip test because %s is not enabled." % target)
continue
ctx = tvm.context(target, 0)
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
op_res1 = intrp1.evaluate(func)(np_cls_prob, np_bbox_pred, np_im_info)
tvm.testing.assert_allclose(op_res1.asnumpy(), np_out, rtol=1e-4)
intrp2 = relay.create_executor("debug", ctx=ctx, target=target)
op_res2 = intrp2.evaluate(func)(np_cls_prob, np_bbox_pred, np_im_info)
tvm.testing.assert_allclose(op_res2.asnumpy(), np_out, rtol=1e-4)

attrs = {
'scales': (0.5,),
'ratios': (0.5,),
'feature_stride': 16,
'iou_loss': False,
'rpn_min_size': 16,
'threshold': 0.7,
'rpn_pre_nms_top_n': 200,
'rpn_post_nms_top_n': 4,
}

np_cls_prob = np.array([[
[[0.3, 0.6, 0.2], [0.4, 0.7, 0.5], [0.1, 0.4, 0.3]],
[[0.7, 0.5, 0.3], [0.6, 0.4, 0.8], [0.9, 0.2, 0.5]]
]], dtype='float32')
np_bbox_pred = np.array([[
[[0.5, 1.0, 0.6], [0.8, 1.2, 2.0], [0.9, 1.0, 0.8]],
[[0.5, 1.0, 0.7], [0.8, 1.2, 1.6], [2.1, 1.5, 0.7]],
[[1.0, 0.5, 0.7], [1.5, 0.9, 1.6], [1.4, 1.5, 0.8]],
[[1.0, 0.5, 0.6], [1.5, 0.9, 2.0], [1.8, 1.0, 0.9]],
]], dtype='float32')
np_im_info = np.array([[48., 48., 1.]], dtype='float32')
np_out = np.array([
[0., 0., 2.8451548,28.38012, 18.154846],
[0., 0., 15.354933, 41.96971, 41.245064],
[0., 18.019852, 1.0538368, 51.98015, 25.946163],
[0., 27.320923, -1.266357, 55., 24.666357]
], dtype='float32')


verify_proposal(np_cls_prob, np_bbox_pred, np_im_info, np_out, attrs)

np_out = np.array([
[ 0., -5.25, -2.5, 21.75, 19.],
[ 0., 11.25, -2., 37.25, 18.5],
[ 0., 26.849998, -2.3000002, 53.45, 18.6],
[ 0., -4.95, 13.799999, 22.25, 35.5]
], dtype='float32')
attrs['iou_loss'] = True
verify_proposal(np_cls_prob, np_bbox_pred, np_im_info, np_out, attrs)


def test_yolo_reorg_infer_shape():
def verify_yolo_reorg(shape, stride, out_shape):
x = relay.var("x", relay.TensorType(shape, "float32"))
Expand Down Expand Up @@ -347,5 +413,6 @@ def verify_yolo_reorg(shape, stride):
test_multibox_transform_loc()
test_nms()
test_roi_align()
test_proposal()
test_yolo_reorg_infer_shape()
test_yolo_reorg()
4 changes: 2 additions & 2 deletions topi/tests/python/test_topi_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def test_roi_align():
def verify_proposal(np_cls_prob, np_bbox_pred, np_im_info, np_out, attrs):
cls_prob = tvm.placeholder(np_cls_prob.shape)
bbox_pred = tvm.placeholder(np_bbox_pred.shape)
im_info = tvm.placeholder(np_im_info.shape, dtype='int32')
im_info = tvm.placeholder(np_im_info.shape)

def check_device(device):
ctx = tvm.context(device, 0)
Expand Down Expand Up @@ -252,7 +252,7 @@ def test_proposal():
[[1.0, 0.5, 0.7], [1.5, 0.9, 1.6], [1.4, 1.5, 0.8]],
[[1.0, 0.5, 0.6], [1.5, 0.9, 2.0], [1.8, 1.0, 0.9]],
]], dtype='float32')
np_im_info = np.array([[48, 48, 1]], dtype='int32')
np_im_info = np.array([[48., 48., 1.]], dtype='float32')
np_out = np.array([
[0., 0., 2.8451548,28.38012, 18.154846],
[0., 0., 15.354933, 41.96971, 41.245064],
Expand Down