Skip to content

Commit 0364a68

Browse files
committed
See title; consolidate so that we only compile a class once
Signed-off-by: Lucas Kabela <[email protected]>
1 parent db7f9a8 commit 0364a68

File tree

5 files changed

+32
-17
lines changed

5 files changed

+32
-17
lines changed

vllm/compilation/backends.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,6 @@ def configure_post_pass(self):
474474
inductor_config[PASS_KEY] = self.post_grad_pass_manager
475475

476476
def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
477-
478477
vllm_config = self.vllm_config
479478
if not self.compilation_config.cache_dir:
480479
# no provided cache dir, generate one based on the known factors

vllm/compilation/decorators.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,6 @@ def _support_torch_compile(
192192
# make sure super().__init__ is called on the base class
193193
# other than TorchCompileWrapperWithCustomDispatcher
194194
cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, )
195-
196195
old_init = cls.__init__
197196

198197
setattr(cls, IGNORE_COMPILE_KEY, False)
@@ -222,20 +221,33 @@ def __init__(self, **kwargs):
222221
return
223222

224223
compilation_counter.num_models_seen += 1
225-
TorchCompileWrapperWithCustomDispatcher.__init__(
226-
self, compilation_level=vllm_config.compilation_config.level)
224+
if not hasattr(self.__class__, "compiled_callable"):
225+
print(f"init self for {self.__class__}")
226+
# only compile the same model once
227+
# NOTE: this is probably not right, since parameters can change
228+
# and cause us to fall over
229+
TorchCompileWrapperWithCustomDispatcher.__init__(
230+
self, compilation_level=vllm_config.compilation_config.level)
231+
self.__class__.compiled_callable = self.compiled_callable
232+
else:
233+
print("init reusing the callable")
234+
TorchCompileWrapperWithCustomDispatcher.__init__(
235+
self,
236+
self.__class__.compiled_callable,
237+
compilation_level=vllm_config.compilation_config.level)
227238

228239
cls.__init__ = __init__
229240

230241
def __call__(self, *args, **kwargs):
242+
print(f"Call to {self.__class__} forward")
231243
# torch.compiler.is_compiling() means we are inside the compilation
232244
# e.g. TPU has the compilation logic in model runner, so we don't
233245
# need to compile the model inside.
234246
if self.do_not_compile or torch.compiler.is_compiling():
235247
return self.forward(*args, **kwargs)
236248

237249
# the first compilation needs to have dynamic shapes marked
238-
if len(self.compiled_codes) < 1:
250+
if len(self.__class__.compiled_codes) < 1:
239251
sig = inspect.signature(self.__class__.forward)
240252
bound_args = sig.bind(self, *args, **kwargs)
241253
bound_args.apply_defaults()
@@ -269,7 +281,8 @@ def __call__(self, *args, **kwargs):
269281
# if we don't use custom dispatcher, we can directly call the
270282
# compiled function and let torch.compile handle the dispatching,
271283
# with the overhead of guard evaluation and recompilation.
272-
if len(self.compiled_codes) < 1 or not self.use_custom_dispatcher:
284+
if len(self.__class__.compiled_codes
285+
) < 1 or not self.use_custom_dispatcher:
273286
# it seems Dynamo reuse the compilation across instances,
274287
# while we need to make sure the compiled code is not reused.
275288
# we need to control all the compilation of the model.

vllm/compilation/wrapper.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def __init__(self,
5353

5454
self.compiled_callable = compiled_callable
5555
self.original_code_object = self.__class__.forward.__code__
56-
self.compiled_codes: list[CodeType] = []
56+
self.__class__.compiled_codes = [] # type: ignore[attr-defined]
5757
torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook)
5858

5959
# read the env var to determine whether to use the custom dispatcher
@@ -91,7 +91,8 @@ def bytecode_hook(self, old_code: CodeType, new_code: CodeType):
9191
if frame.f_locals["self"] is not self:
9292
return
9393

94-
self.compiled_codes.append(new_code)
94+
self.__class__.compiled_codes.append( # type: ignore[attr-defined]
95+
new_code)
9596
debug_dump_dir = self.vllm_config.compilation_config.debug_dump_path
9697
if isinstance(debug_dump_dir, str) and debug_dump_dir != "":
9798
rank = self.vllm_config.parallel_config.rank
@@ -131,6 +132,7 @@ def dispatch_to_code(self, index: int):
131132
132133
See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7 for more details.
133134
""" # noqa
134-
self.__class__.forward.__code__ = self.compiled_codes[index]
135+
self.__class__.forward.__code__ = self.__class__.compiled_codes[ # type: ignore[attr-defined]
136+
index]
135137
yield
136138
self.__class__.forward.__code__ = self.original_code_object

vllm/model_executor/models/gemma3n.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1048,6 +1048,7 @@ def load_weights(self, weights: Iterable[tuple[str,
10481048

10491049

10501050
class Gemma3nForCausalLM(nn.Module):
1051+
10511052
packed_modules_mapping = {
10521053
"qkv_proj": [
10531054
"q_proj",

vllm/model_executor/models/qwen2_5_vl.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -518,10 +518,10 @@ def forward(
518518
return x
519519

520520

521-
@set_model_tag("Qwen2_5_VisionPatchEmbed")
522-
@support_torch_compile(dynamic_arg_dims={
523-
"x": 0,
524-
})
521+
# @set_model_tag("Qwen2_5_VisionPatchEmbed")
522+
# @support_torch_compile(dynamic_arg_dims={
523+
# "x": 0,
524+
# })
525525
class Qwen2_5_VisionPatchEmbed(nn.Module):
526526

527527
def __init__(
@@ -551,10 +551,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
551551
return x
552552

553553

554-
@set_model_tag("Qwen2_5_VisionPatchMerger")
555-
@support_torch_compile(dynamic_arg_dims={
556-
"x": 0,
557-
})
554+
# @set_model_tag("Qwen2_5_VisionPatchMerger")
555+
# @support_torch_compile(dynamic_arg_dims={
556+
# "x": 0,
557+
# })
558558
class Qwen2_5_VisionPatchMerger(nn.Module):
559559

560560
def __init__(

0 commit comments

Comments
 (0)