@@ -49,6 +49,9 @@ def compile(
49
49
cuda_graph_batch_size = - 1 ,
50
50
is_aten = False ,
51
51
use_experimental_fx_rt = False ,
52
+ max_aux_streams = None ,
53
+ version_compatible = False ,
54
+ optimization_level = None ,
52
55
num_avg_timing_iters = 1 ,
53
56
torch_executed_ops = [],
54
57
torch_executed_modules = [],
@@ -68,14 +71,12 @@ def compile(
68
71
save_timing_cache: Update timing cache with current timing cache data if set to True.
69
72
cuda_graph_batch_size: Cuda graph batch size, default to be -1.
70
73
use_experimental_fx_rt: Uses the next generation TRTModule which supports both Python and TorchScript based execution (including in C++).
74
+ max_aux_streams: max number of aux stream to use
75
+ version_compatible: enable version compatible feature
76
+ optimization_level: builder optimization level
71
77
Returns:
72
78
A torch.nn.Module lowered by TensorRT.
73
79
"""
74
- if use_experimental_fx_rt and not explicit_batch_dimension :
75
- raise ValueError (
76
- "The experimental unifed runtime only supports explicit batch. Please make sure to set explicit_batch_dimension=True when use_experimental_fx_rt=True"
77
- )
78
-
79
80
logger .warn (
80
81
"For ir=fx_ts_compat backend only the "
81
82
+ "following arguments are supported: "
@@ -123,6 +124,9 @@ def compile(
123
124
cuda_graph_batch_size = cuda_graph_batch_size ,
124
125
is_aten = is_aten ,
125
126
use_experimental_rt = use_experimental_fx_rt ,
127
+ max_aux_streams = max_aux_streams ,
128
+ version_compatible = version_compatible ,
129
+ optimization_level = optimization_level ,
126
130
)
127
131
lowerer = Lowerer .create (lower_setting = lower_setting )
128
132
return lowerer (module , inputs )
@@ -162,8 +166,6 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult:
162
166
interpreter = TRTInterpreter (
163
167
mod ,
164
168
input_specs = self .lower_setting .input_specs ,
165
- explicit_batch_dimension = self .lower_setting .explicit_batch_dimension ,
166
- explicit_precision = self .lower_setting .explicit_precision ,
167
169
logger_level = trt .Logger .VERBOSE
168
170
if self .lower_setting .debug
169
171
else trt .Logger .WARNING ,
@@ -198,7 +200,7 @@ def default_split_function(
198
200
model : fx .GraphModule , inputs : Input , lower_setting : LowerSetting
199
201
) -> SplitResult :
200
202
splitter_setting = TRTSplitterSetting ()
201
- splitter_setting .use_implicit_batch_dim = not lower_setting . explicit_batch_dimension
203
+ splitter_setting .use_implicit_batch_dim = False
202
204
splitter_setting .min_block_size = lower_setting .min_block_size
203
205
splitter_setting .use_experimental_rt = lower_setting .use_experimental_rt
204
206
splitter = TRTSplitter (model , inputs , settings = splitter_setting )
0 commit comments