diff --git a/tests/v1/engine/test_async_llm.py b/tests/v1/engine/test_async_llm.py index da0639678af8..5d52ad5f5328 100644 --- a/tests/v1/engine/test_async_llm.py +++ b/tests/v1/engine/test_async_llm.py @@ -3,16 +3,19 @@ import asyncio from contextlib import ExitStack from typing import Optional +from unittest.mock import MagicMock import pytest from vllm import SamplingParams from vllm.assets.image import ImageAsset +from vllm.config import VllmConfig from vllm.engine.arg_utils import AsyncEngineArgs from vllm.inputs import PromptType from vllm.platforms import current_platform from vllm.sampling_params import RequestOutputKind from vllm.v1.engine.async_llm import AsyncLLM +from vllm.v1.metrics.loggers import LoggingStatLogger if not current_platform.is_cuda(): pytest.skip(reason="V1 currently only supported on CUDA.", @@ -216,3 +219,33 @@ async def test_finished_flag(monkeypatch: pytest.MonkeyPatch, n: int, # Assert only the last output has the finished flag set assert all(not out.finished for out in outputs[:-1]) assert outputs[-1].finished + + +class MockLoggingStatLogger(LoggingStatLogger): + + def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): + super().__init__(vllm_config, engine_index) + self.log = MagicMock() + + +@pytest.mark.asyncio +async def test_customize_loggers(monkeypatch): + """Test that we can customize the loggers. + If a customized logger is provided at the init, it should + be used directly. + """ + + with monkeypatch.context() as m, ExitStack() as after: + m.setenv("VLLM_USE_V1", "1") + + engine = AsyncLLM.from_engine_args( + TEXT_ENGINE_ARGS, + stat_loggers=[MockLoggingStatLogger], + ) + after.callback(engine.shutdown) + + await engine.do_log_stats() + + assert len(engine.stat_loggers) == 1 + assert len(engine.stat_loggers[0]) == 1 + engine.stat_loggers[0][0].log.assert_called_once() diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index c33535b3d360..a1eb5c8ba185 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 import asyncio -import logging from collections.abc import AsyncGenerator, Mapping from copy import copy from typing import Optional, Union @@ -33,8 +32,8 @@ from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.engine.processor import Processor from vllm.v1.executor.abstract import Executor -from vllm.v1.metrics.loggers import (LoggingStatLogger, PrometheusStatLogger, - StatLoggerBase) +from vllm.v1.metrics.loggers import (StatLoggerBase, StatLoggerFactory, + setup_default_loggers) from vllm.v1.metrics.stats import IterationStats, SchedulerStats logger = init_logger(__name__) @@ -52,7 +51,28 @@ def __init__( use_cached_outputs: bool = False, log_requests: bool = True, start_engine_loop: bool = True, + stat_loggers: Optional[list[StatLoggerFactory]] = None, ) -> None: + """ + Create an AsyncLLM. + + Args: + vllm_config: global configuration. + executor_class: an Executor impl, e.g. MultiprocExecutor. + log_stats: Whether to log stats. + usage_context: Usage context of the LLM. + mm_registry: Multi-modal registry. + use_cached_outputs: Whether to use cached outputs. + log_requests: Whether to log requests. + start_engine_loop: Whether to start the engine loop. + stat_loggers: customized stat loggers for the engine. + If not provided, default stat loggers will be used. + PLEASE BE AWARE THAT STAT LOGGER IS NOT STABLE + IN V1, AND ITS BASE CLASS INTERFACE MIGHT CHANGE. + + Returns: + None + """ if not envs.VLLM_USE_V1: raise ValueError( "Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. " @@ -66,15 +86,12 @@ def __init__( self.log_stats = log_stats # Set up stat loggers; independent set for each DP rank. - self.stat_loggers: list[list[StatLoggerBase]] = [] - if self.log_stats: - for i in range(vllm_config.parallel_config.data_parallel_size): - loggers: list[StatLoggerBase] = [] - if logger.isEnabledFor(logging.INFO): - loggers.append(LoggingStatLogger(engine_index=i)) - loggers.append( - PrometheusStatLogger(vllm_config, engine_index=i)) - self.stat_loggers.append(loggers) + self.stat_loggers: list[list[StatLoggerBase]] = setup_default_loggers( + vllm_config=vllm_config, + log_stats=self.log_stats, + engine_num=vllm_config.parallel_config.data_parallel_size, + custom_stat_loggers=stat_loggers, + ) # Tokenizer (+ ensure liveness if running in another process). self.tokenizer = init_tokenizer_from_configs( @@ -118,7 +135,7 @@ def from_vllm_config( vllm_config: VllmConfig, start_engine_loop: bool = True, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - stat_loggers: Optional[dict[str, StatLoggerBase]] = None, + stat_loggers: Optional[list[StatLoggerFactory]] = None, disable_log_requests: bool = False, disable_log_stats: bool = False, ) -> "AsyncLLM": @@ -129,17 +146,12 @@ def from_vllm_config( "AsyncLLMEngine.from_vllm_config(...) or explicitly set " "VLLM_USE_V1=0 or 1 and report this issue on Github.") - # FIXME(rob): refactor VllmConfig to include the StatLoggers - # include StatLogger in the Oracle decision. - if stat_loggers is not None: - raise ValueError("Custom StatLoggers are not yet supported on V1. " - "Explicitly set VLLM_USE_V1=0 to disable V1.") - # Create the LLMEngine. return cls( vllm_config=vllm_config, executor_class=Executor.get_class(vllm_config), start_engine_loop=start_engine_loop, + stat_loggers=stat_loggers, log_requests=not disable_log_requests, log_stats=not disable_log_stats, usage_context=usage_context, @@ -151,6 +163,7 @@ def from_engine_args( engine_args: AsyncEngineArgs, start_engine_loop: bool = True, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[list[StatLoggerFactory]] = None, ) -> "AsyncLLM": """Create an AsyncLLM from the EngineArgs.""" @@ -166,6 +179,7 @@ def from_engine_args( log_stats=not engine_args.disable_log_stats, start_engine_loop=start_engine_loop, usage_context=usage_context, + stat_loggers=stat_loggers, ) def __del__(self): diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index a07595a552af..ac2ee065f09f 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -10,7 +10,6 @@ from vllm.config import ParallelConfig, VllmConfig from vllm.distributed import stateless_destroy_torch_distributed_process_group from vllm.engine.arg_utils import EngineArgs -from vllm.engine.metrics_types import StatLoggerBase from vllm.inputs import PromptType from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -28,6 +27,7 @@ from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.engine.processor import Processor from vllm.v1.executor.abstract import Executor +from vllm.v1.metrics.loggers import StatLoggerFactory logger = init_logger(__name__) @@ -43,7 +43,7 @@ def __init__( executor_class: type[Executor], log_stats: bool, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - stat_loggers: Optional[dict[str, StatLoggerBase]] = None, + stat_loggers: Optional[list[StatLoggerFactory]] = None, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, use_cached_outputs: bool = False, multiprocess_mode: bool = False, @@ -55,6 +55,11 @@ def __init__( "LLMEngine.from_vllm_config(...) or explicitly set " "VLLM_USE_V1=0 or 1 and report this issue on Github.") + if stat_loggers is not None: + raise NotImplementedError( + "Passing StatLoggers to LLMEngine in V1 is not yet supported. " + "Set VLLM_USE_V1=0 and file and issue on Github.") + self.vllm_config = vllm_config self.model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config @@ -101,14 +106,9 @@ def from_vllm_config( cls, vllm_config: VllmConfig, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - stat_loggers: Optional[dict[str, StatLoggerBase]] = None, + stat_loggers: Optional[list[StatLoggerFactory]] = None, disable_log_stats: bool = False, ) -> "LLMEngine": - if stat_loggers is not None: - raise NotImplementedError( - "Passing StatLoggers to V1 is not yet supported. " - "Set VLLM_USE_V1=0 and file and issue on Github.") - return cls(vllm_config=vllm_config, executor_class=Executor.get_class(vllm_config), log_stats=(not disable_log_stats), @@ -121,7 +121,7 @@ def from_engine_args( cls, engine_args: EngineArgs, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - stat_loggers: Optional[dict[str, StatLoggerBase]] = None, + stat_loggers: Optional[list[StatLoggerFactory]] = None, enable_multiprocessing: bool = False, ) -> "LLMEngine": """Creates an LLM engine from the engine arguments.""" diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 547e60467632..22d1d9724c8c 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -1,8 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 +import logging import time from abc import ABC, abstractmethod -from typing import Optional +from typing import Callable, Optional import numpy as np import prometheus_client @@ -18,8 +19,20 @@ _LOCAL_LOGGING_INTERVAL_SEC = 5.0 +StatLoggerFactory = Callable[[VllmConfig, int], "StatLoggerBase"] + class StatLoggerBase(ABC): + """Interface for logging metrics. + + API users may define custom loggers that implement this interface. + However, note that the `SchedulerStats` and `IterationStats` classes + are not considered stable interfaces and may change in future versions. + """ + + @abstractmethod + def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): + ... @abstractmethod def record(self, scheduler_stats: SchedulerStats, @@ -32,7 +45,7 @@ def log(self): # noqa class LoggingStatLogger(StatLoggerBase): - def __init__(self, engine_index: int = 0): + def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): self.engine_index = engine_index self._reset(time.monotonic()) self.last_scheduler_stats = SchedulerStats() @@ -462,3 +475,31 @@ def build_cudagraph_buckets(vllm_config: VllmConfig) -> list[int]: return buckets else: return [1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8096] + + +def setup_default_loggers( + vllm_config: VllmConfig, + log_stats: bool, + engine_num: int, + custom_stat_loggers: Optional[list[StatLoggerFactory]] = None, +) -> list[list[StatLoggerBase]]: + """Setup logging and prometheus metrics.""" + if not log_stats: + return [] + + factories: list[StatLoggerFactory] + if custom_stat_loggers is not None: + factories = custom_stat_loggers + else: + factories = [PrometheusStatLogger] + if logger.isEnabledFor(logging.INFO): + factories.append(LoggingStatLogger) + + stat_loggers: list[list[StatLoggerBase]] = [] + for i in range(engine_num): + per_engine_stat_loggers: list[StatLoggerBase] = [] + for logger_factory in factories: + per_engine_stat_loggers.append(logger_factory(vllm_config, i)) + stat_loggers.append(per_engine_stat_loggers) + + return stat_loggers