@@ -148,7 +148,6 @@ def check_correctness(
148148 tvm_num_outputs = 1
149149
150150 # Check that number of outputs match.
151-
152151 assert tvm_num_outputs == len (ort_output ), "Unequal number of outputs"
153152
154153 for (tvm_out , ort_out ) in zip (tvm_output , ort_output ):
@@ -435,6 +434,22 @@ def test_unsqueeze():
435434 check_correctness (model )
436435
437436
437+ def test_unsqueeze_v1 ():
438+ # https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Unsqueeze-1
439+ unsqueeze_node = helper .make_node ("Unsqueeze" , ["a" ], ["b" ], axes = [0 , 2 , 3 ])
440+ graph = helper .make_graph (
441+ [unsqueeze_node ],
442+ "unsqueeze_v1" ,
443+ inputs = [helper .make_tensor_value_info ("a" , TensorProto .FLOAT , [32 , 32 ])],
444+ outputs = [helper .make_tensor_value_info ("b" , TensorProto .FLOAT , [1 , 32 , 1 , 1 , 32 ])],
445+ )
446+
447+ model = helper .make_model (
448+ graph , producer_name = "unsqueeze_v1_test" , opset_imports = [helper .make_opsetid ("" , 6 )]
449+ )
450+ check_correctness (model , opset = 10 )
451+
452+
438453def test_gelu ():
439454 verify_unary ("Gelu" , [32 , 32 ], domain = "com.microsoft" )
440455
@@ -490,6 +505,25 @@ def test_clip(min, max):
490505 check_correctness (model )
491506
492507
508+ @pytest .mark .parametrize ("min" , [- 6.0 , 0.0 ])
509+ @pytest .mark .parametrize ("max" , [6.0 ])
510+ def test_clip_v6 (max , min ):
511+ # https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Clip-6
512+ clip_node = helper .make_node ("Clip" , ["input" ], ["output" ], max = max , min = min )
513+ inputs = [helper .make_tensor_value_info ("input" , TensorProto .FLOAT , [32 , 64 ])]
514+ graph = helper .make_graph (
515+ [clip_node ],
516+ "clip_v6_test" ,
517+ inputs = inputs ,
518+ outputs = [helper .make_tensor_value_info ("output" , TensorProto .FLOAT , [32 , 64 ])],
519+ )
520+ model = helper .make_model (
521+ graph , producer_name = "clip_v6_test" , opset_imports = [helper .make_opsetid ("" , 6 )]
522+ )
523+ onnx .save (model , "a.onnx" )
524+ check_correctness (model , opset = 10 )
525+
526+
493527def test_equal ():
494528 equal_node = helper .make_node ("Equal" , ["a" , "b" ], ["output" ])
495529
0 commit comments