Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,13 @@ repos:
types: [python]
pass_filenames: false
additional_dependencies: [pathspec, regex]
- id: validate-config
name: Validate configuration has default values and that each field has a docstring
entry: python tools/validate_config.py
language: python
types: [python]
pass_filenames: true
files: vllm/config.py|tests/test_config.py
# Keep `suggestion` last
- id: suggestion
name: Suggestion
Expand Down
35 changes: 1 addition & 34 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,49 +2,16 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from dataclasses import MISSING, Field, asdict, dataclass, field
from typing import Literal, Union

import pytest

from vllm.compilation.backends import VllmBackend
from vllm.config import (LoadConfig, ModelConfig, PoolerConfig, VllmConfig,
config, get_field)
get_field)
from vllm.model_executor.layers.pooler import PoolingType
from vllm.platforms import current_platform


class _TestConfig1:
pass


@dataclass
class _TestConfig2:
a: int
"""docstring"""


@dataclass
class _TestConfig3:
a: int = 1


@dataclass
class _TestConfig4:
a: Union[Literal[1], Literal[2]] = 1
"""docstring"""


@pytest.mark.parametrize(("test_config", "expected_error"), [
(_TestConfig1, "must be a dataclass"),
(_TestConfig2, "must have a default"),
(_TestConfig3, "must have a docstring"),
(_TestConfig4, "must use a single Literal"),
])
def test_config(test_config, expected_error):
with pytest.raises(Exception, match=expected_error):
config(test_config)


def test_compile_config_repr_succeeds():
# setup: VllmBackend mutates the config object
config = VllmConfig()
Expand Down
Empty file added tests/tools/__init__.py
Empty file.
49 changes: 49 additions & 0 deletions tests/tools/test_config_validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import ast

import pytest

from tools.validate_config import validate_ast

_TestConfig1 = '''
@config
class _TestConfig1:
pass
'''

_TestConfig2 = '''
@config
@dataclass
class _TestConfig2:
a: int
"""docstring"""
'''

_TestConfig3 = '''
@config
@dataclass
class _TestConfig3:
a: int = 1
'''

_TestConfig4 = '''
@config
@dataclass
class _TestConfig4:
a: Union[Literal[1], Literal[2]] = 1
"""docstring"""
'''


@pytest.mark.parametrize(("test_config", "expected_error"), [
(_TestConfig1, "must be a dataclass"),
(_TestConfig2, "must have a default"),
(_TestConfig3, "must have a docstring"),
(_TestConfig4, "must use a single Literal"),
])
def test_config(test_config, expected_error):
tree = ast.parse(test_config)
with pytest.raises(Exception, match=expected_error):
validate_ast(tree)
158 changes: 158 additions & 0 deletions tools/validate_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Ensures all fields in a config dataclass have default values
and that each field has a docstring.
"""

import ast
import inspect
import sys


def get_attr_docs(cls_node: ast.ClassDef) -> dict[str, str]:
"""
Get any docstrings placed after attribute assignments in a class body.

Adapted from https://davidism.com/attribute-docstrings/
https://davidism.com/mit-license/
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this file need a "Adapted from https://github.com/..." message? Not sure what the intent of linking this URL is.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure. The link is needed because of the MIT license:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

"""

def pairwise(iterable):
"""
Manually implement https://docs.python.org/3/library/itertools.html#itertools.pairwise

