Skip to content

Commit be6f398

Browse files
Lara Haidarfmassa
authored andcommitted
Enable ONNX Test for FasterRcnn (#1555)
* enable faster rcnn test * flake8 * smaller image size * set min/max
1 parent af225a8 commit be6f398

File tree

4 files changed

+18
-31
lines changed

4 files changed

+18
-31
lines changed

test/test_onnx.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
onnxruntime = None
2020

2121
import unittest
22+
from torchvision.ops._register_onnx_ops import _onnx_opset_version
2223

2324

2425
@unittest.skipIf(onnxruntime is None, 'ONNX Runtime unavailable')
@@ -32,7 +33,8 @@ def run_model(self, model, inputs_list, tolerate_small_mismatch=False):
3233

3334
onnx_io = io.BytesIO()
3435
# export to onnx with the first input
35-
torch.onnx.export(model, inputs_list[0], onnx_io, do_constant_folding=True, opset_version=10)
36+
torch.onnx.export(model, inputs_list[0], onnx_io,
37+
do_constant_folding=True, opset_version=_onnx_opset_version)
3638

3739
# validate the exported model with onnx runtime
3840
for test_inputs in inputs_list:
@@ -97,7 +99,6 @@ def test_roi_pool(self):
9799
model = ops.RoIPool((pool_h, pool_w), 2)
98100
self.run_model(model, [(x, rois)])
99101

100-
@unittest.skip("Disable test until Resize opset 11 is implemented in ONNX Runtime")
101102
def test_transform_images(self):
102103

103104
class TransformModule(torch.nn.Module):
@@ -108,13 +109,13 @@ def __init__(self_module):
108109
def forward(self_module, images):
109110
return self_module.transform(images)[0].tensors
110111

111-
input = [torch.rand(3, 800, 1280), torch.rand(3, 800, 800)]
112-
input_test = [torch.rand(3, 800, 1280), torch.rand(3, 800, 800)]
112+
input = [torch.rand(3, 100, 200), torch.rand(3, 200, 200)]
113+
input_test = [torch.rand(3, 100, 200), torch.rand(3, 200, 200)]
113114
self.run_model(TransformModule(), [input, input_test])
114115

115116
def _init_test_generalized_rcnn_transform(self):
116-
min_size = 800
117-
max_size = 1333
117+
min_size = 100
118+
max_size = 200
118119
image_mean = [0.485, 0.456, 0.406]
119120
image_std = [0.229, 0.224, 0.225]
120121
transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std)
@@ -234,7 +235,6 @@ def forward(self, input, boxes):
234235

235236
self.run_model(TransformModule(), [(i, [boxes],), (i1, [boxes1],)])
236237

237-
@unittest.skipIf(torch.__version__ < "1.4.", "Disable test if torch version is less than 1.4")
238238
def test_roi_heads(self):
239239
class RoiHeadsModule(torch.nn.Module):
240240
def __init__(self_module, images):
@@ -271,7 +271,7 @@ def get_image_from_url(self, url):
271271

272272
data = requests.get(url)
273273
image = Image.open(BytesIO(data.content)).convert("RGB")
274-
image = image.resize((800, 1280), Image.BILINEAR)
274+
image = image.resize((300, 200), Image.BILINEAR)
275275

276276
to_tensor = transforms.ToTensor()
277277
return to_tensor(image)
@@ -285,12 +285,12 @@ def get_test_images(self):
285285
test_images = [image2]
286286
return images, test_images
287287

288-
@unittest.skip("Disable test until Resize opset 11 is implemented in ONNX Runtime")
289-
@unittest.skipIf(torch.__version__ < "1.4.", "Disable test if torch version is less than 1.4")
290288
def test_faster_rcnn(self):
291289
images, test_images = self.get_test_images()
292290

293-
model = models.detection.faster_rcnn.fasterrcnn_resnet50_fpn(pretrained=True)
291+
model = models.detection.faster_rcnn.fasterrcnn_resnet50_fpn(pretrained=True,
292+
min_size=200,
293+
max_size=300)
294294
model.eval()
295295
model(images)
296296
self.run_model(model, [(images,), (test_images,)])

torchvision/models/detection/rpn.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -110,13 +110,7 @@ def grid_anchors(self, grid_sizes, strides):
110110
shifts_y = torch.arange(
111111
0, grid_height, dtype=torch.float32, device=device
112112
) * stride_height
113-
# TODO: remove tracing pass when exporting torch.meshgrid()
114-
# is suported in ONNX
115-
if torchvision._is_tracing():
116-
shift_y = shifts_y.view(-1, 1).expand(grid_height, grid_width)
117-
shift_x = shifts_x.view(1, -1).expand(grid_height, grid_width)
118-
else:
119-
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
113+
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
120114
shift_x = shift_x.reshape(-1)
121115
shift_y = shift_y.reshape(-1)
122116
shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1)

torchvision/models/detection/transform.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -89,15 +89,6 @@ def resize(self, image, target):
8989
target["keypoints"] = keypoints
9090
return image, target
9191

92-
# _onnx_dynamic_img_pad() creates a dynamic padding
93-
# for an image supported in ONNx tracing.
94-
# it is used to process the images in _onnx_batch_images().
95-
def _onnx_dynamic_img_pad(self, img, padding):
96-
concat_0 = torch.cat((img, torch.zeros(padding[0], img.shape[1], img.shape[2])), 0)
97-
concat_1 = torch.cat((concat_0, torch.zeros(concat_0.shape[0], padding[1], concat_0.shape[2])), 1)
98-
padded_img = torch.cat((concat_1, torch.zeros(concat_1.shape[0], concat_1.shape[1], padding[2])), 2)
99-
return padded_img
100-
10192
# _onnx_batch_images() is an implementation of
10293
# batch_images() that is supported by ONNX tracing.
10394
def _onnx_batch_images(self, images, size_divisible=32):
@@ -116,7 +107,7 @@ def _onnx_batch_images(self, images, size_divisible=32):
116107
padded_imgs = []
117108
for img in images:
118109
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
119-
padded_img = self._onnx_dynamic_img_pad(img, padding)
110+
padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
120111
padded_imgs.append(padded_img)
121112

122113
return torch.stack(padded_imgs)

torchvision/ops/_register_onnx_ops.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import sys
22
import torch
33

4+
_onnx_opset_version = 11
5+
46

57
def _register_custom_op():
68
from torch.onnx.symbolic_helper import parse_args, scalar_type_to_onnx
@@ -30,6 +32,6 @@ def roi_pool(g, input, rois, spatial_scale, pooled_height, pooled_width):
3032
return roi_pool, None
3133

3234
from torch.onnx import register_custom_op_symbolic
33-
register_custom_op_symbolic('torchvision::nms', symbolic_multi_label_nms, 10)
34-
register_custom_op_symbolic('torchvision::roi_align', roi_align, 10)
35-
register_custom_op_symbolic('torchvision::roi_pool', roi_pool, 10)
35+
register_custom_op_symbolic('torchvision::nms', symbolic_multi_label_nms, _onnx_opset_version)
36+
register_custom_op_symbolic('torchvision::roi_align', roi_align, _onnx_opset_version)
37+
register_custom_op_symbolic('torchvision::roi_pool', roi_pool, _onnx_opset_version)

0 commit comments

Comments
 (0)