diff --git a/.buildkite/download-images.sh b/.buildkite/download-images.sh new file mode 100644 index 000000000000..389a12956c3c --- /dev/null +++ b/.buildkite/download-images.sh @@ -0,0 +1,18 @@ +#!/bin/bash + +set -ex +set -o pipefail + +(which wget && which curl) || (apt-get update && apt-get install -y wget curl) + +# aws s3 sync s3://air-example-data-2/vllm_opensource_llava/ images/ +mkdir -p images +cd images +wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/stop_sign_pixel_values.pt +wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/stop_sign_image_features.pt +wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/cherry_blossom_pixel_values.pt +wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/cherry_blossom_image_features.pt +wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/stop_sign.jpg +wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/cherry_blossom.jpg + +cd - diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 6d052d0f7f4a..f6781de61af1 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -39,9 +39,15 @@ steps: - label: Models Test commands: - - pytest -v -s models --forked + - bash ../.buildkite/download-images.sh + - pytest -v -s models --ignore=models/test_llava.py --forked soft_fail: true +- label: Llava Test + commands: + - bash ../.buildkite/download-images.sh + - pytest -v -s models/test_llava.py + - label: Prefix Caching Test commands: - pytest -v -s prefix_caching diff --git a/examples/llava_example.py b/examples/llava_example.py new file mode 100644 index 000000000000..a455e9858598 --- /dev/null +++ b/examples/llava_example.py @@ -0,0 +1,84 @@ +import argparse +import os +import subprocess + +import torch + +from vllm import LLM +from vllm.sequence import MultiModalData + +# The assets are located at `s3://air-example-data-2/vllm_opensource_llava/`. + + +def run_llava_pixel_values(): + llm = LLM( + model="llava-hf/llava-1.5-7b-hf", + image_input_type="pixel_values", + image_token_id=32000, + image_input_shape="1,3,336,336", + image_feature_size=576, + ) + + prompt = "" * 576 + ( + "\nUSER: What is the content of this image?\nASSISTANT:") + + # This should be provided by another online or offline component. + images = torch.load("images/stop_sign_pixel_values.pt") + + outputs = llm.generate(prompt, + multi_modal_data=MultiModalData( + type=MultiModalData.Type.IMAGE, data=images)) + for o in outputs: + generated_text = o.outputs[0].text + print(generated_text) + + +def run_llava_image_features(): + llm = LLM( + model="llava-hf/llava-1.5-7b-hf", + image_input_type="image_features", + image_token_id=32000, + image_input_shape="1,576,1024", + image_feature_size=576, + ) + + prompt = "" * 576 + ( + "\nUSER: What is the content of this image?\nASSISTANT:") + + # This should be provided by another online or offline component. + images = torch.load("images/stop_sign_image_features.pt") + + outputs = llm.generate(prompt, + multi_modal_data=MultiModalData( + type=MultiModalData.Type.IMAGE, data=images)) + for o in outputs: + generated_text = o.outputs[0].text + print(generated_text) + + +def main(args): + if args.type == "pixel_values": + run_llava_pixel_values() + else: + run_llava_image_features() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Demo on Llava") + parser.add_argument("--type", + type=str, + choices=["pixel_values", "image_features"], + default="pixel_values", + help="image input type") + args = parser.parse_args() + # Download from s3 + s3_bucket_path = "s3://air-example-data-2/vllm_opensource_llava/" + local_directory = "images" + + # Make sure the local directory exists or create it + os.makedirs(local_directory, exist_ok=True) + + # Use AWS CLI to sync the directory + subprocess.check_call( + ["aws", "s3", "sync", s3_bucket_path, local_directory]) + main(args) diff --git a/requirements-dev.txt b/requirements-dev.txt index 72525d7c1228..78a239bc31e0 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -24,6 +24,10 @@ openai requests ray peft +awscli # Benchmarking aiohttp + +# Multimodal +pillow diff --git a/tests/conftest.py b/tests/conftest.py index 40a25ba01269..3409f87349eb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,16 +3,39 @@ import pytest import torch -from transformers import AutoModelForCausalLM +from PIL import Image +from transformers import (AutoModelForCausalLM, AutoProcessor, + LlavaForConditionalGeneration) from vllm import LLM, SamplingParams -from vllm.config import TokenizerPoolConfig +from vllm.config import TokenizerPoolConfig, VisionLanguageConfig +from vllm.sequence import MultiModalData from vllm.transformers_utils.tokenizer import get_tokenizer _TEST_DIR = os.path.dirname(__file__) _TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")] _LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")] +# Multi modal related +_PIXEL_VALUES_FILES = [ + os.path.join(_TEST_DIR, "images", filename) for filename in + ["stop_sign_pixel_values.pt", "cherry_blossom_pixel_values.pt"] +] +_IMAGE_FEATURES_FILES = [ + os.path.join(_TEST_DIR, "images", filename) for filename in + ["stop_sign_image_features.pt", "cherry_blossom_image_features.pt"] +] +_IMAGE_FILES = [ + os.path.join(_TEST_DIR, "images", filename) + for filename in ["stop_sign.jpg", "cherry_blossom.jpg"] +] +_IMAGE_PROMPTS = [ + "\nUSER: What's the content of the image?\nASSISTANT:", + "\nUSER: What is the season?\nASSISTANT:" +] +assert len(_PIXEL_VALUES_FILES) == len(_IMAGE_FEATURES_FILES) == len( + _IMAGE_FILES) == len(_IMAGE_PROMPTS) + def _read_prompts(filename: str) -> List[str]: with open(filename, "r") as f: @@ -20,6 +43,39 @@ def _read_prompts(filename: str) -> List[str]: return prompts +@pytest.fixture(scope="session") +def hf_image_prompts() -> List[str]: + return _IMAGE_PROMPTS + + +@pytest.fixture(scope="session") +def hf_images() -> List[Image.Image]: + return [Image.open(filename) for filename in _IMAGE_FILES] + + +@pytest.fixture() +def vllm_images(request) -> "torch.Tensor": + vision_language_config = request.getfixturevalue("model_and_config")[1] + all_images = [] + if vision_language_config.image_input_type == ( + VisionLanguageConfig.ImageInputType.IMAGE_FEATURES): + filenames = _IMAGE_FEATURES_FILES + else: + filenames = _PIXEL_VALUES_FILES + for filename in filenames: + all_images.append(torch.load(filename)) + return torch.concat(all_images, dim=0) + + +@pytest.fixture() +def vllm_image_prompts(request) -> List[str]: + vision_language_config = request.getfixturevalue("model_and_config")[1] + return [ + "" * (vision_language_config.image_feature_size - 1) + p + for p in _IMAGE_PROMPTS + ] + + @pytest.fixture def example_prompts() -> List[str]: prompts = [] @@ -42,6 +98,10 @@ def example_long_prompts() -> List[str]: "float": torch.float, } +_VISION_LANGUAGE_MODELS = { + "llava-hf/llava-1.5-7b-hf": LlavaForConditionalGeneration, +} + class HfRunner: @@ -53,11 +113,24 @@ def __init__( ) -> None: assert dtype in _STR_DTYPE_TO_TORCH_DTYPE torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype] - self.model = AutoModelForCausalLM.from_pretrained( - model_name, - torch_dtype=torch_dtype, - trust_remote_code=True, - ).cuda() + self.model_name = model_name + if model_name not in _VISION_LANGUAGE_MODELS: + self.model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=torch_dtype, + trust_remote_code=True, + ).cuda() + self.processor = None + else: + self.model = _VISION_LANGUAGE_MODELS[model_name].from_pretrained( + model_name, + torch_dtype=torch_dtype, + trust_remote_code=True, + ).cuda() + self.processor = AutoProcessor.from_pretrained( + model_name, + torch_dtype=torch_dtype, + ) if tokenizer_name is None: tokenizer_name = model_name self.tokenizer = get_tokenizer(tokenizer_name, trust_remote_code=True) @@ -65,13 +138,28 @@ def __init__( def generate( self, prompts: List[str], + images: Optional[List[Image.Image]] = None, **kwargs, ) -> List[Tuple[List[int], str]]: outputs: List[Tuple[List[int], str]] = [] - for prompt in prompts: - input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids + if images: + assert len(prompts) == len(images) + for i, prompt in enumerate(prompts): + if self.model_name not in _VISION_LANGUAGE_MODELS: + input_ids = self.tokenizer(prompt, + return_tensors="pt").input_ids + inputs = {"input_ids": input_ids.cuda()} + else: + image = images[i] if images else None + inputs = self.processor(text=prompt, + images=image, + return_tensors="pt") + inputs = { + key: value.cuda() if value is not None else None + for key, value in inputs.items() + } output_ids = self.model.generate( - input_ids.cuda(), + **inputs, use_cache=True, **kwargs, ) @@ -88,10 +176,12 @@ def generate_greedy( self, prompts: List[str], max_tokens: int, + images: Optional["torch.Tensor"] = None, ) -> List[Tuple[List[int], str]]: outputs = self.generate(prompts, do_sample=False, - max_new_tokens=max_tokens) + max_new_tokens=max_tokens, + images=images) for i in range(len(outputs)): output_ids, output_str = outputs[i] outputs[i] = (output_ids[0], output_str[0]) @@ -183,9 +273,16 @@ def generate( self, prompts: List[str], sampling_params: SamplingParams, + images: Optional["torch.Tensor"] = None, ) -> List[Tuple[List[int], str]]: - req_outputs = self.model.generate(prompts, - sampling_params=sampling_params) + if images is not None: + assert len(prompts) == images.shape[0] + req_outputs = self.model.generate( + prompts, + sampling_params=sampling_params, + multi_modal_data=MultiModalData(type=MultiModalData.Type.IMAGE, + data=images) + if images is not None else None) outputs = [] for req_output in req_outputs: prompt_str = req_output.prompt @@ -222,9 +319,10 @@ def generate_greedy( self, prompts: List[str], max_tokens: int, + images: Optional[torch.Tensor] = None, ) -> List[Tuple[List[int], str]]: greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens) - outputs = self.generate(prompts, greedy_params) + outputs = self.generate(prompts, greedy_params, images=images) return [(output_ids[0], output_str[0]) for output_ids, output_str in outputs] diff --git a/tests/models/test_llava.py b/tests/models/test_llava.py new file mode 100644 index 000000000000..b36d37c2e041 --- /dev/null +++ b/tests/models/test_llava.py @@ -0,0 +1,110 @@ +import gc +from dataclasses import fields +from enum import Enum +from typing import Dict, List, Tuple + +import pytest +import torch +from transformers import AutoTokenizer + +from vllm.config import VisionLanguageConfig + +model_and_vl_config = [ + ("llava-hf/llava-1.5-7b-hf", + VisionLanguageConfig( + image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES, + image_feature_size=576, + image_token_id=32000, + image_input_shape=(1, 3, 336, 336))), + ("llava-hf/llava-1.5-7b-hf", + VisionLanguageConfig( + image_input_type=VisionLanguageConfig.ImageInputType.IMAGE_FEATURES, + image_feature_size=576, + image_token_id=32000, + image_input_shape=(1, 576, 1024))) +] + + +def as_dict(vision_language_config: VisionLanguageConfig) -> Dict: + """Flatten vision language config to pure args. + + Compatible with what llm entrypoint expects. + """ + result = {} + for field in fields(vision_language_config): + value = getattr(vision_language_config, field.name) + if isinstance(value, Enum): + result[field.name] = value.name.lower() + elif isinstance(value, tuple): + result[field.name] = ",".join([str(item) for item in value]) + else: + result[field.name] = value + return result + + +def sanitize_vllm_output(vllm_output: Tuple[List[int], str], + vision_language_config: VisionLanguageConfig, + model_id: str): + """Sanitize vllm output to be comparable with hf output. + The function reduces `input_ids` from 1, 32000, 32000, ..., 32000, + x1, x2, x3 ... to 1, 32000, x1, x2, x3 ... + It also reduces `output_str` from "bla" to "bla". + """ + tokenizer = AutoTokenizer.from_pretrained(model_id) + image_token_str = tokenizer.decode(vision_language_config.image_token_id) + image_token_str_len = len(image_token_str) + input_ids, output_str = vllm_output + sanitized_input_ids = input_ids[0:2] + input_ids[2 + vision_language_config + .image_feature_size - 1:] + sanitzied_output_str = output_str[vision_language_config. + image_feature_size * + image_token_str_len:] + return sanitized_input_ids, sanitzied_output_str + + +@pytest.mark.parametrize("worker_use_ray", [False]) +@pytest.mark.parametrize("model_and_config", model_and_vl_config) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [128]) +def test_models(hf_runner, vllm_runner, hf_image_prompts, hf_images, + vllm_image_prompts, vllm_images, model_and_config: tuple, + dtype: str, max_tokens: int, worker_use_ray: bool) -> None: + """Inference result should be the same between hf and vllm. + + All the image fixtures for the test is under tests/images. + For huggingface runner, we provide the raw images as input. + For vllm runner, we provide image tensors and corresponding + vision language config as input. + Note, the text input is also adjusted to abide by vllm contract. + The text output is sanitized to be able to compare with hf. + """ + model_id, vision_language_config = model_and_config + hf_model = hf_runner(model_id, dtype=dtype) + hf_outputs = hf_model.generate_greedy(hf_image_prompts, + max_tokens, + images=hf_images) + del hf_model + + gc.collect() + torch.cuda.empty_cache() + + vllm_model = vllm_runner(model_id, + dtype=dtype, + worker_use_ray=worker_use_ray, + **as_dict(vision_language_config)) + vllm_outputs = vllm_model.generate_greedy(vllm_image_prompts, + max_tokens, + images=vllm_images) + del vllm_model + + gc.collect() + torch.cuda.empty_cache() + + for i in range(len(hf_image_prompts)): + hf_output_ids, hf_output_str = hf_outputs[i] + vllm_output_ids, vllm_output_str = sanitize_vllm_output( + vllm_outputs[i], vision_language_config, model_id) + assert hf_output_str == vllm_output_str, ( + f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") + assert hf_output_ids == vllm_output_ids, ( + f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") diff --git a/tests/models/test_models.py b/tests/models/test_models.py index 81189e25d4f1..f816cf388b1c 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -25,7 +25,7 @@ @pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [128]) def test_models( hf_runner, diff --git a/tests/spec_decode/utils.py b/tests/spec_decode/utils.py index e6756195694c..0cd9a4b1d581 100644 --- a/tests/spec_decode/utils.py +++ b/tests/spec_decode/utils.py @@ -109,7 +109,7 @@ def create_worker(cls: type, ) (model_config, cache_config, parallel_config, scheduler_config, - device_config, _) = engine_args.create_engine_configs() + device_config, _, _) = engine_args.create_engine_configs() distributed_init_method = get_distributed_init_method( get_ip(), get_open_port()) diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 12e3c8eff2ce..930ecad3f175 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -35,7 +35,7 @@ def test_prepare_prompt(batch_size): prompt_len - 1) selected_token_start_idx += prompt_len (input_tokens, input_positions, attn_metadata, return_prompt_lens, _, _, _, - _) = (model_runner._prepare_prompt(seq_group_metadata_list)) + _, _) = (model_runner._prepare_prompt(seq_group_metadata_list)) assert return_prompt_lens == prompt_lens # Verify input metadata is correct for prompts. diff --git a/tests/worker/test_swap.py b/tests/worker/test_swap.py index bc86fb574e02..0bbf85f59075 100644 --- a/tests/worker/test_swap.py +++ b/tests/worker/test_swap.py @@ -11,7 +11,7 @@ def test_swap() -> None: dtype="half", load_format="dummy") (model_config, cache_config, parallel_config, scheduler_config, - device_config, _) = engine_args.create_engine_configs() + device_config, _, _) = engine_args.create_engine_configs() cache_config.num_gpu_blocks = 100 cache_config.num_cpu_blocks = 100 diff --git a/vllm/config.py b/vllm/config.py index 6070d9d9e50f..3ef9497eb032 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1,3 +1,4 @@ +import enum import json import os from dataclasses import dataclass @@ -8,7 +9,7 @@ from transformers import PretrainedConfig from vllm.logger import init_logger -from vllm.transformers_utils.config import get_config +from vllm.transformers_utils.config import get_config, get_hf_text_config from vllm.utils import get_cpu_memory, get_nvcc_cuda_version, is_hip, is_neuron if TYPE_CHECKING: @@ -118,8 +119,9 @@ def __init__( self.hf_config = get_config(self.model, trust_remote_code, revision, code_revision) - self.dtype = _get_and_verify_dtype(self.hf_config, dtype) - self.max_model_len = _get_and_verify_max_len(self.hf_config, + self.hf_text_config = get_hf_text_config(self.hf_config) + self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) + self.max_model_len = _get_and_verify_max_len(self.hf_text_config, max_model_len) self._verify_load_format() self._verify_tokenizer_mode() @@ -218,7 +220,7 @@ def verify_with_parallel_config( self, parallel_config: "ParallelConfig", ) -> None: - total_num_attention_heads = self.hf_config.num_attention_heads + total_num_attention_heads = self.hf_text_config.num_attention_heads tensor_parallel_size = parallel_config.tensor_parallel_size if total_num_attention_heads % tensor_parallel_size != 0: raise ValueError( @@ -226,7 +228,7 @@ def verify_with_parallel_config( " must be divisible by tensor parallel size " f"({tensor_parallel_size}).") - total_num_hidden_layers = self.hf_config.num_hidden_layers + total_num_hidden_layers = self.hf_text_config.num_hidden_layers pipeline_parallel_size = parallel_config.pipeline_parallel_size if total_num_hidden_layers % pipeline_parallel_size != 0: raise ValueError( @@ -241,22 +243,23 @@ def get_sliding_window(self) -> Optional[int]: # Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in # addition to sliding window size. We check if that field is present # and if it's False, return None. - if (hasattr(self.hf_config, "use_sliding_window") - and not self.hf_config.use_sliding_window): + if (hasattr(self.hf_text_config, "use_sliding_window") + and not self.hf_text_config.use_sliding_window): return None - return getattr(self.hf_config, "sliding_window", None) + return getattr(self.hf_text_config, "sliding_window", None) def get_vocab_size(self) -> int: - return self.hf_config.vocab_size + return self.hf_text_config.vocab_size def get_hidden_size(self) -> int: - return self.hf_config.hidden_size + return self.hf_text_config.hidden_size def get_head_size(self) -> int: - if hasattr(self.hf_config, "head_dim"): - return self.hf_config.head_dim + if hasattr(self.hf_text_config, "head_dim"): + return self.hf_text_config.head_dim # FIXME(woosuk): This may not be true for all models. - return self.hf_config.hidden_size // self.hf_config.num_attention_heads + return (self.hf_text_config.hidden_size // + self.hf_text_config.num_attention_heads) def get_total_num_kv_heads(self) -> int: """Returns the total number of KV heads.""" @@ -268,7 +271,7 @@ def get_total_num_kv_heads(self) -> int: new_decoder_arch_falcon = ( self.hf_config.model_type in falcon_model_types and getattr(self.hf_config, "new_decoder_architecture", False)) - if not new_decoder_arch_falcon and getattr(self.hf_config, + if not new_decoder_arch_falcon and getattr(self.hf_text_config, "multi_query", False): # Multi-query attention, only one KV head. # Currently, tensor parallelism is not supported in this case. @@ -284,13 +287,13 @@ def get_total_num_kv_heads(self) -> int: "multi_query_group_num", ] for attr in attributes: - num_kv_heads = getattr(self.hf_config, attr, None) + num_kv_heads = getattr(self.hf_text_config, attr, None) if num_kv_heads is not None: return num_kv_heads # For non-grouped-query attention models, the number of KV heads is # equal to the number of attention heads. - return self.hf_config.num_attention_heads + return self.hf_text_config.num_attention_heads def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int: """Returns the number of KV heads per GPU.""" @@ -303,7 +306,7 @@ def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int: total_num_kv_heads // parallel_config.tensor_parallel_size) def get_num_layers(self, parallel_config: "ParallelConfig") -> int: - total_num_hidden_layers = self.hf_config.num_hidden_layers + total_num_hidden_layers = self.hf_text_config.num_hidden_layers return total_num_hidden_layers // parallel_config.pipeline_parallel_size @@ -627,6 +630,48 @@ def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig): "LoRA is enabled.") +@dataclass +class VisionLanguageConfig: + """Configs the input data format and how models should run for + vision language models.""" + + class ImageInputType(enum.Enum): + """Image input type into the vision language model. + + An image roughly goes through the following transformation: + Raw image --> pixel values --> image features --> image embeddings. + + The difference between different image input types is where the + image encoder (pixel values --> image features) is run. + Different image input types also correspond to different tensor shapes. + + For example, for Llava, PIXEL_VALUES: (1, 3, 336, 336). + IMAGE_FEATURES: (1, 576, 1024). + """ + PIXEL_VALUES = enum.auto() + IMAGE_FEATURES = enum.auto() + + image_input_type: ImageInputType + # The input id corresponding to image token. + image_token_id: int + # Used for running `run_prefill_max_token`. + # For models that support varying resolution, this corresponds to + # worst case scenario (biggest supported resolution). + image_input_shape: tuple + image_feature_size: int + + @classmethod + def get_image_input_enum_type( + cls, value: str) -> "VisionLanguageConfig.ImageInputType": + """Get the image input type from a string.""" + try: + return cls.ImageInputType[value.upper()] + except KeyError as e: + raise ValueError(f"{value} is not a valid choice. " + f"Expecting to choose from " + f"{[x.name for x in cls.ImageInputType]}.") from e + + _STR_DTYPE_TO_TORCH_DTYPE = { "half": torch.float16, "float16": torch.float16, diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 72b4cf043e90..899816b6a3fa 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -388,6 +388,12 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: computed_block_nums=self.block_manager. get_common_computed_block_ids(seq_group), state=seq_group.state, + # `multi_modal_data` will only be present for the 1st comm + # between engine and worker. + # the subsequent comms can still use delta, but + # `multi_modal_data` will be None. + multi_modal_data=seq_group.multi_modal_data + if scheduler_outputs.prompt_run else None, ) seq_group_metadata_list.append(seq_group_metadata) return seq_group_metadata_list, scheduler_outputs diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index fc6665dbe64b..6dcd60a1185c 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -4,7 +4,9 @@ from typing import Optional, Tuple from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, TokenizerPoolConfig) + ParallelConfig, SchedulerConfig, TokenizerPoolConfig, + VisionLanguageConfig) +from vllm.utils import str_to_int_tuple @dataclass @@ -50,6 +52,11 @@ class EngineArgs: max_cpu_loras: Optional[int] = None device: str = 'auto' ray_workers_use_nsight: bool = False + # Related to Vision-language models such as llava + image_input_type: Optional[str] = None + image_token_id: Optional[int] = None + image_input_shape: Optional[str] = None + image_feature_size: Optional[int] = None scheduler_delay_factor: float = 0.0 def __post_init__(self): @@ -305,6 +312,31 @@ def add_cli_args( default=EngineArgs.device, choices=["auto", "cuda", "neuron"], help='Device type for vLLM execution.') + # Related to Vision-language models such as llava + parser.add_argument( + '--image-input-type', + type=str, + default=None, + choices=[ + t.name.lower() for t in VisionLanguageConfig.ImageInputType + ], + help=('The image input type passed into vLLM. ' + 'Should be one of "pixel_values" or "image_features".')) + parser.add_argument('--image-token-id', + type=int, + default=None, + help=('Input id for image token.')) + parser.add_argument( + '--image-input-shape', + type=str, + default=None, + help=('The biggest image input shape (worst for memory footprint) ' + 'given an input type. Only used for vLLM\'s profile_run.')) + parser.add_argument( + '--image-feature-size', + type=int, + default=None, + help=('The image feature size along the context dimension.')) parser.add_argument( '--scheduler-delay-factor', type=float, @@ -324,7 +356,8 @@ def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs': def create_engine_configs( self, ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig, - DeviceConfig, Optional[LoRAConfig]]: + DeviceConfig, Optional[LoRAConfig], + Optional[VisionLanguageConfig]]: device_config = DeviceConfig(self.device) model_config = ModelConfig( self.model, self.tokenizer, self.tokenizer_mode, @@ -358,8 +391,25 @@ def create_engine_configs( lora_dtype=self.lora_dtype, max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras and self.max_cpu_loras > 0 else None) if self.enable_lora else None + + if self.image_input_type: + if (not self.image_token_id or not self.image_input_shape + or not self.image_feature_size): + raise ValueError( + 'Specify `image_token_id`, `image_input_shape` and ' + '`image_feature_size` together with `image_input_type`.') + vision_language_config = VisionLanguageConfig( + image_input_type=VisionLanguageConfig. + get_image_input_enum_type(self.image_input_type), + image_token_id=self.image_token_id, + image_input_shape=str_to_int_tuple(self.image_input_shape), + image_feature_size=self.image_feature_size, + ) + else: + vision_language_config = None + return (model_config, cache_config, parallel_config, scheduler_config, - device_config, lora_config) + device_config, lora_config, vision_language_config) @dataclass diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index d642915aee19..9da228bb3a26 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -15,6 +15,7 @@ from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams +from vllm.sequence import MultiModalData logger = init_logger(__name__) ENGINE_ITERATION_TIMEOUT_S = int( @@ -240,6 +241,7 @@ async def add_request_async( prompt_token_ids: Optional[List[int]] = None, arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, ) -> None: if lora_request is not None and not self.lora_config: raise ValueError(f"Got lora_request {lora_request} but LoRA is " @@ -252,14 +254,13 @@ async def add_request_async( prompt_token_ids=prompt_token_ids, lora_request=lora_request) - return self.add_request( - request_id, - prompt=prompt, - prompt_token_ids=prompt_token_ids, - sampling_params=sampling_params, - arrival_time=arrival_time, - lora_request=lora_request, - ) + return self.add_request(request_id, + prompt=prompt, + prompt_token_ids=prompt_token_ids, + sampling_params=sampling_params, + arrival_time=arrival_time, + lora_request=lora_request, + multi_modal_data=multi_modal_data) async def check_health_async(self) -> None: self.model_executor.check_health() @@ -486,6 +487,7 @@ async def add_request( prompt_token_ids: Optional[List[int]] = None, arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, ) -> AsyncStream: if self.log_requests: shortened_prompt = prompt @@ -534,7 +536,9 @@ async def add_request( sampling_params=sampling_params, prompt_token_ids=prompt_token_ids, arrival_time=arrival_time, - lora_request=lora_request) + lora_request=lora_request, + multi_modal_data=multi_modal_data, + ) return stream @@ -545,6 +549,7 @@ async def generate( request_id: str, prompt_token_ids: Optional[List[int]] = None, lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None ) -> AsyncIterator[RequestOutput]: """Generate outputs for a request. @@ -560,6 +565,7 @@ async def generate( prompt_token_ids: The token IDs of the prompt. If None, we use the tokenizer to convert the prompts to token IDs. lora_request: LoRA request to use for generation, if any. + multi_modal_data: Multi modal data per request. Yields: The output `RequestOutput` objects from the LLMEngine for the @@ -619,6 +625,7 @@ async def generate( prompt_token_ids=prompt_token_ids, arrival_time=arrival_time, lora_request=lora_request, + multi_modal_data=multi_modal_data, ) async for request_output in stream: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 1984b94024a1..144829739f68 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -5,7 +5,7 @@ import vllm from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig) + ParallelConfig, SchedulerConfig, VisionLanguageConfig) from vllm.core.scheduler import Scheduler, SchedulerOutputs from vllm.engine.arg_utils import EngineArgs from vllm.engine.metrics import StatLogger, Stats @@ -15,8 +15,9 @@ from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams -from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup, - SequenceGroupOutput, SequenceOutput, SequenceStatus) +from vllm.sequence import (MultiModalData, SamplerOutput, Sequence, + SequenceGroup, SequenceGroupOutput, SequenceOutput, + SequenceStatus) from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup, get_tokenizer_group) @@ -62,6 +63,7 @@ def __init__( scheduler_config: SchedulerConfig, device_config: DeviceConfig, lora_config: Optional[LoRAConfig], + vision_language_config: Optional["VisionLanguageConfig"], executor_class: Type[ExecutorBase], log_stats: bool, ) -> None: @@ -90,6 +92,7 @@ def __init__( self.model_config = model_config self.cache_config = cache_config self.lora_config = lora_config + self.vision_language_config = vision_language_config self.parallel_config = parallel_config self.scheduler_config = scheduler_config self.device_config = device_config @@ -102,7 +105,8 @@ def __init__( self.model_executor = executor_class(model_config, cache_config, parallel_config, scheduler_config, - device_config, lora_config) + device_config, lora_config, + vision_language_config) # Ping the tokenizer to ensure liveness if it runs in a # different process. @@ -170,7 +174,6 @@ def _init_tokenizer(self, **tokenizer_init_kwargs): trust_remote_code=self.model_config.trust_remote_code, revision=self.model_config.tokenizer_revision) init_kwargs.update(tokenizer_init_kwargs) - self.tokenizer: BaseTokenizerGroup = get_tokenizer_group( self.parallel_config.tokenizer_pool_config, **init_kwargs) @@ -212,6 +215,7 @@ def add_request( prompt_token_ids: Optional[List[int]] = None, arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, ) -> None: """Add a request to the engine's request pool. @@ -228,6 +232,7 @@ def add_request( use the tokenizer to convert the prompts to token IDs. arrival_time: The arrival time of the request. If None, we use the current monotonic time. + multi_modal_data: Multi modal data per request. Details: - Set arrival_time to the current time if it is None. @@ -288,7 +293,7 @@ def add_request( # Create the sequence group. seq_group = SequenceGroup(request_id, [seq], sampling_params, - arrival_time, lora_request) + arrival_time, lora_request, multi_modal_data) # Add the sequence group to the scheduler. self.scheduler.add_seq_group(seq_group) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index db223d809ea0..504f5e201c94 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1,5 +1,6 @@ from typing import List, Optional, Union +import torch from tqdm import tqdm from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast @@ -8,6 +9,7 @@ from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams +from vllm.sequence import MultiModalData from vllm.utils import Counter @@ -126,6 +128,7 @@ def generate( prompt_token_ids: Optional[List[List[int]]] = None, use_tqdm: bool = True, lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, ) -> List[RequestOutput]: """Generates the completions for the input prompts. @@ -141,6 +144,7 @@ def generate( use the tokenizer to convert the prompts to token IDs. use_tqdm: Whether to use tqdm to display the progress bar. lora_request: LoRA request to use for generation, if any. + multi_modal_data: Multi modal data. Returns: A list of `RequestOutput` objects containing the generated @@ -160,6 +164,9 @@ def generate( # Use default sampling params. sampling_params = SamplingParams() + if multi_modal_data: + multi_modal_data.data = multi_modal_data.data.to(torch.float16) + # Add requests to the engine. num_requests = len(prompts) if prompts is not None else len( prompt_token_ids) @@ -167,10 +174,17 @@ def generate( prompt = prompts[i] if prompts is not None else None token_ids = None if prompt_token_ids is None else prompt_token_ids[ i] - self._add_request(prompt, - sampling_params, - token_ids, - lora_request=lora_request) + self._add_request( + prompt, + sampling_params, + token_ids, + lora_request=lora_request, + # Get ith image while maintaining the batch dim. + multi_modal_data=MultiModalData( + type=multi_modal_data.type, + data=multi_modal_data.data[i].unsqueeze(0)) + if multi_modal_data else None, + ) return self._run_engine(use_tqdm) def _add_request( @@ -179,13 +193,15 @@ def _add_request( sampling_params: SamplingParams, prompt_token_ids: Optional[List[int]], lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, ) -> None: request_id = str(next(self.request_counter)) self.llm_engine.add_request(request_id, prompt, sampling_params, prompt_token_ids, - lora_request=lora_request) + lora_request=lora_request, + multi_modal_data=multi_modal_data) def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]: # Initialize tqdm. diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index cc6c12edcffe..55180d6110b6 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -2,7 +2,7 @@ from typing import Dict, List, Optional from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig) + ParallelConfig, SchedulerConfig, VisionLanguageConfig) from vllm.lora.request import LoRARequest from vllm.sequence import SamplerOutput, SequenceGroupMetadata @@ -24,6 +24,7 @@ def __init__( scheduler_config: SchedulerConfig, device_config: DeviceConfig, lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], ) -> None: raise NotImplementedError diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index a48f0ac7f0e5..90c388244176 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -1,7 +1,7 @@ from typing import Dict, List, Optional from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig) + ParallelConfig, SchedulerConfig, VisionLanguageConfig) from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.executor.utils import check_block_size_valid from vllm.logger import init_logger @@ -23,6 +23,7 @@ def __init__( scheduler_config: SchedulerConfig, device_config: DeviceConfig, lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], ) -> None: self.model_config = model_config self.cache_config = cache_config @@ -30,6 +31,7 @@ def __init__( self.parallel_config = parallel_config self.scheduler_config = scheduler_config self.device_config = device_config + self.vision_language_config = vision_language_config # Instantiate the worker and load the model to GPU. self._init_worker() @@ -56,6 +58,7 @@ def _init_worker(self): rank=0, distributed_init_method=distributed_init_method, lora_config=self.lora_config, + vision_language_config=self.vision_language_config, kv_cache_dtype=self.cache_config.cache_dtype, is_driver_worker=True, ) diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index e39b881f7cab..f2fc8aec9887 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig) + ParallelConfig, SchedulerConfig, VisionLanguageConfig) from vllm.engine.ray_utils import RayWorkerVllm, ray from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.executor.utils import check_block_size_valid @@ -40,6 +40,7 @@ def __init__( scheduler_config: SchedulerConfig, device_config: DeviceConfig, lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], ) -> None: self.model_config = model_config self.cache_config = cache_config @@ -47,6 +48,7 @@ def __init__( self.parallel_config = parallel_config self.scheduler_config = scheduler_config self.device_config = device_config + self.vision_language_config = vision_language_config assert self.parallel_config.worker_use_ray placement_group = self.parallel_config.placement_group @@ -181,6 +183,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", driver_rank, distributed_init_method, lora_config=self.lora_config, + vision_language_config=self.vision_language_config, kv_cache_dtype=kv_cache_dtype, is_driver_worker=True, ) diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index cb64d80c8147..958c9b97f725 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -7,9 +7,14 @@ from vllm.config import DeviceConfig, ModelConfig from vllm.model_executor.models import ModelRegistry +from vllm.model_executor.models.llava import LlavaForConditionalGeneration from vllm.model_executor.weight_utils import (get_quant_config, initialize_dummy_weights) +_VISION_MODEL_CLASSES = [ + LlavaForConditionalGeneration, +] + @contextlib.contextmanager def _set_default_torch_dtype(dtype: torch.dtype): @@ -40,6 +45,7 @@ def _get_model_architecture(model_config: ModelConfig) -> Type[nn.Module]: def get_model(model_config: ModelConfig, device_config: DeviceConfig, **kwargs) -> nn.Module: lora_config = kwargs.get("lora_config", None) + vision_language_config = kwargs.get("vision_language_config", None) model_class = _get_model_architecture(model_config) # Get the (maybe quantized) linear method. @@ -76,7 +82,11 @@ def get_model(model_config: ModelConfig, device_config: DeviceConfig, "be added in the future. If this is important to you, " "please open an issue on github.") else: - model = model_class(model_config.hf_config, linear_method) + if model_class not in _VISION_MODEL_CLASSES: + model = model_class(model_config.hf_config, linear_method) + else: + model = model_class(model_config.hf_config, + vision_language_config, linear_method) if model_config.load_format == "dummy": # NOTE(woosuk): For accurate performance evaluation, we assign # random values to the weights. diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index efadb1c504ca..d561316886ca 100755 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -29,6 +29,8 @@ "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"), "JAISLMHeadModel": ("jais", "JAISLMHeadModel"), "LlamaForCausalLM": ("llama", "LlamaForCausalLM"), + "LlavaForConditionalGeneration": + ("llava", "LlavaForConditionalGeneration"), # For decapoda-research/llama-* "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), "MistralForCausalLM": ("llama", "LlamaForCausalLM"), diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 2cd56f0ce59d..57857deb9eb8 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -250,14 +250,21 @@ def __init__( ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, - input_ids: torch.Tensor, + input_ids: Optional[torch.Tensor], positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) residual = None for i in range(len(self.layers)): layer = self.layers[i] diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py new file mode 100644 index 000000000000..c824efdf0468 --- /dev/null +++ b/vllm/model_executor/models/llava.py @@ -0,0 +1,246 @@ +from typing import List, Optional, Tuple + +import torch +from torch import nn +# TODO(xwjiang): We should port CLIPVisionModel's code over to not depend on +# transformers' impl. +from transformers import CLIPVisionModel, LlavaConfig + +from vllm.attention import AttentionMetadata +from vllm.config import VisionLanguageConfig +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import LinearMethodBase +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.models.llama import LlamaModel +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) +from vllm.sequence import SamplerOutput + +KVCache = Tuple[torch.Tensor, torch.Tensor] + +_KEYS_TO_MODIFY_MAPPING = { + "language_model.lm_head": "lm_head", + "language_model.model": "language_model", +} + + +# TODO(xwjiang): Run benchmark and decide if TP. +class LlavaMultiModalProjector(nn.Module): + + def __init__(self, vision_hidden_size: int, text_hidden_size: int, + projector_hidden_act: str): + super().__init__() + + self.linear_1 = nn.Linear(vision_hidden_size, + text_hidden_size, + bias=True) + self.act = get_act_fn(projector_hidden_act) + self.linear_2 = nn.Linear(text_hidden_size, + text_hidden_size, + bias=True) + + def forward(self, image_features): + hidden_states = self.linear_1(image_features) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +def _merge_vision_embeddings(input_ids: torch.Tensor, + inputs_embeds: torch.Tensor, + vision_embeddings: torch.Tensor, + image_token_id: int): + """In place merges in vision_embeddings with inputs_embeds.""" + mask = (input_ids == image_token_id) + inputs_embeds[mask] = vision_embeddings.view(-1, + vision_embeddings.shape[-1]) + + +class LlavaForConditionalGeneration(nn.Module): + + def __init__(self, + config: "LlavaConfig", + vision_language_config: VisionLanguageConfig, + linear_method: Optional["LinearMethodBase"] = None) -> None: + super().__init__() + self.config = config + + self.vision_language_config = vision_language_config + + assert self.vision_language_config, ( + "Provide `image_input_type` and other vision " + "related configurations through LLM entrypoint " + "or engine arguments.") + + if self.vision_language_config.image_input_type == ( + VisionLanguageConfig.ImageInputType.PIXEL_VALUES): + self.vision_tower = CLIPVisionModel(config.vision_config) + else: + self.vision_tower = None + + self.multi_modal_projector = LlavaMultiModalProjector( + vision_hidden_size=config.vision_config.hidden_size, + text_hidden_size=config.text_config.hidden_size, + projector_hidden_act=config.projector_hidden_act) + + self.linear_method = linear_method + self.language_model = LlamaModel(config.text_config, linear_method) + self.unpadded_vocab_size = config.text_config.vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.text_config.hidden_size, + org_num_embeddings=self.language_model.org_vocab_size) + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, logit_scale) + self.sampler = Sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + attn_metadata: AttentionMetadata, + image_input: Optional[torch.Tensor] = None + ) -> SamplerOutput: # noqa: E501 + """Run forward pass for Llava 1.5. + + One key thing to understand is the `input_ids` already accounts for the + positions of the to-be-inserted image embeddings. + Concretely, consider a text prompt: + "\nUSER: What's the content of the image?\nASSISTANT:". + Tokenizer outputs: + [1, 32000, 29871, 13, 11889, 29901, 1724, 29915, 29879, 278, + 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]. + The to-be-inserted image has a size of 576 (24 * 24) along the context + length dimension. + `input_ids` is thus [1, 32000, ..., 32000, 29871, 13, 11889, 29901, + 1724, 29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, + 9047, 13566, 29901]. + There will be 576 `32000` in the `input_ids`. + (32000 is the token id for ``.) + + This way, the `positions` and `attn_metadata` are consistent + with the `input_ids`. + + The model takes two types of image inputs: + PIXEL_VALUES and IMAGE_FEATURES. + The following shows how each maps to huggingface implementation. + PIXEL_VALUES: + - https://github.com/huggingface/transformers/blob/07bdbeb/src/transformers/models/llava/modeling_llava.py#L353 + IMAGE_FEATURES: + - https://github.com/huggingface/transformers/blob/07bdbeb/src/transformers/models/llava/modeling_llava.py#L430 + before going through the multi modal projector. + + Args: + input_ids: Flattened (concatenated) input_ids corresponding to a + batch. + image_input: A batch of image inputs. + For PIXEL_VALUES, expecting [1, 3, 336, 336]. + For IMAGE_FEATURES, expecting [1, 576, 1024]. + """ + if image_input is not None: + if list(image_input.shape[1:]) != list( + self.vision_language_config.image_input_shape[1:]): + raise ValueError( + f"The expected image tensor shape is batch dimension " + f"plus " + f"{self.vision_language_config.image_input_shape[1:]}." + f" You supplied {image_input.shape}. " + f"If you are using vLLM's entrypoint, make sure your " + f"supplied image input is consistent with " + f"image_input_shape in engine args.") + if self.vision_tower is not None: + # TODO(xwjiang): Maybe port minimal CLIPVisionModel over. + image_outputs = self.vision_tower(image_input, + output_hidden_states=True) + image_features = image_outputs.hidden_states[ + self.config.vision_feature_layer] + # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa + if self.config.vision_feature_select_strategy == "default": + image_features = image_features[:, 1:] + elif self.config.vision_feature_select_strategy == "full": + image_features = image_features + else: + raise ValueError( + f"Unexpected select feature strategy: " + f"{self.config.vision_feature_select_strategy}") + else: + image_features = image_input + vision_embeddings = self.multi_modal_projector(image_features) + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + _merge_vision_embeddings( + input_ids, inputs_embeds, vision_embeddings, + self.vision_language_config.image_token_id) + input_ids = None + else: + inputs_embeds = None + hidden_states = self.language_model(input_ids, + positions, + kv_caches, + attn_metadata, + inputs_embeds=inputs_embeds) + + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head.weight, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None): + # only doing this for language model part for now. + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision): + if "rotary_emb.inv_freq" in name: + continue + for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): + if key_to_modify in name: + name = name.replace(key_to_modify, new_key) + use_default_weight_loading = False + if "vision" in name: + if self.vision_tower is not None: + # We only do sharding for language model and + # not vision model for now. + use_default_weight_loading = True + else: + for (param_name, weight_name, + shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + param = params_dict[name.replace(weight_name, param_name)] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + use_default_weight_loading = True + if use_default_weight_loading: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/sequence.py b/vllm/sequence.py index 8b2855daa552..bf3679c312dd 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -303,6 +303,25 @@ class SequenceGroupState: generator: Optional = None +class MultiModalData: + """Multi modal request. + + Args: + type: The data type. + data: The actual data. + The required shape and semantic meaning of it depends on the vision + language config of the hosted model. + See `VisionLanguageConfig` in `config.py`. + """ + + class Type(enum.Enum): + IMAGE = enum.auto() + + def __init__(self, type: Type, data: "torch.Tensor"): + self.type = type + self.data = data + + class SequenceGroup: """A group of sequences that are generated from the same prompt. @@ -312,6 +331,7 @@ class SequenceGroup: sampling_params: The sampling parameters used to generate the outputs. arrival_time: The arrival time of the request. lora_request: LoRA request. + multi_modal_data: Multi modal data associated with the request. """ def __init__( @@ -321,6 +341,7 @@ def __init__( sampling_params: SamplingParams, arrival_time: float, lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, ) -> None: self.request_id = request_id self.seqs_dict = {seq.seq_id: seq for seq in seqs} @@ -333,6 +354,7 @@ def __init__( self.lora_request = lora_request self.prompt_logprobs: Optional[PromptLogprobs] = None self.state = SequenceGroupState() + self.multi_modal_data = multi_modal_data @property def prompt(self) -> str: @@ -450,6 +472,7 @@ class SequenceGroupMetadata: numbers) state: Internal state tied to this sequence group. lora_request: LoRA request. + multi_modal_data: Multi modal data. """ def __init__( @@ -462,6 +485,7 @@ def __init__( lora_request: Optional[LoRARequest] = None, computed_block_nums: Optional[List[int]] = None, state: Optional[SequenceGroupState] = None, + multi_modal_data: Optional[MultiModalData] = None, ) -> None: self.request_id = request_id self.is_prompt = is_prompt @@ -470,6 +494,7 @@ def __init__( self.block_tables = block_tables self.lora_request = lora_request self.computed_block_nums = computed_block_nums + self.multi_modal_data = multi_modal_data self.state = SequenceGroupState() if state is None else state @property diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index dc226248910e..c34ee10bebc6 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -40,3 +40,17 @@ def get_config(model: str, revision=revision, code_revision=code_revision) return config + + +def get_hf_text_config(config: PretrainedConfig): + """Get the "sub" config relevant to llm for multi modal models. + No op for pure text models. + """ + if hasattr(config, "text_config"): + # The code operates under the assumption that text_config should have + # `num_attention_heads` (among others). Assert here to fail early + # if transformers config doesn't align with this assumption. + assert hasattr(config.text_config, "num_attention_heads") + return config.text_config + else: + return config diff --git a/vllm/utils.py b/vllm/utils.py index 4b9558ffe88d..f88c52731b3b 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -377,6 +377,16 @@ def __exit__(self, exc_type, exc_val, exc_tb): gc.collect() +def str_to_int_tuple(s: str) -> Tuple[int]: + """Convert a string to a tuple of integers.""" + try: + return tuple(map(int, s.split(","))) + except ValueError as e: + raise ValueError( + "String must be a series of integers separated by commas " + f"(e.g., 1, 2, 3). Given input: {s}") from e + + def pad_to_max_length(x: List[int], max_len: int, pad: int) -> List[int]: assert len(x) <= max_len return x + [pad] * (max_len - len(x)) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index fd96e752bb15..8a08c3cbf583 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -8,7 +8,7 @@ from vllm.attention import AttentionMetadata, get_attn_backend from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, - SchedulerConfig) + SchedulerConfig, VisionLanguageConfig) from vllm.logger import init_logger from vllm.lora.layers import LoRAMapping from vllm.lora.request import LoRARequest @@ -21,7 +21,8 @@ from vllm.model_executor.parallel_utils.parallel_state import ( with_cupy_nccl_for_all_reduce) from vllm.sampling_params import SamplingParams, SamplingType -from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata +from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData, + SequenceGroupMetadata) from vllm.utils import (CudaMemoryProfiler, async_tensor_h2d, is_pin_memory_available, make_tensor_with_pad, maybe_expand_dim) @@ -49,6 +50,7 @@ def __init__( lora_config: Optional[LoRAConfig], kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, + vision_language_config: Optional[VisionLanguageConfig] = None, ): self.model_config = model_config self.parallel_config = parallel_config @@ -83,17 +85,20 @@ def __init__( self.graph_block_tables = None # Set after initial profiling. self.pin_memory = is_pin_memory_available() self.kv_cache_dtype = kv_cache_dtype + self.vision_language_config = vision_language_config self.attn_backend = get_attn_backend( self.model_config.dtype if model_config is not None else None) def load_model(self) -> None: with CudaMemoryProfiler() as m: - self.model = get_model(self.model_config, - self.device_config, - lora_config=self.lora_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config) + self.model = get_model( + self.model_config, + self.device_config, + lora_config=self.lora_config, + vision_language_config=self.vision_language_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config) self.model_memory_usage = m.consumed_memory logger.info(f"Loading model weights took " @@ -130,7 +135,8 @@ def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], - List[int], List[int], List[int], Set[LoRARequest]]: + List[int], List[int], List[int], Set[LoRARequest], + torch.Tensor]: assert len(seq_group_metadata_list) > 0 input_tokens: List[int] = [] input_positions: List[int] = [] @@ -143,6 +149,7 @@ def _prepare_prompt( context_lens: List[int] = [] subquery_lens: List[int] = [] prefix_block_tables: List[List[int]] = [] + multi_modal_input_list: List[torch.Tensor] = [] for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt seq_ids = list(seq_group_metadata.seq_data.keys()) @@ -188,6 +195,10 @@ def _prepare_prompt( (prompt_len - computed_len if seq_group_metadata.sampling_params.prompt_logprobs else 1)) + if seq_group_metadata.multi_modal_data: + multi_modal_input_list.append( + seq_group_metadata.multi_modal_data.data) + if seq_group_metadata.block_tables is None: # During memory profiling, the block tables are not initialized # yet. In this case, we just use a dummy slot mapping. @@ -236,6 +247,16 @@ def _prepare_prompt( context_lens_tensor = torch.tensor(context_lens, dtype=torch.int, device=self.device) + + if multi_modal_input_list: + assert self.vision_language_config, ( + "Multi-modal inputs are only supported by " + "vision language models.") + multi_modal_input = torch.cat(multi_modal_input_list, + dim=0).to(self.device) + else: + multi_modal_input = None + # Prepare prefix block tables max_prompt_block_table_len = max(len(t) for t in prefix_block_tables) block_tables = make_tensor_with_pad( @@ -291,7 +312,7 @@ def _prepare_prompt( ) return (input_tokens, input_positions, attn_metadata, prompt_lens, subquery_lens, lora_index_mapping, lora_prompt_mapping, - lora_requests) + lora_requests, multi_modal_input) def _prepare_decode( self, @@ -525,7 +546,7 @@ def prepare_input_tensors( self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata, - Set[int], LoRAMapping]: + Set[int], LoRAMapping, torch.Tensor]: if self.is_driver_worker: # NOTE: We assume that all sequences in the group are all prompts or # all decodes. @@ -534,13 +555,15 @@ def prepare_input_tensors( if is_prompt: (input_tokens, input_positions, attn_metadata, prompt_lens, subquery_lens, lora_index_mapping, lora_prompt_mapping, - lora_requests) = self._prepare_prompt(seq_group_metadata_list) + lora_requests, multi_modal_input + ) = self._prepare_prompt(seq_group_metadata_list) else: (input_tokens, input_positions, attn_metadata, lora_index_mapping, lora_prompt_mapping, lora_requests) = self._prepare_decode(seq_group_metadata_list) prompt_lens = [] subquery_lens = None + multi_modal_input = None sampling_metadata = self._prepare_sample(seq_group_metadata_list, prompt_lens, subquery_lens) @@ -561,6 +584,7 @@ def prepare_input_tensors( sampling_metadata.selected_token_indices, "lora_requests": lora_requests, "lora_mapping": lora_mapping, + "multi_modal_input": multi_modal_input, } metadata_dict.update(attn_metadata.asdict_zerocopy()) broadcast_tensor_dict(metadata_dict, src=0) @@ -572,6 +596,7 @@ def prepare_input_tensors( "selected_token_indices") lora_mapping = metadata_dict.pop("lora_mapping") lora_requests = metadata_dict.pop("lora_requests") + multi_modal_input = metadata_dict.pop("multi_modal_input") attn_metadata = self.attn_backend.make_metadata(**metadata_dict) sampling_metadata = SamplingMetadata( seq_groups=None, @@ -584,7 +609,8 @@ def prepare_input_tensors( ) return (input_tokens, input_positions, attn_metadata, - sampling_metadata, lora_requests, lora_mapping) + sampling_metadata, lora_requests, lora_mapping, + multi_modal_input) @torch.inference_mode() def execute_model( @@ -593,8 +619,8 @@ def execute_model( kv_caches: List[torch.Tensor], ) -> Optional[SamplerOutput]: (input_tokens, input_positions, attn_metadata, sampling_metadata, - lora_requests, - lora_mapping) = self.prepare_input_tensors(seq_group_metadata_list) + lora_requests, lora_mapping, multi_modal_input + ) = self.prepare_input_tensors(seq_group_metadata_list) if self.lora_config: self.set_active_loras(lora_requests, lora_mapping) @@ -605,12 +631,15 @@ def execute_model( model_executable = self.graph_runners[graph_batch_size] else: model_executable = self.model - hidden_states = model_executable( - input_ids=input_tokens, - positions=input_positions, - kv_caches=kv_caches, - attn_metadata=attn_metadata, - ) + execute_model_kwargs = { + "input_ids": input_tokens, + "positions": input_positions, + "kv_caches": kv_caches, + "attn_metadata": attn_metadata, + } + if self.vision_language_config: + execute_model_kwargs.update({"image_input": multi_modal_input}) + hidden_states = model_executable(**execute_model_kwargs) # Compute the logits. logits = self.model.compute_logits(hidden_states, sampling_metadata) @@ -658,10 +687,22 @@ def profile_run(self) -> None: # Profile memory usage with max_num_sequences sequences and the total # number of tokens equal to max_num_batched_tokens. seqs: List[SequenceGroupMetadata] = [] + # Additional GPU memory may be needed for vision encoding, which needs + # to be accounted for when calculating the GPU blocks for + # vLLM blocker manager. + # To exercise the worst scenario for GPU memory consumption, + # the number of seqs (batch_size) is chosen to maximize the number + # of images processed. + if self.vision_language_config: + max_num_seqs = min( + max_num_seqs, + int(max_num_batched_tokens / + self.vision_language_config.image_feature_size)) for group_id in range(max_num_seqs): seq_len = (max_num_batched_tokens // max_num_seqs + (group_id < max_num_batched_tokens % max_num_seqs)) - seq_data = SequenceData([0] * seq_len) + seq_data, fake_multi_modal_input = _prepare_fake_inputs( + seq_len, self.vision_language_config) seq = SequenceGroupMetadata( request_id=str(group_id), is_prompt=True, @@ -670,6 +711,7 @@ def profile_run(self) -> None: block_tables=None, lora_request=dummy_lora_requests_per_seq[group_id] if dummy_lora_requests_per_seq else None, + multi_modal_data=fake_multi_modal_input, ) seqs.append(seq) @@ -831,6 +873,7 @@ def capture( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, memory_pool, + **kwargs, ) -> None: assert self.graph is None # Run the model once without capturing the graph. @@ -842,6 +885,7 @@ def capture( positions, kv_caches, attn_metadata, + **kwargs, ) torch.cuda.synchronize() @@ -856,6 +900,7 @@ def capture( positions, kv_caches, attn_metadata, + **kwargs, ) torch.cuda.synchronize() @@ -877,6 +922,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + **kwargs, ) -> torch.Tensor: # KV caches are fixed tensors, so we don't need to copy them. del kv_caches @@ -922,3 +968,21 @@ def _get_graph_batch_size(batch_size: int) -> int: else: return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) // _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT) + + +def _prepare_fake_inputs( + seq_len: int, vision_language_config: Optional[VisionLanguageConfig]): + """Prepare fake inputs for profile run.""" + if vision_language_config: + prompt_tokens = [ + vision_language_config.image_token_id + ] * vision_language_config.image_feature_size + [0] * ( + seq_len - vision_language_config.image_feature_size) + fake_image_input = MultiModalData( + type=MultiModalData.Type.IMAGE, + data=torch.zeros(vision_language_config.image_input_shape, + dtype=torch.float16)) + else: + prompt_tokens = [0] * seq_len + fake_image_input = None + return SequenceData(prompt_tokens), fake_image_input diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 48c276681e9d..46a62fa69325 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -7,7 +7,7 @@ import torch.distributed from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig) + ParallelConfig, SchedulerConfig, VisionLanguageConfig) from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed from vllm.model_executor.parallel_utils import cupy_utils @@ -39,6 +39,7 @@ def __init__( rank: int, distributed_init_method: str, lora_config: Optional[LoRAConfig] = None, + vision_language_config: Optional[VisionLanguageConfig] = None, kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, ) -> None: @@ -54,13 +55,20 @@ def __init__( if self.is_driver_worker: assert self.rank == 0, "The driver worker must have rank 0." - self.model_runner = ModelRunner(model_config, - parallel_config, - scheduler_config, - device_config, - lora_config=self.lora_config, - kv_cache_dtype=kv_cache_dtype, - is_driver_worker=is_driver_worker) + self.vision_language_config = vision_language_config + if self.vision_language_config: + assert not self.lora_config, ( + "To be tested: vision language model with LoRA settings.") + + self.model_runner = ModelRunner( + model_config, + parallel_config, + scheduler_config, + device_config, + lora_config=self.lora_config, + kv_cache_dtype=kv_cache_dtype, + is_driver_worker=is_driver_worker, + vision_language_config=vision_language_config) # Uninitialized cache engine. Will be initialized by # self.init_cache_engine(). self.cache_config = None