From ba8d8d109e02ca321fe8403fff84b995263ab219 Mon Sep 17 00:00:00 2001 From: AzureSilent Date: Sun, 17 Dec 2023 06:31:56 +0000 Subject: [PATCH 1/3] add llava model --- Dockerfile | 8 + requirements-rocm.txt | 1 + requirements.txt | 1 + vllm/__init__.py | 2 + vllm/engine/async_llava_engine.py | 140 +++++++++++++++ vllm/engine/llava_engine.py | 119 +++++++++++++ vllm/entrypoints/llava_llm.py | 207 +++++++++++++++++++++ vllm/entrypoints/llava_server.py | 119 +++++++++++++ vllm/model_executor/model_loader.py | 1 + vllm/model_executor/models/__init__.py | 2 + vllm/model_executor/models/llama.py | 18 +- vllm/model_executor/models/llava.py | 238 +++++++++++++++++++++++++ vllm/sequence.py | 8 +- vllm/worker/model_runner.py | 49 +++++ vllm/worker/worker.py | 18 +- 15 files changed, 925 insertions(+), 6 deletions(-) create mode 100644 vllm/engine/async_llava_engine.py create mode 100644 vllm/engine/llava_engine.py create mode 100644 vllm/entrypoints/llava_llm.py create mode 100644 vllm/entrypoints/llava_server.py create mode 100644 vllm/model_executor/models/llava.py diff --git a/Dockerfile b/Dockerfile index 13a38f2eba64..68920788fbcf 100644 --- a/Dockerfile +++ b/Dockerfile @@ -71,6 +71,14 @@ COPY vllm vllm EXPOSE 8000 ENTRYPOINT ["python3", "-m", "vllm.entrypoints.api_server"] + +FROM vllm-base as vllm-llava + +COPY --from=build /workspace/vllm/*.so /workspace/vllm/ +COPY vllm vllm + +ENTRYPOINT ["python3", "-m", "vllm.entrypoints.llava_server"] + # openai api server alternative FROM vllm-base AS vllm-openai # install additional dependencies for openai api server, and mixtral diff --git a/requirements-rocm.txt b/requirements-rocm.txt index c2e0dc3f464f..811867378036 100644 --- a/requirements-rocm.txt +++ b/requirements-rocm.txt @@ -15,3 +15,4 @@ fastapi uvicorn[standard] pydantic == 1.10.13 # Required for OpenAI server. aioprometheus[starlette] +pillow # Rqueired for image processing. \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 04b19b97babf..10c352f7de16 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,3 +13,4 @@ fastapi uvicorn[standard] pydantic == 1.10.13 # Required for OpenAI server. aioprometheus[starlette] +pillow # Rqueired for image processing. diff --git a/vllm/__init__.py b/vllm/__init__.py index 3121d1169027..4e3d1ce3be6c 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -5,6 +5,7 @@ from vllm.engine.llm_engine import LLMEngine from vllm.engine.ray_utils import initialize_cluster from vllm.entrypoints.llm import LLM +from vllm.entrypoints.llava_llm import LLaVA from vllm.outputs import CompletionOutput, RequestOutput from vllm.sampling_params import SamplingParams @@ -12,6 +13,7 @@ __all__ = [ "LLM", + "LLaVA", "SamplingParams", "RequestOutput", "CompletionOutput", diff --git a/vllm/engine/async_llava_engine.py b/vllm/engine/async_llava_engine.py new file mode 100644 index 000000000000..c264baf867b5 --- /dev/null +++ b/vllm/engine/async_llava_engine.py @@ -0,0 +1,140 @@ +from vllm.engine.llava_engine import LLaVAEngine +from vllm.engine.async_llm_engine import AsyncLLMEngine, _AsyncLLMEngine, AsyncStream, AsyncEngineDeadError +import asyncio +import time +from typing import (List, Optional, Type) +from PIL import Image +from vllm.logger import init_logger +from vllm.outputs import RequestOutput +from vllm.sampling_params import SamplingParams + +logger = init_logger(__name__) + + +class _AsyncLLaVAEngine(LLaVAEngine, _AsyncLLMEngine): + + async def step_async(self) -> List[RequestOutput]: + """Performs one decoding iteration and returns newly generated results. + The workers are ran asynchronously if possible. + + This function performs one decoding iteration of the engine. It first + schedules the sequences to be executed in the next iteration and the + token blocks to be swapped in/out/copy. Then, it executes the model + and updates the scheduler with the model outputs. Finally, it decodes + the sequences and returns the newly generated results. + + This rewriting of the function is to sent the runner_method to model + runner then knowing that it is a llava model. It won't be needed in the + future when we merge the execute_llava_model function to the + execute_model. + """ + seq_group_metadata_list, scheduler_outputs, ignored = self._schedule() + if scheduler_outputs.is_empty(): + return ignored + + # Execute the model. + output = await self._run_workers_async( + "execute_model", + seq_group_metadata_list=seq_group_metadata_list, + blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, + blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, + blocks_to_copy=scheduler_outputs.blocks_to_copy, + runner_method="execute_llava_model", + ) + + return self._process_model_outputs(output, scheduler_outputs) + ignored + + +class AsyncLLaVAEngine(AsyncLLMEngine): + + _engine_class: Type[_AsyncLLaVAEngine] = _AsyncLLaVAEngine + + async def add_request( + self, + request_id: str, + prompt: Optional[str], + sampling_params: SamplingParams, + prompt_token_ids: Optional[List[int]] = None, + arrival_time: Optional[float] = None, + images: Optional[List[Image.Image]] = None) -> AsyncStream: + if self.log_requests: + shortened_prompt = prompt + shortened_token_ids = prompt_token_ids + if self.max_log_len is not None: + if shortened_prompt is not None: + shortened_prompt = shortened_prompt[:self.max_log_len] + if shortened_token_ids is not None: + shortened_token_ids = shortened_token_ids[:self. + max_log_len] + logger.info(f"Received request {request_id}: " + f"prompt: {shortened_prompt!r}, " + f"sampling params: {sampling_params}, " + f"prompt token ids: {shortened_token_ids}." + f"images: {0 if images is None else len(images)}") + + if not self.is_running: + if self.start_engine_loop: + self.start_background_loop() + else: + raise AsyncEngineDeadError( + "Background loop is not running. If it was running, " + "inspect the output to find the stacktrace of the " + "error that caused the background loop to stop " + "(AsyncEngineDeadError).") + + stream = self._request_tracker.add_request( + request_id, + prompt=prompt, + sampling_params=sampling_params, + prompt_token_ids=prompt_token_ids, + arrival_time=arrival_time, + images=images) + + return stream + + async def generate( + self, + prompt: Optional[str], + sampling_params: SamplingParams, + request_id: str, + prompt_token_ids: Optional[List[int]] = None, + images: Optional[List[Image.Image]] = None) -> RequestOutput: + """Generate outputs for a request. + + Generate outputs for a request. This method is a coroutine. It adds the + request into the waiting queue of the LLMEngine and streams the outputs + from the LLMEngine to the caller. + + Args: + prompt: The prompt string. Can be None if prompt_token_ids is + provided. + sampling_params: The sampling parameters of the request. + request_id: The unique id of the request. + prompt_token_ids: The token IDs of the prompt. If None, we + use the tokenizer to convert the prompts to token IDs. + images: A list of PIL images for the prompt. It supports multiple + images, although most llava models are trained with only one image. + + Yields: + The output `RequestOutput` objects from the LLMEngine for the + request. + """ + # Preprocess the request. + # This should not be used for logging, as it is monotonic time. + arrival_time = time.monotonic() + + try: + stream = await self.add_request(request_id, + prompt, + sampling_params, + prompt_token_ids=prompt_token_ids, + arrival_time=arrival_time, + images=images) + + async for request_output in stream: + yield request_output + except (Exception, asyncio.CancelledError) as e: + # If there is an exception or coroutine is cancelled, abort the + # request. + self._abort(request_id) + raise e diff --git a/vllm/engine/llava_engine.py b/vllm/engine/llava_engine.py new file mode 100644 index 000000000000..8fc9cd00cd4f --- /dev/null +++ b/vllm/engine/llava_engine.py @@ -0,0 +1,119 @@ +from vllm.engine.llm_engine import LLMEngine +from transformers import CLIPImageProcessor +import time +from functools import partial +from typing import List, Optional + +from vllm.engine.ray_utils import ray +from vllm.outputs import RequestOutput +from vllm.sampling_params import SamplingParams +from vllm.sequence import (Sequence, SequenceGroup) +from PIL import Image +import numpy as np + + +class LLaVAEngine(LLMEngine): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.image_processor = CLIPImageProcessor.from_pretrained( + self.model_config.tokenizer) + + def add_request( + self, + request_id: str, + prompt: Optional[str], + sampling_params: SamplingParams, + prompt_token_ids: Optional[List[int]] = None, + arrival_time: Optional[float] = None, + images: Optional[List[Image.Image]] = None, + ) -> None: + """Add a request to the engine's request pool. + + The request is added to the request pool and will be processed by the + scheduler as `engine.step()` is called. The exact scheduling policy is + determined by the scheduler. + + Args: + request_id: The unique ID of the request. + prompt: The prompt string. Can be None if prompt_token_ids is + provided. + sampling_params: The sampling parameters for text generation. + prompt_token_ids: The token IDs of the prompt. If None, we + use the tokenizer to convert the prompts to token IDs. + arrival_time: The arrival time of the request. If None, we use + the current monotonic time. + images: A list of PIL images for the prompt. It supports multiple + images, although most llava models are trained with only one image. + """ + if arrival_time is None: + arrival_time = time.monotonic() + if prompt_token_ids is None: + assert prompt is not None + prompt_token_ids = self.tokenizer.encode(prompt) + + if images is not None and len(images) > 0: + pixel_values = self.image_processor( + images, return_tensors="pt")['pixel_values'] + else: + pixel_values = None + + # prepare prompt. expand image token and extract image features + num_workers = len(self.workers) + # random select a worker + worker = self.workers[np.random.randint(num_workers)] + if self.parallel_config.worker_use_ray: + execute_model_methord = partial(worker.execute_method.remote, + 'execute_model_methord') + else: + execute_model_methord = getattr(worker, 'execute_model_methord') + outputs = execute_model_methord('prepare_promt', prompt_token_ids, + pixel_values) + if self.parallel_config.worker_use_ray: + outputs = ray.get(outputs) + processed_token_ids, image_features = outputs + prompt_token_ids = processed_token_ids.tolist() + if image_features is not None: + extra_data = {'image_features': image_features} + else: + extra_data = None + # Create the sequences. + block_size = self.cache_config.block_size + seq_id = next(self.seq_counter) + seq = Sequence(seq_id, + prompt, + prompt_token_ids, + block_size, + extra_data=extra_data) + + # Create the sequence group. + seq_group = SequenceGroup(request_id, [seq], sampling_params, + arrival_time) + + # Add the sequence group to the scheduler. + self.scheduler.add_seq_group(seq_group) + + def step(self) -> List[RequestOutput]: + """Performs one decoding iteration and returns newly generated results. + + This function performs one decoding iteration of the engine. It first + schedules the sequences to be executed in the next iteration and the + token blocks to be swapped in/out/copy. Then, it executes the model + and updates the scheduler with the model outputs. Finally, it decodes + the sequences and returns the newly generated results. + """ + seq_group_metadata_list, scheduler_outputs, ignored = self._schedule() + if scheduler_outputs.is_empty(): + return ignored + + # Execute the model. + output = self._run_workers( + "execute_model", + seq_group_metadata_list=seq_group_metadata_list, + blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, + blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, + blocks_to_copy=scheduler_outputs.blocks_to_copy, + runner_method="execute_llava_model", + ) + + return self._process_model_outputs(output, scheduler_outputs) diff --git a/vllm/entrypoints/llava_llm.py b/vllm/entrypoints/llava_llm.py new file mode 100644 index 000000000000..ed6faffb3454 --- /dev/null +++ b/vllm/entrypoints/llava_llm.py @@ -0,0 +1,207 @@ +from vllm.engine.llava_engine import LLaVAEngine +from PIL import Image +import requests +import base64 +from io import BytesIO +import numpy as np +from typing import List, Optional, Union + +from tqdm import tqdm +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast + +from vllm.engine.arg_utils import EngineArgs +from vllm.outputs import RequestOutput +from vllm.sampling_params import SamplingParams +from vllm.utils import Counter + + +class LLaVA: + + def __init__( + self, + model: str, + tokenizer: Optional[str] = None, + tokenizer_mode: str = "auto", + trust_remote_code: bool = False, + tensor_parallel_size: int = 1, + dtype: str = "auto", + quantization: Optional[str] = None, + revision: Optional[str] = None, + tokenizer_revision: Optional[str] = None, + seed: int = 0, + gpu_memory_utilization: float = 0.9, + swap_space: int = 4, + **kwargs, + ) -> None: + if "disable_log_stats" not in kwargs: + kwargs["disable_log_stats"] = True + engine_args = EngineArgs( + model=model, + tokenizer=tokenizer, + tokenizer_mode=tokenizer_mode, + trust_remote_code=trust_remote_code, + tensor_parallel_size=tensor_parallel_size, + dtype=dtype, + quantization=quantization, + revision=revision, + tokenizer_revision=tokenizer_revision, + seed=seed, + gpu_memory_utilization=gpu_memory_utilization, + swap_space=swap_space, + **kwargs, + ) + self.llm_engine = LLaVAEngine.from_engine_args(engine_args) + self.request_counter = Counter() + + def get_tokenizer( + self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + return self.llm_engine.tokenizer + + def set_tokenizer( + self, + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + ) -> None: + self.llm_engine.tokenizer = tokenizer + + def get_image_processor(self): + return self.llm_engine.image_processor + + def set_image_processor(self, image_processor): + self.llm_engine.image_processor = image_processor + + def generate( + self, + prompts: Optional[Union[str, List[str]]] = None, + sampling_params: Optional[SamplingParams] = None, + prompt_token_ids: Optional[List[List[int]]] = None, + use_tqdm: bool = True, + images: Optional[Union[Image.Image, List[Image.Image]]] = None, + ) -> List[RequestOutput]: + """Generates the completions for the input prompts. + + NOTE: This class automatically batches the given prompts, considering + the memory constraint. For the best performance, put all of your prompts + into a single list and pass it to this method. + + Args: + prompts: A list of prompts to generate completions for. + sampling_params: The sampling parameters for text generation. If + None, we use the default sampling parameters. + prompt_token_ids: A list of token IDs for the prompts. If None, we + use the tokenizer to convert the prompts to token IDs. + use_tqdm: Whether to use tqdm to display the progress bar. + + Returns: + A list of `RequestOutput` objects containing the generated + completions in the same order as the input prompts. + """ + if prompts is None and prompt_token_ids is None: + raise ValueError("Either prompts or prompt_token_ids must be " + "provided.") + if isinstance(prompts, str): + # Convert a single prompt to a list. + prompts = [prompts] + + if (prompts is not None and prompt_token_ids is not None + and len(prompts) != len(prompt_token_ids)): + raise ValueError("The lengths of prompts and prompt_token_ids " + "must be the same.") + if sampling_params is None: + # Use default sampling params. + sampling_params = SamplingParams() + + # process images + if images is None: + images = [] + elif not isinstance(images, list): + images = [images] + _images = [] + image_id = 0 + for image in images: + if isinstance(image, str): + if image.startswith("http"): + _images.append( + Image.open(requests.get(image, stream=True).raw)) + elif image.startswith("data:"): + _images.append( + Image.open( + BytesIO(base64.b64decode(image.split(",")[1])))) + elif image.startswith("/"): + _images.append(Image.open(image)) + else: + _images.append(Image.open(BytesIO( + base64.b64decode(image)))) + elif isinstance(image, Image.Image): + _images.append(image) + else: + raise ValueError("image must be str or PIL.Image") + + # image_token_index = self.config.image_token_index + image_token_index = 32000 + image_token = self.get_tokenizer().decode(image_token_index) + + # Add requests to the engine. + num_requests = len(prompts) if prompts is not None else len( + prompt_token_ids) + for i in range(num_requests): + prompt = prompts[i] if prompts is not None else None + token_ids = None if prompt_token_ids is None else prompt_token_ids[ + i] + + image_token_num = 0 + if prompt is not None: + image_token_num = prompt.count(image_token) + if token_ids is not None: + _image_token_num = np.sum( + np.asarray(token_ids) == image_token_index) + if image_token_num != _image_token_num: + raise ValueError("image_token_num != _image_token_num") + else: + image_token_num = _image_token_num + if image_token_num > 0: + assert image_id + image_token_num <= len( + _images + ), " The input provided to the model are wrong. The number of image tokens is not equal to the number of images provided." + images = _images[image_id:image_id + image_token_num] + image_id += image_token_num + else: + images = None + + self._add_request(prompt, sampling_params, token_ids, images) + return self._run_engine(use_tqdm) + + def _add_request( + self, + prompt: Optional[str], + sampling_params: SamplingParams, + prompt_token_ids: Optional[List[int]], + images: Optional[List[Image.Image]] = None, + ) -> None: + request_id = str(next(self.request_counter)) + self.llm_engine.add_request(request_id, + prompt, + sampling_params, + prompt_token_ids, + images=images) + + def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]: + # Initialize tqdm. + if use_tqdm: + num_requests = self.llm_engine.get_num_unfinished_requests() + pbar = tqdm(total=num_requests, desc="Processed prompts") + # Run the engine. + outputs: List[RequestOutput] = [] + while self.llm_engine.has_unfinished_requests(): + step_outputs = self.llm_engine.step() + for output in step_outputs: + if output.finished: + outputs.append(output) + if use_tqdm: + pbar.update(1) + if use_tqdm: + pbar.close() + # Sort the outputs by request ID. + # This is necessary because some requests may be finished earlier than + # its previous requests. + outputs = sorted(outputs, key=lambda x: int(x.request_id)) + return outputs diff --git a/vllm/entrypoints/llava_server.py b/vllm/entrypoints/llava_server.py new file mode 100644 index 000000000000..0c353e01d5c3 --- /dev/null +++ b/vllm/entrypoints/llava_server.py @@ -0,0 +1,119 @@ +import argparse +import json +from typing import AsyncGenerator + +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse, Response, StreamingResponse +import uvicorn +from PIL import Image +import requests +import base64 +from io import BytesIO + +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.async_llava_engine import AsyncLLaVAEngine +from vllm.sampling_params import SamplingParams +from vllm.utils import random_uuid + +TIMEOUT_KEEP_ALIVE = 5 # seconds. +TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds. +app = FastAPI() +engine = None + + +@app.get("/health") +async def health() -> Response: + """Health check.""" + return Response(status_code=200) + + +@app.post("/generate") +async def generate(request: Request) -> Response: + """Generate completion for the request. + + The request should be a JSON object with the following fields: + - prompt: the prompt to use for the generation. + - stream: whether to stream the results or not. + - images: a list of strings, each string is either a url or a base64 encoded image. + - other fields: the sampling parameters (See `SamplingParams` for details). + + Currently use base64 to send file data. But it is not very efficient. + Due to the limitation of http, it is not easy to send both file and json body in a post request. + There are some other ways to do it, but will need to change the request format: + https://stackoverflow.com/questions/65504438/how-to-add-both-file-and-json-body-in-a-fastapi-post-request + """ + request_dict = await request.json() + prompt = request_dict.pop("prompt") + stream = request_dict.pop("stream", False) + images = request_dict.pop("images", None) + sampling_params = SamplingParams(**request_dict) + request_id = random_uuid() + + # decode images + if images is None: + images = [] + elif not isinstance(images, list): + images = [images] + _images = [] + for image in images: + if isinstance(image, str): + if image.startswith("http"): + _images.append(Image.open( + requests.get(image, stream=True).raw)) + elif image.startswith("data:"): + _images.append( + Image.open(BytesIO(base64.b64decode(image.split(",")[1])))) + else: + _images.append(Image.open(BytesIO(base64.b64decode(image)))) + if len(_images) == 0: + _images = None + + results_generator = engine.generate(prompt, + sampling_params, + request_id, + images=_images) + + # Streaming case + async def stream_results() -> AsyncGenerator[bytes, None]: + async for request_output in results_generator: + prompt = request_output.prompt + text_outputs = [ + prompt + output.text for output in request_output.outputs + ] + ret = {"text": text_outputs} + yield (json.dumps(ret) + "\0").encode("utf-8") + + if stream: + return StreamingResponse(stream_results()) + + # Non-streaming case + final_output = None + async for request_output in results_generator: + if await request.is_disconnected(): + # Abort the request if the client disconnects. + await engine.abort(request_id) + return Response(status_code=499) + final_output = request_output + + assert final_output is not None + prompt = final_output.prompt + text_outputs = [prompt + output.text for output in final_output.outputs] + ret = {"text": text_outputs} + return JSONResponse(ret) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default=None) + parser.add_argument("--port", type=int, default=8000) + parser = AsyncEngineArgs.add_cli_args(parser) + args = parser.parse_args() + + engine_args = AsyncEngineArgs.from_cli_args(args) + engine = AsyncLLaVAEngine.from_engine_args(engine_args) + + uvicorn.run(app, + host=args.host, + port=args.port, + log_level="debug", + timeout_keep_alive=TIMEOUT_KEEP_ALIVE) diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index e7bd7548afd2..2b5b558113c0 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -31,6 +31,7 @@ "GPTNeoXForCausalLM": GPTNeoXForCausalLM, "InternLMForCausalLM": InternLMForCausalLM, "LlamaForCausalLM": LlamaForCausalLM, + "LlavaForConditionalGeneration": LlavaForConditionalGeneration, "LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-* "MistralForCausalLM": MistralForCausalLM, "MixtralForCausalLM": MixtralForCausalLM, diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 28a0aa772d84..3a672ec51ebc 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -17,6 +17,7 @@ from vllm.model_executor.models.qwen import QWenLMHeadModel from vllm.model_executor.models.chatglm import ChatGLMForCausalLM from vllm.model_executor.models.yi import YiForCausalLM +from vllm.model_executor.models.llava import LlavaForConditionalGeneration __all__ = [ "AquilaForCausalLM", @@ -31,6 +32,7 @@ "GPTNeoXForCausalLM", "InternLMForCausalLM", "LlamaForCausalLM", + "LlavaForConditionalGeneration", "MPTForCausalLM", "OPTForCausalLM", "PhiForCausalLM", diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index cc83f5dd75f2..22c3cbb2e047 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -251,8 +251,12 @@ def forward( kv_caches: List[KVCache], input_metadata: InputMetadata, cache_events: Optional[List[torch.cuda.Event]], + inputs_embeds: Optional[torch.FloatTensor] = None, ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is None: + hidden_states = self.embed_tokens(input_ids) + else: + hidden_states = inputs_embeds residual = None for i in range(len(self.layers)): cache_event = None if cache_events is None else cache_events[i] @@ -283,6 +287,9 @@ def __init__( self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.sampler = Sampler(config.vocab_size) + def get_input_embeddings(self): + return self.model.embed_tokens + def forward( self, input_ids: torch.Tensor, @@ -290,9 +297,14 @@ def forward( kv_caches: List[KVCache], input_metadata: InputMetadata, cache_events: Optional[List[torch.cuda.Event]], + inputs_embeds: Optional[torch.FloatTensor] = None, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, kv_caches, - input_metadata, cache_events) + hidden_states = self.model(input_ids, + positions, + kv_caches, + input_metadata, + cache_events, + inputs_embeds=inputs_embeds) return hidden_states def sample( diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py new file mode 100644 index 000000000000..ae5c85b235aa --- /dev/null +++ b/vllm/model_executor/models/llava.py @@ -0,0 +1,238 @@ +"""Inference-only LLaVA model compatible with HuggingFace weights.""" +from typing import List, Optional, Tuple + +import torch +from torch import nn +from transformers import LlavaConfig, AutoModel +from transformers.activations import ACT2FN + +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.linear import LinearMethodBase +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) +from vllm.model_executor.models.llama import LlamaForCausalLM +from vllm.sequence import SamplerOutput +from vllm.logger import init_logger +import numpy as np + +logger = init_logger(__name__) +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class LlavaMultiModalProjector(nn.Module): + + def __init__(self, config: LlavaConfig): + super().__init__() + + self.linear_1 = nn.Linear(config.vision_config.hidden_size, + config.text_config.hidden_size, + bias=True) + self.act = ACT2FN[config.projector_hidden_act] + self.linear_2 = nn.Linear(config.text_config.hidden_size, + config.text_config.hidden_size, + bias=True) + + def forward(self, image_features): + hidden_states = self.linear_1(image_features) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +class LlavaForConditionalGeneration(nn.Module): + + def __init__( + self, + config, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + + self.config = config + self.linear_method = linear_method + + self.vision_tower = AutoModel.from_config(config.vision_config) + self.language_model = LlamaForCausalLM(config.text_config, + linear_method) + self.multi_modal_projector = LlavaMultiModalProjector(config) + + self.vocab_size = config.vocab_size + self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def prepare_promt(self, + input_ids: List[int], + pixel_values: torch.Tensor = None): + input_ids = np.asarray(input_ids) + assert len( + input_ids.shape + ) == 1, f"input_ids should be 1D array, got {input_ids.shape}" + + # Create a mask to know where image tokens are + image_token_mask = input_ids == self.config.image_token_index + non_image_indices = np.where( + input_ids != self.config.image_token_index) + + # check the number of image tokens and images + num_image_tokens = image_token_mask.sum() + num_images = 0 if pixel_values is None else pixel_values.shape[0] + assert num_images == num_image_tokens, f" The input provided to the model are wrong. The number of image tokens ({num_image_tokens}) is not equal to the number of images ({num_images}) provided." + + # expand each image token to image_hidden_dim + if pixel_values is not None: + # get image features + pixel_values = pixel_values.to('cuda') + image_outputs = self.vision_tower(pixel_values, + output_hidden_states=True) + # this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated. + + selected_image_feature = image_outputs.hidden_states[ + self.config.vision_feature_layer] + if self.config.vision_feature_select_strategy == "default": + selected_image_feature = selected_image_feature[:, 1:] + elif self.config.vision_feature_select_strategy == "full": + selected_image_feature = selected_image_feature + else: + raise ValueError( + f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}" + ) + image_features = self.multi_modal_projector( + selected_image_feature).cpu() + nb_images, image_hidden_dim, embed_dim = image_features.shape + + # Compute the positions where text should be written + # Calculate new positions for text tokens in merged image-text sequence. + # `image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens. + # `torch.cumsum` computes how each image token shifts subsequent text token positions. + # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one. + new_token_positions = np.cumsum( + (image_token_mask * (image_hidden_dim - 1) + 1), -1) - 1 + text_to_overwrite = new_token_positions[non_image_indices] + + final_input_ids = np.ones( + (num_images * (image_hidden_dim - 1)) + len(input_ids), + dtype=input_ids.dtype) * self.config.image_token_index + final_input_ids[text_to_overwrite] = input_ids[non_image_indices] + + input_ids = final_input_ids + image_features = image_features.contiguous().reshape(-1, embed_dim) + else: + image_features = None + return input_ids, image_features + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + cache_events: Optional[List[torch.cuda.Event]], + image_features: Optional[List[torch.Tensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + + if inputs_embeds is None: + # Extra the input embeddings + inputs_embeds = self.get_input_embeddings()(input_ids) + # TODO change the vision_tower to parallel version or pre-compute the image features somewhere else + # currently, if put the vision_tower here will cause duplicated process. + + # repace the embedding of image tokens with the image features. + if image_features is not None and input_ids.shape[1] != 1: + image_token_mask = input_ids == self.config.image_token_index + # image_features is a list of tensor, len(image_features) == batch_size + # each tensor is a concatenate of image features, there shapes are not the same, + # and may be None if the prompt have no image tokens. + # shape: [image_num * image_hidden_dim, embed_dim], image_hidden_dim: feature tokens per image + for i, features in enumerate(image_features): + if features is not None: # the prompt have a image + inputs_embeds[i][image_token_mask[i]] = features.to( + inputs_embeds) + else: + # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of + # generation with cache + pass + + hidden_states = self.language_model(input_ids, + positions, + kv_caches, + input_metadata, + cache_events, + inputs_embeds=inputs_embeds) + return hidden_states + + def sample( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> SamplerOutput: + return self.language_model.sample(hidden_states, sampling_metadata) + + def load_weights(self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + unused_keys = [] + + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision): + + if name.startswith( + "model.language_model"): # load language model weights + name = name[6:] # remove "model." prefix + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + for (param_name, weight_name, + shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + param = params_dict[name.replace(weight_name, param_name)] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + if params_dict.get(name, None) is None: + unused_keys.append(name) + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + elif name.startswith("model.vision_tower") or name.startswith( + 'model.multi_modal_projector' + ): # load vision model weights + name = name[6:] # remove "model." prefix + if params_dict.get(name, None) is None: + unused_keys.append(name) + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + else: + # duplicate keys with out 'model.' prefix + pass + + if len(unused_keys) > 0: + unused_keys.sort() + logger.warning( + f"These keys found in checkpoint but not used in model! {unused_keys}" + ) diff --git a/vllm/sequence.py b/vllm/sequence.py index 7d36eeac0aa0..225634de9d7e 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -53,20 +53,24 @@ class SequenceData: Args: prompt_token_ids: The token IDs of the prompt. + extra_data: Extra data for the multimodality models. Attributes: prompt_token_ids: The token IDs of the prompt. output_token_ids: The token IDs of the output. cumulative_logprob: The cumulative log probability of the output. + extra_data: Extra data for the multimodality models. """ def __init__( self, prompt_token_ids: List[int], + extra_data: Optional[dict] = None, ) -> None: self.prompt_token_ids = prompt_token_ids self.output_token_ids: List[int] = [] self.cumulative_logprob = 0.0 + self.extra_data = extra_data def append_token_id(self, token_id: int, logprob: float) -> None: self.output_token_ids.append(token_id) @@ -105,6 +109,7 @@ class Sequence: prompt_token_ids: The token IDs of the prompt. block_size: The block size of the sequence. Should be the same as the block size used by the block manager and cache engine. + extra_data: Extra data for the multimodality models. """ def __init__( @@ -113,12 +118,13 @@ def __init__( prompt: str, prompt_token_ids: List[int], block_size: int, + extra_data: Optional[dict] = None, ) -> None: self.seq_id = seq_id self.prompt = prompt self.block_size = block_size - self.data = SequenceData(prompt_token_ids) + self.data = SequenceData(prompt_token_ids, extra_data=extra_data) self.output_logprobs: SampleLogprobs = [] self.output_text = "" diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 2209c994e2b8..3678d7ff3fb0 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -291,6 +291,55 @@ def execute_model( ) return output + @torch.inference_mode() + def execute_llava_model( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + cache_events: Optional[List[torch.cuda.Event]] = None, + ) -> SamplerOutput: + # NOTE: We assume that all sequences in the group are all prompts or + # all decodes. + # Prepare input tensors. + is_prompt = seq_group_metadata_list[0].is_prompt + image_features = None + if is_prompt: + inputs = self._prepare_prompt(seq_group_metadata_list) + input_tokens, input_positions, input_metadata = inputs + + image_features = [] + for seq_group_metadata in seq_group_metadata_list: + extra_data = seq_group_metadata.seq_data[list( + seq_group_metadata.seq_data.keys())[0]].extra_data + if extra_data is not None and 'image_features' in extra_data: + image_features.append( + extra_data.get('image_features', None)) + else: + image_features.append(None) + + else: + inputs = self._prepare_decode(seq_group_metadata_list) + input_tokens, input_positions, input_metadata = inputs + sampling_metadata = self._prepare_sample(seq_group_metadata_list, + input_metadata.prompt_lens) + + # Execute the model. + hidden_states = self.model( + input_ids=input_tokens, + positions=input_positions, + kv_caches=kv_caches, + input_metadata=input_metadata, + cache_events=cache_events, + image_features=image_features, + ) + + # Sample the next token. + output = self.model.sample( + hidden_states=hidden_states, + sampling_metadata=sampling_metadata, + ) + return output + @torch.inference_mode() def profile_run(self) -> None: # Enable top-k sampling to reflect the accurate memory usage. diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 6f5e16f0011f..245d1fe55249 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -122,6 +122,7 @@ def execute_model( blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int, int], blocks_to_copy: Dict[int, List[int]], + runner_method: str = "execute_model", ) -> SamplerOutput: # Issue cache operations. issued_cache_op = False @@ -144,10 +145,23 @@ def execute_model( event.wait() return {} - output = self.model_runner.execute_model(seq_group_metadata_list, - self.gpu_cache, cache_events) + output = self.model_runner.__getattribute__(runner_method)( + seq_group_metadata_list, self.gpu_cache, cache_events) return output + @torch.inference_mode() + def execute_model_methord( + self, + model_methord: str, + *args, + **kwargs, + ): + """Directly execute some none distributed model methord. Just a temporary hack. + For the image token replace of the llava model. + """ + return self.model_runner.model.__getattribute__(model_methord)( + *args, **kwargs) + def _init_distributed_environment( parallel_config: ParallelConfig, From c58ca15a04246acd11fda26a289da92b48d12a9a Mon Sep 17 00:00:00 2001 From: AzureSilent Date: Tue, 19 Dec 2023 02:11:48 +0000 Subject: [PATCH 2/3] fix CUDA graph mode --- vllm/engine/llava_engine.py | 13 +-- vllm/model_executor/models/llava.py | 153 ++++++++++++++++++++-------- vllm/sequence.py | 8 +- vllm/worker/model_runner.py | 45 ++++---- 4 files changed, 149 insertions(+), 70 deletions(-) diff --git a/vllm/engine/llava_engine.py b/vllm/engine/llava_engine.py index 032583f96e9c..d201db4cd7b6 100644 --- a/vllm/engine/llava_engine.py +++ b/vllm/engine/llava_engine.py @@ -52,13 +52,17 @@ def add_request( assert prompt is not None prompt_token_ids = self.tokenizer.encode(prompt) + # process images + extra_data = None if images is not None and len(images) > 0: pixel_values = self.image_processor( images, return_tensors="pt")['pixel_values'] + extra_data = {'pixel_values': pixel_values} else: pixel_values = None - # prepare prompt. expand image token and extract image features + # Check the validation of the imput. And expand each image token to the + # number of tokens per image. So the scheduler can allocate proper resources. num_workers = len(self.workers) # random select a worker worker = self.workers[np.random.randint(num_workers)] @@ -71,12 +75,9 @@ def add_request( pixel_values) if self.parallel_config.worker_use_ray: outputs = ray.get(outputs) - processed_token_ids, image_features = outputs + processed_token_ids = outputs prompt_token_ids = processed_token_ids.tolist() - if image_features is not None: - extra_data = {'image_features': image_features} - else: - extra_data = None + # Create the sequences. block_size = self.cache_config.block_size seq_id = next(self.seq_counter) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 3502303aa9d0..da9dda54b7e4 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -60,12 +60,31 @@ def __init__( self.vocab_size = config.vocab_size self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 + patches_per_image = int(config.vision_config.image_size / + config.vision_config.patch_size)**2 + if self.config.vision_feature_select_strategy == "default": + self.tokens_per_image = patches_per_image + elif self.config.vision_feature_select_strategy == "full": + self.tokens_per_image = patches_per_image + 1 + else: + raise ValueError( + f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}" + ) + def get_input_embeddings(self): return self.language_model.get_input_embeddings() def prepare_promt(self, input_ids: List[int], pixel_values: torch.Tensor = None): + """ + 1.Check the validation of the imput. + 2.Expand each image token to the number of tokens per image. + So the scheduler can allocate proper resources. + + We do not extract the image features here. + This function deals with only one request/promt. + """ input_ids = np.asarray(input_ids) assert len( input_ids.shape @@ -76,52 +95,86 @@ def prepare_promt(self, non_image_indices = np.where( input_ids != self.config.image_token_index) - # check the number of image tokens and images + # check if the number of image tokens and images are matched num_image_tokens = image_token_mask.sum() num_images = 0 if pixel_values is None else pixel_values.shape[0] assert num_images == num_image_tokens, f" The input provided to the model are wrong. The number of image tokens ({num_image_tokens}) is not equal to the number of images ({num_images}) provided." - # expand each image token to image_hidden_dim - if pixel_values is not None: - # get image features - pixel_values = pixel_values.to('cuda') - image_outputs = self.vision_tower(pixel_values, - output_hidden_states=True) - # this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated. - - selected_image_feature = image_outputs.hidden_states[ - self.config.vision_feature_layer] - if self.config.vision_feature_select_strategy == "default": - selected_image_feature = selected_image_feature[:, 1:] - elif self.config.vision_feature_select_strategy == "full": - selected_image_feature = selected_image_feature - else: - raise ValueError( - f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}" - ) - image_features = self.multi_modal_projector( - selected_image_feature).cpu() - nb_images, image_hidden_dim, embed_dim = image_features.shape - + # expand each image token to number of tokens per image + if num_images > 0: # Compute the positions where text should be written # Calculate new positions for text tokens in merged image-text sequence. - # `image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens. + # `image_token_mask` identifies image tokens. # `torch.cumsum` computes how each image token shifts subsequent text token positions. # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one. new_token_positions = np.cumsum( - (image_token_mask * (image_hidden_dim - 1) + 1), -1) - 1 + (image_token_mask * (self.tokens_per_image - 1) + 1), -1) - 1 text_to_overwrite = new_token_positions[non_image_indices] final_input_ids = np.ones( - (num_images * (image_hidden_dim - 1)) + len(input_ids), + (num_images * (self.tokens_per_image - 1)) + len(input_ids), dtype=input_ids.dtype) * self.config.image_token_index final_input_ids[text_to_overwrite] = input_ids[non_image_indices] input_ids = final_input_ids - image_features = image_features.contiguous().reshape(-1, embed_dim) else: - image_features = None - return input_ids, image_features + final_input_ids = input_ids + return final_input_ids + + def extract_visual_features( + self, + input_ids: torch.Tensor, + pixel_values: Optional[List[torch.Tensor]] = None, + image_features: Optional[List[torch.Tensor]] = None, + ): + """ + process batched inputs, extract visual features from pixel_values + pixel_values: each element is a tensor of shape [num_images, 3, height, width] + image_features: extracted visual features + """ + if input_ids.shape[1] == 1: + # in the case of generation with cache + return None + _pixel_values = [ + values for values in pixel_values if values is not None + ] + if len(_pixel_values) < 1: + res_image_features = image_features + else: + _pixel_values = torch.cat(_pixel_values, dim=0).to('cuda') + # TODO change the vision_tower to parallel version + image_outputs = self.vision_tower(_pixel_values, + output_hidden_states=True) + selected_image_feature = image_outputs.hidden_states[ + self.config.vision_feature_layer] + if self.config.vision_feature_select_strategy == "default": + selected_image_feature = selected_image_feature[:, 1:] + elif self.config.vision_feature_select_strategy == "full": + selected_image_feature = selected_image_feature + else: + raise ValueError( + f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}" + ) + input_image_features = image_features if image_features is not None else [ + None + ] * input_ids.shape[0] + projected_image_features = self.multi_modal_projector( + selected_image_feature) + nb_images, image_hidden_dim, embed_dim = projected_image_features.shape + res_image_features = [] + # flatten the image tokens for each prompt + for i, value in enumerate(pixel_values): + if value is None: + # if the prompt have no pixel_values, use the input image feature + res_image_features.append( + input_image_features[i].to('cuda')) + else: + res_image_features.append( + projected_image_features[:value.shape[0]].contiguous( + ).reshape(-1, embed_dim)) + projected_image_features = projected_image_features[ + value.shape[0]:] + return res_image_features def forward( self, @@ -131,28 +184,38 @@ def forward( input_metadata: InputMetadata, image_features: Optional[List[torch.Tensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[List[torch.Tensor]] = None, ) -> torch.Tensor: + """ + image_features is a list of tensor, len(image_features) == batch_size + each tensor is a concatenation of image features, there shapes are not the same, + and may be None if the prompt have no image tokens. + shape: [image_num * image_hidden_dim, embed_dim], image_hidden_dim: feature tokens per image + """ if inputs_embeds is None: - # Extra the input embeddings + # Extract the input embeddings inputs_embeds = self.get_input_embeddings()(input_ids) - # TODO change the vision_tower to parallel version or pre-compute the image features somewhere else - # currently, if put the vision_tower here will cause duplicated process. - # repace the embedding of image tokens with the image features. - if image_features is not None and input_ids.shape[1] != 1: - image_token_mask = input_ids == self.config.image_token_index - # image_features is a list of tensor, len(image_features) == batch_size - # each tensor is a concatenate of image features, there shapes are not the same, - # and may be None if the prompt have no image tokens. - # shape: [image_num * image_hidden_dim, embed_dim], image_hidden_dim: feature tokens per image - for i, features in enumerate(image_features): - if features is not None: # the prompt have a image - inputs_embeds[i][image_token_mask[i]] = features.to( - inputs_embeds) + if input_ids.shape[1] != 1: + # Extract the image features + if pixel_values is not None: + # TODO Put the image process here seams won't impact the GUDA graph? But will comsume too + # more memory during the graph_runner trace. + # But put this out side may change the model_runner too much and not graceful. + image_features = self.extract_visual_features( + input_ids, pixel_values, image_features) + # if image_features is None: + # image_features = [] + # print(input_ids.shape, [f if f is None else f.shape for f in image_features]) + if image_features is not None: + image_token_mask = input_ids == self.config.image_token_index + for i, features in enumerate(image_features): + if features is not None: # the prompt have a image + inputs_embeds[i][image_token_mask[ + i]] = features.to(inputs_embeds) else: - # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of - # generation with cache + # we are in the case of generation with cache pass hidden_states = self.language_model(input_ids, diff --git a/vllm/sequence.py b/vllm/sequence.py index 225634de9d7e..cae47281b3fa 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -109,7 +109,13 @@ class Sequence: prompt_token_ids: The token IDs of the prompt. block_size: The block size of the sequence. Should be the same as the block size used by the block manager and cache engine. - extra_data: Extra data for the multimodality models. + extra_data: Extra data for the multimodality models. This data will be + sent directly to the model.forward. e.g if three Sequence has + extra_data = {'pix': [1,2] } + extra_data = {'pix': [1], 'text': 'str'} + extra_data = None + this will call model.forward(..., + pix=[[1,2], [1], None], text=[None, 'str', None]) """ def __init__( diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index fc423a343a6b..fedb0ecfef09 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1,5 +1,6 @@ import time from typing import Dict, List, Tuple, Union +from collections import defaultdict import numpy as np import torch @@ -367,21 +368,23 @@ def execute_llava_model( # all decodes. is_prompt = seq_group_metadata_list[0].is_prompt # Prepare input tensors. - image_features = None + extra_kwargs = {} if is_prompt: inputs = self._prepare_prompt(seq_group_metadata_list) input_tokens, input_positions, input_metadata = inputs - image_features = [] - for seq_group_metadata in seq_group_metadata_list: - extra_data = seq_group_metadata.seq_data[list( - seq_group_metadata.seq_data.keys())[0]].extra_data - if extra_data is not None and 'image_features' in extra_data: - image_features.append( - extra_data.get('image_features', None)) - else: - image_features.append(None) - + # Collect extra data for each prompt from seq_group_metadata_list. e.g. image pixel values, image features + if input_tokens.shape[1] > 1: + extra_kwargs = defaultdict( + lambda: [None for _ in range(input_tokens.shape[0])]) + # Not in the stage of generation with cache + for i, seq_group_metadata in enumerate( + seq_group_metadata_list): + extra_data = seq_group_metadata.seq_data[list( + seq_group_metadata.seq_data.keys())[0]].extra_data + if extra_data is not None: + for key, v in extra_data.items(): + extra_kwargs[key][i] = v else: inputs = self._prepare_decode(seq_group_metadata_list) input_tokens, input_positions, input_metadata = inputs @@ -394,13 +397,11 @@ def execute_llava_model( model_executable = self.graph_runners[graph_batch_size] else: model_executable = self.model - hidden_states = model_executable( - input_ids=input_tokens, - positions=input_positions, - kv_caches=kv_caches, - input_metadata=input_metadata, - image_features=image_features, - ) + hidden_states = model_executable(input_ids=input_tokens, + positions=input_positions, + kv_caches=kv_caches, + input_metadata=input_metadata, + **extra_kwargs) # Sample the next token. output = self.model.sample( @@ -546,6 +547,7 @@ def forward( positions: torch.Tensor, kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], input_metadata: InputMetadata, + **kwargs, ) -> torch.Tensor: # KV caches are fixed tensors, so we don't need to copy them. del kv_caches @@ -556,6 +558,13 @@ def forward( self.input_buffers["slot_mapping"].copy_(input_metadata.slot_mapping) self.input_buffers["context_lens"].copy_(input_metadata.context_lens) self.input_buffers["block_tables"].copy_(input_metadata.block_tables) + for key, value in kwargs.items(): + # Hack, Only surrport values that do not change the graph. + # The image_features is only used to substitute the input_ids and won't change the graph. + if self.input_buffers.get(key, None) is not None: + self.input_buffers[key].copy_(value) + else: + self.input_buffers[key] = value # Run the graph. self.graph.replay() From a8b0dbcb3e54ccc6474b4845d1fe56f12b7fe3db Mon Sep 17 00:00:00 2001 From: AzureSilent Date: Fri, 29 Dec 2023 13:48:29 +0000 Subject: [PATCH 3/3] bug fix --- vllm/entrypoints/llava_llm.py | 11 +++++------ vllm/model_executor/models/llava.py | 13 +++++++------ vllm/transformers_utils/config.py | 5 +++++ 3 files changed, 17 insertions(+), 12 deletions(-) diff --git a/vllm/entrypoints/llava_llm.py b/vllm/entrypoints/llava_llm.py index ed6faffb3454..b146148d69a0 100644 --- a/vllm/entrypoints/llava_llm.py +++ b/vllm/entrypoints/llava_llm.py @@ -53,6 +53,9 @@ def __init__( self.llm_engine = LLaVAEngine.from_engine_args(engine_args) self.request_counter = Counter() + self.image_token_index = self.llm_engine.model_config.hf_config.image_token_index + self.image_token = self.get_tokenizer().decode(self.image_token_index) + def get_tokenizer( self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: return self.llm_engine.tokenizer @@ -136,10 +139,6 @@ def generate( else: raise ValueError("image must be str or PIL.Image") - # image_token_index = self.config.image_token_index - image_token_index = 32000 - image_token = self.get_tokenizer().decode(image_token_index) - # Add requests to the engine. num_requests = len(prompts) if prompts is not None else len( prompt_token_ids) @@ -150,10 +149,10 @@ def generate( image_token_num = 0 if prompt is not None: - image_token_num = prompt.count(image_token) + image_token_num = prompt.count(self.image_token) if token_ids is not None: _image_token_num = np.sum( - np.asarray(token_ids) == image_token_index) + np.asarray(token_ids) == self.image_token_index) if image_token_num != _image_token_num: raise ValueError("image_token_num != _image_token_num") else: diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index da9dda54b7e4..1298cc52a648 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -250,10 +250,11 @@ def load_weights(self, for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): + if name.startswith("model."): + name = name[6:] # remove "model." prefix - if name.startswith( - "model.language_model"): # load language model weights - name = name[6:] # remove "model." prefix + if name.startswith("language_model"): # load language model weights + # name = name[6:] # remove "model." prefix if "rotary_emb.inv_freq" in name: continue if ("rotary_emb.cos_cached" in name @@ -277,10 +278,10 @@ def load_weights(self, weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) - elif name.startswith("model.vision_tower") or name.startswith( - 'model.multi_modal_projector' + elif name.startswith("vision_tower") or name.startswith( + 'multi_modal_projector' ): # load vision model weights - name = name[6:] # remove "model." prefix + # name = name[6:] # remove "model." prefix if params_dict.get(name, None) is None: unused_keys.append(name) else: diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 8b16e559b24f..18378e359318 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -36,4 +36,9 @@ def get_config(model: str, if config.model_type in _CONFIG_REGISTRY: config_class = _CONFIG_REGISTRY[config.model_type] config = config_class.from_pretrained(model, revision=revision) + + if config.model_type == "llava": + config.num_attention_heads = config.text_config.num_attention_heads + config.hidden_size = config.text_config.hidden_size + config.num_hidden_layers = config.text_config.num_hidden_layers return config