diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 28faa96fd26d..1a36d9d6a5de 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -92,17 +92,9 @@ async def get_request( await asyncio.sleep(interval) -async def send_request( - backend: str, - model: str, - api_url: str, - prompt: str, - prompt_len: int, - output_len: int, - best_of: int, - use_beam_search: bool, - pbar: tqdm -) -> None: +async def send_request(backend: str, model: str, api_url: str, prompt: str, + prompt_len: int, output_len: int, best_of: int, + use_beam_search: bool, pbar: tqdm) -> None: request_start_time = time.perf_counter() headers = {"User-Agent": "Benchmark Client"} @@ -155,7 +147,6 @@ async def send_request( pbar.update(1) - async def benchmark( backend: str, model: str, @@ -217,7 +208,10 @@ def main(args: argparse.Namespace): type=str, default="vllm", choices=["vllm", "tgi"]) - parser.add_argument("--protocol", type=str, default="http", choices=["http", "https"]) + parser.add_argument("--protocol", + type=str, + default="http", + choices=["http", "https"]) parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--port", type=int, default=8000) parser.add_argument("--endpoint", type=str, default="/generate") diff --git a/setup.py b/setup.py index fb37a8d95231..85e51ddd5694 100644 --- a/setup.py +++ b/setup.py @@ -28,7 +28,7 @@ def _is_neuron() -> bool: torch_neuronx_installed = True try: subprocess.run(["neuron-ls"], capture_output=True, check=True) - except FileNotFoundError as e: + except FileNotFoundError: torch_neuronx_installed = False return torch_neuronx_installed @@ -99,7 +99,8 @@ def get_hipcc_rocm_version(): def get_neuronxcc_version(): import sysconfig site_dir = sysconfig.get_paths()["purelib"] - version_file = os.path.join(site_dir, "neuronxcc", "version", "__init__.py") + version_file = os.path.join(site_dir, "neuronxcc", "version", + "__init__.py") # Check if the command was executed successfully with open(version_file, "rt") as fp: diff --git a/tests/conftest.py b/tests/conftest.py index 8d6afdbd0035..9a6403b72e15 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -61,7 +61,10 @@ def __init__( ).cuda() if tokenizer_name is None: tokenizer_name = model_name - self.tokenizer = get_tokenizer(tokenizer_name, trust_remote_code=True) + self.tokenizer = get_tokenizer( + tokenizer_name, + trust_remote_code=True, + ) def generate( self, @@ -178,9 +181,11 @@ def generate( self, prompts: List[str], sampling_params: SamplingParams, + prompt_embeds: List[torch.Tensor] = None, ) -> List[Tuple[List[int], str]]: req_outputs = self.model.generate(prompts, - sampling_params=sampling_params) + sampling_params=sampling_params, + prompt_embeds=prompt_embeds) outputs = [] for req_output in req_outputs: prompt_str = req_output.prompt @@ -199,9 +204,12 @@ def generate_greedy( self, prompts: List[str], max_tokens: int, + prompt_embeds: List[torch.Tensor] = None, ) -> List[Tuple[List[int], str]]: greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens) - outputs = self.generate(prompts, greedy_params) + outputs = self.generate(prompts, + greedy_params, + prompt_embeds=prompt_embeds) return [(output_ids[0], output_str[0]) for output_ids, output_str in outputs] @@ -210,12 +218,15 @@ def generate_beam_search( prompts: List[str], beam_width: int, max_tokens: int, + prompt_embeds: List[torch.Tensor] = None, ) -> List[Tuple[List[int], str]]: beam_search_params = SamplingParams(n=beam_width, use_beam_search=True, temperature=0.0, max_tokens=max_tokens) - outputs = self.generate(prompts, beam_search_params) + outputs = self.generate(prompts, + beam_search_params, + prompt_embeds=prompt_embeds) return outputs diff --git a/tests/models/test_models.py b/tests/models/test_models.py index 40858a517b31..9b3d496ac943 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -3,6 +3,7 @@ Run `pytest tests/models/test_models.py --forked`. """ import pytest +from vllm.sampling_params import SamplingParams MODELS = [ "facebook/opt-125m", "meta-llama/Llama-2-7b-hf", @@ -39,3 +40,44 @@ def test_models( f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") assert hf_output_ids == vllm_output_ids, ( f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [128]) +def test_models_from_prompt_embeds( + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + vllm_model = vllm_runner(model, dtype=dtype) + tokenizer = vllm_model.model.llm_engine.tokenizer + input_embeddings = vllm_model.model.llm_engine.workers[ + 0].model_runner.model.get_input_embeddings() + + prompt_embeds = [] + for prompt in example_prompts: + token_ids = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda") + token_embeds = input_embeddings(token_ids) + prompt_embeds.append(token_embeds[0]) + + outputs_from_prompts = vllm_model.model.generate( + example_prompts, + sampling_params=SamplingParams(temperature=0.0, max_tokens=max_tokens), + prompt_embeds=None) + outputs_from_embeds = vllm_model.model.generate( + None, + sampling_params=SamplingParams(temperature=0.0, max_tokens=max_tokens), + prompt_embeds=prompt_embeds, + ) + del vllm_model + + for output_prompt, output_embed in zip(outputs_from_prompts, + outputs_from_embeds): + assert output_prompt.outputs[0].token_ids == output_embed.outputs[ + 0].token_ids, ( + f"output_prompt: {output_prompt.outputs[0].token_ids}\n", + f"output_embed: {output_embed.outputs[0].token_ids}", + ) diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index bcd0cd60bfc5..a0acf67c26f2 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -19,24 +19,29 @@ def __init__(self, vocab_size: int, fake_logits: torch.Tensor): self.fake_logits = fake_logits def forward(self, *args, **kwargs): - with patch("vllm.model_executor.layers.sampler._prune_hidden_states", - lambda x, y: x), patch( - "vllm.model_executor.layers.sampler._get_logits", - lambda *args, **kwargs: self.fake_logits): + with patch( + "vllm.model_executor.layers.sampler._prune_hidden_states", + lambda x, y: x, + ), patch( + "vllm.model_executor.layers.sampler._get_logits", + lambda *args, **kwargs: self.fake_logits, + ): return super().forward(*args, **kwargs) def _prepare_test( - batch_size: int + batch_size: int, ) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, ModelRunner]: vocab_size = 32000 input_tensor = torch.rand((batch_size, 1024), device="cuda", dtype=torch.float16) - fake_logits = torch.full((batch_size, vocab_size), - 1e-2, - device=input_tensor.device, - dtype=input_tensor.dtype) + fake_logits = torch.full( + (batch_size, vocab_size), + 1e-2, + device=input_tensor.device, + dtype=input_tensor.dtype, + ) sampler = MockLogitsSampler(32000, fake_logits) model_runner = ModelRunner(None, None, None) return input_tensor, fake_logits, sampler, model_runner diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index cbf2978c01c2..9aed2ab87186 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -3,6 +3,7 @@ from functools import partial from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union, AsyncIterator) +import torch from vllm.config import ModelConfig from vllm.engine.arg_utils import AsyncEngineArgs @@ -372,6 +373,7 @@ async def add_request( prompt_token_ids: Optional[List[int]] = None, arrival_time: Optional[float] = None, prefix_pos: Optional[int] = None, + prompt_embeds: Optional[torch.Tensor] = None, ) -> AsyncStream: if self.log_requests: shortened_prompt = prompt @@ -404,7 +406,9 @@ async def add_request( sampling_params=sampling_params, prompt_token_ids=prompt_token_ids, arrival_time=arrival_time, - prefix_pos=prefix_pos) + prefix_pos=prefix_pos, + prompt_embeds=prompt_embeds, + ) return stream @@ -415,6 +419,7 @@ async def generate( request_id: str, prompt_token_ids: Optional[List[int]] = None, prefix_pos: Optional[int] = None, + prompt_embeds: Optional[torch.Tensor] = None, ) -> AsyncIterator[RequestOutput]: """Generate outputs for a request. @@ -492,7 +497,8 @@ async def generate( sampling_params, prompt_token_ids=prompt_token_ids, arrival_time=arrival_time, - prefix_pos=prefix_pos) + prefix_pos=prefix_pos, + prompt_embeds=prompt_embeds) async for request_output in stream: yield request_output diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 85f520871cd9..8515ef13ebdb 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -4,6 +4,7 @@ import time from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union) +import torch from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig) @@ -340,6 +341,7 @@ def add_request( prompt_token_ids: Optional[List[int]] = None, arrival_time: Optional[float] = None, prefix_pos: Optional[int] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, ) -> None: """Add a request to the engine's request pool. @@ -349,8 +351,8 @@ def add_request( Args: request_id: The unique ID of the request. - prompt: The prompt string. Can be None if prompt_token_ids is - provided. + prompt: The prompt string. Can be None if prompt_token_ids + or prompt_embeds are provided. sampling_params: The sampling parameters for text generation. prompt_token_ids: The token IDs of the prompt. If None, we use the tokenizer to convert the prompts to token IDs. @@ -385,17 +387,27 @@ def add_request( >>> SamplingParams(temperature=0.0)) >>> # continue the request processing >>> ... + prompt_embeds: The prompt embeddings. If set, + input prompt and prompt_token_ids are ignored """ if arrival_time is None: arrival_time = time.monotonic() - if prompt_token_ids is None: + + # If prompt_embeds is set, prompt_token_ids is filled with 0 + if prompt_embeds is not None: + prompt_token_ids = [0] * prompt_embeds.size(0) + elif prompt_token_ids is None: assert prompt is not None prompt_token_ids = self.tokenizer.encode(prompt) # Create the sequences. block_size = self.cache_config.block_size seq_id = next(self.seq_counter) - seq = Sequence(seq_id, prompt, prompt_token_ids, block_size) + seq = Sequence(seq_id, + prompt, + prompt_token_ids, + block_size, + prompt_embeds=prompt_embeds) # Check whether the input specifies prefix prefix = self.scheduler.prefix_pool.add_or_get_prefix( @@ -835,10 +847,17 @@ def _log_system_stats( def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None: """Decodes the new token for a sequence.""" + + # if data has prompt embeds, all_input_ids are only output token ids + if seq.data.has_prompt_embeds_forwarding(): + all_input_ids = seq.get_output_token_ids() + else: + all_input_ids = seq.get_token_ids() + (new_tokens, new_output_text, prefix_offset, read_offset) = detokenize_incrementally( self.tokenizer, - all_input_ids=seq.get_token_ids(), + all_input_ids=all_input_ids, prev_tokens=seq.tokens, prefix_offset=seq.prefix_offset, read_offset=seq.read_offset, diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index f7b8d258fae4..2534b33a1fa0 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -5,6 +5,7 @@ from fastapi import FastAPI, Request from fastapi.responses import JSONResponse, Response, StreamingResponse import uvicorn +import torch from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine @@ -28,12 +29,18 @@ async def generate(request: Request) -> Response: The request should be a JSON object with the following fields: - prompt: the prompt to use for the generation. + - prompt_embeds: the prompt embedding to use for the generation + instead of the prompt. - stream: whether to stream the results or not. - other fields: the sampling parameters (See `SamplingParams` for details). """ request_dict = await request.json() prompt = request_dict.pop("prompt") prefix_pos = request_dict.pop("prefix_pos", None) + prompt_embeds = request_dict.pop("prompt_embeds", None) + if prompt_embeds is not None: + prompt_embeds = torch.tensor(prompt_embeds).to("cuda") + prompt = None stream = request_dict.pop("stream", False) sampling_params = SamplingParams(**request_dict) request_id = random_uuid() @@ -41,7 +48,8 @@ async def generate(request: Request) -> Response: results_generator = engine.generate(prompt, sampling_params, request_id, - prefix_pos=prefix_pos) + prefix_pos=prefix_pos, + prompt_embeds=prompt_embeds) # Streaming case async def stream_results() -> AsyncGenerator[bytes, None]: @@ -67,7 +75,12 @@ async def stream_results() -> AsyncGenerator[bytes, None]: assert final_output is not None prompt = final_output.prompt - text_outputs = [prompt + output.text for output in final_output.outputs] + if prompt: + text_outputs = [ + prompt + output.text for output in final_output.outputs + ] + else: + text_outputs = [output.text for output in final_output.outputs] ret = {"text": text_outputs} return JSONResponse(ret) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index b819e233c06b..13e9cfe4268d 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -3,6 +3,8 @@ from tqdm import tqdm from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast +import torch + from vllm.engine.arg_utils import EngineArgs from vllm.engine.llm_engine import LLMEngine from vllm.outputs import RequestOutput @@ -106,7 +108,7 @@ def __init__( self.request_counter = Counter() def get_tokenizer( - self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + self, ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: return self.llm_engine.tokenizer def set_tokenizer( @@ -122,6 +124,7 @@ def generate( prompt_token_ids: Optional[List[List[int]]] = None, prefix_pos: Optional[Union[int, List[int]]] = None, use_tqdm: bool = True, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, ) -> List[RequestOutput]: """Generates the completions for the input prompts. @@ -141,14 +144,19 @@ def generate( This is an experimental feature, and may be replaced with automatic prefix caching in the future. use_tqdm: Whether to use tqdm to display the progress bar. + prompt_embeds: A list of embeddings for the prompts. If set, we + directly pass the embeddings instead of passing + `prompt_token_ids`. Returns: A list of `RequestOutput` objects containing the generated completions in the same order as the input prompts. """ - if prompts is None and prompt_token_ids is None: - raise ValueError("Either prompts or prompt_token_ids must be " - "provided.") + if (prompts is None and prompt_token_ids is None + and prompt_embeds is None): + raise ValueError( + "Either prompts, prompt_token_ids or prompt_token_embeds " + "must be provided.") if isinstance(prompts, str): # Convert a single prompt to a list. prompts = [prompts] @@ -161,14 +169,24 @@ def generate( sampling_params = SamplingParams() # Add requests to the engine. - num_requests = len(prompts) if prompts is not None else len( - prompt_token_ids) + if prompts is not None: + num_requests = len(prompts) + elif prompt_token_ids is not None: + num_requests = len(prompt_token_ids) + elif prompt_embeds is not None: + num_requests = len(prompt_embeds) + for i in range(num_requests): prompt = prompts[i] if prompts is not None else None prefix_pos_i = prefix_pos[i] if prefix_pos is not None else None - token_ids = None if prompt_token_ids is None else prompt_token_ids[ - i] - self._add_request(prompt, sampling_params, token_ids, prefix_pos_i) + token_ids = (None + if prompt_token_ids is None else prompt_token_ids[i]) + embeds = None if prompt_embeds is None else prompt_embeds[i] + self._add_request(prompt, + sampling_params, + token_ids, + prefix_pos_i, + prompt_embeds=embeds) return self._run_engine(use_tqdm) def _add_request( @@ -177,13 +195,15 @@ def _add_request( sampling_params: SamplingParams, prompt_token_ids: Optional[List[int]], prefix_pos: Optional[int] = None, + prompt_embeds: Optional[torch.Tensor] = None, ) -> None: request_id = str(next(self.request_counter)) self.llm_engine.add_request(request_id, prompt, sampling_params, - prompt_token_ids, - prefix_pos=prefix_pos) + prompt_token_ids=prompt_token_ids, + prefix_pos=prefix_pos, + prompt_embeds=prompt_embeds) def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]: # Initialize tqdm. diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py index ef49cc5902ea..cf907421fb8b 100644 --- a/vllm/model_executor/input_metadata.py +++ b/vllm/model_executor/input_metadata.py @@ -12,6 +12,7 @@ class InputMetadata: max_context_len: The maximum context length. context_lens: the length of attention context for each sequence. block_tables: The block tables. (Seq id -> list of physical block) + """ def __init__( @@ -25,6 +26,7 @@ def __init__( context_lens: Optional[torch.Tensor], block_tables: Optional[torch.Tensor], use_cuda_graph: bool, + prompt_embeds_indices: Optional[torch.Tensor], ) -> None: self.is_prompt = is_prompt self.prompt_lens = prompt_lens @@ -35,6 +37,7 @@ def __init__( self.context_lens = context_lens self.block_tables = block_tables self.use_cuda_graph = use_cuda_graph + self.prompt_embeds_indices = prompt_embeds_indices # Set during the execution of the first attention op. # FIXME(woosuk): This is a hack. diff --git a/vllm/model_executor/models/aquila.py b/vllm/model_executor/models/aquila.py index 2f2bd5ffb4a6..11e227078af1 100644 --- a/vllm/model_executor/models/aquila.py +++ b/vllm/model_executor/models/aquila.py @@ -29,19 +29,26 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.attention import PagedAttention -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + LinearMethodBase, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding, ParallelLMHead) + VocabParallelEmbedding, + ParallelLMHead, +) from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) + get_tensor_model_parallel_world_size, ) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) +from vllm.model_executor.utils import replace_prompt_embeds +from vllm.model_executor.weight_utils import ( + default_weight_loader, + hf_model_weights_iterator, +) from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs.aquila import AquilaConfig @@ -59,13 +66,17 @@ def __init__( ): super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, + hidden_size, + [intermediate_size] * 2, bias=False, - linear_method=linear_method) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - linear_method=linear_method) + linear_method=linear_method, + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + linear_method=linear_method, + ) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -147,10 +158,12 @@ def __init__( base=self.rope_theta, rope_scaling=rope_scaling, ) - self.attn = PagedAttention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads) + self.attn = PagedAttention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + ) def forward( self, @@ -254,8 +267,16 @@ def forward( positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, + prompt_embeds: torch.Tensor = None, ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) + inputs_embeds = self.embed_tokens(input_ids) + if prompt_embeds is not None: + inputs_embeds = replace_prompt_embeds( + inputs_embeds, + prompt_embeds, + input_metadata.prompt_embeds_indices, + ) + hidden_states = inputs_embeds for i in range(len(self.layers)): layer = self.layers[i] hidden_states = layer( @@ -289,9 +310,10 @@ def forward( positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, + prompt_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - input_metadata) + input_metadata, prompt_embeds) return hidden_states def sample( @@ -303,11 +325,13 @@ def sample( sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights( + self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None, + ): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -321,7 +345,7 @@ def load_weights(self, model_name_or_path, cache_dir, load_format, revision): if "rotary_emb.inv_freq" in name: continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -340,3 +364,6 @@ def load_weights(self, weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + + def get_input_embeddings(self): + return self.model.embed_tokens diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index f08c3c8d257f..07fa78fb7b70 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -28,19 +28,28 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + LinearMethodBase, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding, ParallelLMHead) + VocabParallelEmbedding, + ParallelLMHead, +) from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) +from vllm.model_executor.utils import replace_prompt_embeds +from vllm.model_executor.weight_utils import ( + default_weight_loader, + hf_model_weights_iterator, +) from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs.baichuan import BaiChuanConfig @@ -83,13 +92,17 @@ def __init__( ): super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, + hidden_size, + [intermediate_size] * 2, + bias=False, + linear_method=linear_method, + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, bias=False, - linear_method=linear_method) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - linear_method=linear_method) + linear_method=linear_method, + ) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -116,8 +129,8 @@ def __init__( ): super().__init__() self.hidden_size = hidden_size - tensor_model_parallel_world_size = get_tensor_model_parallel_world_size( - ) + tensor_model_parallel_world_size = ( + get_tensor_model_parallel_world_size()) self.total_num_heads = num_heads assert self.total_num_heads % tensor_model_parallel_world_size == 0 self.num_heads = (self.total_num_heads // @@ -151,10 +164,12 @@ def __init__( alibi_slopes = alibi_slopes[head_start:head_end].tolist() scaling = self.head_dim**-0.5 - self.attn = PagedAttention(self.num_heads, - self.head_dim, - scaling, - alibi_slopes=alibi_slopes) + self.attn = PagedAttention( + self.num_heads, + self.head_dim, + scaling, + alibi_slopes=alibi_slopes, + ) else: self.rotary_emb = get_rope( self.head_dim, @@ -185,10 +200,12 @@ def forward( class BaiChuanDecoderLayer(nn.Module): - def __init__(self, - config: BaiChuanConfig, - position_embedding: str, - linear_method: Optional[LinearMethodBase] = None): + def __init__( + self, + config: BaiChuanConfig, + position_embedding: str, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) @@ -244,10 +261,12 @@ def forward( class BaiChuanModel(nn.Module): - def __init__(self, - config: BaiChuanConfig, - position_embedding: str, - linear_method: Optional[LinearMethodBase] = None): + def __init__( + self, + config: BaiChuanConfig, + position_embedding: str, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.config = config self.padding_idx = config.pad_token_id @@ -269,8 +288,16 @@ def forward( positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, + prompt_embeds: torch.Tensor = None, ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) + inputs_embeds = self.embed_tokens(input_ids) + if prompt_embeds is not None: + inputs_embeds = replace_prompt_embeds( + inputs_embeds, + prompt_embeds, + input_metadata.prompt_embeds_indices, + ) + hidden_states = inputs_embeds residual = None for i in range(len(self.layers)): layer = self.layers[i] @@ -287,10 +314,12 @@ def forward( class BaiChuanBaseForCausalLM(nn.Module): - def __init__(self, - config, - position_embedding: str, - linear_method: Optional[LinearMethodBase] = None): + def __init__( + self, + config, + position_embedding: str, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.config = config self.linear_method = linear_method @@ -304,9 +333,10 @@ def forward( positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, + prompt_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - input_metadata) + input_metadata, prompt_embeds) return hidden_states def sample( @@ -318,11 +348,13 @@ def sample( sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights( + self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None, + ): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("gate_up_proj", "gate_proj", 0), @@ -344,7 +376,7 @@ def load_weights(self, loaded_weight = torch.nn.functional.normalize( loaded_weight) - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -364,6 +396,9 @@ def load_weights(self, default_weight_loader) weight_loader(param, loaded_weight) + def get_input_embeddings(self): + return self.model.embed_tokens + class BaichuanForCausalLM(BaiChuanBaseForCausalLM): """Baichuan 13B and Baichuan2 7B/13B.""" diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 4adfb6b78102..81a8ff88c68d 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -36,6 +36,7 @@ from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.utils import replace_prompt_embeds from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -246,9 +247,16 @@ def forward( position_ids: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, + prompt_embeds: torch.Tensor = None, ) -> torch.Tensor: - hidden_states = self.word_embeddings(input_ids) - hidden_states = self.word_embeddings_layernorm(hidden_states) + inputs_embeds = self.word_embeddings(input_ids) + if prompt_embeds is not None: + inputs_embeds = replace_prompt_embeds( + inputs_embeds, + prompt_embeds, + input_metadata.prompt_embeds_indices, + ) + hidden_states = self.word_embeddings_layernorm(inputs_embeds) for i in range(len(self.h)): layer = self.h[i] hidden_states = layer( @@ -281,9 +289,10 @@ def forward( positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, + prompt_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, - input_metadata) + input_metadata, prompt_embeds) return hidden_states def sample( @@ -328,3 +337,6 @@ def load_weights(self, weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + + def get_input_embeddings(self): + return self.transformer.word_embeddings diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index dca8d724f976..1b6a4a6e3ddd 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -23,6 +23,7 @@ from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.utils import replace_prompt_embeds from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -307,8 +308,15 @@ def forward( position_ids: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, + prompt_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: inputs_embeds = self.embedding(input_ids) + if prompt_embeds is not None: + inputs_embeds = replace_prompt_embeds( + inputs_embeds, + prompt_embeds, + input_metadata.prompt_embeds_indices, + ) # Run encoder. hidden_states = self.encoder( @@ -340,9 +348,10 @@ def forward( positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, + prompt_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, - input_metadata) + input_metadata, prompt_embeds) return hidden_states def sample( @@ -373,3 +382,6 @@ def load_weights(self, weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + + def get_input_embeddings(self): + return self.transformer.embedding diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 2b5e022312e3..f92186cb6ce9 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -29,21 +29,30 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.attention import PagedAttention -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearMethodBase, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + LinearMethodBase, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding, ParallelLMHead) + VocabParallelEmbedding, + ParallelLMHead, +) from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_all_reduce) + tensor_model_parallel_all_reduce, ) from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) +from vllm.model_executor.utils import replace_prompt_embeds +from vllm.model_executor.weight_utils import ( + default_weight_loader, + hf_model_weights_iterator, +) from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs import RWConfig @@ -61,7 +70,8 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: if closest_power_of_2 != total_num_heads: extra_base = torch.tensor( 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))), - dtype=torch.float32) + dtype=torch.float32, + ) num_remaining_heads = min(closest_power_of_2, total_num_heads - closest_power_of_2) extra_powers = torch.arange(1, @@ -133,12 +143,13 @@ def __init__( bias=config.bias, skip_bias_add=True, linear_method=linear_method, - reduce_results=self.reduce_row_parallel_results) + reduce_results=self.reduce_row_parallel_results, + ) self.use_rotary = config.rotary self.use_alibi = config.alibi - assert not (self.use_rotary and self.use_alibi), ( - "Rotary and alibi are mutually exclusive.") + assert not (self.use_rotary and + self.use_alibi), "Rotary and alibi are mutually exclusive." if self.use_rotary: rope_theta = getattr(config, "rope_theta", 10000) @@ -150,10 +161,12 @@ def __init__( max_position=max_position_embeddings, base=rope_theta, ) - self.attn = PagedAttention(self.num_heads, - self.head_dim, - self.inv_norm_factor, - num_kv_heads=self.num_kv_heads) + self.attn = PagedAttention( + self.num_heads, + self.head_dim, + self.inv_norm_factor, + num_kv_heads=self.num_kv_heads, + ) elif self.use_alibi: tp_rank = get_tensor_model_parallel_rank() head_start = tp_rank * self.num_heads @@ -161,16 +174,20 @@ def __init__( alibi_slopes = (_get_alibi_slopes(self.total_num_heads) * self.inv_norm_factor) alibi_slopes = alibi_slopes[head_start:head_end].tolist() - self.attn = PagedAttention(self.num_heads, - self.head_dim, - self.inv_norm_factor, - num_kv_heads=self.num_kv_heads, - alibi_slopes=alibi_slopes) + self.attn = PagedAttention( + self.num_heads, + self.head_dim, + self.inv_norm_factor, + num_kv_heads=self.num_kv_heads, + alibi_slopes=alibi_slopes, + ) else: - self.attn = PagedAttention(self.num_heads, - self.head_dim, - scale=self.inv_norm_factor, - num_kv_heads=self.num_kv_heads) + self.attn = PagedAttention( + self.num_heads, + self.head_dim, + scale=self.inv_norm_factor, + num_kv_heads=self.num_kv_heads, + ) def forward( self, @@ -201,11 +218,13 @@ def __init__( super().__init__() hidden_size = config.hidden_size - self.dense_h_to_4h = ColumnParallelLinear(hidden_size, - 4 * hidden_size, - bias=config.bias, - skip_bias_add=True, - linear_method=linear_method) + self.dense_h_to_4h = ColumnParallelLinear( + hidden_size, + 4 * hidden_size, + bias=config.bias, + skip_bias_add=True, + linear_method=linear_method, + ) quant_config = getattr(linear_method, "quant_config", None) self.act = get_act_fn("gelu", quant_config, 4 * hidden_size) self.reduce_row_parallel_results = not (config.new_decoder_architecture @@ -216,7 +235,8 @@ def __init__( bias=config.bias, skip_bias_add=True, reduce_results=self.reduce_row_parallel_results, - linear_method=linear_method) + linear_method=linear_method, + ) def forward(self, x: torch.Tensor) -> torch.Tensor: # NOTE(zhuohan): Following huggingface, we do not fuse bias add here. @@ -344,8 +364,16 @@ def forward( positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, + prompt_embeds: torch.Tensor = None, ) -> torch.Tensor: - hidden_states = self.word_embeddings(input_ids) + inputs_embeds = self.word_embeddings(input_ids) + if prompt_embeds is not None: + inputs_embeds = replace_prompt_embeds( + inputs_embeds, + prompt_embeds, + input_metadata.prompt_embeds_indices, + ) + hidden_states = inputs_embeds for i in range(len(self.h)): layer = self.h[i] hidden_states = layer( @@ -381,12 +409,14 @@ def forward( positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, + prompt_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.transformer( input_ids, positions, kv_caches, input_metadata, + prompt_embeds, ) return hidden_states @@ -399,11 +429,13 @@ def sample( sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights( + self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None, + ): total_num_heads = self.config.num_attention_heads if self.config.new_decoder_architecture: total_num_kv_heads = self.config.num_kv_heads @@ -445,3 +477,6 @@ def load_weights(self, weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + + def get_input_embeddings(self): + return self.transformer.word_embeddings diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 661da0fe0434..2eaa0b550558 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -26,18 +26,23 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.attention import PagedAttention -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearMethodBase, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + LinearMethodBase, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) + VocabParallelEmbedding, ) from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) + get_tensor_model_parallel_world_size, ) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) +from vllm.model_executor.utils import replace_prompt_embeds +from vllm.model_executor.weight_utils import ( + default_weight_loader, + hf_model_weights_iterator, +) from vllm.sequence import SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -193,8 +198,16 @@ def forward( position_ids: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, + prompt_embeds: torch.Tensor = None, ) -> torch.Tensor: inputs_embeds = self.wte(input_ids) + if prompt_embeds is not None: + inputs_embeds = replace_prompt_embeds( + inputs_embeds, + prompt_embeds, + input_metadata.prompt_embeds_indices, + ) + position_embeds = self.wpe(position_ids) hidden_states = inputs_embeds + position_embeds @@ -226,9 +239,10 @@ def forward( positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, + prompt_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, - input_metadata) + input_metadata, prompt_embeds) return hidden_states def sample( @@ -240,11 +254,13 @@ def sample( sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights( + self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None, + ): params_dict = dict(self.named_parameters(remove_duplicate=False)) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): @@ -271,3 +287,6 @@ def load_weights(self, weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + + def get_input_embeddings(self): + return self.transformer.wte diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index ef4c1d4143c8..45130a9363e2 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -27,18 +27,23 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.attention import PagedAttention -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearMethodBase, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + LinearMethodBase, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) + VocabParallelEmbedding, ) from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) + get_tensor_model_parallel_world_size, ) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) +from vllm.model_executor.utils import replace_prompt_embeds +from vllm.model_executor.weight_utils import ( + default_weight_loader, + hf_model_weights_iterator, +) from vllm.sequence import SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -85,10 +90,12 @@ def __init__( bias=True, linear_method=linear_method, ) - self.attn = PagedAttention(self.num_heads, - self.head_dim, - scale=self.scale, - num_kv_heads=self.num_kv_heads) + self.attn = PagedAttention( + self.num_heads, + self.head_dim, + scale=self.scale, + num_kv_heads=self.num_kv_heads, + ) def forward( self, @@ -100,7 +107,8 @@ def forward( q, k, v = qkv.split( [ self.hidden_size // self.tensor_model_parallel_world_size, - self.kv_dim, self.kv_dim + self.kv_dim, + self.kv_dim, ], dim=-1, ) @@ -212,8 +220,15 @@ def forward( position_ids: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, + prompt_embeds: torch.Tensor = None, ) -> torch.Tensor: inputs_embeds = self.wte(input_ids) + if prompt_embeds is not None: + inputs_embeds = replace_prompt_embeds( + inputs_embeds, + prompt_embeds, + input_metadata.prompt_embeds_indices, + ) position_embeds = self.wpe(position_ids) hidden_states = inputs_embeds + position_embeds @@ -245,9 +260,10 @@ def forward( positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, + prompt_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, - input_metadata) + input_metadata, prompt_embeds) return hidden_states def sample( @@ -259,11 +275,13 @@ def sample( sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights( + self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None, + ): params_dict = dict(self.named_parameters(remove_duplicate=False)) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): @@ -277,3 +295,6 @@ def load_weights(self, weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + + def get_input_embeddings(self): + return self.transformer.wte diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index 5bab30d9d442..95adb4359ed8 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -25,19 +25,26 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.attention import PagedAttention -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearMethodBase, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + LinearMethodBase, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding, ParallelLMHead) + VocabParallelEmbedding, + ParallelLMHead, +) from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) + get_tensor_model_parallel_world_size, ) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) +from vllm.model_executor.utils import replace_prompt_embeds +from vllm.model_executor.weight_utils import ( + default_weight_loader, + hf_model_weights_iterator, +) from vllm.sequence import SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -143,7 +150,8 @@ def __init__( linear_method: Optional[LinearMethodBase] = None, ): super().__init__() - inner_dim = 4 * config.n_embd if config.n_inner is None else config.n_inner + inner_dim = (4 * config.n_embd + if config.n_inner is None else config.n_inner) self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) self.attn = GPTJAttention(config, linear_method) self.mlp = GPTJMLP(inner_dim, config, linear_method) @@ -192,8 +200,17 @@ def forward( position_ids: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, + prompt_embeds: torch.Tensor = None, ) -> torch.Tensor: - hidden_states = self.wte(input_ids) + inputs_embeds = self.wte(input_ids) + if prompt_embeds is not None: + inputs_embeds = replace_prompt_embeds( + inputs_embeds, + prompt_embeds, + input_metadata.prompt_embeds_indices, + ) + hidden_states = inputs_embeds + for i in range(len(self.h)): layer = self.h[i] hidden_states = layer( @@ -231,9 +248,10 @@ def forward( positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, + prompt_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, - input_metadata) + input_metadata, prompt_embeds) return hidden_states def sample( @@ -245,11 +263,13 @@ def sample( sampling_metadata, self.lm_head.bias) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights( + self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None, + ): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -263,7 +283,7 @@ def load_weights(self, model_name_or_path, cache_dir, load_format, revision): if "attn.bias" in name or "attn.masked_bias" in name: continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -282,3 +302,6 @@ def load_weights(self, weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + + def get_input_embeddings(self): + return self.transformer.wte diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index 8f7e1063e0c1..fbf9f5199717 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -25,19 +25,26 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.attention import PagedAttention -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearMethodBase, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + LinearMethodBase, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding, ParallelLMHead) + VocabParallelEmbedding, + ParallelLMHead, +) from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) + get_tensor_model_parallel_world_size, ) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) +from vllm.model_executor.utils import replace_prompt_embeds +from vllm.model_executor.weight_utils import ( + default_weight_loader, + hf_model_weights_iterator, +) from vllm.sequence import SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -209,8 +216,17 @@ def forward( position_ids: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, + prompt_embeds: torch.Tensor = None, ) -> torch.Tensor: - hidden_states = self.embed_in(input_ids) + inputs_embeds = self.embed_in(input_ids) + if prompt_embeds is not None: + inputs_embeds = replace_prompt_embeds( + inputs_embeds, + prompt_embeds, + input_metadata.prompt_embeds_indices, + ) + hidden_states = inputs_embeds + for i in range(len(self.layers)): layer = self.layers[i] hidden_states = layer( @@ -246,9 +262,10 @@ def forward( positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, + prompt_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.gpt_neox(input_ids, positions, kv_caches, - input_metadata) + input_metadata, prompt_embeds) return hidden_states def sample( @@ -260,11 +277,13 @@ def sample( sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights( + self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None, + ): params_dict = dict(self.named_parameters()) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): @@ -292,3 +311,6 @@ def load_weights(self, weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + + def get_input_embeddings(self): + return self.gpt_neox.embed_in diff --git a/vllm/model_executor/models/internlm.py b/vllm/model_executor/models/internlm.py index 5d0b93793c89..e2c6ad5b407a 100644 --- a/vllm/model_executor/models/internlm.py +++ b/vllm/model_executor/models/internlm.py @@ -9,19 +9,26 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + LinearMethodBase, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding, ParallelLMHead) + VocabParallelEmbedding, + ParallelLMHead, +) from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) + get_tensor_model_parallel_world_size, ) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) +from vllm.model_executor.utils import replace_prompt_embeds +from vllm.model_executor.weight_utils import ( + default_weight_loader, + hf_model_weights_iterator, +) from vllm.sequence import SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -38,13 +45,17 @@ def __init__( ): super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, + hidden_size, + [intermediate_size] * 2, + bias=False, + linear_method=linear_method, + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, bias=False, - linear_method=linear_method) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - linear_method=linear_method) + linear_method=linear_method, + ) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -210,8 +221,15 @@ def forward( positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, + prompt_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) + if prompt_embeds is not None: + hidden_states = replace_prompt_embeds( + hidden_states, + prompt_embeds, + input_metadata.prompt_embeds_indices, + ) residual = None for i in range(len(self.layers)): layer = self.layers[i] @@ -246,9 +264,10 @@ def forward( positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, + prompt_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - input_metadata) + input_metadata, prompt_embeds) return hidden_states def sample( @@ -260,11 +279,13 @@ def sample( sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights( + self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None, + ): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -278,7 +299,7 @@ def load_weights(self, model_name_or_path, cache_dir, load_format, revision): if "rotary_emb.inv_freq" in name: continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -297,3 +318,6 @@ def load_weights(self, weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + + def get_input_embeddings(self): + return self.model.embed_tokens diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 3791aa893893..449a3b896178 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -31,19 +31,26 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + LinearMethodBase, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding, ParallelLMHead) + VocabParallelEmbedding, + ParallelLMHead, +) from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) + get_tensor_model_parallel_world_size, ) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) +from vllm.model_executor.utils import replace_prompt_embeds +from vllm.model_executor.weight_utils import ( + default_weight_loader, + hf_model_weights_iterator, +) from vllm.sequence import SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -60,13 +67,17 @@ def __init__( ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, + hidden_size, + [intermediate_size] * 2, + bias=False, + linear_method=linear_method, + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, bias=False, - linear_method=linear_method) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - linear_method=linear_method) + linear_method=linear_method, + ) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -136,10 +147,12 @@ def __init__( base=rope_theta, rope_scaling=rope_scaling, ) - self.attn = PagedAttention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads) + self.attn = PagedAttention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + ) def forward( self, @@ -246,8 +259,16 @@ def forward( positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, + prompt_embeds: torch.Tensor = None, ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) + inputs_embeds = self.embed_tokens(input_ids) + if prompt_embeds is not None: + inputs_embeds = replace_prompt_embeds( + inputs_embeds, + prompt_embeds, + input_metadata.prompt_embeds_indices, + ) + hidden_states = inputs_embeds residual = None for i in range(len(self.layers)): layer = self.layers[i] @@ -282,9 +303,10 @@ def forward( positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, + prompt_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - input_metadata) + input_metadata, prompt_embeds) return hidden_states def sample( @@ -296,11 +318,13 @@ def sample( sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights( + self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None, + ): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -338,3 +362,6 @@ def load_weights(self, weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + + def get_input_embeddings(self): + return self.model.embed_tokens diff --git a/vllm/model_executor/models/mistral.py b/vllm/model_executor/models/mistral.py index 70d033fec69f..1b778bb7c709 100644 --- a/vllm/model_executor/models/mistral.py +++ b/vllm/model_executor/models/mistral.py @@ -31,19 +31,26 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + LinearMethodBase, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding, ParallelLMHead) + VocabParallelEmbedding, + ParallelLMHead, +) from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) + get_tensor_model_parallel_world_size, ) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) +from vllm.model_executor.utils import replace_prompt_embeds +from vllm.model_executor.weight_utils import ( + default_weight_loader, + hf_model_weights_iterator, +) from vllm.sequence import SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -60,13 +67,17 @@ def __init__( ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, + hidden_size, + [intermediate_size] * 2, + bias=False, + linear_method=linear_method, + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, bias=False, - linear_method=linear_method) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - linear_method=linear_method) + linear_method=linear_method, + ) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -81,14 +92,16 @@ def forward(self, x): class MistralAttention(nn.Module): - def __init__(self, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - max_position: int = 4096 * 32, - rope_theta: float = 10000, - linear_method: Optional[LinearMethodBase] = None, - sliding_window: Optional[int] = None) -> None: + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + linear_method: Optional[LinearMethodBase] = None, + sliding_window: Optional[int] = None, + ) -> None: super().__init__() self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() @@ -133,11 +146,13 @@ def __init__(self, max_position=max_position, base=self.rope_theta, ) - self.attn = PagedAttention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - sliding_window=self.sliding_window) + self.attn = PagedAttention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + sliding_window=self.sliding_window, + ) def forward( self, @@ -173,7 +188,8 @@ def __init__( num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, linear_method=linear_method, - sliding_window=config.sliding_window) + sliding_window=config.sliding_window, + ) self.mlp = MistralMLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, @@ -242,8 +258,16 @@ def forward( positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, + prompt_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) + if prompt_embeds is not None: + hidden_states = replace_prompt_embeds( + hidden_states, + prompt_embeds, + input_metadata.prompt_embeds_indices, + ) + residual = None for i in range(len(self.layers)): layer = self.layers[i] @@ -278,9 +302,10 @@ def forward( positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, + prompt_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - input_metadata) + input_metadata, prompt_embeds) return hidden_states def sample( @@ -292,11 +317,13 @@ def sample( sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights( + self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None, + ): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -310,7 +337,7 @@ def load_weights(self, model_name_or_path, cache_dir, load_format, revision): if "rotary_emb.inv_freq" in name: continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -329,3 +356,6 @@ def load_weights(self, weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + + def get_input_embeddings(self): + return self.model.embed_tokens diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index 22a876e2ef69..9d5605bf3067 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -19,6 +19,7 @@ from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.utils import replace_prompt_embeds from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -231,8 +232,17 @@ def forward( position_ids: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, + prompt_embeds: torch.Tensor = None, ) -> torch.Tensor: - hidden_states = self.wte(input_ids) + inputs_embeds = self.wte(input_ids) + if prompt_embeds is not None: + inputs_embeds = replace_prompt_embeds( + inputs_embeds, + prompt_embeds, + input_metadata.prompt_embeds_indices, + ) + hidden_states = inputs_embeds + for i in range(len(self.blocks)): block = self.blocks[i] hidden_states = block( @@ -267,9 +277,10 @@ def forward( positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, + prompt_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, - input_metadata) + input_metadata, prompt_embeds) return hidden_states def sample( @@ -296,3 +307,6 @@ def load_weights(self, weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + + def get_input_embeddings(self): + return self.transformer.wte diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 393b2dcabcd5..eccc969743eb 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -37,6 +37,7 @@ from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.utils import replace_prompt_embeds from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -242,8 +243,16 @@ def forward( positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, + prompt_embeds: torch.Tensor = None, ) -> torch.Tensor: inputs_embeds = self.embed_tokens(input_ids) + if prompt_embeds is not None: + inputs_embeds = replace_prompt_embeds( + inputs_embeds, + prompt_embeds, + input_metadata.prompt_embeds_indices, + ) + pos_embeds = self.embed_positions(positions) if self.project_in is not None: inputs_embeds, _ = self.project_in(inputs_embeds) @@ -276,8 +285,10 @@ def forward( positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, + prompt_embeds: torch.Tensor = None, ) -> torch.Tensor: - return self.decoder(input_ids, positions, kv_caches, input_metadata) + return self.decoder(input_ids, positions, kv_caches, input_metadata, + prompt_embeds) class OPTForCausalLM(nn.Module): @@ -300,9 +311,10 @@ def forward( positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, + prompt_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - input_metadata) + input_metadata, prompt_embeds) return hidden_states def sample( @@ -352,3 +364,6 @@ def load_weights(self, weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index fbc7320fb45a..80d96b63963e 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -24,6 +24,7 @@ from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.utils import replace_prompt_embeds from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -202,8 +203,15 @@ def forward( positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, + prompt_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.wte(input_ids) + if prompt_embeds is not None: + hidden_states = replace_prompt_embeds( + hidden_states, + prompt_embeds, + input_metadata.prompt_embeds_indices, + ) residual = None for i in range(len(self.h)): layer = self.h[i] @@ -238,9 +246,10 @@ def forward( positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, + prompt_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, - input_metadata) + input_metadata, prompt_embeds) return hidden_states def sample( @@ -286,3 +295,6 @@ def load_weights(self, weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + + def get_input_embeddings(self): + return self.transformer.wte diff --git a/vllm/model_executor/models/yi.py b/vllm/model_executor/models/yi.py index 53daa6c4cd93..9add25e31611 100644 --- a/vllm/model_executor/models/yi.py +++ b/vllm/model_executor/models/yi.py @@ -25,6 +25,7 @@ import torch from torch import nn +from vllm.model_executor.utils import replace_prompt_embeds from vllm.transformers_utils.configs.yi import YiConfig from vllm.model_executor.input_metadata import InputMetadata @@ -241,8 +242,16 @@ def forward( positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, + prompt_embeds: torch.Tensor = None, ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) + inputs_embeds = self.embed_tokens(input_ids) + if prompt_embeds is not None: + inputs_embeds = replace_prompt_embeds( + inputs_embeds, + prompt_embeds, + input_metadata.prompt_embeds_indices, + ) + hidden_states = inputs_embeds residual = None for i in range(len(self.layers)): layer = self.layers[i] @@ -277,9 +286,10 @@ def forward( positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, + prompt_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - input_metadata) + input_metadata, prompt_embeds) return hidden_states def sample( @@ -328,3 +338,6 @@ def load_weights(self, weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + + def get_input_embeddings(self): + return self.model.embed_tokens diff --git a/vllm/model_executor/utils.py b/vllm/model_executor/utils.py index 336bc1cd005c..dae90c1a104e 100644 --- a/vllm/model_executor/utils.py +++ b/vllm/model_executor/utils.py @@ -33,3 +33,13 @@ def set_weight_attrs( assert not hasattr( weight, key), (f"Overwriting existing tensor attribute: {key}") setattr(weight, key, value) + + +def replace_prompt_embeds( + inputs_embeds: torch.Tensor, + prompt_embeds: torch.Tensor, + prompt_embeds_indices: torch.Tensor, +): + inputs_embeds[prompt_embeds_indices] = torch.index_select( + prompt_embeds, 0, prompt_embeds_indices) + return inputs_embeds diff --git a/vllm/sequence.py b/vllm/sequence.py index ca647afce9f1..af5e2a12e2fb 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -3,6 +3,8 @@ import enum from typing import Dict, List, Optional, Union +import torch + from vllm.block import LogicalTokenBlock from vllm.prefix import Prefix from vllm.sampling_params import SamplingParams @@ -57,6 +59,8 @@ class SequenceData: Attributes: prompt_token_ids: The token IDs of the prompt. + prompt_embeds: The embeddings of the prompt + (If set, it takes priority over prompt_token_ids) output_token_ids: The token IDs of the output. cumulative_logprob: The cumulative log probability of the output. """ @@ -64,8 +68,10 @@ class SequenceData: def __init__( self, prompt_token_ids: List[int], + prompt_embeds: Optional[torch.Tensor] = None, ) -> None: self.prompt_token_ids = prompt_token_ids + self.prompt_embeds = prompt_embeds self.output_token_ids: List[int] = [] self.cumulative_logprob = 0.0 @@ -85,14 +91,21 @@ def get_output_len(self) -> int: def get_token_ids(self) -> List[int]: return self.prompt_token_ids + self.output_token_ids + def get_output_token_ids(self) -> List[int]: + return self.output_token_ids + def get_last_token_id(self) -> int: if not self.output_token_ids: return self.prompt_token_ids[-1] return self.output_token_ids[-1] + def has_prompt_embeds_forwarding(self) -> bool: + return self.prompt_embeds is not None + def __repr__(self) -> str: return (f"SequenceData(" f"prompt_token_ids={self.prompt_token_ids}, " + f"prompt_embeds={self.prompt_embeds}, " f"output_token_ids={self.output_token_ids}, " f"cumulative_logprob={self.cumulative_logprob})") @@ -106,20 +119,24 @@ class Sequence: prompt_token_ids: The token IDs of the prompt. block_size: The block size of the sequence. Should be the same as the block size used by the block manager and cache engine. + prompt_embeds: The embeddings of the prompt + (If set, it takes priority over prompt_token_ids) """ - def __init__( - self, - seq_id: int, - prompt: str, - prompt_token_ids: List[int], - block_size: int, - ) -> None: + def __init__(self, + seq_id: int, + prompt: str, + prompt_token_ids: List[int], + block_size: int, + prompt_embeds: Optional[torch.Tensor] = None) -> None: self.seq_id = seq_id self.prompt = prompt self.block_size = block_size - self.data = SequenceData(prompt_token_ids) + self.data = SequenceData( + prompt_token_ids, + prompt_embeds=prompt_embeds, + ) self.output_logprobs: SampleLogprobs = [] self.output_text = "" diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index aa37facb0ff9..c113bbbf8348 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -60,6 +60,22 @@ def __init__( # cache in_wsl result self.in_wsl = in_wsl() + self.graph_runners: Dict[int, CUDAGraphRunner] = {} + self.graph_memory_pool = None # Set during graph capture. + + self.max_context_len_to_capture = ( + self.model_config.max_context_len_to_capture + if self.model_config is not None else 0) + # When using CUDA graph, the input block tables must be padded to + # max_context_len_to_capture. However, creating the block table in + # Python can be expensive. To optimize this, we cache the block table + # in numpy and only copy the actual input content at every iteration. + # The shape of the cached block table will be + # (max batch size to capture, max context len to capture / block size). + self.graph_block_tables = None # Set after initial profiling. + # cache in_wsl result + self.in_wsl = in_wsl() + def load_model(self) -> None: self.model = get_model(self.model_config) @@ -80,12 +96,15 @@ def _prepare_prompt( input_tokens: List[List[int]] = [] input_positions: List[List[int]] = [] slot_mapping: List[List[int]] = [] + prompt_embeds: List[torch.Tensor] = [] + prompt_embeds_indices: List[int] = [] prompt_lens: List[int] = [] context_lens: List[int] = [] subquery_lens: List[int] = [] prefix_block_tables: List[List[int]] = [] - for seq_group_metadata in seq_group_metadata_list: + for seq_group_idx, seq_group_metadata in enumerate( + seq_group_metadata_list): assert seq_group_metadata.is_prompt seq_ids = list(seq_group_metadata.seq_data.keys()) assert len(seq_ids) == 1 @@ -113,6 +132,16 @@ def _prepare_prompt( input_positions.append( list(range(prefix_len, prefix_len + len(prompt_tokens)))) + if seq_data.has_prompt_embeds_forwarding(): + # If prompt_embeds are set, + # the token_ids of the prompt are treated as 0, + # so zero_token_embeds is excluded from prompt_embeds. + prompt_embeds.append(seq_data.prompt_embeds.to("cuda")) + prompt_embeds_indices.append(seq_group_idx) + else: + prompt_embeds.append( + self.zero_token_embeds.repeat(prompt_len, 1)) + if seq_group_metadata.block_tables is None: # During memory profiling, the block tables are not initialized # yet. In this case, we just use a dummy slot mapping. @@ -176,6 +205,21 @@ def _prepare_prompt( dtype=torch.long, device='cuda') + if prompt_embeds: + padded_prompt_embeds = [ + _pad_embeddings_to_max(embeds, max_prompt_len, + self.zero_token_embeds) + for embeds in prompt_embeds + ] + prompt_embeds = torch.stack(padded_prompt_embeds).to( + dtype=self.model_config.dtype, device="cuda") + else: + prompt_embeds = None + + prompt_embeds_indices = torch.tensor(prompt_embeds_indices, + device="cuda", + dtype=torch.int) + input_metadata = InputMetadata( is_prompt=True, slot_mapping=slot_mapping, @@ -186,9 +230,10 @@ def _prepare_prompt( context_lens=context_lens_tensor, block_tables=block_tables, use_cuda_graph=False, + prompt_embeds_indices=prompt_embeds_indices, ) - return (input_tokens, input_positions, input_metadata, prompt_lens, - subquery_lens) + return (input_tokens, input_positions, prompt_embeds, input_metadata, + prompt_lens, subquery_lens) def _prepare_decode( self, @@ -297,8 +342,9 @@ def _prepare_decode( context_lens=context_lens, block_tables=block_tables, use_cuda_graph=use_captured_graph, + prompt_embeds_indices=None, ) - return input_tokens, input_positions, input_metadata + return input_tokens, input_positions, None, input_metadata def _prepare_sample( self, @@ -382,10 +428,11 @@ def prepare_input_tensors( is_prompt = seq_group_metadata_list[0].is_prompt # Prepare input tensors. if is_prompt: - (input_tokens, input_positions, input_metadata, prompt_lens, + (input_tokens, input_positions, prompt_embeds, input_metadata, + prompt_lens, subquery_lens) = self._prepare_prompt(seq_group_metadata_list) else: - (input_tokens, input_positions, input_metadata + (input_tokens, input_positions, prompt_embeds, input_metadata ) = self._prepare_decode(seq_group_metadata_list) subquery_lens = None prompt_lens = [] @@ -434,7 +481,7 @@ def prepare_input_tensors( perform_sampling=False, ) - return input_tokens, input_positions, input_metadata, sampling_metadata + return input_tokens, input_positions, prompt_embeds, input_metadata, sampling_metadata @torch.inference_mode() def execute_model( @@ -442,7 +489,7 @@ def execute_model( seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], ) -> Optional[SamplerOutput]: - input_tokens, input_positions, input_metadata, sampling_metadata = ( + input_tokens, input_positions, prompt_embeds, input_metadata, sampling_metadata = ( self.prepare_input_tensors(seq_group_metadata_list)) # Execute the model. if input_metadata.use_cuda_graph: @@ -455,6 +502,7 @@ def execute_model( positions=input_positions, kv_caches=kv_caches, input_metadata=input_metadata, + prompt_embeds=prompt_embeds, ) # Sample the next token. @@ -644,6 +692,14 @@ def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]: return x + [pad] * (max_len - len(x)) +def _pad_embeddings_to_max(x: torch.Tensor, max_len: int, + pad: torch.Tensor) -> torch.Tensor: + return torch.cat( + [x, pad.repeat(max_len - x.shape[0], 1)], + dim=0, + ) + + def _make_tensor_with_pad( x: List[List[int]], max_len: int, diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 7d99c634ded1..188c5bbea7bf 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -79,6 +79,7 @@ def init_model(self) -> None: def load_model(self): self.model_runner.load_model() + self.model_runner.set_zero_token_embeds() @torch.inference_mode() def profile_num_available_blocks(