1919 onnxruntime = None
2020
2121import unittest
22+ from torchvision .ops ._register_onnx_ops import _onnx_opset_version
2223
2324
2425@unittest .skipIf (onnxruntime is None , 'ONNX Runtime unavailable' )
@@ -32,7 +33,8 @@ def run_model(self, model, inputs_list, tolerate_small_mismatch=False):
3233
3334 onnx_io = io .BytesIO ()
3435 # export to onnx with the first input
35- torch .onnx .export (model , inputs_list [0 ], onnx_io , do_constant_folding = True , opset_version = 10 )
36+ torch .onnx .export (model , inputs_list [0 ], onnx_io ,
37+ do_constant_folding = True , opset_version = _onnx_opset_version )
3638
3739 # validate the exported model with onnx runtime
3840 for test_inputs in inputs_list :
@@ -97,7 +99,6 @@ def test_roi_pool(self):
9799 model = ops .RoIPool ((pool_h , pool_w ), 2 )
98100 self .run_model (model , [(x , rois )])
99101
100- @unittest .skip ("Disable test until Resize opset 11 is implemented in ONNX Runtime" )
101102 def test_transform_images (self ):
102103
103104 class TransformModule (torch .nn .Module ):
@@ -108,13 +109,13 @@ def __init__(self_module):
108109 def forward (self_module , images ):
109110 return self_module .transform (images )[0 ].tensors
110111
111- input = [torch .rand (3 , 800 , 1280 ), torch .rand (3 , 800 , 800 )]
112- input_test = [torch .rand (3 , 800 , 1280 ), torch .rand (3 , 800 , 800 )]
112+ input = [torch .rand (3 , 100 , 200 ), torch .rand (3 , 200 , 200 )]
113+ input_test = [torch .rand (3 , 100 , 200 ), torch .rand (3 , 200 , 200 )]
113114 self .run_model (TransformModule (), [input , input_test ])
114115
115116 def _init_test_generalized_rcnn_transform (self ):
116- min_size = 800
117- max_size = 1333
117+ min_size = 100
118+ max_size = 200
118119 image_mean = [0.485 , 0.456 , 0.406 ]
119120 image_std = [0.229 , 0.224 , 0.225 ]
120121 transform = GeneralizedRCNNTransform (min_size , max_size , image_mean , image_std )
@@ -234,7 +235,6 @@ def forward(self, input, boxes):
234235
235236 self .run_model (TransformModule (), [(i , [boxes ],), (i1 , [boxes1 ],)])
236237
237- @unittest .skipIf (torch .__version__ < "1.4." , "Disable test if torch version is less than 1.4" )
238238 def test_roi_heads (self ):
239239 class RoiHeadsModule (torch .nn .Module ):
240240 def __init__ (self_module , images ):
@@ -271,7 +271,7 @@ def get_image_from_url(self, url):
271271
272272 data = requests .get (url )
273273 image = Image .open (BytesIO (data .content )).convert ("RGB" )
274- image = image .resize ((800 , 1280 ), Image .BILINEAR )
274+ image = image .resize ((300 , 200 ), Image .BILINEAR )
275275
276276 to_tensor = transforms .ToTensor ()
277277 return to_tensor (image )
@@ -285,12 +285,12 @@ def get_test_images(self):
285285 test_images = [image2 ]
286286 return images , test_images
287287
288- @unittest .skip ("Disable test until Resize opset 11 is implemented in ONNX Runtime" )
289- @unittest .skipIf (torch .__version__ < "1.4." , "Disable test if torch version is less than 1.4" )
290288 def test_faster_rcnn (self ):
291289 images , test_images = self .get_test_images ()
292290
293- model = models .detection .faster_rcnn .fasterrcnn_resnet50_fpn (pretrained = True )
291+ model = models .detection .faster_rcnn .fasterrcnn_resnet50_fpn (pretrained = True ,
292+ min_size = 200 ,
293+ max_size = 300 )
294294 model .eval ()
295295 model (images )
296296 self .run_model (model , [(images ,), (test_images ,)])
0 commit comments