11from  dataclasses  import  dataclass 
2- from  typing  import  Any , Callable , Dict , Type 
2+ from  typing  import  Any , Callable , Dict , Optional ,  Type ,  Union 
33import  torch 
44import  logging 
55
88
99
1010@dataclass (frozen = True ) 
11- class  ModuleReplacement :
11+ class  Substitution :
1212    """Class to store key functionality for module replacement""" 
1313
1414    # torch.ops.___ name for replacement function for module 
1515    new_operator : torch ._ops .OpOverload 
1616
17-     # Function taking a containing graph, a submodule, and a 'call_module' node and returning 
18-     # a replacement node, with type 'call_function', or raising an Error if incompatibility is detected 
17+     # Function taking a containing graph, a node, and optionally a submodule (if replacing a module) 
18+     # and returning a replacement node, with type 'call_function', or raising an Error if 
19+     # incompatibility is detected 
1920    # Note: subgraph_insertion_fn should NOT delete nodes or recompile the graph 
2021    subgraph_insertion_fn : Callable [
21-         [torch .fx .GraphModule , torch .nn . Module ,  torch .fx . Node ], torch .fx .Node 
22+         [torch .fx .GraphModule , torch .fx . Node ,  Optional [ torch .nn . Module ] ], torch .fx .Node 
2223    ]
2324
2425
25- # Dictionary mapping module to ModuleReplacement instance 
26- MODULE_SUBSTITUTION_REGISTRY : Dict [Type [torch .nn .Module ], ModuleReplacement ] =  dict ()
26+ # Dictionary mapping module to Substitution instance 
27+ SUBSTITUTION_REGISTRY : Dict [
28+     Union [Type [torch .nn .Module ], Callable ], Substitution 
29+ ] =  dict ()
2730
2831
29- def  module_substitution (
30-     module_to_replace :  Type [torch .nn .Module ],
32+ def  register_substitution (
33+     module_or_function_to_replace :  Union [ Type [torch .nn .Module ],  Callable ],
3134    new_operator : torch ._ops .OpOverload ,
3235    enabled : bool  =  True ,
3336) ->  Callable [[Any ], Any ]:
3437    """Decorator to register subgraph insertion functions 
3538
3639    Args: 
37-         module_to_replace : nn.Module to replace 
40+         module_or_function_to_replace : nn.Module or node target Callable  to replace 
3841        new_operator: Custom torch operator to replace with 
3942        enabled: Whether the substitution is enabled or disabled 
4043    Returns: 
4144        torch.fx.GraphModule 
4245    """ 
4346
44-     def  register_substitution (subgraph_insertion_fn ):
47+     def  enable_substitution (subgraph_insertion_fn ):
4548        """Function for use if substitution is enabled""" 
46-         module_replacement  =  ModuleReplacement (
49+         replacement  =  Substitution (
4750            new_operator = new_operator , subgraph_insertion_fn = subgraph_insertion_fn 
4851        )
49-         MODULE_SUBSTITUTION_REGISTRY [ module_to_replace ] =  module_replacement 
52+         SUBSTITUTION_REGISTRY [ module_or_function_to_replace ] =  replacement 
5053        return  subgraph_insertion_fn 
5154
5255    def  disable_substitution (subgraph_insertion_fn ):
5356        """Function for use if substitution is disabled""" 
5457        return  subgraph_insertion_fn 
5558
56-     return  register_substitution  if  enabled  else  disable_substitution 
59+     return  enable_substitution  if  enabled  else  disable_substitution 
5760
5861
59- def  pre_aot_module_replacement (gm : torch .fx .GraphModule ):
60-     """Perform module-level  graph replacement  prior to AOT tracing 
62+ def  pre_aot_substitutions (gm : torch .fx .GraphModule ):
63+     """Perform graph substitutions  prior to AOT tracing 
6164
6265    Args: 
63-         gm: FX GraphModule to perform module replacement  on 
66+         gm: FX GraphModule to perform substitution  on 
6467    Returns: 
6568        torch.fx.GraphModule 
6669
@@ -71,48 +74,58 @@ def pre_aot_module_replacement(gm: torch.fx.GraphModule):
7174
7275    # Iterate over graph nodes, extracting module calls, to check for interceptions 
7376    for  n  in  gm .graph .nodes :
77+         exists_in_registry  =  False 
78+         to_replace  =  None 
79+ 
7480        if  n .op  ==  "call_module" :
75-             # Extract submodule from graph 
81+             # Extract submodule from graph, validate in registry  
7682            submodule  =  gm .get_submodule (n .target )
77- 
78-             # If submodule is a member of the substitution registry, replace it 
79-             if  type (submodule ) in  MODULE_SUBSTITUTION_REGISTRY :
80- 
81-                 try :
82-                     replacement  =  MODULE_SUBSTITUTION_REGISTRY [type (submodule )]
83-                     op , insertion_fn  =  (
84-                         replacement .new_operator ,
85-                         replacement .subgraph_insertion_fn ,
86-                     )
87-                     logger .debug (
88-                         f"Replacing module of type { type (submodule )}   with { op }  " 
83+             to_replace  =  type (submodule )
84+             exists_in_registry  =  to_replace  in  SUBSTITUTION_REGISTRY 
85+         elif  n .op  ==  "call_function" :
86+             # Extract function from graph, validate in registry 
87+             to_replace  =  n .target 
88+             exists_in_registry  =  n .target  in  SUBSTITUTION_REGISTRY 
89+ 
90+         # If submodule/function is a member of the substitution registry, replace it 
91+         if  exists_in_registry :
92+             try :
93+                 replacement  =  SUBSTITUTION_REGISTRY [to_replace ]
94+                 op , insertion_fn  =  (
95+                     replacement .new_operator ,
96+                     replacement .subgraph_insertion_fn ,
97+                 )
98+                 logger .debug (f"Replacing node of type { to_replace }   with { op }  " )
99+ 
100+                 # Insert new node prior to older node 
101+                 with  gm .graph .inserting_before (n ):
102+                     new_node  =  insertion_fn (
103+                         gm , n , submodule  if  n .op  ==  "call_module"  else  None 
89104                    )
90105
91-                     # Insert new node prior to older node 
92-                     with  gm .graph .inserting_before (n ):
93-                         new_node  =  insertion_fn (gm , submodule , n )
94- 
95-                     # If submodule is not a native torch.nn module, it must be manually excluded 
96-                     # from Dynamo tracing 
97-                     if  not  type (submodule ).__module__ .startswith ("torch.nn" ):
98-                         torch ._dynamo .allowed_functions ._allowed_function_ids .add (
99-                             id (type (submodule ))
100-                         )
101- 
102-                     # Replace all original node uses and clean up graph 
103-                     n .replace_all_uses_with (new_node )
104-                     gm .graph .eliminate_dead_code ()
105-                     gm .graph .lint ()
106-                     gm .recompile ()
107- 
108-                 # A module replacement can fail in the event that the specific instance of the submodule cannot 
109-                 # be replaced 
110-                 except  Exception :
111-                     logger .debug (
112-                         f"Encountered error while replacing { type (submodule )}  " ,
113-                         exc_info = True ,
106+                 # If submodule is not a native torch.nn module, it must be manually excluded 
107+                 # from Dynamo tracing 
108+                 if  n .op  ==  "call_module"  and  not  type (submodule ).__module__ .startswith (
109+                     "torch.nn" 
110+                 ):
111+                     torch ._dynamo .allowed_functions ._allowed_function_ids .add (
112+                         id (to_replace )
114113                    )
115-                     continue 
114+ 
115+                 # Replace all original node uses and clean up graph 
116+                 n .replace_all_uses_with (new_node )
117+                 gm .graph .eliminate_dead_code ()
118+                 gm .graph .lint ()
119+                 gm .recompile ()
120+ 
121+             # A replacement can fail in the event that the specific instance of the submodule/function 
122+             # cannot be replaced 
123+             except  Exception :
124+                 logger .debug (
125+                     f"Encountered error while replacing { to_replace }  " ,
126+                     exc_info = True ,
127+                 )
128+                 continue 
116129
117130    # Perform cleanup and recompilation before returning module 
118131    gm .graph .eliminate_dead_code ()
0 commit comments