From f547f71bf6e3b93a50cb3b44ec3ca986f70382f0 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 27 May 2025 06:41:57 +0000 Subject: [PATCH 01/15] [Core] Refactor dtype resolution Signed-off-by: DarkLight1337 --- vllm/config.py | 123 ++++++++++++++++++++++++------------------ vllm/platforms/cpu.py | 2 +- 2 files changed, 72 insertions(+), 53 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 4196684639ee..62fdcbf259eb 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3043,13 +3043,31 @@ def compute_hash(self) -> str: "bfloat16": torch.bfloat16, } -_ROCM_NOT_SUPPORTED_DTYPE: list[str] = [] # +# model_type -> reason +_FLOAT16_NOT_SUPPORTED_MODELS = { + "gemma3": "Numerical instability. Please use bfloat16 or float32 instead.", + "plamo2": "Numerical instability. Please use bfloat16 or float32 instead.", + "glm4": "Numerical instability. Please use bfloat16 or float32 instead.", +} -def _get_and_verify_dtype( - config: PretrainedConfig, - dtype: Union[str, torch.dtype], -) -> torch.dtype: +def _is_valid_dtype(model_type: str, dtype: torch.dtype): + if model_type in _FLOAT16_NOT_SUPPORTED_MODELS and dtype == torch.float16: + return False + + return True + + +def _check_valid_dtype(model_type: str, dtype: torch.dtype): + if model_type in _FLOAT16_NOT_SUPPORTED_MODELS and dtype == torch.float16: + reason = _FLOAT16_NOT_SUPPORTED_MODELS[model_type] + raise ValueError(f"The model type {model_type!r} " + f"does not support float16. Reason: {reason}") + + return True + + +def _find_config_dtype(config: PretrainedConfig) -> torch.dtype: # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct # because config.torch_dtype can be None. config_dtype = getattr(config, "torch_dtype", None) @@ -3064,72 +3082,73 @@ def _get_and_verify_dtype( if config_dtype is None: config_dtype = torch.float32 - if isinstance(dtype, str): - dtype = dtype.lower() - if dtype == "auto": - # Set default dtype from model config - if config_dtype == torch.float32: - # Following common practice, we use float16 for float32 models - torch_dtype = torch.float16 - else: - torch_dtype = config_dtype + return config_dtype - if config.model_type == "plamo2": - logger.warning( - "For PLaMo2, we cast models to bfloat16 instead of using " - "float16 by default. This is because float16 does not work." - ) - torch_dtype = torch.bfloat16 - # Deal with torch dtype fallback for device compatibility. - from vllm.platforms import current_platform - if torch_dtype not in current_platform.supported_dtypes: - device_name = current_platform.get_device_name() +def _resolve_auto_dtype(model_type: str, config_dtype: torch.dtype): + from vllm.platforms import current_platform - if ((capability := current_platform.get_device_capability()) - is None): - compute_str = "" - else: - version_str = capability.as_version_str() - compute_str = f" (with compute capability {version_str})" - fallback_dtype = current_platform.supported_dtypes[0] - logger.warning( - "Your %s device%s doesn't support %s. " \ - "Falling back to %s for compatibility.", - device_name, compute_str, torch_dtype, fallback_dtype - ) - torch_dtype = fallback_dtype + platform_dtype = next(dtype for dtype in current_platform.supported_dtypes + if _is_valid_dtype(model_type, dtype)) - if current_platform.is_hpu() and torch_dtype == torch.float16: - logger.warning( - "For HPU, we cast models to bfloat16 instead of " - "using float16 by default. Please specify `dtype` if you " - "want to use float16.") - torch_dtype = torch.bfloat16 - elif dtype == "float16" and config.model_type == "plamo2": - logger.warning( - "For PLaMo2, using float16 is unstable and might cause " - "unexpected behavior. Please use bfloat16 or float32 instead.") - torch_dtype = torch.float16 + # Downcast to platform's default for float32 models + if config_dtype == torch.float32: + return platform_dtype + + # Ensure device compatibility + if config_dtype in current_platform.supported_dtypes: + return config_dtype + + device_name = current_platform.get_device_name() + device_capability = current_platform.get_device_capability() + + if device_capability is None: + device_str = f"{device_name!r}" + else: + version_str = device_capability.as_version_str() + device_str = f"{device_name!r} (with compute capability {version_str})" + + logger.warning( + "Your device %s doesn't support %s. " + "Falling back to %s for compatibility.", + device_str, + config_dtype, + platform_dtype, + ) + + return platform_dtype + + +def _get_and_verify_dtype( + config: PretrainedConfig, + dtype: Union[str, torch.dtype], +) -> torch.dtype: + config_dtype = _find_config_dtype(config) + model_type = config.model_type + + if isinstance(dtype, str): + dtype = dtype.lower() + if dtype == "auto": + # Set default dtype from model config + torch_dtype = _resolve_auto_dtype(model_type, config_dtype) else: if dtype not in _STR_DTYPE_TO_TORCH_DTYPE: - raise ValueError(f"Unknown dtype: {dtype}") + raise ValueError(f"Unknown dtype: {dtype!r}") torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype] elif isinstance(dtype, torch.dtype): torch_dtype = dtype else: raise ValueError(f"Unknown dtype: {dtype}") - # Verify the dtype. + _check_valid_dtype(model_type, torch_dtype) + if torch_dtype != config_dtype: if torch_dtype == torch.float32: # Upcasting to float32 is allowed. logger.info("Upcasting %s to %s.", config_dtype, torch_dtype) - pass elif config_dtype == torch.float32: # Downcasting from float32 to float16 or bfloat16 is allowed. logger.info("Downcasting %s to %s.", config_dtype, torch_dtype) - pass else: # Casting between float16 and bfloat16 is allowed with a warning. logger.warning("Casting %s to %s.", config_dtype, torch_dtype) diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index c79c603c02eb..eaffaac78cce 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -28,7 +28,7 @@ class CpuPlatform(Platform): dispatch_key: str = "CPU" @property - def supported_dtypes(self) -> list: + def supported_dtypes(self) -> list[torch.dtype]: if self.get_cpu_architecture() == CpuArchEnum.POWERPC: return [torch.bfloat16, torch.float32] elif sys.platform.startswith( From 185db656dff74ba7bc179b56af5c81e5bd27c7d8 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 27 May 2025 07:01:23 +0000 Subject: [PATCH 02/15] Add noqa Signed-off-by: DarkLight1337 --- vllm/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index 62fdcbf259eb..170e24943eff 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3052,7 +3052,7 @@ def compute_hash(self) -> str: def _is_valid_dtype(model_type: str, dtype: torch.dtype): - if model_type in _FLOAT16_NOT_SUPPORTED_MODELS and dtype == torch.float16: + if model_type in _FLOAT16_NOT_SUPPORTED_MODELS and dtype == torch.float16: # noqa: E501, SIM103 return False return True From efd2d516b84ad96c936422bffeb6cdf143eb5334 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 27 May 2025 13:42:41 +0000 Subject: [PATCH 03/15] Try to read dtype of safetensors weights Signed-off-by: DarkLight1337 --- tests/conftest.py | 6 +- .../multimodal/processing/test_common.py | 2 +- tests/test_utils.py | 90 ++++++++++++++----- vllm/config.py | 46 ++++++++-- vllm/transformers_utils/config.py | 40 +++++++-- vllm/utils.py | 36 ++++++++ 6 files changed, 181 insertions(+), 39 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 19c2c6247129..c99faf11ad62 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -323,7 +323,11 @@ def __init__( trust_remote_code=True, ) self.device = self.get_default_device() - self.dtype = torch_dtype = _get_and_verify_dtype(self.config, dtype) + self.dtype = torch_dtype = _get_and_verify_dtype( + self.model_name, + self.config, + dtype, + ) model_kwargs = model_kwargs if model_kwargs is not None else {} model_kwargs.setdefault("torch_dtype", torch_dtype) diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 572fa366d332..d7f950c23d95 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -40,7 +40,7 @@ def _test_processing_correctness( tokenizer_mode=model_info.tokenizer_mode, trust_remote_code=model_info.trust_remote_code, seed=0, - dtype="float16", + dtype="auto", revision=None, hf_overrides=model_info.hf_overrides, ) diff --git a/tests/test_utils.py b/tests/test_utils.py index 0b88d05efeaa..879767c8cadf 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -18,9 +18,9 @@ from vllm.utils import (CacheInfo, FlexibleArgumentParser, LRUCache, MemorySnapshot, PlaceholderModule, StoreBoolean, bind_kv_cache, deprecate_kwargs, get_open_port, - make_zmq_path, make_zmq_socket, memory_profiling, - merge_async_iterators, sha256, split_zmq_path, - supports_kw, swap_dict_values) + is_lossless_cast, make_zmq_path, make_zmq_socket, + memory_profiling, merge_async_iterators, sha256, + split_zmq_path, supports_kw, swap_dict_values) from .utils import create_new_process_for_each_test, error_on_warning @@ -567,12 +567,50 @@ def test_lru_cache(): assert 6 in cache +# yapf: disable +@pytest.mark.parametrize( + ("src_dtype", "tgt_dtype", "expected_result"), + [ + # Different safe levels + (torch.bool, torch.int8, True), + (torch.bool, torch.float16, True), + (torch.bool, torch.complex32, True), + (torch.int64, torch.bool, False), + (torch.int64, torch.float16, True), + (torch.int64, torch.complex32, True), + (torch.float64, torch.bool, False), + (torch.float64, torch.int8, False), + (torch.float64, torch.complex32, True), + (torch.complex128, torch.bool, False), + (torch.complex128, torch.int8, False), + (torch.complex128, torch.float16, False), + # precision_level=0 + (torch.bool, torch.bool, True), + # precision_level=1 + (torch.int8, torch.int16, True), + (torch.int16, torch.int8, False), + (torch.uint8, torch.int8, False), + (torch.int8, torch.uint8, False), + # precision_level=2 + (torch.float16, torch.float32, True), + (torch.float32, torch.float16, False), + (torch.bfloat16, torch.float32, True), + (torch.float32, torch.bfloat16, False), + # precision_level=3 + (torch.complex32, torch.complex64, True), + (torch.complex64, torch.complex32, False), + ], +) +# yapf: enable +def test_is_lossless_cast(src_dtype, tgt_dtype, expected_result): + assert is_lossless_cast(src_dtype, tgt_dtype) == expected_result + + def test_placeholder_module_error_handling(): placeholder = PlaceholderModule("placeholder_1234") def build_ctx(): - return pytest.raises(ModuleNotFoundError, - match="No module named") + return pytest.raises(ModuleNotFoundError, match="No module named") with build_ctx(): int(placeholder) @@ -608,6 +646,7 @@ def build_ctx(): _ = placeholder_attr.module +# yapf: disable @pytest.mark.parametrize( "obj,key1,key2", [ @@ -618,6 +657,7 @@ def build_ctx(): # Tests for both keys do not exist ({1: "a", 2: "b"}, 3, 4), ]) +# yapf: enable def test_swap_dict_values(obj, key1, key2): original_obj = obj.copy() swap_dict_values(obj, key1, key2) @@ -631,19 +671,19 @@ def test_swap_dict_values(obj, key1, key2): assert key1 not in obj -def test_model_specification(parser_with_config, - cli_config_file, +def test_model_specification(parser_with_config, cli_config_file, cli_config_file_with_model): # Test model in CLI takes precedence over config - args = parser_with_config.parse_args([ - 'serve', 'cli-model', '--config', cli_config_file_with_model - ]) + args = parser_with_config.parse_args( + ['serve', 'cli-model', '--config', cli_config_file_with_model]) assert args.model_tag == 'cli-model' assert args.served_model_name == 'mymodel' # Test model from config file works args = parser_with_config.parse_args([ - 'serve', '--config', cli_config_file_with_model, + 'serve', + '--config', + cli_config_file_with_model, ]) assert args.model == 'config-model' assert args.served_model_name == 'mymodel' @@ -654,17 +694,19 @@ def test_model_specification(parser_with_config, # Test using --model option raises error with pytest.raises( - ValueError, - match=( - "With `vllm serve`, you should provide the model as a positional " - "argument or in a config file instead of via the `--model` option." - ), + ValueError, + match= + ("With `vllm serve`, you should provide the model as a positional " + "argument or in a config file instead of via the `--model` option."), ): parser_with_config.parse_args(['serve', '--model', 'my-model']) # Test other config values are preserved args = parser_with_config.parse_args([ - 'serve', 'cli-model', '--config', cli_config_file_with_model, + 'serve', + 'cli-model', + '--config', + cli_config_file_with_model, ]) assert args.tensor_parallel_size == 2 assert args.trust_remote_code is True @@ -673,7 +715,7 @@ def test_model_specification(parser_with_config, @pytest.mark.parametrize("input", [(), ("abc", ), (None, ), - (None, bool, [1, 2, 3])]) + (None, bool, [1, 2, 3])]) @pytest.mark.parametrize("output", [0, 1, 2]) def test_sha256(input: tuple, output: int): hash = sha256(input) @@ -682,7 +724,8 @@ def test_sha256(input: tuple, output: int): assert hash != 0 bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL) - assert hash == int.from_bytes(hashlib.sha256(bytes).digest(), byteorder="big") + assert hash == int.from_bytes(hashlib.sha256(bytes).digest(), + byteorder="big") # hashing again, returns the same value assert hash == sha256(input) @@ -698,8 +741,7 @@ def test_sha256(input: tuple, output: int): ("tcp://127.0.0.1:5555", ("tcp", "127.0.0.1", "5555")), ("tcp://[::1]:5555", ("tcp", "::1", "5555")), # IPv6 address ("inproc://some_identifier", ("inproc", "some_identifier", "")), - ] -) + ]) def test_split_zmq_path(path, expected): assert split_zmq_path(path) == expected @@ -711,8 +753,7 @@ def test_split_zmq_path(path, expected): "tcp://127.0.0.1", # Missing port "tcp://[::1]", # Missing port for IPv6 "tcp://:5555", # Missing host - ] -) + ]) def test_split_zmq_path_invalid(invalid_path): with pytest.raises(ValueError): split_zmq_path(invalid_path) @@ -734,7 +775,8 @@ def test_make_zmq_socket_ipv6(): zsock: zmq.Socket = make_zmq_socket(ctx, ipv6_path, socket_type) # Verify that the IPV6 option is set - assert zsock.getsockopt(zmq.IPV6) == 1, "IPV6 option should be enabled for IPv6 addresses" + assert zsock.getsockopt( + zmq.IPV6) == 1, "IPV6 option should be enabled for IPv6 addresses" # Clean up zsock.close() diff --git a/vllm/config.py b/vllm/config.py index 170e24943eff..0f853bac67bc 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -21,6 +21,7 @@ import regex as re import torch +from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE from torch.distributed import ProcessGroup, ReduceOp from transformers import PretrainedConfig from typing_extensions import deprecated @@ -39,15 +40,16 @@ ConfigFormat, get_config, get_hf_image_processor_config, get_hf_text_config, get_pooling_config, get_sentence_transformer_tokenizer_config, is_encoder_decoder, - try_get_generation_config, uses_mrope) + try_get_generation_config, try_get_safetensors_metadata, uses_mrope) from vllm.transformers_utils.s3_utils import S3Model from vllm.transformers_utils.utils import is_s3, maybe_model_redirect from vllm.utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS, MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS, POOLING_MODEL_MAX_NUM_BATCHED_TOKENS, GiB_bytes, LayerBlockType, cuda_device_count_stateless, - get_cpu_memory, get_open_port, is_torch_equal_or_newer, - random_uuid, resolve_obj_by_qualname) + get_cpu_memory, get_open_port, is_lossless_cast, + is_torch_equal_or_newer, random_uuid, + resolve_obj_by_qualname) if TYPE_CHECKING: from _typeshed import DataclassInstance @@ -531,7 +533,12 @@ def __post_init__(self) -> None: self.encoder_config = self._get_encoder_config() self.hf_image_processor_config = get_hf_image_processor_config( self.model, hf_token=self.hf_token, revision=self.revision) - self.dtype = _get_and_verify_dtype(self.hf_config, self.dtype) + self.dtype = _get_and_verify_dtype( + self.model, + self.hf_config, + self.dtype, + revision=self.revision, + ) # Workaround for Gemma 2 which uses interleaved sliding window # attention, but it's not specified in its config. TODO: remove this @@ -3067,7 +3074,12 @@ def _check_valid_dtype(model_type: str, dtype: torch.dtype): return True -def _find_config_dtype(config: PretrainedConfig) -> torch.dtype: +def _find_dtype( + model_id: str, + config: PretrainedConfig, + *, + revision: Optional[str], +): # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct # because config.torch_dtype can be None. config_dtype = getattr(config, "torch_dtype", None) @@ -3079,6 +3091,25 @@ def _find_config_dtype(config: PretrainedConfig) -> torch.dtype: if config_dtype is None and hasattr(config, "vision_config"): config_dtype = getattr(config.vision_config, "torch_dtype", None) + # Try to read the dtype of the weights if they are in safetensors format + if config_dtype is None: + repo_mt = try_get_safetensors_metadata(model_id, revision=revision) + + if repo_mt and (files_mt := repo_mt.files_metadata): + param_dtypes = set[torch.dtype]().union( + *(_SAFETENSORS_TO_TORCH_DTYPE[dtype_str] + for file_mt in files_mt.values() + for dtype_str in file_mt.parameter_count + if dtype_str in _SAFETENSORS_TO_TORCH_DTYPE)) + + if param_dtypes: + # Use the safest dtype out of the available ones + return max( + param_dtypes, + key=lambda dtype: sum( + is_lossless_cast(dtype, dt) for dt in param_dtypes), + ) + if config_dtype is None: config_dtype = torch.float32 @@ -3120,10 +3151,13 @@ def _resolve_auto_dtype(model_type: str, config_dtype: torch.dtype): def _get_and_verify_dtype( + model_id: str, config: PretrainedConfig, dtype: Union[str, torch.dtype], + *, + revision: Optional[str] = None, ) -> torch.dtype: - config_dtype = _find_config_dtype(config) + config_dtype = _find_dtype(model_id, config, revision=revision) model_type = config.model_type if isinstance(dtype, str): diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 69e7207cc350..89f565515e9e 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -4,12 +4,12 @@ import json import os import time -from functools import cache +from functools import cache, partial from pathlib import Path -from typing import Any, Callable, Literal, Optional, Union +from typing import Any, Callable, Literal, Optional, TypeVar, Union import huggingface_hub -from huggingface_hub import hf_hub_download +from huggingface_hub import get_safetensors_metadata, hf_hub_download from huggingface_hub import list_repo_files as hf_list_repo_files from huggingface_hub import try_to_load_from_cache from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError, @@ -93,10 +93,15 @@ class ConfigFormat(str, enum.Enum): MISTRAL = "mistral" -def with_retry(func: Callable[[], Any], - log_msg: str, - max_retries: int = 2, - retry_delay: int = 2): +_R = TypeVar("_R") + + +def with_retry( + func: Callable[[], _R], + log_msg: str, + max_retries: int = 2, + retry_delay: int = 2, +) -> _R: for attempt in range(max_retries): try: return func() @@ -109,6 +114,8 @@ def with_retry(func: Callable[[], Any], time.sleep(retry_delay) retry_delay *= 2 + raise AssertionError("Should not be reached") + # @cache doesn't cache exceptions @cache @@ -833,3 +840,22 @@ def get_cross_encoder_activation_function(config: PretrainedConfig): return resolve_obj_by_qualname(function_name)() else: return nn.Sigmoid() if config.num_labels == 1 else nn.Identity() + + +def try_get_safetensors_metadata( + model: str, + *, + revision: Optional[str] = None, +): + get_safetensors_metadata_partial = partial( + get_safetensors_metadata, + model, + revision=revision, + token=os.getenv('HF_TOKEN', None), + ) + + try: + return with_retry(get_safetensors_metadata_partial, + "Error retrieving safetensors") + except Exception: + return None diff --git a/vllm/utils.py b/vllm/utils.py index 86873ff75817..2eddd728ea47 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -980,6 +980,42 @@ def get_dtype_size(dtype: torch.dtype) -> int: return torch.tensor([], dtype=dtype).element_size() +# bool = 0, int = 1, float = 2, complex = 3 +def _get_precision_level(dtype: torch.dtype) -> int: + # NOTE: Complex dtypes return `is_floating_point=False` + return ((dtype != torch.bool) + dtype.is_floating_point + + dtype.is_complex * 2) + + +def is_lossless_cast(src_dtype: torch.dtype, tgt_dtype: torch.dtype): + """ + Test whether it is lossless to cast a tensor from + `src_dtype` to `tgt_dtype`. + """ + if src_dtype == tgt_dtype: + return True + + src_level = _get_precision_level(src_dtype) + tgt_level = _get_precision_level(tgt_dtype) + + if src_level < tgt_level: + return True + if src_level > tgt_level: + return False + + # Compare integral types + if not src_dtype.is_floating_point and not src_dtype.is_complex: + src_info = torch.iinfo(src_dtype) + tgt_info = torch.iinfo(tgt_dtype) + return src_info.min >= tgt_info.min and src_info.max <= tgt_info.max + + # Compare floating-point types + src_info = torch.finfo(src_dtype) + tgt_info = torch.finfo(tgt_dtype) + return (src_info.min >= tgt_info.min and src_info.max <= tgt_info.max + and src_info.resolution >= tgt_info.resolution) + + # `collections` helpers def is_list_of( value: object, From 8300ff336e903ccd951097cce8ad25554a4b24ff Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 27 May 2025 14:13:45 +0000 Subject: [PATCH 04/15] Update Signed-off-by: DarkLight1337 --- tests/test_utils.py | 24 ++++++++++++++++++++---- vllm/config.py | 24 ++++++++++-------------- vllm/utils.py | 15 +++++++++++++-- 3 files changed, 43 insertions(+), 20 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 879767c8cadf..a408304339b4 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -17,10 +17,11 @@ from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.utils import (CacheInfo, FlexibleArgumentParser, LRUCache, MemorySnapshot, PlaceholderModule, StoreBoolean, - bind_kv_cache, deprecate_kwargs, get_open_port, - is_lossless_cast, make_zmq_path, make_zmq_socket, - memory_profiling, merge_async_iterators, sha256, - split_zmq_path, supports_kw, swap_dict_values) + bind_kv_cache, common_broadcastable_dtype, + deprecate_kwargs, get_open_port, is_lossless_cast, + make_zmq_path, make_zmq_socket, memory_profiling, + merge_async_iterators, sha256, split_zmq_path, + supports_kw, swap_dict_values) from .utils import create_new_process_for_each_test, error_on_warning @@ -606,6 +607,21 @@ def test_is_lossless_cast(src_dtype, tgt_dtype, expected_result): assert is_lossless_cast(src_dtype, tgt_dtype) == expected_result +# yapf: disable +@pytest.mark.parametrize( + ("dtypes", "expected_result"), + [ + ([torch.bool], torch.bool), + ([torch.bool, torch.int8], torch.int8), + ([torch.bool, torch.int8, torch.float16], torch.float16), + ([torch.bool, torch.int8, torch.float16, torch.complex32], torch.complex32), # noqa: E501 + ], +) +# yapf: enable +def test_common_broadcastable_dtype(dtypes, expected_result): + assert common_broadcastable_dtype(dtypes) == expected_result + + def test_placeholder_module_error_handling(): placeholder = PlaceholderModule("placeholder_1234") diff --git a/vllm/config.py b/vllm/config.py index 0f853bac67bc..dc30f9be257a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -46,9 +46,9 @@ from vllm.utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS, MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS, POOLING_MODEL_MAX_NUM_BATCHED_TOKENS, GiB_bytes, - LayerBlockType, cuda_device_count_stateless, - get_cpu_memory, get_open_port, is_lossless_cast, - is_torch_equal_or_newer, random_uuid, + LayerBlockType, common_broadcastable_dtype, + cuda_device_count_stateless, get_cpu_memory, + get_open_port, is_torch_equal_or_newer, random_uuid, resolve_obj_by_qualname) if TYPE_CHECKING: @@ -3096,19 +3096,15 @@ def _find_dtype( repo_mt = try_get_safetensors_metadata(model_id, revision=revision) if repo_mt and (files_mt := repo_mt.files_metadata): - param_dtypes = set[torch.dtype]().union( - *(_SAFETENSORS_TO_TORCH_DTYPE[dtype_str] - for file_mt in files_mt.values() - for dtype_str in file_mt.parameter_count - if dtype_str in _SAFETENSORS_TO_TORCH_DTYPE)) + param_dtypes: set[torch.dtype] = { + _SAFETENSORS_TO_TORCH_DTYPE[dtype_str] + for file_mt in files_mt.values() + for dtype_str in file_mt.parameter_count + if dtype_str in _SAFETENSORS_TO_TORCH_DTYPE + } if param_dtypes: - # Use the safest dtype out of the available ones - return max( - param_dtypes, - key=lambda dtype: sum( - is_lossless_cast(dtype, dt) for dt in param_dtypes), - ) + return common_broadcastable_dtype(param_dtypes) if config_dtype is None: config_dtype = torch.float32 diff --git a/vllm/utils.py b/vllm/utils.py index 2eddd728ea47..cb58264b3a13 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -37,8 +37,8 @@ _ArgumentGroup) from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task from collections import UserDict, defaultdict -from collections.abc import (AsyncGenerator, Awaitable, Generator, Hashable, - Iterable, Iterator, KeysView, Mapping) +from collections.abc import (AsyncGenerator, Awaitable, Collection, Generator, + Hashable, Iterable, Iterator, KeysView, Mapping) from concurrent.futures.process import ProcessPoolExecutor from dataclasses import dataclass, field from functools import cache, lru_cache, partial, wraps @@ -1016,6 +1016,17 @@ def is_lossless_cast(src_dtype: torch.dtype, tgt_dtype: torch.dtype): and src_info.resolution >= tgt_info.resolution) +def common_broadcastable_dtype(dtypes: Collection[torch.dtype]): + """ + Get the common `dtype` where all of the other `dtypes` can be + cast to it without losing any information. + """ + return max( + dtypes, + key=lambda dtype: sum(is_lossless_cast(dt, dtype) for dt in dtypes), + ) + + # `collections` helpers def is_list_of( value: object, From 960255e9966b42797ca6a9c185982178f6d86721 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 27 May 2025 14:15:00 +0000 Subject: [PATCH 05/15] Update Signed-off-by: DarkLight1337 --- tests/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index a408304339b4..dd8777f06888 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -572,7 +572,7 @@ def test_lru_cache(): @pytest.mark.parametrize( ("src_dtype", "tgt_dtype", "expected_result"), [ - # Different safe levels + # Different precision_levels (torch.bool, torch.int8, True), (torch.bool, torch.float16, True), (torch.bool, torch.complex32, True), From ee0f1a67419d228a483e76eaa0a3310ef5162d44 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 30 May 2025 05:21:21 +0000 Subject: [PATCH 06/15] Also include gemma2 Signed-off-by: DarkLight1337 --- vllm/config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/config.py b/vllm/config.py index dc30f9be257a..d4d0960643d7 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3052,6 +3052,7 @@ def compute_hash(self) -> str: # model_type -> reason _FLOAT16_NOT_SUPPORTED_MODELS = { + "gemma2": "Numerical instability. Please use bfloat16 or float32 instead.", "gemma3": "Numerical instability. Please use bfloat16 or float32 instead.", "plamo2": "Numerical instability. Please use bfloat16 or float32 instead.", "glm4": "Numerical instability. Please use bfloat16 or float32 instead.", From e00f1aa88143abf9b61e595243b75597ffc5d875 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 30 May 2025 07:55:02 +0000 Subject: [PATCH 07/15] Fix tests Signed-off-by: DarkLight1337 --- tests/basic_correctness/test_basic_correctness.py | 5 +---- tests/models/language/pooling/test_baai.py | 14 ++++++++++++++ .../models/language/pooling/test_classification.py | 13 ++----------- tests/models/language/pooling/test_embedding.py | 6 +----- tests/models/language/pooling/test_gte.py | 4 ++++ .../pooling/test_snowflake_arctic_embed.py | 8 ++++++++ tests/models/multimodal/generation/test_whisper.py | 1 + 7 files changed, 31 insertions(+), 20 deletions(-) diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 86b5e1e0ab7c..11c8e7a4b9d1 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -60,7 +60,6 @@ def _fix_prompt_embed_outputs( @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("backend", ["FLASH_ATTN"]) -@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("enforce_eager", [False]) @pytest.mark.parametrize("enable_prompt_embeds", [True, False]) @@ -69,7 +68,6 @@ def test_models( hf_runner, model: str, backend: str, - dtype: str, max_tokens: int, enforce_eager: bool, enable_prompt_embeds: bool, @@ -97,7 +95,7 @@ def test_models( str(i) for i in range(1024)) + " are:" example_prompts = [prompt] - with hf_runner(model, dtype=dtype) as hf_model: + with hf_runner(model) as hf_model: hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) if enable_prompt_embeds: with torch.no_grad(): @@ -106,7 +104,6 @@ def test_models( with VllmRunner(model, max_model_len=8192, - dtype=dtype, enforce_eager=enforce_eager, enable_prompt_embeds=enable_prompt_embeds, gpu_memory_utilization=0.7) as vllm_model: diff --git a/tests/models/language/pooling/test_baai.py b/tests/models/language/pooling/test_baai.py index fc0e8207954f..8e435e79ee31 100644 --- a/tests/models/language/pooling/test_baai.py +++ b/tests/models/language/pooling/test_baai.py @@ -8,46 +8,60 @@ ########## BertModel EmbedModelInfo("BAAI/bge-base-en", architecture="BertModel", + dtype="half", enable_test=True), EmbedModelInfo("BAAI/bge-base-zh", architecture="BertModel", + dtype="half", enable_test=False), EmbedModelInfo("BAAI/bge-small-en", architecture="BertModel", + dtype="half", enable_test=False), EmbedModelInfo("BAAI/bge-small-zh", architecture="BertModel", + dtype="half", enable_test=False), EmbedModelInfo("BAAI/bge-large-en", architecture="BertModel", + dtype="half", enable_test=False), EmbedModelInfo("BAAI/bge-large-zh", architecture="BertModel", + dtype="half", enable_test=False), EmbedModelInfo("BAAI/bge-large-zh-noinstruct", architecture="BertModel", + dtype="half", enable_test=False), EmbedModelInfo("BAAI/bge-base-en-v1.5", architecture="BertModel", + dtype="half", enable_test=False), EmbedModelInfo("BAAI/bge-base-zh-v1.5", architecture="BertModel", + dtype="half", enable_test=False), EmbedModelInfo("BAAI/bge-small-en-v1.5", architecture="BertModel", + dtype="half", enable_test=False), EmbedModelInfo("BAAI/bge-small-zh-v1.5", architecture="BertModel", + dtype="half", enable_test=False), EmbedModelInfo("BAAI/bge-large-en-v1.5", architecture="BertModel", + dtype="half", enable_test=False), EmbedModelInfo("BAAI/bge-large-zh-v1.5", architecture="BertModel", + dtype="half", enable_test=False), ########## XLMRobertaModel EmbedModelInfo("BAAI/bge-m3", architecture="XLMRobertaModel", + dtype="half", enable_test=True), ########## Qwen2Model EmbedModelInfo("BAAI/bge-code-v1", diff --git a/tests/models/language/pooling/test_classification.py b/tests/models/language/pooling/test_classification.py index 44af3df08a86..a46461168237 100644 --- a/tests/models/language/pooling/test_classification.py +++ b/tests/models/language/pooling/test_classification.py @@ -13,14 +13,11 @@ marks=[pytest.mark.core_model, pytest.mark.cpu_model]), ], ) -@pytest.mark.parametrize("dtype", - ["half"] if current_platform.is_rocm() else ["float"]) def test_models( hf_runner, vllm_runner, example_prompts, model: str, - dtype: str, monkeypatch, ) -> None: if current_platform.is_rocm(): @@ -28,21 +25,15 @@ def test_models( # switch to use ROCm CK FA backend monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False") - with vllm_runner(model, dtype=dtype) as vllm_model: + with vllm_runner(model) as vllm_model: vllm_outputs = vllm_model.classify(example_prompts) with hf_runner(model, - dtype=dtype, auto_cls=AutoModelForSequenceClassification) as hf_model: hf_outputs = hf_model.classify(example_prompts) - # check logits difference for hf_output, vllm_output in zip(hf_outputs, vllm_outputs): hf_output = torch.tensor(hf_output) vllm_output = torch.tensor(vllm_output) - # the tolerance value of 1e-2 is selected based on the - # half datatype tests in - # tests/models/embedding/language/test_embedding.py - assert torch.allclose(hf_output, vllm_output, - 1e-3 if dtype == "float" else 1e-2) + assert torch.allclose(hf_output, vllm_output, 1e-2) diff --git a/tests/models/language/pooling/test_embedding.py b/tests/models/language/pooling/test_embedding.py index 306cfdf37707..8f82c8091af3 100644 --- a/tests/models/language/pooling/test_embedding.py +++ b/tests/models/language/pooling/test_embedding.py @@ -30,13 +30,11 @@ pytest.param("sentence-transformers/stsb-roberta-base-v2"), ], ) -@pytest.mark.parametrize("dtype", ["half"]) def test_models( hf_runner, vllm_runner, example_prompts, model, - dtype: str, monkeypatch, ) -> None: @@ -58,13 +56,11 @@ def test_models( # So we need to strip the input texts to avoid test failing. example_prompts = [str(s).strip() for s in example_prompts] - with hf_runner(model, dtype=dtype, - is_sentence_transformer=True) as hf_model: + with hf_runner(model, is_sentence_transformer=True) as hf_model: hf_outputs = hf_model.encode(example_prompts) with vllm_runner(model, task="embed", - dtype=dtype, max_model_len=None, **vllm_extra_kwargs) as vllm_model: vllm_outputs = vllm_model.encode(example_prompts) diff --git a/tests/models/language/pooling/test_gte.py b/tests/models/language/pooling/test_gte.py index 725e3d168408..83b37f588f60 100644 --- a/tests/models/language/pooling/test_gte.py +++ b/tests/models/language/pooling/test_gte.py @@ -35,12 +35,15 @@ ########### NewModel EmbedModelInfo("Alibaba-NLP/gte-multilingual-base", architecture="GteNewModel", + dtype="half", enable_test=True), EmbedModelInfo("Alibaba-NLP/gte-base-en-v1.5", architecture="GteNewModel", + dtype="half", enable_test=True), EmbedModelInfo("Alibaba-NLP/gte-large-en-v1.5", architecture="GteNewModel", + dtype="half", enable_test=True), ########### Qwen2ForCausalLM EmbedModelInfo("Alibaba-NLP/gte-Qwen2-1.5B-instruct", @@ -50,6 +53,7 @@ ########## ModernBertModel EmbedModelInfo("Alibaba-NLP/gte-modernbert-base", architecture="ModernBertModel", + dtype="half", enable_test=True), ] diff --git a/tests/models/language/pooling/test_snowflake_arctic_embed.py b/tests/models/language/pooling/test_snowflake_arctic_embed.py index c6c2d1e7a679..ea492d23cb69 100644 --- a/tests/models/language/pooling/test_snowflake_arctic_embed.py +++ b/tests/models/language/pooling/test_snowflake_arctic_embed.py @@ -9,34 +9,42 @@ EmbedModelInfo("Snowflake/snowflake-arctic-embed-xs", is_matryoshka=False, architecture="BertModel", + dtype="half", enable_test=True), EmbedModelInfo("Snowflake/snowflake-arctic-embed-s", is_matryoshka=False, architecture="BertModel", + dtype="half", enable_test=False), EmbedModelInfo("Snowflake/snowflake-arctic-embed-m", is_matryoshka=False, architecture="BertModel", + dtype="half", enable_test=False), EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-long", is_matryoshka=False, architecture="NomicBertModel", + dtype="half", enable_test=True), EmbedModelInfo("Snowflake/snowflake-arctic-embed-l", is_matryoshka=False, architecture="BertModel", + dtype="half", enable_test=False), EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v1.5", is_matryoshka=True, architecture="BertModel", + dtype="half", enable_test=True), EmbedModelInfo("Snowflake/snowflake-arctic-embed-l-v2.0", is_matryoshka=True, architecture="XLMRobertaModel", + dtype="half", enable_test=True), EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v2.0", is_matryoshka=True, architecture="GteModel", + dtype="half", enable_test=True), ] diff --git a/tests/models/multimodal/generation/test_whisper.py b/tests/models/multimodal/generation/test_whisper.py index 4e48bdbd0428..d0b85842a3d8 100644 --- a/tests/models/multimodal/generation/test_whisper.py +++ b/tests/models/multimodal/generation/test_whisper.py @@ -100,6 +100,7 @@ def run_test( with vllm_runner( model, + dtype="half", max_model_len=448, tensor_parallel_size=tensor_parallel_size, distributed_executor_backend=distributed_executor_backend, From 721b427fe1b2b7c77605d1873a09be8d7ad96c94 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 30 May 2025 07:56:58 +0000 Subject: [PATCH 08/15] Fix Signed-off-by: DarkLight1337 --- tests/samplers/test_no_bad_words.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/samplers/test_no_bad_words.py b/tests/samplers/test_no_bad_words.py index 355e3adcf5f3..f9688b4b9b27 100644 --- a/tests/samplers/test_no_bad_words.py +++ b/tests/samplers/test_no_bad_words.py @@ -103,7 +103,7 @@ def setup_method(self, method): add_special_tokens=False)[0] def test_two_token_bad_word(self, vllm_runner): - with vllm_runner(self.MODEL) as llm: + with vllm_runner(self.MODEL, dtype="half") as llm: output_token_ids = self._generate(llm) assert output_token_ids[:2] == [ self.target_token_id1, self.target_token_id2 From 447bea24650572014a3f63f0e272078cf064b62a Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 30 May 2025 09:47:25 +0000 Subject: [PATCH 09/15] Fix Signed-off-by: DarkLight1337 --- tests/models/language/pooling/test_classification.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/models/language/pooling/test_classification.py b/tests/models/language/pooling/test_classification.py index a46461168237..955daf6ff712 100644 --- a/tests/models/language/pooling/test_classification.py +++ b/tests/models/language/pooling/test_classification.py @@ -13,11 +13,13 @@ marks=[pytest.mark.core_model, pytest.mark.cpu_model]), ], ) +@pytest.mark.parametrize("dtype", ["float"]) def test_models( hf_runner, vllm_runner, example_prompts, model: str, + dtype: str, monkeypatch, ) -> None: if current_platform.is_rocm(): @@ -25,10 +27,11 @@ def test_models( # switch to use ROCm CK FA backend monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False") - with vllm_runner(model) as vllm_model: + with vllm_runner(model, dtype=dtype) as vllm_model: vllm_outputs = vllm_model.classify(example_prompts) with hf_runner(model, + dtype=dtype, auto_cls=AutoModelForSequenceClassification) as hf_model: hf_outputs = hf_model.classify(example_prompts) @@ -36,4 +39,4 @@ def test_models( hf_output = torch.tensor(hf_output) vllm_output = torch.tensor(vllm_output) - assert torch.allclose(hf_output, vllm_output, 1e-2) + assert torch.allclose(hf_output, vllm_output, 1e-3) From 82a88532ab313606406550907720383f944ed511 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 30 May 2025 11:03:11 +0000 Subject: [PATCH 10/15] Fix test Signed-off-by: DarkLight1337 --- tests/models/language/pooling/test_snowflake_arctic_embed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/language/pooling/test_snowflake_arctic_embed.py b/tests/models/language/pooling/test_snowflake_arctic_embed.py index ea492d23cb69..112976c4ddcf 100644 --- a/tests/models/language/pooling/test_snowflake_arctic_embed.py +++ b/tests/models/language/pooling/test_snowflake_arctic_embed.py @@ -24,7 +24,7 @@ EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-long", is_matryoshka=False, architecture="NomicBertModel", - dtype="half", + dtype="float32", enable_test=True), EmbedModelInfo("Snowflake/snowflake-arctic-embed-l", is_matryoshka=False, From bfa796d1a46611330b7be4fcf9234569cb2b742d Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 30 May 2025 13:23:31 +0000 Subject: [PATCH 11/15] Use the exact same dtype for both impls Signed-off-by: DarkLight1337 --- tests/models/language/pooling/mteb_utils.py | 11 ++++------- .../language/pooling/test_snowflake_arctic_embed.py | 2 +- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/tests/models/language/pooling/mteb_utils.py b/tests/models/language/pooling/mteb_utils.py index f4837ae952c3..f45168bc0f1d 100644 --- a/tests/models/language/pooling/mteb_utils.py +++ b/tests/models/language/pooling/mteb_utils.py @@ -102,21 +102,18 @@ def mteb_test_embed_models(hf_runner, vllm_main_score = run_mteb_embed_task(VllmMtebEncoder(vllm_model), MTEB_EMBED_TASKS) vllm_dtype = vllm_model.model.llm_engine.model_config.dtype - model_dtype = getattr( - vllm_model.model.llm_engine.model_config.hf_config, "torch_dtype", - vllm_dtype) - with set_default_torch_dtype(model_dtype) and hf_runner( + with set_default_torch_dtype(vllm_dtype) and hf_runner( model_info.name, is_sentence_transformer=True, - dtype=model_dtype) as hf_model: + dtype=vllm_dtype) as hf_model: if hf_model_callback is not None: hf_model_callback(hf_model) st_main_score = run_mteb_embed_task(hf_model, MTEB_EMBED_TASKS) - print("VLLM:", vllm_dtype, vllm_main_score) - print("SentenceTransformer:", model_dtype, st_main_score) + print("VLLM:", vllm_main_score) + print("SentenceTransformers:", st_main_score) print("Difference:", st_main_score - vllm_main_score) assert st_main_score == pytest.approx(vllm_main_score, abs=MTEB_EMBED_TOL) diff --git a/tests/models/language/pooling/test_snowflake_arctic_embed.py b/tests/models/language/pooling/test_snowflake_arctic_embed.py index 112976c4ddcf..ea492d23cb69 100644 --- a/tests/models/language/pooling/test_snowflake_arctic_embed.py +++ b/tests/models/language/pooling/test_snowflake_arctic_embed.py @@ -24,7 +24,7 @@ EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-long", is_matryoshka=False, architecture="NomicBertModel", - dtype="float32", + dtype="half", enable_test=True), EmbedModelInfo("Snowflake/snowflake-arctic-embed-l", is_matryoshka=False, From 49119fca8614eb44ca7eec6613c8f45e6922a1d0 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 31 May 2025 05:59:15 +0000 Subject: [PATCH 12/15] Revert changes to pooling tests Signed-off-by: DarkLight1337 --- tests/models/language/pooling/test_baai.py | 14 -------------- .../models/language/pooling/test_classification.py | 10 ++++++++-- tests/models/language/pooling/test_embedding.py | 6 +++++- tests/models/language/pooling/test_gte.py | 4 ---- .../pooling/test_snowflake_arctic_embed.py | 8 -------- 5 files changed, 13 insertions(+), 29 deletions(-) diff --git a/tests/models/language/pooling/test_baai.py b/tests/models/language/pooling/test_baai.py index 8e435e79ee31..fc0e8207954f 100644 --- a/tests/models/language/pooling/test_baai.py +++ b/tests/models/language/pooling/test_baai.py @@ -8,60 +8,46 @@ ########## BertModel EmbedModelInfo("BAAI/bge-base-en", architecture="BertModel", - dtype="half", enable_test=True), EmbedModelInfo("BAAI/bge-base-zh", architecture="BertModel", - dtype="half", enable_test=False), EmbedModelInfo("BAAI/bge-small-en", architecture="BertModel", - dtype="half", enable_test=False), EmbedModelInfo("BAAI/bge-small-zh", architecture="BertModel", - dtype="half", enable_test=False), EmbedModelInfo("BAAI/bge-large-en", architecture="BertModel", - dtype="half", enable_test=False), EmbedModelInfo("BAAI/bge-large-zh", architecture="BertModel", - dtype="half", enable_test=False), EmbedModelInfo("BAAI/bge-large-zh-noinstruct", architecture="BertModel", - dtype="half", enable_test=False), EmbedModelInfo("BAAI/bge-base-en-v1.5", architecture="BertModel", - dtype="half", enable_test=False), EmbedModelInfo("BAAI/bge-base-zh-v1.5", architecture="BertModel", - dtype="half", enable_test=False), EmbedModelInfo("BAAI/bge-small-en-v1.5", architecture="BertModel", - dtype="half", enable_test=False), EmbedModelInfo("BAAI/bge-small-zh-v1.5", architecture="BertModel", - dtype="half", enable_test=False), EmbedModelInfo("BAAI/bge-large-en-v1.5", architecture="BertModel", - dtype="half", enable_test=False), EmbedModelInfo("BAAI/bge-large-zh-v1.5", architecture="BertModel", - dtype="half", enable_test=False), ########## XLMRobertaModel EmbedModelInfo("BAAI/bge-m3", architecture="XLMRobertaModel", - dtype="half", enable_test=True), ########## Qwen2Model EmbedModelInfo("BAAI/bge-code-v1", diff --git a/tests/models/language/pooling/test_classification.py b/tests/models/language/pooling/test_classification.py index 955daf6ff712..57b3cb58d88b 100644 --- a/tests/models/language/pooling/test_classification.py +++ b/tests/models/language/pooling/test_classification.py @@ -13,7 +13,8 @@ marks=[pytest.mark.core_model, pytest.mark.cpu_model]), ], ) -@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("dtype", + ["half"] if current_platform.is_rocm() else ["float"]) def test_models( hf_runner, vllm_runner, @@ -35,8 +36,13 @@ def test_models( auto_cls=AutoModelForSequenceClassification) as hf_model: hf_outputs = hf_model.classify(example_prompts) + # check logits difference for hf_output, vllm_output in zip(hf_outputs, vllm_outputs): hf_output = torch.tensor(hf_output) vllm_output = torch.tensor(vllm_output) - assert torch.allclose(hf_output, vllm_output, 1e-3) + # the tolerance value of 1e-2 is selected based on the + # half datatype tests in + # tests/models/language/pooling/test_embedding.py + assert torch.allclose(hf_output, vllm_output, + 1e-3 if dtype == "float" else 1e-2) diff --git a/tests/models/language/pooling/test_embedding.py b/tests/models/language/pooling/test_embedding.py index 8f82c8091af3..306cfdf37707 100644 --- a/tests/models/language/pooling/test_embedding.py +++ b/tests/models/language/pooling/test_embedding.py @@ -30,11 +30,13 @@ pytest.param("sentence-transformers/stsb-roberta-base-v2"), ], ) +@pytest.mark.parametrize("dtype", ["half"]) def test_models( hf_runner, vllm_runner, example_prompts, model, + dtype: str, monkeypatch, ) -> None: @@ -56,11 +58,13 @@ def test_models( # So we need to strip the input texts to avoid test failing. example_prompts = [str(s).strip() for s in example_prompts] - with hf_runner(model, is_sentence_transformer=True) as hf_model: + with hf_runner(model, dtype=dtype, + is_sentence_transformer=True) as hf_model: hf_outputs = hf_model.encode(example_prompts) with vllm_runner(model, task="embed", + dtype=dtype, max_model_len=None, **vllm_extra_kwargs) as vllm_model: vllm_outputs = vllm_model.encode(example_prompts) diff --git a/tests/models/language/pooling/test_gte.py b/tests/models/language/pooling/test_gte.py index 83b37f588f60..725e3d168408 100644 --- a/tests/models/language/pooling/test_gte.py +++ b/tests/models/language/pooling/test_gte.py @@ -35,15 +35,12 @@ ########### NewModel EmbedModelInfo("Alibaba-NLP/gte-multilingual-base", architecture="GteNewModel", - dtype="half", enable_test=True), EmbedModelInfo("Alibaba-NLP/gte-base-en-v1.5", architecture="GteNewModel", - dtype="half", enable_test=True), EmbedModelInfo("Alibaba-NLP/gte-large-en-v1.5", architecture="GteNewModel", - dtype="half", enable_test=True), ########### Qwen2ForCausalLM EmbedModelInfo("Alibaba-NLP/gte-Qwen2-1.5B-instruct", @@ -53,7 +50,6 @@ ########## ModernBertModel EmbedModelInfo("Alibaba-NLP/gte-modernbert-base", architecture="ModernBertModel", - dtype="half", enable_test=True), ] diff --git a/tests/models/language/pooling/test_snowflake_arctic_embed.py b/tests/models/language/pooling/test_snowflake_arctic_embed.py index ea492d23cb69..c6c2d1e7a679 100644 --- a/tests/models/language/pooling/test_snowflake_arctic_embed.py +++ b/tests/models/language/pooling/test_snowflake_arctic_embed.py @@ -9,42 +9,34 @@ EmbedModelInfo("Snowflake/snowflake-arctic-embed-xs", is_matryoshka=False, architecture="BertModel", - dtype="half", enable_test=True), EmbedModelInfo("Snowflake/snowflake-arctic-embed-s", is_matryoshka=False, architecture="BertModel", - dtype="half", enable_test=False), EmbedModelInfo("Snowflake/snowflake-arctic-embed-m", is_matryoshka=False, architecture="BertModel", - dtype="half", enable_test=False), EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-long", is_matryoshka=False, architecture="NomicBertModel", - dtype="half", enable_test=True), EmbedModelInfo("Snowflake/snowflake-arctic-embed-l", is_matryoshka=False, architecture="BertModel", - dtype="half", enable_test=False), EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v1.5", is_matryoshka=True, architecture="BertModel", - dtype="half", enable_test=True), EmbedModelInfo("Snowflake/snowflake-arctic-embed-l-v2.0", is_matryoshka=True, architecture="XLMRobertaModel", - dtype="half", enable_test=True), EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v2.0", is_matryoshka=True, architecture="GteModel", - dtype="half", enable_test=True), ] From a9bfe1cc70e889649e40e690c2799c6b79ed6f92 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 31 May 2025 06:18:52 +0000 Subject: [PATCH 13/15] Don't downcast dtype for pooling models Signed-off-by: DarkLight1337 --- tests/conftest.py | 3 ++- vllm/config.py | 48 +++++++++++++++++++++++++++++------------------ 2 files changed, 32 insertions(+), 19 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 73aaa21d5758..6336c6c2ce01 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -327,7 +327,8 @@ def __init__( self.dtype = torch_dtype = _get_and_verify_dtype( self.model_name, self.config, - dtype, + dtype=dtype, + is_pooling_model=is_sentence_transformer or is_cross_encoder, ) model_kwargs = model_kwargs if model_kwargs is not None else {} diff --git a/vllm/config.py b/vllm/config.py index c7543428717b..4ec770a5794b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -542,10 +542,22 @@ def __post_init__(self) -> None: self.encoder_config = self._get_encoder_config() self.hf_image_processor_config = get_hf_image_processor_config( self.model, hf_token=self.hf_token, revision=self.revision) + + supported_tasks, task = self._resolve_task(self.task) + self.supported_tasks = supported_tasks + self.task = task + if self.task in ("draft", "generate"): + self.truncation_side = "left" + else: + self.truncation_side = "right" + + self.pooler_config = self._init_pooler_config() + self.dtype = _get_and_verify_dtype( self.model, self.hf_config, self.dtype, + is_pooling_model=self.runner_type == "pooling", revision=self.revision, ) @@ -604,16 +616,6 @@ def __post_init__(self) -> None: raise ValueError( "`override_neuron_config` is only supported on Neuron.") - supported_tasks, task = self._resolve_task(self.task) - self.supported_tasks = supported_tasks - self.task = task - if self.task in ("draft", "generate"): - self.truncation_side = "left" - else: - self.truncation_side = "right" - - self.pooler_config = self._init_pooler_config() - self._verify_quantization() self._verify_cuda_graph() self._verify_bnb_config() @@ -699,7 +701,6 @@ def _get_encoder_config(self): self.model, self.revision) def _init_pooler_config(self) -> Optional["PoolerConfig"]: - if self.runner_type == "pooling": if isinstance(self.override_pooler_config, dict): self.override_pooler_config = PoolerConfig( @@ -3134,20 +3135,26 @@ def _find_dtype( return config_dtype -def _resolve_auto_dtype(model_type: str, config_dtype: torch.dtype): +def _resolve_auto_dtype( + model_type: str, + config_dtype: torch.dtype, + *, + is_pooling_model: bool, +): from vllm.platforms import current_platform - platform_dtype = next(dtype for dtype in current_platform.supported_dtypes + platform_supported_dtypes = current_platform.supported_dtypes + platform_dtype = next(dtype for dtype in platform_supported_dtypes if _is_valid_dtype(model_type, dtype)) - # Downcast to platform's default for float32 models + # Downcast for float32 models if config_dtype == torch.float32: - return platform_dtype + config_dtype = torch.float16 if is_pooling_model else platform_dtype - # Ensure device compatibility - if config_dtype in current_platform.supported_dtypes: + if config_dtype in platform_supported_dtypes: return config_dtype + # Ensure device compatibility device_name = current_platform.get_device_name() device_capability = current_platform.get_device_capability() @@ -3173,6 +3180,7 @@ def _get_and_verify_dtype( config: PretrainedConfig, dtype: Union[str, torch.dtype], *, + is_pooling_model: bool, revision: Optional[str] = None, ) -> torch.dtype: config_dtype = _find_dtype(model_id, config, revision=revision) @@ -3182,7 +3190,11 @@ def _get_and_verify_dtype( dtype = dtype.lower() if dtype == "auto": # Set default dtype from model config - torch_dtype = _resolve_auto_dtype(model_type, config_dtype) + torch_dtype = _resolve_auto_dtype( + model_type, + config_dtype, + is_pooling_model=is_pooling_model, + ) else: if dtype not in _STR_DTYPE_TO_TORCH_DTYPE: raise ValueError(f"Unknown dtype: {dtype!r}") From 7e5ffeb13a50fd4c9a2f6f2a6681d25ffe03e8cb Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 31 May 2025 10:48:29 +0000 Subject: [PATCH 14/15] Fix Signed-off-by: DarkLight1337 --- vllm/config.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 4ec770a5794b..5776fc5f3531 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3143,15 +3143,21 @@ def _resolve_auto_dtype( ): from vllm.platforms import current_platform - platform_supported_dtypes = current_platform.supported_dtypes - platform_dtype = next(dtype for dtype in platform_supported_dtypes - if _is_valid_dtype(model_type, dtype)) + supported_dtypes = [ + dtype for dtype in current_platform.supported_dtypes + if _is_valid_dtype(model_type, dtype) + ] + + if is_pooling_model and torch.float16 in supported_dtypes: + preferred_dtype = torch.float16 + else: + preferred_dtype = supported_dtypes[0] # Downcast for float32 models if config_dtype == torch.float32: - config_dtype = torch.float16 if is_pooling_model else platform_dtype + config_dtype = preferred_dtype - if config_dtype in platform_supported_dtypes: + if config_dtype in supported_dtypes: return config_dtype # Ensure device compatibility @@ -3169,10 +3175,10 @@ def _resolve_auto_dtype( "Falling back to %s for compatibility.", device_str, config_dtype, - platform_dtype, + preferred_dtype, ) - return platform_dtype + return preferred_dtype def _get_and_verify_dtype( From 48c5c2b8fe52acb0446f1fc1729cc22dfdcd1ff6 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 31 May 2025 14:40:53 +0000 Subject: [PATCH 15/15] Fix Signed-off-by: DarkLight1337 --- tests/models/language/pooling/test_embedding.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/models/language/pooling/test_embedding.py b/tests/models/language/pooling/test_embedding.py index 306cfdf37707..8f82c8091af3 100644 --- a/tests/models/language/pooling/test_embedding.py +++ b/tests/models/language/pooling/test_embedding.py @@ -30,13 +30,11 @@ pytest.param("sentence-transformers/stsb-roberta-base-v2"), ], ) -@pytest.mark.parametrize("dtype", ["half"]) def test_models( hf_runner, vllm_runner, example_prompts, model, - dtype: str, monkeypatch, ) -> None: @@ -58,13 +56,11 @@ def test_models( # So we need to strip the input texts to avoid test failing. example_prompts = [str(s).strip() for s in example_prompts] - with hf_runner(model, dtype=dtype, - is_sentence_transformer=True) as hf_model: + with hf_runner(model, is_sentence_transformer=True) as hf_model: hf_outputs = hf_model.encode(example_prompts) with vllm_runner(model, task="embed", - dtype=dtype, max_model_len=None, **vllm_extra_kwargs) as vllm_model: vllm_outputs = vllm_model.encode(example_prompts)