From bfbdc3915a4170aeb3b66276f78975f31754dc94 Mon Sep 17 00:00:00 2001 From: Mark McLoughlin Date: Thu, 27 Feb 2025 09:51:16 -0500 Subject: [PATCH 1/4] [V1][Metrics] Add comments to organize Prometheus metrics Not a perfect set of categories, but should make it easier to navigate. Signed-off-by: Mark McLoughlin --- vllm/v1/metrics/loggers.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 5a2a1c30a9d5..1bb6271b88ea 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -105,6 +105,9 @@ def __init__(self, vllm_config: VllmConfig): max_model_len = vllm_config.model_config.max_model_len + # + # Scheduler state + # self.gauge_scheduler_running = prometheus_client.Gauge( name="vllm:num_requests_running", documentation="Number of requests in model execution batches.", @@ -115,6 +118,9 @@ def __init__(self, vllm_config: VllmConfig): documentation="Number of requests waiting to be processed.", labelnames=labelnames).labels(*labelvalues) + # + # GPU cache + # self.gauge_gpu_cache_usage = prometheus_client.Gauge( name="vllm:gpu_cache_usage_perc", documentation="GPU KV-cache usage. 1 means 100 percent usage.", @@ -132,6 +138,9 @@ def __init__(self, vllm_config: VllmConfig): "GPU prefix cache hits, in terms of number of cached blocks.", labelnames=labelnames).labels(*labelvalues) + # + # Counters + # self.counter_num_preempted_reqs = prometheus_client.Counter( name="vllm:num_preemptions_total", documentation="Cumulative number of preemption from the engine.", @@ -158,6 +167,9 @@ def __init__(self, vllm_config: VllmConfig): reason] = counter_request_success_base.labels(*(labelvalues + [str(reason)])) + # + # Histograms of counts + # self.histogram_num_prompt_tokens_request = \ prometheus_client.Histogram( name="vllm:request_prompt_tokens", @@ -179,6 +191,9 @@ def __init__(self, vllm_config: VllmConfig): buckets=build_cudagraph_buckets(vllm_config), labelnames=labelnames).labels(*labelvalues) + # + # Histogram of timing intervals + # self.histogram_time_to_first_token = \ prometheus_client.Histogram( name="vllm:time_to_first_token_seconds", @@ -238,6 +253,9 @@ def __init__(self, vllm_config: VllmConfig): buckets=request_latency_buckets, labelnames=labelnames).labels(*labelvalues) + # + # LoRA metrics + # self.gauge_lora_info: Optional[prometheus_client.Gauge] = None if vllm_config.lora_config is not None: self.labelname_max_lora = "max_lora" @@ -254,6 +272,9 @@ def __init__(self, vllm_config: VllmConfig): self.labelname_running_lora_adapters, ]) + # + # Cache config info metric + # self.log_metrics_info("cache_config", vllm_config.cache_config) def log_metrics_info(self, type: str, config_obj: SupportsMetricsInfo): From f31736e63def37349377e8a7fc2771a8ca1bf26a Mon Sep 17 00:00:00 2001 From: Mark McLoughlin Date: Thu, 27 Feb 2025 10:36:49 -0500 Subject: [PATCH 2/4] [V1][Metrics] Implement max_num_generation_tokens metrics This metric tracks the maximum of num_generation_tokens across a set of identical requests under a parallel sampling parent. It is the last remaining metric used by the example Grafana dashboard that makes sense in V1. Add some additional tracking of child requests to ParentRequest in order to facilitate this. Signed-off-by: Mark McLoughlin --- vllm/v1/engine/output_processor.py | 11 +++++++++++ vllm/v1/engine/parallel_sampling.py | 27 +++++++++++++++++++++++++-- vllm/v1/metrics/loggers.py | 11 +++++++++++ vllm/v1/metrics/stats.py | 1 + 4 files changed, 48 insertions(+), 2 deletions(-) diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 4e1d1e3bf51b..040e74ad874d 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -198,6 +198,8 @@ def abort_requests( req_state = self.request_states.pop(request_id, None) if req_state is not None: self.lora_states.abort_request(req_state) + if req_state.parent_req is not None: + req_state.parent_req.finish_child_request(request_id) def add_request( self, @@ -310,6 +312,8 @@ def process_outputs( # If req not finished in EngineCore, but Detokenizer # detected stop string, abort needed in EngineCore. reqs_to_abort.append(req_id) + if req_state.parent_req is not None: + req_state.parent_req.finish_child_request(req_id) # Track per-request stats self._update_stats_from_finished(req_state, finish_reason, @@ -352,3 +356,10 @@ def _update_stats_from_finished(self, req_state: RequestState, num_prompt_tokens=len(req_state.prompt_token_ids), req_stats=req_state.stats) self.lora_states.finish_request(req_state) + + if req_state.parent_req is None: + iteration_stats.max_num_generation_tokens_iter.append( + req_state.stats.num_generation_tokens) + else: + req_state.parent_req.observe_max_num_generation_tokens( + iteration_stats, req_state.stats.num_generation_tokens) diff --git a/vllm/v1/engine/parallel_sampling.py b/vllm/v1/engine/parallel_sampling.py index adced8973b03..056640efcbb5 100644 --- a/vllm/v1/engine/parallel_sampling.py +++ b/vllm/v1/engine/parallel_sampling.py @@ -6,6 +6,7 @@ from vllm.outputs import CompletionOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams +from vllm.v1.metrics.stats import IterationStats class ParentRequest: @@ -18,9 +19,15 @@ class ParentRequest: request_id: str sampling_params: SamplingParams + # To track the completion of child requests + child_requests: set[str] + # To aggregate child completions when not streaming output_aggregator: Optional[RequestOutput] + # To find the max number of generated tokens across all children + max_num_generation_tokens: int + # To efficiently obtain child sampling params cached_child_sampling_params: Optional[SamplingParams] @@ -29,7 +36,9 @@ def __init__(self, request_id: str, self.request_id = request_id self.sampling_params = sampling_params + self.child_requests = set() self.output_aggregator = None + self.max_num_generation_tokens = 0 self.cached_child_sampling_params = None @classmethod @@ -82,8 +91,12 @@ def get_child_info(self, index: int) -> tuple[str, SamplingParams]: Returns: (request ID, sampling_params) tuple """ - return (f"{index}_{self.request_id}", - self._get_child_sampling_params(index)) + child_req_id = f"{index}_{self.request_id}" + self.child_requests.add(child_req_id) + return (child_req_id, self._get_child_sampling_params(index)) + + def finish_child_request(self, req_id: str): + self.child_requests.remove(req_id) @property def n(self) -> int: @@ -117,3 +130,13 @@ def make_request_output( request_output.outputs = sorted(request_output.outputs, key=lambda x: x.index) return request_output + + def observe_max_num_generation_tokens(self, + iteration_stats: IterationStats, + num_generation_tokens: int): + self.max_num_generation_tokens = max(num_generation_tokens, + self.max_num_generation_tokens) + if not self.child_requests: + # All child requests have finished, we can now record the max + iteration_stats.max_num_generation_tokens_iter.append( + self.max_num_generation_tokens) diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 1bb6271b88ea..e67431014597 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -191,6 +191,14 @@ def __init__(self, vllm_config: VllmConfig): buckets=build_cudagraph_buckets(vllm_config), labelnames=labelnames).labels(*labelvalues) + self.histogram_max_num_generation_tokens_request = \ + prometheus_client.Histogram( + name="vllm:request_max_num_generation_tokens", + documentation= + "Histogram of maximum number of requested generation tokens.", + buckets=build_1_2_5_buckets(max_model_len), + labelnames=labelnames).labels(*labelvalues) + # # Histogram of timing intervals # @@ -316,6 +324,9 @@ def log(self, scheduler_stats: SchedulerStats, iteration_stats.num_prompt_tokens + \ iteration_stats.num_generation_tokens) + for max_gen_tokens in iteration_stats.max_num_generation_tokens_iter: + self.histogram_max_num_generation_tokens_request.observe( + max_gen_tokens) for ttft in iteration_stats.time_to_first_tokens_iter: self.histogram_time_to_first_token.observe(ttft) for tpot in iteration_stats.time_per_output_tokens_iter: diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index abdca95670e1..8639a10275f3 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -81,6 +81,7 @@ def __init__(self): self.num_prompt_tokens = 0 self.num_preempted_reqs = 0 self.finished_requests: list[FinishedRequestStats] = [] + self.max_num_generation_tokens_iter: list[int] = [] self.time_to_first_tokens_iter: list[float] = [] self.time_per_output_tokens_iter: list[float] = [] self.waiting_lora_adapters: dict[str, int] = {} From d66c82bbf863372dd899b9e270522364294f87c8 Mon Sep 17 00:00:00 2001 From: Mark McLoughlin Date: Mon, 3 Mar 2025 11:30:44 -0500 Subject: [PATCH 3/4] [V1][Metrics] Implement vllm:request_params_n This records SamplingParams.n from every parent request in a histogram. Signed-off-by: Mark McLoughlin --- tests/entrypoints/openai/test_metrics.py | 3 +++ vllm/v1/engine/output_processor.py | 9 +++------ vllm/v1/engine/parallel_sampling.py | 24 ++++++++++++++++++------ vllm/v1/metrics/loggers.py | 9 +++++++++ vllm/v1/metrics/stats.py | 1 + 5 files changed, 34 insertions(+), 12 deletions(-) diff --git a/tests/entrypoints/openai/test_metrics.py b/tests/entrypoints/openai/test_metrics.py index 39ce4ba23548..e11d49eed04d 100644 --- a/tests/entrypoints/openai/test_metrics.py +++ b/tests/entrypoints/openai/test_metrics.py @@ -239,6 +239,9 @@ async def test_metrics_counts(server: RemoteOpenAIServer, "vllm:request_generation_tokens_sum", "vllm:request_generation_tokens_bucket", "vllm:request_generation_tokens_count", + "vllm:request_params_n_sum", + "vllm:request_params_n_bucket", + "vllm:request_params_n_count", "vllm:time_to_first_token_seconds_sum", "vllm:time_to_first_token_seconds_bucket", "vllm:time_to_first_token_seconds_count", diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 040e74ad874d..23d66f18ec9e 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -357,9 +357,6 @@ def _update_stats_from_finished(self, req_state: RequestState, req_stats=req_state.stats) self.lora_states.finish_request(req_state) - if req_state.parent_req is None: - iteration_stats.max_num_generation_tokens_iter.append( - req_state.stats.num_generation_tokens) - else: - req_state.parent_req.observe_max_num_generation_tokens( - iteration_stats, req_state.stats.num_generation_tokens) + ParentRequest.observe_finished_request( + req_state.parent_req, iteration_stats, + req_state.stats.num_generation_tokens) diff --git a/vllm/v1/engine/parallel_sampling.py b/vllm/v1/engine/parallel_sampling.py index 056640efcbb5..4e2c78173b51 100644 --- a/vllm/v1/engine/parallel_sampling.py +++ b/vllm/v1/engine/parallel_sampling.py @@ -131,12 +131,24 @@ def make_request_output( key=lambda x: x.index) return request_output - def observe_max_num_generation_tokens(self, - iteration_stats: IterationStats, - num_generation_tokens: int): + def observe_num_generation_tokens(self, num_generation_tokens: int): self.max_num_generation_tokens = max(num_generation_tokens, self.max_num_generation_tokens) - if not self.child_requests: - # All child requests have finished, we can now record the max + return self.max_num_generation_tokens + + @staticmethod + def observe_finished_request(parent_req: Optional['ParentRequest'], + iteration_stats: IterationStats, + num_generation_tokens: int): + + n_param = parent_req.n if parent_req is not None else 1 + + if parent_req is not None: + num_generation_tokens = parent_req.observe_num_generation_tokens( + num_generation_tokens) + + # Child requests finished, we can now record to iteration stats + if parent_req is None or not parent_req.child_requests: iteration_stats.max_num_generation_tokens_iter.append( - self.max_num_generation_tokens) + num_generation_tokens) + iteration_stats.n_params_iter.append(n_param) diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index e67431014597..48d09ff4a0ca 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -199,6 +199,13 @@ def __init__(self, vllm_config: VllmConfig): buckets=build_1_2_5_buckets(max_model_len), labelnames=labelnames).labels(*labelvalues) + self.histogram_n_request = \ + prometheus_client.Histogram( + name="vllm:request_params_n", + documentation="Histogram of the n request parameter.", + buckets=[1, 2, 5, 10, 20], + labelnames=labelnames).labels(*labelvalues) + # # Histogram of timing intervals # @@ -327,6 +334,8 @@ def log(self, scheduler_stats: SchedulerStats, for max_gen_tokens in iteration_stats.max_num_generation_tokens_iter: self.histogram_max_num_generation_tokens_request.observe( max_gen_tokens) + for n_param in iteration_stats.n_params_iter: + self.histogram_n_request.observe(n_param) for ttft in iteration_stats.time_to_first_tokens_iter: self.histogram_time_to_first_token.observe(ttft) for tpot in iteration_stats.time_per_output_tokens_iter: diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index 8639a10275f3..289af1585439 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -82,6 +82,7 @@ def __init__(self): self.num_preempted_reqs = 0 self.finished_requests: list[FinishedRequestStats] = [] self.max_num_generation_tokens_iter: list[int] = [] + self.n_params_iter: list[int] = [] self.time_to_first_tokens_iter: list[float] = [] self.time_per_output_tokens_iter: list[float] = [] self.waiting_lora_adapters: dict[str, int] = {} From fdab236319bd0ff6aa9c8edebf99a22d2f4ae57f Mon Sep 17 00:00:00 2001 From: Mark McLoughlin Date: Mon, 3 Mar 2025 11:44:28 -0500 Subject: [PATCH 4/4] [V1][Metrics] Implement vllm:request_params_max_tokens This just observes SamplingParams.max_tokens values in a histogram. Signed-off-by: Mark McLoughlin --- tests/entrypoints/openai/test_metrics.py | 3 +++ vllm/v1/engine/output_processor.py | 5 +++++ vllm/v1/metrics/loggers.py | 9 +++++++++ vllm/v1/metrics/stats.py | 3 +++ 4 files changed, 20 insertions(+) diff --git a/tests/entrypoints/openai/test_metrics.py b/tests/entrypoints/openai/test_metrics.py index e11d49eed04d..2bffd0ce138e 100644 --- a/tests/entrypoints/openai/test_metrics.py +++ b/tests/entrypoints/openai/test_metrics.py @@ -242,6 +242,9 @@ async def test_metrics_counts(server: RemoteOpenAIServer, "vllm:request_params_n_sum", "vllm:request_params_n_bucket", "vllm:request_params_n_count", + "vllm:request_params_max_tokens_sum", + "vllm:request_params_max_tokens_bucket", + "vllm:request_params_max_tokens_count", "vllm:time_to_first_token_seconds_sum", "vllm:time_to_first_token_seconds_bucket", "vllm:time_to_first_token_seconds_count", diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 23d66f18ec9e..75c638a854f8 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -36,6 +36,7 @@ def __init__( prompt_token_ids: list[int], logprobs_processor: LogprobsProcessor, detokenizer: IncrementalDetokenizer, + max_tokens_param: Optional[int], arrival_time: float, queue: Optional[asyncio.Queue[RequestOutput]], log_stats: bool, @@ -50,6 +51,7 @@ def __init__( self.prompt_len = len(prompt_token_ids) self.logprobs_processor = logprobs_processor self.detokenizer = detokenizer + self.max_tokens_param = max_tokens_param self.is_prefilling = True self.queue = queue @@ -83,6 +85,8 @@ def from_new_request( tokenizer=tokenizer, request=request, ), + max_tokens_param=(request.sampling_params.max_tokens if + request.sampling_params is not None else None), arrival_time=request.arrival_time, queue=queue, log_stats=log_stats, @@ -354,6 +358,7 @@ def _update_stats_from_finished(self, req_state: RequestState, iteration_stats.update_from_finished_request( finish_reason=finish_reason, num_prompt_tokens=len(req_state.prompt_token_ids), + max_tokens_param=req_state.max_tokens_param, req_stats=req_state.stats) self.lora_states.finish_request(req_state) diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 48d09ff4a0ca..a557367cf48a 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -206,6 +206,13 @@ def __init__(self, vllm_config: VllmConfig): buckets=[1, 2, 5, 10, 20], labelnames=labelnames).labels(*labelvalues) + self.histogram_max_tokens_request = \ + prometheus_client.Histogram( + name="vllm:request_params_max_tokens", + documentation="Histogram of the max_tokens request parameter.", + buckets=build_1_2_5_buckets(max_model_len), + labelnames=labelnames).labels(*labelvalues) + # # Histogram of timing intervals # @@ -357,6 +364,8 @@ def log(self, scheduler_stats: SchedulerStats, finished_request.num_prompt_tokens) self.histogram_num_generation_tokens_request.observe( finished_request.num_generation_tokens) + self.histogram_max_tokens_request.observe( + finished_request.max_tokens_param) if self.gauge_lora_info is not None: running_lora_adapters = \ diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index 289af1585439..14ec7d2d7463 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -66,6 +66,7 @@ class FinishedRequestStats: e2e_latency: float = 0.0 num_prompt_tokens: int = 0 num_generation_tokens: int = 0 + max_tokens_param: Optional[int] = None queued_time: float = 0.0 prefill_time: float = 0.0 inference_time: float = 0.0 @@ -152,6 +153,7 @@ def update_from_events(self, req_id: str, events: list["EngineCoreEvent"], def update_from_finished_request(self, finish_reason: "FinishReason", num_prompt_tokens: int, + max_tokens_param: Optional[int], req_stats: RequestStateStats): e2e_latency = self._time_since(req_stats.arrival_time) @@ -175,6 +177,7 @@ def update_from_finished_request(self, finish_reason: "FinishReason", e2e_latency=e2e_latency, num_prompt_tokens=num_prompt_tokens, num_generation_tokens=req_stats.num_generation_tokens, + max_tokens_param=max_tokens_param, queued_time=queued_time, prefill_time=prefill_time, inference_time=inference_time,