Skip to content

Commit 136a17f

Browse files
[Chore] Separate out vllm.utils.func (vllm-project#26904)
Signed-off-by: DarkLight1337 <[email protected]>
1 parent f574383 commit 136a17f

File tree

18 files changed

+407
-371
lines changed

18 files changed

+407
-371
lines changed

tests/models/multimodal/generation/test_common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
)
1818

1919
from vllm.platforms import current_platform
20-
from vllm.utils import identity
20+
from vllm.utils.func import identity
2121

2222
from ....conftest import (
2323
IMAGE_ASSETS,

tests/utils_/test_func_utils.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
# ruff: noqa
4+
5+
import pytest
6+
7+
from vllm.utils.func import deprecate_kwargs, supports_kw
8+
9+
from ..utils import error_on_warning
10+
11+
12+
def test_deprecate_kwargs_always():
13+
@deprecate_kwargs("old_arg", is_deprecated=True)
14+
def dummy(*, old_arg: object = None, new_arg: object = None):
15+
pass
16+
17+
with pytest.warns(DeprecationWarning, match="'old_arg'"):
18+
dummy(old_arg=1)
19+
20+
with error_on_warning(DeprecationWarning):
21+
dummy(new_arg=1)
22+
23+
24+
def test_deprecate_kwargs_never():
25+
@deprecate_kwargs("old_arg", is_deprecated=False)
26+
def dummy(*, old_arg: object = None, new_arg: object = None):
27+
pass
28+
29+
with error_on_warning(DeprecationWarning):
30+
dummy(old_arg=1)
31+
32+
with error_on_warning(DeprecationWarning):
33+
dummy(new_arg=1)
34+
35+
36+
def test_deprecate_kwargs_dynamic():
37+
is_deprecated = True
38+
39+
@deprecate_kwargs("old_arg", is_deprecated=lambda: is_deprecated)
40+
def dummy(*, old_arg: object = None, new_arg: object = None):
41+
pass
42+
43+
with pytest.warns(DeprecationWarning, match="'old_arg'"):
44+
dummy(old_arg=1)
45+
46+
with error_on_warning(DeprecationWarning):
47+
dummy(new_arg=1)
48+
49+
is_deprecated = False
50+
51+
with error_on_warning(DeprecationWarning):
52+
dummy(old_arg=1)
53+
54+
with error_on_warning(DeprecationWarning):
55+
dummy(new_arg=1)
56+
57+
58+
def test_deprecate_kwargs_additional_message():
59+
@deprecate_kwargs("old_arg", is_deprecated=True, additional_message="abcd")
60+
def dummy(*, old_arg: object = None, new_arg: object = None):
61+
pass
62+
63+
with pytest.warns(DeprecationWarning, match="abcd"):
64+
dummy(old_arg=1)
65+
66+
67+
@pytest.mark.parametrize(
68+
("callable", "kw_name", "requires_kw_only", "allow_var_kwargs", "is_supported"),
69+
[
70+
# Tests for positional argument support
71+
(lambda foo: None, "foo", True, True, False),
72+
(lambda foo: None, "foo", False, True, True),
73+
# Tests for positional or keyword / keyword only
74+
(lambda foo=100: None, "foo", True, True, False),
75+
(lambda *, foo: None, "foo", False, True, True),
76+
# Tests to make sure the names of variadic params are NOT supported
77+
(lambda *args: None, "args", False, True, False),
78+
(lambda **kwargs: None, "kwargs", False, True, False),
79+
# Tests for if we allow var kwargs to add support
80+
(lambda foo: None, "something_else", False, True, False),
81+
(lambda foo, **kwargs: None, "something_else", False, True, True),
82+
(lambda foo, **kwargs: None, "kwargs", True, True, False),
83+
(lambda foo, **kwargs: None, "foo", True, True, False),
84+
],
85+
)
86+
def test_supports_kw(
87+
callable, kw_name, requires_kw_only, allow_var_kwargs, is_supported
88+
):
89+
assert (
90+
supports_kw(
91+
callable=callable,
92+
kw_name=kw_name,
93+
requires_kw_only=requires_kw_only,
94+
allow_var_kwargs=allow_var_kwargs,
95+
)
96+
== is_supported
97+
)

tests/utils_/test_jsontree.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from vllm.utils.jsontree import json_count_leaves
4+
5+
6+
def test_json_count_leaves():
7+
"""Test json_count_leaves function from jsontree utility."""
8+
9+
# Single leaf values
10+
assert json_count_leaves(42) == 1
11+
assert json_count_leaves("hello") == 1
12+
assert json_count_leaves(None) == 1
13+
14+
# Empty containers
15+
assert json_count_leaves([]) == 0
16+
assert json_count_leaves({}) == 0
17+
assert json_count_leaves(()) == 0
18+
19+
# Flat structures
20+
assert json_count_leaves([1, 2, 3]) == 3
21+
assert json_count_leaves({"a": 1, "b": 2}) == 2
22+
assert json_count_leaves((1, 2, 3)) == 3
23+
24+
# Nested structures
25+
nested_dict = {"a": 1, "b": {"c": 2, "d": 3}}
26+
assert json_count_leaves(nested_dict) == 3
27+
28+
nested_list = [1, [2, 3], 4]
29+
assert json_count_leaves(nested_list) == 4
30+
31+
mixed_nested = {"list": [1, 2], "dict": {"x": 3}, "value": 4}
32+
assert json_count_leaves(mixed_nested) == 4

tests/utils_/test_utils.py

Lines changed: 1 addition & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
bind_kv_cache,
3131
common_broadcastable_dtype,
3232
current_stream,
33-
deprecate_kwargs,
3433
get_open_port,
3534
get_tcp_uri,
3635
is_lossless_cast,
@@ -42,12 +41,11 @@
4241
sha256,
4342
split_host_port,
4443
split_zmq_path,
45-
supports_kw,
4644
swap_dict_values,
4745
unique_filepath,
4846
)
4947

