11import dataclasses as dc
22import logging
3- from typing import Callable , List , Any , Sequence , Type , Set , Optional , Tuple , NamedTuple
3+ from typing import Callable , Any , Sequence
44
55import fx2trt_oss .tracer .acc_tracer .acc_tracer as acc_tracer
66
99import torch
1010import torch .fx as fx
1111import torch .nn as nn
12- from fx2trt_oss . fx . observer import Observer
12+
1313from torch .fx .passes .splitter_base import SplitResult
1414
1515from .fx2trt import (
2121)
2222from .passes .pass_utils import chain_passes , PassFunc
2323from .passes .lower_pass_manager_builder import LowerPassManagerBuilder , LowerPassContext
24- from .passes .remove_duplicate_output_args import (
25- remove_duplicate_output_args ,
26- )
2724from .tools .timing_cache_utils import (
2825 TimingCacheManager ,
2926)
4037
4138Input = Sequence [Any ]
4239
43- # ----------------------------------------------------------------------
44- # OBSERVERS
45- # ----------------------------------------------------------------------
46- # List of observers. We can subscribe to them by calling its `add(callback)`
47- # function from anywhere in code:
48- #
49- # >>> from fx2trt_oss.fx.lower import FUSE_PASSES_POST_OBSERVER
50- # >>> with FUSE_PASSES_POST_OBSERVER.add(print_module_and_input):
51- # >>> # print_module_and_input will be called right after the fuse passes
52- # >>> lower(module, sample_input)
53-
54- # Observer for the model after the fuse passes.
55- FUSE_PASSES_POST_OBSERVER : Observer [
56- Callable [[nn .Module , Input ], None ]
57- ] = Observer ("FUSE_PASSES_POST_OBSERVER" )
58-
59- # Observer for the TRT split submodules before lowering
60- LOWER_SPLIT_PRE_OBSERVER : Observer [
61- Callable [[str , nn .Module , Input ], None ]
62- ] = Observer ("LOWER_SPLIT_PRE_OBSERVER" )
63-
64- # Observer for the TRT split submodules after lowering
65- LOWER_SPLIT_POST_OBSERVER : Observer [
66- Callable [[str , nn .Module , Input ], None ]
67- ] = Observer ("LOWER_SPLIT_POST_OBSERVER" )
68- # ----------------------------------------------------------------------
69-
70-
71- class PassContext (NamedTuple ):
72- input : Input
73- lower_setting : "LowerSetting"
74- module_name : str = ""
75-
7640
7741def lower_to_trt (
7842 module : nn .Module ,
@@ -119,16 +83,6 @@ def lower_to_trt(
11983 return lowerer (module , input )
12084
12185
122- def default_split_function (model : fx .GraphModule , inputs : Input , lower_setting : LowerSetting , min_acc_module_size : int = 10 ) -> SplitResult :
123- splitter_setting = TRTSplitterSetting ()
124- splitter_setting .use_implicit_batch_dim = not lower_setting .explicit_batch_dimension
125- # TODO: avoid hardcode here by introducing another flag in lowering setting.
126- splitter_setting .min_acc_module_size = min_acc_module_size
127- splitter = TRTSplitter (model , inputs , settings = splitter_setting )
128- splitter .node_support_preview ()
129- return splitter .generate_split_results ()
130-
131-
13286@dc .dataclass
13387class LowerTrtInterpreter :
13488 lower_setting : LowerSetting
@@ -194,6 +148,41 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult:
194148 return interp_result
195149
196150
151+ def default_split_function (model : fx .GraphModule , inputs : Input , lower_setting : LowerSetting , min_acc_module_size : int = 10 ) -> SplitResult :
152+ splitter_setting = TRTSplitterSetting ()
153+ splitter_setting .use_implicit_batch_dim = not lower_setting .explicit_batch_dimension
154+ # TODO: avoid hardcode here by introducing another flag in lowering setting.
155+ splitter_setting .min_acc_module_size = min_acc_module_size
156+ splitter = TRTSplitter (model , inputs , settings = splitter_setting )
157+ splitter .node_support_preview ()
158+ return splitter .generate_split_results ()
159+
160+
161+ def create_lower_trt_interpreter (lower_setting : LowerSetting ) -> LowerTrtInterpreter :
162+ return LowerTrtInterpreter .create (lower_setting )
163+
164+
165+ def default_lower_pass (
166+ create_trt_interpreter : Callable [[LowerSetting ], LowerTrtInterpreter ],
167+ ) -> PassFunc :
168+
169+ def lower_pass (mod : nn .Module , input : Input , lower_setting : LowerSetting , module_name : str ) -> nn .Module :
170+ """
171+ Create a module transformation pass which lowers an `fx.GraphModule` into a
172+ `TRTModule`
173+ """
174+ interpreter = create_trt_interpreter (lower_setting )
175+ interp_res : TRTInterpreterResult = interpreter (mod , input , module_name )
176+ trt_module = TRTModule (
177+ engine = interp_res .engine ,
178+ input_names = interp_res .input_names ,
179+ output_names = interp_res .output_names ,
180+ cuda_graph_batch_size = lower_setting .cuda_graph_batch_size ,
181+ )
182+ return trt_module
183+ return lower_pass
184+
185+
197186@dc .dataclass (frozen = True )
198187class Lowerer :
199188 """Lowers a module using fx2trt.
@@ -214,8 +203,6 @@ class Lowerer:
214203 4. Wraps the executable TRT engine into `TRTModule`, which is an `nn.Module`.
215204 5. The converted submodule is then set back onto the top-level module
216205
217- # TODO: @kefeilu: also incorporates a validator to do inference (and optionally)
218- # result comparison along the way.
219206
220207 Attributes:
221208 trace_func: fx trace function for TRT conversion.
@@ -227,9 +214,10 @@ class Lowerer:
227214
228215 trace_func : Callable [[nn .Module , Input ], fx .GraphModule ]
229216 split_func : Callable [[fx .GraphModule , Input , LowerSetting ], SplitResult ]
230- lower_pass : PassFunc
217+ lower_func : PassFunc
231218 lower_setting : LowerSetting
232219
220+
233221 @classmethod
234222 def create (
235223 cls ,
@@ -244,7 +232,7 @@ def create(
244232 ast_rewriter_allow_list = lower_setting .ast_rewriter_allow_list ,
245233 leaf_module_list = lower_setting .leaf_module_list ), # type: ignore[arg-type]
246234 split_func = default_split_function ,
247- lower_pass = create_lower_pass (create_lower_trt_interpreter ),
235+ lower_func = default_lower_pass (create_lower_trt_interpreter ),
248236 lower_setting = lower_setting ,
249237 )
250238
@@ -264,53 +252,11 @@ def __call__(
264252 pm = LowerPassManagerBuilder (LowerPassContext (
265253 input = inputs ,
266254 lower_setting = self .lower_setting ,
267- trace_func = self .trace_func )).build_lower_pipeline ()
268- traced_mod = pm (module )
269- FUSE_PASSES_POST_OBSERVER .observe (traced_mod , inputs )
270-
271- # Run split.
272- split_result = self .split_func (traced_mod , inputs , self .lower_setting ) # type: ignore[misc,operator]
273-
274- # TesnorRT doesn't like duplicate outputs. Run this pass to eliminate such case.
275- remove_duplicate_output_args (split_result .split_module , split_result .submodule_inputs .keys ())
276-
277- for submod_name , submod_inputs in split_result .submodule_inputs .items ():
278- submod = getattr (split_result .split_module , submod_name )
279-
280- LOWER_SPLIT_PRE_OBSERVER .observe (submod_name , submod , submod_inputs )
281-
282- # We only lower acc submodules.
283- if not submod_name .startswith (split_result .non_acc_submodule_prefix ):
284- lowered_module , ctx = self .lower_pass (
285- submod ,
286- PassContext (submod_inputs , self .lower_setting , submod_name ),
287- )
288- setattr (split_result .split_module , submod_name , lowered_module )
289- LOWER_SPLIT_POST_OBSERVER .observe (submod_name , lowered_module , ctx .input )
290-
291- return split_result .split_module
292-
293-
294- def create_lower_pass (
295- create_trt_interpreter : Callable [[PassContext ], LowerTrtInterpreter ],
296- ) -> PassFunc :
297-
298- def lower_pass (mod : nn .Module , ctx : PassContext ) -> Tuple [nn .Module , PassContext ]:
299- """
300- Create a module transformation pass which lowers an `fx.GraphModule` into a
301- `TRTModule`
302- """
303- interpreter = create_trt_interpreter (ctx )
304- interp_res : TRTInterpreterResult = interpreter (mod , ctx .input , ctx .module_name )
305- trt_module = TRTModule (
306- engine = interp_res .engine ,
307- input_names = interp_res .input_names ,
308- output_names = interp_res .output_names ,
309- cuda_graph_batch_size = ctx .lower_setting .cuda_graph_batch_size ,
310- )
311- return trt_module , ctx
312- return lower_pass
313-
314-
315- def create_lower_trt_interpreter (ctx : PassContext ) -> LowerTrtInterpreter :
316- return LowerTrtInterpreter .create (ctx .lower_setting )
255+ trace_func = self .trace_func ,
256+ split_func = self .split_func ,
257+ lower_func = self .lower_func ,
258+ ),
259+ ).build_lower_pipeline ()
260+ lower_result = pm (module )
261+
262+ return lower_result
0 commit comments