diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 15ef5defff69..d962252eb3dd 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -160,6 +160,13 @@ repos: types: [python] pass_filenames: false additional_dependencies: [pathspec, regex] + - id: validate-config + name: Validate configuration has default values and that each field has a docstring + entry: python tools/validate_config.py + language: python + types: [python] + pass_filenames: true + files: vllm/config.py|tests/test_config.py # Keep `suggestion` last - id: suggestion name: Suggestion diff --git a/tests/test_config.py b/tests/test_config.py index 5d5c4453d30d..cb7654c26afc 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -2,49 +2,16 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import MISSING, Field, asdict, dataclass, field -from typing import Literal, Union import pytest from vllm.compilation.backends import VllmBackend from vllm.config import (LoadConfig, ModelConfig, PoolerConfig, VllmConfig, - config, get_field) + get_field) from vllm.model_executor.layers.pooler import PoolingType from vllm.platforms import current_platform -class _TestConfig1: - pass - - -@dataclass -class _TestConfig2: - a: int - """docstring""" - - -@dataclass -class _TestConfig3: - a: int = 1 - - -@dataclass -class _TestConfig4: - a: Union[Literal[1], Literal[2]] = 1 - """docstring""" - - -@pytest.mark.parametrize(("test_config", "expected_error"), [ - (_TestConfig1, "must be a dataclass"), - (_TestConfig2, "must have a default"), - (_TestConfig3, "must have a docstring"), - (_TestConfig4, "must use a single Literal"), -]) -def test_config(test_config, expected_error): - with pytest.raises(Exception, match=expected_error): - config(test_config) - - def test_compile_config_repr_succeeds(): # setup: VllmBackend mutates the config object config = VllmConfig() diff --git a/tests/tools/__init__.py b/tests/tools/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/tools/test_config_validator.py b/tests/tools/test_config_validator.py new file mode 100644 index 000000000000..b0475894a114 --- /dev/null +++ b/tests/tools/test_config_validator.py @@ -0,0 +1,49 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import ast + +import pytest + +from tools.validate_config import validate_ast + +_TestConfig1 = ''' +@config +class _TestConfig1: + pass +''' + +_TestConfig2 = ''' +@config +@dataclass +class _TestConfig2: + a: int + """docstring""" +''' + +_TestConfig3 = ''' +@config +@dataclass +class _TestConfig3: + a: int = 1 +''' + +_TestConfig4 = ''' +@config +@dataclass +class _TestConfig4: + a: Union[Literal[1], Literal[2]] = 1 + """docstring""" +''' + + +@pytest.mark.parametrize(("test_config", "expected_error"), [ + (_TestConfig1, "must be a dataclass"), + (_TestConfig2, "must have a default"), + (_TestConfig3, "must have a docstring"), + (_TestConfig4, "must use a single Literal"), +]) +def test_config(test_config, expected_error): + tree = ast.parse(test_config) + with pytest.raises(Exception, match=expected_error): + validate_ast(tree) diff --git a/tools/validate_config.py b/tools/validate_config.py new file mode 100644 index 000000000000..8b1e955c653d --- /dev/null +++ b/tools/validate_config.py @@ -0,0 +1,158 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Ensures all fields in a config dataclass have default values +and that each field has a docstring. +""" + +import ast +import inspect +import sys + + +def get_attr_docs(cls_node: ast.ClassDef) -> dict[str, str]: + """ + Get any docstrings placed after attribute assignments in a class body. + + Adapted from https://davidism.com/attribute-docstrings/ + https://davidism.com/mit-license/ + """ + + def pairwise(iterable): + """ + Manually implement https://docs.python.org/3/library/itertools.html#itertools.pairwise + + Can be removed when Python 3.9 support is dropped. + """ + iterator = iter(iterable) + a = next(iterator, None) + + for b in iterator: + yield a, b + a = b + + out = {} + + # Consider each pair of nodes. + for a, b in pairwise(cls_node.body): + # Must be an assignment then a constant string. + if (not isinstance(a, (ast.Assign, ast.AnnAssign)) + or not isinstance(b, ast.Expr) + or not isinstance(b.value, ast.Constant) + or not isinstance(b.value.value, str)): + continue + + doc = inspect.cleandoc(b.value.value) + + # An assignment can have multiple targets (a = b = v), but an + # annotated assignment only has one target. + targets = a.targets if isinstance(a, ast.Assign) else [a.target] + + for target in targets: + # Must be assigning to a plain name. + if not isinstance(target, ast.Name): + continue + + out[target.id] = doc + + return out + + +class ConfigValidator(ast.NodeVisitor): + + def __init__(self): + ... + + def visit_ClassDef(self, node): + # Validate class with both @config and @dataclass decorators + decorators = [ + id for d in node.decorator_list if (isinstance(d, ast.Name) and ( + (id := d.id) == 'config' or id == 'dataclass')) or + (isinstance(d, ast.Call) and (isinstance(d.func, ast.Name) and + (id := d.func.id) == 'dataclass')) + ] + + if set(decorators) == {'config', 'dataclass'}: + validate_class(node) + elif set(decorators) == {'config'}: + fail( + f"Class {node.name} with config decorator must be a dataclass.", + node) + + self.generic_visit(node) + + +def validate_class(class_node: ast.ClassDef): + attr_docs = get_attr_docs(class_node) + + for stmt in class_node.body: + # A field is defined as a class variable that has a type annotation. + if isinstance(stmt, ast.AnnAssign): + # Skip ClassVar + # see https://docs.python.org/3/library/dataclasses.html#class-variables + if isinstance(stmt.annotation, ast.Subscript) and isinstance( + stmt.annotation.value, + ast.Name) and stmt.annotation.value.id == "ClassVar": + continue + + if isinstance(stmt.target, ast.Name): + field_name = stmt.target.id + if stmt.value is None: + fail( + f"Field '{field_name}' in {class_node.name} must have " + "a default value.", stmt) + + if field_name not in attr_docs: + fail( + f"Field '{field_name}' in {class_node.name} must have " + "a docstring.", stmt) + + if isinstance(stmt.annotation, ast.Subscript) and \ + isinstance(stmt.annotation.value, ast.Name) \ + and stmt.annotation.value.id == "Union" and \ + isinstance(stmt.annotation.slice, ast.Tuple): + args = stmt.annotation.slice.elts + literal_args = [ + arg for arg in args + if isinstance(arg, ast.Subscript) and isinstance( + arg.value, ast.Name) and arg.value.id == "Literal" + ] + if len(literal_args) > 1: + fail( + f"Field '{field_name}' in {class_node.name} must " + "use a single " + "Literal type. Please use 'Literal[Literal1, " + "Literal2]' instead of 'Union[Literal1, Literal2]'" + ".", stmt) + + +def validate_ast(tree: ast.stmt): + ConfigValidator().visit(tree) + + +def validate_file(file_path: str): + try: + print(f"validating {file_path} config dataclasses ", end="") + with open(file_path, encoding="utf-8") as f: + source = f.read() + + tree = ast.parse(source, filename=file_path) + validate_ast(tree) + except ValueError as e: + print(e) + SystemExit(2) + else: + print("✅") + + +def fail(message: str, node: ast.stmt): + raise ValueError(f"❌ line({node.lineno}): {message}") + + +def main(): + for filename in sys.argv[1:]: + validate_file(filename) + + +if __name__ == "__main__": + main() diff --git a/vllm/config.py b/vllm/config.py index 623ba3aaf109..6e94573bb8e2 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -18,7 +18,7 @@ from importlib.util import find_spec from pathlib import Path from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Literal, Optional, - Protocol, TypeVar, Union, cast, get_args, get_origin) + Protocol, TypeVar, Union, cast, get_args) import regex as re import torch @@ -193,28 +193,10 @@ def config(cls: ConfigT) -> ConfigT: (i.e. `ConfigT(**json.loads(cli_arg))`). However, if a particular `ConfigT` requires custom construction from CLI (i.e. `CompilationConfig`), it can have a `from_cli` method, which will be called instead. - """ - if not is_dataclass(cls): - raise TypeError("The decorated class must be a dataclass.") - attr_docs = get_attr_docs(cls) - for f in fields(cls): - if f.init and f.default is MISSING and f.default_factory is MISSING: - raise ValueError( - f"Field '{f.name}' in {cls.__name__} must have a default value." - ) - if f.name not in attr_docs: - raise ValueError( - f"Field '{f.name}' in {cls.__name__} must have a docstring.") - - if get_origin(f.type) is Union: - args = get_args(f.type) - literal_args = [arg for arg in args if get_origin(arg) is Literal] - if len(literal_args) > 1: - raise ValueError( - f"Field '{f.name}' in {cls.__name__} must use a single " - "Literal type. Please use 'Literal[Literal1, Literal2]' " - "instead of 'Union[Literal1, Literal2]'.") + Config validation is performed by the tools/validate_config.py + script, which is invoked during the pre-commit checks. + """ return cls @@ -1798,7 +1780,7 @@ class ParallelConfig: eplb_step_interval: int = 3000 """ Interval for rearranging experts in expert parallelism. - + Note that if this is greater than the EPLB window size, only the metrics of the last `eplb_window_size` steps will be used for rearranging experts. """