Skip to content

Commit 36a65a0

Browse files
Simon Hollisfacebook-github-bot
authored andcommitted
Enable torch tracing by changing assertions in d2go forwards to allow for torch.fx.proxy.Proxy type.
Summary: Pull Request resolved: #4227 X-link: facebookresearch/d2go#241 Torch FX tracing propagates a type of `torch.fx.proxy.Proxy` through the graph. Existing type assertions in the d2go code base trigger during torch FX tracing, causing tracing to fail. This adds a check for FX tracing in progress and adds a helper function `assert_fx_safe()`, that can be used in place of a standard assertion. This function only applies the assertion if one is not tracing, allowing d2go assertion tests to be compatible with FX tracing. Reviewed By: wat3rBro Differential Revision: D35518556 fbshipit-source-id: a9b5d3d580518ca74948544973ae89f8b9de3282
1 parent 5aeb252 commit 36a65a0

File tree

2 files changed

+65
-12
lines changed

2 files changed

+65
-12
lines changed

detectron2/modeling/poolers.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from detectron2.layers import ROIAlign, ROIAlignRotated, cat, nonzero_tuple, shapes_to_tensor
99
from detectron2.structures import Boxes
10+
from detectron2.utils.tracing import assert_fx_safe
1011

1112
"""
1213
To export ROIPooler to torchscript, in this file, variables that should be annotated with
@@ -219,19 +220,20 @@ def forward(self, x: List[torch.Tensor], box_lists: List[Boxes]):
219220
"""
220221
num_level_assignments = len(self.level_poolers)
221222

222-
assert isinstance(x, list) and isinstance(
223-
box_lists, list
224-
), "Arguments to pooler must be lists"
225-
assert (
226-
len(x) == num_level_assignments
227-
), "unequal value, num_level_assignments={}, but x is list of {} Tensors".format(
228-
num_level_assignments, len(x)
223+
assert_fx_safe(
224+
isinstance(x, list) and isinstance(box_lists, list), "Arguments to pooler must be lists"
229225
)
230-
231-
assert len(box_lists) == x[0].size(
232-
0
233-
), "unequal value, x[0] batch dim 0 is {}, but box_list has length {}".format(
234-
x[0].size(0), len(box_lists)
226+
assert_fx_safe(
227+
len(x) == num_level_assignments,
228+
"unequal value, num_level_assignments={}, but x is list of {} Tensors".format(
229+
num_level_assignments, len(x)
230+
),
231+
)
232+
assert_fx_safe(
233+
len(box_lists) == x[0].size(0),
234+
"unequal value, x[0] batch dim 0 is {}, but box_list has length {}".format(
235+
x[0].size(0), len(box_lists)
236+
),
235237
)
236238
if len(box_lists) == 0:
237239
return _create_zeros(None, x[0].shape[1], *self.output_size, x[0])

detectron2/utils/tracing.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import inspect
2+
from typing import Union
3+
import torch
4+
from torch.fx._symbolic_trace import _orig_module_call
5+
from torch.fx._symbolic_trace import is_fx_tracing as is_fx_tracing_current
6+
7+
from detectron2.utils.env import TORCH_VERSION
8+
9+
10+
@torch.jit.ignore
11+
def is_fx_tracing_legacy() -> bool:
12+
"""
13+
Returns a bool indicating whether torch.fx is currently symbolically tracing a module.
14+
Can be useful for gating module logic that is incompatible with symbolic tracing.
15+
"""
16+
return torch.nn.Module.__call__ is not _orig_module_call
17+
18+
19+
@torch.jit.ignore
20+
def is_fx_tracing() -> bool:
21+
"""Returns whether execution is currently in
22+
Torch FX tracing mode"""
23+
if TORCH_VERSION >= (1, 10):
24+
return is_fx_tracing_current()
25+
else:
26+
return is_fx_tracing_legacy()
27+
28+
29+
@torch.jit.ignore
30+
def assert_fx_safe(condition: Union[bool, str], message: str):
31+
"""An FX-tracing safe version of assert.
32+
Avoids erroneous type assertion triggering when types are masked inside
33+
an fx.proxy.Proxy object during tracing.
34+
Args: condition - either a boolean expression or a string representing
35+
the condition to test. If this assert triggers an exception when tracing
36+
due to dynamic control flow, try encasing the expression in quotation
37+
marks and supplying it as a string."""
38+
if not is_fx_tracing():
39+
try:
40+
if isinstance(condition, str):
41+
caller_frame = inspect.currentframe().f_back
42+
torch._assert(
43+
eval(condition, caller_frame.f_globals, caller_frame.f_locals), message
44+
)
45+
else:
46+
torch._assert(condition, message)
47+
except torch.fx.proxy.TraceError as e:
48+
print(
49+
"Found a non-FX compatible assertion. Skipping the check. Failure is shown below"
50+
+ str(e)
51+
)

0 commit comments

Comments
 (0)