diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 484cd171f5f5..d83b595c9e1e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,11 +9,15 @@ repos: - id: yapf args: [--in-place, --verbose] additional_dependencies: [toml] # TODO: Remove when yapf is upgraded + exclude: "^(csrc|vllm/assets|vllm/inputs|vllm/multimodal|vllm/usage)/.*" - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.9.3 hooks: - - id: ruff - args: [--output-format, github, --fix] + - id: ruff + args: [--output-format, github, --fix, --exit-non-zero-on-fix, --show-fixes] + - id: ruff-format + types_or: [python] + exclude: "^(?!(csrc|vllm/assets|vllm/inputs|vllm/multimodal|vllm/usage)/).*" - repo: https://github.com/codespell-project/codespell rev: v2.4.0 hooks: @@ -24,6 +28,7 @@ repos: rev: 0a0b7a830386ba6a31c2ec8316849ae4d1b8240d # 6.0.0 hooks: - id: isort + exclude: "^(csrc|vllm/assets|vllm/inputs|vllm/multimodal|vllm/usage)/.*" - repo: https://github.com/pre-commit/mirrors-clang-format rev: v19.1.7 hooks: diff --git a/.yapfignore b/.yapfignore deleted file mode 100644 index 2d6dcf8380ca..000000000000 --- a/.yapfignore +++ /dev/null @@ -1 +0,0 @@ -collect_env.py diff --git a/collect_env.py b/collect_env.py index 0ec9d4cae4ba..6e09b0657d58 100644 --- a/collect_env.py +++ b/collect_env.py @@ -277,12 +277,13 @@ def get_vllm_version(): if __version__ == "dev": return "N/A (dev)" - if len(__version_tuple__) == 4: # dev build - git_sha = __version_tuple__[-1][1:] # type: ignore + if len(__version_tuple__) == 4: # dev build + git_sha = __version_tuple__[-1][1:] # type: ignore return f"{__version__} (git sha: {git_sha}" return __version__ + def summarize_vllm_build_flags(): # This could be a static method if the flags are constant, or dynamic if you need to check environment variables, etc. return 'CUDA Archs: {}; ROCm: {}; Neuron: {}'.format( @@ -517,13 +518,12 @@ def is_xnnpack_available(): else: return "N/A" + def get_env_vars(): env_vars = '' - secret_terms=('secret', 'token', 'api', 'access', 'password') - report_prefix = ("TORCH", "NCCL", "PYTORCH", - "CUDA", "CUBLAS", "CUDNN", - "OMP_", "MKL_", - "NVIDIA") + secret_terms = ('secret', 'token', 'api', 'access', 'password') + report_prefix = ("TORCH", "NCCL", "PYTORCH", "CUDA", "CUBLAS", "CUDNN", + "OMP_", "MKL_", "NVIDIA") for k, v in os.environ.items(): if any(term in k.lower() for term in secret_terms): continue @@ -534,6 +534,7 @@ def get_env_vars(): return env_vars + def get_env_info(): run_lambda = run pip_version, pip_list_output = get_pip_packages(run_lambda) diff --git a/csrc/cutlass_extensions/vllm_cutlass_library_extension.py b/csrc/cutlass_extensions/vllm_cutlass_library_extension.py index d64f0d0a5c2a..c33370bab9cf 100644 --- a/csrc/cutlass_extensions/vllm_cutlass_library_extension.py +++ b/csrc/cutlass_extensions/vllm_cutlass_library_extension.py @@ -26,7 +26,7 @@ class MixedInputKernelScheduleType(enum.Enum): **{ VLLMDataType.u4b8: "u4b8", VLLMDataType.u8b128: "u8b128", - } + }, } VLLMDataTypeTag: dict[Union[VLLMDataType, DataType], str] = { @@ -34,7 +34,7 @@ class MixedInputKernelScheduleType(enum.Enum): **{ VLLMDataType.u4b8: "cutlass::vllm_uint4b8_t", VLLMDataType.u8b128: "cutlass::vllm_uint8b128_t", - } + }, } VLLMDataTypeSize: dict[Union[VLLMDataType, DataType], int] = { @@ -42,7 +42,7 @@ class MixedInputKernelScheduleType(enum.Enum): **{ VLLMDataType.u4b8: 4, VLLMDataType.u8b128: 8, - } + }, } VLLMDataTypeVLLMScalarTypeTag: dict[Union[VLLMDataType, DataType], str] = { @@ -66,15 +66,13 @@ class MixedInputKernelScheduleType(enum.Enum): DataType.f32: "at::ScalarType::Float", } -VLLMKernelScheduleTag: dict[Union[ - MixedInputKernelScheduleType, KernelScheduleType], str] = { - **KernelScheduleTag, # type: ignore - **{ - MixedInputKernelScheduleType.TmaWarpSpecialized: - "cutlass::gemm::KernelTmaWarpSpecialized", - MixedInputKernelScheduleType.TmaWarpSpecializedPingpong: - "cutlass::gemm::KernelTmaWarpSpecializedPingpong", - MixedInputKernelScheduleType.TmaWarpSpecializedCooperative: - "cutlass::gemm::KernelTmaWarpSpecializedCooperative", - } - } +VLLMKernelScheduleTag: dict[ + Union[MixedInputKernelScheduleType, KernelScheduleType], str +] = { + **KernelScheduleTag, # type: ignore + **{ + MixedInputKernelScheduleType.TmaWarpSpecialized: "cutlass::gemm::KernelTmaWarpSpecialized", + MixedInputKernelScheduleType.TmaWarpSpecializedPingpong: "cutlass::gemm::KernelTmaWarpSpecializedPingpong", + MixedInputKernelScheduleType.TmaWarpSpecializedCooperative: "cutlass::gemm::KernelTmaWarpSpecializedCooperative", + }, +} diff --git a/csrc/quantization/machete/generate.py b/csrc/quantization/machete/generate.py index 3114e14baa0c..68c34dad27cb 100644 --- a/csrc/quantization/machete/generate.py +++ b/csrc/quantization/machete/generate.py @@ -11,18 +11,24 @@ from typing import Optional, Union import jinja2 + # yapf conflicts with isort for this block # yapf: disable -from vllm_cutlass_library_extension import (DataType, EpilogueScheduleTag, - EpilogueScheduleType, - MixedInputKernelScheduleType, - TileSchedulerTag, - TileSchedulerType, VLLMDataType, - VLLMDataTypeNames, - VLLMDataTypeSize, VLLMDataTypeTag, - VLLMDataTypeTorchDataTypeTag, - VLLMDataTypeVLLMScalarTypeTag, - VLLMKernelScheduleTag) +from vllm_cutlass_library_extension import ( + DataType, + EpilogueScheduleTag, + EpilogueScheduleType, + MixedInputKernelScheduleType, + TileSchedulerTag, + TileSchedulerType, + VLLMDataType, + VLLMDataTypeNames, + VLLMDataTypeSize, + VLLMDataTypeTag, + VLLMDataTypeTorchDataTypeTag, + VLLMDataTypeVLLMScalarTypeTag, + VLLMKernelScheduleTag, +) # yapf: enable @@ -285,18 +291,25 @@ def generate_sch_sig(schedule_config: ScheduleConfig) -> str: tile_shape = ( f"{schedule_config.tile_shape_mn[0]}x{schedule_config.tile_shape_mn[1]}" ) - cluster_shape = (f"{schedule_config.cluster_shape_mnk[0]}" + - f"x{schedule_config.cluster_shape_mnk[1]}" + - f"x{schedule_config.cluster_shape_mnk[2]}") - kernel_schedule = VLLMKernelScheduleTag[schedule_config.kernel_schedule]\ - .split("::")[-1] + cluster_shape = ( + f"{schedule_config.cluster_shape_mnk[0]}" + + f"x{schedule_config.cluster_shape_mnk[1]}" + + f"x{schedule_config.cluster_shape_mnk[2]}" + ) + kernel_schedule = VLLMKernelScheduleTag[ + schedule_config.kernel_schedule + ].split("::")[-1] epilogue_schedule = EpilogueScheduleTag[ - schedule_config.epilogue_schedule].split("::")[-1] - tile_scheduler = TileSchedulerTag[schedule_config.tile_scheduler]\ - .split("::")[-1] - - return (f"{tile_shape}_{cluster_shape}_{kernel_schedule}" + - f"_{epilogue_schedule}_{tile_scheduler}") + schedule_config.epilogue_schedule + ].split("::")[-1] + tile_scheduler = TileSchedulerTag[schedule_config.tile_scheduler].split( + "::" + )[-1] + + return ( + f"{tile_shape}_{cluster_shape}_{kernel_schedule}" + + f"_{epilogue_schedule}_{tile_scheduler}" + ) # mostly unique shorter sch_sig @@ -315,18 +328,24 @@ def generate_terse_sch_sig(schedule_config: ScheduleConfig) -> str: # unique type_name def generate_type_signature(kernel_types: TypeConfig): - return str("".join([ - VLLMDataTypeNames[getattr(kernel_types, field.name)] - for field in fields(TypeConfig) - ])) + return str( + "".join( + [ + VLLMDataTypeNames[getattr(kernel_types, field.name)] + for field in fields(TypeConfig) + ] + ) + ) def generate_type_option_name(kernel_types: TypeConfig): - return ", ".join([ - f"{field.name.replace('b_', 'with_')+'_type'}=" + - VLLMDataTypeNames[getattr(kernel_types, field.name)] - for field in fields(TypeConfig) - ]) + return ", ".join( + [ + f"{field.name.replace('b_', 'with_') + '_type'}=" + + VLLMDataTypeNames[getattr(kernel_types, field.name)] + for field in fields(TypeConfig) + ] + ) def is_power_of_two(n): @@ -334,7 +353,6 @@ def is_power_of_two(n): def to_cute_constant(value: list[int]): - def _to_cute_constant(value: int): if is_power_of_two(value): return f"_{value}" @@ -349,8 +367,10 @@ def _to_cute_constant(value: int): def unique_schedules(impl_configs: list[ImplConfig]): return list( - set(sch for impl_config in impl_configs - for sch in impl_config.schedules)) + set( + sch for impl_config in impl_configs for sch in impl_config.schedules + ) + ) def unsigned_type_with_bitwidth(num_bits): @@ -376,7 +396,7 @@ def unsigned_type_with_bitwidth(num_bits): "gen_type_sig": generate_type_signature, "unique_schedules": unique_schedules, "unsigned_type_with_bitwidth": unsigned_type_with_bitwidth, - "gen_type_option_name": generate_type_option_name + "gen_type_option_name": generate_type_option_name, } @@ -394,23 +414,28 @@ def create_template(template_str): def create_sources(impl_configs: list[ImplConfig], num_impl_files=8): sources = [] - sources.append(( - "machete_mm_dispatch", - mm_dispatch_template.render(impl_configs=impl_configs), - )) + sources.append( + ( + "machete_mm_dispatch", + mm_dispatch_template.render(impl_configs=impl_configs), + ) + ) prepack_types = [] for impl_config in impl_configs: - convert_type = impl_config.types.a \ - if impl_config.types.b_group_scale == DataType.void \ - else impl_config.types.b_group_scale + convert_type = ( + impl_config.types.a + if impl_config.types.b_group_scale == DataType.void + else impl_config.types.b_group_scale + ) prepack_types.append( PrepackTypeConfig( a=impl_config.types.a, b_num_bits=VLLMDataTypeSize[impl_config.types.b], convert=convert_type, accumulator=impl_config.types.accumulator, - )) + ) + ) def prepacked_type_key(prepack_type: PrepackTypeConfig): # For now we we can just use the first accumulator type seen since @@ -426,10 +451,14 @@ def prepacked_type_key(prepack_type: PrepackTypeConfig): unique_prepack_types.append(prepack_type) prepack_types_seen.add(key) - sources.append(( - "machete_prepack", - prepack_dispatch_template.render(types=unique_prepack_types, ), - )) + sources.append( + ( + "machete_prepack", + prepack_dispatch_template.render( + types=unique_prepack_types, + ), + ) + ) # Split up impls across files num_impls = reduce(lambda x, y: x + len(y.schedules), impl_configs, 0) @@ -462,10 +491,12 @@ def prepacked_type_key(prepack_type: PrepackTypeConfig): curr_impl_in_file += len(files_impls[-1][-1].schedules) for part, file_impls in enumerate(files_impls): - sources.append(( - f"machete_mm_impl_part{part+1}", - mm_impl_template.render(impl_configs=file_impls), - )) + sources.append( + ( + f"machete_mm_impl_part{part + 1}", + mm_impl_template.render(impl_configs=file_impls), + ) + ) return sources @@ -510,8 +541,7 @@ def generate(): # For now we use the same heuristic for all types # Heuristic is currently tuned for H100s default_heuristic = [ - (cond, ScheduleConfig(*tile_config, - **sch_common_params)) # type: ignore + (cond, ScheduleConfig(*tile_config, **sch_common_params)) # type: ignore for cond, tile_config in default_tile_heuristic_config.items() ] @@ -537,14 +567,18 @@ def get_unique_schedules(heuristic: dict[str, ScheduleConfig]): a_token_scale=DataType.void, out=a, accumulator=DataType.f32, - ) for b in (VLLMDataType.u4b8, VLLMDataType.u8b128) - for a in (DataType.f16, DataType.bf16)) + ) + for b in (VLLMDataType.u4b8, VLLMDataType.u8b128) + for a in (DataType.f16, DataType.bf16) + ) impl_configs += [ ImplConfig(x[0], x[1], x[2]) - for x in zip(GPTQ_kernel_type_configs, - itertools.repeat(get_unique_schedules(default_heuristic)), - itertools.repeat(default_heuristic)) + for x in zip( + GPTQ_kernel_type_configs, + itertools.repeat(get_unique_schedules(default_heuristic)), + itertools.repeat(default_heuristic), + ) ] AWQ_kernel_type_configs = list( @@ -557,14 +591,18 @@ def get_unique_schedules(heuristic: dict[str, ScheduleConfig]): a_token_scale=DataType.void, out=a, accumulator=DataType.f32, - ) for b in (DataType.u4, DataType.u8) - for a in (DataType.f16, DataType.bf16)) + ) + for b in (DataType.u4, DataType.u8) + for a in (DataType.f16, DataType.bf16) + ) impl_configs += [ ImplConfig(x[0], x[1], x[2]) - for x in zip(AWQ_kernel_type_configs, - itertools.repeat(get_unique_schedules(default_heuristic)), - itertools.repeat(default_heuristic)) + for x in zip( + AWQ_kernel_type_configs, + itertools.repeat(get_unique_schedules(default_heuristic)), + itertools.repeat(default_heuristic), + ) ] # Stored as "condition": ((tile_shape_mn), (cluster_shape_mnk)) @@ -592,7 +630,7 @@ def get_unique_schedules(heuristic: dict[str, ScheduleConfig]): "M > 32 && K <= 6144 && N <= 6144": ((128, 16), (1, 1, 1)), # Broken for QQQ types # TODO (LucasWilkinson): Investigate further - #"M > 32 && K >= 16384 && N >= 12288": ((256, 64), (2, 1, 1)), + # "M > 32 && K >= 16384 && N >= 12288": ((256, 64), (2, 1, 1)), "M > 32": ((128, 64), (2, 1, 1)), #### M = 17-32 "M > 16 && K <= 12288 && N <= 8192": ((128, 32), (2, 1, 1)), @@ -605,39 +643,46 @@ def get_unique_schedules(heuristic: dict[str, ScheduleConfig]): # For now we use the same heuristic for all types # Heuristic is currently tuned for H100s qqq_heuristic = [ - (cond, ScheduleConfig(*tile_config, - **sch_common_params)) # type: ignore + (cond, ScheduleConfig(*tile_config, **sch_common_params)) # type: ignore for cond, tile_config in qqq_tile_heuristic_config.items() ] QQQ_kernel_types = [ - *(TypeConfig( - a=DataType.s8, - b=VLLMDataType.u4b8, - b_group_scale=b_group_scale, - b_group_zeropoint=DataType.void, - b_channel_scale=DataType.f32, - a_token_scale=DataType.f32, - out=DataType.f16, - accumulator=DataType.s32, - ) for b_group_scale in (DataType.f16, DataType.void)), - *(TypeConfig( - a=DataType.e4m3, - b=VLLMDataType.u4b8, - b_group_scale=b_group_scale, - b_group_zeropoint=DataType.void, - b_channel_scale=DataType.f32, - a_token_scale=DataType.f32, - out=DataType.f16, - accumulator=DataType.f32, - ) for b_group_scale in (DataType.f16, DataType.void)), + *( + TypeConfig( + a=DataType.s8, + b=VLLMDataType.u4b8, + b_group_scale=b_group_scale, + b_group_zeropoint=DataType.void, + b_channel_scale=DataType.f32, + a_token_scale=DataType.f32, + out=DataType.f16, + accumulator=DataType.s32, + ) + for b_group_scale in (DataType.f16, DataType.void) + ), + *( + TypeConfig( + a=DataType.e4m3, + b=VLLMDataType.u4b8, + b_group_scale=b_group_scale, + b_group_zeropoint=DataType.void, + b_channel_scale=DataType.f32, + a_token_scale=DataType.f32, + out=DataType.f16, + accumulator=DataType.f32, + ) + for b_group_scale in (DataType.f16, DataType.void) + ), ] impl_configs += [ ImplConfig(x[0], x[1], x[2]) - for x in zip(QQQ_kernel_types, - itertools.repeat(get_unique_schedules(qqq_heuristic)), - itertools.repeat(qqq_heuristic)) + for x in zip( + QQQ_kernel_types, + itertools.repeat(get_unique_schedules(qqq_heuristic)), + itertools.repeat(qqq_heuristic), + ) ] output_dir = os.path.join(SCRIPT_DIR, "generated") diff --git a/pyproject.toml b/pyproject.toml index ee4e2ed0b7ce..515b74d5fee1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,15 +57,24 @@ ignore_patterns = [ [tool.ruff] # Allow lines to be as long as 80. line-length = 80 +indent-width = 4 +target-version = "py39" exclude = [ # External file, leaving license intact "examples/other/fp8/quantizer/quantize.py" ] +[tool.ruff.format] +quote-style = "double" +indent-style = "space" +docstring-code-format = true + [tool.ruff.lint.per-file-ignores] "vllm/third_party/**" = ["ALL"] "vllm/version.py" = ["F401"] "vllm/_version.py" = ["ALL"] +"vllm/**/*.py" = ["I"] +"vllm/v1/**/*.py" = [] # Python 3.8 typing. TODO: Remove these excludes after v1.0.0 "vllm/adapter_commons/**/*.py" = ["UP006", "UP035"] "vllm/attention/**/*.py" = ["UP006", "UP035"] @@ -100,7 +109,7 @@ select = [ # flake8-simplify "SIM", # isort - # "I", + "I", "G", ] ignore = [ @@ -108,6 +117,9 @@ ignore = [ "F405", "F403", # lambda expression assignment "E731", + # Ignore line to long. Most cases are either Exception + # or inline notes/docstring + "E501", # Loop control variable not used within loop body "B007", # f-string format diff --git a/vllm/assets/audio.py b/vllm/assets/audio.py index 0203dc092a71..d23bbbfab6da 100644 --- a/vllm/assets/audio.py +++ b/vllm/assets/audio.py @@ -25,13 +25,15 @@ class AudioAsset: @property def audio_and_sample_rate(self) -> tuple[npt.NDArray, float]: - audio_path = get_vllm_public_assets(filename=f"{self.name}.ogg", - s3_prefix=ASSET_DIR) + audio_path = get_vllm_public_assets( + filename=f"{self.name}.ogg", s3_prefix=ASSET_DIR + ) return librosa.load(audio_path, sr=None) def get_local_path(self) -> Path: - return get_vllm_public_assets(filename=f"{self.name}.ogg", - s3_prefix=ASSET_DIR) + return get_vllm_public_assets( + filename=f"{self.name}.ogg", s3_prefix=ASSET_DIR + ) @property def url(self) -> str: diff --git a/vllm/assets/base.py b/vllm/assets/base.py index 03f3b9dabf14..5e1da375ae2e 100644 --- a/vllm/assets/base.py +++ b/vllm/assets/base.py @@ -19,8 +19,9 @@ def get_cache_dir() -> Path: @lru_cache -def get_vllm_public_assets(filename: str, - s3_prefix: Optional[str] = None) -> Path: +def get_vllm_public_assets( + filename: str, s3_prefix: Optional[str] = None +) -> Path: """ Download an asset file from ``s3://vllm-public-assets`` and return the path to the downloaded file. @@ -35,6 +36,7 @@ def get_vllm_public_assets(filename: str, global_http_connection.download_file( f"{VLLM_S3_BUCKET_URL}/{filename}", asset_path, - timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT) + timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT, + ) return asset_path diff --git a/vllm/assets/image.py b/vllm/assets/image.py index 2b1d258da9c7..355d94375da4 100644 --- a/vllm/assets/image.py +++ b/vllm/assets/image.py @@ -17,8 +17,9 @@ class ImageAsset: @property def pil_image(self) -> Image.Image: - image_path = get_vllm_public_assets(filename=f"{self.name}.jpg", - s3_prefix=VLM_IMAGES_DIR) + image_path = get_vllm_public_assets( + filename=f"{self.name}.jpg", s3_prefix=VLM_IMAGES_DIR + ) return Image.open(image_path) @property @@ -26,6 +27,7 @@ def image_embeds(self) -> torch.Tensor: """ Image embeddings, only used for testing purposes with llava 1.5. """ - image_path = get_vllm_public_assets(filename=f"{self.name}.pt", - s3_prefix=VLM_IMAGES_DIR) + image_path = get_vllm_public_assets( + filename=f"{self.name}.pt", s3_prefix=VLM_IMAGES_DIR + ) return torch.load(image_path, map_location="cpu", weights_only=True) diff --git a/vllm/assets/video.py b/vllm/assets/video.py index e45e1a65f890..7f470d62ab03 100644 --- a/vllm/assets/video.py +++ b/vllm/assets/video.py @@ -52,13 +52,16 @@ def video_to_ndarrays(path: str, num_frames: int = -1) -> npt.NDArray: frames = np.stack(frames) frames = sample_frames_from_video(frames, num_frames) if len(frames) < num_frames: - raise ValueError(f"Could not read enough frames from video file {path}" - f" (expected {num_frames} frames, got {len(frames)})") + raise ValueError( + f"Could not read enough frames from video file {path}" + f" (expected {num_frames} frames, got {len(frames)})" + ) return frames -def video_to_pil_images_list(path: str, - num_frames: int = -1) -> list[Image.Image]: +def video_to_pil_images_list( + path: str, num_frames: int = -1 +) -> list[Image.Image]: frames = video_to_ndarrays(path, num_frames) return [ Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index 6f8f2cd758f7..523f5080592c 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -1,13 +1,28 @@ # SPDX-License-Identifier: Apache-2.0 -from .data import (DecoderOnlyInputs, EncoderDecoderInputs, - ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType, - SingletonInputs, SingletonInputsAdapter, SingletonPrompt, - TextPrompt, TokenInputs, TokensPrompt, - build_explicit_enc_dec_prompt, to_enc_dec_tuple_list, - token_inputs, zip_enc_dec_prompts) -from .registry import (DummyData, InputContext, InputProcessingContext, - InputRegistry) +from .data import ( + DecoderOnlyInputs, + EncoderDecoderInputs, + ExplicitEncoderDecoderPrompt, + ProcessorInputs, + PromptType, + SingletonInputs, + SingletonInputsAdapter, + SingletonPrompt, + TextPrompt, + TokenInputs, + TokensPrompt, + build_explicit_enc_dec_prompt, + to_enc_dec_tuple_list, + token_inputs, + zip_enc_dec_prompts, +) +from .registry import ( + DummyData, + InputContext, + InputProcessingContext, + InputRegistry, +) INPUT_REGISTRY = InputRegistry() """ diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 138a8f61107b..ef17c5bef196 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -9,8 +9,11 @@ from typing_extensions import NotRequired, TypedDict, TypeVar, assert_never if TYPE_CHECKING: - from vllm.multimodal import (MultiModalDataDict, MultiModalKwargs, - MultiModalPlaceholderDict) + from vllm.multimodal import ( + MultiModalDataDict, + MultiModalKwargs, + MultiModalPlaceholderDict, + ) from vllm.multimodal.inputs import MultiModalInputs @@ -80,14 +83,12 @@ class TokensPrompt(TypedDict): more than one prompt, i.e. :class:`ExplicitEncoderDecoderPrompt` """ -_T1_co = TypeVar("_T1_co", - bound=SingletonPrompt, - default=SingletonPrompt, - covariant=True) -_T2_co = TypeVar("_T2_co", - bound=SingletonPrompt, - default=SingletonPrompt, - covariant=True) +_T1_co = TypeVar( + "_T1_co", bound=SingletonPrompt, default=SingletonPrompt, covariant=True +) +_T2_co = TypeVar( + "_T2_co", bound=SingletonPrompt, default=SingletonPrompt, covariant=True +) # TODO: Make fields ReadOnly once mypy supports it @@ -224,6 +225,7 @@ class EncoderDecoderInputs(TypedDict): This specifies the required data for encoder-decoder models. """ + encoder: Union[TokenInputs, "MultiModalInputs"] """The inputs for the encoder portion.""" @@ -243,6 +245,7 @@ class SingletonInputsAdapter: """ Unified interface to access the components of :class:`SingletonInputs`. """ + inputs: SingletonInputs @cached_property @@ -362,19 +365,21 @@ def build_explicit_enc_dec_prompt( return ExplicitEncoderDecoderPrompt( encoder_prompt=encoder_prompt, decoder_prompt=decoder_prompt, - mm_processor_kwargs=mm_processor_kwargs) + mm_processor_kwargs=mm_processor_kwargs, + ) def zip_enc_dec_prompts( enc_prompts: Iterable[_T1], dec_prompts: Iterable[Optional[_T2]], - mm_processor_kwargs: Optional[Union[Iterable[dict[str, Any]], - dict[str, Any]]] = None, + mm_processor_kwargs: Optional[ + Union[Iterable[dict[str, Any]], dict[str, Any]] + ] = None, ) -> list[ExplicitEncoderDecoderPrompt[_T1, _T2]]: """ Zip encoder and decoder prompts together into a list of :class:`ExplicitEncoderDecoderPrompt` instances. - + ``mm_processor_kwargs`` may also be provided; if a dict is passed, the same dictionary will be used for every encoder/decoder prompt. If an iterable is provided, it will be zipped with the encoder/decoder prompts. @@ -384,22 +389,28 @@ def zip_enc_dec_prompts( if isinstance(mm_processor_kwargs, dict): return [ build_explicit_enc_dec_prompt( - encoder_prompt, decoder_prompt, - cast(dict[str, Any], mm_processor_kwargs)) - for (encoder_prompt, - decoder_prompt) in zip(enc_prompts, dec_prompts) + encoder_prompt, + decoder_prompt, + cast(dict[str, Any], mm_processor_kwargs), + ) + for (encoder_prompt, decoder_prompt) in zip( + enc_prompts, dec_prompts + ) ] return [ - build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt, - mm_proc_kwargs) - for (encoder_prompt, decoder_prompt, mm_proc_kwargs - ) in zip(enc_prompts, dec_prompts, mm_processor_kwargs) + build_explicit_enc_dec_prompt( + encoder_prompt, decoder_prompt, mm_proc_kwargs + ) + for (encoder_prompt, decoder_prompt, mm_proc_kwargs) in zip( + enc_prompts, dec_prompts, mm_processor_kwargs + ) ] def to_enc_dec_tuple_list( enc_dec_prompts: Iterable[ExplicitEncoderDecoderPrompt[_T1, _T2]], ) -> list[tuple[_T1, Optional[_T2]]]: - return [(enc_dec_prompt["encoder_prompt"], - enc_dec_prompt["decoder_prompt"]) - for enc_dec_prompt in enc_dec_prompts] + return [ + (enc_dec_prompt["encoder_prompt"], enc_dec_prompt["decoder_prompt"]) + for enc_dec_prompt in enc_dec_prompts + ] diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py index ed1056948d80..31f9c204ce43 100644 --- a/vllm/inputs/parse.py +++ b/vllm/inputs/parse.py @@ -7,9 +7,15 @@ from vllm.utils import is_list_of -from .data import (EncoderDecoderInputs, ExplicitEncoderDecoderPrompt, - ProcessorInputs, PromptType, SingletonPrompt, TextPrompt, - TokensPrompt) +from .data import ( + EncoderDecoderInputs, + ExplicitEncoderDecoderPrompt, + ProcessorInputs, + PromptType, + SingletonPrompt, + TextPrompt, + TokensPrompt, +) class ParsedText(TypedDict): @@ -24,14 +30,14 @@ class ParsedTokens(TypedDict): @overload def parse_and_batch_prompt( - prompt: Union[str, list[str]]) -> Sequence[ParsedText]: - ... + prompt: Union[str, list[str]], +) -> Sequence[ParsedText]: ... @overload def parse_and_batch_prompt( - prompt: Union[list[int], list[list[int]]]) -> Sequence[ParsedTokens]: - ... + prompt: Union[list[int], list[list[int]]], +) -> Sequence[ParsedTokens]: ... def parse_and_batch_prompt( @@ -67,8 +73,10 @@ def parse_and_batch_prompt( for elem in prompt ] - raise TypeError("prompt must be a string, array of strings, " - "array of tokens, or array of token arrays") + raise TypeError( + "prompt must be a string, array of strings, " + "array of tokens, or array of token arrays" + ) class ParsedStrPrompt(TypedDict): @@ -93,8 +101,7 @@ def parse_singleton_prompt( return ParsedStrPrompt(type="str", content=prompt) elif isinstance(prompt, dict): if "prompt_token_ids" in prompt: - return ParsedTokensPrompt(type="tokens", - content=prompt) # type: ignore + return ParsedTokensPrompt(type="tokens", content=prompt) # type: ignore elif "prompt" in prompt: return ParsedTextPrompt(type="text", content=prompt) @@ -106,10 +113,12 @@ def is_token_prompt(prompt: PromptType) -> TypeIs[TokensPrompt]: def is_explicit_encoder_decoder_prompt( - prompt: PromptType) -> TypeIs[ExplicitEncoderDecoderPrompt]: + prompt: PromptType, +) -> TypeIs[ExplicitEncoderDecoderPrompt]: return isinstance(prompt, dict) and "encoder_prompt" in prompt def is_encoder_decoder_inputs( - inputs: ProcessorInputs) -> TypeIs[EncoderDecoderInputs]: + inputs: ProcessorInputs, +) -> TypeIs[EncoderDecoderInputs]: return "encoder" in inputs and "decoder" in inputs diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index af35e43d825a..abea31423ee5 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -10,20 +10,29 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs, - MultiModalInputs) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalEncDecInputs, + MultiModalInputs, +) from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup -from .data import (DecoderOnlyInputs, EncoderDecoderInputs, ProcessorInputs, - PromptType, SingletonInputs, SingletonPrompt, token_inputs) +from .data import ( + DecoderOnlyInputs, + EncoderDecoderInputs, + ProcessorInputs, + PromptType, + SingletonInputs, + SingletonPrompt, + token_inputs, +) from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt logger = init_logger(__name__) class InputPreprocessor: - def __init__( self, model_config: ModelConfig, @@ -38,63 +47,73 @@ def __init__( def get_tokenizer_group(self) -> BaseTokenizerGroup: if self.tokenizer is None: - raise ValueError("You cannot pass text prompts when " - "`skip_tokenizer_init` is True") + raise ValueError( + "You cannot pass text prompts when " + "`skip_tokenizer_init` is True" + ) return self.tokenizer - def get_bos_token_id(self, - lora_request: Optional[LoRARequest] = None - ) -> Optional[int]: + def get_bos_token_id( + self, lora_request: Optional[LoRARequest] = None + ) -> Optional[int]: if self.tokenizer is None: - logger.warning("Using None for BOS token id because tokenizer " - "is not initialized") + logger.warning( + "Using None for BOS token id because tokenizer " + "is not initialized" + ) return None return self.tokenizer.get_lora_tokenizer(lora_request).bos_token_id - def get_eos_token_id(self, - lora_request: Optional[LoRARequest] = None - ) -> Optional[int]: + def get_eos_token_id( + self, lora_request: Optional[LoRARequest] = None + ) -> Optional[int]: if self.tokenizer is None: - logger.warning("Using None for EOS token id because tokenizer " - "is not initialized") + logger.warning( + "Using None for EOS token id because tokenizer " + "is not initialized" + ) return None return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id def get_decoder_start_token_id(self) -> Optional[int]: - ''' + """ Obtain the decoder start token id employed by an encoder/decoder model. Returns None for non-encoder/decoder models or if the model config is unavailable. - ''' + """ if not self.model_config.is_encoder_decoder: logger.warning_once( "Using None for decoder start token id because " - "this is not an encoder/decoder model.") + "this is not an encoder/decoder model." + ) return None - if (self.model_config is None or self.model_config.hf_config is None): + if self.model_config is None or self.model_config.hf_config is None: logger.warning_once( "Using None for decoder start token id because " - "model config is not available.") + "model config is not available." + ) return None - dec_start_token_id = getattr(self.model_config.hf_config, - 'decoder_start_token_id', None) + dec_start_token_id = getattr( + self.model_config.hf_config, "decoder_start_token_id", None + ) if dec_start_token_id is None: logger.warning_once( "Falling back on for decoder start token " "id because decoder start token id is not " - "available.") + "available." + ) dec_start_token_id = self.get_bos_token_id() return dec_start_token_id def _get_default_enc_dec_decoder_prompt(self) -> list[int]: - ''' + """ Specifically for encoder/decoder models: generate a default decoder prompt for when the user specifies only the encoder prompt. @@ -123,7 +142,7 @@ def _get_default_enc_dec_decoder_prompt(self) -> list[int]: Returns: * prompt_token_ids - ''' + """ bos_token_id = self.get_bos_token_id() assert bos_token_id is not None @@ -161,8 +180,10 @@ def _prepare_decoder_input_ids_for_generation( # use decoder_start_token_id as decoder_input_ids decoder_input_ids = self._get_default_enc_dec_decoder_prompt() - if (len(decoder_input_ids) == 0 - or decoder_input_ids[0] != decoder_start_token_id): + if ( + len(decoder_input_ids) == 0 + or decoder_input_ids[0] != decoder_start_token_id + ): decoder_input_ids = [decoder_start_token_id] + decoder_input_ids return decoder_input_ids @@ -175,7 +196,8 @@ def _apply_prompt_adapter( if prompt_adapter_request: prompt_token_ids = ( [0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens - + prompt_token_ids) + + prompt_token_ids + ) return prompt_token_ids @@ -197,15 +219,18 @@ def _tokenize_prompt( # appending an EOS token to the prompt which disrupts generation. add_special_tokens = False - if (self.model_config.encoder_config is not None - and self.model_config.encoder_config.get( - "do_lower_case", False)): + if ( + self.model_config.encoder_config is not None + and self.model_config.encoder_config.get("do_lower_case", False) + ): prompt = prompt.lower() - return tokenizer.encode(request_id=request_id, - prompt=prompt, - lora_request=lora_request, - add_special_tokens=add_special_tokens) + return tokenizer.encode( + request_id=request_id, + prompt=prompt, + lora_request=lora_request, + add_special_tokens=add_special_tokens, + ) async def _tokenize_prompt_async( self, @@ -225,7 +250,8 @@ async def _tokenize_prompt_async( request_id=request_id, prompt=prompt, lora_request=lora_request, - add_special_tokens=add_special_tokens) + add_special_tokens=add_special_tokens, + ) def _can_process_multimodal(self) -> bool: model_config = self.model_config @@ -238,8 +264,10 @@ def _can_process_multimodal(self) -> bool: can_process_multimodal = self.mm_registry.has_processor(model_config) if not can_process_multimodal: from vllm.model_executor.models.registry import _VLLM_MODELS - if not any(arch in _VLLM_MODELS - for arch in model_config.architectures): + + if not any( + arch in _VLLM_MODELS for arch in model_config.architectures + ): logger.warning_once( "Your model uses the legacy input pipeline, which will be " "removed in an upcoming release. " @@ -271,13 +299,15 @@ def _process_multimodal( tokenizer = tokenizer_group.get_lora_tokenizer(lora_request) mm_processor = self.mm_registry.create_processor( - self.model_config, tokenizer) + self.model_config, tokenizer + ) if mm_processor_kwargs is None: mm_processor_kwargs = {} - return mm_processor.apply(prompt, mm_data, mm_processor_kwargs, - return_mm_hashes) + return mm_processor.apply( + prompt, mm_data, mm_processor_kwargs, return_mm_hashes + ) async def _process_multimodal_async( self, @@ -296,15 +326,18 @@ async def _process_multimodal_async( else: tokenizer_group = self.get_tokenizer_group() tokenizer = await tokenizer_group.get_lora_tokenizer_async( - lora_request) + lora_request + ) mm_processor = self.mm_registry.create_processor( - self.model_config, tokenizer) + self.model_config, tokenizer + ) if mm_processor_kwargs is None: mm_processor_kwargs = {} - return mm_processor.apply(prompt, mm_data, mm_processor_kwargs, - return_mm_hashes) + return mm_processor.apply( + prompt, mm_data, mm_processor_kwargs, return_mm_hashes + ) def _prompt_to_llm_inputs( self, @@ -478,8 +511,10 @@ def _build_enc_dec_llm_inputs( encoder_inputs: SingletonInputs, decoder_inputs: Optional[SingletonInputs], ) -> EncoderDecoderInputs: - if (encoder_inputs["type"] == "token" - or encoder_inputs["type"] == "multimodal"): + if ( + encoder_inputs["type"] == "token" + or encoder_inputs["type"] == "multimodal" + ): pass else: assert_never(encoder_inputs) # type: ignore[arg-type] @@ -493,17 +528,23 @@ def _build_enc_dec_llm_inputs( dec_token_ids = encoder_inputs["prompt_token_ids"].copy() else: dec_token_ids = self._prepare_decoder_input_ids_for_generation( - None) + None + ) decoder_inputs = token_inputs(dec_token_ids) - elif (decoder_inputs["type"] == "token" - or decoder_inputs["type"] == "multimodal"): + elif ( + decoder_inputs["type"] == "token" + or decoder_inputs["type"] == "multimodal" + ): dec_token_ids = self._prepare_decoder_input_ids_for_generation( - decoder_inputs["prompt_token_ids"]) + decoder_inputs["prompt_token_ids"] + ) decoder_inputs["prompt_token_ids"] = dec_token_ids if "multi_modal_data" in decoder_inputs: - raise ValueError("Multi-modal decoder inputs of encoder-" - "decoder models are not supported yet") + raise ValueError( + "Multi-modal decoder inputs of encoder-" + "decoder models are not supported yet" + ) else: assert_never(encoder_inputs) # type: ignore[arg-type] @@ -525,8 +566,10 @@ def _separate_enc_dec_inputs_from_mm_processor_outputs( decoder_inputs: SingletonInputs if inputs["type"] == "multimodal": # Multimodal data inputs - assert ("encoder_prompt" in inputs - and "encoder_prompt_token_ids" in inputs) + assert ( + "encoder_prompt" in inputs + and "encoder_prompt_token_ids" in inputs + ) inputs = cast(MultiModalEncDecInputs, inputs) encoder_inputs = token_inputs( prompt=inputs["encoder_prompt"], @@ -537,7 +580,8 @@ def _separate_enc_dec_inputs_from_mm_processor_outputs( type="multimodal", prompt=decoder_inputs_to_override.get("prompt", ""), prompt_token_ids=decoder_inputs_to_override[ - "prompt_token_ids"], + "prompt_token_ids" + ], mm_kwargs=inputs["mm_kwargs"], mm_placeholders=inputs["mm_placeholders"], ) @@ -611,21 +655,27 @@ def _process_encoder_decoder_prompt( # For multimodal model, override decoder prompt from processor # with explicit decoder prompt. if self.model_config.is_multimodal_model and ( - self._can_process_multimodal()): + self._can_process_multimodal() + ): encoder_inputs, decoder_inputs = ( self._separate_enc_dec_inputs_from_mm_processor_outputs( - encoder_inputs, decoder_inputs)) + encoder_inputs, decoder_inputs + ) + ) else: inputs = self._prompt_to_llm_inputs( prompt, request_id=request_id, ) if self.model_config.is_multimodal_model and ( - self._can_process_multimodal()): + self._can_process_multimodal() + ): # Encoder-Decoder Multimodal model encoder_inputs, decoder_inputs = ( self._separate_enc_dec_inputs_from_mm_processor_outputs( - inputs)) + inputs + ) + ) else: encoder_inputs = inputs @@ -658,26 +708,33 @@ async def _process_encoder_decoder_prompt_async( ) encoder_inputs, decoder_inputs = await asyncio.gather( - encoder_task, decoder_task) + encoder_task, decoder_task + ) # For multimodal model, override decoder prompt from processor # with explicit decoder prompt. if self.model_config.is_multimodal_model and ( - self._can_process_multimodal()): + self._can_process_multimodal() + ): encoder_inputs, decoder_inputs = ( self._separate_enc_dec_inputs_from_mm_processor_outputs( - encoder_inputs, decoder_inputs)) + encoder_inputs, decoder_inputs + ) + ) else: inputs = await self._prompt_to_llm_inputs_async( prompt, request_id=request_id, ) if self.model_config.is_multimodal_model and ( - self._can_process_multimodal()): + self._can_process_multimodal() + ): # Encoder-Decoder Multimodal model encoder_inputs, decoder_inputs = ( self._separate_enc_dec_inputs_from_mm_processor_outputs( - inputs)) + inputs + ) + ) else: encoder_inputs = inputs @@ -690,8 +747,10 @@ def _build_decoder_only_llm_inputs( prompt_inputs: DecoderOnlyInputs, prompt_adapter_request: Optional[PromptAdapterRequest], ) -> DecoderOnlyInputs: - if (prompt_inputs["type"] == "token" - or prompt_inputs["type"] == "multimodal"): + if ( + prompt_inputs["type"] == "token" + or prompt_inputs["type"] == "multimodal" + ): prompt_inputs["prompt_token_ids"] = self._apply_prompt_adapter( prompt_inputs["prompt_token_ids"], prompt_adapter_request=prompt_adapter_request, @@ -771,7 +830,8 @@ def preprocess( if self.model_config.is_encoder_decoder: assert not return_mm_hashes, ( "Multimodal hashes for encoder-decoder models should not be ", - "returned until they are supported on vLLM V1.") + "returned until they are supported on vLLM V1.", + ) # Encoder-decoder model requires special mapping of # input prompts to encoder & decoder return self._process_encoder_decoder_prompt( @@ -780,8 +840,9 @@ def preprocess( ) if is_explicit_encoder_decoder_prompt(prompt): - raise ValueError("Cannot pass encoder-decoder prompt " - "to decoder-only models") + raise ValueError( + "Cannot pass encoder-decoder prompt to decoder-only models" + ) # Decoder-only operation return self._process_decoder_only_prompt( @@ -804,7 +865,8 @@ async def preprocess_async( if self.model_config.is_encoder_decoder: assert not return_mm_hashes, ( "Multimodal hashes for encoder-decoder models should not be ", - "returned until they are supported on vLLM V1.") + "returned until they are supported on vLLM V1.", + ) # Encoder-decoder model requires special mapping of # input prompts to encoder & decoder return await self._process_encoder_decoder_prompt_async( @@ -813,8 +875,9 @@ async def preprocess_async( ) if is_explicit_encoder_decoder_prompt(prompt): - raise ValueError("Cannot pass encoder-decoder prompt " - "to decoder-only models") + raise ValueError( + "Cannot pass encoder-decoder prompt to decoder-only models" + ) # Decoder-only operation return await self._process_decoder_only_prompt_async( diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index b6ceb5fb82d7..6fef2bda1d86 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -4,8 +4,15 @@ from collections import UserDict from collections.abc import Mapping from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, Callable, NamedTuple, Optional, - Protocol, Union) +from typing import ( + TYPE_CHECKING, + Any, + Callable, + NamedTuple, + Optional, + Protocol, + Union, +) from torch import nn from transformers import BatchFeature, PretrainedConfig, ProcessorMixin @@ -13,18 +20,26 @@ from vllm.logger import init_logger from vllm.transformers_utils.processor import cached_processor_from_config -from vllm.transformers_utils.tokenizer import (AnyTokenizer, - cached_tokenizer_from_config) -from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides, - resolve_mm_processor_kwargs) +from vllm.transformers_utils.tokenizer import ( + AnyTokenizer, + cached_tokenizer_from_config, +) +from vllm.utils import ( + ClassRegistry, + get_allowed_kwarg_only_overrides, + resolve_mm_processor_kwargs, +) from .data import ProcessorInputs, SingletonInputs from .parse import is_encoder_decoder_inputs if TYPE_CHECKING: from vllm.config import ModelConfig - from vllm.multimodal import (MultiModalDataDict, MultiModalPlaceholderDict, - MultiModalRegistry) + from vllm.multimodal import ( + MultiModalDataDict, + MultiModalPlaceholderDict, + MultiModalRegistry, + ) from vllm.sequence import SequenceData logger = init_logger(__name__) @@ -59,9 +74,11 @@ def get_hf_config( """ hf_config = self.model_config.hf_config if not isinstance(hf_config, typ): - raise TypeError("Invalid type of HuggingFace config. " - f"Expected type: {typ}, but " - f"found type: {type(hf_config)}") + raise TypeError( + "Invalid type of HuggingFace config. " + f"Expected type: {typ}, but " + f"found type: {type(hf_config)}" + ) return hf_config @@ -167,8 +184,10 @@ def call_hf_processor( try: return hf_processor(**data, **merged_kwargs, return_tensors="pt") except Exception as exc: - msg = (f"Failed to apply {type(hf_processor).__name__} " - f"on data={data} with kwargs={merged_kwargs}") + msg = ( + f"Failed to apply {type(hf_processor).__name__} " + f"on data={data} with kwargs={merged_kwargs}" + ) raise RuntimeError(msg) from exc @@ -185,7 +204,6 @@ class DummyData(NamedTuple): class DummyDataFactory(Protocol): - def __call__( self, ctx: InputContext, @@ -216,8 +234,10 @@ def __getitem__(self, key: str) -> int: try: return super().__getitem__(key) except KeyError as exc: - msg = (f"There is no multi-modal plugin with the key: {key}. " - f"Available keys: {set(self.keys())}") + msg = ( + f"There is no multi-modal plugin with the key: {key}. " + f"Available keys: {set(self.keys())}" + ) raise KeyError(msg) from exc @@ -232,12 +252,15 @@ class InputRegistry: """ def __init__(self) -> None: - self._dummy_factories_by_model_type = \ - ClassRegistry[nn.Module, DummyDataFactory]() - self._dummy_encoder_factories_by_model_type = \ - ClassRegistry[nn.Module, DummyDataFactory]() - self._input_processors_by_model_type = \ - ClassRegistry[nn.Module, InputProcessor]() + self._dummy_factories_by_model_type = ClassRegistry[ + nn.Module, DummyDataFactory + ]() + self._dummy_encoder_factories_by_model_type = ClassRegistry[ + nn.Module, DummyDataFactory + ]() + self._input_processors_by_model_type = ClassRegistry[ + nn.Module, InputProcessor + ]() def _default_dummy_data_factory( self, @@ -267,12 +290,15 @@ def register_dummy_data(self, factory: DummyDataFactory): """ def wrapper(model_cls: N) -> N: - if self._dummy_factories_by_model_type.contains(model_cls, - strict=True): + if self._dummy_factories_by_model_type.contains( + model_cls, strict=True + ): logger.warning( "Model class %s already has dummy data " "registered to %s. It is overwritten by the new one.", - model_cls, self) + model_cls, + self, + ) self._dummy_factories_by_model_type[model_cls] = factory @@ -281,8 +307,9 @@ def wrapper(model_cls: N) -> N: return wrapper def _get_dummy_data_factory(self, model_cls: type[nn.Module]): - return self._dummy_factories_by_model_type \ - .get(model_cls, self._default_dummy_data_factory) + return self._dummy_factories_by_model_type.get( + model_cls, self._default_dummy_data_factory + ) def register_dummy_encoder_data(self, factory: DummyDataFactory): """ @@ -293,11 +320,14 @@ def register_dummy_encoder_data(self, factory: DummyDataFactory): def wrapper(model_cls: N) -> N: if self._dummy_encoder_factories_by_model_type.contains( - model_cls, strict=True): + model_cls, strict=True + ): logger.warning( "Model class %s already has dummy encoder data " "registered to %s. It is overwritten by the new one.", - model_cls, self) + model_cls, + self, + ) self._dummy_encoder_factories_by_model_type[model_cls] = factory @@ -306,8 +336,9 @@ def wrapper(model_cls: N) -> N: return wrapper def _get_dummy_encoder_data_factory(self, model_cls: type[nn.Module]): - return self._dummy_encoder_factories_by_model_type \ - .get(model_cls, self._default_dummy_data_factory) + return self._dummy_encoder_factories_by_model_type.get( + model_cls, self._default_dummy_data_factory + ) def dummy_data_for_profiling( self, @@ -332,13 +363,15 @@ def dummy_data_for_profiling( if mm_registry.has_processor(model_config): tokenizer = cached_tokenizer_from_config(model_config) - processor = mm_registry.create_processor(model_config, - tokenizer, - disable_cache=True) + processor = mm_registry.create_processor( + model_config, tokenizer, disable_cache=True + ) profiler = MultiModalProfiler(processor) - dummy_data_factory = (profiler.get_encoder_dummy_data - if is_encoder_data else - profiler.get_decoder_dummy_data) + dummy_data_factory = ( + profiler.get_encoder_dummy_data + if is_encoder_data + else profiler.get_decoder_dummy_data + ) dummy_data = dummy_data_factory(seq_len) else: model_cls, _ = get_model_architecture(model_config) @@ -354,9 +387,12 @@ def dummy_data_for_profiling( allow_var_kwargs=True, ) - dummy_data = dummy_factory(InputContext(model_config), seq_len, - _MultiModalCounts(mm_counts), - **mm_processor_kwargs) + dummy_data = dummy_factory( + InputContext(model_config), + seq_len, + _MultiModalCounts(mm_counts), + **mm_processor_kwargs, + ) # Having more tokens is over-conservative but otherwise fine num_tokens = dummy_data.seq_data.prompt_token_ids @@ -364,20 +400,24 @@ def dummy_data_for_profiling( if is_encoder_data: logger.warning_once( f"Expected at least {seq_len} dummy encoder tokens for " - f"profiling, but found {len(num_tokens)} tokens instead.") + f"profiling, but found {len(num_tokens)} tokens instead." + ) else: raise AssertionError( f"Expected at least {seq_len} dummy tokens for profiling, " - f"but found {len(num_tokens)} tokens instead.") + f"but found {len(num_tokens)} tokens instead." + ) - if (dummy_data.multi_modal_data is not None and - not isinstance(dummy_data.multi_modal_data, MultiModalKwargs)): + if dummy_data.multi_modal_data is not None and not isinstance( + dummy_data.multi_modal_data, MultiModalKwargs + ): for k, v in dummy_data.multi_modal_data.items(): num_items = len(v) if isinstance(v, list) else 1 num_expected = mm_counts[k] assert num_items >= num_expected, ( f"Expected at least {num_expected} dummy '{k}' instances " - f"for profiling, but found {num_items} instances instead.") + f"for profiling, but found {num_items} instances instead." + ) return dummy_data @@ -400,12 +440,15 @@ def register_input_processor(self, processor: InputProcessor): """ def wrapper(model_cls: N) -> N: - if self._input_processors_by_model_type.contains(model_cls, - strict=True): + if self._input_processors_by_model_type.contains( + model_cls, strict=True + ): logger.warning( "Model class %s already has input processor " "registered to %s. It is overwritten by the new one.", - model_cls, self) + model_cls, + self, + ) self._input_processors_by_model_type[model_cls] = processor @@ -414,8 +457,9 @@ def wrapper(model_cls: N) -> N: return wrapper def _get_model_input_processor(self, model_cls: type[nn.Module]): - return self._input_processors_by_model_type \ - .get(model_cls, self._default_input_processor) + return self._input_processors_by_model_type.get( + model_cls, self._default_input_processor + ) def _ensure_mm_kwargs( self, @@ -432,8 +476,9 @@ def _ensure_mm_kwargs( else: assert_never(inputs["type"]) # type: ignore[arg-type] - def process_input(self, model_config: "ModelConfig", - inputs: ProcessorInputs) -> ProcessorInputs: + def process_input( + self, model_config: "ModelConfig", inputs: ProcessorInputs + ) -> ProcessorInputs: """ Apply an input processor to an instance of model inputs. @@ -463,10 +508,12 @@ def process_input(self, model_config: "ModelConfig", ) if is_encoder_decoder_inputs(processed_inputs): - self._ensure_mm_kwargs(processed_inputs["encoder"], - mm_processor_kwargs) - self._ensure_mm_kwargs(processed_inputs["decoder"], - mm_processor_kwargs) + self._ensure_mm_kwargs( + processed_inputs["encoder"], mm_processor_kwargs + ) + self._ensure_mm_kwargs( + processed_inputs["decoder"], mm_processor_kwargs + ) else: self._ensure_mm_kwargs(processed_inputs, mm_processor_kwargs) diff --git a/vllm/multimodal/__init__.py b/vllm/multimodal/__init__.py index 741bd1a6a1c1..d887e4311de0 100644 --- a/vllm/multimodal/__init__.py +++ b/vllm/multimodal/__init__.py @@ -2,9 +2,15 @@ from .base import MultiModalPlaceholderMap, MultiModalPlugin from .hasher import MultiModalHashDict, MultiModalHasher -from .inputs import (BatchedTensorInputs, ModalityData, MultiModalDataBuiltins, - MultiModalDataDict, MultiModalKwargs, - MultiModalPlaceholderDict, NestedTensors) +from .inputs import ( + BatchedTensorInputs, + ModalityData, + MultiModalDataBuiltins, + MultiModalDataDict, + MultiModalKwargs, + MultiModalPlaceholderDict, + NestedTensors, +) from .registry import MultiModalRegistry MULTIMODAL_REGISTRY = MultiModalRegistry() diff --git a/vllm/multimodal/audio.py b/vllm/multimodal/audio.py index f379ec1682a3..f2e9ab3cad03 100644 --- a/vllm/multimodal/audio.py +++ b/vllm/multimodal/audio.py @@ -40,7 +40,8 @@ def _default_input_mapper( def _default_max_multimodal_tokens(self, ctx: InputContext) -> int: raise NotImplementedError( - "There is no default maximum multimodal tokens") + "There is no default maximum multimodal tokens" + ) def resample_audio( @@ -53,7 +54,6 @@ def resample_audio( class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]): - def load_bytes(self, data: bytes) -> tuple[npt.NDArray, float]: return librosa.load(BytesIO(data), sr=None) @@ -74,4 +74,4 @@ def encode_base64(self, media: tuple[npt.NDArray, float]) -> str: soundfile.write(buffer, audio, sr, format="WAV") data = buffer.getvalue() - return base64.b64encode(data).decode('utf-8') + return base64.b64encode(data).decode("utf-8") diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index 5159b0bca8c1..f2e402d7e405 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -4,27 +4,43 @@ from collections import defaultdict from collections.abc import Sequence from pathlib import Path -from typing import (TYPE_CHECKING, Any, Callable, Generic, NamedTuple, - Optional, TypeVar, Union) +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Generic, + NamedTuple, + Optional, + TypeVar, + Union, +) from torch import nn from vllm.inputs import InputContext from vllm.logger import init_logger -from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides, - resolve_mm_processor_kwargs) +from vllm.utils import ( + ClassRegistry, + get_allowed_kwarg_only_overrides, + resolve_mm_processor_kwargs, +) if TYPE_CHECKING: from vllm.config import ModelConfig from vllm.sequence import SequenceGroupMetadata -from .inputs import (ModalityData, MultiModalDataDict, MultiModalKwargs, - PlaceholderRange) +from .inputs import ( + ModalityData, + MultiModalDataDict, + MultiModalKwargs, + PlaceholderRange, +) logger = init_logger(__name__) -MultiModalInputMapper = Callable[[InputContext, ModalityData[object]], - MultiModalKwargs] +MultiModalInputMapper = Callable[ + [InputContext, ModalityData[object]], MultiModalKwargs +] """ Return a dictionary to be passed as keyword arguments to :meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers @@ -104,8 +120,9 @@ def wrapper(model_cls: N) -> N: self, ) - self._input_mappers[model_cls] = (mapper - or self._default_input_mapper) + self._input_mappers[model_cls] = ( + mapper or self._default_input_mapper + ) return model_cls @@ -135,8 +152,10 @@ def map_input( mapper = self._input_mappers.get(model_cls) if mapper is None: - raise KeyError(f"No input mapper in {self} is registered for " - f"model class {model_cls.__name__}.") + raise KeyError( + f"No input mapper in {self} is registered for " + f"model class {model_cls.__name__}." + ) if mm_processor_kwargs is None: mm_processor_kwargs = {} @@ -168,8 +187,10 @@ def _default_max_multimodal_tokens(self, ctx: InputContext) -> int: def _validate_max_multimodal_tokens(self, max_mm_tokens: int): if max_mm_tokens < 1: - raise ValueError("You should set the number of tokens to a " - f"positive integer. Found: {max_mm_tokens}") + raise ValueError( + "You should set the number of tokens to a " + f"positive integer. Found: {max_mm_tokens}" + ) def register_max_multimodal_tokens( self, @@ -196,7 +217,8 @@ def wrapper(model_cls: N) -> N: self._validate_max_multimodal_tokens(max_mm_tokens) self._max_mm_tokens[model_cls] = ( - max_mm_tokens or self._default_max_multimodal_tokens) + max_mm_tokens or self._default_max_multimodal_tokens + ) return model_cls @@ -231,8 +253,9 @@ def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int: requires_kw_only=False, allow_var_kwargs=True, ) - max_mm_tokens = max_mm_tokens(InputContext(model_config), - **mm_processor_kwargs) + max_mm_tokens = max_mm_tokens( + InputContext(model_config), **mm_processor_kwargs + ) self._validate_max_multimodal_tokens(max_mm_tokens) @@ -279,8 +302,9 @@ def __init__(self): @classmethod def from_seq_group( cls, seq_group: "SequenceGroupMetadata", positions: range - ) -> tuple[Optional[MultiModalDataDict], dict[str, - "MultiModalPlaceholderMap"]]: + ) -> tuple[ + Optional[MultiModalDataDict], dict[str, "MultiModalPlaceholderMap"] + ]: """ Returns the multi-modal items that intersect with the portion of a prompt (``seq_group``) represented by ``positions``, as well as a @@ -346,7 +370,8 @@ def from_seq_group( mm_data = {**seq_mm_data} placeholder_maps = defaultdict[str, MultiModalPlaceholderMap]( - MultiModalPlaceholderMap) + MultiModalPlaceholderMap + ) for modality, placeholders in seq_mm_placeholders.items(): mm_items = mm_data.pop(modality) @@ -354,12 +379,13 @@ def from_seq_group( mm_items = [mm_items] if positions: - intersecting_items = placeholder_maps[modality] \ - .append_items_from_seq_group( - positions, - mm_items, - placeholders, - ) + intersecting_items = placeholder_maps[ + modality + ].append_items_from_seq_group( + positions, + mm_items, + placeholders, + ) if intersecting_items: mm_data[modality] = intersecting_items @@ -382,8 +408,9 @@ def append_items_from_seq_group( raise ValueError( "Multi-modal placeholders and items must have the same length." ) - for placeholder_dict, mm_item in zip(multi_modal_placeholders, - multi_modal_items): + for placeholder_dict, mm_item in zip( + multi_modal_placeholders, multi_modal_items + ): placeholder = range( placeholder_dict["offset"], placeholder_dict["offset"] + placeholder_dict["length"], @@ -424,11 +451,13 @@ def extend(self, other: "MultiModalPlaceholderMap"): self.src_ranges.extend( range(self.src_len + r.start, self.src_len + r.stop) - for r in other.src_ranges) + for r in other.src_ranges + ) self.src_len += other.src_len self.dest_ranges.extend( range(self.dest_len + r.start, self.dest_len + r.stop) - for r in other.dest_ranges) + for r in other.dest_ranges + ) self.dest_len += other.dest_len def index_map(self) -> "IndexMap": @@ -443,14 +472,15 @@ def index_map(self) -> "IndexMap": if len(src_indices) != len(dest_indices): raise ValueError( f"The number of source ({len(src_indices)}) and destination " - f"indices ({len(dest_indices)}) must be the same.") + f"indices ({len(dest_indices)}) must be the same." + ) - return MultiModalPlaceholderMap.IndexMap(src=src_indices, - dest=dest_indices) + return MultiModalPlaceholderMap.IndexMap( + src=src_indices, dest=dest_indices + ) class MediaIO(ABC, Generic[_T]): - @abstractmethod def load_bytes(self, data: bytes) -> _T: raise NotImplementedError diff --git a/vllm/multimodal/hasher.py b/vllm/multimodal/hasher.py index 11665ef66753..0311a3f0e07f 100644 --- a/vllm/multimodal/hasher.py +++ b/vllm/multimodal/hasher.py @@ -23,7 +23,6 @@ class MultiModalHasher: - @classmethod def serialize_item(cls, obj: object) -> bytes: # Simple cases @@ -43,8 +42,9 @@ def serialize_item(cls, obj: object) -> bytes: return obj.tobytes() logger.warning( - "No serialization method found for %s. " - "Falling back to pickle.", type(obj)) + "No serialization method found for %s. Falling back to pickle.", + type(obj), + ) return pickle.dumps(obj) @@ -79,7 +79,8 @@ def hash_kwargs(cls, **kwargs: object) -> str: @classmethod def hash_prompt_mm_data( - cls, prompt: "TokensPrompt") -> Optional["MultiModalHashDict"]: + cls, prompt: "TokensPrompt" + ) -> Optional["MultiModalHashDict"]: """Hash multimodal data in the user input prompt if they exist.""" if "multi_modal_data" not in prompt: diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index 255fac30bd78..054f752afbb2 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -38,7 +38,8 @@ def _get_hf_image_processor( return cached_get_image_processor( model_config.model, trust_remote_code=model_config.trust_remote_code, - **mm_processor_kwargs) + **mm_processor_kwargs, + ) def _default_input_mapper( self, @@ -56,24 +57,27 @@ def _default_input_mapper( ) if image_processor is None: - raise RuntimeError("No HuggingFace processor is available " - "to process the image object") + raise RuntimeError( + "No HuggingFace processor is available " + "to process the image object" + ) try: # NOTE: It may make sense to forward the mm_processor_kwargs # here too. For now, to keep it simple, we only allow it be # used for the initialization call though, just in case the # signatures of the preprocessor initializer don't match # preprocess() - batch_data = image_processor \ - .preprocess(data, return_tensors="pt") \ - .data + batch_data = image_processor.preprocess( + data, return_tensors="pt" + ).data except Exception: logger.error( "Failed to process image (%s) with the default mapper. " "This is most likely an edge-case with this model's image " "processor in transformers (type: %s), and not vLLM.", data, - type(image_processor).__name__) + type(image_processor).__name__, + ) raise return MultiModalKwargs(batch_data) @@ -88,9 +92,9 @@ def _default_max_multimodal_tokens(self, ctx: InputContext) -> int: return 3000 -def rescale_image_size(image: Image.Image, - size_factor: float, - transpose: int = -1) -> Image.Image: +def rescale_image_size( + image: Image.Image, size_factor: float, transpose: int = -1 +) -> Image.Image: """Rescale the dimensions of an image by a constant factor.""" new_width = int(image.width * size_factor) new_height = int(image.height * size_factor) @@ -101,7 +105,6 @@ def rescale_image_size(image: Image.Image, class ImageMediaIO(MediaIO[Image.Image]): - def __init__(self, *, image_mode: str = "RGB") -> None: super().__init__() @@ -133,11 +136,10 @@ def encode_base64( image.save(buffer, image_format) data = buffer.getvalue() - return base64.b64encode(data).decode('utf-8') + return base64.b64encode(data).decode("utf-8") class ImageEmbeddingMediaIO(MediaIO[torch.Tensor]): - def __init__(self) -> None: super().__init__() @@ -152,4 +154,4 @@ def load_file(self, filepath: Path) -> torch.Tensor: return torch.load(filepath) def encode_base64(self, media: torch.Tensor) -> str: - return base64.b64encode(media.numpy()).decode('utf-8') + return base64.b64encode(media.numpy()).decode("utf-8") diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 3c609fd96765..072e753e5c2f 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -6,8 +6,17 @@ from dataclasses import dataclass from functools import partial from itertools import accumulate -from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, TypeVar, - Union, cast, final) +from typing import ( + TYPE_CHECKING, + Any, + Literal, + Optional, + TypedDict, + TypeVar, + Union, + cast, + final, +) import numpy as np import torch @@ -30,8 +39,9 @@ item, which can be passed to a HuggingFace :code:`ImageProcessor`. """ -HfVideoItem: TypeAlias = Union[list[Image], np.ndarray, torch.Tensor, - list[np.ndarray], list[torch.Tensor]] +HfVideoItem: TypeAlias = Union[ + list[Image], np.ndarray, torch.Tensor, list[np.ndarray], list[torch.Tensor] +] """ A :class:`transformers.image_utils.VideoInput` representing a single video item, which can be passed to a HuggingFace :code:`VideoProcessor`. @@ -63,8 +73,9 @@ these are directly passed to the model without HF processing. """ -AudioItem: TypeAlias = Union[HfAudioItem, tuple[np.ndarray, float], - torch.Tensor] +AudioItem: TypeAlias = Union[ + HfAudioItem, tuple[np.ndarray, float], torch.Tensor +] """ Represents a single audio item, which can be passed to a HuggingFace :code:`AudioProcessor`. @@ -132,8 +143,12 @@ class PlaceholderRange(TypedDict): """The length of the placeholder.""" -NestedTensors = Union[list["NestedTensors"], list[torch.Tensor], torch.Tensor, - tuple[torch.Tensor, ...]] +NestedTensors = Union[ + list["NestedTensors"], + list[torch.Tensor], + torch.Tensor, + tuple[torch.Tensor, ...], +] """ Uses a list instead of a tensor if the dimensions of each element do not match. """ @@ -147,11 +162,13 @@ def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool: return isinstance(a, torch.Tensor) and torch.equal(b, a) if isinstance(a, list): - return (isinstance(b, list) - and all(nested_tensors_equal(a_, b_) for a_, b_ in zip(a, b))) + return isinstance(b, list) and all( + nested_tensors_equal(a_, b_) for a_, b_ in zip(a, b) + ) if isinstance(b, list): - return (isinstance(a, list) - and all(nested_tensors_equal(b_, a_) for b_, a_ in zip(b, a))) + return isinstance(a, list) and all( + nested_tensors_equal(b_, a_) for b_, a_ in zip(b, a) + ) # Both a and b are scalars return a == b @@ -199,9 +216,11 @@ def __eq__(self, other: object) -> bool: if not isinstance(other, self.__class__): return False - return ((self.modality, self.key) == (other.modality, other.key) - and nested_tensors_equal(self.data, other.data) - and type(self.field) == type(other.field)) # noqa: E721 + return ( + (self.modality, self.key) == (other.modality, other.key) + and nested_tensors_equal(self.data, other.data) + and type(self.field) == type(other.field) # noqa: E721 + ) @dataclass(frozen=True) @@ -235,7 +254,7 @@ def build_elems( """ Construct :class:`MultiModalFieldElem` instances to represent the provided data. - + This is the inverse of :meth:`reduce_data`. """ raise NotImplementedError @@ -294,6 +313,7 @@ class MultiModalFlatField(BaseMultiModalField): :func:`MultiModalFieldConfig.flat` :func:`MultiModalFieldConfig.flat_from_sizes` """ + slices: Sequence[slice] def build_elems( @@ -325,6 +345,7 @@ class MultiModalSharedField(BaseMultiModalField): See also: :func:`MultiModalFieldConfig.shared` """ + batch_size: int def build_elems( @@ -341,7 +362,6 @@ def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: class MultiModalFieldConfig: - @staticmethod def batched(modality: str): """ @@ -386,7 +406,7 @@ def flat(modality: str, slices: Sequence[slice]): Example: .. code-block:: - + Given: slices: [slice(0, 3), slice(3, 7), slice(7, 9)] @@ -418,7 +438,7 @@ def flat_from_sizes(modality: str, size_per_item: torch.Tensor): Example: .. code-block:: - + Given: size_per_item: [3, 4, 2] @@ -429,14 +449,16 @@ def flat_from_sizes(modality: str, size_per_item: torch.Tensor): Element 1: [AAA] Element 2: [BBBB] Element 3: [CC] - + See also: :func:`MultiModalFieldConfig.flat` """ if size_per_item.ndim != 1: - raise ValueError("size_per_item should be a 1-D tensor, " - f"but found shape: {size_per_item.shape}") + raise ValueError( + "size_per_item should be a 1-D tensor, " + f"but found shape: {size_per_item.shape}" + ) slice_idxs = [0, *accumulate(size_per_item)] slices = [ @@ -462,7 +484,7 @@ def shared(modality: str, batch_size: int): Example: .. code-block:: - + Given: batch_size: 4 @@ -548,7 +570,8 @@ def from_hf_inputs( if len(set(batch_sizes.values())) > 1: raise ValueError( f"Cannot merge different batch sizes for {modality=}! " - f"Found: {batch_sizes=}") + f"Found: {batch_sizes=}" + ) batch_size = next(iter(batch_sizes.values())) for item_idx in range(batch_size): @@ -567,7 +590,8 @@ def from_items(items: Sequence[MultiModalKwargsItem]): data = { key: elems[0].field.reduce_data(elems) - for key, elems in elems_by_key.items() if len(elems) > 0 + for key, elems in elems_by_key.items() + if len(elems) > 0 } return MultiModalKwargs(data, items=items) @@ -672,19 +696,23 @@ def __eq__(self, other: object) -> bool: return False ks = self.keys() - return (ks == other.keys() - and all(nested_tensors_equal(self[k], other[k]) for k in ks)) + return ks == other.keys() and all( + nested_tensors_equal(self[k], other[k]) for k in ks + ) def _validate_modality(self, method_name: str, modality: str) -> None: if not self._items_by_modality: raise RuntimeError( f"`{method_name}` is not supported when " - "MultiModalKwargs is not initialized with `items`") + "MultiModalKwargs is not initialized with `items`" + ) if modality not in self._items_by_modality: available_modalities = set(self._items_by_modality.keys()) - raise KeyError(f"Modality {modality!r} not found. " - f"Available modalities: {available_modalities}") + raise KeyError( + f"Modality {modality!r} not found. " + f"Available modalities: {available_modalities}" + ) def get_item_count(self, modality: str) -> int: """Get the number of items belonging to a modality.""" diff --git a/vllm/multimodal/parse.py b/vllm/multimodal/parse.py index 772b1609a9fb..0b0d4ea3212d 100644 --- a/vllm/multimodal/parse.py +++ b/vllm/multimodal/parse.py @@ -3,8 +3,15 @@ from abc import ABC, abstractmethod from collections import UserDict from collections.abc import Callable, Iterator, Mapping, Sequence -from typing import (TYPE_CHECKING, Any, Generic, NamedTuple, Optional, TypeVar, - Union) +from typing import ( + TYPE_CHECKING, + Any, + Generic, + NamedTuple, + Optional, + TypeVar, + Union, +) import numpy as np import torch @@ -15,9 +22,18 @@ from vllm.utils import is_list_of from .audio import resample_audio -from .inputs import (AudioItem, HfAudioItem, HfImageItem, HfVideoItem, - ImageItem, ModalityData, MultiModalDataDict, - MultiModalFieldConfig, MultiModalKwargs, VideoItem) +from .inputs import ( + AudioItem, + HfAudioItem, + HfImageItem, + HfVideoItem, + ImageItem, + ModalityData, + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargs, + VideoItem, +) _T = TypeVar("_T") _I = TypeVar("_I") @@ -35,8 +51,10 @@ def __init__(self, data: _T, modality: str) -> None: self.modality = modality def __repr__(self) -> str: - return (f"{type(self).__name__}(modality={self.modality!r}, " - f"len={len(self)})") + return ( + f"{type(self).__name__}(modality={self.modality!r}, " + f"len={len(self)})" + ) def __len__(self) -> int: return self.get_count() @@ -46,8 +64,7 @@ def __getitem__(self, index: int) -> _I: if TYPE_CHECKING: # Auto-generated - def __iter__(self) -> Iterator[_I]: - ... + def __iter__(self) -> Iterator[_I]: ... @abstractmethod def get_count(self) -> int: @@ -90,8 +107,9 @@ def get_passthrough_data(self) -> Mapping[str, object]: return {} -class EmbeddingItems(ModalityDataItems[Union[torch.Tensor, list[torch.Tensor]], - torch.Tensor]): +class EmbeddingItems( + ModalityDataItems[Union[torch.Tensor, list[torch.Tensor]], torch.Tensor] +): """ Base class for data items that are expressed as a batched embedding tensor, or a list of embedding tensors (one per item). @@ -113,8 +131,9 @@ def get_feature_size(self, item_idx: int) -> int: return len(self.get(item_idx)) -class DictEmbeddingItems(ModalityDataItems[Mapping[str, torch.Tensor], - Mapping[str, torch.Tensor]]): +class DictEmbeddingItems( + ModalityDataItems[Mapping[str, torch.Tensor], Mapping[str, torch.Tensor]] +): """ Base class for data items that are expressed as a dictionary of tensors. @@ -136,8 +155,10 @@ def __init__( missing_required_data_keys = required_fields - data.keys() if missing_required_data_keys: data_keys = set(data.keys()) - msg = (f"The data should contain the fields: {required_fields}, " - f"but only found the following keys: {data_keys}") + msg = ( + f"The data should contain the fields: {required_fields}, " + f"but only found the following keys: {data_keys}" + ) raise ValueError(msg) fields_config = fields_factory(data) @@ -172,7 +193,6 @@ def get_passthrough_data(self) -> Mapping[str, object]: class AudioProcessorItems(ProcessorBatchItems[HfAudioItem]): - def __init__(self, data: Sequence[HfAudioItem]) -> None: super().__init__(data, "audio") @@ -182,7 +202,6 @@ def get_audio_length(self, item_idx: int) -> int: class AudioEmbeddingItems(EmbeddingItems): - def __init__(self, data: Union[torch.Tensor, list[torch.Tensor]]) -> None: super().__init__(data, "audio") @@ -193,7 +212,6 @@ class ImageSize(NamedTuple): class ImageProcessorItems(ProcessorBatchItems[HfImageItem]): - def __init__(self, data: Sequence[HfImageItem]) -> None: super().__init__(data, "image") @@ -210,13 +228,11 @@ def get_image_size(self, item_idx: int) -> ImageSize: class ImageEmbeddingItems(EmbeddingItems): - def __init__(self, data: Union[torch.Tensor, list[torch.Tensor]]) -> None: super().__init__(data, "image") class VideoProcessorItems(ProcessorBatchItems[HfVideoItem]): - def __init__(self, data: Sequence[HfVideoItem]) -> None: super().__init__(data, "video") @@ -236,7 +252,6 @@ def get_frame_size(self, item_idx: int) -> ImageSize: class VideoEmbeddingItems(EmbeddingItems): - def __init__(self, data: Union[torch.Tensor, list[torch.Tensor]]) -> None: super().__init__(data, "video") @@ -253,15 +268,17 @@ class MultiModalDataItems(UserDict[str, ModalityDataItems[Any, Any]]): def get_count(self, modality: str, *, strict: bool = True) -> int: """ Get the number of data items belonging to a modality. - + If `strict=False`, return `0` instead of raising :exc:`KeyError` even if the modality is not found. """ if modality not in self: if strict: available_modalities = set(self.keys()) - raise KeyError(f"Modality {modality!r} not found. " - f"Available modalities: {available_modalities}") + raise KeyError( + f"Modality {modality!r} not found. " + f"Available modalities: {available_modalities}" + ) return 0 @@ -282,20 +299,25 @@ def get_items( """ if modality not in self: available_modalities = set(self.keys()) - raise KeyError(f"Modality {modality!r} not found. " - f"Available modalities: {available_modalities}") + raise KeyError( + f"Modality {modality!r} not found. " + f"Available modalities: {available_modalities}" + ) items = self[modality] if not isinstance(items, typ): - raise TypeError(f"Invalid type of data items for {modality=}. " - f"Expected type: {typ}, but " - f"found type: {type(items)}") + raise TypeError( + f"Invalid type of data items for {modality=}. " + f"Expected type: {typ}, but " + f"found type: {type(items)}" + ) return items # type: ignore[return-value] -ModalityDataParser: TypeAlias = Callable[[ModalityData[Any]], - ModalityDataItems[Any, Any]] +ModalityDataParser: TypeAlias = Callable[ + [ModalityData[Any]], ModalityDataItems[Any, Any] +] class MultiModalDataParser: @@ -314,7 +336,7 @@ def __init__(self, *, target_sr: Optional[float] = None) -> None: self.target_sr = target_sr def _is_embeddings( - self, data: object + self, data: object ) -> TypeGuard[Union[torch.Tensor, list[torch.Tensor]]]: if isinstance(data, torch.Tensor): return data.ndim == 3 @@ -345,10 +367,12 @@ def _parse_audio_data( if self._is_embeddings(data): return AudioEmbeddingItems(data) - if (is_list_of(data, float) - or isinstance(data, - (np.ndarray, torch.Tensor)) and data.ndim == 1 - or isinstance(data, tuple)): + if ( + is_list_of(data, float) + or isinstance(data, (np.ndarray, torch.Tensor)) + and data.ndim == 1 + or isinstance(data, tuple) + ): data_items = [data] elif isinstance(data, (np.ndarray, torch.Tensor)): data_items = [elem for elem in data] @@ -365,11 +389,12 @@ def _parse_audio_data( if target_sr is None: raise RuntimeError( "Audio resampling is not supported when " - "`target_sr` is not provided") + "`target_sr` is not provided" + ) - new_audio = resample_audio(audio, - orig_sr=orig_sr, - target_sr=target_sr) + new_audio = resample_audio( + audio, orig_sr=orig_sr, target_sr=target_sr + ) new_audios.append(new_audio) @@ -382,9 +407,11 @@ def _parse_image_data( if self._is_embeddings(data): return ImageEmbeddingItems(data) - if (isinstance(data, Image) - or isinstance(data, - (np.ndarray, torch.Tensor)) and data.ndim == 3): + if ( + isinstance(data, Image) + or isinstance(data, (np.ndarray, torch.Tensor)) + and data.ndim == 3 + ): data_items = [data] elif isinstance(data, (np.ndarray, torch.Tensor)): data_items = [elem for elem in data] @@ -400,9 +427,11 @@ def _parse_video_data( if self._is_embeddings(data): return VideoEmbeddingItems(data) - if (is_list_of(data, Image) - or isinstance(data, - (np.ndarray, torch.Tensor)) and data.ndim == 4): + if ( + is_list_of(data, Image) + or isinstance(data, (np.ndarray, torch.Tensor)) + and data.ndim == 4 + ): data_items = [data] elif isinstance(data, (np.ndarray, torch.Tensor)): data_items = [elem for elem in data] @@ -418,8 +447,7 @@ def _get_subparsers(self) -> Mapping[str, ModalityDataParser]: "video": self._parse_video_data, } - def parse_mm_data(self, - mm_data: MultiModalDataDict) -> MultiModalDataItems: + def parse_mm_data(self, mm_data: MultiModalDataDict) -> MultiModalDataItems: subparsers = self._get_subparsers() mm_items = MultiModalDataItems() diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index b400e2701ac3..8dfdcc7d9d02 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -3,13 +3,27 @@ import sys from abc import ABC, abstractmethod from collections import defaultdict -from collections.abc import (Callable, Generator, ItemsView, Iterable, Mapping, - Sequence) +from collections.abc import ( + Callable, + Generator, + ItemsView, + Iterable, + Mapping, + Sequence, +) from dataclasses import dataclass, field from enum import Enum from functools import lru_cache -from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol, - TypeVar, Union, cast) +from typing import ( + TYPE_CHECKING, + Generic, + NamedTuple, + Optional, + Protocol, + TypeVar, + Union, + cast, +) import torch from cachetools import LRUCache @@ -19,16 +33,29 @@ from vllm.inputs import InputProcessingContext from vllm.jsontree import json_map_leaves, json_reduce_leaves from vllm.logger import init_logger -from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens, - encode_tokens) +from vllm.transformers_utils.tokenizer import ( + AnyTokenizer, + decode_tokens, + encode_tokens, +) from vllm.utils import GiB_bytes, flatten_2d_lists, full_groupby from .hasher import MultiModalHasher -from .inputs import (MultiModalDataDict, MultiModalEncDecInputs, - MultiModalFieldConfig, MultiModalInputs, MultiModalKwargs, - MultiModalKwargsItem, PlaceholderRange) -from .parse import (DictEmbeddingItems, EmbeddingItems, MultiModalDataItems, - MultiModalDataParser) +from .inputs import ( + MultiModalDataDict, + MultiModalEncDecInputs, + MultiModalFieldConfig, + MultiModalInputs, + MultiModalKwargs, + MultiModalKwargsItem, + PlaceholderRange, +) +from .parse import ( + DictEmbeddingItems, + EmbeddingItems, + MultiModalDataItems, + MultiModalDataParser, +) if TYPE_CHECKING: from .profiling import BaseDummyInputsBuilder @@ -44,11 +71,11 @@ @dataclass class PromptIndex: """Resolves to an index in the prompt.""" + get_match_index: Callable[[AnyTokenizer, PromptSeq], Optional[int]] class PromptIndexTargets: - @staticmethod def start() -> PromptIndex: """ @@ -77,9 +104,9 @@ def get_match_index( else: if isinstance(prefix, str): # Make both `list[int]` - prefix = encode_tokens(tokenizer, - prefix, - add_special_tokens=False) + prefix = encode_tokens( + tokenizer, prefix, add_special_tokens=False + ) match_idx = len(prefix) return match_idx if prompt[:match_idx] == prefix else None @@ -129,8 +156,7 @@ def from_seq(seq: PromptSeq) -> "PromptUpdateDetails": use :class:`PromptUpdateDetails` to specify which part. """ -PromptUpdateContent = Union[Callable[[int], PromptUpdateInfo], - PromptUpdateInfo] +PromptUpdateContent = Union[Callable[[int], PromptUpdateInfo], PromptUpdateInfo] """ Given the index of the processed item within :attr:`modality`, output the corresponding token sequence (or text). @@ -272,11 +298,13 @@ class PromptReplacement(PromptUpdate): modality="image", target="", replacement=PromptUpdateDetails( - full="".join([ - "", - "" * image_feature_size, - "", - ]), + full="".join( + [ + "", + "" * image_feature_size, + "", + ] + ), features="" * image_feature_size, ), ) @@ -290,8 +318,11 @@ class PromptReplacement(PromptUpdate): modality="image", target=[image_token_id], replacement=PromptUpdateDetails( - full=([image_bos_id] + [image_token_id] * image_feature_size - + [image_eos_id]), + full=( + [image_bos_id] + + [image_token_id] * image_feature_size + + [image_eos_id] + ), features=[image_token_id] * image_feature_size, ), ) @@ -322,9 +353,7 @@ def _cached_encode( *, add_special_tokens: Optional[bool] = None, ) -> list[int]: - return encode_tokens(tokenizer, - text, - add_special_tokens=add_special_tokens) + return encode_tokens(tokenizer, text, add_special_tokens=add_special_tokens) @lru_cache(maxsize=2048) @@ -334,9 +363,9 @@ def _cached_decode( *, skip_special_tokens: Optional[bool] = None, ) -> str: - return decode_tokens(tokenizer, - list(token_ids), - skip_special_tokens=skip_special_tokens) + return decode_tokens( + tokenizer, list(token_ids), skip_special_tokens=skip_special_tokens + ) class _HasModalityAttr(Protocol): @@ -344,10 +373,8 @@ class _HasModalityAttr(Protocol): class _HasModalityProp(Protocol): - @property - def modality(self) -> str: - ... + def modality(self) -> str: ... _M = TypeVar("_M", bound=Union[_HasModalityAttr, _HasModalityProp]) @@ -364,6 +391,7 @@ class _BoundPromptSequence: A :data:`_PromptSeq` bound to a tokenizer to automatically convert between token sequence and text representations. """ + tokenizer: AnyTokenizer = field(repr=False) _text: Optional[str] @@ -382,8 +410,9 @@ def from_seq( def __post_init__(self) -> None: if self._text is None and self._token_ids is None: - raise ValueError("At least one of 'text' and 'token_ids' must be " - "specified") + raise ValueError( + "At least one of 'text' and 'token_ids' must be specified" + ) @property def text(self) -> str: @@ -397,9 +426,9 @@ def text(self) -> str: def token_ids(self) -> list[int]: if self._token_ids is None: assert self._text is not None - self._token_ids = _cached_encode(self.tokenizer, - self._text, - add_special_tokens=False) + self._token_ids = _cached_encode( + self.tokenizer, self._text, add_special_tokens=False + ) return self._token_ids @@ -417,6 +446,7 @@ class BoundPromptUpdate: :attr:`target` and the result of :meth:`get_content` between token sequence and text representations. """ + _origin: PromptUpdate tokenizer: AnyTokenizer = field(repr=False) @@ -465,12 +495,13 @@ def get_content(self, item_idx: int) -> _BoundPromptContent: if not isinstance(content, PromptUpdateDetails): content = PromptUpdateDetails.from_seq(content) - bound_full = _BoundPromptSequence.from_seq(self.tokenizer, - content.full) - bound_features = _BoundPromptSequence.from_seq(self.tokenizer, - content.features) - bound_content = _BoundPromptContent(full=bound_full, - features=bound_features) + bound_full = _BoundPromptSequence.from_seq(self.tokenizer, content.full) + bound_features = _BoundPromptSequence.from_seq( + self.tokenizer, content.features + ) + bound_content = _BoundPromptContent( + full=bound_full, features=bound_features + ) if cache_key is not None: self._content_cache[cache_key] = bound_content @@ -557,8 +588,10 @@ def end_idx(self) -> int: raise NotImplementedError def __repr__(self) -> str: - return (f"{type(self).__name__}(modality={self.modality!r}, " - f"start_idx={self.start_idx!r}, end_idx={self.end_idx!r})") + return ( + f"{type(self).__name__}(modality={self.modality!r}, " + f"start_idx={self.start_idx!r}, end_idx={self.end_idx!r})" + ) @dataclass(repr=False) @@ -639,9 +672,7 @@ def get_matches(update: BoundPromptUpdate): for match in iter_token_matches(prompt, target.token_ids) ] - return [ - match for update in prompt_updates for match in get_matches(update) - ] + return [match for update in prompt_updates for match in get_matches(update)] def find_text_matches( @@ -665,9 +696,7 @@ def get_matches(update: BoundPromptUpdate): for match in re.finditer(re.escape(target.text), prompt) ] - return [ - match for update in prompt_updates for match in get_matches(update) - ] + return [match for update in prompt_updates for match in get_matches(update)] def _resolve_matches( @@ -685,9 +714,11 @@ def _resolve_matches( for match in matches: for idx in range(match.start_idx, match.end_idx): if seen_matches[idx] is not None: - raise ValueError("Found overlapping matches " - f"({seen_matches[idx]} and {match}) " - f"at index={idx} of prompt={prompt}") + raise ValueError( + "Found overlapping matches " + f"({seen_matches[idx]} and {match}) " + f"at index={idx} of prompt={prompt}" + ) seen_matches[idx] = match @@ -730,8 +761,11 @@ def _apply_matches( for item_idx in range(item_start_idx, item_end_idx): content = origin.get_content(item_idx) - insert_seq = (content.full.text if isinstance(prompt, str) else - content.full.token_ids) + insert_seq = ( + content.full.text + if isinstance(prompt, str) + else content.full.token_ids + ) out_seqs.append(insert_seq) @@ -811,8 +845,10 @@ def _iter_placeholders( try: match = next( - iter_token_matches(content_tokens_full, - content_tokens_feat)) + iter_token_matches( + content_tokens_full, content_tokens_feat + ) + ) yield PlaceholderFeaturesInfo( modality=modality, item_idx=item_idx, @@ -822,7 +858,8 @@ def _iter_placeholders( except StopIteration: raise AssertionError( f"{content_tokens_feat=} should be a " - f"subsequence of {content_tokens_full=}") from None + f"subsequence of {content_tokens_full=}" + ) from None # Exclude overlapping matches start_idx = end_idx_full @@ -850,13 +887,11 @@ def find_mm_placeholders( class ProcessingCache: - @staticmethod def get_lru_cache( capacity_gb: int, value_type: type[_V], ) -> LRUCache[str, _V]: - def get_size(leaf: object) -> int: if isinstance(leaf, torch.Tensor): return leaf.nbytes # sys.getsizeof doesn't work for tensors @@ -888,8 +923,10 @@ def _maybe_log_cache_stats(self) -> None: total = self.debug_cache_total if total > 0 and total % steps == 0: - logger.debug("ProcessingCache: hit_ratio = %.2f", - self.debug_cache_hits / total) + logger.debug( + "ProcessingCache: hit_ratio = %.2f", + self.debug_cache_hits / total, + ) def get( self, @@ -909,9 +946,9 @@ def get( """ self._maybe_log_cache_stats() - cache_key = MultiModalHasher.hash_kwargs(model_id=model_id, - **{modality: input_item}, - **input_kwargs) + cache_key = MultiModalHasher.hash_kwargs( + model_id=model_id, **{modality: input_item}, **input_kwargs + ) if self.debug_cache_hit_ratio_steps: if cache_key in self._cache: @@ -933,9 +970,9 @@ def put( Put a processed multi-modal item into the cache according to its dependencies (see :meth:`get`). """ - cache_key = MultiModalHasher.hash_kwargs(model_id=model_id, - **{modality: input_item}, - **input_kwargs) + cache_key = MultiModalHasher.hash_kwargs( + model_id=model_id, **{modality: input_item}, **input_kwargs + ) self._cache[cache_key] = output_kwargs @@ -1002,16 +1039,20 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): Not to be confused with :class:`transformers.ProcessorMixin`. """ - def __init__(self, - info: _I, - dummy_inputs: "BaseDummyInputsBuilder[_I]", - *, - cache: Optional[ProcessingCache] = None, - enable_sanity_checks: bool = True) -> None: + def __init__( + self, + info: _I, + dummy_inputs: "BaseDummyInputsBuilder[_I]", + *, + cache: Optional[ProcessingCache] = None, + enable_sanity_checks: bool = True, + ) -> None: if get_repls := getattr(self, "_get_prompt_replacements", None): - logger.warning_once("`_get_prompt_replacements` has been renamed " - "to `_get_prompt_updates`. The old name will " - "be removed in an upcoming release.") + logger.warning_once( + "`_get_prompt_replacements` has been renamed " + "to `_get_prompt_updates`. The old name will " + "be removed in an upcoming release." + ) self._get_prompt_updates = get_repls # type: ignore[method-assign] super().__init__() @@ -1058,7 +1099,8 @@ def _to_mm_items( raise ValueError( f"You set {modality}={limit} (or defaulted to 1) in " f"`--limit-mm-per-prompt`, but passed {len(items)} " - f"{modality} items in the same prompt.") + f"{modality} items in the same prompt." + ) return mm_items @@ -1099,8 +1141,9 @@ def _find_mm_placeholders( new_token_ids: list[int], mm_item_counts: Mapping[str, int], ) -> Mapping[str, list[PlaceholderFeaturesInfo]]: - return find_mm_placeholders(mm_prompt_updates, new_token_ids, - mm_item_counts) + return find_mm_placeholders( + mm_prompt_updates, new_token_ids, mm_item_counts + ) def _get_hf_mm_data( self, @@ -1148,7 +1191,8 @@ def _hf_processor_applies_updates( """ return not any( isinstance(items, (EmbeddingItems, DictEmbeddingItems)) - for items in mm_items.values()) + for items in mm_items.values() + ) def _apply_hf_processor_text_mm( self, @@ -1171,7 +1215,7 @@ def _apply_hf_processor_text_mm( ) processed_data.update(passthrough_data) - prompt_ids, = processed_data.pop("input_ids").tolist() + (prompt_ids,) = processed_data.pop("input_ids").tolist() mm_kwargs = MultiModalKwargs.from_hf_inputs( processed_data, @@ -1313,8 +1357,7 @@ def _cached_apply_hf_processor( } mm_missing_idxs = { - modality: - [idx for idx, item in enumerate(kw_items) if item is None] + modality: [idx for idx, item in enumerate(kw_items) if item is None] for modality, kw_items in mm_maybe_cached_kw_items.items() } mm_missing_data = { @@ -1338,8 +1381,7 @@ def _cached_apply_hf_processor( ) mm_missing_next_idx = { - modality: 0 - for modality in mm_missing_data_items + modality: 0 for modality in mm_missing_data_items } merged_kw_items = list[MultiModalKwargsItem]() @@ -1367,9 +1409,11 @@ def _cached_apply_hf_processor( mm_missing_counts = mm_missing_data_items.get_all_counts() assert all( item_count == mm_missing_counts[modality] - for modality, item_count in mm_missing_next_idx.items()), dict( - mm_missing_next_idx=mm_missing_next_idx, - mm_missing_counts=mm_missing_counts) + for modality, item_count in mm_missing_next_idx.items() + ), dict( + mm_missing_next_idx=mm_missing_next_idx, + mm_missing_counts=mm_missing_counts, + ) mm_kwargs = MultiModalKwargs.from_items(merged_kw_items) @@ -1455,9 +1499,7 @@ def _apply_prompt_updates( mm_item_counts, ) - token_ids = encode_tokens(tokenizer, - text, - add_special_tokens=False) + token_ids = encode_tokens(tokenizer, text, add_special_tokens=False) matched_updates = { modality: [match._origin for match in token_matches] for modality, token_matches in mm_text_matches.items() @@ -1490,7 +1532,8 @@ def _validate_mm_kwargs( "There is likely a problem with your " "implementation of merged multi-modal processor for this " "model (usually arising from an inconsistency between " - "`_call_hf_processor` and `_get_mm_fields_config`).") + "`_call_hf_processor` and `_get_mm_fields_config`)." + ) def _validate_mm_placeholders( self, @@ -1509,7 +1552,8 @@ def _validate_mm_placeholders( "multi-modal inputs, or there is a problem with your " "implementation of merged multi-modal processor for this " "model (usually arising from an inconsistency between " - "`_call_hf_processor` and `_get_prompt_updates`).") + "`_call_hf_processor` and `_get_prompt_updates`)." + ) def apply( self, @@ -1541,9 +1585,11 @@ def apply( model_id = self.info.model_id mm_hashes = { modality: [ - MultiModalHasher.hash_kwargs(model_id=model_id, - **{modality: item}, - **hf_processor_mm_kwargs) + MultiModalHasher.hash_kwargs( + model_id=model_id, + **{modality: item}, + **hf_processor_mm_kwargs, + ) for item in items ] for modality, items in mm_items.items() @@ -1566,8 +1612,7 @@ def apply( hf_processor_mm_kwargs, mm_kwargs, ) - mm_prompt_updates = self._bind_and_group_updates( - unbound_prompt_updates) + mm_prompt_updates = self._bind_and_group_updates(unbound_prompt_updates) mm_item_counts = mm_items.get_all_counts() self._validate_mm_kwargs(mm_kwargs, mm_item_counts) @@ -1610,7 +1655,6 @@ def apply( class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): - @abstractmethod def create_encoder_prompt( self, @@ -1618,7 +1662,7 @@ def create_encoder_prompt( mm_data: MultiModalDataDict, ) -> Union[str, list[int]]: """ - Create input prompt for the encoder. HF processor will be applied on + Create input prompt for the encoder. HF processor will be applied on this prompt during profiling and generation. """ raise NotImplementedError @@ -1656,9 +1700,9 @@ def apply( tokenizer = self.info.get_tokenizer() decoder_prompt = self.create_decoder_prompt(prompt, mm_data) if isinstance(decoder_prompt, str): - decoder_prompt_ids = encode_tokens(tokenizer, - decoder_prompt, - add_special_tokens=False) + decoder_prompt_ids = encode_tokens( + tokenizer, decoder_prompt, add_special_tokens=False + ) else: decoder_prompt_ids = decoder_prompt decoder_prompt = decode_tokens(tokenizer, decoder_prompt) @@ -1666,9 +1710,9 @@ def apply( mm_inputs = MultiModalEncDecInputs( encoder_prompt=encoder_inputs["prompt"], encoder_prompt_token_ids=encoder_inputs["prompt_token_ids"], - **encoder_inputs) - mm_inputs.update({ - "prompt": decoder_prompt, - "prompt_token_ids": decoder_prompt_ids - }) + **encoder_inputs, + ) + mm_inputs.update( + {"prompt": decoder_prompt, "prompt_token_ids": decoder_prompt_ids} + ) return mm_inputs diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py index 7b4fb5eb598d..0a7d3746603d 100644 --- a/vllm/multimodal/profiling.py +++ b/vllm/multimodal/profiling.py @@ -13,8 +13,7 @@ from vllm.inputs import DummyData from vllm.logger import init_logger -from .inputs import (MultiModalDataDict, MultiModalEncDecInputs, - MultiModalInputs) +from .inputs import MultiModalDataDict, MultiModalEncDecInputs, MultiModalInputs from .processing import BaseMultiModalProcessor, BaseProcessingInfo logger = init_logger(__name__) @@ -26,6 +25,7 @@ class ProcessorInputs: Represents the keyword arguments to :meth:`vllm.multimodal.processing.BaseMultiModalProcessor.apply`. """ + prompt_text: str mm_data: MultiModalDataDict hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict) @@ -63,7 +63,7 @@ def _get_dummy_audios( length: int, num_audios: int, ) -> list[npt.NDArray]: - audio = np.zeros((length, )) + audio = np.zeros((length,)) return [audio] * num_audios def _get_dummy_images( @@ -124,7 +124,8 @@ def get_mm_limits(self) -> Mapping[str, int]: raise ValueError( f"You set {modality}={limit} (or defaulted to 1) in " f"`--limit-mm-per-prompt`, but this model only supports " - f"at most {supported_limit} {modality} items.") + f"at most {supported_limit} {modality} items." + ) return mm_limits @@ -135,7 +136,8 @@ def _get_dummy_mm_inputs( ) -> MultiModalInputs: factory = self.dummy_inputs processor_inputs = factory.get_dummy_processor_inputs( - seq_len, mm_counts) + seq_len, mm_counts + ) return self.processor.apply( prompt=processor_inputs.prompt_text, @@ -151,14 +153,16 @@ def get_and_validate_mm_inputs( info = self.processing_info mm_max_tokens_per_item = info.get_mm_max_tokens_per_item( - seq_len, mm_counts) + seq_len, mm_counts + ) if mm_counts.keys() != mm_max_tokens_per_item.keys(): raise AssertionError( "The keys returned by `get_supported_mm_limits` " f"({set(mm_counts.keys())}) should be the same as those " "returned by `get_mm_max_tokens_per_item` " - f"({set(mm_max_tokens_per_item.keys())})") + f"({set(mm_max_tokens_per_item.keys())})" + ) mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts) placeholders_by_modality = mm_inputs["mm_placeholders"] @@ -176,7 +180,8 @@ def get_and_validate_mm_inputs( f"The processed dummy data has a total of " f"{total_placeholders_by_modality} placeholder tokens, which " f"is not the expected {expected_placeholders_by_modality} " - "tokens.") + "tokens." + ) return mm_inputs, total_placeholders_by_modality def get_encoder_dummy_data( @@ -210,8 +215,9 @@ def get_decoder_dummy_data( # Avoid circular import from vllm.sequence import SequenceData - (mm_inputs, total_placeholders_by_modality - ) = self.get_and_validate_mm_inputs(seq_len) + (mm_inputs, total_placeholders_by_modality) = ( + self.get_and_validate_mm_inputs(seq_len) + ) prompt_token_ids = mm_inputs["prompt_token_ids"] total_len = len(prompt_token_ids) @@ -228,8 +234,11 @@ def get_decoder_dummy_data( "multi-modal inputs to fail during inference, even when " "the input text is short. To avoid this, you should " "increase `max_model_len`, reduce `max_num_seqs`, " - "and/or reduce `mm_counts`.", seq_len, total_len, - total_placeholders_by_modality) + "and/or reduce `mm_counts`.", + seq_len, + total_len, + total_placeholders_by_modality, + ) return DummyData( seq_data=SequenceData.from_prompt_token_counts((0, seq_len)), diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 24b835898279..3f00b4441f71 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -11,16 +11,21 @@ from vllm.envs import VLLM_MM_INPUT_CACHE_GIB from vllm.inputs import InputProcessingContext from vllm.logger import init_logger -from vllm.transformers_utils.tokenizer import (AnyTokenizer, - cached_tokenizer_from_config) +from vllm.transformers_utils.tokenizer import ( + AnyTokenizer, + cached_tokenizer_from_config, +) from vllm.utils import ClassRegistry from .audio import AudioPlugin from .base import MultiModalInputMapper, MultiModalPlugin, MultiModalTokensCalc from .image import ImagePlugin from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors -from .processing import (BaseMultiModalProcessor, BaseProcessingInfo, - ProcessingCache) +from .processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + ProcessingCache, +) from .profiling import BaseDummyInputsBuilder, MultiModalProfiler from .video import VideoPlugin @@ -40,8 +45,7 @@ class ProcessingInfoFactory(Protocol[_I_co]): def __call__( self, ctx: InputProcessingContext, - ) -> _I_co: - ... + ) -> _I_co: ... class DummyInputsBuilderFactory(Protocol[_I]): @@ -49,8 +53,7 @@ class DummyInputsBuilderFactory(Protocol[_I]): Constructs a :class:`BaseDummyInputsBuilder` instance from the context. """ - def __call__(self, info: _I) -> BaseDummyInputsBuilder[_I]: - ... + def __call__(self, info: _I) -> BaseDummyInputsBuilder[_I]: ... class MultiModalProcessorFactory(Protocol[_I]): @@ -62,8 +65,7 @@ def __call__( dummy_inputs: BaseDummyInputsBuilder[_I], *, cache: Optional[ProcessingCache] = None, - ) -> BaseMultiModalProcessor[_I]: - ... + ) -> BaseMultiModalProcessor[_I]: ... @dataclass(frozen=True) @@ -93,8 +95,10 @@ def __getitem__(self, key: "ModelConfig") -> dict[str, int]: try: return super().__getitem__(key) except KeyError as exc: - msg = (f"Cannot find `mm_limits` for model={key.model}. Did you " - "forget to call `init_mm_limits_per_prompt`?") + msg = ( + f"Cannot find `mm_limits` for model={key.model}. Did you " + "forget to call `init_mm_limits_per_prompt`?" + ) raise KeyError(msg) from exc @@ -106,13 +110,13 @@ class MultiModalRegistry: DEFAULT_PLUGINS = (ImagePlugin(), AudioPlugin(), VideoPlugin()) def __init__( - self, - *, - plugins: Sequence[MultiModalPlugin] = DEFAULT_PLUGINS) -> None: + self, *, plugins: Sequence[MultiModalPlugin] = DEFAULT_PLUGINS + ) -> None: self._plugins = {p.get_data_key(): p for p in plugins} - self._processor_factories = ClassRegistry[nn.Module, - _ProcessorFactories]() + self._processor_factories = ClassRegistry[ + nn.Module, _ProcessorFactories + ]() # This is used for non-multimodal models self._disabled_limits_per_plugin = {k: 0 for k in self._plugins} @@ -130,8 +134,10 @@ def register_plugin(self, plugin: MultiModalPlugin) -> None: if data_type_key in self._plugins: logger.warning( "A plugin is already registered for data type %s, " - "and will be overwritten by the new plugin %s.", data_type_key, - plugin) + "and will be overwritten by the new plugin %s.", + data_type_key, + plugin, + ) self._plugins[data_type_key] = plugin @@ -195,15 +201,19 @@ def map_input( raise ValueError( f"You set {data_key}={max_items} (or defaulted to 1) in " f"`--limit-mm-per-prompt`, but found {num_items} items " - "in the same prompt.") + "in the same prompt." + ) - input_dict = plugin.map_input(model_config, data_value, - mm_processor_kwargs) + input_dict = plugin.map_input( + model_config, data_value, mm_processor_kwargs + ) for input_key, input_tensor in input_dict.items(): if input_key in merged_dict: - raise ValueError(f"The input mappers (keys={set(data)}) " - f"resulted in a conflicting keyword " - f"argument to `forward()`: {input_key}") + raise ValueError( + f"The input mappers (keys={set(data)}) " + f"resulted in a conflicting keyword " + f"argument to `forward()`: {input_key}" + ) merged_dict[input_key] = input_tensor @@ -234,8 +244,9 @@ def register_max_multimodal_tokens( instance of multimodal data belonging to a specific modality, that are passed to the language model for a model class. """ - return self._get_plugin(data_type_key) \ - .register_max_multimodal_tokens(max_mm_tokens) + return self._get_plugin(data_type_key).register_max_multimodal_tokens( + max_mm_tokens + ) def register_max_image_tokens( self, @@ -252,18 +263,17 @@ def get_max_tokens_per_item_by_modality( model_config: "ModelConfig", ) -> Mapping[str, int]: """ - Get the maximum number of tokens per data item from each modality based + Get the maximum number of tokens per data item from each modality based on underlying model configuration. """ if self.has_processor(model_config): tokenizer = cached_tokenizer_from_config(model_config) - processor = self.create_processor(model_config, - tokenizer, - disable_cache=True) + processor = self.create_processor( + model_config, tokenizer, disable_cache=True + ) seq_len = model_config.max_model_len mm_limits = self.get_mm_limits_per_prompt(model_config) - return processor.info.get_mm_max_tokens_per_item( - seq_len, mm_limits) + return processor.info.get_mm_max_tokens_per_item(seq_len, mm_limits) return { key: plugin.get_max_multimodal_tokens(model_config) @@ -276,19 +286,20 @@ def get_max_tokens_per_item_by_nonzero_modality( ) -> Mapping[str, int]: """ Get the maximum number of tokens per data item from each modality based - on underlying model configuration, excluding modalities that user + on underlying model configuration, excluding modalities that user explicitly disabled via `limit_mm_per_prompt`. Note: - This is currently directly used only in V1 for profiling the memory + This is currently directly used only in V1 for profiling the memory usage of a model. """ mm_limits = self.get_mm_limits_per_prompt(model_config) return { key: max_tokens_per_mm_item - for key, max_tokens_per_mm_item in - self.get_max_tokens_per_item_by_modality(model_config).items() + for key, max_tokens_per_mm_item in self.get_max_tokens_per_item_by_modality( + model_config + ).items() if mm_limits[key] > 0 } @@ -309,8 +320,9 @@ def get_max_tokens_by_modality( return { key: mm_limits[key] * max_tokens_per_mm_item - for key, max_tokens_per_mm_item in - self.get_max_tokens_per_item_by_modality(model_config).items() + for key, max_tokens_per_mm_item in self.get_max_tokens_per_item_by_modality( + model_config + ).items() } def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int: @@ -336,7 +348,9 @@ def init_mm_limits_per_prompt( if model_config in self._limits_by_model: logger.warning( "`mm_limits` has already been set for model=%s, and will " - "be overwritten by the new values.", model_config.model) + "be overwritten by the new values.", + model_config.model, + ) multimodal_config = model_config.multimodal_config if multimodal_config is None: @@ -349,7 +363,9 @@ def init_mm_limits_per_prompt( logger.warning( "Detected extra keys in `--limit-mm-per-prompt` which " "are not registered as multi-modal plugins: %s. " - "They will be ignored.", extra_keys) + "They will be ignored.", + extra_keys, + ) # NOTE: Currently the default is set to 1 for each plugin # TODO: Automatically determine the limits based on budget @@ -374,9 +390,9 @@ def get_mm_limits_per_prompt( """ if self.has_processor(model_config): tokenizer = cached_tokenizer_from_config(model_config) - processor = self.create_processor(model_config, - tokenizer, - disable_cache=True) + processor = self.create_processor( + model_config, tokenizer, disable_cache=True + ) profiler = MultiModalProfiler(processor) return profiler.get_mm_limits() @@ -405,7 +421,9 @@ def wrapper(model_cls: N) -> N: logger.warning( "Model class %s already has a multi-modal processor " "registered to %s. It is overwritten by the new one.", - model_cls, self) + model_cls, + self, + ) self._processor_factories[model_cls] = _ProcessorFactories( info=info, diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index ad381e1d1d00..9a55273f2051 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -31,7 +31,6 @@ class MediaConnector: - def __init__( self, connection: HTTPConnection = global_http_connection, @@ -48,11 +47,13 @@ def __init__( if not allowed_local_media_path_.exists(): raise ValueError( "Invalid `--allowed-local-media-path`: The path " - f"{allowed_local_media_path_} does not exist.") + f"{allowed_local_media_path_} does not exist." + ) if not allowed_local_media_path_.is_dir(): raise ValueError( "Invalid `--allowed-local-media-path`: The path " - f"{allowed_local_media_path_} must be a directory.") + f"{allowed_local_media_path_} must be a directory." + ) else: allowed_local_media_path_ = None @@ -79,14 +80,16 @@ def _load_file_url( ) -> _M: allowed_local_media_path = self.allowed_local_media_path if allowed_local_media_path is None: - raise RuntimeError("Cannot load local files without " - "`--allowed-local-media-path`.") + raise RuntimeError( + "Cannot load local files without `--allowed-local-media-path`." + ) filepath = Path(url_spec.path) if allowed_local_media_path not in filepath.resolve().parents: raise ValueError( f"The file path {filepath} must be a subpath " - f"of `--allowed-local-media-path` {allowed_local_media_path}.") + f"of `--allowed-local-media-path` {allowed_local_media_path}." + ) return media_io.load_file(filepath) @@ -333,10 +336,14 @@ def repeat_and_pad_placeholder_tokens( new_prompt = None else: placeholder_token_str = tokenizer.decode(placeholder_token_id) - pad_token_str_left = (None if pad_token_left is None else - tokenizer.decode(pad_token_left)) - pad_token_str_right = (None if pad_token_right is None else - tokenizer.decode(pad_token_right)) + pad_token_str_left = ( + None if pad_token_left is None else tokenizer.decode(pad_token_left) + ) + pad_token_str_right = ( + None + if pad_token_right is None + else tokenizer.decode(pad_token_right) + ) placeholder_token_count = prompt.count(placeholder_token_str) # This is an arbitrary number to distinguish between the two cases @@ -344,16 +351,20 @@ def repeat_and_pad_placeholder_tokens( logger.warning( "Please follow the prompt format that is " "documented on HuggingFace which does not involve " - "repeating %s tokens.", placeholder_token_str) + "repeating %s tokens.", + placeholder_token_str, + ) if placeholder_token_count < len(repeat_count): logger.warning( "The number of multi-modal placeholder tokens in the prompt " "is less than the number of multi-modal inputs. Extra " - "placeholder tokens will be treated as plain text") + "placeholder tokens will be treated as plain text" + ) repeat_count = repeat_count[:placeholder_token_count] - prompt_parts = prompt.split(placeholder_token_str, - maxsplit=len(repeat_count)) + prompt_parts = prompt.split( + placeholder_token_str, maxsplit=len(repeat_count) + ) new_prompt = "" for i, repeat_count_item in enumerate(repeat_count): replacement_str = "".join( @@ -362,7 +373,8 @@ def repeat_and_pad_placeholder_tokens( repeat_count=repeat_count_item, pad_token_left=pad_token_str_left, pad_token_right=pad_token_str_right, - )) + ) + ) # The image tokens are removed to be consistent with HuggingFace new_prompt += prompt_parts[i] + replacement_str new_prompt += prompt_parts[-1] @@ -382,16 +394,18 @@ def repeat_and_pad_placeholder_tokens( offset = len(new_token_ids) if pad_token_left is not None: offset += 1 - placeholder_ranges.append({ - "offset": offset, - "length": curr_repeat_count, - }) + placeholder_ranges.append( + { + "offset": offset, + "length": curr_repeat_count, + } + ) new_token_ids.extend(replacement_ids) placeholder_token_idx += 1 # No need to further scan the list since we replaced all tokens if placeholder_token_idx >= len(repeat_count): - new_token_ids.extend(prompt_token_ids[i + 1:]) + new_token_ids.extend(prompt_token_ids[i + 1 :]) break else: new_token_ids.append(token) @@ -400,14 +414,15 @@ def repeat_and_pad_placeholder_tokens( def consecutive_placeholder_ranges( - num_items: int, - item_size: int, - initial_offset: int = 0) -> list[PlaceholderRange]: + num_items: int, item_size: int, initial_offset: int = 0 +) -> list[PlaceholderRange]: """Returns a list of consecutive PlaceholderRanges of a fixed size""" return [ - PlaceholderRange(offset=initial_offset + i * item_size, - length=item_size) for i in range(num_items) + PlaceholderRange( + offset=initial_offset + i * item_size, length=item_size + ) + for i in range(num_items) ] @@ -416,23 +431,23 @@ def merge_and_sort_multimodal_metadata( mm_hashes: Optional["MultiModalHashDict"], ) -> tuple[list[str], list[PlaceholderRange], Optional[list[str]]]: """Given a MultiModalPlaceholderDict, merge all PlaceholderRange - objects from all available modalities into a single list of - PlaceholderRange, sorted by their offset (starting index in the input + objects from all available modalities into a single list of + PlaceholderRange, sorted by their offset (starting index in the input sequence) in the ascending order. - Optionally if a MultiModalHashDict is given, same operation will be + Optionally if a MultiModalHashDict is given, same operation will be applied to the object and the sorted list of hashes will be returned. Raises: ValueError: If the input prompt has interleaved placeholders from - different modalities (e.g, "