50-
from ..utils import create_new_process_for_each_test, error_on_warning
48+
from ..utils import create_new_process_for_each_test
5149

5250

5351
@pytest.mark.asyncio
@@ -83,61 +81,6 @@ async def stream_output(generator: AsyncIterator[tuple[int, str]]):
8381
raise AssertionError() from e
8482

8583

86-
def test_deprecate_kwargs_always():
87-
@deprecate_kwargs("old_arg", is_deprecated=True)
88-
def dummy(*, old_arg: object = None, new_arg: object = None):
89-
pass
90-
91-
with pytest.warns(DeprecationWarning, match="'old_arg'"):
92-
dummy(old_arg=1)
93-
94-
with error_on_warning(DeprecationWarning):
95-
dummy(new_arg=1)
96-
97-
98-
def test_deprecate_kwargs_never():
99-
@deprecate_kwargs("old_arg", is_deprecated=False)
100-
def dummy(*, old_arg: object = None, new_arg: object = None):
101-
pass
102-
103-
with error_on_warning(DeprecationWarning):
104-
dummy(old_arg=1)
105-
106-
with error_on_warning(DeprecationWarning):
107-
dummy(new_arg=1)
108-
109-
110-
def test_deprecate_kwargs_dynamic():
111-
is_deprecated = True
112-
113-
@deprecate_kwargs("old_arg", is_deprecated=lambda: is_deprecated)
114-
def dummy(*, old_arg: object = None, new_arg: object = None):
115-
pass
116-
117-
with pytest.warns(DeprecationWarning, match="'old_arg'"):
118-
dummy(old_arg=1)
119-
120-
with error_on_warning(DeprecationWarning):
121-
dummy(new_arg=1)
122-
123-
is_deprecated = False
124-
125-
with error_on_warning(DeprecationWarning):
126-
dummy(old_arg=1)
127-
128-
with error_on_warning(DeprecationWarning):
129-
dummy(new_arg=1)
130-
131-
132-
def test_deprecate_kwargs_additional_message():
133-
@deprecate_kwargs("old_arg", is_deprecated=True, additional_message="abcd")
134-
def dummy(*, old_arg: object = None, new_arg: object = None):
135-
pass
136-
137-
with pytest.warns(DeprecationWarning, match="abcd"):
138-
dummy(old_arg=1)
139-
140-
14184
def test_get_open_port(monkeypatch: pytest.MonkeyPatch):
14285
with monkeypatch.context() as m:
14386
m.setenv("VLLM_PORT", "5678")
@@ -383,39 +326,6 @@ def test_duplicate_dict_args(caplog_vllm, parser):
383326
assert "-O.mode" in caplog_vllm.text
384327

385328

