11import logging
2- from typing import Callable , List , Sequence , Tuple
2+ from typing import List , Sequence
33
44import torch
55from torch_tensorrt .dynamo .lowering .passes .pass_utils import (
66 clean_up_graph_after_modifications ,
7+ get_metadata ,
8+ set_metadata ,
79)
810
911logger = logging .getLogger (__name__ )
@@ -13,27 +15,25 @@ def view_to_reshape(
1315 gm : torch .fx .GraphModule , sample_inputs : Sequence [torch .Tensor ]
1416) -> torch .fx .GraphModule :
1517 """Replace aten.view with an equivalent implementation which avoids Tensor memory issues"""
16- orig , replacement = view_replacement ()
17-
18- if torch .fx .subgraph_rewriter .replace_pattern (gm , orig , replacement ):
19- gm = clean_up_graph_after_modifications (gm )
20- logger .debug (f"Graph after replacing view with reshape:\n { gm .graph } " )
21-
22- return gm
23-
24-
25- def view_replacement () -> Tuple [
26- torch .fx .GraphModule ,
27- Callable [[torch .Tensor , List [torch .SymInt ]], torch .Tensor ],
28- ]:
29- """Constructs the original and replacement functions for view"""
18+ orig_op = torch .ops .aten .view .default
19+ replacement_op = torch .ops .aten .reshape .default
3020
3121 # Original graph
3222 def orig (input : torch .Tensor , shape : List [torch .SymInt ]) -> torch .Tensor :
33- return torch . ops . aten . view . default (input , shape )
23+ return orig_op (input , shape )
3424
3525 # Replacement graph
3626 def replacement (input : torch .Tensor , shape : List [torch .SymInt ]) -> torch .Tensor :
37- return torch . ops . aten . reshape . default (input , shape )
27+ return replacement_op (input , shape )
3828
39- return orig , replacement
29+ # Store metadata of the orig_op
30+ metadata = get_metadata (gm , orig_op )
31+
32+ if torch .fx .subgraph_rewriter .replace_pattern (gm , orig , replacement ):
33+ gm = clean_up_graph_after_modifications (gm )
34+ logger .debug (f"Graph after replacing view with reshape:\n { gm .graph } " )
35+
36+ # Copy the orig_op's metadata to the replacement op
37+ set_metadata (gm , replacement_op , metadata )
38+
39+ return gm
0 commit comments