From 678c291c3708549329559be5a21c13377ad58ced Mon Sep 17 00:00:00 2001 From: Linkun Chen Date: Sat, 16 Nov 2024 21:22:36 +0000 Subject: [PATCH 01/25] Patch multi_modal_placeholders to RequestOutput * confirm that `offline_inference_vision_language.py` and `benchmark_throughput.py` runs * FIXME: the placeholders in output, however, is empty - will fix in next commit Signed-off-by: Linkun Chen --- vllm/outputs.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/vllm/outputs.py b/vllm/outputs.py index badf50d0602d..205a479e8355 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -5,6 +5,7 @@ from typing import Union from vllm.lora.request import LoRARequest +from vllm.multimodal.inputs import MultiModalPlaceholderDict from vllm.sampling_params import RequestOutputKind from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs, SequenceGroup, SequenceGroupBase, SequenceStatus) @@ -95,6 +96,7 @@ def __init__( request_id: str, prompt: Optional[str], prompt_token_ids: Optional[List[int]], + multi_modal_placeholders: MultiModalPlaceholderDict, prompt_logprobs: Optional[PromptLogprobs], outputs: List[CompletionOutput], finished: bool, @@ -107,6 +109,7 @@ def __init__( self.request_id = request_id self.prompt = prompt self.prompt_token_ids = prompt_token_ids + self.mutli_modal_placeholders = multi_modal_placeholders self.prompt_logprobs = prompt_logprobs self.outputs = outputs self.finished = finished @@ -141,6 +144,7 @@ def new( request_id=request_id, prompt=prompt, prompt_token_ids=prompt_token_ids, + multi_modal_placeholders=MultiModalPlaceholderDict(), prompt_logprobs=None, # TODO outputs=[completion_output], finished=finished, @@ -154,8 +158,7 @@ def from_seq_group( finished = seq_group.is_finished() if seq_group.request_id in seq_id_to_seq_group: - group: SequenceGroupBase = seq_id_to_seq_group[ - seq_group.request_id] + group: SequenceGroupBase = seq_id_to_seq_group[seq_group.request_id] if finished: group.finish_seq(seq_group) assembled_seq_group = group.maybe_assemble_group(seq_group) @@ -198,8 +201,8 @@ def from_seq_group( # num_cached_tokens should be the same for all the sequences num_cached_tokens = None for i, seq in enumerate(top_n_seqs): - output_text = seq.get_output_text_to_return( - text_buffer_length, delta) + output_text = seq.get_output_text_to_return(text_buffer_length, + delta) output_token_ids = seq.get_output_token_ids_to_return(delta) num_output_tokens = 1 if isinstance(output_token_ids, @@ -276,7 +279,8 @@ def from_seq_group( seq_group.set_finished_time(finished_time) init_args = (seq_group.request_id, prompt, prompt_token_ids, - prompt_logprobs, outputs, finished, seq_group.metrics, + seq_group.multi_modal_placeholders, prompt_logprobs, + outputs, finished, seq_group.metrics, seq_group.lora_request, encoder_prompt, encoder_prompt_token_ids, num_cached_tokens) @@ -293,6 +297,7 @@ def __repr__(self) -> str: return (f"RequestOutput(request_id={self.request_id}, " f"prompt={self.prompt!r}, " f"prompt_token_ids={self.prompt_token_ids}, " + f"multi_modal_placeholders={self.mutli_modal_placeholders}, " f"encoder_prompt={self.encoder_prompt!r}, " f"encoder_prompt_token_ids={self.encoder_prompt_token_ids}, " f"prompt_logprobs={self.prompt_logprobs}, " From a1cdcb32801a478ac2240546ddc896017bd54318 Mon Sep 17 00:00:00 2001 From: Linkun Chen Date: Sun, 17 Nov 2024 23:47:28 +0000 Subject: [PATCH 02/25] pipe multi_modal_placeholders from intput to final output * add test for pixtral * fix a typo Signed-off-by: Linkun Chen --- .../vision_language/test_pixtral.py | 82 ++++++++++++++++++- vllm/model_executor/models/pixtral.py | 16 +++- vllm/outputs.py | 11 +-- 3 files changed, 102 insertions(+), 7 deletions(-) diff --git a/tests/models/decoder_only/vision_language/test_pixtral.py b/tests/models/decoder_only/vision_language/test_pixtral.py index d8a98a0f84d3..bbae8138d44b 100644 --- a/tests/models/decoder_only/vision_language/test_pixtral.py +++ b/tests/models/decoder_only/vision_language/test_pixtral.py @@ -6,15 +6,20 @@ import uuid from dataclasses import asdict from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from transformers import AutoProcessor import pytest +from mistral_common.multimodal import download_image from mistral_common.protocol.instruct.messages import ImageURLChunk from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.tokens.tokenizers.mistral import MistralTokenizer from mistral_common.tokens.tokenizers.multimodal import image_from_chunk -from vllm import EngineArgs, LLMEngine, SamplingParams, TokensPrompt +from vllm import (EngineArgs, LLMEngine, SamplingParams, TokensPrompt, + TextPrompt, RequestOutput) +from vllm.logger import init_logger from vllm.multimodal import MultiModalDataBuiltins +from vllm.multimodal.inputs import PlaceholderRange from vllm.sequence import Logprob, SampleLogprobs from ....utils import VLLM_PATH, large_gpu_test @@ -23,6 +28,8 @@ if TYPE_CHECKING: from _typeshed import StrPath +logger = init_logger(__name__) + MODELS = ["mistralai/Pixtral-12B-2409"] IMG_URLS = [ "https://picsum.photos/id/237/400/300", @@ -49,6 +56,20 @@ def _create_msg_format(urls: List[str]) -> List[Dict[str, Any]]: }] +def _create_msg_format_hf(urls: List[str]) -> List[Dict[str, Any]]: + return [{ + "role": + "user", + "content": [{ + "type": "text", + "content": PROMPT, + }, *({ + "type": "image", + "image": download_image(url) + } for url in urls)], + }] + + def _create_engine_inputs(urls: List[str]) -> TokensPrompt: msg = _create_msg_format(urls) @@ -70,6 +91,23 @@ def _create_engine_inputs(urls: List[str]) -> TokensPrompt: return engine_inputs +def _create_engine_inputs_hf(urls: List[str]) -> TextPrompt: + msg = _create_msg_format_hf(urls) + + tokenizer = AutoProcessor.from_pretrained("mistral-community/pixtral-12b") + prompt = tokenizer.apply_chat_template(msg) + + images = [] + for chunk in msg[0]["content"]: + if chunk["type"] == "image": + images.append(chunk["image"]) + + mm_data = MultiModalDataBuiltins(image=images) + engine_inputs = TextPrompt(prompt=prompt, multi_modal_data=mm_data) + + return engine_inputs + + MSGS = [ _create_msg_format(IMG_URLS[:1]), _create_msg_format(IMG_URLS[:2]), @@ -191,3 +229,45 @@ def test_model_engine(vllm_runner, model: str, dtype: str) -> None: outputs_1_lst=logprobs, name_0="h100_ref", name_1="output") + + +@large_gpu_test(min_gb=24) +@pytest.mark.parametrize( + "prompt,expected_ranges", + [(_create_engine_inputs_hf(IMG_URLS[:1]), [{ + "offset": 10, + "length": 494 + }]), + (_create_engine_inputs_hf(IMG_URLS[1:4]), [{ + "offset": 10, + "length": 266 + }, { + "offset": 276, + "length": 1056 + }, { + "offset": 1332, + "length": 418 + }])]) +def test_multi_modal_placeholders( + vllm_runner, prompt, expected_ranges: list[PlaceholderRange]) -> None: + with vllm_runner( + "mistral-community/pixtral-12b", + max_model_len=8192, + limit_mm_per_prompt=LIMIT_MM_PER_PROMPT, + ) as vllm_model: + outputs = vllm_model.model.generate(prompt) + + assert len(outputs) == 1, f"{len(outputs)=}" + output: RequestOutput = outputs[0] + assert hasattr(output, + "multi_modal_placeholders"), f"{output.__dict__=}" + assert "image" in output.multi_modal_placeholders, \ + f"{output.multi_modal_placeholders.keys()=}" + image_placeholder_ranges: list[ + PlaceholderRange] = output.multi_modal_placeholders["image"] + assert len(image_placeholder_ranges) == len( + expected_ranges), f"{image_placeholder_ranges=}" + for real_range, expected_range in zip(image_placeholder_ranges, + expected_ranges): + assert real_range == expected_range, \ + f"{real_range=} {expected_range=}" diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index a3e30ea2dd29..790a260d43ec 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -30,6 +30,7 @@ from vllm.model_executor.models.utils import merge_multimodal_embeddings from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs +from vllm.multimodal.inputs import PlaceholderRange from vllm.multimodal.utils import (cached_get_tokenizer, consecutive_placeholder_ranges) from vllm.sequence import IntermediateTensors, SequenceData @@ -773,15 +774,28 @@ def input_processor_for_pixtral_hf( replace_tokens[-1] = image_end_id replace_tokens_list.append(replace_tokens) + reverse_offsets: List[int] = [] # Backward iteration for replacement without affecting known indices for placeholder_idx, replace_tokens in zip(reversed(placeholder_indices), reversed(replace_tokens_list)): + reverse_offsets.append( + len(new_token_ids) - placeholder_idx + len(replace_tokens)) new_token_ids[placeholder_idx:placeholder_idx + 1] = replace_tokens + placeholder_ranges: List[PlaceholderRange] = [] + for reverse_offset, replace_tokens in zip(reversed(reverse_offsets), + replace_tokens_list): + placeholder_ranges.append( + PlaceholderRange( + offset=len(new_token_ids) - reverse_offset, + length=len(replace_tokens), + )) + # NOTE: Create a defensive copy of the original inputs return token_inputs(prompt_token_ids=new_token_ids, prompt=new_prompt, - multi_modal_data=multi_modal_data) + multi_modal_data=multi_modal_data, + multi_modal_placeholders={"image": placeholder_ranges}) class PixtralHFMLP(nn.Module): diff --git a/vllm/outputs.py b/vllm/outputs.py index 205a479e8355..a02f1e97b5b5 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -109,7 +109,7 @@ def __init__( self.request_id = request_id self.prompt = prompt self.prompt_token_ids = prompt_token_ids - self.mutli_modal_placeholders = multi_modal_placeholders + self.multi_modal_placeholders = multi_modal_placeholders self.prompt_logprobs = prompt_logprobs self.outputs = outputs self.finished = finished @@ -158,7 +158,8 @@ def from_seq_group( finished = seq_group.is_finished() if seq_group.request_id in seq_id_to_seq_group: - group: SequenceGroupBase = seq_id_to_seq_group[seq_group.request_id] + group: SequenceGroupBase = seq_id_to_seq_group[ + seq_group.request_id] if finished: group.finish_seq(seq_group) assembled_seq_group = group.maybe_assemble_group(seq_group) @@ -201,8 +202,8 @@ def from_seq_group( # num_cached_tokens should be the same for all the sequences num_cached_tokens = None for i, seq in enumerate(top_n_seqs): - output_text = seq.get_output_text_to_return(text_buffer_length, - delta) + output_text = seq.get_output_text_to_return( + text_buffer_length, delta) output_token_ids = seq.get_output_token_ids_to_return(delta) num_output_tokens = 1 if isinstance(output_token_ids, @@ -297,7 +298,7 @@ def __repr__(self) -> str: return (f"RequestOutput(request_id={self.request_id}, " f"prompt={self.prompt!r}, " f"prompt_token_ids={self.prompt_token_ids}, " - f"multi_modal_placeholders={self.mutli_modal_placeholders}, " + f"multi_modal_placeholders={self.multi_modal_placeholders}, " f"encoder_prompt={self.encoder_prompt!r}, " f"encoder_prompt_token_ids={self.encoder_prompt_token_ids}, " f"prompt_logprobs={self.prompt_logprobs}, " From f60964a976b664ce9ddaf286e3a8cf8ba0850524 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 16 Nov 2024 10:45:26 -0800 Subject: [PATCH 03/25] [V1] Add code owners for V1 (#10397) Signed-off-by: Woosuk Kwon Signed-off-by: Linkun Chen --- .github/CODEOWNERS | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index cd721971d01d..3cb91fc0f823 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -3,13 +3,16 @@ # This lists cover the "core" components of vLLM that require careful review /vllm/attention/backends/abstract.py @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill -/vllm/core @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill -/vllm/engine/llm_engine.py @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill -/vllm/executor/executor_base.py @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill -/vllm/worker/worker_base.py @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill -/vllm/worker/worker.py @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill -/vllm/model_executor/layers/sampler.py @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill -CMakeLists.txt @tlrmchlsmth @WoosukKwon +/vllm/core @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill +/vllm/engine/llm_engine.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill +/vllm/executor/executor_base.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill +/vllm/worker/worker_base.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill +/vllm/worker/worker.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill +/vllm/model_executor/layers/sampler.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill +CMakeLists.txt @tlrmchlsmth + +# vLLM V1 +/vllm/v1 @WoosukKwon @robertgshaw2-neuralmagic @njhill @ywang96 @comaniac @alexm-neuralmagic # Test ownership /tests/async_engine @njhill @robertgshaw2-neuralmagic @simon-mo From 578e482b3eef1b816a7a32b943adf481b279018b Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 16 Nov 2024 18:02:14 -0800 Subject: [PATCH 04/25] [2/N][torch.compile] make compilation cfg part of vllm cfg (#10383) Signed-off-by: youkaichao Signed-off-by: Linkun Chen --- tests/compile/piecewise/test_simple.py | 8 +- tests/compile/piecewise/test_toy_llama.py | 22 +- tests/compile/test_basic_correctness.py | 2 +- tests/compile/test_full_graph.py | 2 +- tests/compile/test_fusion.py | 2 +- tests/compile/test_wrapper.py | 4 +- tests/compile/utils.py | 2 +- .../model_executor/test_enabled_custom_ops.py | 52 ++--- tests/tpu/test_compilation.py | 2 +- tests/tpu/test_custom_dispatcher.py | 2 +- vllm/compilation/backends.py | 20 +- vllm/compilation/config.py | 159 --------------- vllm/compilation/decorators.py | 10 +- vllm/compilation/fusion.py | 2 +- vllm/compilation/inductor_pass.py | 2 +- vllm/compilation/levels.py | 8 - vllm/compilation/wrapper.py | 11 +- vllm/config.py | 189 ++++++++++++++++++ vllm/envs.py | 13 -- vllm/model_executor/custom_op.py | 27 +-- vllm/model_executor/model_loader/loader.py | 7 +- vllm/platforms/interface.py | 20 +- vllm/platforms/tpu.py | 21 +- vllm/plugins/__init__.py | 30 ++- vllm/v1/worker/gpu_model_runner.py | 10 +- vllm/worker/model_runner.py | 7 +- vllm/worker/tpu_model_runner.py | 8 +- 27 files changed, 359 insertions(+), 283 deletions(-) delete mode 100644 vllm/compilation/config.py delete mode 100644 vllm/compilation/levels.py diff --git a/tests/compile/piecewise/test_simple.py b/tests/compile/piecewise/test_simple.py index c631850ecded..45f56cbbd4b1 100644 --- a/tests/compile/piecewise/test_simple.py +++ b/tests/compile/piecewise/test_simple.py @@ -11,8 +11,8 @@ from vllm.compilation.compile_context import set_compile_context from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile -from vllm.compilation.levels import CompilationLevel -from vllm.config import VllmConfig +from vllm.config import CompilationLevel, VllmConfig +from vllm.plugins import set_current_vllm_config from vllm.utils import direct_register_custom_op global_counter = 0 @@ -82,7 +82,9 @@ def test_simple_piecewise_compile(): os.environ["VLLM_TORCH_COMPILE_CONFIG"] = config os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE) - model = SillyModel(vllm_config=VllmConfig(), prefix='') + vllm_config = VllmConfig() + with set_current_vllm_config(vllm_config): + model = SillyModel(vllm_config=vllm_config, prefix='') inputs = torch.randn(100).cuda() diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index c363a587a818..8032304e9580 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -15,12 +15,10 @@ from torch.library import Library from vllm.compilation.compile_context import set_compile_context -from vllm.compilation.config import CompilationConfig from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile -from vllm.compilation.levels import CompilationLevel -from vllm.config import VllmConfig -from vllm.plugins import set_compilation_config +from vllm.config import CompilationConfig, CompilationLevel, VllmConfig +from vllm.plugins import set_compilation_config, set_current_vllm_config from vllm.utils import direct_register_custom_op # create a library to hold the custom op @@ -272,9 +270,11 @@ def run_model(llama_config, CompilationLevel.NO_COMPILATION) set_compilation_config(None) - model = LlamaModel(config=llama_config, - vllm_config=VllmConfig(), - prefix="").eval().cuda() + vllm_config = VllmConfig() + with set_current_vllm_config(vllm_config): + model = LlamaModel(config=llama_config, + vllm_config=vllm_config, + prefix="").eval().cuda() B = 16 # max batch size input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda() @@ -395,9 +395,11 @@ def benchmark(): else: set_compilation_config(None) - model = LlamaModel(config=llama_config, - vllm_config=VllmConfig(), - prefix="").eval().cuda().to(torch.bfloat16) + vllm_config = VllmConfig() + with set_current_vllm_config(vllm_config): + model = LlamaModel(config=llama_config, + vllm_config=vllm_config, + prefix="").eval().cuda().to(torch.bfloat16) B = 256 # max batch size input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda() diff --git a/tests/compile/test_basic_correctness.py b/tests/compile/test_basic_correctness.py index 833589ba5dc9..08747ebc58b7 100644 --- a/tests/compile/test_basic_correctness.py +++ b/tests/compile/test_basic_correctness.py @@ -3,7 +3,7 @@ import pytest -from vllm.compilation.levels import CompilationLevel +from vllm.config import CompilationLevel from vllm.utils import cuda_device_count_stateless from ..utils import compare_all_settings diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index f00334934cb4..4dfdfe21a67d 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -1,6 +1,6 @@ import pytest -from vllm.compilation.levels import CompilationLevel +from vllm.config import CompilationLevel from ..utils import fork_new_process_for_each_test from .utils import TEST_MODELS, check_full_graph_support diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index e4d3defafb95..4db79b070fd8 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -3,10 +3,10 @@ from compressed_tensors.quantization import FP8_DTYPE import vllm.envs as envs -from vllm.compilation.config import CompilationConfig from vllm.compilation.fusion import (FusionPass, find_auto_fn, find_auto_fn_maybe) from vllm.compilation.reshapes import RedundantReshapesPass +from vllm.config import CompilationConfig from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( apply_fp8_linear) diff --git a/tests/compile/test_wrapper.py b/tests/compile/test_wrapper.py index 3668c1fab6b8..74f66baaa5ea 100644 --- a/tests/compile/test_wrapper.py +++ b/tests/compile/test_wrapper.py @@ -3,6 +3,7 @@ import torch from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher +from vllm.config import CompilationLevel class MyMod(torch.nn.Module): @@ -18,7 +19,8 @@ class MyWrapper(TorchCompileWrapperWithCustomDispatcher): def __init__(self, model): self.model = model compiled_callable = torch.compile(self.forward, backend="eager") - super().__init__(compiled_callable) + super().__init__(compiled_callable, + compilation_level=CompilationLevel.DYNAMO_ONCE) def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None): # this is the function to be compiled diff --git a/tests/compile/utils.py b/tests/compile/utils.py index 222c63a342a4..729f10676888 100644 --- a/tests/compile/utils.py +++ b/tests/compile/utils.py @@ -4,7 +4,7 @@ from tests.quantization.utils import is_quant_method_supported from vllm import LLM, SamplingParams -from vllm.compilation.levels import CompilationLevel +from vllm.config import CompilationLevel from vllm.platforms import current_platform TEST_MODELS = [ diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py index af267f804ffa..c3219bc50646 100644 --- a/tests/model_executor/test_enabled_custom_ops.py +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -3,11 +3,13 @@ import pytest +from vllm.config import CompilationConfig, VllmConfig from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.activation import (GeluAndMul, ReLUSquaredActivation, SiluAndMul) from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.plugins import set_current_vllm_config # Registered subclass for test @@ -51,42 +53,40 @@ class Relu3(ReLUSquaredActivation): ]) def test_enabled_ops(env: str, torch_level: int, ops_enabled: List[int], default_on: bool): - os.environ["VLLM_CUSTOM_OPS"] = env os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(torch_level) + vllm_config = VllmConfig(compilation_config=CompilationConfig( + custom_ops=env.split(","))) + with set_current_vllm_config(vllm_config): + assert CustomOp.default_on() == default_on - # Reset default_on (computed once): - CustomOp.default_on.cache_clear() + ops_enabled = [bool(x) for x in ops_enabled] - assert CustomOp.default_on() == default_on + assert RMSNorm(1024).enabled() == ops_enabled[0] + assert CustomOp.op_registry["rms_norm"].enabled() == ops_enabled[0] - ops_enabled = [bool(x) for x in ops_enabled] + assert SiluAndMul().enabled() == ops_enabled[1] + assert CustomOp.op_registry["silu_and_mul"].enabled() == ops_enabled[1] - assert RMSNorm(1024).enabled() == ops_enabled[0] - assert CustomOp.op_registry["rms_norm"].enabled() == ops_enabled[0] + assert GeluAndMul().enabled() == ops_enabled[2] + assert CustomOp.op_registry["gelu_and_mul"].enabled() == ops_enabled[2] - assert SiluAndMul().enabled() == ops_enabled[1] - assert CustomOp.op_registry["silu_and_mul"].enabled() == ops_enabled[1] + # If registered, subclasses should follow their own name + assert Relu3().enabled() == ops_enabled[3] + assert CustomOp.op_registry["relu3"].enabled() == ops_enabled[3] - assert GeluAndMul().enabled() == ops_enabled[2] - assert CustomOp.op_registry["gelu_and_mul"].enabled() == ops_enabled[2] + # Unregistered subclass + class SiluAndMul2(SiluAndMul): + pass - # If registered, subclasses should follow their own name - assert Relu3().enabled() == ops_enabled[3] - assert CustomOp.op_registry["relu3"].enabled() == ops_enabled[3] - - # Unregistered subclass - class SiluAndMul2(SiluAndMul): - pass - - # Subclasses should not require registration - assert SiluAndMul2().enabled() == SiluAndMul().enabled() + # Subclasses should not require registration + assert SiluAndMul2().enabled() == SiluAndMul().enabled() @pytest.mark.parametrize( "env", ["all,none", "all,+rms_norm,all", "+rms_norm,-rms_norm"]) def test_enabled_ops_invalid(env: str): - os.environ["VLLM_CUSTOM_OPS"] = env - CustomOp.default_on.cache_clear() - - with pytest.raises(AssertionError): - RMSNorm(1024).enabled() + with pytest.raises(Exception): # noqa + vllm_config = VllmConfig(compilation_config=CompilationConfig( + custom_ops=env.split(","))) + with set_current_vllm_config(vllm_config): + RMSNorm(1024).enabled() diff --git a/tests/tpu/test_compilation.py b/tests/tpu/test_compilation.py index 86d9af88e49e..941abe17a337 100644 --- a/tests/tpu/test_compilation.py +++ b/tests/tpu/test_compilation.py @@ -5,7 +5,7 @@ import depyf -from vllm.compilation.levels import CompilationLevel +from vllm.config import CompilationLevel # disable custom dispatcher, let Dynamo takes over # all the control diff --git a/tests/tpu/test_custom_dispatcher.py b/tests/tpu/test_custom_dispatcher.py index 923d0f168080..53b10c06135a 100644 --- a/tests/tpu/test_custom_dispatcher.py +++ b/tests/tpu/test_custom_dispatcher.py @@ -1,6 +1,6 @@ import os -from vllm.compilation.levels import CompilationLevel +from vllm.config import CompilationLevel from ..utils import compare_two_settings diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 5682faa15806..22c613931f08 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -10,13 +10,12 @@ import torch.fx as fx import vllm.envs as envs +from vllm.config import CompilationConfig, CompilationLevel from vllm.logger import init_logger from vllm.utils import combine_fx_passes, weak_ref_tensors -from .config import CompilationConfig from .counter import compilation_counter from .fusion import FusionPass -from .levels import CompilationLevel from .reshapes import RedundantReshapesPass logger = init_logger(__name__) @@ -392,7 +391,10 @@ class VllmBackend: sym_tensor_indices: List[int] input_buffers: List[torch.Tensor] - def __init__(self, post_grad_passes: Sequence[Callable] = ()): + def __init__( + self, + compilation_configs: CompilationConfig, + ): global global_graph_pool if global_graph_pool is None: global_graph_pool = torch.cuda.graph_pool_handle() @@ -401,11 +403,13 @@ def __init__(self, post_grad_passes: Sequence[Callable] = ()): # streams, it might not be safe to share a global pool. # only investigate this when we use multiple streams self.graph_pool = global_graph_pool - self.post_grad_passes = post_grad_passes + self.post_grad_passes = [] self.sym_tensor_indices = [] self.input_buffers = [] + self.compilation_configs = compilation_configs + # `torch.compile` is JIT compiled, so we don't need to # do anything here @@ -437,10 +441,10 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: assert not self._called, "VllmBackend can only be called once" self.graph = graph - # config is read now, because only here can + # config is updated now, because only here can # we get the sizes to capture for cudagraph # from compilation context - self.compilation_configs = CompilationConfig.select_and_init_config() + self.compilation_configs.init_during_runtime() self.add_passes_to_config() self.split_gm, self.piecewise_graphs = split_graph( @@ -688,4 +692,6 @@ def select_default_backend(level: int) -> Union[str, Callable]: return backend_str assert level == CompilationLevel.PIECEWISE - return VllmBackend() + from vllm.plugins import get_current_vllm_config + compilation_config = get_current_vllm_config().compilation_config + return VllmBackend(compilation_config) diff --git a/vllm/compilation/config.py b/vllm/compilation/config.py deleted file mode 100644 index 3e663505c627..000000000000 --- a/vllm/compilation/config.py +++ /dev/null @@ -1,159 +0,0 @@ -import copy -from pathlib import Path -from typing import Any, Dict, List, Optional - -from pydantic import BaseModel, Field, PrivateAttr - -import vllm.envs as envs -from vllm.logger import init_logger - -from .compile_context import get_compile_context - -logger = init_logger(__name__) - - -class CompilationConfig(BaseModel): - """ - Configuration for compilation. - It has two parts: - - CudaGraph capture: - - use_cudagraph: whether to use cudagraph inside compilation. - - False: cudagraph inside compilation is not used. - - True: cudagraph inside compilation is used. It requires - that all input buffers have fixed addresses. - Note that this is orthogonal to the cudagraph capture out - side of compilation. - TODO: move outside cudagraph logic into compilation. - torch.compile will handle cudagraph capture logic in the future. - - cudagraph_capture_sizes: sizes to capture cudagraph. - - None: capture sizes are inferred from compilation context. - - List[int]: capture sizes are specified. - - cudagraph_num_of_warmups: number of warmup runs for cudagraph. - It means the first several runs will be treated as warmup runs. - Only after that, the execution will be recorded, and the recorded - cudagraph will be used for subsequent runs. - - cudagraph_copy_inputs: whether to copy input tensors for - cudagraph. If the caller can guarantee that the same input buffers - are always used, it can set this to False. Otherwise, it should - set this to True, and the compiler will copy the input to an - internally managed buffer. Default is False. - - Inductor compilation: - - use_inductor: whether to use inductor compilation. - - False: inductor compilation is not used. graph runs in eager. - - True: inductor compilation is used. one graph for symbolic shape - is compiled. In addition, compile for different sizes specified - in inductor_compile_sizes, using configurations - in inductor_compile_config. - - inductor_compile_sizes: sizes to compile for inductor. - - inductor_specialize_for_cudagraph_no_more_than: an optional integer - to specialize inductor for cudagraph sizes no more than the - specified size. It is useful when we want to specialize inductor - with a subset of cudagraph sizes. - - inductor_compile_config: additional configurations for inductor. - - None: use default configurations. - - inductor_passes: additional passes for inductor. It is a dictionary - from pass name to pass function qualified name. We use function - name because the config uses json format. If we pass the config - from Python, functions can also be passed directly via Python object - constructor, e.g. `CompilationConfig(inductor_passes={"a": func})` - - Custom inductor passes: - - dump_graph_stages: list of stages for which we want to dump the graph. - Each pass defines its own stages (before, after, maybe in-between). - - dump_graph_dir: directory to dump the graph. Default is . - - enable_fusion: whether to enable the custom fusion pass. - TODO better pass enabling system. - - Why we have different sizes for cudagraph and inductor: - - cudagraph: a cudagraph captured for a specific size can only be used - for the same size. We need to capture all the sizes we want to use. - - inductor: a graph compiled by inductor for a general shape can be used - for different sizes. Inductor can also compile for specific sizes, - where it can have more information to optimize the graph with fully - static shapes. However, we find the general shape compilation is - sufficient for most cases. It might be beneficial to compile for - certain small batchsizes, where inductor is good at optimizing. - """ - use_inductor: bool = True - inductor_specialize_for_cudagraph_no_more_than: Optional[int] = None - inductor_compile_sizes: Optional[List[int]] = Field(default_factory=dict) - inductor_compile_config: Dict = Field(default_factory=dict) - inductor_passes: Dict[str, str] = Field(default_factory=dict) - - use_cudagraph: bool = False - non_cudagraph_ops: List[str] = Field(default_factory=list) - cudagraph_num_of_warmups: int = 0 - cudagraph_capture_sizes: Optional[List[int]] = None - cudagraph_copy_inputs: bool = False - - dump_graph_stages: List[str] = Field(default_factory=list) - dump_graph_dir: Path = Field(default=Path(".")) - enable_fusion: bool = True - - # not configurable, computed after init - compile_sizes: List[int] = PrivateAttr - capture_sizes: List[int] = PrivateAttr - - def model_post_init(self, __context: Any) -> None: - for k, v in self.inductor_passes.items(): - if not isinstance(v, str): - assert callable(v), ( - f"pass {k} should be a function or a qualified name") - self.inductor_compile_config[k] = v - continue - - # resolve function from qualified name - names = v.split(".") - module = ".".join(names[:-1]) - func_name = names[-1] - func = __import__(module).__dict__[func_name] - self.inductor_compile_config[k] = func - - def init_during_runtime(self): - """To complete the initialization of config, - we need to know the compile context, which is only available - during the first run of the model. - """ - context = get_compile_context() - context = copy.deepcopy(context) if context is not None else [] - sizes_to_specialize: List[int] = context - if self.cudagraph_capture_sizes is None: - self.capture_sizes = sizes_to_specialize - else: - self.capture_sizes = self.cudagraph_capture_sizes - logger.info(("cudagraph sizes specified by model runner" - " %s is overridden by config %s"), - sizes_to_specialize, self.cudagraph_capture_sizes) - if self.inductor_specialize_for_cudagraph_no_more_than is not None: - assert self.inductor_compile_sizes is None, ( - "inductor_compile_sizes should be None when " - "inductor_specialize_for_cudagraph_no_more_than is not None") - self.compile_sizes = [ - x for x in self.capture_sizes - if x <= self.inductor_specialize_for_cudagraph_no_more_than - ] - else: - assert self.inductor_compile_sizes is not None, ( - "inductor_compile_sizes should not be None when " - "inductor_specialize_for_cudagraph_no_more_than is None") - self.compile_sizes = self.inductor_compile_sizes - - @staticmethod - def select_and_init_config() -> "CompilationConfig": - """The order of selecting config is: - 1. Use the config specified in environment variable. - 2. Use the config specified in plugins. - 3. Use the default config. - """ - config_path = envs.VLLM_TORCH_COMPILE_CONFIG - if config_path is not None: - with open(config_path) as json_file: - config = CompilationConfig.model_validate_json( - json_file.read()) - else: - from vllm.plugins import get_compilation_config - predefined_config = get_compilation_config() - config = predefined_config if predefined_config is not None else ( - CompilationConfig()) - - config.init_during_runtime() - return config diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index ca1e96a33c01..4b78491bc5a4 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -3,10 +3,8 @@ import torch -import vllm.envs as envs -from vllm.compilation.levels import CompilationLevel from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher -from vllm.config import VllmConfig +from vllm.config import CompilationLevel, VllmConfig from vllm.logger import init_logger from vllm.sequence import IntermediateTensors from vllm.utils import supports_dynamo @@ -126,12 +124,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs): old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs) # for CompilationLevel.DYNAMO_AS_IS , the upper level model runner # will handle the compilation, so we don't need to do anything here. - self.do_not_compile = envs.VLLM_TORCH_COMPILE_LEVEL in [ + self.do_not_compile = \ + vllm_config.compilation_config.level in [ CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS ] or not supports_dynamo() if self.do_not_compile: return - TorchCompileWrapperWithCustomDispatcher.__init__(self) + TorchCompileWrapperWithCustomDispatcher.__init__( + self, compilation_level=vllm_config.compilation_config.level) cls.__init__ = __init__ # type: ignore diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index eb43604b1399..e6a3afef85e1 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -6,8 +6,8 @@ from torch._inductor.pattern_matcher import (Match, PatternMatcherPass, fwd_only, register_replacement) -from vllm.compilation.config import CompilationConfig from vllm.compilation.inductor_pass import InductorPass +from vllm.config import CompilationConfig from vllm.logger import init_logger logger = init_logger(__name__) diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py index b23351fa1975..8082a08b4001 100644 --- a/vllm/compilation/inductor_pass.py +++ b/vllm/compilation/inductor_pass.py @@ -2,7 +2,7 @@ import torch -from vllm.compilation.config import CompilationConfig +from vllm.config import CompilationConfig # yapf: disable from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank from vllm.distributed import ( diff --git a/vllm/compilation/levels.py b/vllm/compilation/levels.py deleted file mode 100644 index 19a3a2b52687..000000000000 --- a/vllm/compilation/levels.py +++ /dev/null @@ -1,8 +0,0 @@ -# constants for the levels of the compilation process - - -class CompilationLevel: - NO_COMPILATION = 0 - DYNAMO_AS_IS = 1 - DYNAMO_ONCE = 2 - PIECEWISE = 3 diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index 7366ed4d16b0..2a1aecc11ce2 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -8,8 +8,7 @@ import torch import vllm.envs as envs - -from .levels import CompilationLevel +from vllm.config import CompilationLevel class TorchCompileWrapperWithCustomDispatcher: @@ -25,7 +24,9 @@ class TorchCompileWrapperWithCustomDispatcher: `torch.compile` over the forward method. """ - def __init__(self, compiled_callable: Optional[Callable] = None): + def __init__(self, + compiled_callable: Optional[Callable] = None, + compilation_level: int = 0): if compiled_callable is None: # default compilation settings @@ -38,7 +39,7 @@ def __init__(self, compiled_callable: Optional[Callable] = None): backend = get_torch_compile_backend() if backend is None: from vllm.compilation.backends import select_default_backend - backend = select_default_backend(envs.VLLM_TORCH_COMPILE_LEVEL) + backend = select_default_backend(compilation_level) compiled_callable = torch.compile( self.forward, @@ -54,7 +55,7 @@ def __init__(self, compiled_callable: Optional[Callable] = None): # subclasses can use this to switch between the custom dispatcher # and the default Dynamo guard mechanism. self.use_custom_dispatcher: bool = \ - envs.VLLM_TORCH_COMPILE_LEVEL >= CompilationLevel.DYNAMO_ONCE + compilation_level >= CompilationLevel.DYNAMO_ONCE def __call__(self, *args, **kwargs): """Implement the dispatch logic here, beyond the torch.compile level. diff --git a/vllm/config.py b/vllm/config.py index 64b2f75e092d..7e37edbe594b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3,10 +3,12 @@ import json import warnings from dataclasses import dataclass, field, replace +from pathlib import Path from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Dict, Final, List, Literal, Mapping, Optional, Set, Tuple, Type, Union) import torch +from pydantic import BaseModel, Field, PrivateAttr from transformers import PretrainedConfig import vllm.envs as envs @@ -2052,6 +2054,185 @@ def __post_init__(self): f"installed. Original error:\n{otel_import_error_traceback}") +class CompilationLevel: + # constants for the levels of the compilation process + NO_COMPILATION = 0 + DYNAMO_AS_IS = 1 + DYNAMO_ONCE = 2 + PIECEWISE = 3 + + +class CompilationConfig(BaseModel): + """ + Configuration for compilation. + It has three parts: + - Top-level Compilation control: + - level: the level of compilation. + - 0: no compilation. + - 1: dynamo as is. + - 2: dynamo once. + - 3: piecewise compilation. + - custom_ops: fine-grained control over which custom ops to enable/disable. + Use 'all' to enable all, 'none' to disable all. + Also specify a list of custom op names to enable (prefixed with a '+'), + or disable (prefixed with a '-'). + Examples: + - 'all,-op1' to enable all except op1 + - 'none,+op1,+op2' to enable only op1 and op2 + By default, all custom ops are enabled when running without Inductor + and disabled when running with Inductor (compile_level >= Inductor). + - CudaGraph capture: + - use_cudagraph: whether to use cudagraph inside compilation. + - False: cudagraph inside compilation is not used. + - True: cudagraph inside compilation is used. It requires + that all input buffers have fixed addresses. + Note that this is orthogonal to the cudagraph capture out + side of compilation. + TODO: move outside cudagraph logic into compilation. + torch.compile will handle cudagraph capture logic in the future. + - cudagraph_capture_sizes: sizes to capture cudagraph. + - None: capture sizes are inferred from compilation context. + - List[int]: capture sizes are specified. + - cudagraph_num_of_warmups: number of warmup runs for cudagraph. + It means the first several runs will be treated as warmup runs. + Only after that, the execution will be recorded, and the recorded + cudagraph will be used for subsequent runs. + - cudagraph_copy_inputs: whether to copy input tensors for + cudagraph. If the caller can guarantee that the same input buffers + are always used, it can set this to False. Otherwise, it should + set this to True, and the compiler will copy the input to an + internally managed buffer. Default is False. + - Inductor compilation: + - use_inductor: whether to use inductor compilation. + - False: inductor compilation is not used. graph runs in eager. + - True: inductor compilation is used. one graph for symbolic shape + is compiled. In addition, compile for different sizes specified + in inductor_compile_sizes, using configurations + in inductor_compile_config. + - inductor_compile_sizes: sizes to compile for inductor. + - inductor_specialize_for_cudagraph_no_more_than: an optional integer + to specialize inductor for cudagraph sizes no more than the + specified size. It is useful when we want to specialize inductor + with a subset of cudagraph sizes. + - inductor_compile_config: additional configurations for inductor. + - None: use default configurations. + - inductor_passes: additional passes for inductor. It is a dictionary + from pass name to pass function qualified name. We use function + name because the config uses json format. If we pass the config + from Python, functions can also be passed directly via Python object + constructor, e.g. `CompilationConfig(inductor_passes={"a": func})` + - custom inductor passes: + - dump_graph_stages: list of stages for which we want to dump the graph. + Each pass defines its own stages (before, after, maybe in-between). + - dump_graph_dir: directory to dump the graph. Default is . + - enable_fusion: whether to enable the custom fusion pass. + TODO better pass enabling system. + + Why we have different sizes for cudagraph and inductor: + - cudagraph: a cudagraph captured for a specific size can only be used + for the same size. We need to capture all the sizes we want to use. + - inductor: a graph compiled by inductor for a general shape can be used + for different sizes. Inductor can also compile for specific sizes, + where it can have more information to optimize the graph with fully + static shapes. However, we find the general shape compilation is + sufficient for most cases. It might be beneficial to compile for + certain small batchsizes, where inductor is good at optimizing. + """ # noqa + level: int = 0 + custom_ops: List[str] = Field(default_factory=list) + + use_inductor: bool = True + inductor_specialize_for_cudagraph_no_more_than: Optional[int] = None + inductor_compile_sizes: Optional[List[int]] = Field(default_factory=dict) + inductor_compile_config: Dict = Field(default_factory=dict) + inductor_passes: Dict[str, str] = Field(default_factory=dict) + + use_cudagraph: bool = False + non_cudagraph_ops: List[str] = Field(default_factory=list) + cudagraph_num_of_warmups: int = 0 + cudagraph_capture_sizes: Optional[List[int]] = None + cudagraph_copy_inputs: bool = False + + dump_graph_stages: List[str] = Field(default_factory=list) + dump_graph_dir: Path = Field(default=Path(".")) + enable_fusion: bool = True + + # not configurable, computed after init + compile_sizes: List[int] = PrivateAttr + capture_sizes: List[int] = PrivateAttr + + def model_post_init(self, __context: Any) -> None: + self.level = envs.VLLM_TORCH_COMPILE_LEVEL + + count_none = self.custom_ops.count("none") + count_all = self.custom_ops.count("all") + assert count_none + count_all <= 1, "Can only specify 'none' or 'all'" + + for k, v in self.inductor_passes.items(): + if not isinstance(v, str): + assert callable(v), ( + f"pass {k} should be a function or a qualified name") + self.inductor_compile_config[k] = v + continue + + # resolve function from qualified name + names = v.split(".") + module = ".".join(names[:-1]) + func_name = names[-1] + func = __import__(module).__dict__[func_name] + self.inductor_compile_config[k] = func + + def init_during_runtime(self): + """To complete the initialization of config, + we need to know the compile context, which is only available + during the first run of the model. + """ + from vllm.compilation.compile_context import get_compile_context + context = get_compile_context() + context = copy.deepcopy(context) if context is not None else [] + sizes_to_specialize: List[int] = context + if self.cudagraph_capture_sizes is None: + self.capture_sizes = sizes_to_specialize + else: + self.capture_sizes = self.cudagraph_capture_sizes + logger.info(("cudagraph sizes specified by model runner" + " %s is overridden by config %s"), + sizes_to_specialize, self.cudagraph_capture_sizes) + if self.inductor_specialize_for_cudagraph_no_more_than is not None: + assert self.inductor_compile_sizes is None, ( + "inductor_compile_sizes should be None when " + "inductor_specialize_for_cudagraph_no_more_than is not None") + self.compile_sizes = [ + x for x in self.capture_sizes + if x <= self.inductor_specialize_for_cudagraph_no_more_than + ] + else: + assert self.inductor_compile_sizes is not None, ( + "inductor_compile_sizes should not be None when " + "inductor_specialize_for_cudagraph_no_more_than is None") + self.compile_sizes = self.inductor_compile_sizes + + @staticmethod + def select_and_init_config() -> "CompilationConfig": + """The order of selecting config is: + 1. Use the config specified in environment variable. + 2. Use the config specified in plugins. + 3. Use the default config. + """ + config_path = envs.VLLM_TORCH_COMPILE_CONFIG + if config_path is not None: + with open(config_path) as json_file: + config = CompilationConfig.model_validate_json( + json_file.read()) + else: + from vllm.plugins import get_compilation_config + predefined_config = get_compilation_config() + config = predefined_config if predefined_config is not None else ( + CompilationConfig()) + + return config + + @dataclass class VllmConfig: """Dataclass which contains all vllm-related configuration. This @@ -2073,6 +2254,8 @@ class VllmConfig: observability_config: Optional[ObservabilityConfig] = None prompt_adapter_config: Optional[PromptAdapterConfig] = None quant_config: Optional[QuantizationConfig] = None + compilation_config: CompilationConfig = field(default=None, + init=True) # type: ignore @staticmethod def _get_quantization_config( @@ -2133,6 +2316,12 @@ def __post_init__(self): self.quant_config = VllmConfig._get_quantization_config( self.model_config, self.load_config) + if self.compilation_config is None: + self.compilation_config = CompilationConfig.select_and_init_config( + ) + + current_platform.check_and_update_config(self) + def __str__(self): return ("model=%r, speculative_config=%r, tokenizer=%r, " "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, " diff --git a/vllm/envs.py b/vllm/envs.py index f320e35971f9..716e835a555f 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -69,7 +69,6 @@ VLLM_SKIP_P2P_CHECK: bool = False VLLM_TORCH_COMPILE_LEVEL: int = 0 VLLM_TORCH_COMPILE_CONFIG: Optional[str] = None - VLLM_CUSTOM_OPS: List[str] = [] VLLM_DISABLED_KERNELS: List[str] = [] VLLM_USE_V1: bool = False VLLM_ENABLE_V1_MULTIPROCESSING: bool = False @@ -217,18 +216,6 @@ def get_default_config_root(): "VLLM_TORCH_COMPILE_CONFIG": lambda: os.environ.get("VLLM_TORCH_COMPILE_CONFIG", None), - # Fine-grained control over which custom ops to enable/disable. - # Use 'all' to enable all, 'none' to disable all. - # Also specify a list of custom op names to enable (prefixed with a '+'), - # or disable (prefixed with a '-'). - # Examples: - # - 'all,-op1' to enable all except op1 - # - 'none,+op1,+op2' to enable only op1 and op2 - # By default, all custom ops are enabled when running without Inductor - # and disabled when running with Inductor (compile_level >= Inductor). - "VLLM_CUSTOM_OPS": - lambda: os.environ.get("VLLM_CUSTOM_OPS", "").replace(" ", "").split(","), - # local rank of the process in the distributed setting, used to determine # the GPU device id "LOCAL_RANK": diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index 24d75f4df4e0..6ae7d7cf6964 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -1,12 +1,10 @@ -from functools import lru_cache from typing import Dict, Type import torch.nn as nn -import vllm.envs as envs -from vllm.compilation.levels import CompilationLevel from vllm.logger import init_logger from vllm.platforms import current_platform +from vllm.plugins import get_current_vllm_config from vllm.utils import print_warning_once logger = init_logger(__name__) @@ -87,6 +85,8 @@ def dispatch_forward(self): @classmethod def enabled(cls) -> bool: # if no name, then it was not registered + compilation_config = get_current_vllm_config().compilation_config + custom_ops = compilation_config.custom_ops if not hasattr(cls, "name"): print_warning_once( f"Custom op {cls.__name__} was not registered, " @@ -94,22 +94,25 @@ def enabled(cls) -> bool: f"It will be enabled/disabled based on the global settings.") return CustomOp.default_on() - enabled = f"+{cls.name}" in envs.VLLM_CUSTOM_OPS - disabled = f"-{cls.name}" in envs.VLLM_CUSTOM_OPS + enabled = f"+{cls.name}" in custom_ops + disabled = f"-{cls.name}" in custom_ops assert not (enabled and disabled), f"Cannot enable and disable {cls.name}" return (CustomOp.default_on() or enabled) and not disabled - # On by default if VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.PIECEWISE - # Specifying 'all' or 'none' in VLLM_CUSTOM_OPS takes precedence. @staticmethod - @lru_cache def default_on() -> bool: - count_none = envs.VLLM_CUSTOM_OPS.count("none") - count_all = envs.VLLM_CUSTOM_OPS.count("all") - assert count_none + count_all <= 1, "Can only specify 'none' or 'all'" - return envs.VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.PIECEWISE and \ + """ + On by default if level < CompilationLevel.PIECEWISE + Specifying 'all' or 'none' in custom_op takes precedence. + """ + from vllm.config import CompilationLevel + compilation_config = get_current_vllm_config().compilation_config + custom_ops = compilation_config.custom_ops + count_none = custom_ops.count("none") + count_all = custom_ops.count("all") + return compilation_config.level < CompilationLevel.PIECEWISE and \ not count_none > 0 or count_all > 0 # Dictionary of all custom ops (classes, indexed by registered name). diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 140b61fe6d56..0f8b81c3ef40 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -42,6 +42,7 @@ safetensors_weights_iterator) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform +from vllm.plugins import set_current_vllm_config from vllm.utils import is_pin_memory_available @@ -97,7 +98,8 @@ def _initialize_model(vllm_config: VllmConfig, prefix: str = "") -> nn.Module: all_params = [param.name for param in signatures.parameters.values()] if "vllm_config" in all_params and "prefix" in all_params: # new-style model class - return model_class(vllm_config=vllm_config, prefix=prefix) + with set_current_vllm_config(vllm_config): + return model_class(vllm_config=vllm_config, prefix=prefix) msg = ("vLLM model class should accept `vllm_config` and `prefix` as " "input arguments. Possibly you have an old-style model class" " registered from out of tree and it is used for new vLLM version. " @@ -121,7 +123,8 @@ def _initialize_model(vllm_config: VllmConfig, prefix: str = "") -> nn.Module: kwargs["lora_config"] = vllm_config.lora_config if "scheduler_config" in all_params: kwargs["scheduler_config"] = vllm_config.scheduler_config - return model_class(**kwargs) + with set_current_vllm_config(vllm_config): + return model_class(**kwargs) class BaseModelLoader(ABC): diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 81d8bdae2383..970c0d1be617 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -1,10 +1,15 @@ import enum import random -from typing import NamedTuple, Optional, Tuple, Union +from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Union import numpy as np import torch +if TYPE_CHECKING: + from vllm.config import VllmConfig +else: + VllmConfig = None + class PlatformEnum(enum.Enum): CUDA = enum.auto() @@ -129,6 +134,19 @@ def seed_everything(cls, seed: int) -> None: np.random.seed(seed) torch.manual_seed(seed) + @classmethod + def check_and_update_config(cls, vllm_config: VllmConfig) -> None: + """ + Check and update the configuration for the current platform. + + It can raise an exception if the configuration is not compatible with + the current platform, or it can update the configuration to make it + compatible with the current platform. + + The config is passed by reference, so it can be modified in place. + """ + pass + class UnspecifiedPlatform(Platform): _enum = PlatformEnum.UNSPECIFIED diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 8d0ce47df404..c2e22bfc09f2 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -1,18 +1,16 @@ import os +from typing import TYPE_CHECKING import torch -import vllm.envs as envs -from vllm.compilation.levels import CompilationLevel from vllm.plugins import set_torch_compile_backend from .interface import Platform, PlatformEnum -if "VLLM_TORCH_COMPILE_LEVEL" not in os.environ: - os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.DYNAMO_ONCE) - -assert envs.VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.PIECEWISE,\ - "TPU does not support Inductor." +if TYPE_CHECKING: + from vllm.config import VllmConfig +else: + VllmConfig = None set_torch_compile_backend("openxla") @@ -31,3 +29,12 @@ def get_device_total_memory(cls, device_id: int = 0) -> int: @classmethod def inference_mode(cls): return torch.no_grad() + + @classmethod + def check_and_update_config(cls, vllm_config: VllmConfig) -> None: + from vllm.config import CompilationLevel + compilation_config = vllm_config.compilation_config + if "VLLM_TORCH_COMPILE_LEVEL" not in os.environ: + compilation_config.level = CompilationLevel.DYNAMO_ONCE + assert compilation_config.level < CompilationLevel.PIECEWISE,\ + "TPU does not support Inductor." diff --git a/vllm/plugins/__init__.py b/vllm/plugins/__init__.py index 7b1bbb14c530..c20b9ec891d5 100644 --- a/vllm/plugins/__init__.py +++ b/vllm/plugins/__init__.py @@ -1,11 +1,11 @@ import logging +from contextlib import contextmanager from typing import TYPE_CHECKING, Callable, Optional, Union import vllm.envs as envs if TYPE_CHECKING: - from vllm.compilation.config import CompilationConfig - from vllm.config import VllmConfig + from vllm.config import CompilationConfig, VllmConfig else: CompilationConfig = None VllmConfig = None @@ -72,3 +72,29 @@ def set_compilation_config(config: Optional[CompilationConfig]): def get_compilation_config() -> Optional[CompilationConfig]: return _compilation_config + + +_current_vllm_config: Optional[VllmConfig] = None + + +@contextmanager +def set_current_vllm_config(vllm_config: VllmConfig): + """ + Temporarily set the current VLLM config. + Used during model initialization. + We save the current VLLM config in a global variable, + so that all modules can access it, e.g. custom ops + can access the VLLM config to determine how to dispatch. + """ + global _current_vllm_config + old_vllm_config = _current_vllm_config + try: + _current_vllm_config = vllm_config + yield + finally: + _current_vllm_config = old_vllm_config + + +def get_current_vllm_config() -> VllmConfig: + assert _current_vllm_config is not None, "Current VLLM config is not set." + return _current_vllm_config diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index eebd1de96537..d60f93a44f6d 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1,4 +1,3 @@ -import os import time from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple @@ -8,11 +7,8 @@ import torch.distributed import torch.nn as nn -from vllm import envs from vllm.compilation.compile_context import set_compile_context -from vllm.compilation.config import CompilationConfig -from vllm.compilation.levels import CompilationLevel -from vllm.config import VllmConfig +from vllm.config import CompilationConfig, CompilationLevel, VllmConfig from vllm.forward_context import set_forward_context from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.logger import init_logger @@ -99,7 +95,7 @@ def __init__( pin_memory=self.pin_memory, ) - self.use_cuda_graph = (envs.VLLM_TORCH_COMPILE_LEVEL + self.use_cuda_graph = (self.vllm_config.compilation_config.level == CompilationLevel.PIECEWISE and not self.model_config.enforce_eager) # TODO(woosuk): Provide an option to tune the max cudagraph batch size. @@ -517,9 +513,9 @@ def load_model(self) -> None: # CUDA graphs do not work properly with the custom CUDA kernels. # FIXME(woosuk): Disable inductor to reduce the compilation time # and avoid any potential issues with the inductor. - os.environ["VLLM_CUSTOM_OPS"] = "none" set_compilation_config( CompilationConfig( + custom_ops=["none"], use_cudagraph=True, non_cudagraph_ops=["vllm.unified_v1_flash_attention"], use_inductor=True, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 22ee3f9f863e..fd89f9544556 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -19,8 +19,7 @@ from vllm.attention.backends.abstract import AttentionState from vllm.attention.backends.utils import CommonAttentionState from vllm.compilation.compile_context import set_compile_context -from vllm.compilation.levels import CompilationLevel -from vllm.config import VllmConfig +from vllm.config import CompilationLevel, VllmConfig from vllm.core.scheduler import SchedulerOutputs from vllm.distributed import get_pp_group from vllm.distributed.parallel_state import graph_capture @@ -1142,8 +1141,8 @@ def load_model(self) -> None: "provided. Defaulting to scaling factors of 1.0. " "This may lead to less accurate results!") - if envs.VLLM_TORCH_COMPILE_LEVEL == CompilationLevel.DYNAMO_AS_IS \ - and supports_dynamo(): + if self.vllm_config.compilation_config.level ==\ + CompilationLevel.DYNAMO_AS_IS and supports_dynamo(): from vllm.plugins import get_torch_compile_backend backend = get_torch_compile_backend() or "eager" self.model = torch.compile( diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index a72118613732..d7a641857a61 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -140,7 +140,7 @@ def load_model(self) -> None: model = get_model(vllm_config=self.vllm_config) model = model.eval() xm.wait_device_ops() - self.model = ModelWrapper(model) + self.model = ModelWrapper(model, self.vllm_config) def _dummy_run( self, @@ -669,13 +669,15 @@ def execute_model( class ModelWrapper(TorchCompileWrapperWithCustomDispatcher): - def __init__(self, model: nn.Module): + def __init__(self, model: nn.Module, vllm_config: VllmConfig): self.model = model compiled_callable = torch.compile(self.forward, backend="openxla", fullgraph=True, dynamic=False) - super().__init__(compiled_callable) + super().__init__( + compiled_callable, + compilation_level=vllm_config.compilation_config.level) def __call__(self, *args, is_prompt: bool, **kwargs): if len(self.compiled_codes) < 3 or not self.use_custom_dispatcher: From 7fa97cf556eb753b30f5d83b85e6f2b2d3619f62 Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Sat, 16 Nov 2024 21:18:46 -0800 Subject: [PATCH 05/25] [V1] Refactor model executable interface for all text-only language models (#10374) Signed-off-by: Roger Wang Signed-off-by: Linkun Chen --- vllm/model_executor/models/arctic.py | 16 ++++++++++++++-- vllm/model_executor/models/baichuan.py | 16 ++++++++++++++-- vllm/model_executor/models/bloom.py | 17 ++++++++++++++--- vllm/model_executor/models/commandr.py | 16 ++++++++++++++-- vllm/model_executor/models/dbrx.py | 16 ++++++++++++++-- vllm/model_executor/models/deepseek.py | 16 ++++++++++++++-- vllm/model_executor/models/deepseek_v2.py | 16 ++++++++++++++-- vllm/model_executor/models/eagle.py | 13 ++++++++++--- vllm/model_executor/models/exaone.py | 7 ++++++- vllm/model_executor/models/falcon.py | 16 ++++++++++++++-- vllm/model_executor/models/gemma.py | 7 ++++++- vllm/model_executor/models/gemma2.py | 12 ++++++++++-- vllm/model_executor/models/gpt2.py | 7 +++++-- vllm/model_executor/models/gpt_bigcode.py | 17 +++++++++++++---- vllm/model_executor/models/gpt_j.py | 16 ++++++++++++++-- vllm/model_executor/models/gpt_neox.py | 16 ++++++++++++++-- vllm/model_executor/models/granite.py | 7 ++++++- vllm/model_executor/models/granitemoe.py | 16 ++++++++++++++-- vllm/model_executor/models/internlm2.py | 9 +++++++-- vllm/model_executor/models/jais.py | 14 ++++++++++++-- vllm/model_executor/models/jamba.py | 16 ++++++++++++++-- vllm/model_executor/models/mamba.py | 15 +++++++++++++-- vllm/model_executor/models/minicpm.py | 7 ++++++- vllm/model_executor/models/mixtral.py | 16 ++++++++++++++-- vllm/model_executor/models/mixtral_quant.py | 16 ++++++++++++++-- vllm/model_executor/models/mpt.py | 16 ++++++++++++++-- vllm/model_executor/models/nemotron.py | 7 ++++++- vllm/model_executor/models/olmo.py | 19 +++++++++++++------ vllm/model_executor/models/olmoe.py | 16 ++++++++++++++-- vllm/model_executor/models/orion.py | 16 ++++++++++++++-- vllm/model_executor/models/persimmon.py | 8 +++++++- vllm/model_executor/models/phi.py | 16 ++++++++++++++-- vllm/model_executor/models/phi3_small.py | 19 +++++++++++-------- vllm/model_executor/models/phimoe.py | 16 ++++++++++++++-- vllm/model_executor/models/qwen.py | 16 ++++++++++++++-- vllm/model_executor/models/qwen2.py | 2 +- vllm/model_executor/models/qwen2_cls.py | 7 ++++++- vllm/model_executor/models/qwen2_moe.py | 16 ++++++++++++++-- vllm/model_executor/models/qwen2_rm.py | 7 ++++++- vllm/model_executor/models/solar.py | 4 +++- vllm/model_executor/models/stablelm.py | 16 ++++++++++++++-- vllm/model_executor/models/starcoder2.py | 16 ++++++++++++++-- vllm/model_executor/models/xverse.py | 16 ++++++++++++++-- 43 files changed, 483 insertions(+), 90 deletions(-) diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index 9ee2a2cc09a2..d52418ee0f6f 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -389,6 +389,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory(["hidden_states"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -396,9 +399,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] @@ -439,6 +446,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -446,9 +456,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index aabbd31192a4..01ce7c42cd39 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -284,6 +284,9 @@ def __init__( make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -291,9 +294,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) residual = None else: assert intermediate_tensors is not None @@ -363,6 +370,9 @@ def __init__( self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -370,9 +380,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 84adf574af5e..cf2eee817276 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -251,6 +251,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory(["hidden_states"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.word_embeddings_layernorm(self.word_embeddings(input_ids)) + def forward( self, input_ids: torch.Tensor, @@ -258,10 +261,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.word_embeddings(input_ids) - hidden_states = self.word_embeddings_layernorm(hidden_states) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] @@ -301,6 +307,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.transformer.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -308,9 +317,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index cd5c1d684471..fbb09a64cde9 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -280,6 +280,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -287,9 +290,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) residual = None else: assert intermediate_tensors is not None @@ -354,6 +361,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + @torch.no_grad() def forward( self, @@ -362,9 +372,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index fff8710f6b47..3952ff31e5ce 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -321,6 +321,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory(["hidden_states"], config.d_model)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.wte(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -328,9 +331,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.wte(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) else: assert intermediate_tensors hidden_states = intermediate_tensors["hidden_states"] @@ -376,6 +383,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.transformer.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -383,9 +393,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index a9bf1440c4d6..36dfea5a6565 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -353,6 +353,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -360,9 +363,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) residual = None else: hidden_states = intermediate_tensors["hidden_states"] @@ -401,6 +408,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -408,9 +418,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 4fb1eed15a2e..1e32fe60c7a5 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -445,6 +445,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -452,9 +455,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) residual = None else: assert intermediate_tensors is not None @@ -495,6 +502,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -502,9 +512,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/eagle.py b/vllm/model_executor/models/eagle.py index 85c51e840458..f138d1363026 100644 --- a/vllm/model_executor/models/eagle.py +++ b/vllm/model_executor/models/eagle.py @@ -78,6 +78,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def sampler(self): return self.model.sampler + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -86,11 +89,14 @@ def forward( attn_metadata: AttentionMetadata, previous_hidden_states: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - tok_embeds = self.model.model.embed_tokens(input_ids) + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings(input_ids) + inputs_embeds = self.fc( - torch.cat([tok_embeds, previous_hidden_states], dim=-1)) + torch.cat([inputs_embeds, previous_hidden_states], dim=-1)) inputs_embeds[positions == 0] = 0 # masking inputs at position=0 @@ -100,7 +106,8 @@ def forward( positions=positions, kv_caches=kv_caches, attn_metadata=attn_metadata, - intermediate_tensors=intermediate_tensors) + intermediate_tensors=intermediate_tensors, + ) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py index cd3e7da657e0..52dd603ca558 100644 --- a/vllm/model_executor/models/exaone.py +++ b/vllm/model_executor/models/exaone.py @@ -479,6 +479,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -486,9 +489,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: model_output = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return model_output def compute_logits( diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index b3dbf063ac29..e97abe949ccd 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -367,6 +367,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory(["hidden_states"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.word_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -374,9 +377,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.word_embeddings(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) else: hidden_states = intermediate_tensors["hidden_states"] for i in range(self.start_layer, self.end_layer): @@ -432,6 +439,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.transformer.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.LongTensor, @@ -439,9 +449,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index 55baba809e58..ace13664c6ea 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -390,6 +390,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -397,9 +400,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index eeb3fd98a7ea..a60b4e73a76d 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -272,6 +272,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: Optional[torch.Tensor], @@ -285,7 +288,7 @@ def forward( if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.embed_tokens(input_ids) + hidden_states = self.get_input_embeddings(input_ids) hidden_states *= self.normalizer residual = None else: @@ -414,6 +417,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -421,9 +427,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index cc85693f9952..fa0fdad28d16 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -209,6 +209,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory(["hidden_states"], config.n_embd)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.wte(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -220,7 +223,7 @@ def forward( ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: if inputs_embeds is None: - inputs_embeds = self.wte(input_ids) + inputs_embeds = self.get_input_embeddings(input_ids) position_embeds = self.wpe(position_ids) hidden_states = inputs_embeds + position_embeds else: @@ -262,7 +265,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.transformer.make_empty_intermediate_tensors) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.transformer.wte(input_ids) + return self.transformer.get_input_embeddings(input_ids) def forward( self, diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index ab25c66c3a88..b2fc79d0d36d 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -218,6 +218,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory(["hidden_states"], config.n_embd)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.wte(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -225,11 +228,12 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - inputs_embeds = self.wte(input_ids) - position_embeds = self.wpe(position_ids) - hidden_states = inputs_embeds + position_embeds + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings(input_ids) + hidden_states = inputs_embeds + self.wpe(position_ids) else: hidden_states = intermediate_tensors["hidden_states"] @@ -285,6 +289,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.transformer.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -292,9 +299,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index a83d03480dde..cec3fd12a67d 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -201,6 +201,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory(["hidden_states"], config.n_embd)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.wte(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -208,9 +211,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.wte(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) else: hidden_states = intermediate_tensors["hidden_states"] for i in range(self.start_layer, self.end_layer): @@ -250,6 +257,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.transformer.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -257,9 +267,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index 794b141bfa4a..11f286d6bcba 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -214,6 +214,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory(["hidden_states"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_in(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -221,9 +224,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_in(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) else: hidden_states = intermediate_tensors["hidden_states"] for i in range(self.start_layer, self.end_layer): @@ -262,6 +269,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.gpt_neox.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.gpt_neox.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -269,9 +279,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.gpt_neox(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index d1e6e31f2b8d..cb2583e69d88 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -409,6 +409,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): else: self.lm_head = PPMissingLayer() + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -416,9 +419,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: model_output = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return model_output def compute_logits( diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py index 2ed115c56af4..f437dd521a7d 100644 --- a/vllm/model_executor/models/granitemoe.py +++ b/vllm/model_executor/models/granitemoe.py @@ -277,6 +277,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -284,9 +287,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) hidden_states *= self.embedding_multiplier residual = None else: @@ -366,6 +373,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.sampler = get_sampler() + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -373,9 +383,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 21fa6983063b..19bfe16e4d5f 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -290,7 +290,7 @@ def forward( if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.tok_embeddings(input_ids) + hidden_states = self.get_input_embeddings(input_ids) residual = None else: assert intermediate_tensors is not None @@ -335,6 +335,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -342,9 +345,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index 65800c44e5a9..ee49ffb3cd87 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -250,6 +250,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory(["hidden_states"], config.n_embd)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.wte(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -257,9 +260,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[IntermediateTensors, torch.Tensor]: if get_pp_group().is_first_rank: - inputs_embeds = self.wte(input_ids) + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings(input_ids) if self.wpe is not None: position_embeds = self.wpe(position_ids) hidden_states = inputs_embeds + position_embeds @@ -311,6 +316,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.transformer.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -318,9 +326,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[IntermediateTensors, torch.Tensor]: hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 88fb8d5cf555..5612dd688638 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -292,6 +292,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -299,8 +302,12 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, mamba_cache_params: MambaCacheParams, + inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) residual = None for i in range(len(self.layers)): layer = self.layers[i] @@ -381,12 +388,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config.vocab_size) self.sampler = get_sampler() + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[KVCache], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, **kwargs): if self.mamba_cache is None: max_batch_size = (_get_graph_batch_size( @@ -409,7 +420,8 @@ def forward(self, mamba_cache_tensors[1], state_indices_tensor) hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, mamba_cache_params) + attn_metadata, mamba_cache_params, + inputs_embeds) return hidden_states def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 55c575e22a0f..ac0d265a961f 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -106,15 +106,22 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.norm_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, attn_metadata: AttentionMetadata, mamba_cache_params: MambaCacheParams, + inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.embeddings(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) residual = None for i in range(len(self.layers)): @@ -168,12 +175,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config.vocab_size) self.sampler = get_sampler() + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.backbone.get_input_embeddings(input_ids) + def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[KVCache], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, **kwargs): if self.mamba_cache is None: max_batch_size = (_get_graph_batch_size( @@ -194,7 +205,7 @@ def forward(self, state_indices_tensor) hidden_states = self.backbone(input_ids, positions, attn_metadata, - mamba_cache_params) + mamba_cache_params, inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index 2db953329fd9..6b67266c5336 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -504,6 +504,9 @@ def _init_model(self, *, vllm_config: VllmConfig, prefix: str = ""): self.model = MiniCPMModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -511,9 +514,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 3eb2f60fd4fc..eebf5bab5a28 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -281,6 +281,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -288,9 +291,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) residual = None else: assert intermediate_tensors is not None @@ -363,6 +370,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -370,9 +380,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index 95cfb6f54dc1..af2e9586988d 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -318,6 +318,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -325,9 +328,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) residual = None else: assert intermediate_tensors is not None @@ -368,6 +375,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -375,9 +385,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index e15c0fe8db06..3c74ef2448ab 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -237,6 +237,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory(["hidden_states"], config.d_model)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.wte(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -244,9 +247,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.wte(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] @@ -283,6 +290,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.transformer.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -290,9 +300,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py index e09d7088a69c..eb45beae7d21 100644 --- a/vllm/model_executor/models/nemotron.py +++ b/vllm/model_executor/models/nemotron.py @@ -440,6 +440,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -447,9 +450,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: model_output = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return model_output def compute_logits( diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 3467ae589649..98d4e1ec320a 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -248,6 +248,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory(["hidden_states"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -255,17 +258,16 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: """ :param input_ids: A tensor of shape `(batch_size, seq_len)`. """ if get_pp_group().is_first_rank: - # Get embeddings of input. - # shape: (batch_size, seq_len, d_model) - inputs_embeds = self.embed_tokens(input_ids) - - # embed positions - hidden_states = inputs_embeds + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] @@ -315,6 +317,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -322,6 +327,7 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model( input_ids=input_ids, @@ -329,6 +335,7 @@ def forward( kv_caches=kv_caches, attn_metadata=attn_metadata, intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, ) return hidden_states diff --git a/vllm/model_executor/models/olmoe.py b/vllm/model_executor/models/olmoe.py index 3d31919edd86..f4eebab8c98d 100644 --- a/vllm/model_executor/models/olmoe.py +++ b/vllm/model_executor/models/olmoe.py @@ -269,6 +269,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -276,9 +279,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) residual = None else: assert intermediate_tensors is not None @@ -326,6 +333,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -333,9 +343,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index 38821c828834..39d659c49cbc 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -237,6 +237,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): "hidden_states", ], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -244,9 +247,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] @@ -286,6 +293,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -293,9 +303,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/persimmon.py b/vllm/model_executor/models/persimmon.py index 2e34a7cc3087..62c509153a11 100644 --- a/vllm/model_executor/models/persimmon.py +++ b/vllm/model_executor/models/persimmon.py @@ -235,6 +235,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory(["hidden_states"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -248,7 +251,7 @@ def forward( if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.embed_tokens(input_ids) + hidden_states = self.get_input_embeddings(input_ids) else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] @@ -282,6 +285,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index 262f6996fc37..a2ab0d74c48d 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -218,6 +218,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory(["hidden_states"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -225,9 +228,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] @@ -303,6 +310,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -310,9 +320,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/phi3_small.py b/vllm/model_executor/models/phi3_small.py index 8a5fb6d303e6..2139cec44180 100644 --- a/vllm/model_executor/models/phi3_small.py +++ b/vllm/model_executor/models/phi3_small.py @@ -324,11 +324,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory(["hidden_states"], config.hidden_size)) - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) def forward( self, @@ -337,9 +334,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor], ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) if (self.mup_embedding_multiplier is not None and self.mup_embedding_multiplier > 0.0): hidden_states = hidden_states * self.mup_embedding_multiplier @@ -397,8 +398,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): else: self.dummy_token_indices = None - def get_input_embeddings(self): - return self.model.embed_tokens + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) def set_input_embeddings(self, value): self.model.embed_tokens = value @@ -433,6 +434,7 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: output_hidden_states = self.model( input_ids=input_ids, @@ -440,6 +442,7 @@ def forward( kv_caches=kv_caches, attn_metadata=attn_metadata, intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, ) output_hidden_states = output_hidden_states return output_hidden_states diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py index 6d71a8949111..b7e70f8fa2c6 100644 --- a/vllm/model_executor/models/phimoe.py +++ b/vllm/model_executor/models/phimoe.py @@ -465,6 +465,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -472,9 +475,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) residual = None else: assert intermediate_tensors is not None @@ -560,6 +567,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -567,9 +577,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 3d26ede722dd..447632cefcd9 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -578,6 +578,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config=quant_config) if hasattr( config, "visual") else None + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.wte(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -586,6 +589,7 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], pixel_values: Optional[QwenImageInputs], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: img_pos = None # If pixel / visual embeddings are provided, this is a visual model @@ -606,6 +610,10 @@ def forward( ) if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.wte(input_ids) # Merge the image embeddings into the hidden states if actually have # visual features and the corresponding image tokens @@ -915,6 +923,9 @@ def _get_image_input_type( ) return None + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.transformer.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -922,7 +933,8 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - pixel_values: Optional[torch.Tensor] = None + pixel_values: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if intermediate_tensors is not None: input_ids = None @@ -932,7 +944,7 @@ def forward( hidden_states = self.transformer(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors, - pixel_values) + pixel_values, inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 431e397e1e10..8f10df808c21 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -309,7 +309,7 @@ def forward( if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.embed_tokens(input_ids) + hidden_states = self.get_input_embeddings(input_ids) residual = None else: assert intermediate_tensors is not None diff --git a/vllm/model_executor/models/qwen2_cls.py b/vllm/model_executor/models/qwen2_cls.py index 120403e94868..07eb330620a4 100644 --- a/vllm/model_executor/models/qwen2_cls.py +++ b/vllm/model_executor/models/qwen2_cls.py @@ -72,6 +72,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): normalize=False, softmax=True) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -79,9 +82,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) logits, _ = self.score(hidden_states) return logits diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 51c0cd5664fd..249d94b5d95e 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -344,6 +344,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -351,9 +354,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) residual = None else: assert intermediate_tensors is not None @@ -395,6 +402,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -402,9 +412,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/qwen2_rm.py b/vllm/model_executor/models/qwen2_rm.py index 55843d832534..6db467af334f 100644 --- a/vllm/model_executor/models/qwen2_rm.py +++ b/vllm/model_executor/models/qwen2_rm.py @@ -85,6 +85,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -92,9 +95,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) logits, _ = self.score(hidden_states) return logits diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py index 4f03ca501fb6..affb2c975ce4 100644 --- a/vllm/model_executor/models/solar.py +++ b/vllm/model_executor/models/solar.py @@ -456,9 +456,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: model_output = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return model_output def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index 1125f9e9f961..99acce596602 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -218,6 +218,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory(["hidden_states"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -225,9 +228,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] @@ -265,6 +272,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -272,9 +282,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index ce7a7957f52c..0ef940acebb9 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -221,6 +221,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory(["hidden_states"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -228,9 +231,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] @@ -273,6 +280,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -280,9 +290,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/xverse.py b/vllm/model_executor/models/xverse.py index 153527da20d7..51172d8782a7 100644 --- a/vllm/model_executor/models/xverse.py +++ b/vllm/model_executor/models/xverse.py @@ -252,6 +252,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -259,9 +262,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) residual = None else: hidden_states = intermediate_tensors["hidden_states"] @@ -335,6 +342,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -342,9 +352,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( From 629f5120b70cb1426d18c5b8eb19f779702ef367 Mon Sep 17 00:00:00 2001 From: "Chendi.Xue" Date: Sun, 17 Nov 2024 00:58:22 -0600 Subject: [PATCH 06/25] [CI/Build] Fix IDC hpu [Device not found] issue (#10384) Signed-off-by: Chendi Xue Signed-off-by: Linkun Chen --- .buildkite/run-hpu-test.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.buildkite/run-hpu-test.sh b/.buildkite/run-hpu-test.sh index 4505dc7a9373..fa4f74fca7a1 100644 --- a/.buildkite/run-hpu-test.sh +++ b/.buildkite/run-hpu-test.sh @@ -13,4 +13,4 @@ trap remove_docker_container EXIT remove_docker_container # Run the image and launch offline inference -docker run --runtime=habana --name=hpu-test --network=host -e VLLM_SKIP_WARMUP=true --entrypoint="" hpu-test-env python3 examples/offline_inference.py \ No newline at end of file +docker run --runtime=habana --name=hpu-test --network=host -e HABANA_VISIBLE_DEVICES=all -e VLLM_SKIP_WARMUP=true --entrypoint="" hpu-test-env python3 examples/offline_inference.py \ No newline at end of file From 7539ab8d83dfe4357caf1deda4c29201b698a3d5 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sun, 17 Nov 2024 15:12:04 +0800 Subject: [PATCH 07/25] [Bugfix][CPU] Fix CPU embedding runner with tensor parallel (#10394) Signed-off-by: Isotr0py <2037008807@qq.com> Signed-off-by: Linkun Chen --- vllm/worker/cpu_embedding_model_runner.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/worker/cpu_embedding_model_runner.py b/vllm/worker/cpu_embedding_model_runner.py index 7053075bf4d8..d0b8fec48d74 100644 --- a/vllm/worker/cpu_embedding_model_runner.py +++ b/vllm/worker/cpu_embedding_model_runner.py @@ -66,6 +66,10 @@ def execute_model( hidden_states = model_executable(**execute_model_kwargs) + # Only perform pooling in the driver worker. + if not self.is_driver_worker: + return [] + return [ self.model.pooler(hidden_states=hidden_states, pooling_metadata=model_input.pooling_metadata) From bce660d722520d7e236e6bdd8e5ecbb102900972 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 16 Nov 2024 23:14:23 -0800 Subject: [PATCH 08/25] [platforms] refactor cpu code (#10402) Signed-off-by: youkaichao Signed-off-by: Linkun Chen --- vllm/executor/cpu_executor.py | 68 +---------------------------------- vllm/platforms/cpu.py | 60 +++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 67 deletions(-) diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index 4ceb5a837dd7..1542a2ae367e 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -2,9 +2,6 @@ from functools import partial from typing import Any, Awaitable, List, Optional, Set, Tuple, Union -import vllm.envs as envs -from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, - SchedulerConfig) from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper, ResultHandler, WorkerMonitor) @@ -13,7 +10,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import ExecuteModelRequest -from vllm.utils import (GiB_bytes, get_distributed_init_method, get_open_port, +from vllm.utils import (get_distributed_init_method, get_open_port, get_vllm_instance_id, make_async) from vllm.worker.worker_base import WorkerWrapperBase @@ -57,13 +54,6 @@ def _init_executor(self) -> None: os.environ["LOCAL_WORLD_SIZE"] = str( self.parallel_config.tensor_parallel_size) - self.model_config = _verify_and_get_model_config(self.model_config) - self.cache_config = _verify_and_get_cache_config(self.cache_config) - self.scheduler_config = _verify_and_get_scheduler_config( - self.scheduler_config) - self.parallel_config = _verify_and_get_parallel_config( - self.parallel_config) - # Multiprocessing-based executor does not support multi-node setting. # Since it only works for single node, we can use the loopback address # 127.0.0.1 for communication. @@ -313,62 +303,6 @@ async def check_health_async(self) -> None: self.check_health() -def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig: - # Reminder: Please update docs/source/serving/compatibility_matrix.rst - # If the feature combo become valid - if not config.enforce_eager: - logger.warning( - "CUDA graph is not supported on CPU, fallback to the eager " - "mode.") - config.enforce_eager = True - return config - - -def _verify_and_get_scheduler_config( - config: SchedulerConfig) -> SchedulerConfig: - # Reminder: Please update docs/source/serving/compatibility_matrix.rst - # If the feature combo become valid - if config.chunked_prefill_enabled: - logger.warning("Chunked prefill is not supported on CPU, disable it.") - config.chunked_prefill_enabled = False - - return config - - -def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig: - # Reminder: Please update docs/source/serving/compatibility_matrix.rst - # If the feature combo become valid - if config.enable_prefix_caching: - logger.warning("Prefix caching is not supported on CPU, disable it.") - config.enable_prefix_caching = False - - kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE - - if kv_cache_space >= 0: - if kv_cache_space == 0: - config.cpu_kvcache_space_bytes = 4 * GiB_bytes # type: ignore - logger.warning("Environment variable VLLM_CPU_KVCACHE_SPACE (GB) " - "for CPU backend is not set, using 4 by default.") - else: - config.cpu_kvcache_space_bytes = kv_cache_space * GiB_bytes # type: ignore - else: - raise RuntimeError( - "Invalid environment variable VLLM_CPU_KVCACHE_SPACE" - f" {kv_cache_space}, expect a positive integer value.") - - return config - - -def _verify_and_get_parallel_config(config: ParallelConfig) -> ParallelConfig: - if (config.distributed_executor_backend is not None - and config.distributed_executor_backend != "mp"): - logger.warning( - "%s is not supported on CPU, fallback to mp distributed executor " - "backend.", config.distributed_executor_backend) - config.distributed_executor_backend = "mp" - return config - - def _driver_method_invoker(driver, method: str, *args, **kwargs): return getattr(driver, method)(*args, **kwargs) diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 5243f59203af..42bee31dfb0e 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -1,8 +1,19 @@ +from typing import TYPE_CHECKING + import psutil import torch +from vllm.logger import init_logger + from .interface import Platform, PlatformEnum +if TYPE_CHECKING: + from vllm.config import VllmConfig +else: + VllmConfig = None + +logger = init_logger(__name__) + class CpuPlatform(Platform): _enum = PlatformEnum.CPU @@ -18,3 +29,52 @@ def get_device_total_memory(cls, device_id: int = 0) -> int: @classmethod def inference_mode(cls): return torch.no_grad() + + @classmethod + def check_and_update_config(cls, vllm_config: VllmConfig) -> None: + import vllm.envs as envs + from vllm.utils import GiB_bytes + model_config = vllm_config.model_config + # Reminder: Please update docs/source/serving/compatibility_matrix.rst + # If the feature combo become valid + if not model_config.enforce_eager: + logger.warning( + "CUDA graph is not supported on CPU, fallback to the eager " + "mode.") + model_config.enforce_eager = True + + cache_config = vllm_config.cache_config + + if cache_config.enable_prefix_caching: + logger.warning( + "Prefix caching is not supported on CPU, disable it.") + cache_config.enable_prefix_caching = False + + kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE + + if kv_cache_space >= 0: + if kv_cache_space == 0: + cache_config.cpu_kvcache_space_bytes = 4 * GiB_bytes # type: ignore + logger.warning( + "Environment variable VLLM_CPU_KVCACHE_SPACE (GB) " + "for CPU backend is not set, using 4 by default.") + else: + cache_config.cpu_kvcache_space_bytes = kv_cache_space * GiB_bytes # type: ignore # noqa + else: + raise RuntimeError( + "Invalid environment variable VLLM_CPU_KVCACHE_SPACE" + f" {kv_cache_space}, expect a positive integer value.") + + scheduler_config = vllm_config.scheduler_config + if scheduler_config.chunked_prefill_enabled: + logger.warning( + "Chunked prefill is not supported on CPU, disable it.") + scheduler_config.chunked_prefill_enabled = False + + parallel_config = vllm_config.parallel_config + if (parallel_config.distributed_executor_backend is not None + and parallel_config.distributed_executor_backend != "mp"): + logger.warning(("%s is not supported on CPU, fallback to mp " + "distributed executor backend."), + parallel_config.distributed_executor_backend) + parallel_config.distributed_executor_backend = "mp" From 305708bd04dbe1146a349935ebe628873157879c Mon Sep 17 00:00:00 2001 From: Kunshang Ji Date: Sun, 17 Nov 2024 16:44:44 +0800 Subject: [PATCH 09/25] [Hardware] [HPU]add `mark_step` for hpu (#10239) Signed-off-by: Kunshang Ji Signed-off-by: Linkun Chen --- vllm/worker/hpu_model_runner.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 1ff30d685c6b..99cf9a7e6725 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -272,6 +272,19 @@ def precompute_indices_and_offsets(block_size, slot_mapping, is_prompt): return indices, offsets +def modify_decoder_layer(module: torch.nn.Module, suffix="DecoderLayer"): + if module.__class__.__name__.endswith(suffix): + + def forward_hook(module, args, output): + htorch.core.mark_step() + return output + + module.register_forward_hook(forward_hook) + + for child_name, child_module in module.named_children(): + modify_decoder_layer(child_module) + + class HpuModelAdapter: def __init__(self, model, block_size, dtype, enforce_eager): @@ -636,6 +649,7 @@ def load_model(self) -> None: else: self.model = self.model.to("hpu") htcore.mark_step() + modify_decoder_layer(self.model) torch.hpu.synchronize() with HabanaMemoryProfiler() as m_wrap: From 871a773c9543f154cbc03a73ec901ab4ca039458 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=94=B5=E8=84=91=E6=98=9F=E4=BA=BA?= Date: Sun, 17 Nov 2024 16:50:24 +0800 Subject: [PATCH 10/25] [Bugfix] Fix mrope_position_delta in non-last prefill chunk (#10403) Signed-off-by: imkero Signed-off-by: Linkun Chen --- vllm/model_executor/layers/rotary_embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index b01e4c61fe10..117fe086e5e8 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -922,9 +922,9 @@ def get_input_positions( torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - llm_positions = llm_positions[:, context_len:seq_len] mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() + llm_positions = llm_positions[:, context_len:seq_len] return llm_positions.tolist(), mrope_position_delta From 242bb53c234fa252f62c2667722f69c9fe362cf9 Mon Sep 17 00:00:00 2001 From: wchen61 Date: Sun, 17 Nov 2024 19:32:40 +0800 Subject: [PATCH 11/25] =?UTF-8?q?[Misc]=20Enhance=20offline=5Finference=20?= =?UTF-8?q?to=20support=20user-configurable=20paramet=E2=80=A6=20(#10392)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: wchen61 Signed-off-by: Linkun Chen --- examples/offline_inference.py | 98 ++++++++++++++++++++++++++++------- 1 file changed, 78 insertions(+), 20 deletions(-) diff --git a/examples/offline_inference.py b/examples/offline_inference.py index 9b758fa2479f..391ac6b9b6b0 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -1,22 +1,80 @@ +from dataclasses import asdict + from vllm import LLM, SamplingParams +from vllm.engine.arg_utils import EngineArgs +from vllm.utils import FlexibleArgumentParser + + +def get_prompts(num_prompts: int): + # The default sample prompts. + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + if num_prompts != len(prompts): + prompts = (prompts * ((num_prompts // len(prompts)) + 1))[:num_prompts] + + return prompts + + +def main(args): + # Create prompts + prompts = get_prompts(args.num_prompts) + + # Create a sampling params object. + sampling_params = SamplingParams(n=args.n, + temperature=args.temperature, + top_p=args.top_p, + top_k=args.top_k, + max_tokens=args.max_tokens) + + # Create an LLM. + # The default model is 'facebook/opt-125m' + engine_args = EngineArgs.from_cli_args(args) + llm = LLM(**asdict(engine_args)) + + # Generate texts from the prompts. + # The output is a list of RequestOutput objects + # that contain the prompt, generated text, and other information. + outputs = llm.generate(prompts, sampling_params) + # Print the outputs. + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + +if __name__ == '__main__': + parser = FlexibleArgumentParser() + parser = EngineArgs.add_cli_args(parser) + group = parser.add_argument_group("SamplingParams options") + group.add_argument("--num-prompts", + type=int, + default=4, + help="Number of prompts used for inference") + group.add_argument("--max-tokens", + type=int, + default=16, + help="Generated output length for sampling") + group.add_argument('--n', + type=int, + default=1, + help='Number of generated sequences per prompt') + group.add_argument('--temperature', + type=float, + default=0.8, + help='Temperature for text generation') + group.add_argument('--top-p', + type=float, + default=0.95, + help='top_p for text generation') + group.add_argument('--top-k', + type=int, + default=-1, + help='top_k for text generation') -# Sample prompts. -prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", -] -# Create a sampling params object. -sampling_params = SamplingParams(temperature=0.8, top_p=0.95) - -# Create an LLM. -llm = LLM(model="facebook/opt-125m") -# Generate texts from the prompts. The output is a list of RequestOutput objects -# that contain the prompt, generated text, and other information. -outputs = llm.generate(prompts, sampling_params) -# Print the outputs. -for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + args = parser.parse_args() + main(args) From f5312d38c44144c518a1b049dec6de1cd2b9c7ce Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Mon, 18 Nov 2024 09:11:01 +0800 Subject: [PATCH 12/25] Fix initialization Signed-off-by: Cyrus Leung Signed-off-by: Linkun Chen --- vllm/outputs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/outputs.py b/vllm/outputs.py index a02f1e97b5b5..32160a8c0432 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -144,7 +144,7 @@ def new( request_id=request_id, prompt=prompt, prompt_token_ids=prompt_token_ids, - multi_modal_placeholders=MultiModalPlaceholderDict(), + multi_modal_placeholders={}, prompt_logprobs=None, # TODO outputs=[completion_output], finished=finished, From 439e324269d343be3b9ba3b45e9e7d8a5f3a0d8b Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Mon, 18 Nov 2024 09:13:07 +0800 Subject: [PATCH 13/25] Run isort Signed-off-by: Linkun Chen --- tests/models/decoder_only/vision_language/test_pixtral.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/decoder_only/vision_language/test_pixtral.py b/tests/models/decoder_only/vision_language/test_pixtral.py index bbae8138d44b..c502575d4f92 100644 --- a/tests/models/decoder_only/vision_language/test_pixtral.py +++ b/tests/models/decoder_only/vision_language/test_pixtral.py @@ -15,8 +15,8 @@ from mistral_common.tokens.tokenizers.mistral import MistralTokenizer from mistral_common.tokens.tokenizers.multimodal import image_from_chunk -from vllm import (EngineArgs, LLMEngine, SamplingParams, TokensPrompt, - TextPrompt, RequestOutput) +from vllm import (EngineArgs, LLMEngine, RequestOutput, SamplingParams, + TokensPrompt, TextPrompt) from vllm.logger import init_logger from vllm.multimodal import MultiModalDataBuiltins from vllm.multimodal.inputs import PlaceholderRange From 60815f258a3983e3071387da7194b407ce6f52e4 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Mon, 18 Nov 2024 09:16:04 +0800 Subject: [PATCH 14/25] isort Signed-off-by: Linkun Chen --- tests/models/decoder_only/vision_language/test_pixtral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/decoder_only/vision_language/test_pixtral.py b/tests/models/decoder_only/vision_language/test_pixtral.py index c502575d4f92..dad14f9fe3c8 100644 --- a/tests/models/decoder_only/vision_language/test_pixtral.py +++ b/tests/models/decoder_only/vision_language/test_pixtral.py @@ -6,7 +6,6 @@ import uuid from dataclasses import asdict from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple -from transformers import AutoProcessor import pytest from mistral_common.multimodal import download_image @@ -14,6 +13,7 @@ from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.tokens.tokenizers.mistral import MistralTokenizer from mistral_common.tokens.tokenizers.multimodal import image_from_chunk +from transformers import AutoProcessor from vllm import (EngineArgs, LLMEngine, RequestOutput, SamplingParams, TokensPrompt, TextPrompt) From ec4675566c9ce859fc0988ea6f2dc76dcfa2a58a Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Mon, 18 Nov 2024 09:18:22 +0800 Subject: [PATCH 15/25] isort Signed-off-by: Linkun Chen --- tests/models/decoder_only/vision_language/test_pixtral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/decoder_only/vision_language/test_pixtral.py b/tests/models/decoder_only/vision_language/test_pixtral.py index dad14f9fe3c8..4edfd6862511 100644 --- a/tests/models/decoder_only/vision_language/test_pixtral.py +++ b/tests/models/decoder_only/vision_language/test_pixtral.py @@ -16,7 +16,7 @@ from transformers import AutoProcessor from vllm import (EngineArgs, LLMEngine, RequestOutput, SamplingParams, - TokensPrompt, TextPrompt) + TextPrompt, TokensPrompt) from vllm.logger import init_logger from vllm.multimodal import MultiModalDataBuiltins from vllm.multimodal.inputs import PlaceholderRange From ce3ae6fb0cdeee32b4742b4a5865e52dc4959c7e Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Mon, 18 Nov 2024 09:07:46 +0800 Subject: [PATCH 16/25] [Misc] Add uninitialized params tracking for `AutoWeightsLoader` (#10327) Signed-off-by: Isotr0py <2037008807@qq.com> Signed-off-by: Linkun Chen --- vllm/model_executor/model_loader/loader.py | 12 +++++++++++- vllm/model_executor/models/arctic.py | 8 ++++++-- vllm/model_executor/models/baichuan.py | 8 ++++++-- vllm/model_executor/models/bert.py | 8 ++++++-- vllm/model_executor/models/blip.py | 12 ++++++++---- vllm/model_executor/models/blip2.py | 7 ++++--- vllm/model_executor/models/bloom.py | 8 ++++++-- vllm/model_executor/models/chameleon.py | 8 ++++++-- vllm/model_executor/models/chatglm.py | 10 ++++++++-- vllm/model_executor/models/clip.py | 11 ++++++++--- vllm/model_executor/models/commandr.py | 4 +++- vllm/model_executor/models/dbrx.py | 8 ++++++-- vllm/model_executor/models/decilm.py | 8 ++++++-- vllm/model_executor/models/deepseek.py | 8 ++++++-- vllm/model_executor/models/deepseek_v2.py | 8 ++++++-- vllm/model_executor/models/exaone.py | 9 +++++++-- vllm/model_executor/models/falcon.py | 8 ++++++-- vllm/model_executor/models/florence2.py | 17 +++++++++++------ vllm/model_executor/models/fuyu.py | 8 +++++--- vllm/model_executor/models/gemma.py | 4 +++- vllm/model_executor/models/gemma2.py | 9 ++++++--- vllm/model_executor/models/gpt2.py | 8 ++++++-- vllm/model_executor/models/gpt_bigcode.py | 8 ++++++-- vllm/model_executor/models/gpt_j.py | 8 ++++++-- vllm/model_executor/models/gpt_neox.py | 8 ++++++-- vllm/model_executor/models/granite.py | 9 +++++++-- vllm/model_executor/models/granitemoe.py | 8 +++++--- .../models/idefics2_vision_model.py | 11 ++++++++--- vllm/model_executor/models/idefics3.py | 7 ++++--- vllm/model_executor/models/intern_vit.py | 8 ++++++-- vllm/model_executor/models/internlm2.py | 8 ++++++-- vllm/model_executor/models/internvl.py | 7 ++++--- vllm/model_executor/models/jais.py | 8 ++++++-- vllm/model_executor/models/jamba.py | 8 ++++++-- vllm/model_executor/models/llama.py | 15 ++++++++++----- vllm/model_executor/models/llava.py | 7 ++++--- vllm/model_executor/models/llava_next.py | 7 ++++--- vllm/model_executor/models/llava_next_video.py | 7 ++++--- vllm/model_executor/models/llava_onevision.py | 7 ++++--- vllm/model_executor/models/mamba.py | 8 ++++++-- vllm/model_executor/models/medusa.py | 9 +++++++-- vllm/model_executor/models/minicpm.py | 8 ++++++-- vllm/model_executor/models/minicpmv.py | 14 +++++++++----- vllm/model_executor/models/mixtral.py | 8 ++++++-- vllm/model_executor/models/mixtral_quant.py | 8 ++++++-- vllm/model_executor/models/mllama.py | 9 ++++++--- vllm/model_executor/models/mlp_speculator.py | 8 ++++++-- vllm/model_executor/models/mpt.py | 8 ++++++-- vllm/model_executor/models/nemotron.py | 8 ++++++-- vllm/model_executor/models/olmo.py | 8 ++++++-- vllm/model_executor/models/olmoe.py | 8 ++++++-- vllm/model_executor/models/opt.py | 8 ++++++-- vllm/model_executor/models/orion.py | 8 ++++++-- vllm/model_executor/models/paligemma.py | 7 ++++--- vllm/model_executor/models/persimmon.py | 8 ++++++-- vllm/model_executor/models/phi.py | 8 ++++++-- vllm/model_executor/models/phi3_small.py | 8 ++++++-- vllm/model_executor/models/phi3v.py | 9 ++++++--- vllm/model_executor/models/phimoe.py | 8 ++++++-- vllm/model_executor/models/pixtral.py | 12 ++++++++---- vllm/model_executor/models/qwen.py | 8 ++++++-- vllm/model_executor/models/qwen2.py | 18 ++++++++++++------ vllm/model_executor/models/qwen2_audio.py | 9 +++++++-- vllm/model_executor/models/qwen2_cls.py | 7 ++++--- vllm/model_executor/models/qwen2_moe.py | 8 ++++++-- vllm/model_executor/models/qwen2_rm.py | 7 ++++--- vllm/model_executor/models/qwen2_vl.py | 8 ++++++-- vllm/model_executor/models/siglip.py | 11 ++++++++--- vllm/model_executor/models/solar.py | 9 +++++++-- vllm/model_executor/models/stablelm.py | 8 ++++++-- vllm/model_executor/models/starcoder2.py | 8 ++++++-- vllm/model_executor/models/ultravox.py | 7 ++++--- vllm/model_executor/models/utils.py | 11 ++++++----- vllm/model_executor/models/xverse.py | 8 ++++++-- 74 files changed, 454 insertions(+), 185 deletions(-) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 0f8b81c3ef40..d9ce85949e4e 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -334,7 +334,17 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module: with target_device: model = _initialize_model(vllm_config=vllm_config) - model.load_weights(self._get_all_weights(model_config, model)) + weights_to_load = {name for name, _ in model.named_parameters()} + loaded_weights = model.load_weights( + self._get_all_weights(model_config, model)) + # We only enable strict check for non-quantiized models + # that have loaded weights tracking currently. + if model_config.quantization is None and loaded_weights is not None: + weights_not_loaded = weights_to_load - loaded_weights + if weights_not_loaded: + raise ValueError( + "Following weights were not initialized from " + f"checkpoint: {weights_not_loaded}") for _, module in model.named_modules(): quant_method = getattr(module, "quant_method", None) diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index d52418ee0f6f..e58ad19cab54 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -1,5 +1,5 @@ """Inference-only Snowflake Arctic model.""" -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -480,7 +480,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -518,6 +519,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("ws", f"experts.{expert_id}.w3.weight", expert_id)) params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() logger.info( "It will take ~10 minutes loading from the 16-bit weights. " @@ -573,3 +575,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 01ce7c42cd39..3749a16a3899 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -18,7 +18,7 @@ # limitations under the License. """Inference-only BaiChuan model compatible with HuggingFace weights.""" import math -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -404,13 +404,15 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -449,6 +451,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params class BaichuanForCausalLM(BaiChuanBaseForCausalLM): diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 42dd6119e76f..d8301a36acb0 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -1,4 +1,4 @@ -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Set, Tuple import torch from torch import nn @@ -337,7 +337,8 @@ def forward( return self.encoder(hidden_states, kv_caches, attn_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "query", "q"), @@ -346,6 +347,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ] params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "pooler" in name: continue @@ -368,6 +370,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params class BertEmbeddingModel(nn.Module): diff --git a/vllm/model_executor/models/blip.py b/vllm/model_executor/models/blip.py index e61201067736..6db6462e97f3 100644 --- a/vllm/model_executor/models/blip.py +++ b/vllm/model_executor/models/blip.py @@ -1,6 +1,6 @@ """Minimal implementation of BlipVisionModel intended to be only used within a vision language model.""" -from typing import Iterable, Optional, Tuple, Union +from typing import Iterable, Optional, Set, Tuple, Union import torch import torch.nn as nn @@ -415,7 +415,8 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: return self.post_layernorm(hidden_states) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -423,6 +424,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("qkv_proj", "v_proj", "v"), ] if self.shard_weight else [] params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() layer_count = len(self.encoder.layers) for name, loaded_weight in weights: @@ -440,8 +442,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue - - param = params_dict[name.replace(weight_name, param_name)] + name = name.replace(weight_name, param_name) + param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break @@ -450,3 +452,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index 03dc1d15ab69..7d7639b4a92c 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -1,5 +1,5 @@ from functools import cached_property -from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, +from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, TypedDict, Union) import torch @@ -692,6 +692,7 @@ def sample( ) -> Optional[SamplerOutput]: return self.language_model.sample(logits, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) - loader.load_weights(weights) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index cf2eee817276..1060d418474e 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -16,7 +16,7 @@ # limitations under the License. """Inference-only BLOOM model compatible with HuggingFace weights.""" import math -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -341,8 +341,10 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if name == "lm_head.weight": continue @@ -371,3 +373,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index 7b59c818e0b6..8f91abffaea9 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -1,5 +1,5 @@ from functools import cached_property -from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, +from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set, Tuple, TypedDict, Union) import torch @@ -1034,7 +1034,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -1044,6 +1045,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): (".gate_up_proj", ".up_proj", 1), ] params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -1111,3 +1113,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 70e9b607b064..81e56381eabd 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -3,7 +3,8 @@ """Inference-only ChatGLM model compatible with THUDM weights.""" from argparse import Namespace from array import array -from typing import Dict, Iterable, List, Mapping, Optional, Tuple, TypedDict +from typing import (Dict, Iterable, List, Mapping, Optional, Set, Tuple, + TypedDict) import torch from PIL import Image @@ -645,7 +646,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: # Merge two ColumnParallelLinear into one MergedColumnParallelLinear merged_weights_dict: Dict[str, Dict[str, Optional[torch.Tensor]]] = { "transformer.vision.linear_proj.merged_proj.weight": { @@ -655,6 +657,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): } params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() for name, loaded_weight in weights: is_weight_to_be_merge = False for _, merged_weight_dict in merged_weights_dict.items(): @@ -677,6 +680,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) for combined_name, merged_weight_dict in merged_weights_dict.items(): if combined_name in params_dict: @@ -686,3 +690,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, combined_weight) + loaded_params.add(combined_name) + return loaded_params diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index 2d81b9266826..184758f4a8a4 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -1,6 +1,6 @@ """Minimal implementation of CLIPVisionModel intended to be only used within a vision language model.""" -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Optional, Set, Tuple, Union import numpy as np import torch @@ -483,7 +483,8 @@ def device(self): # (TODO) Add prefix argument for filtering out weights to be loaded # ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986 - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -491,6 +492,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("qkv_proj", "v_proj", "v"), ] if self.shard_weight else [] params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() layer_count = len(self.vision_model.encoder.layers) for name, loaded_weight in weights: @@ -508,8 +510,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue + name = name.replace(weight_name, param_name) - param = params_dict[name.replace(weight_name, param_name)] + param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break @@ -518,3 +521,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index fbb09a64cde9..9fd083e5a02a 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -402,7 +402,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -447,3 +448,4 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index 3952ff31e5ce..eab338800249 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -1,4 +1,4 @@ -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Optional, Set, Tuple, Union import torch import torch.nn as nn @@ -417,13 +417,15 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: expert_params_mapping = [( "w13_weight" if weight_name in ["w1", "v1"] else "w2_weight", f"mlp.{weight_name}", ) for weight_name in ["w1", "v1", "w2"]] params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() for name, loaded_weight in weights: for param_name, weight_name in expert_params_mapping: if weight_name not in name: @@ -447,3 +449,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/decilm.py b/vllm/model_executor/models/decilm.py index b38fd9fa49c2..c551853956b9 100644 --- a/vllm/model_executor/models/decilm.py +++ b/vllm/model_executor/models/decilm.py @@ -22,7 +22,7 @@ # limitations under the License. """Inference-only DeciLM model compatible with HuggingFace weights.""" -from typing import Iterable, Tuple +from typing import Iterable, Set, Tuple import torch @@ -57,7 +57,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): delattr(config, "num_key_value_heads_per_layer") super().__init__(vllm_config=vllm_config) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -67,6 +68,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -97,6 +99,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params def _degroup_weight(self, loaded_weight: torch.Tensor) -> torch.Tensor: hidden_size = self.config.hidden_size diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index 36dfea5a6565..8c5ad9904e92 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -20,7 +20,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Deepseek model.""" -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -442,7 +442,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -453,6 +454,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ] params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -487,3 +489,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 1e32fe60c7a5..d2c4ca0bf85e 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -20,7 +20,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only DeepseekV2 model.""" -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -550,7 +550,8 @@ def make_empty_intermediate_tensors( device=device), }) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("gate_up_proj", "gate_proj", 0), @@ -566,6 +567,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): num_experts=self.config.n_routed_experts) params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -623,3 +625,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py index 52dd603ca558..9d739d047954 100644 --- a/vllm/model_executor/models/exaone.py +++ b/vllm/model_executor/models/exaone.py @@ -22,7 +22,7 @@ # limitations under the License. """Inference-only Exaone model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -513,7 +513,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -523,6 +524,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): (".gate_up_proj", ".c_fc_1", 1), ] params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -543,6 +545,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): default_weight_loader) loaded_weight = loaded_weight[0] weight_loader(param, loaded_weight) + loaded_params.add(scale_name) continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: @@ -576,6 +579,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params # If this function is called, it should always initialize KV cache scale # factors (or else raise an exception). Thus, handled exceptions should diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index e97abe949ccd..2aa4b67d9989 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -18,7 +18,7 @@ """PyTorch Falcon model.""" import math -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -473,7 +473,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: total_num_heads = self.config.num_attention_heads if self.config.new_decoder_architecture: total_num_kv_heads = self.config.num_kv_heads @@ -483,6 +484,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): total_num_kv_heads = total_num_heads num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if name == "lm_head.weight" and self.tie_word_embeddings: # Falcon uses tied embeddings except Falcon-11b. @@ -519,3 +521,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/florence2.py b/vllm/model_executor/models/florence2.py index 971a71180164..d3a9ff6915b8 100644 --- a/vllm/model_executor/models/florence2.py +++ b/vllm/model_executor/models/florence2.py @@ -1,5 +1,5 @@ import math -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Set, Tuple import torch import torch.nn as nn @@ -156,7 +156,8 @@ def sample(self, logits: torch.Tensor, next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -165,12 +166,13 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ] params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue - - param = params_dict[name.replace(weight_name, param_name)] + name = name.replace(weight_name, param_name) + param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break @@ -183,6 +185,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params class Florence2ForConditionalGeneration(nn.Module): @@ -248,10 +252,11 @@ def sample( ) -> SamplerOutput: return self.language_model.sample(logits, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: skip_prefixes = [ 'image_projection', "vision_tower", "image_proj_norm", "image_pos_embed", "visual_temporal_embed" ] loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) - loader.load_weights(weights) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 31fc098a8bb3..7b46907ac83a 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -16,7 +16,8 @@ """ PyTorch Fuyu model.""" import math from array import array -from typing import Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict +from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, + TypedDict) import torch import torch.nn as nn @@ -354,6 +355,7 @@ def sample( next_tokens = self.language_model.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) - loader.load_weights(weights) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index ace13664c6ea..64e03b30bf2f 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -424,7 +424,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -469,3 +470,4 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): logger.warning( "Some weights are not initialized from checkpoints: %s", unloaded_params) + return loaded_params diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index a60b4e73a76d..4ba39223cc07 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -312,7 +312,8 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -354,6 +355,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): logger.warning( "Some weights are not initialized from checkpoints: %s", unloaded_params) + return loaded_params class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): @@ -451,13 +453,14 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader( self, skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) - loader.load_weights(weights) + return loader.load_weights(weights) class Gemma2EmbeddingModel(nn.Module, SupportsPP): diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index fa0fdad28d16..1c61408ae1dd 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -16,7 +16,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GPT-2 model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -298,8 +298,10 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "lm_head.weight" in name: # GPT-2 ties the weights of the embedding layer and the final @@ -328,3 +330,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index b2fc79d0d36d..50a143cb1b60 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -17,7 +17,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GPTBigCode model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -323,8 +323,10 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "lm_head.weight" in name: continue @@ -344,3 +346,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader(param, loaded_weight, 'v') else: weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index cec3fd12a67d..d5defc60764e 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GPT-J model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -291,7 +291,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -301,6 +302,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "attn.bias" in name or "attn.masked_bias" in name: continue @@ -330,3 +332,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index 11f286d6bcba..0bb5e2f9b95f 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GPT-NeoX model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -303,8 +303,10 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if ("attention.bias" in name or "attention.masked_bias" in name or "rotary_emb.inv_freq" in name): @@ -337,3 +339,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index cb2583e69d88..c1e2e87f08ec 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -20,7 +20,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only IBM Granite model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -455,7 +455,8 @@ def make_empty_intermediate_tensors( device=device), }) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -465,6 +466,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): (".gate_up_proj", ".up_proj", 1), ] params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -485,6 +487,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): default_weight_loader) loaded_weight = loaded_weight[0] weight_loader(param, loaded_weight) + loaded_params.add(scale_name) continue for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: @@ -518,6 +521,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params # If this function is called, it should always initialize KV cache scale # factors (or else raise an exception). Thus, handled exceptions should diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py index f437dd521a7d..a91a18816995 100644 --- a/vllm/model_executor/models/granitemoe.py +++ b/vllm/model_executor/models/granitemoe.py @@ -20,7 +20,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GraniteMoe model.""" -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Set, Tuple import torch from torch import nn @@ -419,7 +419,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: new_weights = {} for n, p in weights: if n.endswith('.block_sparse_moe.input_linear.weight'): @@ -452,4 +453,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): pass else: new_weights[n] = p - mixtral.MixtralForCausalLM.load_weights(self, new_weights.items()) + return mixtral.MixtralForCausalLM.load_weights(self, + new_weights.items()) diff --git a/vllm/model_executor/models/idefics2_vision_model.py b/vllm/model_executor/models/idefics2_vision_model.py index b21bc2a3f9ce..16192928beb1 100644 --- a/vllm/model_executor/models/idefics2_vision_model.py +++ b/vllm/model_executor/models/idefics2_vision_model.py @@ -15,7 +15,7 @@ # limitations under the License. """PyTorch Idefics2 model.""" -from typing import Iterable, Optional, Tuple +from typing import Iterable, Optional, Set, Tuple import torch from torch import nn @@ -331,7 +331,8 @@ def forward( last_hidden_state = self.post_layernorm(encoder_outputs) return last_hidden_state - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -339,11 +340,13 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("qkv_proj", "v_proj", "v"), ] params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue - param = params_dict[name.replace(weight_name, param_name)] + name = name.replace(weight_name, param_name) + param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break @@ -352,3 +355,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index 0cecc754e916..5d176b2a4e41 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -15,7 +15,7 @@ import math from typing import (Dict, Iterable, List, Literal, Mapping, NamedTuple, - Optional, Tuple, TypedDict, Union) + Optional, Set, Tuple, TypedDict, Union) import torch import torch.utils.checkpoint @@ -751,9 +751,10 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) - loader.load_weights(weights) + return loader.load_weights(weights) def get_mm_mapping(self) -> MultiModelKeys: """ diff --git a/vllm/model_executor/models/intern_vit.py b/vllm/model_executor/models/intern_vit.py index 9761635d2a6c..bd91a0806ae5 100644 --- a/vllm/model_executor/models/intern_vit.py +++ b/vllm/model_executor/models/intern_vit.py @@ -5,7 +5,7 @@ # Licensed under The MIT License [see LICENSE for details] # -------------------------------------------------------- from functools import partial -from typing import Iterable, Optional, Tuple +from typing import Iterable, Optional, Set, Tuple import torch import torch.nn as nn @@ -469,10 +469,14 @@ def forward( return encoder_outputs - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 19bfe16e4d5f..94b819b5d936 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -369,13 +369,15 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("gate_up_proj", "w1", 0), ("gate_up_proj", "w3", 1), ] params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -402,3 +404,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 92579e3aae94..7ea2f9be2191 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -6,7 +6,7 @@ # -------------------------------------------------------- import re from functools import cached_property, partial -from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, +from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, TypedDict, Union) import torch @@ -663,6 +663,7 @@ def sample( ) -> Optional[SamplerOutput]: return self.language_model.sample(logits, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) - loader.load_weights(weights) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index ee49ffb3cd87..41db85b67845 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -19,7 +19,7 @@ """Inference-only Jais model compatible with HuggingFace weights.""" import math -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -350,8 +350,10 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "lm_head.weight" in name: # GPT-2 ties the weights of the embedding layer and the final @@ -382,3 +384,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 5612dd688638..f83f0fce7275 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -1,5 +1,5 @@ """Inference-only Jamba model.""" -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Set, Tuple import torch from torch import nn @@ -462,7 +462,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -479,6 +480,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): num_experts=self.config.num_experts) params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -534,6 +536,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params def _is_moe_layer(name: str): diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index e53631ef19f3..2b40e9ec73fa 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -20,7 +20,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only LLaMA model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -350,7 +350,8 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -360,6 +361,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): (".gate_up_proj", ".up_proj", 1), ] params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -375,6 +377,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): default_weight_loader) loaded_weight = loaded_weight[0] weight_loader(param, loaded_weight) + loaded_params.add(scale_name) continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: @@ -390,7 +393,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) - break else: # Skip loading extra bias for GPTQ models. @@ -408,6 +410,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params # If this function is called, it should always initialize KV cache scale # factors (or else raise an exception). Thus, handled exceptions should @@ -577,13 +581,14 @@ def sample(self, logits: torch.Tensor, next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader( self, skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) - loader.load_weights( + return loader.load_weights( self.maybe_remap_mistral(name, loaded_weight) for name, loaded_weight in weights) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index b13bcfa67681..e7d3161a7cb2 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -1,5 +1,5 @@ from functools import cached_property -from typing import (Iterable, List, Literal, Mapping, Optional, Protocol, +from typing import (Iterable, List, Literal, Mapping, Optional, Protocol, Set, Tuple, TypedDict, Union) import torch @@ -547,6 +547,7 @@ def sample( ) -> Optional[SamplerOutput]: return self.language_model.sample(logits, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) - loader.load_weights(weights) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index dd2fa6cac969..37e2227a52dc 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -1,5 +1,5 @@ from functools import cached_property -from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, +from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, TypedDict, Union) import torch @@ -654,6 +654,7 @@ def pooler( ) -> Optional[PoolerOutput]: return self._pooler(hidden_states, pooling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) - loader.load_weights(weights) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index 5d5598d07bfd..e2880c76cf43 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -1,6 +1,6 @@ import math from functools import cached_property -from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, +from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, TypedDict, Union) import numpy as np @@ -445,10 +445,11 @@ def sample( ) -> Optional[SamplerOutput]: return self.language_model.sample(logits, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader( self, # This model doesn't support images for now ignore_unexpected_prefixes=["image_newline"], ) - loader.load_weights(weights) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index a5b210817783..705ca1e4ab6e 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -1,6 +1,6 @@ import math from functools import cached_property -from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, +from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, TypedDict, Union) import numpy as np @@ -887,6 +887,7 @@ def sample( ) -> Optional[SamplerOutput]: return self.language_model.sample(logits, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) - loader.load_weights(weights) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index ac0d265a961f..405b8f7787ba 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -1,5 +1,5 @@ """PyTorch MAMBA model.""" -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Set, Tuple import torch from torch import nn @@ -243,8 +243,10 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "A_log" in name: name = name.replace("A_log", "A") @@ -256,3 +258,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/medusa.py b/vllm/model_executor/models/medusa.py index b05360b55466..b4ed6538bdda 100644 --- a/vllm/model_executor/models/medusa.py +++ b/vllm/model_executor/models/medusa.py @@ -1,4 +1,4 @@ -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Set, Tuple import torch import torch.nn as nn @@ -156,8 +156,10 @@ def generate_proposals( sampling_metadata=sampling_metadata, ) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() weights_map = {} @@ -181,9 +183,12 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) if self.token_map is not None: self.token_map.to(device=self.lm_heads[0].weight.device) assert (self.truncated_vocab_size == self.orig_vocab_size) or (self.token_map is not None) + + return loaded_params diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index 6b67266c5336..b92bff4d7c28 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -21,7 +21,7 @@ # limitations under the License. """Inference-only MiniCPM model compatible with HuggingFace weights.""" import math -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -539,7 +539,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -556,6 +557,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): for weight_name in ["w1", "w2", "w3"] ] params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -606,3 +608,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index fd8eda997f76..99bf1d42d035 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -24,7 +24,7 @@ import re from functools import partial from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional, - Tuple, TypedDict, Union) + Set, Tuple, TypedDict, Union) import torch import torch.types @@ -602,7 +602,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -612,6 +613,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): if key_to_modify in name: @@ -630,10 +632,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue - if is_pp_missing_parameter( - name.replace(weight_name, param_name), self): + name = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name, self): continue - param = params_dict[name.replace(weight_name, param_name)] + param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break @@ -646,6 +648,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params def get_mm_mapping(self) -> MultiModelKeys: """ diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index eebf5bab5a28..0faffb4f1b00 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -20,7 +20,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Mixtral model.""" -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -404,7 +404,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -421,6 +422,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): num_experts=self.config.num_local_experts) params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -478,3 +480,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index af2e9586988d..ddd6afcf6a1b 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -20,7 +20,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Mixtral model.""" -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Optional, Set, Tuple, Union import numpy as np import torch @@ -409,7 +409,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -418,6 +419,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ] params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -448,3 +450,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index db7ee7b2d853..41f62b37f3bd 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -13,7 +13,7 @@ # limitations under the License. """PyTorch Mllama model.""" import math -from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, +from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, TypedDict, Union) import numpy as np @@ -1427,7 +1427,8 @@ def forward( return outputs - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -1437,7 +1438,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): (".gate_up_proj", ".up_proj", 1), ] params_dict = dict(self.named_parameters()) - updated_params = set() + updated_params: Set[str] = set() for name, loaded_weight in weights: if 'patch_embedding.weight' in name: name = name.replace('patch_embedding.weight', @@ -1457,6 +1458,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + updated_params.add(name) + return updated_params def skip_attention_mask(sparse_mask: List[List[int]]) -> bool: diff --git a/vllm/model_executor/models/mlp_speculator.py b/vllm/model_executor/models/mlp_speculator.py index 4d7e82880041..f2aa2653c4f5 100644 --- a/vllm/model_executor/models/mlp_speculator.py +++ b/vllm/model_executor/models/mlp_speculator.py @@ -1,5 +1,5 @@ import math -from typing import Iterable, List, Tuple +from typing import Iterable, List, Set, Tuple import torch import torch.nn as nn @@ -188,11 +188,15 @@ def generate_proposals( return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: param = params_dict.get(name.replace("speculator.", "")) if param is not None: weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index 3c74ef2448ab..8716e92b0f1c 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -1,6 +1,6 @@ # Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main import math -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Optional, Set, Tuple, Union import torch import torch.nn as nn @@ -324,8 +324,10 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() for name, loaded_weight in weights: # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: @@ -336,3 +338,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py index eb45beae7d21..ceab299a7950 100644 --- a/vllm/model_executor/models/nemotron.py +++ b/vllm/model_executor/models/nemotron.py @@ -20,7 +20,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Nemotron model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -474,7 +474,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -482,6 +483,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): (".qkv_proj", ".v_proj", "v"), ] params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -522,3 +524,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 98d4e1ec320a..dc138e2e636a 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -20,7 +20,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only OLMo model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -356,7 +356,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -366,6 +367,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -402,3 +404,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/olmoe.py b/vllm/model_executor/models/olmoe.py index f4eebab8c98d..ab87695d8e65 100644 --- a/vllm/model_executor/models/olmoe.py +++ b/vllm/model_executor/models/olmoe.py @@ -10,7 +10,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only OLMoE model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -364,7 +364,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -383,6 +384,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): num_experts=self.config.num_experts) params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -455,3 +457,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 997fe642439e..db85a494980a 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -16,7 +16,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only OPT model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -394,7 +394,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -402,6 +403,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("qkv_proj", "v_proj", "v"), ] params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "lm_head.weight" in name and self.config.tie_word_embeddings: continue @@ -431,3 +433,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index 39d659c49cbc..b01734af8ddd 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -3,7 +3,7 @@ # Copyright (c) OrionStar Inc. # LICENSE: https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/LICENSE """Inference-only Orion-14B model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -327,7 +327,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -337,6 +338,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -368,3 +370,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index eea229359255..dd5256eb87ab 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -1,4 +1,4 @@ -from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, +from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, TypedDict, Union) import torch @@ -295,6 +295,7 @@ def sample( ) -> Optional[SamplerOutput]: return self.language_model.sample(logits, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) - loader.load_weights(weights) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/persimmon.py b/vllm/model_executor/models/persimmon.py index 62c509153a11..3b8199f4f166 100644 --- a/vllm/model_executor/models/persimmon.py +++ b/vllm/model_executor/models/persimmon.py @@ -19,7 +19,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only persimmon model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -324,8 +324,10 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -358,3 +360,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index a2ab0d74c48d..0a117bf16c9b 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -34,7 +34,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """Inference-only Phi-1.5 model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -345,7 +345,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -353,6 +354,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("qkv_proj", "v_proj", "v") ] params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: @@ -383,3 +385,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/phi3_small.py b/vllm/model_executor/models/phi3_small.py index 2139cec44180..a78e4d355a31 100644 --- a/vllm/model_executor/models/phi3_small.py +++ b/vllm/model_executor/models/phi3_small.py @@ -1,5 +1,5 @@ import math -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -457,9 +457,11 @@ def sample( sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -471,3 +473,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 4db65edc174f..2e583bb08e87 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -15,7 +15,7 @@ import itertools import re from functools import cached_property, lru_cache -from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, +from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set, Tuple, TypedDict, Union) import numpy as np @@ -744,7 +744,8 @@ def pooler( ) -> Optional[PoolerOutput]: return self._pooler(hidden_states, pooling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ "model.vision_embed_tokens.wte": "embed_tokens", @@ -759,5 +760,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # The HF config doesn't specify whether these are tied, # so we detect it this way - if "embed_tokens" not in autoloaded_weights: + if "embed_tokens.weight" not in autoloaded_weights: self.embed_tokens = self.language_model.model.embed_tokens + autoloaded_weights.add("embed_tokens.weight") + return autoloaded_weights diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py index b7e70f8fa2c6..e475d286bd7e 100644 --- a/vllm/model_executor/models/phimoe.py +++ b/vllm/model_executor/models/phimoe.py @@ -20,7 +20,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only PhiMoE model.""" -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -598,7 +598,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -613,6 +614,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): num_experts=self.config.num_local_experts) params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -666,3 +668,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 790a260d43ec..d44a538d56b8 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -1,7 +1,7 @@ from dataclasses import dataclass, fields from functools import cached_property from itertools import tee -from typing import Iterable, List, Mapping, Optional, Tuple, Union +from typing import Iterable, List, Mapping, Optional, Set, Tuple, Union import numpy import torch @@ -1067,7 +1067,8 @@ def forward( # (TODO) Add prefix argument for filtering out weights to be loaded # ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986 - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -1077,6 +1078,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): (".gate_up_proj", ".up_proj", 1), ] params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() layer_count = len(self.transformer.layers) for name, loaded_weight in weights: @@ -1089,8 +1091,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue - - param = params_dict[name.replace(weight_name, param_name)] + name = name.replace(weight_name, param_name) + param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break @@ -1099,3 +1101,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 447632cefcd9..3978c176a214 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -8,7 +8,7 @@ import re from functools import partial from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping, - Optional, Tuple, TypedDict, Union) + Optional, Set, Tuple, TypedDict, Union) import numpy as np import torch @@ -964,13 +964,15 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("gate_up_proj", "w2", 0), ("gate_up_proj", "w1", 1), ] params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -999,6 +1001,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params class QWenLLM(QWenBaseModel): diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 8f10df808c21..370cff5fa153 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -21,7 +21,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2 model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -332,7 +332,8 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -342,6 +343,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -372,6 +374,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): @@ -494,13 +498,14 @@ def pooler( ) -> Optional[PoolerOutput]: return self._pooler(hidden_states, pooling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader( self, skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) - loader.load_weights(weights) + return loader.load_weights(weights) class Qwen2EmbeddingModel(nn.Module, SupportsLoRA, SupportsPP): @@ -564,7 +569,8 @@ def pooler( ) -> Optional[PoolerOutput]: return self._pooler(hidden_states, pooling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self, ignore_unexpected_prefixes=["lm_head."]) - loader.load_weights(weights) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index d30950361ad8..a4965f34b1ca 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -20,7 +20,8 @@ # limitations under the License. """Inference-only Qwen2-Audio model compatible with HuggingFace weights.""" from functools import lru_cache -from typing import Iterable, List, Mapping, Optional, Tuple, TypedDict, Union +from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict, + Union) import librosa import numpy as np @@ -420,7 +421,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -430,6 +432,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -463,3 +466,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/qwen2_cls.py b/vllm/model_executor/models/qwen2_cls.py index 07eb330620a4..dc5dabf6fc38 100644 --- a/vllm/model_executor/models/qwen2_cls.py +++ b/vllm/model_executor/models/qwen2_cls.py @@ -4,7 +4,7 @@ # Copyright 2024 The Qwen team. # Copyright 2023 The vLLM team. """Inference-only Qwen2-Classification model compatible with HF weights.""" -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Set, Tuple import torch from torch import nn @@ -97,7 +97,8 @@ def pooler( ) -> Optional[PoolerOutput]: return self._pooler(hidden_states, pooling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self, ignore_unexpected_prefixes=["lm_head."]) - loader.load_weights(weights) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 249d94b5d95e..96a9bc451f4d 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -21,7 +21,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2MoE model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union import torch import torch.nn.functional as F @@ -436,7 +436,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -455,6 +456,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): num_experts=self.config.num_experts) params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -532,3 +534,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/qwen2_rm.py b/vllm/model_executor/models/qwen2_rm.py index 6db467af334f..988d682d36be 100644 --- a/vllm/model_executor/models/qwen2_rm.py +++ b/vllm/model_executor/models/qwen2_rm.py @@ -3,7 +3,7 @@ # Copyright 2024 The Qwen team. # Copyright 2023 The vLLM team. """Inference-only Qwen2-RM model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -110,7 +110,8 @@ def pooler( ) -> Optional[PoolerOutput]: return self._pooler(hidden_states, pooling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self, ignore_unexpected_prefixes=["lm_head."]) - loader.load_weights(weights) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 2335baf45977..ef6b52db6e17 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -23,7 +23,7 @@ """Inference-only Qwen2-VL model compatible with HuggingFace weights.""" from functools import partial from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping, - Optional, Tuple, Type, TypedDict, Union) + Optional, Set, Tuple, Type, TypedDict, Union) import torch import torch.nn as nn @@ -1333,7 +1333,8 @@ def pooler( ) -> Optional[PoolerOutput]: return self._pooler(hidden_states, pooling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -1343,6 +1344,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("gate_up_proj", "gate_proj", 0), ] params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -1392,3 +1394,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index acaf4afdecfe..c9e09b879843 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -2,7 +2,7 @@ within a vision language model.""" import math -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Optional, Set, Tuple, Union import numpy as np import torch @@ -594,7 +594,8 @@ def forward( interpolate_pos_encoding=interpolate_pos_encoding, ) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -602,6 +603,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("qkv_proj", "v_proj", "v"), ] if self.shard_weight else [] params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() layer_count = len(self.vision_model.encoder.layers) for name, loaded_weight in weights: @@ -619,8 +621,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue + name = name.replace(weight_name, param_name) - param = params_dict[name.replace(weight_name, param_name)] + param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break @@ -629,3 +632,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py index affb2c975ce4..6d6fafc5ab0e 100644 --- a/vllm/model_executor/models/solar.py +++ b/vllm/model_executor/models/solar.py @@ -21,7 +21,7 @@ # limitations under the License. """Inference-only Solar model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -477,7 +477,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -487,6 +488,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): (".gate_up_proj", ".up_proj", 1), ] params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -502,6 +504,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): default_weight_loader) loaded_weight = loaded_weight[0] weight_loader(param, loaded_weight) + loaded_params.add(scale_name) continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: @@ -535,6 +538,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params # If this function is called, it should always initialize KV cache scale # factors (or else raise an exception). Thus, handled exceptions should diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index 99acce596602..e11d2e916730 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -18,7 +18,7 @@ # https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json """Inference-only StabeLM (https://github.com/Stability-AI/StableLM) model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -306,7 +306,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -316,6 +317,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -347,3 +349,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index 0ef940acebb9..74c66042226d 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -17,7 +17,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """ PyTorch Starcoder2 model.""" -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -314,7 +314,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -323,6 +324,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ] params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -346,3 +348,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 9fde22c016de..512adbc7db35 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -3,7 +3,7 @@ import math from functools import cached_property, lru_cache -from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, +from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, TypedDict, Union, cast) import numpy as np @@ -504,10 +504,11 @@ def sample( ) -> Optional[SamplerOutput]: return self.language_model.sample(logits, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={"audio_tower.model.encoder.": "audio_tower."}) loader = AutoWeightsLoader(self, ignore_unexpected_prefixes=["audio_tower."]) - loader.load_weights(weights, mapper=hf_to_vllm_mapper) + return loader.load_weights(weights, mapper=hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 1d51885f9094..7a4fcce95603 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -1,7 +1,7 @@ import itertools from dataclasses import dataclass, field from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping, - Optional, Protocol, Tuple, Union, overload) + Optional, Protocol, Set, Tuple, Union, overload) import torch import torch.nn as nn @@ -172,8 +172,9 @@ def _load_module( if module != self.module: module_load_weights = getattr(module, "load_weights", None) if callable(module_load_weights): - module_load_weights(weights) - return + loaded_params = module_load_weights(weights) + yield from map(lambda x: self._get_qualname(base_prefix, x), + loaded_params) child_modules = dict(module.named_children()) child_params = dict(module.named_parameters(recurse=False)) @@ -222,11 +223,11 @@ def load_weights( weights: Iterable[Tuple[str, torch.Tensor]], *, mapper: Optional[WeightsMapper] = None, - ) -> List[str]: + ) -> Set[str]: if mapper is not None: weights = mapper.apply(weights) - autoloaded_weights = list(self._load_module("", self.module, weights)) + autoloaded_weights = set(self._load_module("", self.module, weights)) return autoloaded_weights diff --git a/vllm/model_executor/models/xverse.py b/vllm/model_executor/models/xverse.py index 51172d8782a7..bc37a997eabb 100644 --- a/vllm/model_executor/models/xverse.py +++ b/vllm/model_executor/models/xverse.py @@ -19,7 +19,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Xverse model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -376,7 +376,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), @@ -385,6 +386,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if ("rotary_emb.inv_freq" in name or "rotary_emb.cos_cached" in name @@ -413,3 +415,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params From 466b2cfb7468436882556d5bff77672534b4d480 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=84=8D=F0=9D=95=A0=F0=9D=95=9D=F0=9D=95=9D=F0=9D=95=A0?= =?UTF-8?q?=F0=9D=95=A8=20=F0=9D=95=84=F0=9D=95=92=F0=9D=95=9F?= Date: Mon, 18 Nov 2024 05:29:26 +0200 Subject: [PATCH 17/25] [Bugfix] Ignore ray reinit error when current platform is ROCm or XPU (#10375) Signed-off-by: Hollow Man Signed-off-by: Linkun Chen --- vllm/executor/ray_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index 41dd59bc65ec..4f28efd63908 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -234,7 +234,7 @@ def initialize_ray_cluster( if current_platform.is_rocm() or current_platform.is_xpu(): # Try to connect existing ray instance and create a new one if not found try: - ray.init("auto") + ray.init("auto", ignore_reinit_error=True) except ConnectionError: logger.warning( "No existing RAY instance detected. " From 3f092cef27d42ce6f3c8e4c5db9b125ebbfd8638 Mon Sep 17 00:00:00 2001 From: Linkun Chen Date: Mon, 18 Nov 2024 05:27:48 +0000 Subject: [PATCH 18/25] update RequestOutput.__init__() to take `multi_modal_placeholders` as optional argument also require it to be passed as kwargs, to avoid breaking existing code. Signed-off-by: Linkun Chen --- vllm/outputs.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/vllm/outputs.py b/vllm/outputs.py index 32160a8c0432..25d585d65285 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -96,7 +96,6 @@ def __init__( request_id: str, prompt: Optional[str], prompt_token_ids: Optional[List[int]], - multi_modal_placeholders: MultiModalPlaceholderDict, prompt_logprobs: Optional[PromptLogprobs], outputs: List[CompletionOutput], finished: bool, @@ -105,11 +104,13 @@ def __init__( encoder_prompt: Optional[str] = None, encoder_prompt_token_ids: Optional[List[int]] = None, num_cached_tokens: Optional[int] = None, + *, + multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None, ) -> None: self.request_id = request_id self.prompt = prompt self.prompt_token_ids = prompt_token_ids - self.multi_modal_placeholders = multi_modal_placeholders + self.multi_modal_placeholders = multi_modal_placeholders or {} self.prompt_logprobs = prompt_logprobs self.outputs = outputs self.finished = finished @@ -144,7 +145,6 @@ def new( request_id=request_id, prompt=prompt, prompt_token_ids=prompt_token_ids, - multi_modal_placeholders={}, prompt_logprobs=None, # TODO outputs=[completion_output], finished=finished, @@ -158,8 +158,7 @@ def from_seq_group( finished = seq_group.is_finished() if seq_group.request_id in seq_id_to_seq_group: - group: SequenceGroupBase = seq_id_to_seq_group[ - seq_group.request_id] + group: SequenceGroupBase = seq_id_to_seq_group[seq_group.request_id] if finished: group.finish_seq(seq_group) assembled_seq_group = group.maybe_assemble_group(seq_group) @@ -202,8 +201,8 @@ def from_seq_group( # num_cached_tokens should be the same for all the sequences num_cached_tokens = None for i, seq in enumerate(top_n_seqs): - output_text = seq.get_output_text_to_return( - text_buffer_length, delta) + output_text = seq.get_output_text_to_return(text_buffer_length, + delta) output_token_ids = seq.get_output_token_ids_to_return(delta) num_output_tokens = 1 if isinstance(output_token_ids, @@ -280,17 +279,17 @@ def from_seq_group( seq_group.set_finished_time(finished_time) init_args = (seq_group.request_id, prompt, prompt_token_ids, - seq_group.multi_modal_placeholders, prompt_logprobs, - outputs, finished, seq_group.metrics, + prompt_logprobs, outputs, finished, seq_group.metrics, seq_group.lora_request, encoder_prompt, encoder_prompt_token_ids, num_cached_tokens) + init_kwargs = {"multi_modal_placeholders": seq_group.multi_modal_placeholders} if use_cache: request_output = seq_group.cached_request_output - request_output.__init__(*init_args) # type: ignore + request_output.__init__(*init_args, **init_kwargs) # type: ignore else: - request_output = cls(*init_args) + request_output = cls(*init_args, **init_kwargs) return request_output @@ -298,7 +297,6 @@ def __repr__(self) -> str: return (f"RequestOutput(request_id={self.request_id}, " f"prompt={self.prompt!r}, " f"prompt_token_ids={self.prompt_token_ids}, " - f"multi_modal_placeholders={self.multi_modal_placeholders}, " f"encoder_prompt={self.encoder_prompt!r}, " f"encoder_prompt_token_ids={self.encoder_prompt_token_ids}, " f"prompt_logprobs={self.prompt_logprobs}, " @@ -306,7 +304,8 @@ def __repr__(self) -> str: f"finished={self.finished}, " f"metrics={self.metrics}, " f"lora_request={self.lora_request}, " - f"num_cached_tokens={self.num_cached_tokens})") + f"num_cached_tokens={self.num_cached_tokens}, " + f"multi_modal_placeholders={self.multi_modal_placeholders})") class EmbeddingRequestOutput: From dd8427ef342dfa52d6c252db014b414e9e2f5e04 Mon Sep 17 00:00:00 2001 From: Linkun Chen Date: Mon, 18 Nov 2024 05:27:48 +0000 Subject: [PATCH 19/25] update RequestOutput.__init__() to take `multi_modal_placeholders` as optional argument also require it to be passed as kwargs, to avoid breaking existing code. Signed-off-by: Linkun Chen --- vllm/outputs.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/outputs.py b/vllm/outputs.py index 25d585d65285..0e307d337bf1 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -282,7 +282,9 @@ def from_seq_group( prompt_logprobs, outputs, finished, seq_group.metrics, seq_group.lora_request, encoder_prompt, encoder_prompt_token_ids, num_cached_tokens) - init_kwargs = {"multi_modal_placeholders": seq_group.multi_modal_placeholders} + init_kwargs = { + "multi_modal_placeholders": seq_group.multi_modal_placeholders + } if use_cache: request_output = seq_group.cached_request_output From 76ac8b0157de48857602fcbe7406e57724e814dd Mon Sep 17 00:00:00 2001 From: Linkun Chen Date: Mon, 18 Nov 2024 05:27:48 +0000 Subject: [PATCH 20/25] update RequestOutput.__init__() to take `multi_modal_placeholders` as optional argument also require it to be passed as kwargs, to avoid breaking existing code. Signed-off-by: Linkun Chen --- vllm/outputs.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/vllm/outputs.py b/vllm/outputs.py index 0e307d337bf1..c56aa0478925 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -278,20 +278,26 @@ def from_seq_group( finished_time = time.time() if finished else None seq_group.set_finished_time(finished_time) - init_args = (seq_group.request_id, prompt, prompt_token_ids, - prompt_logprobs, outputs, finished, seq_group.metrics, - seq_group.lora_request, encoder_prompt, - encoder_prompt_token_ids, num_cached_tokens) init_kwargs = { + "request_id": seq_group.request_id, + "prompt": prompt, + "prompt_token_ids": prompt_token_ids, + "prompt_logprobs": prompt_logprobs, + "outputs": outputs, + "finished": finished, + "metrics": seq_group.metrics, + "lora_request": seq_group.lora_request, + "encoder_prompt": encoder_prompt, + "encoder_prompt_token_ids": encoder_prompt_token_ids, + "num_cached_tokens": num_cached_tokens, "multi_modal_placeholders": seq_group.multi_modal_placeholders } if use_cache: request_output = seq_group.cached_request_output - request_output.__init__(*init_args, **init_kwargs) # type: ignore - + request_output.__init__(**init_kwargs) # type: ignore else: - request_output = cls(*init_args, **init_kwargs) + request_output = cls(**init_kwargs) return request_output From c963a256f43603194310c27aacccaa555cb350e6 Mon Sep 17 00:00:00 2001 From: Linkun Chen Date: Mon, 18 Nov 2024 05:59:58 +0000 Subject: [PATCH 21/25] disable mypy type check mypy is not smart enough to validate kwargs Signed-off-by: Linkun Chen --- vllm/outputs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/outputs.py b/vllm/outputs.py index c56aa0478925..80cf4702934f 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -297,7 +297,7 @@ def from_seq_group( request_output = seq_group.cached_request_output request_output.__init__(**init_kwargs) # type: ignore else: - request_output = cls(**init_kwargs) + request_output = cls(**init_kwargs) # type: ignore return request_output From 470fbd3921d22a53cac1662f26ca207f1251ca8e Mon Sep 17 00:00:00 2001 From: Linkun Chen Date: Mon, 18 Nov 2024 05:59:58 +0000 Subject: [PATCH 22/25] disable mypy type check mypy is not smart enough to validate kwargs Signed-off-by: Linkun Chen --- vllm/outputs.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/outputs.py b/vllm/outputs.py index 80cf4702934f..4ae9b377ae69 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -158,7 +158,8 @@ def from_seq_group( finished = seq_group.is_finished() if seq_group.request_id in seq_id_to_seq_group: - group: SequenceGroupBase = seq_id_to_seq_group[seq_group.request_id] + group: SequenceGroupBase = seq_id_to_seq_group[ + seq_group.request_id] if finished: group.finish_seq(seq_group) assembled_seq_group = group.maybe_assemble_group(seq_group) @@ -201,8 +202,8 @@ def from_seq_group( # num_cached_tokens should be the same for all the sequences num_cached_tokens = None for i, seq in enumerate(top_n_seqs): - output_text = seq.get_output_text_to_return(text_buffer_length, - delta) + output_text = seq.get_output_text_to_return( + text_buffer_length, delta) output_token_ids = seq.get_output_token_ids_to_return(delta) num_output_tokens = 1 if isinstance(output_token_ids, From 550be230dfa4ba7b05f450a053bb8256287885e5 Mon Sep 17 00:00:00 2001 From: Linkun Chen Date: Mon, 18 Nov 2024 06:17:27 +0000 Subject: [PATCH 23/25] remove unnecessary debug code Signed-off-by: Linkun Chen --- tests/models/decoder_only/vision_language/test_pixtral.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/models/decoder_only/vision_language/test_pixtral.py b/tests/models/decoder_only/vision_language/test_pixtral.py index 4edfd6862511..6233860747b9 100644 --- a/tests/models/decoder_only/vision_language/test_pixtral.py +++ b/tests/models/decoder_only/vision_language/test_pixtral.py @@ -17,7 +17,6 @@ from vllm import (EngineArgs, LLMEngine, RequestOutput, SamplingParams, TextPrompt, TokensPrompt) -from vllm.logger import init_logger from vllm.multimodal import MultiModalDataBuiltins from vllm.multimodal.inputs import PlaceholderRange from vllm.sequence import Logprob, SampleLogprobs @@ -28,8 +27,6 @@ if TYPE_CHECKING: from _typeshed import StrPath -logger = init_logger(__name__) - MODELS = ["mistralai/Pixtral-12B-2409"] IMG_URLS = [ "https://picsum.photos/id/237/400/300", From d8f785a7c76a5b5e7e7fc1d8d1fea64ae506b2d9 Mon Sep 17 00:00:00 2001 From: Linkun Chen Date: Wed, 5 Feb 2025 00:48:09 +0000 Subject: [PATCH 24/25] [V1] Enhance check when clicing encoder output Prepare for vllm-project/vllm#11409 For pixtral model, we need to insert placeholders in the middle of encoder output, to fit into whole soft embedding. This case makes slicing operation tricky. This PR raises assertion if something's off. Signed-off-by: Linkun Chen --- vllm/v1/worker/gpu_model_runner.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index bfc9d1ca83f4..ef56ae4c9c26 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -734,6 +734,7 @@ def _gather_encoder_outputs( assert req_id in self.encoder_cache assert i in self.encoder_cache[req_id] encoder_output = self.encoder_cache[req_id][i] + assert end_idx <= encoder_output.shape[0], f"{end_idx=} {encoder_output.shape=}" encoder_outputs.append(encoder_output[start_idx:end_idx]) return encoder_outputs From 89f243bf7f82a9ab23b6f69c72f0f3d4b19007d7 Mon Sep 17 00:00:00 2001 From: Linkun Chen Date: Tue, 11 Feb 2025 08:54:55 +0000 Subject: [PATCH 25/25] [V1][Pixtral-HF] Add custom `slice_encoder_output` for Pixtral Signed-off-by: Linkun Chen --- vllm/model_executor/models/llava.py | 29 ++++++++++++++++++++++++++++- vllm/v1/worker/gpu_model_runner.py | 28 +++++++++++++++++++++------- 2 files changed, 49 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index b1fee3eeb542..1e0ff3ce3609 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -27,7 +27,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalInputs, MultiModalKwargs, - NestedTensors) + NestedTensors, PlaceholderRange) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -507,6 +507,33 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: if (config.text_config.architectures is None and config.text_config.model_type == "mistral"): config.text_config.architectures = ["MistralForCausalLM"] + + def _slice_encoder_output( + mm_input: MultiModalKwargs, + encoder_output: torch.Tensor, + mm_pos: PlaceholderRange, + num_computed_tokens: int, + num_scheduled_tokens: int, + ) -> torch.Tensor: + assert "pixel_values" in mm_input + image_input = mm_input["pixel_values"] + ncols, nrows = get_pixtral_hf_image_feature_grid_size( + self.config.vision_config, + image_width=image_input.shape[-1], + image_height=image_input.shape[-2], + ) + placeholder_start = mm_pos["offset"] + # Turn placeholder position into encoder output position + def placeholder_pos_to_encoder_output_pos(placeholder_pos: int) -> int: + return placeholder_pos % (ncols + 1) + placeholder_pos // (ncols + 1) * ncols + start_idx = max(placeholder_pos_to_encoder_output_pos(num_computed_tokens - placeholder_start), 0) + end_idx = min( + placeholder_pos_to_encoder_output_pos(num_computed_tokens + num_scheduled_tokens - placeholder_start), + len(encoder_output)) + assert start_idx <= end_idx, f"{start_idx=} should be no greater than {end_idx=}" + return encoder_output[start_idx:end_idx] + self.slice_encoder_output = _slice_encoder_output + if (config.projector_hidden_act is None and config.vision_config.hidden_act == "gelu"): config.projector_hidden_act = "gelu" diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ef56ae4c9c26..b87972c48527 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -726,16 +726,30 @@ def _gather_encoder_outputs( # in the decoder's KV cache. continue - start_idx = max(num_computed_tokens - start_pos, 0) - end_idx = min( - num_computed_tokens - start_pos + num_scheduled_tokens, - num_encoder_tokens) - assert start_idx < end_idx assert req_id in self.encoder_cache assert i in self.encoder_cache[req_id] encoder_output = self.encoder_cache[req_id][i] - assert end_idx <= encoder_output.shape[0], f"{end_idx=} {encoder_output.shape=}" - encoder_outputs.append(encoder_output[start_idx:end_idx]) + if hasattr(self.model, "slice_encoder_output"): + # Per-model custom logic to slice the encoder output. Some + # models (e.g. Pixtral) have dynamic number of special + # tokens (e.g. image_break) in the middle of placeholder + # positions. This allows the model to calculate + # encoder_output slices taking into account the special + # tokens. + encoder_outputs.append(self.model.slice_encoder_output( + mm_input=req_state.mm_inputs[i], + encoder_output=encoder_output, + mm_pos=pos_info, + num_computed_tokens=num_computed_tokens, + num_scheduled_tokens=num_scheduled_tokens)) + else: + start_idx = max(num_computed_tokens - start_pos, 0) + end_idx = min( + num_computed_tokens - start_pos + num_scheduled_tokens, + num_encoder_tokens) + assert start_idx < end_idx + # assert end_idx <= encoder_output.shape[0], f"{end_idx=} {encoder_output.shape=}" + encoder_outputs.append(encoder_output[start_idx:end_idx]) return encoder_outputs def get_model(self) -> nn.Module: