Skip to content

Commit fea0469

Browse files
Shirong WuWei Wei
authored andcommitted
Apply pass manager to lower (#55)
Summary: Pull Request resolved: https://github.com/pytorch/fx2trt/pull/55 Apply pass manager to lower flow Reviewed By: khabinov Differential Revision: D35518483 fbshipit-source-id: 48bc9c364cd006cc5a2c1b04d667987827f0a4d4
1 parent c4d4c7e commit fea0469

File tree

2 files changed

+117
-104
lines changed

2 files changed

+117
-104
lines changed

fx/lower.py

Lines changed: 48 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import dataclasses as dc
22
import logging
3-
from typing import Callable, List, Any, Sequence, Type, Set, Optional, Tuple, NamedTuple
3+
from typing import Callable, Any, Sequence
44

55
import fx2trt_oss.tracer.acc_tracer.acc_tracer as acc_tracer
66

@@ -9,7 +9,7 @@
99
import torch
1010
import torch.fx as fx
1111
import torch.nn as nn
12-
from fx2trt_oss.fx.observer import Observer
12+
1313
from torch.fx.passes.splitter_base import SplitResult
1414

1515
from .fx2trt import (
@@ -21,9 +21,6 @@
2121
)
2222
from .passes.pass_utils import chain_passes, PassFunc
2323
from .passes.lower_pass_manager_builder import LowerPassManagerBuilder, LowerPassContext
24-
from .passes.remove_duplicate_output_args import (
25-
remove_duplicate_output_args,
26-
)
2724
from .tools.timing_cache_utils import (
2825
TimingCacheManager,
2926
)
@@ -40,39 +37,6 @@
4037

4138
Input = Sequence[Any]
4239

43-
# ----------------------------------------------------------------------
44-
# OBSERVERS
45-
# ----------------------------------------------------------------------
46-
# List of observers. We can subscribe to them by calling its `add(callback)`
47-
# function from anywhere in code:
48-
#
49-
# >>> from fx2trt_oss.fx.lower import FUSE_PASSES_POST_OBSERVER
50-
# >>> with FUSE_PASSES_POST_OBSERVER.add(print_module_and_input):
51-
# >>> # print_module_and_input will be called right after the fuse passes
52-
# >>> lower(module, sample_input)
53-
54-
# Observer for the model after the fuse passes.
55-
FUSE_PASSES_POST_OBSERVER: Observer[
56-
Callable[[nn.Module, Input], None]
57-
] = Observer("FUSE_PASSES_POST_OBSERVER")
58-
59-
# Observer for the TRT split submodules before lowering
60-
LOWER_SPLIT_PRE_OBSERVER: Observer[
61-
Callable[[str, nn.Module, Input], None]
62-
] = Observer("LOWER_SPLIT_PRE_OBSERVER")
63-
64-
# Observer for the TRT split submodules after lowering
65-
LOWER_SPLIT_POST_OBSERVER: Observer[
66-
Callable[[str, nn.Module, Input], None]
67-
] = Observer("LOWER_SPLIT_POST_OBSERVER")
68-
# ----------------------------------------------------------------------
69-
70-
71-
class PassContext(NamedTuple):
72-
input: Input
73-
lower_setting: "LowerSetting"
74-
module_name: str = ""
75-
7640

7741
def lower_to_trt(
7842
module: nn.Module,
@@ -119,16 +83,6 @@ def lower_to_trt(
11983
return lowerer(module, input)
12084

12185

122-
def default_split_function(model: fx.GraphModule, inputs: Input, lower_setting: LowerSetting, min_acc_module_size: int = 10) -> SplitResult:
123-
splitter_setting = TRTSplitterSetting()
124-
splitter_setting.use_implicit_batch_dim = not lower_setting.explicit_batch_dimension
125-
# TODO: avoid hardcode here by introducing another flag in lowering setting.
126-
splitter_setting.min_acc_module_size = min_acc_module_size
127-
splitter = TRTSplitter(model, inputs, settings=splitter_setting)
128-
splitter.node_support_preview()
129-
return splitter.generate_split_results()
130-
131-
13286
@dc.dataclass
13387
class LowerTrtInterpreter:
13488
lower_setting: LowerSetting
@@ -194,6 +148,41 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult:
194148
return interp_result
195149

196150

151+
def default_split_function(model: fx.GraphModule, inputs: Input, lower_setting: LowerSetting, min_acc_module_size: int = 10) -> SplitResult:
152+
splitter_setting = TRTSplitterSetting()
153+
splitter_setting.use_implicit_batch_dim = not lower_setting.explicit_batch_dimension
154+
# TODO: avoid hardcode here by introducing another flag in lowering setting.
155+
splitter_setting.min_acc_module_size = min_acc_module_size
156+
splitter = TRTSplitter(model, inputs, settings=splitter_setting)
157+
splitter.node_support_preview()
158+
return splitter.generate_split_results()
159+
160+
161+
def create_lower_trt_interpreter(lower_setting: LowerSetting) -> LowerTrtInterpreter:
162+
return LowerTrtInterpreter.create(lower_setting)
163+
164+
165+
def default_lower_pass(
166+
create_trt_interpreter: Callable[[LowerSetting], LowerTrtInterpreter],
167+
) -> PassFunc:
168+
169+
def lower_pass(mod: nn.Module, input: Input, lower_setting: LowerSetting, module_name: str) -> nn.Module:
170+
"""
171+
Create a module transformation pass which lowers an `fx.GraphModule` into a
172+
`TRTModule`
173+
"""
174+
interpreter = create_trt_interpreter(lower_setting)
175+
interp_res: TRTInterpreterResult = interpreter(mod, input, module_name)
176+
trt_module = TRTModule(
177+
engine=interp_res.engine,
178+
input_names=interp_res.input_names,
179+
output_names=interp_res.output_names,
180+
cuda_graph_batch_size=lower_setting.cuda_graph_batch_size,
181+
)
182+
return trt_module
183+
return lower_pass
184+
185+
197186
@dc.dataclass(frozen=True)
198187
class Lowerer:
199188
"""Lowers a module using fx2trt.
@@ -214,8 +203,6 @@ class Lowerer:
214203
4. Wraps the executable TRT engine into `TRTModule`, which is an `nn.Module`.
215204
5. The converted submodule is then set back onto the top-level module
216205
217-
# TODO: @kefeilu: also incorporates a validator to do inference (and optionally)
218-
# result comparison along the way.
219206
220207
Attributes:
221208
trace_func: fx trace function for TRT conversion.
@@ -227,9 +214,10 @@ class Lowerer:
227214

228215
trace_func: Callable[[nn.Module, Input], fx.GraphModule]
229216
split_func: Callable[[fx.GraphModule, Input, LowerSetting], SplitResult]
230-
lower_pass: PassFunc
217+
lower_func: PassFunc
231218
lower_setting: LowerSetting
232219

220+
233221
@classmethod
234222
def create(
235223
cls,
@@ -244,7 +232,7 @@ def create(
244232
ast_rewriter_allow_list=lower_setting.ast_rewriter_allow_list,
245233
leaf_module_list=lower_setting.leaf_module_list), # type: ignore[arg-type]
246234
split_func=default_split_function,
247-
lower_pass=create_lower_pass(create_lower_trt_interpreter),
235+
lower_func=default_lower_pass(create_lower_trt_interpreter),
248236
lower_setting=lower_setting,
249237
)
250238

@@ -264,53 +252,11 @@ def __call__(
264252
pm = LowerPassManagerBuilder(LowerPassContext(
265253
input=inputs,
266254
lower_setting=self.lower_setting,
267-
trace_func=self.trace_func)).build_lower_pipeline()
268-
traced_mod = pm(module)
269-
FUSE_PASSES_POST_OBSERVER.observe(traced_mod, inputs)
270-
271-
# Run split.
272-
split_result = self.split_func(traced_mod, inputs, self.lower_setting) # type: ignore[misc,operator]
273-
274-
# TesnorRT doesn't like duplicate outputs. Run this pass to eliminate such case.
275-
remove_duplicate_output_args(split_result.split_module, split_result.submodule_inputs.keys())
276-
277-
for submod_name, submod_inputs in split_result.submodule_inputs.items():
278-
submod = getattr(split_result.split_module, submod_name)
279-
280-
LOWER_SPLIT_PRE_OBSERVER.observe(submod_name, submod, submod_inputs)
281-
282-
# We only lower acc submodules.
283-
if not submod_name.startswith(split_result.non_acc_submodule_prefix):
284-
lowered_module, ctx = self.lower_pass(
285-
submod,
286-
PassContext(submod_inputs, self.lower_setting, submod_name),
287-
)
288-
setattr(split_result.split_module, submod_name, lowered_module)
289-
LOWER_SPLIT_POST_OBSERVER.observe(submod_name, lowered_module, ctx.input)
290-
291-
return split_result.split_module
292-
293-
294-
def create_lower_pass(
295-
create_trt_interpreter: Callable[[PassContext], LowerTrtInterpreter],
296-
) -> PassFunc:
297-
298-
def lower_pass(mod: nn.Module, ctx: PassContext) -> Tuple[nn.Module, PassContext]:
299-
"""
300-
Create a module transformation pass which lowers an `fx.GraphModule` into a
301-
`TRTModule`
302-
"""
303-
interpreter = create_trt_interpreter(ctx)
304-
interp_res: TRTInterpreterResult = interpreter(mod, ctx.input, ctx.module_name)
305-
trt_module = TRTModule(
306-
engine=interp_res.engine,
307-
input_names=interp_res.input_names,
308-
output_names=interp_res.output_names,
309-
cuda_graph_batch_size=ctx.lower_setting.cuda_graph_batch_size,
310-
)
311-
return trt_module, ctx
312-
return lower_pass
313-
314-
315-
def create_lower_trt_interpreter(ctx: PassContext) -> LowerTrtInterpreter:
316-
return LowerTrtInterpreter.create(ctx.lower_setting)
255+
trace_func=self.trace_func,
256+
split_func=self.split_func,
257+
lower_func=self.lower_func,
258+
),
259+
).build_lower_pipeline()
260+
lower_result = pm(module)
261+
262+
return lower_result

fx/passes/lower_pass_manager_builder.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,45 @@
11
from typing import Callable, Any, Sequence, NamedTuple
2-
from torch.fx.passes.pass_manager import PassManager
2+
from torch.fx.passes.pass_manager import PassManager, inplace_wrapper
33
from .lower_basic_pass import run_const_fold
44
from fx2trt_oss.fx.lower_setting import LowerSetting
5-
from functools import wraps
5+
from functools import partial, wraps
66
import torch
77
from torch.fx.passes.shape_prop import ShapeProp
8+
from fx2trt_oss.fx.passes.remove_duplicate_output_args import remove_duplicate_output_args
9+
from torch.fx.passes.splitter_base import SplitResult
10+
from torch import nn
11+
from fx2trt_oss.fx.observer import Observer
812

913
Input = Sequence[Any]
1014

1115

16+
# ----------------------------------------------------------------------
17+
# OBSERVERS
18+
# ----------------------------------------------------------------------
19+
# List of observers. We can subscribe to them by calling its `add(callback)`
20+
# function from anywhere in code:
21+
#
22+
# >>> from fx2trt_oss.fx.lower import FUSE_PASSES_POST_OBSERVER
23+
# >>> with FUSE_PASSES_POST_OBSERVER.add(print_module_and_input):
24+
# >>> # print_module_and_input will be called right after the fuse passes
25+
# >>> lower(module, sample_input)
26+
27+
# Observer for the model after the fuse passes.
28+
FUSE_PASSES_POST_OBSERVER: Observer[
29+
Callable[[nn.Module, Input], None]
30+
] = Observer("FUSE_PASSES_POST_OBSERVER")
31+
32+
# Observer for the TRT split submodules before lowering
33+
LOWER_SPLIT_PRE_OBSERVER: Observer[
34+
Callable[[str, nn.Module, Input], None]
35+
] = Observer("LOWER_SPLIT_PRE_OBSERVER")
36+
37+
# Observer for the TRT split submodules after lowering
38+
LOWER_SPLIT_POST_OBSERVER: Observer[
39+
Callable[[str, nn.Module, Input], None]
40+
] = Observer("LOWER_SPLIT_POST_OBSERVER")
41+
# ----------------------------------------------------------------------
42+
1243
class LowerPassContext(NamedTuple):
1344
"""
1445
Args:
@@ -21,6 +52,8 @@ class LowerPassContext(NamedTuple):
2152
input: Input
2253
lower_setting: "LowerSetting"
2354
trace_func: Callable
55+
split_func: Callable
56+
lower_func: Callable
2457

2558
def wrapper(fn: Callable, input) -> Callable:
2659
@wraps(fn)
@@ -52,14 +85,48 @@ def graph_optimization_pass(self) -> PassManager:
5285
]
5386
for p in self._build_context.lower_setting.customized_fuse_pass:
5487
passes.append(wrapper(p, self._build_context.input))
88+
for p in self._build_context.lower_setting.lower_basic_fuse_pass:
89+
passes.append(wrapper(p, self._build_context.input))
90+
passes.append(inplace_wrapper(partial(FUSE_PASSES_POST_OBSERVER.observe, self._build_context.input)))
91+
92+
return PassManager.build_from_passlist(passes)
93+
94+
95+
def _split_pass(self) -> PassManager:
96+
passes = [partial(self._build_context.split_func, inputs=self._build_context.input, lower_setting=self._build_context.lower_setting)]
97+
passes.append(inplace_wrapper(
98+
lambda split_result: remove_duplicate_output_args(
99+
split_result.split_module,
100+
split_result.submodule_inputs.keys()
101+
)
102+
))
55103
return PassManager.build_from_passlist(passes)
56104

57105

106+
def _lower_pass(self) -> PassManager:
107+
def lower_func(split_result: SplitResult) -> nn.Module:
108+
for submod_name, submod_inputs in split_result.submodule_inputs.items():
109+
submod = getattr(split_result.split_module, submod_name)
110+
111+
LOWER_SPLIT_PRE_OBSERVER.observe(submod_name, submod, submod_inputs)
112+
113+
# Only acc submodules will be lowered.
114+
if not submod_name.startswith(split_result.non_acc_submodule_prefix):
115+
lowered_module = self._build_context.lower_func(submod, submod_inputs, self._build_context.lower_setting, submod_name)
116+
setattr(split_result.split_module, submod_name, lowered_module)
117+
LOWER_SPLIT_POST_OBSERVER.observe(submod_name, lowered_module, submod_inputs)
118+
119+
return split_result.split_module
120+
return PassManager.build_from_passlist([lower_func])
121+
122+
58123
def build_lower_pipeline(self) -> PassManager:
59124
passes = []
60125

61126
passes.append(self._const_fold_pass())
62127
passes.append(self.graph_optimization_pass())
128+
passes.append(self._split_pass())
129+
passes.append(self._lower_pass())
63130

64131
pm = PassManager.build_from_passlist(passes)
65132
return pm

0 commit comments

Comments
 (0)