diff --git a/components/metrics/src/bin/mock_worker.rs b/components/metrics/src/bin/mock_worker.rs index 6278de73ce..a2238ea5b1 100644 --- a/components/metrics/src/bin/mock_worker.rs +++ b/components/metrics/src/bin/mock_worker.rs @@ -14,7 +14,7 @@ // limitations under the License. use dynamo_llm::kv_router::{ - protocols::ForwardPassMetrics, scheduler::KVHitRateEvent, KV_HIT_RATE_SUBJECT, + protocols::ForwardPassMetrics, protocols::KVHitRateEvent, KV_HIT_RATE_SUBJECT, }; use dynamo_runtime::{ component::{service::EndpointStats, Namespace}, @@ -89,7 +89,7 @@ async fn mock_event_publisher(namespace: Namespace) { let overlap_blocks = rand::rng().random_range(0..=isl_blocks); let event = KVHitRateEvent { - worker_id, + worker: worker_id, isl_blocks, overlap_blocks, }; diff --git a/components/metrics/src/lib.rs b/components/metrics/src/lib.rs index b928938490..023dd08a6a 100644 --- a/components/metrics/src/lib.rs +++ b/components/metrics/src/lib.rs @@ -84,8 +84,7 @@ use std::net::SocketAddr; use std::time::Duration as StdDuration; use dynamo_llm::kv_router::protocols::ForwardPassMetrics; -use dynamo_llm::kv_router::scheduler::Endpoint; -use dynamo_llm::kv_router::scoring::ProcessedEndpoints; +use dynamo_llm::kv_router::scoring::{Endpoint, ProcessedEndpoints}; use dynamo_runtime::{ distributed::Component, error, service::EndpointInfo, utils::Duration, Result, @@ -451,6 +450,8 @@ impl PrometheusMetrics { let worker_id = worker_id.to_string(); let metrics = endpoint.data.clone(); + // NOTE: using metrics[0] just to get the first dp_rank for now + // to not change the existing behavior self.set_worker_gauge( &self.kv_blocks_active, config, diff --git a/components/metrics/src/main.rs b/components/metrics/src/main.rs index fa8186d07a..c6a0996280 100644 --- a/components/metrics/src/main.rs +++ b/components/metrics/src/main.rs @@ -27,7 +27,7 @@ //! - ISL Blocks: Cumulative count of total blocks in all KV hit rate events //! - Overlap Blocks: Cumulative count of blocks that were already in the KV cache use clap::Parser; -use dynamo_llm::kv_router::scheduler::KVHitRateEvent; +use dynamo_llm::kv_router::protocols::{KVHitRateEvent, WorkerDp}; use dynamo_llm::kv_router::KV_HIT_RATE_SUBJECT; use dynamo_runtime::{ error, logging, @@ -180,14 +180,15 @@ async fn app(runtime: Runtime) -> Result<()> { tracing::debug!("Successfully subscribed to KV hit rate events"); while let Some(msg) = subscriber.next().await { - match serde_json::from_slice::(&msg.payload) { + match serde_json::from_slice::>(&msg.payload) { Ok(event) => { // TODO: Lower to debug let cache_hit_pct = (event.overlap_blocks as f64 / event.isl_blocks as f64) * 100.0; tracing::debug!( - "Received KV hit rate event: worker_id={}, isl_blocks={}, overlap_blocks={}, cache_hit_pct={:.2}%", - event.worker_id, + "Received KV hit rate event: worker_id={}, dp_rank={}, isl_blocks={}, overlap_blocks={}, cache_hit_pct={:.2}%", + event.worker.worker_id, + event.worker.dp_rank.unwrap_or(0), event.isl_blocks, event.overlap_blocks, cache_hit_pct @@ -197,7 +198,8 @@ async fn app(runtime: Runtime) -> Result<()> { let mut metrics = metrics_collector_clone.lock().await; metrics.update_kv_hit_rate( &config_clone, - event.worker_id, + // TODO: this will not take care of dp ranks + event.worker.worker_id, event.isl_blocks, event.overlap_blocks, ); diff --git a/components/router/src/main.rs b/components/router/src/main.rs index 3546a9bb30..e8e7d08c72 100644 --- a/components/router/src/main.rs +++ b/components/router/src/main.rs @@ -25,7 +25,7 @@ use std::sync::Arc; use clap::Parser; use dynamo_llm::kv_router::{ - protocols::WorkerSelectionResult, + protocols::{WorkerDp, WorkerSelectionResult}, scheduler::{DefaultWorkerSelector, KvSchedulerError, SchedulingRequest}, scoring::ProcessedEndpoints, KvRouter, WorkerSelector, @@ -89,7 +89,7 @@ impl WorkerSelector for CustomWorkerSelector { workers: &ProcessedEndpoints, request: &SchedulingRequest, block_size: usize, - ) -> Result { + ) -> Result, KvSchedulerError> { // customize logic here // F12 into [DefaultWorkerSelector] to see the original logic self.0.select_worker(workers, request, block_size) diff --git a/examples/vllm_v1/README.md b/examples/vllm_v1/README.md index 39d6c0e1db..9434d06a99 100644 --- a/examples/vllm_v1/README.md +++ b/examples/vllm_v1/README.md @@ -17,16 +17,15 @@ limitations under the License. # vLLM Deployment Examples -This directory contains examples for deploying vLLM models in both aggregated and disaggregated configurations. +This directory contains examples for deploying vLLM models aggregated with with DP. ## Prerequisites 1. Install vLLM: ```bash -# Note: Currently requires installation from main branch -# From vLLM 0.8.6 onwards, you can install directly from wheel git clone https://github.com/vllm-project/vllm.git -VLLM_USE_PRECOMPILED=1 uv pip install --editable ./vllm/ +cd vllm && git checkout d459fae0a2c464e28680bc6d564c1de1b295029e +VLLM_USE_PRECOMPILED=1 uv pip install --editable . ``` 2. Start required services: @@ -36,78 +35,46 @@ docker compose -f deploy/metrics/docker-compose.yml up -d ## Running the Server -### Aggregated Deployment +### Aggregated Deployment with Multiple disconnected DP engines + +Serves the leader AsyncLLM engine + number of dp ranks you specify ```bash cd examples/vllm_v1 dynamo serve graphs.agg:Frontend -f configs/agg.yaml ``` -### Disaggregated Deployment -```bash -cd examples/vllm_v1 -dynamo serve graphs.disagg:Frontend -f configs/disagg.yaml +To run other dp ranks headless on same node or other nodes can run + +``` +VLLM_LOGGING_LEVEL=DEBUG CUDA_VISIBLE_DEVICES=1 VLLM_USE_V1=1 vllm serve Qwen/Qwen3-0.6B -dp 1 -dpr 1 --data-parallel-address 127.0.0.1 --data-parallel-rpc-port 62300 --data-parallel-size-local 1 --enforce-eager --headless --kv-events-config '{"enable_kv_cache_events": true, "publisher": "zmq"}' --enable-prefix-caching ``` -## Testing the API +To test can run this curl reqeust. KV Routing will mean this will keep routing to a single node, so you will need to switch it up to see routing to different dp workers. -Send a test request using curl: -```bash -curl localhost:8000/v1/completions \ +``` +curl localhost:8000/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ - "model": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B", - "prompt": "In the heart of Eldoria...", - "stream": false, + "model": "Qwen/Qwen3-0.6B", + "messages": [ + { + "role": "user", + "content": "In the heart of Eldoria, an ancient land of boundless magic and mysterious creatures, lies the long-forgotten city of Aeloria. Once a beacon of knowledge and power, Aeloria was buried beneath the shifting sands of time, lost to the world for centuries. You are an intrepid explorer, known for your unparalleled curiosity and courage, who has stumbled upon an ancient map hinting at ests that Aeloria holds a secret so profound that it has the potential to reshape the very fabric of reality. Your journey will take you through treacherous deserts, enchanted forests, and across perilous mountain ranges. Your Task: Character Background: Develop a detailed background for your character. Describe their motivations for seeking out Aeloria, their skills and weaknesses, and any personal connections to the ancient city or its legends. Are they driven by a quest for knowledge, a search for lost familt clue is hidden." + } + ], + "stream":false, "max_tokens": 30 }' -``` - -For more detailed explenations, refer to the main [LLM examples README](../llm/README.md). - - - -## Deepseek R1 - -To run DSR1 model please first follow the Ray setup from the [multinode documentation](../../docs/examples/multinode.md). - -### Aggregated Deployment - -```bash -cd examples/vllm_v1 -dynamo serve graphs.agg:Frontend -f configs/deepseek_r1/agg.yaml -``` - - -### Disaggregated Deployment + ``` -To create frontend with a single decode worker: -```bash -cd examples/vllm_v1 -dynamo serve graphs.agg:Frontend -f configs/deepseek_r1/disagg.yaml -``` - -To create a single decode worker: -```bash -cd examples/vllm_v1 -dynamo serve components.worker:VllmDecodeWorker -f configs/deepseek_r1/disagg.yaml +TODO: +- Currently if you run more than one instance or worker on the same node this will fail because the ZmqKvPublishers will overlap ports, need to add some port offsetting to manage that. ``` - -To create a single prefill worker: -```bash -cd examples/vllm_v1 -dynamo serve components.worker:VllmPrefillWorker -f configs/deepseek_r1/disagg.yaml + ServiceArgs: + workers: 1 # 2 workers not supported ``` +- It would be best to distill the vLLM serve into a VllmHeadlessWorker using - run_headless(self.engine_args). This is relatively simple, the main difficulty here is if you want to add the ZmqKvEventPublisher to these nodes (which would be easier for multi-node because then you just need to set-up nats and not worry about port stuff) they will have a different lease_id than the leader worker. This is a problem because we don't actually route requests to these dp_ranks directly but in the KV Router and KV Indexer it will see these KVEvents as coming from a seperate "worker". We still need to route the KVEvents through the leader AsyncLLM engine and that engine will take care of routing to the dp ranks. + - To address this we could create a concept of worker groups? IE components whose lease_ids are tied to a single leader worker? -## Testing -Send a test request using curl: -```bash -curl localhost:8000/v1/completions \ - -H "Content-Type: application/json" \ - -d '{ - "model": "deepseek-ai/DeepSeek-R1", - "prompt": "In the heart of Eldoria...", - "stream": false, - "max_tokens": 30 - }' -``` \ No newline at end of file +For more detailed explenations, refer to the main [LLM examples README](../llm/README.md). diff --git a/examples/vllm_v1/components/frontend.py b/examples/vllm_v1/components/frontend.py index a0f86e72db..5c58aa08f8 100644 --- a/examples/vllm_v1/components/frontend.py +++ b/examples/vllm_v1/components/frontend.py @@ -17,7 +17,7 @@ import subprocess from pathlib import Path -from components.simple_load_balancer import SimpleLoadBalancer +from components.worker import VllmDecodeWorker from fastapi import FastAPI from pydantic import BaseModel @@ -42,9 +42,8 @@ def get_dynamo_run_binary(): class FrontendConfig(BaseModel): """Configuration for the Frontend service including model and HTTP server settings.""" - served_model_name: str - endpoint: str port: int = 8080 + router_mode: str = "round-robin" # TODO: move these to common for all LLMs once we adopt dynamo-run @@ -58,7 +57,7 @@ class FrontendConfig(BaseModel): app=FastAPI(title="LLM Example"), ) class Frontend: - worker = depends(SimpleLoadBalancer) + worker = depends(VllmDecodeWorker) def __init__(self): """Initialize Frontend service with HTTP server and model configuration.""" @@ -74,20 +73,20 @@ def start_ingress_and_processor(self): f"Starting HTTP server and processor on port {self.frontend_config.port}" ) dynamo_run_binary = get_dynamo_run_binary() - endpoint = f"dyn://{self.frontend_config.endpoint}" logger.info( f"Starting HTTP server and processor on port {self.frontend_config.port}" ) - logger.info(f"Endpoint: {endpoint}") self.process = subprocess.Popen( [ dynamo_run_binary, "in=http", - f"out={endpoint}", + "out=dyn", "--http-port", str(self.frontend_config.port), + "--router-mode", + self.frontend_config.router_mode, ], stdout=None, stderr=None, diff --git a/examples/vllm_v1/components/headless_worker.py b/examples/vllm_v1/components/headless_worker.py new file mode 100644 index 0000000000..1403cd62d5 --- /dev/null +++ b/examples/vllm_v1/components/headless_worker.py @@ -0,0 +1,100 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Work In Progress. This is not usable currently + +import asyncio +import logging +import os +import signal +import socket +from typing import Optional + +from utils.args import parse_vllm_args +from vllm import run_headless +from vllm.distributed.kv_events import KVEventsConfig + +from dynamo.sdk import service + +logger = logging.getLogger(__name__) + +BLOCK_SIZE = 16 + + +@service( + dynamo={ + "enabled": True, + "namespace": "dynamo", + }, + resources={"gpu": 1, "cpu": "10", "memory": "20Gi"}, + workers=1, +) +class VllmHeadlessWorker: + def __init__(self): + class_name = self.__class__.__name__ + self.engine_args = parse_vllm_args(class_name, "") + self.engine_args.kv_events_config = KVEventsConfig( + enable_kv_cache_events=True, publisher="zmq" + ) + if not self.engine_args.block_size: + logger.info(f"block_size not set, default to {BLOCK_SIZE}") + self.engine_args.block_size = BLOCK_SIZE + + os.environ["VLLM_NO_USAGE_STATS"] = "1" # Avoid internal HTTP requests + + model_config = self.engine_args.create_model_config() + self.default_sampling_params = model_config.get_diff_sampling_param() + + self.kv_publishers = [] + + signal.signal(signal.SIGTERM, self.shutdown_vllm_engine) + signal.signal(signal.SIGINT, self.shutdown_vllm_engine) + + self.set_side_channel_host_and_port() + + async def async_init(self): + run_headless(self.engine_args) + + def shutdown_vllm_engine(self, signum, frame): + """Shutdown the background loop""" + logger.info(f"Received signal {signum}, shutting down") + loop = asyncio.get_event_loop() + try: + self.engine_client.shutdown() + for publisher in self.kv_publishers: + publisher.shutdown() + logger.info("VllmWorker shutdown complete") + except Exception as e: + logger.error(f"Error during shutdown: {e}") + finally: + loop.stop() + + def set_side_channel_host_and_port( + self, hostname: Optional[str] = None, port: Optional[int] = None + ): + """vLLM V1 NixlConnector creates a side channel to exchange metadata with other NIXL connectors. + This sets the port number for the side channel. + """ + if hostname is None: + hostname = socket.gethostname() + if port is None: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) # Bind to a free port provided by the host. + port = s.getsockname()[1] # Get the port number assigned. + logger.debug("Setting VLLM_NIXL_SIDE_CHANNEL_HOST to %s", hostname) + os.environ["VLLM_NIXL_SIDE_CHANNEL_HOST"] = hostname + logger.debug("Setting VLLM_NIXL_SIDE_CHANNEL_PORT to %s", port) + os.environ["VLLM_NIXL_SIDE_CHANNEL_PORT"] = str(port) diff --git a/examples/vllm_v1/components/simple_load_balancer.py b/examples/vllm_v1/components/simple_load_balancer.py deleted file mode 100644 index 9a0d3bfb87..0000000000 --- a/examples/vllm_v1/components/simple_load_balancer.py +++ /dev/null @@ -1,199 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import copy -import logging -import uuid -from typing import AsyncGenerator, Optional - -from components.worker import VllmDecodeWorker, VllmPrefillWorker -from utils.args import parse_vllm_args -from utils.protocol import MyRequestOutput, PreprocessedRequest, vLLMGenerateRequest -from vllm.inputs import TokensPrompt -from vllm.sampling_params import SamplingParams - -from dynamo.llm import ModelType, register_llm -from dynamo.sdk import async_on_start, depends, dynamo_context, endpoint, service - -logger = logging.getLogger(__name__) - - -@service( - dynamo={ - "enabled": True, - "namespace": "dynamo", - }, - resources={"cpu": "10", "memory": "20Gi"}, - workers=1, -) -class SimpleLoadBalancer: - prefill_worker = depends(VllmPrefillWorker) - decode_worker = depends(VllmDecodeWorker) - - def __init__(self): - class_name = self.__class__.__name__ - self.engine_args = parse_vllm_args(class_name, "") - model_config = self.engine_args.create_model_config() - self.default_sampling_params = model_config.get_diff_sampling_param() - self.enable_disagg = self.engine_args.enable_disagg - - @async_on_start - async def async_init(self): - runtime = dynamo_context["runtime"] - logger.info("Registering LLM for discovery") - comp_ns, comp_name = SimpleLoadBalancer.dynamo_address() # type: ignore - endpoint_name = "generate" - for served_model_name in self.engine_args.served_model_name: - logger.info( - f"Registering endpoint {endpoint_name} with model {self.engine_args.model} and served_model_name {served_model_name}" - ) - endpoint = ( - runtime.namespace(comp_ns).component(comp_name).endpoint(endpoint_name) - ) - await register_llm( - ModelType.Backend, - endpoint, - self.engine_args.model, - served_model_name, - ) - - comp_ns, comp_name = VllmDecodeWorker.dynamo_address() # type: ignore - self.decode_worker_client = ( - await runtime.namespace(comp_ns) - .component(comp_name) - .endpoint("generate") - .client() - ) - - comp_ns, comp_name = VllmPrefillWorker.dynamo_address() # type: ignore - self.prefill_worker_client = ( - await runtime.namespace(comp_ns) - .component(comp_name) - .endpoint("generate") - .client() - ) - - logger.info("SimpleLoadBalancer has been initialized") - - async def send_request_to_prefill( - self, request: vLLMGenerateRequest - ) -> MyRequestOutput: - logger.debug("Sending request to prefill") - - prefill_request = copy.deepcopy(request) - extra_args = prefill_request.sampling_params.extra_args or {} - extra_args["kv_transfer_params"] = { - "do_remote_decode": True, - } - prefill_request.sampling_params.extra_args = extra_args - prefill_request.sampling_params.max_tokens = 1 - prefill_request.sampling_params.min_tokens = 1 - - logger.debug("Prefill request: %s", prefill_request.model_dump_json()) - - async for prefill_response in await self.prefill_worker_client.round_robin( - prefill_request.model_dump_json() - ): - return MyRequestOutput.model_validate_json(prefill_response.data()) - - async def send_request_to_decode( - self, - request: vLLMGenerateRequest, - prefill_response: Optional[MyRequestOutput] = None, - ) -> AsyncGenerator[MyRequestOutput, None]: - logger.debug("Sending request to decode") - - decode_request = copy.deepcopy(request) - - if prefill_response: - extra_args = decode_request.sampling_params.extra_args or {} - extra_args["kv_transfer_params"] = prefill_response.kv_transfer_params - decode_request.sampling_params.extra_args = extra_args - - logger.debug("Decode request: %s", decode_request.model_dump_json()) - - async for decode_response in await self.decode_worker_client.round_robin( - decode_request.model_dump_json() - ): - yield MyRequestOutput.model_validate_json(decode_response.data()) - - @endpoint() - async def generate(self, request: PreprocessedRequest): - logger.debug( - "Processor received completion request: %s", request.model_dump_json() - ) - - vllm_request = self._create_vllm_request(request) - - logger.debug("VLLM request: %s", vllm_request.model_dump_json()) - - if self.enable_disagg: - prefill_response = await self.send_request_to_prefill(vllm_request) - - logger.debug("Prefill response: %s", prefill_response.model_dump_json()) - else: - prefill_response = None - - gen = self.send_request_to_decode(vllm_request, prefill_response) - async for res in self._stream_response(gen): - yield res - - def _create_vllm_request(self, request: PreprocessedRequest) -> vLLMGenerateRequest: - request_id = str(uuid.uuid4().hex) - - prompt = TokensPrompt(prompt_token_ids=request.token_ids) - - sampling_params = SamplingParams(**self.default_sampling_params) - for key, value in request.sampling_options.model_dump().items(): - if not value: - continue - if hasattr(sampling_params, key): - setattr(sampling_params, key, value) - - max_tokens = request.stop_conditions.max_tokens - if max_tokens: - sampling_params.max_tokens = max_tokens - - return vLLMGenerateRequest( - prompt=prompt, - sampling_params=sampling_params, - request_id=request_id, - ) - - async def _stream_response(self, gen: AsyncGenerator[MyRequestOutput, None]): - num_output_tokens_so_far = 0 - async for res in gen: - logger.debug("Decode response: %s", res.model_dump_json()) - # res is our MyRequestOutput - - # This is the expected way for a request to end. - # The new token ID will be eos, don't forward it. - if res.finished: - yield {"finish_reason": "stop", "token_ids": []} - break - - if not res.outputs: - yield {"finish_reason": "error", "token_ids": []} - break - - output = res.outputs[0] - next_total_toks = len(output.token_ids) - out = {"token_ids": output.token_ids[num_output_tokens_so_far:]} - if output.finish_reason: - out["finish_reason"] = output.finish_reason - if output.stop_reason: - out["stop_reason"] = output.stop_reason - yield out - num_output_tokens_so_far = next_total_toks diff --git a/examples/vllm_v1/components/worker.py b/examples/vllm_v1/components/worker.py index c26cdbbce2..ef7834873d 100644 --- a/examples/vllm_v1/components/worker.py +++ b/examples/vllm_v1/components/worker.py @@ -19,23 +19,103 @@ import os import signal import socket +import uuid from typing import Optional from utils.args import parse_vllm_args -from utils.protocol import MyRequestOutput, vLLMGenerateRequest -from vllm.entrypoints.openai.api_server import ( - build_async_engine_client_from_engine_args, -) +from utils.protocol import PreprocessedRequest +from vllm.config import VllmConfig +from vllm.distributed.kv_events import KVEventsConfig, ZmqEventPublisher +from vllm.inputs import TokensPrompt +from vllm.sampling_params import SamplingParams +from vllm.usage.usage_lib import UsageContext +from vllm.v1.engine.async_llm import AsyncLLM +from vllm.v1.metrics.loggers import StatLoggerBase +from vllm.v1.metrics.stats import IterationStats, SchedulerStats -from dynamo.sdk import async_on_start, endpoint, service +from dynamo.llm import ( + ModelType, + WorkerMetricsPublisher, + ZmqKvEventPublisher, + ZmqKvEventPublisherConfig, + register_llm, +) +from dynamo.runtime import Component +from dynamo.sdk import async_on_start, dynamo_context, endpoint, service logger = logging.getLogger(__name__) +class DynamoStatLoggerPublisher(StatLoggerBase): + """Stat logger publisher. Wrapper for the WorkerMetricsPublisher to match the StatLoggerBase interface.""" + + def __init__(self, component: Component, dp_rank: int) -> None: + self.inner = WorkerMetricsPublisher() + self.inner.create_endpoint(component, dp_rank=dp_rank) + self.dp_rank = dp_rank + + def record( + self, scheduler_stats: SchedulerStats, iteration_stats: Optional[IterationStats] + ): + # request_total_slots and kv_total_blocks are properties of model + gpu + # we should only publish them once, not every metric update + # they should be part of some runtime metadata tied to MDC or put in etcd ? + hit_rate = 0 + if scheduler_stats.prefix_cache_stats.queries > 0: + hit_rate = ( + scheduler_stats.prefix_cache_stats.hits + / scheduler_stats.prefix_cache_stats.queries + ) + + # TODO Manage DP Ranks in metrics aggregation. + self.inner.publish( + request_active_slots=scheduler_stats.num_running_reqs, + request_total_slots=0, # TODO - remove from metrics + kv_active_blocks=0, # TODO - need to calculate this + kv_total_blocks=0, # TODO - remove from metrics + num_requests_waiting=scheduler_stats.num_waiting_reqs, # used in current cost function + gpu_cache_usage_perc=scheduler_stats.gpu_cache_usage, # used in current cost function + gpu_prefix_cache_hit_rate=hit_rate, + data_parallel_rank=self.dp_rank, + ) + + def log_engine_initialized(self) -> None: + pass + + +class StatLoggerFactory: + """Factory for creating stat logger publishers. Required by vLLM.""" + + def __init__(self, component: Component) -> None: + self.component = component + + def create_stat_logger(self, dp_rank: int) -> StatLoggerBase: + return DynamoStatLoggerPublisher(self.component, dp_rank) + + def __call__(self, vllm_config: VllmConfig, dp_rank: int) -> StatLoggerBase: + return self.create_stat_logger(dp_rank=dp_rank) + + +BLOCK_SIZE = 16 + + class VllmBaseWorker: def __init__(self): class_name = self.__class__.__name__ self.engine_args = parse_vllm_args(class_name, "") + self.engine_args.kv_events_config = KVEventsConfig( + enable_kv_cache_events=True, publisher="zmq" + ) + if not self.engine_args.block_size: + logger.info(f"block_size not set, default to {BLOCK_SIZE}") + self.engine_args.block_size = BLOCK_SIZE + + os.environ["VLLM_NO_USAGE_STATS"] = "1" # Avoid internal HTTP requests + + model_config = self.engine_args.create_model_config() + self.default_sampling_params = model_config.get_diff_sampling_param() + + self.kv_publishers = [] signal.signal(signal.SIGTERM, self.shutdown_vllm_engine) signal.signal(signal.SIGINT, self.shutdown_vllm_engine) @@ -43,22 +123,67 @@ def __init__(self): self.set_side_channel_host_and_port() async def async_init(self): - self._engine_context = build_async_engine_client_from_engine_args( - self.engine_args + # Taken from build_async_engine_client_from_engine_args() + usage_context = UsageContext.OPENAI_API_SERVER + vllm_config = self.engine_args.create_engine_config(usage_context=usage_context) + + await register_llm( + ModelType.Backend, + dynamo_context["endpoints"][0], + self.engine_args.model, + self.engine_args.served_model_name[0], + context_length=self.engine_args.max_model_len, + kv_cache_block_size=self.engine_args.block_size, + ) + + # Explicitly pass our custom stat logger for metrics + self.engine_client = AsyncLLM.from_vllm_config( + vllm_config=vllm_config, + usage_context=usage_context, + stat_loggers=[StatLoggerFactory(dynamo_context["component"])], + disable_log_requests=self.engine_args.disable_log_requests, + disable_log_stats=self.engine_args.disable_log_stats, ) - if self._engine_context is not None: - self.engine_client = await self._engine_context.__aenter__() - else: - raise RuntimeError("Failed to initialize engine client") logger.info("VllmWorker has been initialized") + base_zmq_endpoint = "tcp://127.0.0.1:5557" + dp_rank_size = vllm_config.parallel_config.data_parallel_size + + # Store references to prevent garbage collection + + for dp_rank in range(dp_rank_size): + zmq_endpoint = ZmqEventPublisher.offset_endpoint_port( + base_zmq_endpoint, data_parallel_rank=dp_rank + ) + zmq_config = ZmqKvEventPublisherConfig( + worker_id=dynamo_context["endpoints"][0].lease_id(), + kv_block_size=self.engine_args.block_size, + zmq_endpoint=zmq_endpoint, + ) + + try: + publisher = ZmqKvEventPublisher( + component=dynamo_context["component"], config=zmq_config + ) + self.kv_publishers.append(publisher) + except Exception as e: + logger.error( + f"Failed to create ZmqKvEventPublisher for dp_rank {dp_rank}: {e}" + ) + + logger.debug( + f"Successfully created {len(self.kv_publishers)} ZmqKvEventPublishers out of {dp_rank_size} expected" + ) + def shutdown_vllm_engine(self, signum, frame): """Shutdown the background loop""" logger.info(f"Received signal {signum}, shutting down") loop = asyncio.get_event_loop() try: - self.engine_client.close() + self.engine_client.shutdown() + for publisher in self.kv_publishers: + publisher.shutdown() logger.info("VllmWorker shutdown complete") except Exception as e: logger.error(f"Error during shutdown: {e}") @@ -66,25 +191,51 @@ def shutdown_vllm_engine(self, signum, frame): loop.stop() @endpoint() - async def generate(self, request: vLLMGenerateRequest): + async def generate(self, request: PreprocessedRequest): + request_id = str(uuid.uuid4().hex) + + prompt = TokensPrompt(prompt_token_ids=request.token_ids) + + sampling_params = SamplingParams(**self.default_sampling_params) + for key, value in request.sampling_options.model_dump().items(): + if not value: + continue + if hasattr(sampling_params, key): + setattr(sampling_params, key, value) + + max_tokens = request.stop_conditions.max_tokens + if max_tokens: + sampling_params.max_tokens = max_tokens + gen = self.engine_client.generate( - prompt=request.prompt, - sampling_params=request.sampling_params, - request_id=request.request_id, + prompt=prompt, + sampling_params=sampling_params, + request_id=request_id, + data_parallel_rank=request.dp_rank, ) + num_output_tokens_so_far = 0 + async for res in gen: + # res is vllm's RequestOutput + + # This is the expected way for a request to end. + # The new token ID will be eos, don't forward it. + if res.finished: + yield {"finish_reason": "stop", "token_ids": []} + break + + if not res.outputs: + yield {"finish_reason": "error", "token_ids": []} + break - async for response in gen: - logger.debug(f"Response kv_transfer_params: {response.kv_transfer_params}") - yield MyRequestOutput( - request_id=response.request_id, - prompt=response.prompt, - prompt_token_ids=response.prompt_token_ids, - prompt_logprobs=response.prompt_logprobs, - outputs=response.outputs, - finished=response.finished, - metrics=response.metrics, - kv_transfer_params=response.kv_transfer_params, - ).model_dump_json() + output = res.outputs[0] + next_total_toks = len(output.token_ids) + out = {"token_ids": output.token_ids[num_output_tokens_so_far:]} + if output.finish_reason: + out["finish_reason"] = output.finish_reason + if output.stop_reason: + out["stop_reason"] = output.stop_reason + yield out + num_output_tokens_so_far = next_total_toks def set_side_channel_host_and_port( self, hostname: Optional[str] = None, port: Optional[int] = None diff --git a/examples/vllm_v1/configs/agg.yaml b/examples/vllm_v1/configs/agg.yaml index e1a10870c4..ddc1d83a37 100644 --- a/examples/vllm_v1/configs/agg.yaml +++ b/examples/vllm_v1/configs/agg.yaml @@ -1,8 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. +# you may not use this file except in compliance with the License.More actions # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 @@ -13,22 +9,28 @@ # See the License for the specific language governing permissions and # limitations under the License. Common: - model: deepseek-ai/DeepSeek-R1-Distill-Llama-8B - served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B + model: Qwen/Qwen3-0.6B + + block-size: 16 + max-model-len: 16384 + served_model_name: Qwen/Qwen3-0.6B Frontend: - endpoint: dynamo.SimpleLoadBalancer.generate_agg port: 8000 - served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B - -SimpleLoadBalancer: - enable_disagg: false - common-configs: [model, served_model_name] + router_mode: kv VllmDecodeWorker: enforce-eager: true + max-num-batched-tokens: 16384 + enable-prefix-caching: true + data-parallel-address: 127.0.0.1 + data-parallel-rpc-port: 62300 + data-parallel-size: 2 + data-parallel-size-local: 1 + # api-server-count: 2 + ServiceArgs: - workers: 1 + workers: 1 # 2 workers resources: - gpu: '1' - common-configs: [model, served_model_name] + gpu: 1 # 2 dp ranks + common-configs: [model, served_model_name, block-size, max-model-len] diff --git a/examples/vllm_v1/graphs/agg.py b/examples/vllm_v1/graphs/agg.py index b7428756b3..95e02efab1 100644 --- a/examples/vllm_v1/graphs/agg.py +++ b/examples/vllm_v1/graphs/agg.py @@ -14,8 +14,6 @@ # limitations under the License. from components.frontend import Frontend -from components.simple_load_balancer import SimpleLoadBalancer from components.worker import VllmDecodeWorker -load_balancer = Frontend.link(SimpleLoadBalancer) -load_balancer.link(VllmDecodeWorker) +Frontend.link(VllmDecodeWorker) diff --git a/examples/vllm_v1/utils/args.py b/examples/vllm_v1/utils/args.py index f05976c8b8..6780b72a78 100644 --- a/examples/vllm_v1/utils/args.py +++ b/examples/vllm_v1/utils/args.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + # TODO: rename to avoid ambiguity with vllm package from vllm.engine.arg_utils import AsyncEngineArgs from vllm.utils import FlexibleArgumentParser @@ -23,6 +24,7 @@ def parse_vllm_args(service_name, prefix) -> AsyncEngineArgs: config = ServiceConfig.get_instance() vllm_args = config.as_args(service_name, prefix=prefix) + parser = FlexibleArgumentParser() parser.add_argument( "--enable-disagg", action="store_true", help="Enable disaggregation" @@ -31,4 +33,5 @@ def parse_vllm_args(service_name, prefix) -> AsyncEngineArgs: args = parser.parse_args(vllm_args) engine_args = AsyncEngineArgs.from_cli_args(args) engine_args.enable_disagg = args.enable_disagg + return engine_args diff --git a/examples/vllm_v1/utils/protocol.py b/examples/vllm_v1/utils/protocol.py index 0d83dda371..4a131c6d7b 100644 --- a/examples/vllm_v1/utils/protocol.py +++ b/examples/vllm_v1/utils/protocol.py @@ -61,6 +61,8 @@ class PreprocessedRequest(BaseModel): eos_token_ids: List[TokenIdType] = Field(default_factory=list) mdc_sum: Optional[str] = None annotations: List[str] = Field(default_factory=list) + estimated_prefix_hit_num_blocks: Optional[int] = None + dp_rank: Optional[int] = None # Hack to override the type of multi_modal_data in TokensPrompt diff --git a/launch/dynamo-run/src/subprocess/vllm_v1_inc.py b/launch/dynamo-run/src/subprocess/vllm_v1_inc.py index 04732e11f5..3cd5318a34 100644 --- a/launch/dynamo-run/src/subprocess/vllm_v1_inc.py +++ b/launch/dynamo-run/src/subprocess/vllm_v1_inc.py @@ -23,7 +23,7 @@ import uvloop from vllm.config import VllmConfig -from vllm.distributed.kv_events import KVEventsConfig +from vllm.distributed.kv_events import KVEventsConfig, ZmqEventPublisher from vllm.engine.arg_utils import AsyncEngineArgs from vllm.inputs import TokensPrompt from vllm.sampling_params import SamplingParams @@ -68,7 +68,7 @@ class DynamoStatLoggerPublisher(StatLoggerBase): def __init__(self, component: Component, dp_rank: int) -> None: self.inner = WorkerMetricsPublisher() - self.inner.create_endpoint(component) + self.inner.create_endpoint(component, dp_rank=dp_rank) self.dp_rank = dp_rank def record( @@ -246,12 +246,33 @@ async def init(runtime: DistributedRuntime, config: Config): ) logger.info("VllmWorker has been initialized") + base_zmq_endpoint = "tcp://127.0.0.1:5557" + dp_rank_size = vllm_config.parallel_config.data_parallel_size - zmq_config = ZmqKvEventPublisherConfig( - worker_id=endpoint.lease_id(), kv_block_size=engine_args.block_size - ) + # Store references to prevent garbage collection + kv_publishers = [] + + for dp_rank in range(dp_rank_size): + zmq_endpoint = ZmqEventPublisher.offset_endpoint_port( + base_zmq_endpoint, data_parallel_rank=dp_rank + ) + zmq_config = ZmqKvEventPublisherConfig( + worker_id=endpoint.lease_id(), + kv_block_size=engine_args.block_size, + zmq_endpoint=zmq_endpoint, + ) - _ = ZmqKvEventPublisher(component=component, config=zmq_config) + try: + publisher = ZmqKvEventPublisher(component=component, config=zmq_config) + kv_publishers.append(publisher) + except Exception as e: + logger.error( + f"Failed to create ZmqKvEventPublisher for dp_rank {dp_rank}: {e}" + ) + + logger.debug( + f"Successfully created {len(kv_publishers)} ZmqKvEventPublishers out of {dp_rank_size} expected" + ) handler = RequestHandler(component, engine_client, default_sampling_params) @@ -313,7 +334,7 @@ def cmd_line_args(): endpoint_str = args.endpoint.replace("dyn://", "", 1) endpoint_parts = endpoint_str.split(".") if len(endpoint_parts) != 3: - logging.error( + logger.error( f"Invalid endpoint format: '{args.endpoint}'. Expected 'dyn://namespace.component.endpoint' or 'namespace.component.endpoint'." ) sys.exit(1) diff --git a/lib/bindings/c/src/lib.rs b/lib/bindings/c/src/lib.rs index 1c50f4aa8e..742815230d 100644 --- a/lib/bindings/c/src/lib.rs +++ b/lib/bindings/c/src/lib.rs @@ -14,6 +14,7 @@ // limitations under the License. use async_once_cell::OnceCell as AsyncOnceCell; +use dynamo_llm::kv_router::publisher::KvCacheEventWithDp; use libc::c_char; use once_cell::sync::OnceCell; use std::ffi::CStr; @@ -284,7 +285,12 @@ pub unsafe extern "C" fn dynamo_kv_event_publish_stored( }; let publisher = KV_PUB.get().unwrap(); let event = kv_event_create_stored_from_parts(kv_params, publisher.kv_block_size()); - match publisher.publish(event) { + // NOTE: dummy dp_rank for now + let event_with_dp = KvCacheEventWithDp { + kv_cache_event: event, + dp_rank: None, + }; + match publisher.publish(event_with_dp) { Ok(_) => DynamoLlmResult::OK, Err(e) => { eprintln!("Error publishing stored kv event {:?}", e); @@ -301,7 +307,12 @@ pub extern "C" fn dynamo_kv_event_publish_removed( ) -> DynamoLlmResult { let publisher = KV_PUB.get().unwrap(); let event = kv_event_create_removed_from_parts(event_id, block_ids, num_blocks); - match publisher.publish(event) { + // NOTE: dummy dp_rank for now + let event_with_dp = KvCacheEventWithDp { + kv_cache_event: event, + dp_rank: None, + }; + match publisher.publish(event_with_dp) { Ok(_) => DynamoLlmResult::OK, Err(e) => { eprintln!("Error publishing removed kv event {:?}", e); diff --git a/lib/bindings/python/rust/lib.rs b/lib/bindings/python/rust/lib.rs index 39cc1ea46e..da04c5b5bf 100644 --- a/lib/bindings/python/rust/lib.rs +++ b/lib/bindings/python/rust/lib.rs @@ -61,6 +61,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/lib/bindings/python/rust/llm/kv.rs b/lib/bindings/python/rust/llm/kv.rs index 2d7b3d92b5..7f192930e9 100644 --- a/lib/bindings/python/rust/llm/kv.rs +++ b/lib/bindings/python/rust/llm/kv.rs @@ -22,7 +22,25 @@ use rs::traits::events::EventSubscriber; use tracing; use llm_rs::kv_router::protocols::*; -use llm_rs::kv_router::publisher::{create_stored_blocks, KvEventSourceConfig}; +use llm_rs::kv_router::publisher::{create_stored_blocks, KvCacheEventWithDp, KvEventSourceConfig}; + +#[pyclass] +#[derive(Clone, PartialEq, Eq, Hash)] +pub struct WorkerDp { + #[pyo3(get, set)] + pub worker_id: i64, + #[pyo3(get, set)] + pub dp_rank: Option, +} + +impl From for WorkerDp { + fn from(value: llm_rs::kv_router::protocols::WorkerDp) -> Self { + Self { + worker_id: value.worker_id, + dp_rank: value.dp_rank, + } + } +} #[pyclass] pub(crate) struct KvRouter { @@ -57,7 +75,7 @@ impl KvRouter { .schedule(&token_ids, lora_id) .await .map_err(to_pyerr)?; - Ok(worker_id) + Ok(WorkerDp::from(worker_id)) }) } } @@ -78,17 +96,21 @@ impl WorkerMetricsPublisher { }) } - #[pyo3(signature = (component))] + #[pyo3(signature = (component, dp_rank = None))] fn create_endpoint<'p>( &self, py: Python<'p>, component: Component, + dp_rank: Option, ) -> PyResult> { let rs_publisher = self.inner.clone(); let rs_component = component.inner.clone(); pyo3_async_runtimes::tokio::future_into_py(py, async move { rs_publisher - .create_endpoint(rs_component) + .create_endpoint( + rs_component, + dp_rank.as_ref().map(|v| v.to_string()).as_deref(), + ) .await .map_err(to_pyerr)?; Ok(()) @@ -107,7 +129,7 @@ impl WorkerMetricsPublisher { num_requests_waiting: u64, gpu_cache_usage_perc: f32, gpu_prefix_cache_hit_rate: f32, - data_parallel_rank: u32, + data_parallel_rank: DpRank, ) -> PyResult<()> { self.inner .publish( @@ -218,7 +240,7 @@ impl KvEventPublisher { } #[allow(clippy::too_many_arguments)] - #[pyo3(signature = (event_id, token_ids, num_block_tokens, block_hashes, lora_id, parent_hash=None))] + #[pyo3(signature = (event_id, token_ids, num_block_tokens, block_hashes, lora_id, parent_hash=None, dp_rank=None))] fn publish_stored( &mut self, _py: Python, @@ -228,6 +250,7 @@ impl KvEventPublisher { block_hashes: Vec, lora_id: u64, parent_hash: Option, + dp_rank: Option, ) -> PyResult<()> { let event = KvCacheEvent { event_id, @@ -243,11 +266,22 @@ impl KvEventPublisher { ), }), }; + let event_with_dp = KvCacheEventWithDp { + kv_cache_event: event, + dp_rank, + }; - self.inner.publish(event).map_err(to_pyerr) + self.inner.publish(event_with_dp).map_err(to_pyerr) } - fn publish_removed(&self, _py: Python, event_id: u64, block_hashes: Vec) -> PyResult<()> { + #[pyo3(signature = (event_id, block_hashes, dp_rank=None))] + fn publish_removed( + &self, + _py: Python, + event_id: u64, + block_hashes: Vec, + dp_rank: Option, + ) -> PyResult<()> { let block_hashes: Vec = block_hashes .iter() .map(|&h| ExternalSequenceBlockHash::from(h)) @@ -256,22 +290,30 @@ impl KvEventPublisher { event_id, data: KvCacheEventData::Removed(KvCacheRemoveData { block_hashes }), }; + let event_with_dp = KvCacheEventWithDp { + kv_cache_event: event, + dp_rank, + }; - self.inner.publish(event).map_err(to_pyerr) + self.inner.publish(event_with_dp).map_err(to_pyerr) } } #[pyclass] #[derive(Clone)] pub(crate) struct OverlapScores { - inner: llm_rs::kv_router::indexer::OverlapScores, + inner: llm_rs::kv_router::indexer::OverlapScores, } #[pymethods] impl OverlapScores { #[getter] - fn scores(&self) -> HashMap { - self.inner.scores.clone() + fn scores(&self) -> HashMap { + self.inner + .scores + .iter() + .map(|(k, v)| (WorkerDp::from(*k), *v)) + .collect() } #[getter] @@ -282,7 +324,7 @@ impl OverlapScores { #[pyclass] pub(crate) struct KvIndexer { - inner: Arc, + inner: Arc>, } #[pymethods] @@ -291,12 +333,13 @@ impl KvIndexer { fn new(component: Component, kv_block_size: usize) -> PyResult { let runtime = pyo3_async_runtimes::tokio::get_runtime(); runtime.block_on(async { - let inner: Arc = - llm_rs::kv_router::indexer::KvIndexer::new( - component.inner.drt().runtime().child_token(), - kv_block_size, - ) - .into(); + let inner: Arc< + llm_rs::kv_router::indexer::KvIndexer, + > = llm_rs::kv_router::indexer::KvIndexer::new( + component.inner.drt().runtime().child_token(), + kv_block_size, + ) + .into(); // [gluo TODO] try subscribe_with_type::, // error checking below will be different. let mut kv_events_rx = component @@ -310,8 +353,9 @@ impl KvIndexer { // should have been made to a trait and implemented here? i.e. AsyncEngine style tokio::spawn(async move { while let Some(event) = kv_events_rx.next().await { - let event: llm_rs::kv_router::indexer::RouterEvent = - serde_json::from_slice(&event.payload).unwrap(); + let event: llm_rs::kv_router::protocols::RouterEvent< + llm_rs::kv_router::protocols::WorkerDp, + > = serde_json::from_slice(&event.payload).unwrap(); tracing::debug!("received kv event: {:?}", event); if let Err(e) = kv_events_tx.send(event).await { tracing::trace!( @@ -354,6 +398,8 @@ pub(crate) struct EndpointKvMetrics { #[pyo3(get, set)] pub worker_id: i64, #[pyo3(get, set)] + pub dp_rank: Option, + #[pyo3(get, set)] pub request_active_slots: u64, #[pyo3(get, set)] pub request_total_slots: u64, @@ -407,8 +453,9 @@ impl KvMetricsAggregator { let endpoint_kv_metrics = endpoints .endpoints .iter() - .map(|(worker_id, x)| EndpointKvMetrics { - worker_id: *worker_id, + .map(|(worker_dp, x)| EndpointKvMetrics { + worker_id: worker_dp.worker_id, + dp_rank: worker_dp.dp_rank, request_active_slots: x.data.request_active_slots, request_total_slots: x.data.request_total_slots, kv_active_blocks: x.data.kv_active_blocks, @@ -430,7 +477,7 @@ impl KvMetricsAggregator { #[pyclass] pub(crate) struct KvRecorder { - inner: Arc, + inner: Arc>, } #[pymethods] @@ -481,8 +528,9 @@ impl KvRecorder { // Spawn a task to forward events to the recorder tokio::spawn(async move { while let Some(event) = kv_events_rx.next().await { - let event: llm_rs::kv_router::indexer::RouterEvent = - serde_json::from_slice(&event.payload).unwrap(); + let event: llm_rs::kv_router::protocols::RouterEvent< + llm_rs::kv_router::protocols::WorkerDp, + > = serde_json::from_slice(&event.payload).unwrap(); tracing::debug!("KvRecorder received kv event: {:?}", event); if let Err(e) = event_tx.send(event).await { tracing::trace!( diff --git a/lib/bindings/python/src/dynamo/_core.pyi b/lib/bindings/python/src/dynamo/_core.pyi index 424496fe41..73ffe2f19d 100644 --- a/lib/bindings/python/src/dynamo/_core.pyi +++ b/lib/bindings/python/src/dynamo/_core.pyi @@ -356,7 +356,7 @@ class WorkerMetricsPublisher: Create a `WorkerMetricsPublisher` object """ - def create_service(self, component: Component) -> None: + def create_endpoint(self, component: Component, dp_rank: int) -> None: """ Similar to Component.create_service, but only service created through this method will interact with KV router of the same component. diff --git a/lib/llm/src/kv_router.rs b/lib/llm/src/kv_router.rs index 535a428984..0ab6e84bd5 100644 --- a/lib/llm/src/kv_router.rs +++ b/lib/llm/src/kv_router.rs @@ -14,6 +14,7 @@ use dynamo_runtime::{ protocols::annotated::Annotated, }; use futures::stream::{self, StreamExt}; +use protocols::WorkerDp; pub mod indexer; pub mod metrics_aggregator; @@ -25,9 +26,11 @@ pub mod scoring; use crate::{ kv_router::{ - indexer::{KvIndexer, KvIndexerInterface, RouterEvent}, + indexer::{KvIndexer, KvIndexerInterface}, metrics_aggregator::KvMetricsAggregator, - protocols::{LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult}, + protocols::{ + LocalBlockHash, RouterEvent, RouterRequest, RouterResponse, WorkerSelectionResult, + }, scheduler::{KvScheduler, KvSchedulerError, SchedulingRequest}, scoring::ProcessedEndpoints, }, @@ -51,7 +54,7 @@ pub trait WorkerSelector { workers: &ProcessedEndpoints, request: &SchedulingRequest, block_size: usize, - ) -> Result; + ) -> Result, KvSchedulerError>; } /// KV Router configuration parameters @@ -102,7 +105,7 @@ impl KvRouterConfig { /// A KvRouter only decides which worker you should use. It doesn't send you there. /// TODO: Rename this to indicate it only selects a worker, it does not route. pub struct KvRouter { - indexer: KvIndexer, + indexer: KvIndexer, scheduler: KvScheduler, block_size: usize, } @@ -137,7 +140,7 @@ impl KvRouter { tokio::spawn(async move { while let Some(event) = kv_events_rx.next().await { - let event: RouterEvent = match serde_json::from_slice(&event.payload) { + let event: RouterEvent = match serde_json::from_slice(&event.payload) { Ok(event) => event, Err(e) => { tracing::warn!("Failed to deserialize RouterEvent: {:?}", e); @@ -160,7 +163,7 @@ impl KvRouter { } // [TODO] indexer needs to take 'lora_id' as parameter - pub async fn schedule(&self, token_ids: &Vec, _lora_id: u64) -> Result { + pub async fn schedule(&self, token_ids: &Vec, _lora_id: u64) -> Result { // Extracting part of the code in KvRouter::generate() for only // the decision making part, routing is done by the caller let isl_tokens = token_ids.len(); @@ -175,7 +178,7 @@ impl KvRouter { /// Give these tokens, find the worker with the best match in it's KV cache. /// Returned overlap amount is in number of blocks. - async fn find_best_match(&self, tokens: &[u32]) -> anyhow::Result<(i64, u32)> { + async fn find_best_match(&self, tokens: &[u32]) -> anyhow::Result<(WorkerDp, u32)> { let isl_tokens = tokens.len(); let block_size = self.block_size; @@ -202,15 +205,17 @@ impl KvRouter { } #[async_trait] -impl AsyncEngine, ManyOut>, Error> for KvRouter { +impl AsyncEngine, ManyOut>>, Error> + for KvRouter +{ async fn generate( &self, request: SingleIn, - ) -> Result>> { + ) -> Result>>> { let (request, ctx) = request.into_parts(); - let (worker_id, _) = self.find_best_match(&request.tokens).await?; + let (best_match, _) = self.find_best_match(&request.tokens).await?; - let response = RouterResponse { worker_id }; + let response = RouterResponse { worker: best_match }; let response = Annotated::from_data(response); let stream = stream::iter(vec![response]); Ok(ResponseStream::new(Box::pin(stream), ctx.context())) @@ -247,8 +252,11 @@ impl AsyncEngine, ManyOut>; +type SharedRadixBlock = Rc>>; pub fn compute_hash(data: &[u8]) -> u64 { xxh3::xxh3_64_with_seed(data, XXH3_SEED) @@ -133,43 +130,18 @@ pub fn compute_block_hash_for_seq(tokens: &[u32], kv_block_size: usize) -> Vec Self { - Self { worker_id, event } - } -} - /// A block in the Radix Tree. #[derive(Debug)] -struct RadixBlock { +struct RadixBlock { /// A map of child blocks, keyed by their local block hash. - children: HashMap, + children: HashMap>, /// A set of worker IDs associated with this block. - workers: HashSet, + workers: HashSet, /// A buffer of times that this block was last traversed recent_uses: VecDeque, } -impl RadixBlock { +impl RadixBlock { /// Create a new `RadixBlock`. /// /// ### Returns @@ -184,10 +156,10 @@ impl RadixBlock { } } -pub struct RadixTree { +pub struct RadixTree { /// This is the root of the radix/prefix tree /// This will only contain root blocks - root: SharedRadixBlock, + root: SharedRadixBlock, /// This is a global lookup table for all blocks which will let you jump into /// the radix tree at any point @@ -197,18 +169,18 @@ pub struct RadixTree { /// Transitioning to a radix tree only would require a change in the messaging structure /// as the entire prefix would need to be sent. Alternatively, we could use block_depth /// integers to indicate how many blocks to skip and use a radix/prefix tree at each level. - lookup: HashMap>, + lookup: HashMap>>, /// The time buffer the radix tree should check when considering frequence of block accesses expiration_duration: Option, } -impl Default for RadixTree { +impl Default for RadixTree { fn default() -> Self { Self::new() } } -impl RadixTree { +impl RadixTree { /// Create a new `RadixTree`. /// /// ### Returns @@ -236,7 +208,11 @@ impl RadixTree { /// ### Returns /// /// An `OverlapScores` representing the match scores. - pub fn find_matches(&self, sequence: Vec, early_exit: bool) -> OverlapScores { + pub fn find_matches( + &self, + sequence: Vec, + early_exit: bool, + ) -> OverlapScores { let mut scores = OverlapScores::new(); let mut current = self.root.clone(); let now = Instant::now(); @@ -280,12 +256,12 @@ impl RadixTree { /// ### Arguments /// /// * `event` - The `RouterEvent` to apply. - pub fn apply_event(&mut self, event: RouterEvent) { - let (worker_id, event) = (event.worker_id, event.event); + pub fn apply_event(&mut self, event: RouterEvent) { + let (worker_id, event) = (event.worker, event.event); let (id, op) = (event.event_id, event.data); - tracing::trace!(id, "Store operation: {:?}", op); + tracing::trace!(worker_id = ?worker_id, id=?id, "Store operation: {:?}", op); - let worker_lookup = self.lookup.entry(worker_id).or_default(); + let worker_lookup = self.lookup.entry(worker_id.clone()).or_default(); match op { KvCacheEventData::Stored(op) => { @@ -301,8 +277,8 @@ impl RadixTree { Some(current) => current.clone(), None => { tracing::warn!( - worker_id = worker_id.to_string(), - id, + worker_id = ?worker_id, + id = ?id, parent_hash = ?op.parent_hash, "Failed to find parent block; skipping store operation" ); @@ -331,7 +307,7 @@ impl RadixTree { }; // add our worker_id to the block - block.borrow_mut().workers.insert(worker_id); + block.borrow_mut().workers.insert(worker_id.clone()); // add the block to the worker_id lookup table worker_lookup.insert(block_id.block_hash, block.clone()); @@ -355,8 +331,8 @@ impl RadixTree { Some(entry) => entry.clone(), None => { tracing::warn!( - worker_id = worker_id.to_string(), - id, + worker_id = ?worker_id, + id = ?id, "Failed to find block to remove; skipping remove operation" ); continue; @@ -379,7 +355,7 @@ impl RadixTree { } } - pub fn remove_worker(&mut self, worker: WorkerId) { + pub fn remove_worker(&mut self, worker: T) { if let Some((_, blocks)) = self.lookup.remove_entry(&worker) { blocks.iter().for_each(|(_, block)| { block.borrow_mut().workers.remove(&worker); @@ -387,7 +363,7 @@ impl RadixTree { } } - pub fn clear_all_blocks(&mut self, worker: WorkerId) { + pub fn clear_all_blocks(&mut self, worker: T) { // Check if the worker has any blocks to clear if let Some(blocks) = self.lookup.get(&worker) { let blocks_to_clear: Vec<_> = blocks.values().collect(); @@ -407,20 +383,20 @@ impl RadixTree { /// Scores representing the overlap of workers. #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct OverlapScores { +pub struct OverlapScores { // map of worker_id to score - pub scores: HashMap, + pub scores: HashMap, // List of frequencies that the blocks have been accessed. Entries with value 0 are omitted. pub frequencies: Vec, } -impl Default for OverlapScores { +impl Default for OverlapScores { fn default() -> Self { Self::new() } } -impl OverlapScores { +impl OverlapScores { /// Create a new `OverlapScores`. /// /// ### Returns @@ -437,10 +413,10 @@ impl OverlapScores { /// /// ### Arguments /// - /// * `workers` - A reference to a `HashSet` of `WorkerId`s. - pub fn update_scores(&mut self, workers: &HashSet) { + /// * `workers` - A reference to a `HashSet` of worker IDs. + pub fn update_scores(&mut self, workers: &HashSet) { for worker in workers { - let score = self.scores.entry(*worker).or_insert(0); + let score = self.scores.entry(worker.clone()).or_insert(0); *score += 1; } } @@ -457,17 +433,17 @@ impl OverlapScores { } /// A request to find matches in the Radix Tree. -pub struct MatchRequest { +pub struct MatchRequest { /// A vector of `LocalBlockHash` representing the sequence to match. sequence: Vec, /// A boolean indicating whether to exit early if a single match is found. early_exit: bool, /// A channel sender to send the `OverlapScores` response. - resp: oneshot::Sender, + resp: oneshot::Sender>, } #[async_trait] -pub trait KvIndexerInterface { +pub trait KvIndexerInterface { /// Find matches for a given sequence of `LocalBlockHash`es. /// /// ### Arguments @@ -480,7 +456,7 @@ pub trait KvIndexerInterface { async fn find_matches( &self, sequence: Vec, - ) -> Result; + ) -> Result, KvRouterError>; /// Find matches for a given sequence of tokens. /// @@ -494,43 +470,43 @@ pub trait KvIndexerInterface { async fn find_matches_for_request( &self, tokens: &[u32], - ) -> Result; + ) -> Result, KvRouterError>; /// Apply a `RouterEvent` to the KV store. /// /// ### Arguments /// /// * `event` - The `RouterEvent` to apply. - async fn apply_event(&mut self, event: RouterEvent); + async fn apply_event(&mut self, event: RouterEvent); /// Remove a worker's entries from the trie. /// /// ### Arguments /// /// * `worker` - The worker to remove from the trie. - async fn remove_worker(&mut self, worker: WorkerId); + async fn remove_worker(&mut self, worker: T); /// Shutdown the KV Indexer. fn shutdown(&mut self); } /// The KV Indexer, managing the KV store and handling events and match requests. -pub struct KvIndexer { +pub struct KvIndexer { /// A `CancellationToken` for managing shutdown. cancel: CancellationToken, /// A sender for `RouterEvent`s. - event_tx: mpsc::Sender, + event_tx: mpsc::Sender>, /// A sender for `MatchRequest`s. - match_tx: mpsc::Sender, + match_tx: mpsc::Sender>, /// A sender for remove worker requests. - remove_worker_tx: mpsc::Sender, + remove_worker_tx: mpsc::Sender, /// A handle to the background task managing the KV store. task: OnceLock>, /// The size of the KV block this indexer can handle. kv_block_size: usize, } -impl KvIndexer { +impl KvIndexer { /// Create a new `KvIndexer`. /// /// ### Arguments @@ -546,9 +522,9 @@ impl KvIndexer { expiration_duration: Option, kv_block_size: usize, ) -> Self { - let (event_tx, event_rx) = mpsc::channel::(2048); - let (match_tx, match_rx) = mpsc::channel::(128); - let (remove_worker_tx, remove_worker_rx) = mpsc::channel::(16); + let (event_tx, event_rx) = mpsc::channel::>(2048); + let (match_tx, match_rx) = mpsc::channel::>(128); + let (remove_worker_tx, remove_worker_rx) = mpsc::channel::(16); let cancel_clone = token.clone(); let task = std::thread::spawn(move || { // create a new tokio runtime which will only perform work on a single thread @@ -624,17 +600,17 @@ impl KvIndexer { /// ### Returns /// /// A `mpsc::Sender` for `RouterEvent`s. - pub fn event_sender(&self) -> mpsc::Sender { + pub fn event_sender(&self) -> mpsc::Sender> { self.event_tx.clone() } } #[async_trait] -impl KvIndexerInterface for KvIndexer { +impl KvIndexerInterface for KvIndexer { async fn find_matches( &self, sequence: Vec, - ) -> Result { + ) -> Result, KvRouterError> { let (resp_tx, resp_rx) = oneshot::channel(); let req = MatchRequest { sequence, @@ -658,7 +634,7 @@ impl KvIndexerInterface for KvIndexer { async fn find_matches_for_request( &self, tokens: &[u32], - ) -> Result { + ) -> Result, KvRouterError> { tracing::debug!( "Finding matches for request tokens: {:?} / len: {}", tokens, @@ -669,11 +645,11 @@ impl KvIndexerInterface for KvIndexer { self.find_matches(sequence).await } - async fn apply_event(&mut self, event: RouterEvent) { + async fn apply_event(&mut self, event: RouterEvent) { self.event_tx.send(event).await.unwrap(); } - async fn remove_worker(&mut self, worker: WorkerId) { + async fn remove_worker(&mut self, worker: T) { self.remove_worker_tx.send(worker).await.unwrap(); } @@ -686,28 +662,28 @@ impl KvIndexerInterface for KvIndexer { } #[derive(Debug, Clone)] -pub struct ShardedMatchRequest { +pub struct ShardedMatchRequest { sequence: Vec, early_exit: bool, - resp: mpsc::Sender, + resp: mpsc::Sender>, } /// The KV Indexer, managing the KV store and handling events and match requests. -pub struct KvIndexerSharded { +pub struct KvIndexerSharded { /// A `CancellationToken` for managing shutdown. cancel: CancellationToken, /// The size of the KV block this indexer can handle. kv_block_size: usize, - worker_assignments: HashMap, + worker_assignments: HashMap, worker_counts: Vec, - event_tx: Vec>, - request_broadcast_tx: broadcast::Sender, - remove_worker_tx: Vec>, + event_tx: Vec>>, + request_broadcast_tx: broadcast::Sender>, + remove_worker_tx: Vec>, tasks: Vec>, } -impl KvIndexerSharded { +impl KvIndexerSharded { /// Create a new `KvIndexerSharded`. /// /// ### Arguments @@ -725,19 +701,18 @@ impl KvIndexerSharded { expiration_duration: Option, kv_block_size: usize, ) -> Self { - let worker_assignments: HashMap = HashMap::new(); + let worker_assignments: HashMap = HashMap::new(); let worker_counts: Vec = vec![0; num_shards]; let mut event_tx = Vec::new(); let mut remove_worker_tx = Vec::new(); let mut tasks = Vec::new(); - let (request_broadcast_tx, _) = broadcast::channel::(1048576); + let (request_broadcast_tx, _) = broadcast::channel::>(1048576); for _ in 0..num_shards { - let (shard_event_tx, mut shard_event_rx) = mpsc::channel::(2048); - let (shard_remove_worker_tx, mut shard_remove_worker_rx) = - mpsc::channel::(16); + let (shard_event_tx, mut shard_event_rx) = mpsc::channel::>(2048); + let (shard_remove_worker_tx, mut shard_remove_worker_rx) = mpsc::channel::(16); let mut shard_broadcast_rx = request_broadcast_tx.subscribe(); let cancel = token.clone(); @@ -812,11 +787,11 @@ impl KvIndexerSharded { } #[async_trait] -impl KvIndexerInterface for KvIndexerSharded { +impl KvIndexerInterface for KvIndexerSharded { async fn find_matches( &self, sequence: Vec, - ) -> Result { + ) -> Result, KvRouterError> { 'match_loop: loop { let (match_tx, mut match_rx) = mpsc::channel(self.event_tx.len()); self.request_broadcast_tx @@ -863,14 +838,14 @@ impl KvIndexerInterface for KvIndexerSharded { async fn find_matches_for_request( &self, tokens: &[u32], - ) -> Result { + ) -> Result, KvRouterError> { let sequence = compute_block_hash_for_seq(tokens, self.kv_block_size); self.find_matches(sequence).await } - async fn apply_event(&mut self, event: RouterEvent) { + async fn apply_event(&mut self, event: RouterEvent) { #[allow(clippy::map_entry)] - if !self.worker_assignments.contains_key(&event.worker_id) { + if !self.worker_assignments.contains_key(&event.worker) { // Get the shard with the smallest amount of workers. let selected_shard = self .worker_counts @@ -881,17 +856,17 @@ impl KvIndexerInterface for KvIndexerSharded { .0; self.worker_assignments - .insert(event.worker_id, selected_shard); + .insert(event.worker.clone(), selected_shard); self.worker_counts[selected_shard] += 1; } - self.event_tx[self.worker_assignments[&event.worker_id]] + self.event_tx[self.worker_assignments[&event.worker]] .send(event) .await .unwrap(); } - async fn remove_worker(&mut self, worker: WorkerId) { + async fn remove_worker(&mut self, worker: T) { if let Some((_, shard)) = self.worker_assignments.remove_entry(&worker) { self.worker_counts[shard] -= 1; self.remove_worker_tx[shard].send(worker).await.unwrap(); @@ -909,13 +884,15 @@ impl KvIndexerInterface for KvIndexerSharded { #[cfg(test)] mod tests { - use super::*; use rstest::rstest; use rstest_reuse::{self, *}; use tokio::time; use tokio_util::sync::CancellationToken; + // Use u64 as a simple WorkerIdTrait implementation for tests + type TestWorkerId = u64; + fn setup() { dynamo_runtime::logging::init(); } @@ -941,13 +918,13 @@ mod tests { } fn create_store_event( - worker_id: WorkerId, + worker_id: TestWorkerId, event_id: u64, hashes: Vec, parent: Option, - ) -> RouterEvent { + ) -> RouterEvent { RouterEvent { - worker_id, + worker: worker_id, event: KvCacheEvent { event_id, data: add_blocks(hashes, parent), @@ -955,9 +932,13 @@ mod tests { } } - fn create_remove_event(worker_id: WorkerId, event_id: u64, hashes: Vec) -> RouterEvent { + fn create_remove_event( + worker_id: TestWorkerId, + event_id: u64, + hashes: Vec, + ) -> RouterEvent { RouterEvent { - worker_id, + worker: worker_id, event: KvCacheEvent { event_id, data: KvCacheEventData::Removed(KvCacheRemoveData { @@ -1338,7 +1319,7 @@ mod tests { token: &CancellationToken, num_shards: usize, kv_block_size: usize, - ) -> Box { + ) -> Box> { if num_shards == 1 { Box::new(KvIndexer::new(token.clone(), kv_block_size)) } else { @@ -1423,7 +1404,7 @@ mod tests { const ONE_MILLIS: Duration = Duration::from_millis(1); setup(); - let mut kv_indexer: Box; + let mut kv_indexer: Box>; let token = CancellationToken::new(); let expiration = Duration::from_millis(50); @@ -1534,7 +1515,7 @@ mod tests { }; let router_event = RouterEvent::new(worker_id, kv_cache_event); - assert_eq!(router_event.worker_id, worker_id); + assert_eq!(router_event.worker, worker_id); assert_eq!(router_event.event.event_id, 1); if let KvCacheEventData::Stored(store_op) = &router_event.event.data { assert_eq!(store_op.blocks.len(), 1); @@ -1551,7 +1532,7 @@ mod tests { #[test] fn test_radix_tree_default() { setup(); - let radix_tree: RadixTree = Default::default(); + let radix_tree: RadixTree = Default::default(); assert!(radix_tree.root.borrow().children.is_empty()); assert!(radix_tree.root.borrow().workers.is_empty()); assert!(radix_tree.lookup.is_empty()); @@ -1560,7 +1541,7 @@ mod tests { #[test] fn test_overlap_scores_default() { setup(); - let overlap_scores: OverlapScores = Default::default(); + let overlap_scores: OverlapScores = Default::default(); assert!(overlap_scores.scores.is_empty()); } } diff --git a/lib/llm/src/kv_router/metrics_aggregator.rs b/lib/llm/src/kv_router/metrics_aggregator.rs index 156d1dfb02..fc2e451314 100644 --- a/lib/llm/src/kv_router/metrics_aggregator.rs +++ b/lib/llm/src/kv_router/metrics_aggregator.rs @@ -18,8 +18,7 @@ use std::sync::Once; pub use crate::kv_router::protocols::ForwardPassMetrics; use crate::kv_router::KV_METRICS_ENDPOINT; -use crate::kv_router::scheduler::Endpoint; -use crate::kv_router::ProcessedEndpoints; +use crate::kv_router::scoring::{Endpoint, ProcessedEndpoints}; use dynamo_runtime::component::Component; use dynamo_runtime::{service::EndpointInfo, utils::Duration, Result}; use tokio::sync::watch; @@ -134,11 +133,16 @@ pub async fn collect_endpoints_task( .collect(); tracing::trace!("Found {} endpoints for service: {service_subject}", endpoints.len()); - let processed = ProcessedEndpoints::new(endpoints); + // Only create and send ProcessedEndpoints if we have valid endpoints + if !endpoints.is_empty() { + let processed = ProcessedEndpoints::new(endpoints); - if watch_tx.send(processed).is_err() { - tracing::trace!("failed to send processed endpoints; shutting down"); - break; + if watch_tx.send(processed).is_err() { + tracing::trace!("failed to send processed endpoints; shutting down"); + break; + } + } else { + tracing::trace!("No valid endpoints found, skipping ProcessedEndpoints creation"); } } } diff --git a/lib/llm/src/kv_router/protocols.rs b/lib/llm/src/kv_router/protocols.rs index 8131a54f72..730be4929d 100644 --- a/lib/llm/src/kv_router/protocols.rs +++ b/lib/llm/src/kv_router/protocols.rs @@ -14,7 +14,40 @@ // limitations under the License. use crate::tokens::Token; +use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; +use std::cmp::Eq; +use std::fmt::Debug; +use std::hash::Hash; + +pub type WorkerId = i64; +pub type DpRank = u32; + +#[derive(Hash, PartialEq, Eq, Debug, Clone, Copy, Serialize, Deserialize, Default)] +pub struct WorkerDp { + pub worker_id: WorkerId, + pub dp_rank: Option, +} + +impl std::fmt::Display for WorkerDp { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self.dp_rank { + Some(dp_rank) => write!(f, "{}_{}", self.worker_id, dp_rank), + None => write!(f, "{}", self.worker_id), + } + } +} + +// Cannot add DeserializedOwned otherwise compiler will complain +pub trait WorkerGeneral: + Hash + Eq + Debug + Clone + Send + Sync + Default + 'static + Serialize +{ +} + +impl WorkerGeneral for T where + T: Hash + Eq + Debug + Clone + Send + Sync + Default + 'static + Serialize + DeserializeOwned +{ +} #[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct RouterRequest { @@ -22,14 +55,14 @@ pub struct RouterRequest { } #[derive(Debug, Clone, Serialize, Deserialize, Default)] -pub struct RouterResponse { - pub worker_id: i64, +pub struct RouterResponse { + pub worker: T, } #[derive(Debug)] -pub struct WorkerSelectionResult { +pub struct WorkerSelectionResult { /// The worker id of the selected worker - pub worker_id: i64, + pub worker: T, /// The total number of blocks required to prefill the request pub required_blocks: u64, @@ -58,14 +91,14 @@ pub struct ForwardPassMetrics { /// A [`LocalBlockHash`] is a hash computed from the tokens_ids, extra_token_ids and the optional /// lora_id of a block. -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Ord, PartialOrd)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub struct LocalBlockHash(pub u64); /// A sequence aware hash of a block where the hash is computed from the tokens_ids, extra_token_ids /// and the optional lora_id of a block, PLUS the hash of the parent block. /// /// In this case, the hashing function is external and unknown. -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Ord, PartialOrd)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub struct ExternalSequenceBlockHash(pub u64); // Implement From trait for convenient conversion @@ -137,6 +170,38 @@ pub struct KvCacheRemoveData { pub block_hashes: Vec, } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct KVHitRateEvent { + pub worker: T, + pub isl_blocks: usize, + pub overlap_blocks: usize, +} + +/// A [`KvCacheEvent`] on a specific LLM worker denoted by [`WorkerId`]. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RouterEvent { + /// The ID of the worker emitting the event. + pub worker: T, + /// The cache event associated with the worker. + pub event: KvCacheEvent, +} + +impl RouterEvent { + /// Create a new `RouterEvent`. + /// + /// ### Arguments + /// + /// * `worker_id` - The ID of the worker emitting the event. + /// * `event` - The cache event. + /// + /// ### Returns + /// + /// A new `RouterEvent`. + pub fn new(worker: T, event: KvCacheEvent) -> Self { + Self { worker, event } + } +} + impl Serialize for LocalBlockHash { fn serialize(&self, serializer: S) -> Result where diff --git a/lib/llm/src/kv_router/publisher.rs b/lib/llm/src/kv_router/publisher.rs index d4bf56e0d8..f0b69dc970 100644 --- a/lib/llm/src/kv_router/publisher.rs +++ b/lib/llm/src/kv_router/publisher.rs @@ -14,9 +14,7 @@ // limitations under the License. use crate::kv_router::{ - indexer::{compute_block_hash_for_seq, RouterEvent}, - protocols::*, - KV_EVENT_SUBJECT, KV_METRICS_ENDPOINT, + indexer::compute_block_hash_for_seq, protocols::*, KV_EVENT_SUBJECT, KV_METRICS_ENDPOINT, }; use async_trait::async_trait; use dynamo_runtime::traits::{events::EventPublisher, DistributedRuntimeProvider}; @@ -45,6 +43,13 @@ use zeromq::{Socket, SocketRecv, SubSocket}; // KV Event Publishers ----------------------------------------------------- // ------------------------------------------------------------------------- +/// Represents a single cache event with an ID and associated data. +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct KvCacheEventWithDp { + pub kv_cache_event: KvCacheEvent, + pub dp_rank: Option, +} + /// Configure the source of KV events. /// Currently, only ZMQ is supported. pub enum KvEventSourceConfig { @@ -65,7 +70,7 @@ impl KvEventSource { kv_block_size: usize, source_config: KvEventSourceConfig, cancellation_token: CancellationToken, - tx: mpsc::UnboundedSender, + tx: mpsc::UnboundedSender, ) -> Result { match source_config { KvEventSourceConfig::Zmq { endpoint, topic } => { @@ -97,7 +102,6 @@ impl KvEventSource { /// A publisher of KV events. pub struct KvEventPublisher { - /// The size of the KV block. kv_block_size: usize, /// The source of KV events. /// Can be `None` if all events provided through [`KvEventPublisher::publish`]. @@ -105,19 +109,19 @@ pub struct KvEventPublisher { /// The cancellation token. cancellation_token: CancellationToken, /// The channel to send events to. - tx: mpsc::UnboundedSender, + tx: mpsc::UnboundedSender, } impl KvEventPublisher { pub fn new( component: Component, - worker_id: i64, + worker_id: WorkerId, kv_block_size: usize, source_config: Option, ) -> Result { let cancellation_token = CancellationToken::new(); - let (tx, rx) = mpsc::unbounded_channel::(); + let (tx, rx) = mpsc::unbounded_channel::(); // Create our event source (if any) let mut source = None; @@ -150,7 +154,10 @@ impl KvEventPublisher { }) } - pub fn publish(&self, event: KvCacheEvent) -> Result<(), mpsc::error::SendError> { + pub fn publish( + &self, + event: KvCacheEventWithDp, + ) -> Result<(), mpsc::error::SendError> { tracing::trace!("Publish event: {:?}", event); self.tx.send(event) } @@ -178,30 +185,36 @@ impl Drop for KvEventPublisher { async fn start_event_processor( publisher: P, - worker_id: i64, + worker_id: WorkerId, cancellation_token: CancellationToken, - mut rx: mpsc::UnboundedReceiver, + mut rx: mpsc::UnboundedReceiver, ) { + tracing::debug!("KV Event processor starting for worker_id: {}", worker_id); + loop { tokio::select! { _ = cancellation_token.cancelled() => { - tracing::info!("KV Event source received cancellation signal"); + tracing::debug!("KV Event processor received cancellation signal for worker_id: {}", worker_id); break; } - event = rx.recv() => { - let Some(event) = event else { - tracing::debug!("Event processor channel closed."); + maybe_data = rx.recv() => { + let Some(data) = maybe_data else { + tracing::debug!("KV Event processor channel closed for worker_id: {}", worker_id); break; }; // Encapsulate in a router event and publish. - let router_event = RouterEvent::new(worker_id, event); + let event = data.kv_cache_event; + let dp_rank = data.dp_rank.unwrap_or(0); + + let router_event = RouterEvent::new((worker_id, dp_rank), event); if let Err(e) = publisher.publish(KV_EVENT_SUBJECT, &router_event).await { - tracing::error!("Failed to publish event: {}", e); + tracing::error!("Failed to publish event for worker_id: {}, dp_rank: {}, error: {}", worker_id, dp_rank, e); } } } } + tracing::debug!("KV Event processor exiting for worker_id: {}", worker_id); } // Error handling configuration for ZMQ operations @@ -221,12 +234,12 @@ fn calculate_backoff_ms(consecutive_errors: u32) -> u64 { async fn start_zmq_listener( zmq_endpoint: String, zmq_topic: String, - tx: mpsc::UnboundedSender, + tx: mpsc::UnboundedSender, cancellation_token: CancellationToken, kv_block_size: usize, ) { tracing::debug!( - "KVEventPublisher connecting to ZMQ endpoint {} (topic '{}')", + "ZMQ listener starting - connecting to endpoint: {}, topic: '{}'", zmq_endpoint, zmq_topic ); @@ -237,15 +250,25 @@ async fn start_zmq_listener( // Subscribe to the requested topic (empty string == all topics) if let Err(e) = socket.subscribe(&zmq_topic).await { - tracing::error!("Failed to subscribe on ZMQ socket: {}", e); + tracing::error!( + "Failed to subscribe on ZMQ socket for {}: {}", + zmq_endpoint, + e + ); return; } if let Err(e) = socket.connect(&zmq_endpoint).await { - tracing::error!("Failed to connect ZMQ SUB socket: {}", e); + tracing::error!( + "Failed to connect ZMQ SUB socket to {}: {}", + zmq_endpoint, + e + ); return; } + tracing::debug!("ZMQ listener successfully connected to {}", zmq_endpoint); + let mut consecutive_errors = 0u32; loop { @@ -254,7 +277,7 @@ async fn start_zmq_listener( // Check for cancellation _ = cancellation_token.cancelled() => { - tracing::info!("ZMQ listener received cancellation signal"); + tracing::debug!("ZMQ listener received cancellation signal for {}", zmq_endpoint); break; } @@ -268,6 +291,7 @@ async fn start_zmq_listener( tracing::error!( error=%e, consecutive_errors=%consecutive_errors, + endpoint=%zmq_endpoint, "Too many consecutive ZMQ errors, terminating listener" ); break; @@ -280,6 +304,7 @@ async fn start_zmq_listener( error=%e, consecutive_errors=%consecutive_errors, backoff_ms=%backoff_ms, + endpoint=%zmq_endpoint, "Error reading from ZMQ socket, applying exponential backoff" ); @@ -293,7 +318,7 @@ async fn start_zmq_listener( let mut frames: Vec> = msg.into_vec().into_iter().map(|frame| frame.to_vec()).collect(); if frames.len() != 3 { - tracing::warn!(expected=3, actual=%frames.len(), "Received unexpected ZMQ frame count"); + tracing::warn!(expected=3, actual=%frames.len(), endpoint=%zmq_endpoint, "Received unexpected ZMQ frame count"); continue; } @@ -302,7 +327,7 @@ async fn start_zmq_listener( let seq_bytes = frames.pop().unwrap(); if seq_bytes.len() != 8 { - tracing::warn!(expected=8, actual=%seq_bytes.len(), "Invalid sequence number byte length"); + tracing::warn!(expected=8, actual=%seq_bytes.len(), endpoint=%zmq_endpoint, "Invalid sequence number byte length"); continue; } @@ -312,22 +337,25 @@ async fn start_zmq_listener( let batch_result = rmps::from_slice::(&payload); let Ok(batch) = batch_result else { let e = batch_result.unwrap_err(); - tracing::warn!(error=%e, "Failed to decode KVEventBatch msgpack"); + tracing::warn!(error=%e, endpoint=%zmq_endpoint, "Failed to decode KVEventBatch msgpack"); continue; }; + tracing::trace!("ZMQ listener decoded batch with {} events, dp_rank: {:?} from {}", batch.events.len(), batch.data_parallel_rank, zmq_endpoint); + // For each of our events, convert them to [`KvCacheEvent`] and send to the event_processor. + let dp_rank = batch.data_parallel_rank; for raw_event in batch.events.into_iter() { - let event = convert_event(raw_event, seq, kv_block_size, &warning_count); - if tx.send(event).is_err() { - tracing::warn!("Failed to send message to channel - receiver dropped"); + let kv_cache_event = convert_event(raw_event, seq, kv_block_size, &warning_count); + if tx.send(KvCacheEventWithDp { kv_cache_event, dp_rank }).is_err() { + tracing::warn!("Failed to send message to channel - receiver dropped for {}", zmq_endpoint); return; } } } } - tracing::debug!("ZMQ listener exiting"); } + tracing::debug!("ZMQ listener exiting for {}", zmq_endpoint); } /// Convert a raw event coming from the ZMQ channel into the internal @@ -438,6 +466,8 @@ pub fn create_stored_blocks( struct KvEventBatch { ts: f64, events: Vec, + #[serde(alias = "dp_rank")] + data_parallel_rank: Option, } #[derive(Debug, Deserialize, Serialize)] @@ -479,13 +509,18 @@ impl WorkerMetricsPublisher { self.tx.send(metrics) } - pub async fn create_endpoint(&self, component: Component) -> Result<()> { + pub async fn create_endpoint(&self, component: Component, suffix: Option<&str>) -> Result<()> { let mut metrics_rx = self.rx.clone(); let handler = Arc::new(KvLoadEndpoingHander::new(metrics_rx.clone())); let handler = Ingress::for_engine(handler)?; + let endpoint_name = match suffix { + Some(s) => format!("{}-{}", KV_METRICS_ENDPOINT, s), + None => KV_METRICS_ENDPOINT.to_string(), + }; + component - .endpoint(KV_METRICS_ENDPOINT) + .endpoint(&endpoint_name) .endpoint_builder() .stats_handler(move |_| { let metrics = metrics_rx.borrow_and_update().clone(); @@ -705,15 +740,20 @@ mod tests_startup_helpers { async fn test_start_event_processor() { let (component, published) = MockComponent::new(); - let event = KvCacheEvent { + let kv_cache_event = KvCacheEvent { event_id: 1, data: KvCacheEventData::Removed(KvCacheRemoveData { block_hashes: vec![ExternalSequenceBlockHash(1), ExternalSequenceBlockHash(2)], }), }; + let event = KvCacheEventWithDp { + kv_cache_event, + dp_rank: None, + }; + let token = CancellationToken::new(); - let (tx, rx) = mpsc::unbounded_channel::(); + let (tx, rx) = mpsc::unbounded_channel::(); tx.send(event).unwrap(); drop(tx); @@ -737,7 +777,7 @@ mod tests_startup_helpers { #[tokio::test] async fn test_start_zmq_listener_pushes_to_channel() { // Prepare channel that listener should fill - let (tx, mut rx) = mpsc::unbounded_channel::(); + let (tx, mut rx) = mpsc::unbounded_channel::(); // ZMQ TCP endpoint using localhost with fixed port let endpoint = "tcp://127.0.0.1:15555"; @@ -770,7 +810,11 @@ mod tests_startup_helpers { lora_id: None, }]; - let batch = KvEventBatch { ts: 0.0, events }; + let batch = KvEventBatch { + ts: 0.0, + events, + data_parallel_rank: None, + }; let payload = Bytes::from(rmps::to_vec(&batch).unwrap()); @@ -795,7 +839,7 @@ mod tests_startup_helpers { let KvCacheEventData::Stored(KvCacheStoreData { parent_hash, blocks, - }) = event.data + }) = event.kv_cache_event.data else { panic!("expected KvCacheStoreData"); }; diff --git a/lib/llm/src/kv_router/recorder.rs b/lib/llm/src/kv_router/recorder.rs index 17c66c7925..40cdbffcbe 100644 --- a/lib/llm/src/kv_router/recorder.rs +++ b/lib/llm/src/kv_router/recorder.rs @@ -13,23 +13,24 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::kv_router::indexer::RouterEvent; +use crate::kv_router::protocols::*; use crate::recorder::Recorder; -// Type alias for backward compatibility -pub type KvRecorder = Recorder; +// Type alias for backward compatibility, now generic +pub type KvRecorder = Recorder>; #[cfg(test)] mod tests { use super::*; use crate::kv_router::indexer::KvIndexer; - use crate::kv_router::indexer::WorkerId; - use crate::kv_router::protocols::*; use std::time::Duration; use tempfile::tempdir; use tokio::fs; use tokio_util::sync::CancellationToken; + // Use i64 for tests + type TestWorkerId = i64; + fn make_blocks(hashes: Vec) -> Vec { hashes .iter() @@ -51,11 +52,11 @@ mod tests { } fn create_store_event( - worker_id: WorkerId, + worker_id: TestWorkerId, event_id: u64, hashes: Vec, parent: Option, - ) -> RouterEvent { + ) -> RouterEvent { RouterEvent::new( worker_id, KvCacheEvent { @@ -65,7 +66,11 @@ mod tests { ) } - fn create_remove_event(worker_id: WorkerId, event_id: u64, hashes: Vec) -> RouterEvent { + fn create_remove_event( + worker_id: TestWorkerId, + event_id: u64, + hashes: Vec, + ) -> RouterEvent { RouterEvent::new( worker_id, KvCacheEvent { @@ -88,7 +93,7 @@ mod tests { // Part 1: Record events to a file let token = CancellationToken::new(); - let recorder = KvRecorder::new(token.clone(), &file_path, None, None, None) + let recorder = KvRecorder::::new(token.clone(), &file_path, None, None, None) .await .unwrap(); let event_tx = recorder.event_sender(); @@ -128,13 +133,19 @@ mod tests { // Part 2: Now create a KvIndexer and load the events from the file let indexer_token = CancellationToken::new(); let kv_block_size = 32; // Default block size for testing - let indexer = KvIndexer::new(indexer_token.clone(), kv_block_size); + let indexer = KvIndexer::::new(indexer_token.clone(), kv_block_size); let indexer_event_tx = indexer.event_sender(); // Use the send_events method to load events from file to indexer - let count = KvRecorder::send_events(&file_path, &indexer_event_tx, false, None, None) - .await - .unwrap(); + let count = KvRecorder::::send_events( + &file_path, + &indexer_event_tx, + false, + None, + None, + ) + .await + .unwrap(); assert_eq!(count, 2, "Expected to send 2 events from file to indexer"); } } diff --git a/lib/llm/src/kv_router/scheduler.rs b/lib/llm/src/kv_router/scheduler.rs index edf85d3198..f9afafc555 100644 --- a/lib/llm/src/kv_router/scheduler.rs +++ b/lib/llm/src/kv_router/scheduler.rs @@ -16,11 +16,9 @@ use dynamo_runtime::component::Namespace; use dynamo_runtime::traits::events::EventPublisher; use rand::Rng; -use serde::{Deserialize, Serialize}; use std::borrow::BorrowMut; use std::collections::HashMap; -use super::protocols::WorkerSelectionResult; use super::WorkerSelector; use crate::kv_router::indexer::OverlapScores; pub use crate::kv_router::protocols::ForwardPassMetrics; @@ -28,12 +26,7 @@ use crate::kv_router::scoring::ProcessedEndpoints; use crate::kv_router::KvRouterConfig; use crate::kv_router::KV_HIT_RATE_SUBJECT; -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct KVHitRateEvent { - pub worker_id: i64, - pub isl_blocks: usize, - pub overlap_blocks: usize, -} +use super::protocols::{KVHitRateEvent, WorkerDp, WorkerSelectionResult}; #[derive(Debug, thiserror::Error)] pub enum KvSchedulerError { @@ -47,39 +40,15 @@ pub enum KvSchedulerError { SubscriberShutdown, } -/// [gluo FIXME] exactly the same as EndpointInfo except that 'data' -/// is cleaned (not optional) -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Endpoint { - pub name: String, - pub subject: String, - pub data: ForwardPassMetrics, -} - -impl Endpoint { - pub fn worker_id(&self) -> i64 { - i64::from_str_radix( - self.subject - .split("-") - .last() - .expect("invalid subject") - .to_string() - .as_str(), - 16, - ) - .expect("invalid worker id") - } -} - pub struct SchedulingRequest { pub isl_tokens: usize, - pub overlap: OverlapScores, - resp_tx: tokio::sync::oneshot::Sender, + pub overlap: OverlapScores, + resp_tx: tokio::sync::oneshot::Sender, } impl SchedulingRequest { - pub fn respond(self, worker_id: i64) { - if self.resp_tx.send(worker_id).is_err() { + pub fn respond(self, identifier: WorkerDp) { + if self.resp_tx.send(identifier).is_err() { tracing::trace!("failed to send response to requestor"); } } @@ -100,7 +69,8 @@ impl KvScheduler { let mut endpoints_rx = endpoints_rx; let mut endpoints: ProcessedEndpoints = endpoints_rx.borrow_and_update().clone(); - let (event_tx, event_rx) = tokio::sync::mpsc::unbounded_channel::(); + let (event_tx, event_rx) = + tokio::sync::mpsc::unbounded_channel::>(); tokio::spawn(async move { let mut event_rx = event_rx; while let Some(event) = event_rx.recv().await { @@ -178,9 +148,9 @@ impl KvScheduler { pub async fn schedule( &self, - overlap: OverlapScores, + overlap: OverlapScores, isl_tokens: usize, - ) -> Result { + ) -> Result { let (resp_tx, resp_rx) = tokio::sync::oneshot::channel(); let request = SchedulingRequest { isl_tokens, @@ -201,12 +171,12 @@ impl KvScheduler { // This becomes the driver function that handles the selection result pub fn process_worker_selection( workers: &mut ProcessedEndpoints, - selection: WorkerSelectionResult, - event_tx: &tokio::sync::mpsc::UnboundedSender, -) -> i64 { + selection: WorkerSelectionResult, + event_tx: &tokio::sync::mpsc::UnboundedSender>, +) -> WorkerDp { let worker = workers .endpoints - .get_mut(&selection.worker_id) + .get_mut(&selection.worker) .expect("worker not found"); // Update worker state predictively @@ -220,14 +190,14 @@ pub fn process_worker_selection( // Emit event if let Err(e) = event_tx.send(KVHitRateEvent { - worker_id: selection.worker_id, + worker: selection.worker, isl_blocks: selection.required_blocks as usize, overlap_blocks: selection.overlap_blocks, }) { tracing::warn!("Failed to send KV hit rate event: {:?}", e); } - selection.worker_id + selection.worker } // Default implementation matching the Python _cost_function @@ -250,7 +220,7 @@ impl WorkerSelector for DefaultWorkerSelector { workers: &ProcessedEndpoints, request: &SchedulingRequest, block_size: usize, - ) -> Result { + ) -> Result, KvSchedulerError> { assert!(request.isl_tokens > 0); if workers.endpoints.is_empty() { @@ -261,13 +231,11 @@ impl WorkerSelector for DefaultWorkerSelector { let mut max_waiting = 0.0; // Calculate worker scores and find max waiting requests - for (worker_id, ep) in workers.endpoints.iter() { - // Calculate score similar to Python version - if let Some(score) = request.overlap.scores.get(worker_id) { + for (worker_dp, ep) in workers.endpoints.iter() { + if let Some(score) = request.overlap.scores.get(worker_dp) { let score = *score as f64 * block_size as f64 / request.isl_tokens as f64; - worker_scores.insert(worker_id, score); + worker_scores.insert(worker_dp, score); } - // Track max waiting requests max_waiting = f64::max(max_waiting, ep.data.num_requests_waiting as f64); } @@ -278,13 +246,11 @@ impl WorkerSelector for DefaultWorkerSelector { // Calculate logits for each worker let mut best_logit = f64::NEG_INFINITY; - let mut best_workers = Vec::new(); - - for (worker_id, ep) in workers.endpoints.iter() { - let worker_id = *worker_id; + let mut best_worker_dps = Vec::new(); + for (worker_dp, ep) in workers.endpoints.iter() { // Get score or default to 0.0 - let score = worker_scores.get(&worker_id).copied().unwrap_or(0.0); + let score = worker_scores.get(worker_dp).copied().unwrap_or(0.0); // Calculate normalized metrics let gpu_cache_usage = ep.data.gpu_cache_usage_perc as f64; @@ -300,7 +266,7 @@ impl WorkerSelector for DefaultWorkerSelector { - self.kv_router_config.waiting_requests_weight * normalized_waiting; tracing::trace!( - "Formula for {worker_id}: {logit:.3} = {:.1} * {score:.3} - {:.1} * {gpu_cache_usage:.3} - {:.1} * {normalized_waiting:.3}", + "Formula for {worker_dp:?}: {logit:.3} = {:.3} * {score:.3} - {:.3} * {gpu_cache_usage:.3} - {:.3} * {normalized_waiting:.3}", self.kv_router_config.overlap_score_weight, self.kv_router_config.gpu_cache_usage_weight, self.kv_router_config.waiting_requests_weight, @@ -310,40 +276,45 @@ impl WorkerSelector for DefaultWorkerSelector { match logit.partial_cmp(&best_logit) { Some(std::cmp::Ordering::Greater) => { best_logit = logit; - best_workers.clear(); - best_workers.push(worker_id); + best_worker_dps.clear(); + best_worker_dps.push(worker_dp); } Some(std::cmp::Ordering::Equal) => { - best_workers.push(worker_id); + best_worker_dps.push(worker_dp); } _ => {} } } // Return early if no valid workers found - if best_workers.is_empty() { + if best_worker_dps.is_empty() { return Err(KvSchedulerError::NoEndpoints); } else if best_logit == 0.0 { tracing::debug!("best worker logit is 0"); } - let worker_id = if best_workers.len() == 1 { - best_workers[0] + let best_worker_dp = if best_worker_dps.len() == 1 { + best_worker_dps[0] } else { // Randomly select from best workers let mut rng = rand::rng(); - best_workers[rng.random_range(0..best_workers.len())] + best_worker_dps[rng.random_range(0..best_worker_dps.len())] }; // Lower to trace level eventually. Nice to see KV routing working for now. - tracing::debug!("Selected worker: {worker_id}, logit: {best_logit:.3}"); + tracing::debug!("Selected worker: {best_worker_dp:?}, logit: {best_logit:.3}"); // Log selection metrics let total_blocks = std::cmp::max(request.isl_tokens / block_size, 1) as u64; - let overlap_blocks = request.overlap.scores.get(&worker_id).copied().unwrap_or(0) as usize; + let overlap_blocks = request + .overlap + .scores + .get(best_worker_dp) + .copied() + .unwrap_or(0) as usize; Ok(WorkerSelectionResult { - worker_id, + worker: *best_worker_dp, required_blocks: total_blocks, overlap_blocks, }) diff --git a/lib/llm/src/kv_router/scoring.rs b/lib/llm/src/kv_router/scoring.rs index c663c22b5a..c62f486400 100644 --- a/lib/llm/src/kv_router/scoring.rs +++ b/lib/llm/src/kv_router/scoring.rs @@ -18,11 +18,46 @@ use serde::{Deserialize, Serialize}; use std::collections::HashMap; -use crate::kv_router::scheduler::Endpoint; +use crate::kv_router::protocols::{DpRank, ForwardPassMetrics, WorkerDp}; + +/// [gluo FIXME] exactly the same as EndpointInfo except that 'data' +/// is cleaned (not optional) +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Endpoint { + pub name: String, + // contains dp + pub subject: String, + // one set of metrics for each dp worker + pub data: ForwardPassMetrics, +} + +impl Endpoint { + pub fn worker_id(&self) -> i64 { + i64::from_str_radix( + self.subject + .split("-") + .last() + .expect("invalid subject") + .to_string() + .as_str(), + 16, + ) + .expect("invalid worker id") + } + + pub fn dp_rank(&self) -> Option { + let parts: Vec<&str> = self.subject.split("-").collect(); + if parts.len() < 3 { + return None; + } + let second_to_last = parts[parts.len() - 2]; + second_to_last.parse::().ok() + } +} #[derive(Debug, Default, Serialize, Deserialize, Clone)] pub struct ProcessedEndpoints { - pub endpoints: HashMap, + pub endpoints: HashMap, pub load_avg: f64, pub load_std: f64, } @@ -32,8 +67,12 @@ impl ProcessedEndpoints { // compute some basic statistics let load_values: Vec = endpoints .iter() - .map(|x| x.data.kv_active_blocks as f64) + .map(|endpoint| endpoint.data.kv_active_blocks as f64) .collect(); + if load_values.is_empty() { + // TODO we hit this panic while vLLM is starting the ranks up. Need to avoid this + panic!("No endpoints to process!") + }; let load_avg = load_values.iter().copied().sum::() / load_values.len() as f64; let variance = load_values .iter() @@ -42,7 +81,19 @@ impl ProcessedEndpoints { / load_values.len() as f64; let load_std = variance.sqrt(); - let endpoints = endpoints.into_iter().map(|e| (e.worker_id(), e)).collect(); + // pass in (worker_id, dp_rank) tuple + let endpoints = endpoints + .into_iter() + .map(|e| { + ( + WorkerDp { + worker_id: e.worker_id(), + dp_rank: e.dp_rank(), + }, + e, + ) + }) + .collect(); ProcessedEndpoints { endpoints, diff --git a/lib/llm/src/kv_router/worker.rs b/lib/llm/src/kv_router/worker.rs deleted file mode 100644 index fc44624f85..0000000000 --- a/lib/llm/src/kv_router/worker.rs +++ /dev/null @@ -1,105 +0,0 @@ -// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::sync::Arc; - -pub use crate::kv_router::protocols::ForwardPassMetrics; - -use anyhow::Result; -use derive_builder::Builder; -use dynamo_runtime::pipeline::network::{ - ingress::push_endpoint::PushEndpoint, - PushWorkHandler, -}; - -use dynamo_runtime::transports::nats::{self, ServiceExt}; - -use tokio::sync::watch; -use tokio_util::sync::CancellationToken; -use tracing as log; - -#[derive(Builder)] -pub struct KvRoutedIngress { - #[builder(setter(into))] - pub service_name: String, - - #[builder(setter(into))] - pub worker_id: String, - - pub nats: nats::Client, - pub service_handler: Arc, - pub metrics_rx: watch::Receiver>, - pub cancellation_token: CancellationToken, -} - -/// version of crate -pub const VERSION: &str = env!("CARGO_PKG_VERSION"); - -impl KvRoutedIngress { - pub fn builder() -> KvRoutedIngressBuilder { - KvRoutedIngressBuilder::default() - } - - pub async fn start(self) -> Result<()> { - let worker_id = self.worker_id; - - log::trace!( - worker_id, - "Starting nats service: {}:{}", - self.service_name, - VERSION - ); - - let mut metrics_rx = self.metrics_rx; - let worker_id_clone = worker_id.clone(); - - let service = self - .nats - .client() - .service_builder() - .description("A handy min max service") - .stats_handler(move |name, stats| { - log::debug!( - worker_id = worker_id_clone.as_str(), - "[IN worker?] Stats for service {}: {:?}", - name, - stats - ); - let metrics = metrics_rx.borrow_and_update().clone(); - serde_json::to_value(&*metrics).unwrap() - }) - .start(self.service_name.as_str(), VERSION) - .await - .map_err(|e| anyhow::anyhow!("Failed to start service: {e}"))?; - - let group = service.group(self.service_name.as_str()); - - log::trace!(worker_id, "Starting endpoint: {}", worker_id); - - // creates an endpoint for the service - let service_endpoint = group - .endpoint(worker_id.clone()) - .await - .map_err(|e| anyhow::anyhow!("Failed to start endpoint: {e}"))?; - - let push_endpoint = PushEndpoint::builder() - .service_handler(self.service_handler) - .cancellation_token(self.cancellation_token) - .build() - .map_err(|e| anyhow::anyhow!("Failed to build push endpoint: {e}"))?; - - push_endpoint.start(service_endpoint).await - } -} diff --git a/lib/llm/src/protocols/common/preprocessor.rs b/lib/llm/src/protocols/common/preprocessor.rs index 6b3be76069..151b9dbfcc 100644 --- a/lib/llm/src/protocols/common/preprocessor.rs +++ b/lib/llm/src/protocols/common/preprocessor.rs @@ -51,6 +51,10 @@ pub struct PreprocessedRequest { /// Estimated number of prefix hit tokens (only used in kv aware routing) #[builder(default)] pub estimated_prefix_hit_num_blocks: Option, + + // The dp_rank to route to + #[builder(default)] + pub dp_rank: Option, } impl PreprocessedRequest {