@@ -192,7 +192,6 @@ def _support_torch_compile(
192
192
# make sure super().__init__ is called on the base class
193
193
# other than TorchCompileWrapperWithCustomDispatcher
194
194
cls .__bases__ = cls .__bases__ + (TorchCompileWrapperWithCustomDispatcher , )
195
-
196
195
old_init = cls .__init__
197
196
198
197
setattr (cls , IGNORE_COMPILE_KEY , False )
@@ -222,20 +221,33 @@ def __init__(self, **kwargs):
222
221
return
223
222
224
223
compilation_counter .num_models_seen += 1
225
- TorchCompileWrapperWithCustomDispatcher .__init__ (
226
- self , compilation_level = vllm_config .compilation_config .level )
224
+ if not hasattr (self .__class__ , "compiled_callable" ):
225
+ print (f"init self for { self .__class__ } " )
226
+ # only compile the same model once
227
+ # NOTE: this is probably not right, since parameters can change
228
+ # and cause us to fall over
229
+ TorchCompileWrapperWithCustomDispatcher .__init__ (
230
+ self , compilation_level = vllm_config .compilation_config .level )
231
+ self .__class__ .compiled_callable = self .compiled_callable
232
+ else :
233
+ print ("init reusing the callable" )
234
+ TorchCompileWrapperWithCustomDispatcher .__init__ (
235
+ self ,
236
+ self .__class__ .compiled_callable ,
237
+ compilation_level = vllm_config .compilation_config .level )
227
238
228
239
cls .__init__ = __init__
229
240
230
241
def __call__ (self , * args , ** kwargs ):
242
+ print (f"Call to { self .__class__ } forward" )
231
243
# torch.compiler.is_compiling() means we are inside the compilation
232
244
# e.g. TPU has the compilation logic in model runner, so we don't
233
245
# need to compile the model inside.
234
246
if self .do_not_compile or torch .compiler .is_compiling ():
235
247
return self .forward (* args , ** kwargs )
236
248
237
249
# the first compilation needs to have dynamic shapes marked
238
- if len (self .compiled_codes ) < 1 :
250
+ if len (self .__class__ . compiled_codes ) < 1 :
239
251
sig = inspect .signature (self .__class__ .forward )
240
252
bound_args = sig .bind (self , * args , ** kwargs )
241
253
bound_args .apply_defaults ()
@@ -269,7 +281,8 @@ def __call__(self, *args, **kwargs):
269
281
# if we don't use custom dispatcher, we can directly call the
270
282
# compiled function and let torch.compile handle the dispatching,
271
283
# with the overhead of guard evaluation and recompilation.
272
- if len (self .compiled_codes ) < 1 or not self .use_custom_dispatcher :
284
+ if len (self .__class__ .compiled_codes
285
+ ) < 1 or not self .use_custom_dispatcher :
273
286
# it seems Dynamo reuse the compilation across instances,
274
287
# while we need to make sure the compiled code is not reused.
275
288
# we need to control all the compilation of the model.
0 commit comments