66
77
88class TestResizeConverter (DispatchTestCase ):
9+
10+ def compare_resized_tensors (self , tensor1 , tensor2 , input_shape , target_shape ):
11+ # Check if the sizes match
12+ if tensor1 .size () != tensor2 .size ():
13+ return False
14+
15+ # Flatten the tensors to ensure we are comparing the valid elements
16+ flat_tensor1 = tensor1 .flatten ()
17+ flat_tensor2 = tensor2 .flatten ()
18+
19+ # Calculate the number of valid elements to compare
20+ input_numel = torch .Size (input_shape ).numel ()
21+ target_numel = torch .Size (target_shape ).numel ()
22+ min_size = min (input_numel , target_numel )
23+
24+ # Compare only the valid elements
25+ return torch .equal (flat_tensor1 [:min_size ], flat_tensor2 [:min_size ])
26+
927 @parameterized .expand (
1028 [
1129 ((3 ,),),
@@ -28,24 +46,7 @@ def forward(self, x):
2846 input_shape = (5 ,)
2947 inputs = [torch .randn (input_shape )]
3048
31- def compare_resized_tensors (tensor1 , tensor2 , input_shape , target_shape ):
32- # Check if the sizes match
33- if tensor1 .size () != tensor2 .size ():
34- return False
35-
36- # Flatten the tensors to ensure we are comparing the valid elements
37- flat_tensor1 = tensor1 .flatten ()
38- flat_tensor2 = tensor2 .flatten ()
39-
40- # Calculate the number of valid elements to compare
41- input_numel = torch .Size (input_shape ).numel ()
42- target_numel = torch .Size (target_shape ).numel ()
43- min_size = min (input_numel , target_numel )
44-
45- # Compare only the valid elements
46- return torch .equal (flat_tensor1 [:min_size ], flat_tensor2 [:min_size ])
47-
48- comparators = [(compare_resized_tensors , [input_shape , target_shape ])]
49+ comparators = [(self .compare_resized_tensors , [input_shape , target_shape ])]
4950
5051 self .run_test_compare_tensor_attributes_only (
5152 Resize (),
@@ -76,24 +77,7 @@ def forward(self, x):
7677 input_shape = (5 ,)
7778 inputs = [torch .randint (1 , 5 , input_shape )]
7879
79- def compare_resized_tensors (tensor1 , tensor2 , input_shape , target_shape ):
80- # Check if the sizes match
81- if tensor1 .size () != tensor2 .size ():
82- return False
83-
84- # Flatten the tensors to ensure we are comparing the valid elements
85- flat_tensor1 = tensor1 .flatten ()
86- flat_tensor2 = tensor2 .flatten ()
87-
88- # Calculate the number of valid elements to compare
89- input_numel = torch .Size (input_shape ).numel ()
90- target_numel = torch .Size (target_shape ).numel ()
91- min_size = min (input_numel , target_numel )
92-
93- # Compare only the valid elements
94- return torch .equal (flat_tensor1 [:min_size ], flat_tensor2 [:min_size ])
95-
96- comparators = [(compare_resized_tensors , [input_shape , target_shape ])]
80+ comparators = [(self .compare_resized_tensors , [input_shape , target_shape ])]
9781
9882 self .run_test_compare_tensor_attributes_only (
9983 Resize (),
@@ -124,24 +108,7 @@ def forward(self, x):
124108 input_shape = (4 , 4 )
125109 inputs = [torch .randint (1 , 10 , input_shape )]
126110
127- def compare_resized_tensors (tensor1 , tensor2 , input_shape , target_shape ):
128- # Check if the sizes match
129- if tensor1 .size () != tensor2 .size ():
130- return False
131-
132- # Flatten the tensors to ensure we are comparing the valid elements
133- flat_tensor1 = tensor1 .flatten ()
134- flat_tensor2 = tensor2 .flatten ()
135-
136- # Calculate the number of valid elements to compare
137- input_numel = torch .Size (input_shape ).numel ()
138- target_numel = torch .Size (target_shape ).numel ()
139- min_size = min (input_numel , target_numel )
140-
141- # Compare only the valid elements
142- return torch .equal (flat_tensor1 [:min_size ], flat_tensor2 [:min_size ])
143-
144- comparators = [(compare_resized_tensors , [input_shape , target_shape ])]
111+ comparators = [(self .compare_resized_tensors , [input_shape , target_shape ])]
145112
146113 self .run_test_compare_tensor_attributes_only (
147114 Resize (),
@@ -171,24 +138,7 @@ def forward(self, x):
171138 input_shape = (4 , 4 )
172139 inputs = [torch .randint (1 , 10 , input_shape )]
173140
174- def compare_resized_tensors (tensor1 , tensor2 , input_shape , target_shape ):
175- # Check if the sizes match
176- if tensor1 .size () != tensor2 .size ():
177- return False
178-
179- # Flatten the tensors to ensure we are comparing the valid elements
180- flat_tensor1 = tensor1 .flatten ()
181- flat_tensor2 = tensor2 .flatten ()
182-
183- # Calculate the number of valid elements to compare
184- input_numel = torch .Size (input_shape ).numel ()
185- target_numel = torch .Size (target_shape ).numel ()
186- min_size = min (input_numel , target_numel )
187-
188- # Compare only the valid elements
189- return torch .equal (flat_tensor1 [:min_size ], flat_tensor2 [:min_size ])
190-
191- comparators = [(compare_resized_tensors , [input_shape , target_shape ])]
141+ comparators = [(self .compare_resized_tensors , [input_shape , target_shape ])]
192142
193143 self .run_test_compare_tensor_attributes_only (
194144 Resize (),
0 commit comments