Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions tests/v1/engine/test_async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,19 @@
import asyncio
from contextlib import ExitStack
from typing import Optional
from unittest.mock import MagicMock

import pytest

from vllm import SamplingParams
from vllm.assets.image import ImageAsset
from vllm.config import VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.inputs import PromptType
from vllm.platforms import current_platform
from vllm.sampling_params import RequestOutputKind
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.metrics.loggers import LoggingStatLogger

if not current_platform.is_cuda():
pytest.skip(reason="V1 currently only supported on CUDA.",
Expand Down Expand Up @@ -216,3 +219,33 @@ async def test_finished_flag(monkeypatch: pytest.MonkeyPatch, n: int,
# Assert only the last output has the finished flag set
assert all(not out.finished for out in outputs[:-1])
assert outputs[-1].finished


class MockLoggingStatLogger(LoggingStatLogger):

def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
super().__init__(vllm_config, engine_index)
self.log = MagicMock()


@pytest.mark.asyncio
async def test_customize_loggers(monkeypatch):
"""Test that we can customize the loggers.
If a customized logger is provided at the init, it should
be used directly.
"""

with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1")

engine = AsyncLLM.from_engine_args(
TEXT_ENGINE_ARGS,
stat_loggers=[MockLoggingStatLogger],
)
after.callback(engine.shutdown)

await engine.do_log_stats()

assert len(engine.stat_loggers) == 1
assert len(engine.stat_loggers[0]) == 1
engine.stat_loggers[0][0].log.assert_called_once()
52 changes: 33 additions & 19 deletions vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio
import logging
from collections.abc import AsyncGenerator, Mapping
from copy import copy
from typing import Optional, Union
Expand Down Expand Up @@ -33,8 +32,8 @@
from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor
from vllm.v1.metrics.loggers import (LoggingStatLogger, PrometheusStatLogger,
StatLoggerBase)
from vllm.v1.metrics.loggers import (StatLoggerBase, StatLoggerFactory,
setup_default_loggers)
from vllm.v1.metrics.stats import IterationStats, SchedulerStats

logger = init_logger(__name__)
Expand All @@ -52,7 +51,28 @@ def __init__(
use_cached_outputs: bool = False,
log_requests: bool = True,
start_engine_loop: bool = True,
stat_loggers: Optional[list[StatLoggerFactory]] = None,
) -> None:
"""
Create an AsyncLLM.

Args:
vllm_config: global configuration.
executor_class: an Executor impl, e.g. MultiprocExecutor.
log_stats: Whether to log stats.
usage_context: Usage context of the LLM.
mm_registry: Multi-modal registry.
use_cached_outputs: Whether to use cached outputs.
log_requests: Whether to log requests.
start_engine_loop: Whether to start the engine loop.
stat_loggers: customized stat loggers for the engine.
If not provided, default stat loggers will be used.
PLEASE BE AWARE THAT STAT LOGGER IS NOT STABLE
IN V1, AND ITS BASE CLASS INTERFACE MIGHT CHANGE.

Returns:
None
"""
if not envs.VLLM_USE_V1:
raise ValueError(
"Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. "
Expand All @@ -66,15 +86,12 @@ def __init__(
self.log_stats = log_stats

# Set up stat loggers; independent set for each DP rank.
self.stat_loggers: list[list[StatLoggerBase]] = []
if self.log_stats:
for i in range(vllm_config.parallel_config.data_parallel_size):
loggers: list[StatLoggerBase] = []
if logger.isEnabledFor(logging.INFO):
loggers.append(LoggingStatLogger(engine_index=i))
loggers.append(
PrometheusStatLogger(vllm_config, engine_index=i))
self.stat_loggers.append(loggers)
self.stat_loggers: list[list[StatLoggerBase]] = setup_default_loggers(
vllm_config=vllm_config,
log_stats=self.log_stats,
engine_num=vllm_config.parallel_config.data_parallel_size,
custom_stat_loggers=stat_loggers,
)

# Tokenizer (+ ensure liveness if running in another process).
self.tokenizer = init_tokenizer_from_configs(
Expand Down Expand Up @@ -118,7 +135,7 @@ def from_vllm_config(
vllm_config: VllmConfig,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[dict[str, StatLoggerBase]] = None,
stat_loggers: Optional[list[StatLoggerFactory]] = None,
disable_log_requests: bool = False,
disable_log_stats: bool = False,
) -> "AsyncLLM":
Expand All @@ -129,17 +146,12 @@ def from_vllm_config(
"AsyncLLMEngine.from_vllm_config(...) or explicitly set "
"VLLM_USE_V1=0 or 1 and report this issue on Github.")

# FIXME(rob): refactor VllmConfig to include the StatLoggers
# include StatLogger in the Oracle decision.
if stat_loggers is not None:
raise ValueError("Custom StatLoggers are not yet supported on V1. "
"Explicitly set VLLM_USE_V1=0 to disable V1.")

# Create the LLMEngine.
return cls(
vllm_config=vllm_config,
executor_class=Executor.get_class(vllm_config),
start_engine_loop=start_engine_loop,
stat_loggers=stat_loggers,
log_requests=not disable_log_requests,
log_stats=not disable_log_stats,
usage_context=usage_context,
Expand All @@ -151,6 +163,7 @@ def from_engine_args(
engine_args: AsyncEngineArgs,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[list[StatLoggerFactory]] = None,
) -> "AsyncLLM":
"""Create an AsyncLLM from the EngineArgs."""

Expand All @@ -166,6 +179,7 @@ def from_engine_args(
log_stats=not engine_args.disable_log_stats,
start_engine_loop=start_engine_loop,
usage_context=usage_context,
stat_loggers=stat_loggers,
)

def __del__(self):
Expand Down
18 changes: 9 additions & 9 deletions vllm/v1/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.metrics_types import StatLoggerBase
from vllm.inputs import PromptType
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
Expand All @@ -28,6 +27,7 @@
from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor
from vllm.v1.metrics.loggers import StatLoggerFactory

logger = init_logger(__name__)

Expand All @@ -43,7 +43,7 @@ def __init__(
executor_class: type[Executor],
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[dict[str, StatLoggerBase]] = None,
stat_loggers: Optional[list[StatLoggerFactory]] = None,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
use_cached_outputs: bool = False,
multiprocess_mode: bool = False,
Expand All @@ -55,6 +55,11 @@ def __init__(
"LLMEngine.from_vllm_config(...) or explicitly set "
"VLLM_USE_V1=0 or 1 and report this issue on Github.")

if stat_loggers is not None:
raise NotImplementedError(
"Passing StatLoggers to LLMEngine in V1 is not yet supported. "
"Set VLLM_USE_V1=0 and file and issue on Github.")

self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
Expand Down Expand Up @@ -101,14 +106,9 @@ def from_vllm_config(
cls,
vllm_config: VllmConfig,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[dict[str, StatLoggerBase]] = None,
stat_loggers: Optional[list[StatLoggerFactory]] = None,
disable_log_stats: bool = False,
) -> "LLMEngine":
if stat_loggers is not None:
raise NotImplementedError(
"Passing StatLoggers to V1 is not yet supported. "
"Set VLLM_USE_V1=0 and file and issue on Github.")

return cls(vllm_config=vllm_config,
executor_class=Executor.get_class(vllm_config),
log_stats=(not disable_log_stats),
Expand All @@ -121,7 +121,7 @@ def from_engine_args(
cls,
engine_args: EngineArgs,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[dict[str, StatLoggerBase]] = None,
stat_loggers: Optional[list[StatLoggerFactory]] = None,
enable_multiprocessing: bool = False,
) -> "LLMEngine":
"""Creates an LLM engine from the engine arguments."""
Expand Down
45 changes: 43 additions & 2 deletions vllm/v1/metrics/loggers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# SPDX-License-Identifier: Apache-2.0

import logging
import time
from abc import ABC, abstractmethod
from typing import Optional
from typing import Callable, Optional

import numpy as np
import prometheus_client
Expand All @@ -18,8 +19,20 @@

_LOCAL_LOGGING_INTERVAL_SEC = 5.0

StatLoggerFactory = Callable[[VllmConfig, int], "StatLoggerBase"]


class StatLoggerBase(ABC):
"""Interface for logging metrics.

API users may define custom loggers that implement this interface.
However, note that the `SchedulerStats` and `IterationStats` classes
are not considered stable interfaces and may change in future versions.
"""

@abstractmethod
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
...

@abstractmethod
def record(self, scheduler_stats: SchedulerStats,
Expand All @@ -32,7 +45,7 @@ def log(self): # noqa

class LoggingStatLogger(StatLoggerBase):

def __init__(self, engine_index: int = 0):
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
self.engine_index = engine_index
self._reset(time.monotonic())
self.last_scheduler_stats = SchedulerStats()
Expand Down Expand Up @@ -462,3 +475,31 @@ def build_cudagraph_buckets(vllm_config: VllmConfig) -> list[int]:
return buckets
else:
return [1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8096]


def setup_default_loggers(
vllm_config: VllmConfig,
log_stats: bool,
engine_num: int,
custom_stat_loggers: Optional[list[StatLoggerFactory]] = None,
) -> list[list[StatLoggerBase]]:
"""Setup logging and prometheus metrics."""
if not log_stats:
return []

factories: list[StatLoggerFactory]
if custom_stat_loggers is not None:
factories = custom_stat_loggers
else:
factories = [PrometheusStatLogger]
if logger.isEnabledFor(logging.INFO):
factories.append(LoggingStatLogger)

stat_loggers: list[list[StatLoggerBase]] = []
for i in range(engine_num):
per_engine_stat_loggers: list[StatLoggerBase] = []
for logger_factory in factories:
per_engine_stat_loggers.append(logger_factory(vllm_config, i))
stat_loggers.append(per_engine_stat_loggers)

return stat_loggers