386-
@pytest.mark.parametrize(
387-
"callable,kw_name,requires_kw_only,allow_var_kwargs,is_supported",
388-
[
389-
# Tests for positional argument support
390-
(lambda foo: None, "foo", True, True, False),
391-
(lambda foo: None, "foo", False, True, True),
392-
# Tests for positional or keyword / keyword only
393-
(lambda foo=100: None, "foo", True, True, False),
394-
(lambda *, foo: None, "foo", False, True, True),
395-
# Tests to make sure the names of variadic params are NOT supported
396-
(lambda *args: None, "args", False, True, False),
397-
(lambda **kwargs: None, "kwargs", False, True, False),
398-
# Tests for if we allow var kwargs to add support
399-
(lambda foo: None, "something_else", False, True, False),
400-
(lambda foo, **kwargs: None, "something_else", False, True, True),
401-
(lambda foo, **kwargs: None, "kwargs", True, True, False),
402-
(lambda foo, **kwargs: None, "foo", True, True, False),
403-
],
404-
)
405-
def test_supports_kw(
406-
callable, kw_name, requires_kw_only, allow_var_kwargs, is_supported
407-
):
408-
assert (
409-
supports_kw(
410-
callable=callable,
411-
kw_name=kw_name,
412-
requires_kw_only=requires_kw_only,
413-
allow_var_kwargs=allow_var_kwargs,
414-
)
415-
== is_supported
416-
)
417-
418-
419329
@create_new_process_for_each_test()
420330
def test_memory_profiling():
421331
# Fake out some model loading + inference memory usage to test profiling
@@ -863,36 +773,6 @@ def test_join_host_port():
863773
assert join_host_port("::1", 5555) == "[::1]:5555"
864774

865775

866-
def test_json_count_leaves():
867-
"""Test json_count_leaves function from jsontree utility."""
868-
from vllm.utils.jsontree import json_count_leaves
869-
870-
# Single leaf values
871-
assert json_count_leaves(42) == 1
872-
assert json_count_leaves("hello") == 1
873-
assert json_count_leaves(None) == 1
874-
875-
# Empty containers
876-
assert json_count_leaves([]) == 0
877-
assert json_count_leaves({}) == 0
878-
assert json_count_leaves(()) == 0
879-
880-
# Flat structures
881-
assert json_count_leaves([1, 2, 3]) == 3
882-
assert json_count_leaves({"a": 1, "b": 2}) == 2
883-
assert json_count_leaves((1, 2, 3)) == 3
884-
885-
# Nested structures
886-
nested_dict = {"a": 1, "b": {"c": 2, "d": 3}}
887-
assert json_count_leaves(nested_dict) == 3
888-
889-
nested_list = [1, [2, 3], 4]
890-
assert json_count_leaves(nested_list) == 4
891-
892-
mixed_nested = {"list": [1, 2], "dict": {"x": 3}, "value": 4}
893-
assert json_count_leaves(mixed_nested) == 4
894-
895-
896776
def test_convert_ids_list_to_tokens():
897777
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct")
898778
token_ids = tokenizer.encode("Hello, world!")

vllm/entrypoints/chat_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@
5050
from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path
5151
from vllm.transformers_utils.processor import cached_get_processor
5252
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
53-
from vllm.utils import random_uuid, supports_kw
53+
from vllm.utils import random_uuid
54+
from vllm.utils.func import supports_kw
5455

5556
logger = init_logger(__name__)
5657

vllm/entrypoints/openai/serving_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,10 @@
9494
AsyncMicrobatchTokenizer,
9595
collect_from_async_generator,
9696
is_list_of,
97-
make_async,
9897
merge_async_iterators,
9998
random_uuid,
10099
)
100+
from vllm.utils.func import make_async
101101
from vllm.v1.engine import EngineCoreRequest
102102

103103
logger = init_logger(__name__)

vllm/entrypoints/openai/serving_score.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@
3737
from vllm.lora.request import LoRARequest
3838
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
3939
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
40-
from vllm.utils import make_async, merge_async_iterators
40+
from vllm.utils import merge_async_iterators
41+
from vllm.utils.func import make_async
4142

4243
logger = init_logger(__name__)
4344

vllm/executor/executor_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from vllm.lora.request import LoRARequest
1818
from vllm.sequence import ExecuteModelRequest
1919
from vllm.tasks import SupportedTask
20-
from vllm.utils import make_async
20+
from vllm.utils.func import make_async
2121
from vllm.v1.outputs import SamplerOutput
2222
from vllm.v1.worker.worker_base import WorkerBase
2323

vllm/executor/ray_distributed_executor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
get_distributed_init_method,
2525
get_ip,
2626
get_open_port,
27-
make_async,
2827
)
28+
from vllm.utils.func import make_async
2929
from vllm.v1.outputs import SamplerOutput
3030

3131
if ray is not None:

vllm/model_executor/layers/fused_moe/deep_gemm_moe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,9 @@
2727
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
2828
per_token_group_quant_fp8,
2929
)
30-
from vllm.utils import has_deep_gemm, run_once
30+
from vllm.utils import has_deep_gemm
3131
from vllm.utils.deep_gemm import m_grouped_fp8_gemm_nt_contiguous
32+
from vllm.utils.func import run_once
3233

3334
logger = init_logger(__name__)
3435

0 commit comments

Comments
 (0)