Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,6 +773,109 @@ def _impl(inputs, attr, params, mod):
return _impl


def _combined_nms():
def _impl(inputs, attr, params, mod):
# Get parameter values
boxes = inputs[0]
scores = inputs[1]
try:
max_output_size = int(np.atleast_1d(inputs[2].data.asnumpy().astype("int64"))[0])
except Exception:
try:
max_output_size = (
_infer_value(inputs[2], params, mod).asnumpy().astype("int64").tolist()[0]
)
except Exception:
max_output_size = inputs[2]
max_total_size = inputs[3]
iou_threshold = np.atleast_1d(inputs[4].data.asnumpy())[0]
score_threshold = np.atleast_1d(inputs[5].data.asnumpy())[0]
if attr["pad_per_class"]:
raise tvm.error.OpAttributeUnImplemented(
"pad_per_class for CombinedNonMaxSuppression is not supported"
)
boxes_shape = _infer_shape(inputs[0], mod)
scores_shape = _infer_shape(inputs[1], mod)
batch_size = boxes_shape[0]
num_anchors = boxes_shape[1]
q = boxes_shape[2]
num_classes = scores_shape[2]

if q != num_classes:
# When q is 1, it means same box coords are used for all classes.
boxes = _op.broadcast_to(boxes, (batch_size, num_anchors, num_classes, 4))
boxes = _op.reshape(boxes, newshape=[batch_size, num_anchors * num_classes, 4])
scores = _op.reshape(scores, newshape=[batch_size, num_anchors * num_classes, 1])

# In TF, class is specified by memory layout only.
ids = _op.arange(_op.const(num_classes, dtype="float32"))
ids = _op.broadcast_to(ids, (batch_size, num_anchors, num_classes))
ids = _op.reshape(ids, newshape=[batch_size, num_anchors * num_classes, 1])

data = _op.concatenate([ids, scores, boxes], -1)
ct, data, indices = _op.vision.get_valid_counts(
data, score_threshold=score_threshold, id_index=0, score_index=1
)
nms_ret = _op.vision.non_max_suppression(
data=data,
valid_count=ct,
indices=indices,
max_output_size=max_output_size,
iou_threshold=iou_threshold,
force_suppress=False,
top_k=-1,
coord_start=2,
score_index=1,
id_index=0,
return_indices=False,
invalid_to_bottom=True,
)
# Dynamic slice to max_total_size
neg_one = _expr.const([-1])
slice_end = _op.concatenate(
[neg_one, _op.expand_dims(max_total_size, axis=0), neg_one], axis=0
)
nms_ret = _op.strided_slice(
nms_ret, begin=[0, 0, 0], end=slice_end, strides=[1, 1, 1], slice_mode="size"
)

# Slice output into boxes, scores, classes
nmsed_boxes = _op.strided_slice(
nms_ret, begin=[0, 0, 2], end=[-1, -1, 4], slice_mode="size"
)
if attr["clip_boxes"]:
nmsed_boxes = _op.maximum(nmsed_boxes, _expr.const(0, dtype="float32"))
nmsed_boxes = _op.minimum(nmsed_boxes, _expr.const(1, dtype="float32"))
nmsed_scores = _op.strided_slice(
nms_ret, begin=[0, 0, 1], end=[-1, -1, 1], slice_mode="size"
)
nmsed_scores = _op.squeeze(nmsed_scores, axis=[2])
nmsed_classes = _op.strided_slice(
nms_ret, begin=[0, 0, 0], end=[-1, -1, 1], slice_mode="size"
)
nmsed_classes = _op.squeeze(nmsed_classes, axis=[2])
# Get number of valid boxes
nms_count = _op.sum(
_op.cast(_op.greater(nmsed_scores, _expr.const(0, dtype="float32")), "int32"), axis=1
)

# TVM uses -1 for invalid outputs while TF uses 0
box_range = _op.arange(_expr.const(0, dtype="int32"), max_total_size, dtype="int32")
shape = _op.strided_slice(_op.shape_of(nmsed_boxes), begin=[0], end=[2])
box_range = _op.broadcast_to(box_range, shape)
valid_mask = _op.cast(_op.less(box_range, _op.expand_dims(nms_count, axis=1)), "float32")
nmsed_boxes = nmsed_boxes * _op.expand_dims(valid_mask, axis=2)
# Could instead use mask for scores, classes if negative values are possible.
nmsed_scores = _op.maximum(nmsed_scores, _expr.const(0, dtype="float32"))
nmsed_classes = _op.maximum(nmsed_classes, _expr.const(0, dtype="float32"))

return _expr.TupleWrapper(
_expr.Tuple([nmsed_boxes, nmsed_scores, nmsed_classes, nms_count]), 4
)

return _impl


def _decode_image():
def _impl(inputs, attr, params, mod):
# Image decode wrapper: Expecting user to feed decoded input to next layer drop this layer.
Expand Down Expand Up @@ -2473,6 +2576,7 @@ def _impl(inputs, attr, params, mod):
"NonMaxSuppressionV3": _nms(),
"NonMaxSuppressionV4": _nms(),
"NonMaxSuppressionV5": _nms(True),
"CombinedNonMaxSuppression": _combined_nms(),
"NoOp": _no_op(),
"NotEqual": _broadcast("not_equal"),
"OneHot": _one_hot(),
Expand Down
49 changes: 49 additions & 0 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2837,6 +2837,55 @@ def test_forward_nms():
_test_forward_nms((2000, 4), (2000,), 0.4, 0.6, 7)


def _test_forward_combined_nms(
bx_shape,
score_shape,
iou_threshold,
score_threshold,
out_size,
total_size,
clip_boxes=False,
dtype="float32",
):
boxes = np.random.uniform(-1, 2, size=bx_shape).astype(dtype)
scores = np.random.uniform(size=score_shape).astype(dtype)
max_output_size = np.int32(out_size)
tf.reset_default_graph()
in_data_1 = tf.placeholder(dtype, boxes.shape, name="in_data_1")
in_data_2 = tf.placeholder(dtype, scores.shape, name="in_data_2")
in_data_3 = tf.placeholder(tf.int32, name="in_data_3")
tf.image.combined_non_max_suppression(
boxes=in_data_1,
scores=in_data_2,
max_output_size_per_class=in_data_3,
max_total_size=total_size,
iou_threshold=iou_threshold,
score_threshold=score_threshold,
pad_per_class=False,
clip_boxes=clip_boxes,
name="nms",
)
compare_tf_with_tvm(
[boxes, scores, max_output_size],
["in_data_1:0", "in_data_2:0", "in_data_3:0"],
[
"nms/CombinedNonMaxSuppression:0",
"nms/CombinedNonMaxSuppression:1",
"nms/CombinedNonMaxSuppression:2",
"nms/CombinedNonMaxSuppression:3",
],
mode="vm",
)


def test_forward_combined_nms():
""" CombinedNonMaxSuppression """
_test_forward_combined_nms((1, 64, 1, 4), (1, 64, 1), 0.7, 0.5, 64, 64)
_test_forward_combined_nms((1, 64, 1, 4), (1, 64, 20), 0.7, 0.5, 64, 10)
_test_forward_combined_nms((1, 64, 20, 4), (1, 64, 20), 0.7, 0.5, 64, 64, clip_boxes=True)
_test_forward_combined_nms((2, 200, 1, 4), (2, 200, 1), 0.4, 0.6, 100, 100)


#######################################################################
# LSTM
# ----
Expand Down