Can be removed when Python 3.9 support is dropped.
"""
iterator = iter(iterable)
a = next(iterator, None)

for b in iterator:
yield a, b
a = b

out = {}

# Consider each pair of nodes.
for a, b in pairwise(cls_node.body):
# Must be an assignment then a constant string.
if (not isinstance(a, (ast.Assign, ast.AnnAssign))
or not isinstance(b, ast.Expr)
or not isinstance(b.value, ast.Constant)
or not isinstance(b.value.value, str)):
continue

doc = inspect.cleandoc(b.value.value)

# An assignment can have multiple targets (a = b = v), but an
# annotated assignment only has one target.
targets = a.targets if isinstance(a, ast.Assign) else [a.target]

for target in targets:
# Must be assigning to a plain name.
if not isinstance(target, ast.Name):
continue

out[target.id] = doc

return out


class ConfigValidator(ast.NodeVisitor):

def __init__(self):
...

def visit_ClassDef(self, node):
# Validate class with both @config and @dataclass decorators
decorators = [
id for d in node.decorator_list if (isinstance(d, ast.Name) and (
(id := d.id) == 'config' or id == 'dataclass')) or
(isinstance(d, ast.Call) and (isinstance(d.func, ast.Name) and
(id := d.func.id) == 'dataclass'))
]

if set(decorators) == {'config', 'dataclass'}:
validate_class(node)
elif set(decorators) == {'config'}:
fail(
f"Class {node.name} with config decorator must be a dataclass.",
node)

self.generic_visit(node)


def validate_class(class_node: ast.ClassDef):
attr_docs = get_attr_docs(class_node)

for stmt in class_node.body:
# A field is defined as a class variable that has a type annotation.
if isinstance(stmt, ast.AnnAssign):
# Skip ClassVar
# see https://docs.python.org/3/library/dataclasses.html#class-variables
if isinstance(stmt.annotation, ast.Subscript) and isinstance(
stmt.annotation.value,
ast.Name) and stmt.annotation.value.id == "ClassVar":
continue

if isinstance(stmt.target, ast.Name):
field_name = stmt.target.id
if stmt.value is None:
fail(
f"Field '{field_name}' in {class_node.name} must have "
"a default value.", stmt)

if field_name not in attr_docs:
fail(
f"Field '{field_name}' in {class_node.name} must have "
"a docstring.", stmt)

if isinstance(stmt.annotation, ast.Subscript) and \
isinstance(stmt.annotation.value, ast.Name) \
and stmt.annotation.value.id == "Union" and \
isinstance(stmt.annotation.slice, ast.Tuple):
args = stmt.annotation.slice.elts
literal_args = [
arg for arg in args
if isinstance(arg, ast.Subscript) and isinstance(
arg.value, ast.Name) and arg.value.id == "Literal"
]
if len(literal_args) > 1:
fail(
f"Field '{field_name}' in {class_node.name} must "
"use a single "
"Literal type. Please use 'Literal[Literal1, "
"Literal2]' instead of 'Union[Literal1, Literal2]'"
".", stmt)


def validate_ast(tree: ast.stmt):
ConfigValidator().visit(tree)


def validate_file(file_path: str):
try:
print(f"validating {file_path} config dataclasses ", end="")
with open(file_path, encoding="utf-8") as f:
source = f.read()

tree = ast.parse(source, filename=file_path)
validate_ast(tree)
except ValueError as e:
print(e)
SystemExit(2)
else:
print("✅")


def fail(message: str, node: ast.stmt):
raise ValueError(f"❌ line({node.lineno}): {message}")


def main():
for filename in sys.argv[1:]:
validate_file(filename)


if __name__ == "__main__":
main()
28 changes: 5 additions & 23 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from importlib.util import find_spec
from pathlib import Path
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Literal, Optional,
Protocol, TypeVar, Union, cast, get_args, get_origin)
Protocol, TypeVar, Union, cast, get_args)

import regex as re
import torch
Expand Down Expand Up @@ -193,28 +193,10 @@ def config(cls: ConfigT) -> ConfigT:
(i.e. `ConfigT(**json.loads(cli_arg))`). However, if a particular `ConfigT`
requires custom construction from CLI (i.e. `CompilationConfig`), it can
have a `from_cli` method, which will be called instead.
"""
if not is_dataclass(cls):
raise TypeError("The decorated class must be a dataclass.")
attr_docs = get_attr_docs(cls)
for f in fields(cls):
if f.init and f.default is MISSING and f.default_factory is MISSING:
raise ValueError(
f"Field '{f.name}' in {cls.__name__} must have a default value."
)

if f.name not in attr_docs:
raise ValueError(
f"Field '{f.name}' in {cls.__name__} must have a docstring.")

if get_origin(f.type) is Union:
args = get_args(f.type)
literal_args = [arg for arg in args if get_origin(arg) is Literal]
if len(literal_args) > 1:
raise ValueError(
f"Field '{f.name}' in {cls.__name__} must use a single "
"Literal type. Please use 'Literal[Literal1, Literal2]' "
"instead of 'Union[Literal1, Literal2]'.")
Config validation is performed by the tools/validate_config.py
script, which is invoked during the pre-commit checks.
"""
return cls


Expand Down Expand Up @@ -1798,7 +1780,7 @@ class ParallelConfig:
eplb_step_interval: int = 3000
"""
Interval for rearranging experts in expert parallelism.

Note that if this is greater than the EPLB window size, only the metrics
of the last `eplb_window_size` steps will be used for rearranging experts.
"""
Expand Down