|
7 | 7 |
|
8 | 8 | from detectron2.layers import ROIAlign, ROIAlignRotated, cat, nonzero_tuple, shapes_to_tensor |
9 | 9 | from detectron2.structures import Boxes |
| 10 | +from detectron2.utils.tracing import assert_fx_safe |
10 | 11 |
|
11 | 12 | """ |
12 | 13 | To export ROIPooler to torchscript, in this file, variables that should be annotated with |
@@ -219,19 +220,20 @@ def forward(self, x: List[torch.Tensor], box_lists: List[Boxes]): |
219 | 220 | """ |
220 | 221 | num_level_assignments = len(self.level_poolers) |
221 | 222 |
|
222 | | - assert isinstance(x, list) and isinstance( |
223 | | - box_lists, list |
224 | | - ), "Arguments to pooler must be lists" |
225 | | - assert ( |
226 | | - len(x) == num_level_assignments |
227 | | - ), "unequal value, num_level_assignments={}, but x is list of {} Tensors".format( |
228 | | - num_level_assignments, len(x) |
| 223 | + assert_fx_safe( |
| 224 | + isinstance(x, list) and isinstance(box_lists, list), "Arguments to pooler must be lists" |
229 | 225 | ) |
230 | | - |
231 | | - assert len(box_lists) == x[0].size( |
232 | | - 0 |
233 | | - ), "unequal value, x[0] batch dim 0 is {}, but box_list has length {}".format( |
234 | | - x[0].size(0), len(box_lists) |
| 226 | + assert_fx_safe( |
| 227 | + len(x) == num_level_assignments, |
| 228 | + "unequal value, num_level_assignments={}, but x is list of {} Tensors".format( |
| 229 | + num_level_assignments, len(x) |
| 230 | + ), |
| 231 | + ) |
| 232 | + assert_fx_safe( |
| 233 | + len(box_lists) == x[0].size(0), |
| 234 | + "unequal value, x[0] batch dim 0 is {}, but box_list has length {}".format( |
| 235 | + x[0].size(0), len(box_lists) |
| 236 | + ), |
235 | 237 | ) |
236 | 238 | if len(box_lists) == 0: |
237 | 239 | return _create_zeros(None, x[0].shape[1], *self.output_size, x[0]) |
|
0 commit comments