Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 46 additions & 8 deletions examples/offline_inference_vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
For most models, the prompt format should follow corresponding examples
on HuggingFace model repository.
"""
import random
from transformers import AutoTokenizer

from vllm import LLM, SamplingParams
Expand All @@ -23,7 +24,9 @@

prompt = f"USER: <image>\n{question}\nASSISTANT:"

llm = LLM(model="llava-hf/llava-1.5-7b-hf", max_model_len=4096)
llm = LLM(model="llava-hf/llava-1.5-7b-hf",
max_model_len=4096,
mm_cache_preprocessor=args.mm_cache_preprocessor)
stop_token_ids = None
return llm, prompt, stop_token_ids

Expand Down Expand Up @@ -507,14 +510,35 @@

else:
# Batch inference
inputs = [{
"prompt": prompt,
"multi_modal_data": {
modality: data
},
} for _ in range(args.num_prompts)]

if args.image_repeat_ratio is not None:
assert (args.image_repeat_ratio <= 1.0
and args.image_repeat_ratio >= 0)
no_yes = [0, 1]
probs = [1.0 - args.image_repeat_ratio, args.image_repeat_ratio]

inputs = []
cur_image = data
for i in range(args.num_prompts):
if args.image_repeat_ratio is not None:
res = random.choices(no_yes, probs)[0]
if res == 0:
# No repeat => Modify one pixel
cur_image = cur_image.copy()
new_val = (i // 256 // 256, i // 256, i % 256)
cur_image.putpixel((0, 0), new_val)

inputs.append({
"prompt": prompt,
"multi_modal_data": {
modality: cur_image
}
})

import time
start_time = time.time()
outputs = llm.generate(inputs, sampling_params=sampling_params)
elapsed_time = time.time() - start_time
print("-- generate time = {}".format(elapsed_time))

for o in outputs:
generated_text = o.outputs[0].text
Expand Down Expand Up @@ -544,5 +568,19 @@
type=int,
default=16,
help='Number of frames to extract from the video.')

parser.add_argument(
'--image-repeat-ratio',
type=float,
default=None,
help=
'Simulates the hit-ratio for multi-modal preprocessor cache (if enabled)'

Check failure on line 577 in examples/offline_inference_vision_language.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

examples/offline_inference_vision_language.py:577:81: E501 Line too long (81 > 80)
)

parser.add_argument(
'--mm-cache-preprocessor',
action='store_true',
help='If True, enable caching of multi-modal preprocessor/mapper.')

args = parser.parse_args()
main(args)
10 changes: 8 additions & 2 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ class ModelConfig:
HuggingFace config.
mm_processor_kwargs: Arguments to be forwarded to the model's processor
for multi-modal data, e.g., image processor.
mm_cache_preprocessor: If True, enable caching of multi-modal
preprocessor/mapper.
override_neuron_config: Initialize non default neuron config or
override default neuron config that are specific to Neuron devices,
this argument will be used to configure the neuron config that
Expand Down Expand Up @@ -169,6 +171,7 @@ def __init__(
config_format: ConfigFormat = ConfigFormat.AUTO,
hf_overrides: Optional[HfOverrides] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
mm_cache_preprocessor: bool = False,
override_neuron_config: Optional[Dict[str, Any]] = None,
override_pooler_config: Optional["PoolerConfig"] = None) -> None:
self.model = model
Expand Down Expand Up @@ -235,6 +238,7 @@ def __init__(
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
self.use_async_output_proc = use_async_output_proc
self.mm_processor_kwargs = mm_processor_kwargs
self.mm_cache_preprocessor = mm_cache_preprocessor

# Set enforce_eager to False if the value is unset.
if self.enforce_eager is None:
Expand Down Expand Up @@ -2593,7 +2597,8 @@ def __str__(self):
"decoding_config=%r, observability_config=%r, "
"seed=%d, served_model_name=%s, "
"num_scheduler_steps=%d, enable_prefix_caching=%s, "
"use_async_output_proc=%s, mm_processor_kwargs=%s") % \
"use_async_output_proc=%s, mm_processor_kwargs=%s, "
"mm_cache_preprocessor=%s") % \
(self.model_config.model, self.speculative_config,
self.model_config.tokenizer,
self.model_config.skip_tokenizer_init,
Expand All @@ -2619,7 +2624,8 @@ def __str__(self):
self.scheduler_config.num_scheduler_steps,
self.cache_config.enable_prefix_caching,
self.model_config.use_async_output_proc,
self.model_config.mm_processor_kwargs)
self.model_config.mm_processor_kwargs,
self.model_config.mm_cache_preprocessor)


_current_vllm_config: Optional[VllmConfig] = None
Expand Down
6 changes: 6 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ class EngineArgs:
tokenizer_pool_extra_config: Optional[Dict[str, Any]] = None
limit_mm_per_prompt: Optional[Mapping[str, int]] = None
mm_processor_kwargs: Optional[Dict[str, Any]] = None
mm_cache_preprocessor: bool = False
enable_lora: bool = False
enable_lora_bias: bool = False
max_loras: int = 1
Expand Down Expand Up @@ -592,6 +593,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
type=json.loads,
help=('Overrides for the multimodal input mapping/processing, '
'e.g., image processor. For example: {"num_crops": 4}.'))
parser.add_argument(
'--mm-cache-preprocessor',
action='store_true',
help='If True, enable caching of multi-modal preprocessor/mapper.')

# LoRA related configs
parser.add_argument('--enable-lora',
Expand Down Expand Up @@ -964,6 +969,7 @@ def create_model_config(self) -> ModelConfig:
use_async_output_proc=not self.disable_async_output_proc,
config_format=self.config_format,
mm_processor_kwargs=self.mm_processor_kwargs,
mm_cache_preprocessor=self.mm_cache_preprocessor,
override_neuron_config=self.override_neuron_config,
override_pooler_config=self.override_pooler_config,
)
Expand Down
2 changes: 2 additions & 0 deletions vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,7 @@ class NewRequestData:
prompt_token_ids: List[int]
prompt: Optional[str]
mm_inputs: List["MultiModalKwargs"]
mm_hash: List[str]
mm_positions: List["PlaceholderRange"]
sampling_params: SamplingParams
block_ids: List[int]
Expand All @@ -527,6 +528,7 @@ def from_request(
prompt_token_ids=request.prompt_token_ids,
prompt=request.prompt,
mm_inputs=request.mm_inputs,
mm_hash=request.mm_hash,
mm_positions=request.mm_positions,
sampling_params=request.sampling_params,
block_ids=block_ids,
Expand Down
1 change: 1 addition & 0 deletions vllm/v1/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class EngineCoreRequest:
prompt: Optional[str]
prompt_token_ids: List[int]
mm_inputs: Optional[List[MultiModalKwargs]]
mm_hash: Optional[List[str]]
mm_placeholders: Optional[MultiModalPlaceholderDict]
sampling_params: SamplingParams
eos_token_id: Optional[int]
Expand Down
4 changes: 0 additions & 4 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
EngineCoreProfile, EngineCoreRequest,
EngineCoreRequestType)
from vllm.v1.engine.mm_input_mapper import MMInputMapper
from vllm.v1.executor.gpu_executor import GPUExecutor
from vllm.v1.request import Request, RequestStatus
from vllm.v1.serial_utils import PickleEncoder
Expand Down Expand Up @@ -55,9 +54,6 @@ def __init__(
vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks

# Set up multimodal input mapper (e.g., convert PIL images to tensors).
self.mm_input_mapper = MMInputMapper(vllm_config.model_config)

# Setup scheduler.
self.scheduler = Scheduler(vllm_config.scheduler_config,
vllm_config.cache_config,
Expand Down
85 changes: 79 additions & 6 deletions vllm/v1/engine/mm_input_mapper.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import time
import PIL

from blake3 import blake3
from typing import Any, Dict, List, Optional

from vllm.config import ModelConfig
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
MultiModalKwargs, MultiModalRegistry)
from vllm.v1.utils import LRUDictCache


class MMInputMapper:
Expand All @@ -11,29 +16,97 @@
self,
model_config: ModelConfig,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
mm_cache_size: int = 128,
):
self.mm_registry = mm_registry
self.multi_modal_input_mapper = mm_registry.create_input_mapper(
model_config)
self.mm_registry.init_mm_limits_per_prompt(model_config)

self.mm_cache = LRUDictCache(size=mm_cache_size)
self.mm_cache_hits = 0
self.mm_cache_misses = 0

# Set to None to disable (TODO: Disable!)
self.mm_debug_cache_hit_ratio_steps = 32

def cache_hit_ratio(self, steps) -> float:
total_steps = self.mm_cache_hits + self.mm_cache_misses

if total_steps > 0 and total_steps % steps == 0:
print("[debug] MMInputMapper: cache_hit_ratio = {}".format(
self.mm_cache_hits / total_steps))

def process_inputs(
self,
mm_data: MultiModalDataDict,
mm_hash: Optional[List[str]],
mm_processor_kwargs: Optional[Dict[str, Any]],
) -> List[MultiModalKwargs]:
image_inputs = mm_data["image"]
if not isinstance(image_inputs, list):
image_inputs = [image_inputs]

use_hash = False #mm_hash is not None
if use_hash:
assert len(image_inputs) == len(mm_hash) # Sanity

# Process each image input separately so that later we can schedule
# them in a fine-grained manner.
# Utilize caching (if enabled)
mm_inputs: List[MultiModalKwargs] = []
num_images = len(image_inputs)
for i in range(num_images):
mm_input = self.multi_modal_input_mapper(
{"image": [image_inputs[i]]},
mm_processor_kwargs=mm_processor_kwargs,
)
for i in range(len(image_inputs)):
if self.mm_debug_cache_hit_ratio_steps is not None:
self.cache_hit_ratio(self.mm_debug_cache_hit_ratio_steps)

mm_input = self.mm_cache.get(mm_hash[i]) if use_hash else None
if mm_input is None:
self.mm_cache_misses += 1
mm_input = self.multi_modal_input_mapper(
{"image": [image_inputs[i]]},
mm_processor_kwargs=mm_processor_kwargs,
)
if use_hash:
self.mm_cache.put(mm_hash[i], mm_input)
else:
self.mm_cache_hits += 1

mm_inputs.append(mm_input)

return mm_inputs


class MMHasher:

def __init__(self):
pass

def hash(self, mm_data: MultiModalDataDict) -> List[str]:
image_inputs = mm_data["image"]
if not isinstance(image_inputs, list):
image_inputs = [image_inputs]

ret = []
for image in image_inputs:
assert isinstance(image, PIL.Image.Image)

# FIXME(alexm): Remove debug

# print(" type(data) = {}, data = {}".format(type(image), image))

# Convert image to bytes
start_time = time.time()
bytes = image.tobytes()
elapsed_time = time.time() - start_time
# print(" tobytes time = {}".format(elapsed_time))

# Hash image bytes
start_time = time.time()
hasher = blake3()
hasher.update(bytes)
ret.append(hasher.hexdigest())
elapsed_time = time.time() - start_time

Check failure on line 108 in vllm/v1/engine/mm_input_mapper.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (F841)

vllm/v1/engine/mm_input_mapper.py:108:13: F841 Local variable `elapsed_time` is assigned to but never used
# print(" hash time = {}".format(elapsed_time))
# print(" hash val = {}".format(ret[-1]))

return ret
11 changes: 9 additions & 2 deletions vllm/v1/engine/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from vllm.transformers_utils.config import try_get_generation_config
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from vllm.v1.engine import DetokenizerRequest, EngineCoreRequest
from vllm.v1.engine.mm_input_mapper import MMInputMapper
from vllm.v1.engine.mm_input_mapper import MMInputMapper, MMHasher


class Processor:
Expand Down Expand Up @@ -43,6 +43,10 @@ def __init__(
# Multi-modal (huggingface) input mapper
self.mm_input_mapper = MMInputMapper(model_config)

# Multi-modal hasher (for images)
self.mm_hasher = MMHasher(
) if model_config.mm_cache_preprocessor else None

# TODO: run in an ThreadpoolExecutor or BackgroundProcess.
# This ideally should releases the GIL, so we should not block the
# asyncio loop while this is running.
Expand Down Expand Up @@ -101,8 +105,10 @@ def process_inputs(
self.generation_config_fields, eos_token_id)

# Preprocess multi-modal data
mm_hash = self.mm_hasher.hash(decoder_inputs.multi_modal_data
) if self.mm_hasher is not None else None
mm_inputs = self.mm_input_mapper.process_inputs(
decoder_inputs.multi_modal_data,
decoder_inputs.multi_modal_data, mm_hash,
decoder_inputs.mm_processor_kwargs) if len(
decoder_inputs.multi_modal_data) > 0 else None

Expand All @@ -124,6 +130,7 @@ def process_inputs(
decoder_inputs.prompt,
decoder_inputs.prompt_token_ids,
mm_inputs,
mm_hash,
decoder_inputs.multi_modal_placeholders,
sampling_params,
eos_token_id,
Expand Down
3 changes: 3 additions & 0 deletions vllm/v1/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __init__(
eos_token_id: Optional[int],
arrival_time: float,
lora_request: Optional[LoRARequest] = None,
mm_hash: Optional[List[str]] = None,
) -> None:
self.request_id = request_id
self.inputs = SingletonInputsAdapter(inputs)
Expand Down Expand Up @@ -56,6 +57,7 @@ def __init__(
self.mm_inputs = self.inputs.multi_modal_inputs
else:
self.mm_inputs: List[MultiModalKwargs] = []
self.mm_hash = mm_hash

@classmethod
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
Expand All @@ -73,6 +75,7 @@ def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
eos_token_id=request.eos_token_id,
arrival_time=request.arrival_time,
lora_request=request.lora_request,
mm_hash=request.mm_hash,
)

@property
Expand Down
Loading
Loading