From 86c79ba42240de4ef5c746e2f95f9258db254d87 Mon Sep 17 00:00:00 2001 From: PeaBrane Date: Thu, 29 May 2025 20:23:53 -0700 Subject: [PATCH 01/22] some prelim cleanups --- lib/llm/src/kv_router.rs | 6 +++-- lib/llm/src/kv_router/indexer.rs | 28 ----------------------- lib/llm/src/kv_router/protocols.rs | 36 ++++++++++++++++++++++++++---- lib/llm/src/kv_router/publisher.rs | 4 +--- lib/llm/src/kv_router/recorder.rs | 4 +--- 5 files changed, 38 insertions(+), 40 deletions(-) diff --git a/lib/llm/src/kv_router.rs b/lib/llm/src/kv_router.rs index 67d3ccfcac..b86c4cf00e 100644 --- a/lib/llm/src/kv_router.rs +++ b/lib/llm/src/kv_router.rs @@ -25,9 +25,11 @@ pub mod scoring; use crate::{ kv_router::{ - indexer::{KvIndexer, KvIndexerInterface, RouterEvent}, + indexer::{KvIndexer, KvIndexerInterface}, metrics_aggregator::KvMetricsAggregator, - protocols::{LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult}, + protocols::{ + LocalBlockHash, RouterEvent, RouterRequest, RouterResponse, WorkerSelectionResult, + }, scheduler::{KvScheduler, KvSchedulerError, SchedulingRequest}, scoring::ProcessedEndpoints, }, diff --git a/lib/llm/src/kv_router/indexer.rs b/lib/llm/src/kv_router/indexer.rs index 14b9380fc3..c8eb9a0f88 100644 --- a/lib/llm/src/kv_router/indexer.rs +++ b/lib/llm/src/kv_router/indexer.rs @@ -77,9 +77,6 @@ pub enum KvRouterError { IndexerDroppedRequest, } -/// Identifier of a LLM worker which emits events to the router. -pub type WorkerId = i64; - /// A shared reference to a [`RadixBlock`]. type SharedRadixBlock = Rc>; @@ -133,31 +130,6 @@ pub fn compute_block_hash_for_seq(tokens: &[u32], kv_block_size: usize) -> Vec Self { - Self { worker_id, event } - } -} - /// A block in the Radix Tree. #[derive(Debug)] struct RadixBlock { diff --git a/lib/llm/src/kv_router/protocols.rs b/lib/llm/src/kv_router/protocols.rs index 70a7711521..40badc417b 100644 --- a/lib/llm/src/kv_router/protocols.rs +++ b/lib/llm/src/kv_router/protocols.rs @@ -21,15 +21,18 @@ pub struct RouterRequest { pub tokens: Vec, } +/// Identifier of a LLM worker which emits events to the router. +pub type WorkerId = i64; + #[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct RouterResponse { - pub worker_id: i64, + pub worker_id: WorkerId, } #[derive(Debug)] pub struct WorkerSelectionResult { /// The worker id of the selected worker - pub worker_id: i64, + pub worker_id: WorkerId, /// The total number of blocks required to prefill the request pub required_blocks: u64, @@ -58,14 +61,14 @@ pub struct ForwardPassMetrics { /// A [`LocalBlockHash`] is a hash computed from the tokens_ids, extra_token_ids and the optional /// lora_id of a block. -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Ord, PartialOrd)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub struct LocalBlockHash(pub u64); /// A sequence aware hash of a block where the hash is computed from the tokens_ids, extra_token_ids /// and the optional lora_id of a block, PLUS the hash of the parent block. /// /// In this case, the hashing function is external and unknown. -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Ord, PartialOrd)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub struct ExternalSequenceBlockHash(pub u64); // Implement From trait for convenient conversion @@ -138,6 +141,31 @@ pub struct KvCacheRemoveData { pub block_hashes: Vec, } +/// A [`KvCacheEvent`] on a specific LLM worker denoted by [`WorkerId`]. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RouterEvent { + /// The ID of the worker emitting the event. + pub worker_id: WorkerId, + /// The cache event associated with the worker. + pub event: KvCacheEvent, +} + +impl RouterEvent { + /// Create a new `RouterEvent`. + /// + /// ### Arguments + /// + /// * `worker_id` - The ID of the worker emitting the event. + /// * `event` - The cache event. + /// + /// ### Returns + /// + /// A new `RouterEvent`. + pub fn new(worker_id: WorkerId, event: KvCacheEvent) -> Self { + Self { worker_id, event } + } +} + impl Serialize for LocalBlockHash { fn serialize(&self, serializer: S) -> Result where diff --git a/lib/llm/src/kv_router/publisher.rs b/lib/llm/src/kv_router/publisher.rs index 8beac80ff2..7a5cac0b02 100644 --- a/lib/llm/src/kv_router/publisher.rs +++ b/lib/llm/src/kv_router/publisher.rs @@ -14,9 +14,7 @@ // limitations under the License. use crate::kv_router::{ - indexer::{compute_block_hash_for_seq, RouterEvent}, - protocols::*, - KV_EVENT_SUBJECT, KV_METRICS_ENDPOINT, + indexer::compute_block_hash_for_seq, protocols::*, KV_EVENT_SUBJECT, KV_METRICS_ENDPOINT, }; use async_trait::async_trait; use dynamo_runtime::traits::{events::EventPublisher, DistributedRuntimeProvider, RuntimeProvider}; diff --git a/lib/llm/src/kv_router/recorder.rs b/lib/llm/src/kv_router/recorder.rs index 17c66c7925..8da8f6cbcf 100644 --- a/lib/llm/src/kv_router/recorder.rs +++ b/lib/llm/src/kv_router/recorder.rs @@ -13,7 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::kv_router::indexer::RouterEvent; +use crate::kv_router::protocols::*; use crate::recorder::Recorder; // Type alias for backward compatibility @@ -23,8 +23,6 @@ pub type KvRecorder = Recorder; mod tests { use super::*; use crate::kv_router::indexer::KvIndexer; - use crate::kv_router::indexer::WorkerId; - use crate::kv_router::protocols::*; use std::time::Duration; use tempfile::tempdir; use tokio::fs; From 6bee243ae7de47b8fe83778692d0db90e24e7ff4 Mon Sep 17 00:00:00 2001 From: PeaBrane Date: Thu, 29 May 2025 23:08:11 -0700 Subject: [PATCH 02/22] router can route to dp ranks --- components/metrics/src/bin/mock_worker.rs | 2 +- components/metrics/src/lib.rs | 15 +- components/metrics/src/main.rs | 12 +- components/router/src/main.rs | 4 +- lib/llm/src/kv_router.rs | 47 +++--- lib/llm/src/kv_router/indexer.rs | 153 +++++++++++--------- lib/llm/src/kv_router/metrics_aggregator.rs | 5 +- lib/llm/src/kv_router/protocols.rs | 52 +++++-- lib/llm/src/kv_router/recorder.rs | 33 +++-- lib/llm/src/kv_router/scheduler.rs | 149 +++++++++---------- lib/llm/src/kv_router/scoring.rs | 32 +++- lib/llm/src/kv_router/worker.rs | 105 -------------- 12 files changed, 288 insertions(+), 321 deletions(-) delete mode 100644 lib/llm/src/kv_router/worker.rs diff --git a/components/metrics/src/bin/mock_worker.rs b/components/metrics/src/bin/mock_worker.rs index 6278de73ce..c80f88e328 100644 --- a/components/metrics/src/bin/mock_worker.rs +++ b/components/metrics/src/bin/mock_worker.rs @@ -14,7 +14,7 @@ // limitations under the License. use dynamo_llm::kv_router::{ - protocols::ForwardPassMetrics, scheduler::KVHitRateEvent, KV_HIT_RATE_SUBJECT, + protocols::ForwardPassMetrics, protocols::KVHitRateEvent, KV_HIT_RATE_SUBJECT, }; use dynamo_runtime::{ component::{service::EndpointStats, Namespace}, diff --git a/components/metrics/src/lib.rs b/components/metrics/src/lib.rs index b928938490..68c026e081 100644 --- a/components/metrics/src/lib.rs +++ b/components/metrics/src/lib.rs @@ -84,8 +84,7 @@ use std::net::SocketAddr; use std::time::Duration as StdDuration; use dynamo_llm::kv_router::protocols::ForwardPassMetrics; -use dynamo_llm::kv_router::scheduler::Endpoint; -use dynamo_llm::kv_router::scoring::ProcessedEndpoints; +use dynamo_llm::kv_router::scoring::{Endpoint, ProcessedEndpoints}; use dynamo_runtime::{ distributed::Component, error, service::EndpointInfo, utils::Duration, Result, @@ -455,31 +454,31 @@ impl PrometheusMetrics { &self.kv_blocks_active, config, &worker_id, - metrics.kv_active_blocks as f64, + metrics[0].kv_active_blocks as f64, ); self.set_worker_gauge( &self.kv_blocks_total, config, &worker_id, - metrics.kv_total_blocks as f64, + metrics[0].kv_total_blocks as f64, ); self.set_worker_gauge( &self.requests_active, config, &worker_id, - metrics.request_active_slots as f64, + metrics[0].request_active_slots as f64, ); self.set_worker_gauge( &self.requests_total, config, &worker_id, - metrics.request_total_slots as f64, + metrics[0].request_total_slots as f64, ); self.set_worker_gauge( &self.kv_hit_rate_percent, config, &worker_id, - metrics.gpu_prefix_cache_hit_rate as f64, + metrics[0].gpu_prefix_cache_hit_rate as f64, ); } @@ -602,7 +601,7 @@ pub fn postprocess_metrics( e.id().ok().map(|id| Endpoint { name: format!("worker-{id}"), subject: e.subject.clone(), - data: m.clone(), + data: vec![m.clone()], }) }) .collect(); diff --git a/components/metrics/src/main.rs b/components/metrics/src/main.rs index fa8186d07a..f9a1fac09e 100644 --- a/components/metrics/src/main.rs +++ b/components/metrics/src/main.rs @@ -27,7 +27,7 @@ //! - ISL Blocks: Cumulative count of total blocks in all KV hit rate events //! - Overlap Blocks: Cumulative count of blocks that were already in the KV cache use clap::Parser; -use dynamo_llm::kv_router::scheduler::KVHitRateEvent; +use dynamo_llm::kv_router::protocols::{KVHitRateEvent, WorkerId, DpRank}; use dynamo_llm::kv_router::KV_HIT_RATE_SUBJECT; use dynamo_runtime::{ error, logging, @@ -180,14 +180,15 @@ async fn app(runtime: Runtime) -> Result<()> { tracing::debug!("Successfully subscribed to KV hit rate events"); while let Some(msg) = subscriber.next().await { - match serde_json::from_slice::(&msg.payload) { + match serde_json::from_slice::>(&msg.payload) { Ok(event) => { // TODO: Lower to debug let cache_hit_pct = (event.overlap_blocks as f64 / event.isl_blocks as f64) * 100.0; tracing::debug!( - "Received KV hit rate event: worker_id={}, isl_blocks={}, overlap_blocks={}, cache_hit_pct={:.2}%", - event.worker_id, + "Received KV hit rate event: worker_id={}, dp_rank={}, isl_blocks={}, overlap_blocks={}, cache_hit_pct={:.2}%", + event.worker_id.0, + event.worker_id.1, event.isl_blocks, event.overlap_blocks, cache_hit_pct @@ -197,7 +198,8 @@ async fn app(runtime: Runtime) -> Result<()> { let mut metrics = metrics_collector_clone.lock().await; metrics.update_kv_hit_rate( &config_clone, - event.worker_id, + // TODO: this will not take care of dp ranks + event.worker_id.0, event.isl_blocks, event.overlap_blocks, ); diff --git a/components/router/src/main.rs b/components/router/src/main.rs index 3546a9bb30..48023616d6 100644 --- a/components/router/src/main.rs +++ b/components/router/src/main.rs @@ -25,7 +25,7 @@ use std::sync::Arc; use clap::Parser; use dynamo_llm::kv_router::{ - protocols::WorkerSelectionResult, + protocols::{WorkerSelectionResult, WorkerId, DpRank}, scheduler::{DefaultWorkerSelector, KvSchedulerError, SchedulingRequest}, scoring::ProcessedEndpoints, KvRouter, WorkerSelector, @@ -89,7 +89,7 @@ impl WorkerSelector for CustomWorkerSelector { workers: &ProcessedEndpoints, request: &SchedulingRequest, block_size: usize, - ) -> Result { + ) -> Result, KvSchedulerError> { // customize logic here // F12 into [DefaultWorkerSelector] to see the original logic self.0.select_worker(workers, request, block_size) diff --git a/lib/llm/src/kv_router.rs b/lib/llm/src/kv_router.rs index b86c4cf00e..1744879ea6 100644 --- a/lib/llm/src/kv_router.rs +++ b/lib/llm/src/kv_router.rs @@ -28,7 +28,8 @@ use crate::{ indexer::{KvIndexer, KvIndexerInterface}, metrics_aggregator::KvMetricsAggregator, protocols::{ - LocalBlockHash, RouterEvent, RouterRequest, RouterResponse, WorkerSelectionResult, + DpRank, LocalBlockHash, RouterEvent, RouterRequest, RouterResponse, WorkerId, + WorkerSelectionResult, }, scheduler::{KvScheduler, KvSchedulerError, SchedulingRequest}, scoring::ProcessedEndpoints, @@ -53,13 +54,13 @@ pub trait WorkerSelector { workers: &ProcessedEndpoints, request: &SchedulingRequest, block_size: usize, - ) -> Result; + ) -> Result, KvSchedulerError>; } /// A KvRouter only decides which worker you should use. It doesn't send you there. /// TODO: Rename this to indicate it only selects a worker, it does not route. pub struct KvRouter { - indexer: KvIndexer, + indexer: KvIndexer<(WorkerId, DpRank)>, scheduler: KvScheduler, block_size: usize, } @@ -94,15 +95,16 @@ impl KvRouter { tokio::spawn(async move { while let Some(event) = kv_events_rx.next().await { - let event: RouterEvent = match serde_json::from_slice(&event.payload) { - Ok(event) => event, - Err(e) => { - tracing::warn!("Failed to deserialize RouterEvent: {:?}", e); - // Choosing warn and continue to process other events from other workers - // A bad event likely signals a problem with a worker, but potentially other workers are still healthy - continue; - } - }; + let event: RouterEvent<(WorkerId, DpRank)> = + match serde_json::from_slice(&event.payload) { + Ok(event) => event, + Err(e) => { + tracing::warn!("Failed to deserialize RouterEvent: {:?}", e); + // Choosing warn and continue to process other events from other workers + // A bad event likely signals a problem with a worker, but potentially other workers are still healthy + continue; + } + }; if let Err(e) = kv_events_tx.send(event).await { tracing::debug!("failed to send kv event to indexer; shutting down: {:?}", e); } @@ -117,7 +119,11 @@ impl KvRouter { } // [TODO] indexer needs to take 'lora_id' as parameter - pub async fn schedule(&self, token_ids: &Vec, _lora_id: u64) -> Result { + pub async fn schedule( + &self, + token_ids: &Vec, + _lora_id: u64, + ) -> Result<(WorkerId, DpRank)> { // Extracting part of the code in KvRouter::generate() for only // the decision making part, routing is done by the caller let isl_tokens = token_ids.len(); @@ -132,7 +138,7 @@ impl KvRouter { /// Give these tokens, find the worker with the best match in it's KV cache. /// Returned overlap amount is in number of blocks. - async fn find_best_match(&self, tokens: &[u32]) -> anyhow::Result<(i64, u32)> { + async fn find_best_match(&self, tokens: &[u32]) -> anyhow::Result<((WorkerId, DpRank), u32)> { let isl_tokens = tokens.len(); let block_size = self.block_size; @@ -159,11 +165,17 @@ impl KvRouter { } #[async_trait] -impl AsyncEngine, ManyOut>, Error> for KvRouter { +impl + AsyncEngine< + SingleIn, + ManyOut>>, + Error, + > for KvRouter +{ async fn generate( &self, request: SingleIn, - ) -> Result>> { + ) -> Result>>> { let (request, ctx) = request.into_parts(); let (worker_id, _) = self.find_best_match(&request.tokens).await?; @@ -205,7 +217,8 @@ impl AsyncEngine, ManyOut>, Er let (mut backend_input, context) = request.into_parts(); backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount); let updated_request = context.map(|_| backend_input); - self.inner.direct(updated_request, instance_id).await + // TODO: this does not do dp routing + self.inner.direct(updated_request, instance_id.0).await } } } diff --git a/lib/llm/src/kv_router/indexer.rs b/lib/llm/src/kv_router/indexer.rs index c8eb9a0f88..89fb5f75e9 100644 --- a/lib/llm/src/kv_router/indexer.rs +++ b/lib/llm/src/kv_router/indexer.rs @@ -78,7 +78,7 @@ pub enum KvRouterError { } /// A shared reference to a [`RadixBlock`]. -type SharedRadixBlock = Rc>; +type SharedRadixBlock = Rc>>; pub fn compute_hash(data: &[u8]) -> u64 { xxh3::xxh3_64_with_seed(data, XXH3_SEED) @@ -132,16 +132,16 @@ pub fn compute_block_hash_for_seq(tokens: &[u32], kv_block_size: usize) -> Vec { /// A map of child blocks, keyed by their local block hash. - children: HashMap, + children: HashMap>, /// A set of worker IDs associated with this block. - workers: HashSet, + workers: HashSet, /// A buffer of times that this block was last traversed recent_uses: VecDeque, } -impl RadixBlock { +impl RadixBlock { /// Create a new `RadixBlock`. /// /// ### Returns @@ -156,10 +156,10 @@ impl RadixBlock { } } -pub struct RadixTree { +pub struct RadixTree { /// This is the root of the radix/prefix tree /// This will only contain root blocks - root: SharedRadixBlock, + root: SharedRadixBlock, /// This is a global lookup table for all blocks which will let you jump into /// the radix tree at any point @@ -169,18 +169,18 @@ pub struct RadixTree { /// Transitioning to a radix tree only would require a change in the messaging structure /// as the entire prefix would need to be sent. Alternatively, we could use block_depth /// integers to indicate how many blocks to skip and use a radix/prefix tree at each level. - lookup: HashMap>, + lookup: HashMap>>, /// The time buffer the radix tree should check when considering frequence of block accesses expiration_duration: Option, } -impl Default for RadixTree { +impl Default for RadixTree { fn default() -> Self { Self::new() } } -impl RadixTree { +impl RadixTree { /// Create a new `RadixTree`. /// /// ### Returns @@ -208,7 +208,11 @@ impl RadixTree { /// ### Returns /// /// An `OverlapScores` representing the match scores. - pub fn find_matches(&self, sequence: Vec, early_exit: bool) -> OverlapScores { + pub fn find_matches( + &self, + sequence: Vec, + early_exit: bool, + ) -> OverlapScores { let mut scores = OverlapScores::new(); let mut current = self.root.clone(); let now = Instant::now(); @@ -252,12 +256,12 @@ impl RadixTree { /// ### Arguments /// /// * `event` - The `RouterEvent` to apply. - pub fn apply_event(&mut self, event: RouterEvent) { + pub fn apply_event(&mut self, event: RouterEvent) { let (worker_id, event) = (event.worker_id, event.event); let (id, op) = (event.event_id, event.data); tracing::trace!(id, "Store operation: {:?}", op); - let worker_lookup = self.lookup.entry(worker_id).or_default(); + let worker_lookup = self.lookup.entry(worker_id.clone()).or_default(); match op { KvCacheEventData::Stored(op) => { @@ -273,7 +277,7 @@ impl RadixTree { Some(current) => current.clone(), None => { tracing::warn!( - worker_id = worker_id.to_string(), + worker_id = ?worker_id, id, parent_hash = ?op.parent_hash, "Failed to find parent block; skipping store operation" @@ -303,7 +307,7 @@ impl RadixTree { }; // add our worker_id to the block - block.borrow_mut().workers.insert(worker_id); + block.borrow_mut().workers.insert(worker_id.clone()); // add the block to the worker_id lookup table worker_lookup.insert(block_id.block_hash, block.clone()); @@ -327,7 +331,7 @@ impl RadixTree { Some(entry) => entry.clone(), None => { tracing::warn!( - worker_id = worker_id.to_string(), + worker_id = ?worker_id, id, "Failed to find block to remove; skipping remove operation" ); @@ -348,7 +352,7 @@ impl RadixTree { } } - pub fn remove_worker(&mut self, worker: WorkerId) { + pub fn remove_worker(&mut self, worker: T) { if let Some((_, blocks)) = self.lookup.remove_entry(&worker) { blocks.iter().for_each(|(_, block)| { block.borrow_mut().workers.remove(&worker); @@ -359,20 +363,20 @@ impl RadixTree { /// Scores representing the overlap of workers. #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct OverlapScores { +pub struct OverlapScores { // map of worker_id to score - pub scores: HashMap, + pub scores: HashMap, // List of frequencies that the blocks have been accessed. Entries with value 0 are omitted. pub frequencies: Vec, } -impl Default for OverlapScores { +impl Default for OverlapScores { fn default() -> Self { Self::new() } } -impl OverlapScores { +impl OverlapScores { /// Create a new `OverlapScores`. /// /// ### Returns @@ -389,10 +393,10 @@ impl OverlapScores { /// /// ### Arguments /// - /// * `workers` - A reference to a `HashSet` of `WorkerId`s. - pub fn update_scores(&mut self, workers: &HashSet) { + /// * `workers` - A reference to a `HashSet` of worker IDs. + pub fn update_scores(&mut self, workers: &HashSet) { for worker in workers { - let score = self.scores.entry(*worker).or_insert(0); + let score = self.scores.entry(worker.clone()).or_insert(0); *score += 1; } } @@ -409,17 +413,17 @@ impl OverlapScores { } /// A request to find matches in the Radix Tree. -pub struct MatchRequest { +pub struct MatchRequest { /// A vector of `LocalBlockHash` representing the sequence to match. sequence: Vec, /// A boolean indicating whether to exit early if a single match is found. early_exit: bool, /// A channel sender to send the `OverlapScores` response. - resp: oneshot::Sender, + resp: oneshot::Sender>, } #[async_trait] -pub trait KvIndexerInterface { +pub trait KvIndexerInterface { /// Find matches for a given sequence of `LocalBlockHash`es. /// /// ### Arguments @@ -432,7 +436,7 @@ pub trait KvIndexerInterface { async fn find_matches( &self, sequence: Vec, - ) -> Result; + ) -> Result, KvRouterError>; /// Find matches for a given sequence of tokens. /// @@ -446,43 +450,43 @@ pub trait KvIndexerInterface { async fn find_matches_for_request( &self, tokens: &[u32], - ) -> Result; + ) -> Result, KvRouterError>; /// Apply a `RouterEvent` to the KV store. /// /// ### Arguments /// /// * `event` - The `RouterEvent` to apply. - async fn apply_event(&mut self, event: RouterEvent); + async fn apply_event(&mut self, event: RouterEvent); /// Remove a worker's entries from the trie. /// /// ### Arguments /// /// * `worker` - The worker to remove from the trie. - async fn remove_worker(&mut self, worker: WorkerId); + async fn remove_worker(&mut self, worker: T); /// Shutdown the KV Indexer. fn shutdown(&mut self); } /// The KV Indexer, managing the KV store and handling events and match requests. -pub struct KvIndexer { +pub struct KvIndexer { /// A `CancellationToken` for managing shutdown. cancel: CancellationToken, /// A sender for `RouterEvent`s. - event_tx: mpsc::Sender, + event_tx: mpsc::Sender>, /// A sender for `MatchRequest`s. - match_tx: mpsc::Sender, + match_tx: mpsc::Sender>, /// A sender for remove worker requests. - remove_worker_tx: mpsc::Sender, + remove_worker_tx: mpsc::Sender, /// A handle to the background task managing the KV store. task: OnceLock>, /// The size of the KV block this indexer can handle. kv_block_size: usize, } -impl KvIndexer { +impl KvIndexer { /// Create a new `KvIndexer`. /// /// ### Arguments @@ -498,9 +502,9 @@ impl KvIndexer { expiration_duration: Option, kv_block_size: usize, ) -> Self { - let (event_tx, event_rx) = mpsc::channel::(2048); - let (match_tx, match_rx) = mpsc::channel::(128); - let (remove_worker_tx, remove_worker_rx) = mpsc::channel::(16); + let (event_tx, event_rx) = mpsc::channel::>(2048); + let (match_tx, match_rx) = mpsc::channel::>(128); + let (remove_worker_tx, remove_worker_rx) = mpsc::channel::(16); let cancel_clone = token.clone(); let task = std::thread::spawn(move || { // create a new tokio runtime which will only perform work on a single thread @@ -576,17 +580,17 @@ impl KvIndexer { /// ### Returns /// /// A `mpsc::Sender` for `RouterEvent`s. - pub fn event_sender(&self) -> mpsc::Sender { + pub fn event_sender(&self) -> mpsc::Sender> { self.event_tx.clone() } } #[async_trait] -impl KvIndexerInterface for KvIndexer { +impl KvIndexerInterface for KvIndexer { async fn find_matches( &self, sequence: Vec, - ) -> Result { + ) -> Result, KvRouterError> { let (resp_tx, resp_rx) = oneshot::channel(); let req = MatchRequest { sequence, @@ -610,7 +614,7 @@ impl KvIndexerInterface for KvIndexer { async fn find_matches_for_request( &self, tokens: &[u32], - ) -> Result { + ) -> Result, KvRouterError> { tracing::debug!( "Finding matches for request tokens: {:?} / len: {}", tokens, @@ -621,11 +625,11 @@ impl KvIndexerInterface for KvIndexer { self.find_matches(sequence).await } - async fn apply_event(&mut self, event: RouterEvent) { + async fn apply_event(&mut self, event: RouterEvent) { self.event_tx.send(event).await.unwrap(); } - async fn remove_worker(&mut self, worker: WorkerId) { + async fn remove_worker(&mut self, worker: T) { self.remove_worker_tx.send(worker).await.unwrap(); } @@ -638,28 +642,28 @@ impl KvIndexerInterface for KvIndexer { } #[derive(Debug, Clone)] -pub struct ShardedMatchRequest { +pub struct ShardedMatchRequest { sequence: Vec, early_exit: bool, - resp: mpsc::Sender, + resp: mpsc::Sender>, } /// The KV Indexer, managing the KV store and handling events and match requests. -pub struct KvIndexerSharded { +pub struct KvIndexerSharded { /// A `CancellationToken` for managing shutdown. cancel: CancellationToken, /// The size of the KV block this indexer can handle. kv_block_size: usize, - worker_assignments: HashMap, + worker_assignments: HashMap, worker_counts: Vec, - event_tx: Vec>, - request_broadcast_tx: broadcast::Sender, - remove_worker_tx: Vec>, + event_tx: Vec>>, + request_broadcast_tx: broadcast::Sender>, + remove_worker_tx: Vec>, tasks: Vec>, } -impl KvIndexerSharded { +impl KvIndexerSharded { /// Create a new `KvIndexerSharded`. /// /// ### Arguments @@ -677,19 +681,18 @@ impl KvIndexerSharded { expiration_duration: Option, kv_block_size: usize, ) -> Self { - let worker_assignments: HashMap = HashMap::new(); + let worker_assignments: HashMap = HashMap::new(); let worker_counts: Vec = vec![0; num_shards]; let mut event_tx = Vec::new(); let mut remove_worker_tx = Vec::new(); let mut tasks = Vec::new(); - let (request_broadcast_tx, _) = broadcast::channel::(1048576); + let (request_broadcast_tx, _) = broadcast::channel::>(1048576); for _ in 0..num_shards { - let (shard_event_tx, mut shard_event_rx) = mpsc::channel::(2048); - let (shard_remove_worker_tx, mut shard_remove_worker_rx) = - mpsc::channel::(16); + let (shard_event_tx, mut shard_event_rx) = mpsc::channel::>(2048); + let (shard_remove_worker_tx, mut shard_remove_worker_rx) = mpsc::channel::(16); let mut shard_broadcast_rx = request_broadcast_tx.subscribe(); let cancel = token.clone(); @@ -764,11 +767,11 @@ impl KvIndexerSharded { } #[async_trait] -impl KvIndexerInterface for KvIndexerSharded { +impl KvIndexerInterface for KvIndexerSharded { async fn find_matches( &self, sequence: Vec, - ) -> Result { + ) -> Result, KvRouterError> { 'match_loop: loop { let (match_tx, mut match_rx) = mpsc::channel(self.event_tx.len()); self.request_broadcast_tx @@ -815,12 +818,12 @@ impl KvIndexerInterface for KvIndexerSharded { async fn find_matches_for_request( &self, tokens: &[u32], - ) -> Result { + ) -> Result, KvRouterError> { let sequence = compute_block_hash_for_seq(tokens, self.kv_block_size); self.find_matches(sequence).await } - async fn apply_event(&mut self, event: RouterEvent) { + async fn apply_event(&mut self, event: RouterEvent) { #[allow(clippy::map_entry)] if !self.worker_assignments.contains_key(&event.worker_id) { // Get the shard with the smallest amount of workers. @@ -833,7 +836,7 @@ impl KvIndexerInterface for KvIndexerSharded { .0; self.worker_assignments - .insert(event.worker_id, selected_shard); + .insert(event.worker_id.clone(), selected_shard); self.worker_counts[selected_shard] += 1; } @@ -843,7 +846,7 @@ impl KvIndexerInterface for KvIndexerSharded { .unwrap(); } - async fn remove_worker(&mut self, worker: WorkerId) { + async fn remove_worker(&mut self, worker: T) { if let Some((_, shard)) = self.worker_assignments.remove_entry(&worker) { self.worker_counts[shard] -= 1; self.remove_worker_tx[shard].send(worker).await.unwrap(); @@ -861,13 +864,15 @@ impl KvIndexerInterface for KvIndexerSharded { #[cfg(test)] mod tests { - use super::*; use rstest::rstest; use rstest_reuse::{self, *}; use tokio::time; use tokio_util::sync::CancellationToken; + // Use u64 as a simple WorkerIdTrait implementation for tests + type TestWorkerId = u64; + fn setup() { dynamo_runtime::logging::init(); } @@ -893,11 +898,11 @@ mod tests { } fn create_store_event( - worker_id: WorkerId, + worker_id: TestWorkerId, event_id: u64, hashes: Vec, parent: Option, - ) -> RouterEvent { + ) -> RouterEvent { RouterEvent { worker_id, event: KvCacheEvent { @@ -907,7 +912,11 @@ mod tests { } } - fn create_remove_event(worker_id: WorkerId, event_id: u64, hashes: Vec) -> RouterEvent { + fn create_remove_event( + worker_id: TestWorkerId, + event_id: u64, + hashes: Vec, + ) -> RouterEvent { RouterEvent { worker_id, event: KvCacheEvent { @@ -1208,7 +1217,7 @@ mod tests { token: &CancellationToken, num_shards: usize, kv_block_size: usize, - ) -> Box { + ) -> Box> { if num_shards == 1 { Box::new(KvIndexer::new(token.clone(), kv_block_size)) } else { @@ -1293,7 +1302,7 @@ mod tests { const ONE_MILLIS: Duration = Duration::from_millis(1); setup(); - let mut kv_indexer: Box; + let mut kv_indexer: Box>; let token = CancellationToken::new(); let expiration = Duration::from_millis(50); @@ -1421,7 +1430,7 @@ mod tests { #[test] fn test_radix_tree_default() { setup(); - let radix_tree: RadixTree = Default::default(); + let radix_tree: RadixTree = Default::default(); assert!(radix_tree.root.borrow().children.is_empty()); assert!(radix_tree.root.borrow().workers.is_empty()); assert!(radix_tree.lookup.is_empty()); @@ -1430,7 +1439,7 @@ mod tests { #[test] fn test_overlap_scores_default() { setup(); - let overlap_scores: OverlapScores = Default::default(); + let overlap_scores: OverlapScores = Default::default(); assert!(overlap_scores.scores.is_empty()); } } diff --git a/lib/llm/src/kv_router/metrics_aggregator.rs b/lib/llm/src/kv_router/metrics_aggregator.rs index 156d1dfb02..6d824cefdf 100644 --- a/lib/llm/src/kv_router/metrics_aggregator.rs +++ b/lib/llm/src/kv_router/metrics_aggregator.rs @@ -18,8 +18,7 @@ use std::sync::Once; pub use crate::kv_router::protocols::ForwardPassMetrics; use crate::kv_router::KV_METRICS_ENDPOINT; -use crate::kv_router::scheduler::Endpoint; -use crate::kv_router::ProcessedEndpoints; +use crate::kv_router::scoring::{Endpoint, ProcessedEndpoints}; use dynamo_runtime::component::Component; use dynamo_runtime::{service::EndpointInfo, utils::Duration, Result}; use tokio::sync::watch; @@ -119,7 +118,7 @@ pub async fn collect_endpoints_task( .into_iter() .filter(|s| s.data.is_some()) .filter_map(|s| - match s.data.unwrap().decode::() { + match s.data.unwrap().decode::>() { Ok(data) => Some(Endpoint { name: s.name, subject: s.subject, diff --git a/lib/llm/src/kv_router/protocols.rs b/lib/llm/src/kv_router/protocols.rs index 40badc417b..237657e58b 100644 --- a/lib/llm/src/kv_router/protocols.rs +++ b/lib/llm/src/kv_router/protocols.rs @@ -15,24 +15,47 @@ use crate::tokens::Token; use serde::{Deserialize, Serialize}; +use std::cmp::Eq; +use std::fmt::Debug; +use std::hash::Hash; + +pub type WorkerId = i64; + +pub type DpRank = u32; + +pub trait WorkerGeneral: + Hash + Eq + Debug + Clone + Send + Sync + Default + 'static + Serialize +{ +} + +impl WorkerGeneral for T where + T: Hash + + Eq + + Debug + + Clone + + Send + + Sync + + Default + + 'static + + Serialize + + for<'de> Deserialize<'de> +{ +} #[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct RouterRequest { pub tokens: Vec, } -/// Identifier of a LLM worker which emits events to the router. -pub type WorkerId = i64; - #[derive(Debug, Clone, Serialize, Deserialize, Default)] -pub struct RouterResponse { - pub worker_id: WorkerId, +pub struct RouterResponse { + pub worker_id: T, } #[derive(Debug)] -pub struct WorkerSelectionResult { +pub struct WorkerSelectionResult { /// The worker id of the selected worker - pub worker_id: WorkerId, + pub worker_id: T, /// The total number of blocks required to prefill the request pub required_blocks: u64, @@ -141,16 +164,23 @@ pub struct KvCacheRemoveData { pub block_hashes: Vec, } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct KVHitRateEvent { + pub worker_id: T, + pub isl_blocks: usize, + pub overlap_blocks: usize, +} + /// A [`KvCacheEvent`] on a specific LLM worker denoted by [`WorkerId`]. #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct RouterEvent { +pub struct RouterEvent { /// The ID of the worker emitting the event. - pub worker_id: WorkerId, + pub worker_id: T, /// The cache event associated with the worker. pub event: KvCacheEvent, } -impl RouterEvent { +impl RouterEvent { /// Create a new `RouterEvent`. /// /// ### Arguments @@ -161,7 +191,7 @@ impl RouterEvent { /// ### Returns /// /// A new `RouterEvent`. - pub fn new(worker_id: WorkerId, event: KvCacheEvent) -> Self { + pub fn new(worker_id: T, event: KvCacheEvent) -> Self { Self { worker_id, event } } } diff --git a/lib/llm/src/kv_router/recorder.rs b/lib/llm/src/kv_router/recorder.rs index 8da8f6cbcf..40cdbffcbe 100644 --- a/lib/llm/src/kv_router/recorder.rs +++ b/lib/llm/src/kv_router/recorder.rs @@ -16,8 +16,8 @@ use crate::kv_router::protocols::*; use crate::recorder::Recorder; -// Type alias for backward compatibility -pub type KvRecorder = Recorder; +// Type alias for backward compatibility, now generic +pub type KvRecorder = Recorder>; #[cfg(test)] mod tests { @@ -28,6 +28,9 @@ mod tests { use tokio::fs; use tokio_util::sync::CancellationToken; + // Use i64 for tests + type TestWorkerId = i64; + fn make_blocks(hashes: Vec) -> Vec { hashes .iter() @@ -49,11 +52,11 @@ mod tests { } fn create_store_event( - worker_id: WorkerId, + worker_id: TestWorkerId, event_id: u64, hashes: Vec, parent: Option, - ) -> RouterEvent { + ) -> RouterEvent { RouterEvent::new( worker_id, KvCacheEvent { @@ -63,7 +66,11 @@ mod tests { ) } - fn create_remove_event(worker_id: WorkerId, event_id: u64, hashes: Vec) -> RouterEvent { + fn create_remove_event( + worker_id: TestWorkerId, + event_id: u64, + hashes: Vec, + ) -> RouterEvent { RouterEvent::new( worker_id, KvCacheEvent { @@ -86,7 +93,7 @@ mod tests { // Part 1: Record events to a file let token = CancellationToken::new(); - let recorder = KvRecorder::new(token.clone(), &file_path, None, None, None) + let recorder = KvRecorder::::new(token.clone(), &file_path, None, None, None) .await .unwrap(); let event_tx = recorder.event_sender(); @@ -126,13 +133,19 @@ mod tests { // Part 2: Now create a KvIndexer and load the events from the file let indexer_token = CancellationToken::new(); let kv_block_size = 32; // Default block size for testing - let indexer = KvIndexer::new(indexer_token.clone(), kv_block_size); + let indexer = KvIndexer::::new(indexer_token.clone(), kv_block_size); let indexer_event_tx = indexer.event_sender(); // Use the send_events method to load events from file to indexer - let count = KvRecorder::send_events(&file_path, &indexer_event_tx, false, None, None) - .await - .unwrap(); + let count = KvRecorder::::send_events( + &file_path, + &indexer_event_tx, + false, + None, + None, + ) + .await + .unwrap(); assert_eq!(count, 2, "Expected to send 2 events from file to indexer"); } } diff --git a/lib/llm/src/kv_router/scheduler.rs b/lib/llm/src/kv_router/scheduler.rs index eee53368f0..3c1aec464a 100644 --- a/lib/llm/src/kv_router/scheduler.rs +++ b/lib/llm/src/kv_router/scheduler.rs @@ -16,7 +16,6 @@ use dynamo_runtime::component::Namespace; use dynamo_runtime::traits::events::EventPublisher; use rand::Rng; -use serde::{Deserialize, Serialize}; use std::borrow::BorrowMut; use std::collections::HashMap; @@ -25,16 +24,9 @@ pub use crate::kv_router::protocols::ForwardPassMetrics; use crate::kv_router::scoring::ProcessedEndpoints; use crate::kv_router::KV_HIT_RATE_SUBJECT; -use super::protocols::WorkerSelectionResult; +use super::protocols::{DpRank, KVHitRateEvent, WorkerId, WorkerSelectionResult}; use super::WorkerSelector; -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct KVHitRateEvent { - pub worker_id: i64, - pub isl_blocks: usize, - pub overlap_blocks: usize, -} - #[derive(Debug, thiserror::Error)] pub enum KvSchedulerError { #[error("no endpoints aviailable to route work")] @@ -47,39 +39,15 @@ pub enum KvSchedulerError { SubscriberShutdown, } -/// [gluo FIXME] exactly the same as EndpointInfo except that 'data' -/// is cleaned (not optional) -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Endpoint { - pub name: String, - pub subject: String, - pub data: ForwardPassMetrics, -} - -impl Endpoint { - pub fn worker_id(&self) -> i64 { - i64::from_str_radix( - self.subject - .split("-") - .last() - .expect("invalid subject") - .to_string() - .as_str(), - 16, - ) - .expect("invalid worker id") - } -} - pub struct SchedulingRequest { pub isl_tokens: usize, - pub overlap: OverlapScores, - resp_tx: tokio::sync::oneshot::Sender, + pub overlap: OverlapScores<(WorkerId, DpRank)>, + resp_tx: tokio::sync::oneshot::Sender<(WorkerId, DpRank)>, } impl SchedulingRequest { - pub fn respond(self, worker_id: i64) { - if self.resp_tx.send(worker_id).is_err() { + pub fn respond(self, identifier: (WorkerId, DpRank)) { + if self.resp_tx.send(identifier).is_err() { tracing::trace!("failed to send response to requestor"); } } @@ -100,7 +68,8 @@ impl KvScheduler { let mut endpoints_rx = endpoints_rx; let mut endpoints: ProcessedEndpoints = endpoints_rx.borrow_and_update().clone(); - let (event_tx, event_rx) = tokio::sync::mpsc::unbounded_channel::(); + let (event_tx, event_rx) = + tokio::sync::mpsc::unbounded_channel::>(); tokio::spawn(async move { let mut event_rx = event_rx; while let Some(event) = event_rx.recv().await { @@ -178,9 +147,9 @@ impl KvScheduler { pub async fn schedule( &self, - overlap: OverlapScores, + overlap: OverlapScores<(WorkerId, DpRank)>, isl_tokens: usize, - ) -> Result { + ) -> Result<(WorkerId, DpRank), KvSchedulerError> { let (resp_tx, resp_rx) = tokio::sync::oneshot::channel(); let request = SchedulingRequest { isl_tokens, @@ -201,20 +170,22 @@ impl KvScheduler { // This becomes the driver function that handles the selection result pub fn process_worker_selection( workers: &mut ProcessedEndpoints, - selection: WorkerSelectionResult, - event_tx: &tokio::sync::mpsc::UnboundedSender, -) -> i64 { + selection: WorkerSelectionResult<(WorkerId, DpRank)>, + event_tx: &tokio::sync::mpsc::UnboundedSender>, +) -> (WorkerId, DpRank) { let worker = workers .endpoints - .get_mut(&selection.worker_id) + .get_mut(&selection.worker_id.0) .expect("worker not found"); + let dp_rank = selection.worker_id.1 as usize; + // Update worker state predictively // Will be overwritten on next polling of metrics - worker.data.num_requests_waiting += 1; + worker.data[dp_rank].num_requests_waiting += 1; // Assumes radix attention so KV load is only incremented by uncached blocks // overlap_blocks can be bigger than required_blocks. I don't know if that's a bug or not. - worker.data.kv_active_blocks += selection + worker.data[dp_rank].kv_active_blocks += selection .required_blocks .saturating_sub(selection.overlap_blocks as u64); @@ -227,7 +198,7 @@ pub fn process_worker_selection( tracing::warn!("Failed to send KV hit rate event: {:?}", e); } - selection.worker_id + (selection.worker_id.0, selection.worker_id.1) } // Default implementation matching the Python _cost_function @@ -240,7 +211,7 @@ impl WorkerSelector for DefaultWorkerSelector { workers: &ProcessedEndpoints, request: &SchedulingRequest, block_size: usize, - ) -> Result { + ) -> Result, KvSchedulerError> { assert!(request.isl_tokens > 0); if workers.endpoints.is_empty() { @@ -252,14 +223,18 @@ impl WorkerSelector for DefaultWorkerSelector { // Calculate worker scores and find max waiting requests for (worker_id, ep) in workers.endpoints.iter() { - // Calculate score similar to Python version - if let Some(score) = request.overlap.scores.get(worker_id) { - let score = *score as f64 * block_size as f64 / request.isl_tokens as f64; - worker_scores.insert(worker_id, score); + for dp_rank_maybe in ep.data.iter().map(|metrics| metrics.data_parallel_rank) { + let dp_rank = dp_rank_maybe.unwrap(); + if let Some(score) = request.overlap.scores.get(&(*worker_id, dp_rank)) { + let score = *score as f64 * block_size as f64 / request.isl_tokens as f64; + worker_scores.insert((*worker_id, dp_rank), score); + } + // Track max waiting requests + max_waiting = f64::max( + max_waiting, + ep.data[dp_rank as usize].num_requests_waiting as f64, + ); } - - // Track max waiting requests - max_waiting = f64::max(max_waiting, ep.data.num_requests_waiting as f64); } // make immutable @@ -272,36 +247,42 @@ impl WorkerSelector for DefaultWorkerSelector { for (worker_id, ep) in workers.endpoints.iter() { let worker_id = *worker_id; + for fwd_pass_metrics in ep.data.iter() { + let dp_rank = fwd_pass_metrics.data_parallel_rank.unwrap(); + + // Get score or default to 0.0 + let score = worker_scores + .get(&(worker_id, dp_rank)) + .copied() + .unwrap_or(0.0); + + // Calculate normalized metrics + let gpu_cache_usage = ep.data[dp_rank as usize].gpu_cache_usage_perc as f64; + let normalized_waiting = if max_waiting > 0.0 { + ep.data[dp_rank as usize].num_requests_waiting as f64 / max_waiting + } else { + 0.0 + }; - // Get score or default to 0.0 - let score = worker_scores.get(&worker_id).copied().unwrap_or(0.0); - - // Calculate normalized metrics - let gpu_cache_usage = ep.data.gpu_cache_usage_perc as f64; - let normalized_waiting = if max_waiting > 0.0 { - ep.data.num_requests_waiting as f64 / max_waiting - } else { - 0.0 - }; - - // Calculate logit using same formula as Python - let logit = 2.0 * score - gpu_cache_usage - normalized_waiting; - - tracing::trace!( - "Formula for {worker_id}: {logit:.3} = 2.0 * {score:.3} - {gpu_cache_usage:.3} - {normalized_waiting:.3}", - ); - - // Track best workers - match logit.partial_cmp(&best_logit) { - Some(std::cmp::Ordering::Greater) => { - best_logit = logit; - best_workers.clear(); - best_workers.push(worker_id); - } - Some(std::cmp::Ordering::Equal) => { - best_workers.push(worker_id); + // Calculate logit using same formula as Python + let logit = 2.0 * score - gpu_cache_usage - normalized_waiting; + + tracing::trace!( + "Formula for {worker_id}: {logit:.3} = 2.0 * {score:.3} - {gpu_cache_usage:.3} - {normalized_waiting:.3}", + ); + + // Track best workers + match logit.partial_cmp(&best_logit) { + Some(std::cmp::Ordering::Greater) => { + best_logit = logit; + best_workers.clear(); + best_workers.push((worker_id, dp_rank)); + } + Some(std::cmp::Ordering::Equal) => { + best_workers.push((worker_id, dp_rank)); + } + _ => {} } - _ => {} } } @@ -321,7 +302,7 @@ impl WorkerSelector for DefaultWorkerSelector { }; // Lower to trace level eventually. Nice to see KV routing working for now. - tracing::debug!("Selected worker: {worker_id}, logit: {best_logit:.3}"); + tracing::debug!("Selected worker: {worker_id:?}, logit: {best_logit:.3}"); // Log selection metrics let total_blocks = std::cmp::max(request.isl_tokens / block_size, 1) as u64; diff --git a/lib/llm/src/kv_router/scoring.rs b/lib/llm/src/kv_router/scoring.rs index c663c22b5a..625bb80e85 100644 --- a/lib/llm/src/kv_router/scoring.rs +++ b/lib/llm/src/kv_router/scoring.rs @@ -18,11 +18,36 @@ use serde::{Deserialize, Serialize}; use std::collections::HashMap; -use crate::kv_router::scheduler::Endpoint; +use crate::kv_router::protocols::{ForwardPassMetrics, WorkerId}; + +/// [gluo FIXME] exactly the same as EndpointInfo except that 'data' +/// is cleaned (not optional) +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Endpoint { + pub name: String, + pub subject: String, + // one set of metrics for each dp worker + pub data: Vec, +} + +impl Endpoint { + pub fn worker_id(&self) -> i64 { + i64::from_str_radix( + self.subject + .split("-") + .last() + .expect("invalid subject") + .to_string() + .as_str(), + 16, + ) + .expect("invalid worker id") + } +} #[derive(Debug, Default, Serialize, Deserialize, Clone)] pub struct ProcessedEndpoints { - pub endpoints: HashMap, + pub endpoints: HashMap, pub load_avg: f64, pub load_std: f64, } @@ -32,7 +57,8 @@ impl ProcessedEndpoints { // compute some basic statistics let load_values: Vec = endpoints .iter() - .map(|x| x.data.kv_active_blocks as f64) + .flat_map(|endpoint| endpoint.data.iter()) + .map(|metrics| metrics.kv_active_blocks as f64) .collect(); let load_avg = load_values.iter().copied().sum::() / load_values.len() as f64; let variance = load_values diff --git a/lib/llm/src/kv_router/worker.rs b/lib/llm/src/kv_router/worker.rs deleted file mode 100644 index fc44624f85..0000000000 --- a/lib/llm/src/kv_router/worker.rs +++ /dev/null @@ -1,105 +0,0 @@ -// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::sync::Arc; - -pub use crate::kv_router::protocols::ForwardPassMetrics; - -use anyhow::Result; -use derive_builder::Builder; -use dynamo_runtime::pipeline::network::{ - ingress::push_endpoint::PushEndpoint, - PushWorkHandler, -}; - -use dynamo_runtime::transports::nats::{self, ServiceExt}; - -use tokio::sync::watch; -use tokio_util::sync::CancellationToken; -use tracing as log; - -#[derive(Builder)] -pub struct KvRoutedIngress { - #[builder(setter(into))] - pub service_name: String, - - #[builder(setter(into))] - pub worker_id: String, - - pub nats: nats::Client, - pub service_handler: Arc, - pub metrics_rx: watch::Receiver>, - pub cancellation_token: CancellationToken, -} - -/// version of crate -pub const VERSION: &str = env!("CARGO_PKG_VERSION"); - -impl KvRoutedIngress { - pub fn builder() -> KvRoutedIngressBuilder { - KvRoutedIngressBuilder::default() - } - - pub async fn start(self) -> Result<()> { - let worker_id = self.worker_id; - - log::trace!( - worker_id, - "Starting nats service: {}:{}", - self.service_name, - VERSION - ); - - let mut metrics_rx = self.metrics_rx; - let worker_id_clone = worker_id.clone(); - - let service = self - .nats - .client() - .service_builder() - .description("A handy min max service") - .stats_handler(move |name, stats| { - log::debug!( - worker_id = worker_id_clone.as_str(), - "[IN worker?] Stats for service {}: {:?}", - name, - stats - ); - let metrics = metrics_rx.borrow_and_update().clone(); - serde_json::to_value(&*metrics).unwrap() - }) - .start(self.service_name.as_str(), VERSION) - .await - .map_err(|e| anyhow::anyhow!("Failed to start service: {e}"))?; - - let group = service.group(self.service_name.as_str()); - - log::trace!(worker_id, "Starting endpoint: {}", worker_id); - - // creates an endpoint for the service - let service_endpoint = group - .endpoint(worker_id.clone()) - .await - .map_err(|e| anyhow::anyhow!("Failed to start endpoint: {e}"))?; - - let push_endpoint = PushEndpoint::builder() - .service_handler(self.service_handler) - .cancellation_token(self.cancellation_token) - .build() - .map_err(|e| anyhow::anyhow!("Failed to build push endpoint: {e}"))?; - - push_endpoint.start(service_endpoint).await - } -} From dab052ca2614fe989d86fb25656236b280f0d180 Mon Sep 17 00:00:00 2001 From: PeaBrane Date: Thu, 29 May 2025 23:41:52 -0700 Subject: [PATCH 03/22] make the bunny hoppy --- lib/llm/src/kv_router/protocols.rs | 12 ++---------- lib/llm/src/kv_router/scheduler.rs | 4 ++-- lib/llm/src/kv_router/scoring.rs | 3 +++ 3 files changed, 7 insertions(+), 12 deletions(-) diff --git a/lib/llm/src/kv_router/protocols.rs b/lib/llm/src/kv_router/protocols.rs index 237657e58b..5e8d3bacab 100644 --- a/lib/llm/src/kv_router/protocols.rs +++ b/lib/llm/src/kv_router/protocols.rs @@ -14,6 +14,7 @@ // limitations under the License. use crate::tokens::Token; +use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; use std::cmp::Eq; use std::fmt::Debug; @@ -29,16 +30,7 @@ pub trait WorkerGeneral: } impl WorkerGeneral for T where - T: Hash - + Eq - + Debug - + Clone - + Send - + Sync - + Default - + 'static - + Serialize - + for<'de> Deserialize<'de> + T: Hash + Eq + Debug + Clone + Send + Sync + Default + 'static + Serialize + DeserializeOwned { } diff --git a/lib/llm/src/kv_router/scheduler.rs b/lib/llm/src/kv_router/scheduler.rs index 3c1aec464a..593623811e 100644 --- a/lib/llm/src/kv_router/scheduler.rs +++ b/lib/llm/src/kv_router/scheduler.rs @@ -224,7 +224,7 @@ impl WorkerSelector for DefaultWorkerSelector { // Calculate worker scores and find max waiting requests for (worker_id, ep) in workers.endpoints.iter() { for dp_rank_maybe in ep.data.iter().map(|metrics| metrics.data_parallel_rank) { - let dp_rank = dp_rank_maybe.unwrap(); + let dp_rank = dp_rank_maybe.unwrap_or(0); if let Some(score) = request.overlap.scores.get(&(*worker_id, dp_rank)) { let score = *score as f64 * block_size as f64 / request.isl_tokens as f64; worker_scores.insert((*worker_id, dp_rank), score); @@ -248,7 +248,7 @@ impl WorkerSelector for DefaultWorkerSelector { for (worker_id, ep) in workers.endpoints.iter() { let worker_id = *worker_id; for fwd_pass_metrics in ep.data.iter() { - let dp_rank = fwd_pass_metrics.data_parallel_rank.unwrap(); + let dp_rank = fwd_pass_metrics.data_parallel_rank.unwrap_or(0); // Get score or default to 0.0 let score = worker_scores diff --git a/lib/llm/src/kv_router/scoring.rs b/lib/llm/src/kv_router/scoring.rs index 625bb80e85..13f9f90c9f 100644 --- a/lib/llm/src/kv_router/scoring.rs +++ b/lib/llm/src/kv_router/scoring.rs @@ -60,6 +60,9 @@ impl ProcessedEndpoints { .flat_map(|endpoint| endpoint.data.iter()) .map(|metrics| metrics.kv_active_blocks as f64) .collect(); + if load_values.len() == 0 { + panic!("No endpoints to process!") + }; let load_avg = load_values.iter().copied().sum::() / load_values.len() as f64; let variance = load_values .iter() From 34e5c5bac0ea8a28228d9815778ef75c5a522d26 Mon Sep 17 00:00:00 2001 From: PeaBrane Date: Fri, 30 May 2025 15:25:45 -0700 Subject: [PATCH 04/22] new struct combining worker_id with dp_rank, dirty commit, breaks bindings --- components/metrics/src/bin/mock_worker.rs | 2 +- components/metrics/src/main.rs | 11 ++-- components/router/src/main.rs | 4 +- lib/bindings/python/rust/llm/kv.rs | 22 ++++--- lib/llm/src/kv_router.rs | 33 +++++------ lib/llm/src/kv_router/indexer.rs | 19 +++--- lib/llm/src/kv_router/protocols.rs | 22 ++++--- lib/llm/src/kv_router/publisher.rs | 61 ++++++++++++++------ lib/llm/src/kv_router/scheduler.rs | 70 +++++++++++++---------- 9 files changed, 148 insertions(+), 96 deletions(-) diff --git a/components/metrics/src/bin/mock_worker.rs b/components/metrics/src/bin/mock_worker.rs index c80f88e328..10dd4c946d 100644 --- a/components/metrics/src/bin/mock_worker.rs +++ b/components/metrics/src/bin/mock_worker.rs @@ -89,7 +89,7 @@ async fn mock_event_publisher(namespace: Namespace) { let overlap_blocks = rand::rng().random_range(0..=isl_blocks); let event = KVHitRateEvent { - worker_id, + worker_id_general: worker_id, isl_blocks, overlap_blocks, }; diff --git a/components/metrics/src/main.rs b/components/metrics/src/main.rs index f9a1fac09e..ce78cc94c2 100644 --- a/components/metrics/src/main.rs +++ b/components/metrics/src/main.rs @@ -27,7 +27,7 @@ //! - ISL Blocks: Cumulative count of total blocks in all KV hit rate events //! - Overlap Blocks: Cumulative count of blocks that were already in the KV cache use clap::Parser; -use dynamo_llm::kv_router::protocols::{KVHitRateEvent, WorkerId, DpRank}; +use dynamo_llm::kv_router::protocols::{WorkerWithDpRank, KVHitRateEvent, WorkerId}; use dynamo_llm::kv_router::KV_HIT_RATE_SUBJECT; use dynamo_runtime::{ error, logging, @@ -180,15 +180,16 @@ async fn app(runtime: Runtime) -> Result<()> { tracing::debug!("Successfully subscribed to KV hit rate events"); while let Some(msg) = subscriber.next().await { - match serde_json::from_slice::>(&msg.payload) { + match serde_json::from_slice::>(&msg.payload) + { Ok(event) => { // TODO: Lower to debug let cache_hit_pct = (event.overlap_blocks as f64 / event.isl_blocks as f64) * 100.0; tracing::debug!( "Received KV hit rate event: worker_id={}, dp_rank={}, isl_blocks={}, overlap_blocks={}, cache_hit_pct={:.2}%", - event.worker_id.0, - event.worker_id.1, + event.worker_id_general.0, + event.worker_id_general.1, event.isl_blocks, event.overlap_blocks, cache_hit_pct @@ -199,7 +200,7 @@ async fn app(runtime: Runtime) -> Result<()> { metrics.update_kv_hit_rate( &config_clone, // TODO: this will not take care of dp ranks - event.worker_id.0, + event.worker_id_general.0, event.isl_blocks, event.overlap_blocks, ); diff --git a/components/router/src/main.rs b/components/router/src/main.rs index 48023616d6..50f26ecccb 100644 --- a/components/router/src/main.rs +++ b/components/router/src/main.rs @@ -25,7 +25,7 @@ use std::sync::Arc; use clap::Parser; use dynamo_llm::kv_router::{ - protocols::{WorkerSelectionResult, WorkerId, DpRank}, + protocols::{WorkerWithDpRank, WorkerSelectionResult}, scheduler::{DefaultWorkerSelector, KvSchedulerError, SchedulingRequest}, scoring::ProcessedEndpoints, KvRouter, WorkerSelector, @@ -89,7 +89,7 @@ impl WorkerSelector for CustomWorkerSelector { workers: &ProcessedEndpoints, request: &SchedulingRequest, block_size: usize, - ) -> Result, KvSchedulerError> { + ) -> Result, KvSchedulerError> { // customize logic here // F12 into [DefaultWorkerSelector] to see the original logic self.0.select_worker(workers, request, block_size) diff --git a/lib/bindings/python/rust/llm/kv.rs b/lib/bindings/python/rust/llm/kv.rs index 2d7b3d92b5..f860080fe9 100644 --- a/lib/bindings/python/rust/llm/kv.rs +++ b/lib/bindings/python/rust/llm/kv.rs @@ -22,7 +22,7 @@ use rs::traits::events::EventSubscriber; use tracing; use llm_rs::kv_router::protocols::*; -use llm_rs::kv_router::publisher::{create_stored_blocks, KvEventSourceConfig}; +use llm_rs::kv_router::publisher::{create_stored_blocks, KvEventSourceConfig, KvCacheEventWithDp}; #[pyclass] pub(crate) struct KvRouter { @@ -243,8 +243,11 @@ impl KvEventPublisher { ), }), }; + let event_with_dp = KvCacheEventWithDp { + kv_cache_event: event, dp_rank: None, + }; - self.inner.publish(event).map_err(to_pyerr) + self.inner.publish(event_with_dp).map_err(to_pyerr) } fn publish_removed(&self, _py: Python, event_id: u64, block_hashes: Vec) -> PyResult<()> { @@ -256,21 +259,24 @@ impl KvEventPublisher { event_id, data: KvCacheEventData::Removed(KvCacheRemoveData { block_hashes }), }; + let event_with_dp = KvCacheEventWithDp { + kv_cache_event: event, dp_rank: None, + }; - self.inner.publish(event).map_err(to_pyerr) + self.inner.publish(event_with_dp).map_err(to_pyerr) } } #[pyclass] #[derive(Clone)] pub(crate) struct OverlapScores { - inner: llm_rs::kv_router::indexer::OverlapScores, + inner: llm_rs::kv_router::indexer::OverlapScores<(WorkerId, DpRank)>, } #[pymethods] impl OverlapScores { #[getter] - fn scores(&self) -> HashMap { + fn scores(&self) -> HashMap<(WorkerId, DpRank), u32> { self.inner.scores.clone() } @@ -282,7 +288,7 @@ impl OverlapScores { #[pyclass] pub(crate) struct KvIndexer { - inner: Arc, + inner: Arc>, } #[pymethods] @@ -291,7 +297,7 @@ impl KvIndexer { fn new(component: Component, kv_block_size: usize) -> PyResult { let runtime = pyo3_async_runtimes::tokio::get_runtime(); runtime.block_on(async { - let inner: Arc = + let inner: Arc> = llm_rs::kv_router::indexer::KvIndexer::new( component.inner.drt().runtime().child_token(), kv_block_size, @@ -310,7 +316,7 @@ impl KvIndexer { // should have been made to a trait and implemented here? i.e. AsyncEngine style tokio::spawn(async move { while let Some(event) = kv_events_rx.next().await { - let event: llm_rs::kv_router::indexer::RouterEvent = + let event: llm_rs::kv_router::protocols::RouterEvent<(WorkerId, DpRank)> = serde_json::from_slice(&event.payload).unwrap(); tracing::debug!("received kv event: {:?}", event); if let Err(e) = kv_events_tx.send(event).await { diff --git a/lib/llm/src/kv_router.rs b/lib/llm/src/kv_router.rs index 1744879ea6..72ce4956eb 100644 --- a/lib/llm/src/kv_router.rs +++ b/lib/llm/src/kv_router.rs @@ -14,6 +14,7 @@ use dynamo_runtime::{ protocols::annotated::Annotated, }; use futures::stream::{self, StreamExt}; +use protocols::{WorkerId, WorkerWithDpRank}; pub mod indexer; pub mod metrics_aggregator; @@ -28,8 +29,7 @@ use crate::{ indexer::{KvIndexer, KvIndexerInterface}, metrics_aggregator::KvMetricsAggregator, protocols::{ - DpRank, LocalBlockHash, RouterEvent, RouterRequest, RouterResponse, WorkerId, - WorkerSelectionResult, + LocalBlockHash, RouterEvent, RouterRequest, RouterResponse, WorkerSelectionResult, }, scheduler::{KvScheduler, KvSchedulerError, SchedulingRequest}, scoring::ProcessedEndpoints, @@ -54,13 +54,13 @@ pub trait WorkerSelector { workers: &ProcessedEndpoints, request: &SchedulingRequest, block_size: usize, - ) -> Result, KvSchedulerError>; + ) -> Result, KvSchedulerError>; } /// A KvRouter only decides which worker you should use. It doesn't send you there. /// TODO: Rename this to indicate it only selects a worker, it does not route. pub struct KvRouter { - indexer: KvIndexer<(WorkerId, DpRank)>, + indexer: KvIndexer, scheduler: KvScheduler, block_size: usize, } @@ -95,7 +95,7 @@ impl KvRouter { tokio::spawn(async move { while let Some(event) = kv_events_rx.next().await { - let event: RouterEvent<(WorkerId, DpRank)> = + let event: RouterEvent = match serde_json::from_slice(&event.payload) { Ok(event) => event, Err(e) => { @@ -119,11 +119,7 @@ impl KvRouter { } // [TODO] indexer needs to take 'lora_id' as parameter - pub async fn schedule( - &self, - token_ids: &Vec, - _lora_id: u64, - ) -> Result<(WorkerId, DpRank)> { + pub async fn schedule(&self, token_ids: &Vec, _lora_id: u64) -> Result { // Extracting part of the code in KvRouter::generate() for only // the decision making part, routing is done by the caller let isl_tokens = token_ids.len(); @@ -138,7 +134,7 @@ impl KvRouter { /// Give these tokens, find the worker with the best match in it's KV cache. /// Returned overlap amount is in number of blocks. - async fn find_best_match(&self, tokens: &[u32]) -> anyhow::Result<((WorkerId, DpRank), u32)> { + async fn find_best_match(&self, tokens: &[u32]) -> anyhow::Result<(WorkerWithDpRank, u32)> { let isl_tokens = tokens.len(); let block_size = self.block_size; @@ -168,18 +164,21 @@ impl KvRouter { impl AsyncEngine< SingleIn, - ManyOut>>, + ManyOut>>, Error, > for KvRouter { async fn generate( &self, request: SingleIn, - ) -> Result>>> { + ) -> Result>>> { let (request, ctx) = request.into_parts(); - let (worker_id, _) = self.find_best_match(&request.tokens).await?; + let (best_match, _) = self.find_best_match(&request.tokens).await?; - let response = RouterResponse { worker_id }; + // NOTE: this ignores dp routing + let response = RouterResponse { + worker_id_general: best_match.worker_id, + }; let response = Annotated::from_data(response); let stream = stream::iter(vec![response]); Ok(ResponseStream::new(Box::pin(stream), ctx.context())) @@ -218,7 +217,9 @@ impl AsyncEngine, ManyOut>, Er backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount); let updated_request = context.map(|_| backend_input); // TODO: this does not do dp routing - self.inner.direct(updated_request, instance_id.0).await + self.inner + .direct(updated_request, instance_id.worker_id) + .await } } } diff --git a/lib/llm/src/kv_router/indexer.rs b/lib/llm/src/kv_router/indexer.rs index ad325d8199..f0d0bd7b1f 100644 --- a/lib/llm/src/kv_router/indexer.rs +++ b/lib/llm/src/kv_router/indexer.rs @@ -257,7 +257,7 @@ impl RadixTree { /// /// * `event` - The `RouterEvent` to apply. pub fn apply_event(&mut self, event: RouterEvent) { - let (worker_id, event) = (event.worker_id, event.event); + let (worker_id, event) = (event.worker_id_general, event.event); let (id, op) = (event.event_id, event.data); tracing::trace!(id, "Store operation: {:?}", op); @@ -363,7 +363,7 @@ impl RadixTree { } } - pub fn clear_all_blocks(&mut self, worker: WorkerId) { + pub fn clear_all_blocks(&mut self, worker: T) { // Check if the worker has any blocks to clear if let Some(blocks) = self.lookup.get(&worker) { let blocks_to_clear: Vec<_> = blocks.values().collect(); @@ -845,7 +845,10 @@ impl KvIndexerInterface for KvIndexerSharded { async fn apply_event(&mut self, event: RouterEvent) { #[allow(clippy::map_entry)] - if !self.worker_assignments.contains_key(&event.worker_id) { + if !self + .worker_assignments + .contains_key(&event.worker_id_general) + { // Get the shard with the smallest amount of workers. let selected_shard = self .worker_counts @@ -856,11 +859,11 @@ impl KvIndexerInterface for KvIndexerSharded { .0; self.worker_assignments - .insert(event.worker_id.clone(), selected_shard); + .insert(event.worker_id_general.clone(), selected_shard); self.worker_counts[selected_shard] += 1; } - self.event_tx[self.worker_assignments[&event.worker_id]] + self.event_tx[self.worker_assignments[&event.worker_id_general]] .send(event) .await .unwrap(); @@ -924,7 +927,7 @@ mod tests { parent: Option, ) -> RouterEvent { RouterEvent { - worker_id, + worker_id_general: worker_id, event: KvCacheEvent { event_id, data: add_blocks(hashes, parent), @@ -938,7 +941,7 @@ mod tests { hashes: Vec, ) -> RouterEvent { RouterEvent { - worker_id, + worker_id_general: worker_id, event: KvCacheEvent { event_id, data: KvCacheEventData::Removed(KvCacheRemoveData { @@ -1515,7 +1518,7 @@ mod tests { }; let router_event = RouterEvent::new(worker_id, kv_cache_event); - assert_eq!(router_event.worker_id, worker_id); + assert_eq!(router_event.worker_id_general, worker_id); assert_eq!(router_event.event.event_id, 1); if let KvCacheEventData::Stored(store_op) = &router_event.event.data { assert_eq!(store_op.blocks.len(), 1); diff --git a/lib/llm/src/kv_router/protocols.rs b/lib/llm/src/kv_router/protocols.rs index cbd9a72080..f59af5f05e 100644 --- a/lib/llm/src/kv_router/protocols.rs +++ b/lib/llm/src/kv_router/protocols.rs @@ -21,9 +21,14 @@ use std::fmt::Debug; use std::hash::Hash; pub type WorkerId = i64; - pub type DpRank = u32; +#[derive(Hash, PartialEq, Eq, Debug, Clone, Copy, Serialize, Deserialize, Default)] +pub struct WorkerWithDpRank { + pub worker_id: WorkerId, + pub dp_rank: Option, +} + pub trait WorkerGeneral: Hash + Eq + Debug + Clone + Send + Sync + Default + 'static + Serialize { @@ -41,13 +46,13 @@ pub struct RouterRequest { #[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct RouterResponse { - pub worker_id: T, + pub worker_id_general: T, } #[derive(Debug)] pub struct WorkerSelectionResult { /// The worker id of the selected worker - pub worker_id: T, + pub worker_id_general: T, /// The total number of blocks required to prefill the request pub required_blocks: u64, @@ -157,7 +162,7 @@ pub struct KvCacheRemoveData { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct KVHitRateEvent { - pub worker_id: T, + pub worker_id_general: T, pub isl_blocks: usize, pub overlap_blocks: usize, } @@ -166,7 +171,7 @@ pub struct KVHitRateEvent { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct RouterEvent { /// The ID of the worker emitting the event. - pub worker_id: T, + pub worker_id_general: T, /// The cache event associated with the worker. pub event: KvCacheEvent, } @@ -182,8 +187,11 @@ impl RouterEvent { /// ### Returns /// /// A new `RouterEvent`. - pub fn new(worker_id: T, event: KvCacheEvent) -> Self { - Self { worker_id, event } + pub fn new(worker_id_general: T, event: KvCacheEvent) -> Self { + Self { + worker_id_general, + event, + } } } diff --git a/lib/llm/src/kv_router/publisher.rs b/lib/llm/src/kv_router/publisher.rs index 1cb7cef902..3029c0308f 100644 --- a/lib/llm/src/kv_router/publisher.rs +++ b/lib/llm/src/kv_router/publisher.rs @@ -43,6 +43,13 @@ use zeromq::{Socket, SocketRecv, SubSocket}; // KV Event Publishers ----------------------------------------------------- // ------------------------------------------------------------------------- +/// Represents a single cache event with an ID and associated data. +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct KvCacheEventWithDp { + pub kv_cache_event: KvCacheEvent, + pub dp_rank: Option, +} + /// Configure the source of KV events. /// Currently, only ZMQ is supported. pub enum KvEventSourceConfig { @@ -63,7 +70,7 @@ impl KvEventSource { kv_block_size: usize, source_config: KvEventSourceConfig, cancellation_token: CancellationToken, - tx: mpsc::UnboundedSender, + tx: mpsc::UnboundedSender, ) -> Result { match source_config { KvEventSourceConfig::Zmq { endpoint, topic } => { @@ -95,7 +102,6 @@ impl KvEventSource { /// A publisher of KV events. pub struct KvEventPublisher { - /// The size of the KV block. kv_block_size: usize, /// The source of KV events. /// Can be `None` if all events provided through [`KvEventPublisher::publish`]. @@ -103,19 +109,19 @@ pub struct KvEventPublisher { /// The cancellation token. cancellation_token: CancellationToken, /// The channel to send events to. - tx: mpsc::UnboundedSender, + tx: mpsc::UnboundedSender, } impl KvEventPublisher { pub fn new( component: Component, - worker_id: i64, + worker_id: WorkerId, kv_block_size: usize, source_config: Option, ) -> Result { let cancellation_token = CancellationToken::new(); - let (tx, rx) = mpsc::unbounded_channel::(); + let (tx, rx) = mpsc::unbounded_channel::(); // Create our event source (if any) let mut source = None; @@ -148,7 +154,10 @@ impl KvEventPublisher { }) } - pub fn publish(&self, event: KvCacheEvent) -> Result<(), mpsc::error::SendError> { + pub fn publish( + &self, + event: KvCacheEventWithDp, + ) -> Result<(), mpsc::error::SendError> { tracing::trace!("Publish event: {:?}", event); self.tx.send(event) } @@ -176,9 +185,9 @@ impl Drop for KvEventPublisher { async fn start_event_processor( publisher: P, - worker_id: i64, + worker_id: WorkerId, cancellation_token: CancellationToken, - mut rx: mpsc::UnboundedReceiver, + mut rx: mpsc::UnboundedReceiver, ) { loop { tokio::select! { @@ -186,14 +195,17 @@ async fn start_event_processor( tracing::info!("KV Event source received cancellation signal"); break; } - event = rx.recv() => { - let Some(event) = event else { + maybe_data = rx.recv() => { + let Some(data) = maybe_data else { tracing::debug!("Event processor channel closed."); break; }; // Encapsulate in a router event and publish. - let router_event = RouterEvent::new(worker_id, event); + let event = data.kv_cache_event; + let dp_rank = data.dp_rank.unwrap_or(0); + + let router_event = RouterEvent::new((worker_id, dp_rank), event); if let Err(e) = publisher.publish(KV_EVENT_SUBJECT, &router_event).await { tracing::error!("Failed to publish event: {}", e); } @@ -219,7 +231,7 @@ fn calculate_backoff_ms(consecutive_errors: u32) -> u64 { async fn start_zmq_listener( zmq_endpoint: String, zmq_topic: String, - tx: mpsc::UnboundedSender, + tx: mpsc::UnboundedSender, cancellation_token: CancellationToken, kv_block_size: usize, ) { @@ -315,9 +327,10 @@ async fn start_zmq_listener( }; // For each of our events, convert them to [`KvCacheEvent`] and send to the event_processor. + let dp_rank = batch.dp_rank; for raw_event in batch.events.into_iter() { - let event = convert_event(raw_event, seq, kv_block_size, &warning_count); - if tx.send(event).is_err() { + let kv_cache_event = convert_event(raw_event, seq, kv_block_size, &warning_count); + if tx.send(KvCacheEventWithDp { kv_cache_event, dp_rank }).is_err() { tracing::warn!("Failed to send message to channel - receiver dropped"); return; } @@ -436,6 +449,7 @@ pub fn create_stored_blocks( struct KvEventBatch { ts: f64, events: Vec, + dp_rank: Option, } #[derive(Debug, Deserialize, Serialize)] @@ -703,15 +717,20 @@ mod tests_startup_helpers { async fn test_start_event_processor() { let (component, published) = MockComponent::new(); - let event = KvCacheEvent { + let kv_cache_event = KvCacheEvent { event_id: 1, data: KvCacheEventData::Removed(KvCacheRemoveData { block_hashes: vec![ExternalSequenceBlockHash(1), ExternalSequenceBlockHash(2)], }), }; + let event = KvCacheEventWithDp { + kv_cache_event, + dp_rank: None, + }; + let token = CancellationToken::new(); - let (tx, rx) = mpsc::unbounded_channel::(); + let (tx, rx) = mpsc::unbounded_channel::(); tx.send(event).unwrap(); drop(tx); @@ -735,7 +754,7 @@ mod tests_startup_helpers { #[tokio::test] async fn test_start_zmq_listener_pushes_to_channel() { // Prepare channel that listener should fill - let (tx, mut rx) = mpsc::unbounded_channel::(); + let (tx, mut rx) = mpsc::unbounded_channel::(); // ZMQ TCP endpoint using localhost with fixed port let endpoint = "tcp://127.0.0.1:15555"; @@ -768,7 +787,11 @@ mod tests_startup_helpers { lora_id: None, }]; - let batch = KvEventBatch { ts: 0.0, events }; + let batch = KvEventBatch { + ts: 0.0, + events, + dp_rank: None, + }; let payload = Bytes::from(rmps::to_vec(&batch).unwrap()); @@ -793,7 +816,7 @@ mod tests_startup_helpers { let KvCacheEventData::Stored(KvCacheStoreData { parent_hash, blocks, - }) = event.data + }) = event.kv_cache_event.data else { panic!("expected KvCacheStoreData"); }; diff --git a/lib/llm/src/kv_router/scheduler.rs b/lib/llm/src/kv_router/scheduler.rs index 593623811e..da8a439be4 100644 --- a/lib/llm/src/kv_router/scheduler.rs +++ b/lib/llm/src/kv_router/scheduler.rs @@ -24,7 +24,7 @@ pub use crate::kv_router::protocols::ForwardPassMetrics; use crate::kv_router::scoring::ProcessedEndpoints; use crate::kv_router::KV_HIT_RATE_SUBJECT; -use super::protocols::{DpRank, KVHitRateEvent, WorkerId, WorkerSelectionResult}; +use super::protocols::{KVHitRateEvent, WorkerSelectionResult, WorkerWithDpRank}; use super::WorkerSelector; #[derive(Debug, thiserror::Error)] @@ -41,12 +41,12 @@ pub enum KvSchedulerError { pub struct SchedulingRequest { pub isl_tokens: usize, - pub overlap: OverlapScores<(WorkerId, DpRank)>, - resp_tx: tokio::sync::oneshot::Sender<(WorkerId, DpRank)>, + pub overlap: OverlapScores, + resp_tx: tokio::sync::oneshot::Sender, } impl SchedulingRequest { - pub fn respond(self, identifier: (WorkerId, DpRank)) { + pub fn respond(self, identifier: WorkerWithDpRank) { if self.resp_tx.send(identifier).is_err() { tracing::trace!("failed to send response to requestor"); } @@ -69,7 +69,7 @@ impl KvScheduler { let mut endpoints: ProcessedEndpoints = endpoints_rx.borrow_and_update().clone(); let (event_tx, event_rx) = - tokio::sync::mpsc::unbounded_channel::>(); + tokio::sync::mpsc::unbounded_channel::>(); tokio::spawn(async move { let mut event_rx = event_rx; while let Some(event) = event_rx.recv().await { @@ -147,9 +147,9 @@ impl KvScheduler { pub async fn schedule( &self, - overlap: OverlapScores<(WorkerId, DpRank)>, + overlap: OverlapScores, isl_tokens: usize, - ) -> Result<(WorkerId, DpRank), KvSchedulerError> { + ) -> Result { let (resp_tx, resp_rx) = tokio::sync::oneshot::channel(); let request = SchedulingRequest { isl_tokens, @@ -170,15 +170,15 @@ impl KvScheduler { // This becomes the driver function that handles the selection result pub fn process_worker_selection( workers: &mut ProcessedEndpoints, - selection: WorkerSelectionResult<(WorkerId, DpRank)>, - event_tx: &tokio::sync::mpsc::UnboundedSender>, -) -> (WorkerId, DpRank) { + selection: WorkerSelectionResult, + event_tx: &tokio::sync::mpsc::UnboundedSender>, +) -> WorkerWithDpRank { let worker = workers .endpoints - .get_mut(&selection.worker_id.0) + .get_mut(&selection.worker_id_general.worker_id) .expect("worker not found"); - let dp_rank = selection.worker_id.1 as usize; + let dp_rank = selection.worker_id_general.dp_rank.unwrap_or(0) as usize; // Update worker state predictively // Will be overwritten on next polling of metrics @@ -191,14 +191,14 @@ pub fn process_worker_selection( // Emit event if let Err(e) = event_tx.send(KVHitRateEvent { - worker_id: selection.worker_id, + worker_id_general: selection.worker_id_general, isl_blocks: selection.required_blocks as usize, overlap_blocks: selection.overlap_blocks, }) { tracing::warn!("Failed to send KV hit rate event: {:?}", e); } - (selection.worker_id.0, selection.worker_id.1) + selection.worker_id_general } // Default implementation matching the Python _cost_function @@ -211,7 +211,7 @@ impl WorkerSelector for DefaultWorkerSelector { workers: &ProcessedEndpoints, request: &SchedulingRequest, block_size: usize, - ) -> Result, KvSchedulerError> { + ) -> Result, KvSchedulerError> { assert!(request.isl_tokens > 0); if workers.endpoints.is_empty() { @@ -223,16 +223,19 @@ impl WorkerSelector for DefaultWorkerSelector { // Calculate worker scores and find max waiting requests for (worker_id, ep) in workers.endpoints.iter() { - for dp_rank_maybe in ep.data.iter().map(|metrics| metrics.data_parallel_rank) { - let dp_rank = dp_rank_maybe.unwrap_or(0); - if let Some(score) = request.overlap.scores.get(&(*worker_id, dp_rank)) { + for dp_rank in ep.data.iter().map(|metrics| metrics.data_parallel_rank) { + let worker_with_dp_rank = WorkerWithDpRank { + worker_id: *worker_id, + dp_rank: dp_rank, + }; + if let Some(score) = request.overlap.scores.get(&worker_with_dp_rank) { let score = *score as f64 * block_size as f64 / request.isl_tokens as f64; - worker_scores.insert((*worker_id, dp_rank), score); + worker_scores.insert(worker_with_dp_rank, score); } // Track max waiting requests max_waiting = f64::max( max_waiting, - ep.data[dp_rank as usize].num_requests_waiting as f64, + ep.data[dp_rank.unwrap_or(0) as usize].num_requests_waiting as f64, ); } } @@ -248,18 +251,20 @@ impl WorkerSelector for DefaultWorkerSelector { for (worker_id, ep) in workers.endpoints.iter() { let worker_id = *worker_id; for fwd_pass_metrics in ep.data.iter() { - let dp_rank = fwd_pass_metrics.data_parallel_rank.unwrap_or(0); + let dp_rank = fwd_pass_metrics.data_parallel_rank; + let worker_with_dp_rank = WorkerWithDpRank { worker_id, dp_rank }; // Get score or default to 0.0 let score = worker_scores - .get(&(worker_id, dp_rank)) + .get(&worker_with_dp_rank) .copied() .unwrap_or(0.0); // Calculate normalized metrics - let gpu_cache_usage = ep.data[dp_rank as usize].gpu_cache_usage_perc as f64; + let gpu_cache_usage = + ep.data[dp_rank.unwrap_or(0) as usize].gpu_cache_usage_perc as f64; let normalized_waiting = if max_waiting > 0.0 { - ep.data[dp_rank as usize].num_requests_waiting as f64 / max_waiting + ep.data[dp_rank.unwrap_or(0) as usize].num_requests_waiting as f64 / max_waiting } else { 0.0 }; @@ -276,10 +281,10 @@ impl WorkerSelector for DefaultWorkerSelector { Some(std::cmp::Ordering::Greater) => { best_logit = logit; best_workers.clear(); - best_workers.push((worker_id, dp_rank)); + best_workers.push(worker_with_dp_rank); } Some(std::cmp::Ordering::Equal) => { - best_workers.push((worker_id, dp_rank)); + best_workers.push(worker_with_dp_rank); } _ => {} } @@ -293,7 +298,7 @@ impl WorkerSelector for DefaultWorkerSelector { tracing::debug!("best worker logit is 0"); } - let worker_id = if best_workers.len() == 1 { + let best_worker_and_dp = if best_workers.len() == 1 { best_workers[0] } else { // Randomly select from best workers @@ -302,14 +307,19 @@ impl WorkerSelector for DefaultWorkerSelector { }; // Lower to trace level eventually. Nice to see KV routing working for now. - tracing::debug!("Selected worker: {worker_id:?}, logit: {best_logit:.3}"); + tracing::debug!("Selected worker: {best_worker_and_dp:?}, logit: {best_logit:.3}"); // Log selection metrics let total_blocks = std::cmp::max(request.isl_tokens / block_size, 1) as u64; - let overlap_blocks = request.overlap.scores.get(&worker_id).copied().unwrap_or(0) as usize; + let overlap_blocks = request + .overlap + .scores + .get(&best_worker_and_dp) + .copied() + .unwrap_or(0) as usize; Ok(WorkerSelectionResult { - worker_id, + worker_id_general: best_worker_and_dp, required_blocks: total_blocks, overlap_blocks, }) From 2cef74c6fe608128f68504b22a69b475477a1dfe Mon Sep 17 00:00:00 2001 From: PeaBrane Date: Fri, 30 May 2025 16:11:15 -0700 Subject: [PATCH 05/22] binding works --- components/metrics/src/main.rs | 11 ++--- components/router/src/main.rs | 2 +- lib/bindings/c/src/lib.rs | 13 +++++- lib/bindings/python/rust/llm/kv.rs | 73 ++++++++++++++++++++---------- lib/llm/src/kv_router.rs | 7 ++- lib/llm/src/kv_router/scheduler.rs | 2 +- lib/llm/src/kv_router/scoring.rs | 2 +- 7 files changed, 72 insertions(+), 38 deletions(-) diff --git a/components/metrics/src/main.rs b/components/metrics/src/main.rs index ce78cc94c2..ae88684666 100644 --- a/components/metrics/src/main.rs +++ b/components/metrics/src/main.rs @@ -27,7 +27,7 @@ //! - ISL Blocks: Cumulative count of total blocks in all KV hit rate events //! - Overlap Blocks: Cumulative count of blocks that were already in the KV cache use clap::Parser; -use dynamo_llm::kv_router::protocols::{WorkerWithDpRank, KVHitRateEvent, WorkerId}; +use dynamo_llm::kv_router::protocols::{KVHitRateEvent, WorkerWithDpRank}; use dynamo_llm::kv_router::KV_HIT_RATE_SUBJECT; use dynamo_runtime::{ error, logging, @@ -180,16 +180,15 @@ async fn app(runtime: Runtime) -> Result<()> { tracing::debug!("Successfully subscribed to KV hit rate events"); while let Some(msg) = subscriber.next().await { - match serde_json::from_slice::>(&msg.payload) - { + match serde_json::from_slice::>(&msg.payload) { Ok(event) => { // TODO: Lower to debug let cache_hit_pct = (event.overlap_blocks as f64 / event.isl_blocks as f64) * 100.0; tracing::debug!( "Received KV hit rate event: worker_id={}, dp_rank={}, isl_blocks={}, overlap_blocks={}, cache_hit_pct={:.2}%", - event.worker_id_general.0, - event.worker_id_general.1, + event.worker_id_general.worker_id, + event.worker_id_general.dp_rank.unwrap_or(0), event.isl_blocks, event.overlap_blocks, cache_hit_pct @@ -200,7 +199,7 @@ async fn app(runtime: Runtime) -> Result<()> { metrics.update_kv_hit_rate( &config_clone, // TODO: this will not take care of dp ranks - event.worker_id_general.0, + event.worker_id_general.worker_id, event.isl_blocks, event.overlap_blocks, ); diff --git a/components/router/src/main.rs b/components/router/src/main.rs index 50f26ecccb..7caeb56b0d 100644 --- a/components/router/src/main.rs +++ b/components/router/src/main.rs @@ -25,7 +25,7 @@ use std::sync::Arc; use clap::Parser; use dynamo_llm::kv_router::{ - protocols::{WorkerWithDpRank, WorkerSelectionResult}, + protocols::{WorkerSelectionResult, WorkerWithDpRank}, scheduler::{DefaultWorkerSelector, KvSchedulerError, SchedulingRequest}, scoring::ProcessedEndpoints, KvRouter, WorkerSelector, diff --git a/lib/bindings/c/src/lib.rs b/lib/bindings/c/src/lib.rs index 1c50f4aa8e..b27dd399d3 100644 --- a/lib/bindings/c/src/lib.rs +++ b/lib/bindings/c/src/lib.rs @@ -14,6 +14,7 @@ // limitations under the License. use async_once_cell::OnceCell as AsyncOnceCell; +use dynamo_llm::kv_router::publisher::KvCacheEventWithDp; use libc::c_char; use once_cell::sync::OnceCell; use std::ffi::CStr; @@ -284,7 +285,11 @@ pub unsafe extern "C" fn dynamo_kv_event_publish_stored( }; let publisher = KV_PUB.get().unwrap(); let event = kv_event_create_stored_from_parts(kv_params, publisher.kv_block_size()); - match publisher.publish(event) { + let event_with_dp = KvCacheEventWithDp { + kv_cache_event: event, + dp_rank: None, + }; + match publisher.publish(event_with_dp) { Ok(_) => DynamoLlmResult::OK, Err(e) => { eprintln!("Error publishing stored kv event {:?}", e); @@ -301,7 +306,11 @@ pub extern "C" fn dynamo_kv_event_publish_removed( ) -> DynamoLlmResult { let publisher = KV_PUB.get().unwrap(); let event = kv_event_create_removed_from_parts(event_id, block_ids, num_blocks); - match publisher.publish(event) { + let event_with_dp = KvCacheEventWithDp { + kv_cache_event: event, + dp_rank: None, + }; + match publisher.publish(event_with_dp) { Ok(_) => DynamoLlmResult::OK, Err(e) => { eprintln!("Error publishing removed kv event {:?}", e); diff --git a/lib/bindings/python/rust/llm/kv.rs b/lib/bindings/python/rust/llm/kv.rs index f860080fe9..e68ed61977 100644 --- a/lib/bindings/python/rust/llm/kv.rs +++ b/lib/bindings/python/rust/llm/kv.rs @@ -24,6 +24,24 @@ use tracing; use llm_rs::kv_router::protocols::*; use llm_rs::kv_router::publisher::{create_stored_blocks, KvEventSourceConfig, KvCacheEventWithDp}; +#[pyclass] +#[derive(Clone, PartialEq, Eq, Hash)] +pub struct WorkerWithDpRank { + #[pyo3(get, set)] + pub worker_id: i64, + #[pyo3(get, set)] + pub dp_rank: Option, +} + +impl From for WorkerWithDpRank { + fn from(value: llm_rs::kv_router::protocols::WorkerWithDpRank) -> Self { + Self { + worker_id: value.worker_id, + dp_rank: value.dp_rank, + } + } +} + #[pyclass] pub(crate) struct KvRouter { inner: Arc, @@ -57,7 +75,7 @@ impl KvRouter { .schedule(&token_ids, lora_id) .await .map_err(to_pyerr)?; - Ok(worker_id) + Ok(WorkerWithDpRank::from(worker_id)) }) } } @@ -107,7 +125,7 @@ impl WorkerMetricsPublisher { num_requests_waiting: u64, gpu_cache_usage_perc: f32, gpu_prefix_cache_hit_rate: f32, - data_parallel_rank: u32, + data_parallel_rank: DpRank, ) -> PyResult<()> { self.inner .publish( @@ -218,7 +236,7 @@ impl KvEventPublisher { } #[allow(clippy::too_many_arguments)] - #[pyo3(signature = (event_id, token_ids, num_block_tokens, block_hashes, lora_id, parent_hash=None))] + #[pyo3(signature = (event_id, token_ids, num_block_tokens, block_hashes, lora_id, parent_hash=None, dp_rank=None))] fn publish_stored( &mut self, _py: Python, @@ -228,6 +246,7 @@ impl KvEventPublisher { block_hashes: Vec, lora_id: u64, parent_hash: Option, + dp_rank: Option, ) -> PyResult<()> { let event = KvCacheEvent { event_id, @@ -244,13 +263,14 @@ impl KvEventPublisher { }), }; let event_with_dp = KvCacheEventWithDp { - kv_cache_event: event, dp_rank: None, + kv_cache_event: event, dp_rank, }; self.inner.publish(event_with_dp).map_err(to_pyerr) } - fn publish_removed(&self, _py: Python, event_id: u64, block_hashes: Vec) -> PyResult<()> { + #[pyo3(signature = (event_id, block_hashes, dp_rank=None))] + fn publish_removed(&self, _py: Python, event_id: u64, block_hashes: Vec, dp_rank: Option) -> PyResult<()> { let block_hashes: Vec = block_hashes .iter() .map(|&h| ExternalSequenceBlockHash::from(h)) @@ -260,7 +280,7 @@ impl KvEventPublisher { data: KvCacheEventData::Removed(KvCacheRemoveData { block_hashes }), }; let event_with_dp = KvCacheEventWithDp { - kv_cache_event: event, dp_rank: None, + kv_cache_event: event, dp_rank, }; self.inner.publish(event_with_dp).map_err(to_pyerr) @@ -270,14 +290,16 @@ impl KvEventPublisher { #[pyclass] #[derive(Clone)] pub(crate) struct OverlapScores { - inner: llm_rs::kv_router::indexer::OverlapScores<(WorkerId, DpRank)>, + inner: llm_rs::kv_router::indexer::OverlapScores, } #[pymethods] impl OverlapScores { #[getter] - fn scores(&self) -> HashMap<(WorkerId, DpRank), u32> { - self.inner.scores.clone() + fn scores(&self) -> HashMap { + self.inner.scores.iter() + .map(|(k, v)| (WorkerWithDpRank::from(*k), *v)) + .collect() } #[getter] @@ -288,7 +310,7 @@ impl OverlapScores { #[pyclass] pub(crate) struct KvIndexer { - inner: Arc>, + inner: Arc>, } #[pymethods] @@ -297,7 +319,7 @@ impl KvIndexer { fn new(component: Component, kv_block_size: usize) -> PyResult { let runtime = pyo3_async_runtimes::tokio::get_runtime(); runtime.block_on(async { - let inner: Arc> = + let inner: Arc> = llm_rs::kv_router::indexer::KvIndexer::new( component.inner.drt().runtime().child_token(), kv_block_size, @@ -316,7 +338,7 @@ impl KvIndexer { // should have been made to a trait and implemented here? i.e. AsyncEngine style tokio::spawn(async move { while let Some(event) = kv_events_rx.next().await { - let event: llm_rs::kv_router::protocols::RouterEvent<(WorkerId, DpRank)> = + let event: llm_rs::kv_router::protocols::RouterEvent = serde_json::from_slice(&event.payload).unwrap(); tracing::debug!("received kv event: {:?}", event); if let Err(e) = kv_events_tx.send(event).await { @@ -360,6 +382,8 @@ pub(crate) struct EndpointKvMetrics { #[pyo3(get, set)] pub worker_id: i64, #[pyo3(get, set)] + pub dp_rank: Option, + #[pyo3(get, set)] pub request_active_slots: u64, #[pyo3(get, set)] pub request_total_slots: u64, @@ -413,15 +437,18 @@ impl KvMetricsAggregator { let endpoint_kv_metrics = endpoints .endpoints .iter() - .map(|(worker_id, x)| EndpointKvMetrics { - worker_id: *worker_id, - request_active_slots: x.data.request_active_slots, - request_total_slots: x.data.request_total_slots, - kv_active_blocks: x.data.kv_active_blocks, - kv_total_blocks: x.data.kv_total_blocks, - num_requests_waiting: x.data.num_requests_waiting, - gpu_cache_usage_perc: x.data.gpu_cache_usage_perc, - gpu_prefix_cache_hit_rate: x.data.gpu_prefix_cache_hit_rate, + .flat_map(|(worker_id, x)| { + x.data.iter().map(move |data_item| EndpointKvMetrics { + worker_id: *worker_id, + dp_rank: data_item.data_parallel_rank, + request_active_slots: data_item.request_active_slots, + request_total_slots: data_item.request_total_slots, + kv_active_blocks: data_item.kv_active_blocks, + kv_total_blocks: data_item.kv_total_blocks, + num_requests_waiting: data_item.num_requests_waiting, + gpu_cache_usage_perc: data_item.gpu_cache_usage_perc, + gpu_prefix_cache_hit_rate: data_item.gpu_prefix_cache_hit_rate, + }) }) .collect(); pyo3_async_runtimes::tokio::future_into_py(py, async move { @@ -436,7 +463,7 @@ impl KvMetricsAggregator { #[pyclass] pub(crate) struct KvRecorder { - inner: Arc, + inner: Arc>, } #[pymethods] @@ -487,7 +514,7 @@ impl KvRecorder { // Spawn a task to forward events to the recorder tokio::spawn(async move { while let Some(event) = kv_events_rx.next().await { - let event: llm_rs::kv_router::indexer::RouterEvent = + let event: llm_rs::kv_router::protocols::RouterEvent = serde_json::from_slice(&event.payload).unwrap(); tracing::debug!("KvRecorder received kv event: {:?}", event); if let Err(e) = event_tx.send(event).await { diff --git a/lib/llm/src/kv_router.rs b/lib/llm/src/kv_router.rs index 72ce4956eb..998f04a572 100644 --- a/lib/llm/src/kv_router.rs +++ b/lib/llm/src/kv_router.rs @@ -14,7 +14,7 @@ use dynamo_runtime::{ protocols::annotated::Annotated, }; use futures::stream::{self, StreamExt}; -use protocols::{WorkerId, WorkerWithDpRank}; +use protocols::WorkerWithDpRank; pub mod indexer; pub mod metrics_aggregator; @@ -171,13 +171,12 @@ impl async fn generate( &self, request: SingleIn, - ) -> Result>>> { + ) -> Result>>> { let (request, ctx) = request.into_parts(); let (best_match, _) = self.find_best_match(&request.tokens).await?; - // NOTE: this ignores dp routing let response = RouterResponse { - worker_id_general: best_match.worker_id, + worker_id_general: best_match, }; let response = Annotated::from_data(response); let stream = stream::iter(vec![response]); diff --git a/lib/llm/src/kv_router/scheduler.rs b/lib/llm/src/kv_router/scheduler.rs index da8a439be4..2bb8d3ab3c 100644 --- a/lib/llm/src/kv_router/scheduler.rs +++ b/lib/llm/src/kv_router/scheduler.rs @@ -226,7 +226,7 @@ impl WorkerSelector for DefaultWorkerSelector { for dp_rank in ep.data.iter().map(|metrics| metrics.data_parallel_rank) { let worker_with_dp_rank = WorkerWithDpRank { worker_id: *worker_id, - dp_rank: dp_rank, + dp_rank, }; if let Some(score) = request.overlap.scores.get(&worker_with_dp_rank) { let score = *score as f64 * block_size as f64 / request.isl_tokens as f64; diff --git a/lib/llm/src/kv_router/scoring.rs b/lib/llm/src/kv_router/scoring.rs index 13f9f90c9f..d1dcec3980 100644 --- a/lib/llm/src/kv_router/scoring.rs +++ b/lib/llm/src/kv_router/scoring.rs @@ -60,7 +60,7 @@ impl ProcessedEndpoints { .flat_map(|endpoint| endpoint.data.iter()) .map(|metrics| metrics.kv_active_blocks as f64) .collect(); - if load_values.len() == 0 { + if load_values.is_empty() { panic!("No endpoints to process!") }; let load_avg = load_values.iter().copied().sum::() / load_values.len() as f64; From 10d33260ba9460252f04cc294cb231063bcbe8a1 Mon Sep 17 00:00:00 2001 From: PeaBrane Date: Fri, 30 May 2025 16:18:46 -0700 Subject: [PATCH 06/22] dummy c binding note --- lib/bindings/c/src/lib.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lib/bindings/c/src/lib.rs b/lib/bindings/c/src/lib.rs index b27dd399d3..742815230d 100644 --- a/lib/bindings/c/src/lib.rs +++ b/lib/bindings/c/src/lib.rs @@ -285,6 +285,7 @@ pub unsafe extern "C" fn dynamo_kv_event_publish_stored( }; let publisher = KV_PUB.get().unwrap(); let event = kv_event_create_stored_from_parts(kv_params, publisher.kv_block_size()); + // NOTE: dummy dp_rank for now let event_with_dp = KvCacheEventWithDp { kv_cache_event: event, dp_rank: None, @@ -306,6 +307,7 @@ pub extern "C" fn dynamo_kv_event_publish_removed( ) -> DynamoLlmResult { let publisher = KV_PUB.get().unwrap(); let event = kv_event_create_removed_from_parts(event_id, block_ids, num_blocks); + // NOTE: dummy dp_rank for now let event_with_dp = KvCacheEventWithDp { kv_cache_event: event, dp_rank: None, From 4483c68eab0bc773ffc1a99cd62a6ac810441cde Mon Sep 17 00:00:00 2001 From: PeaBrane Date: Fri, 30 May 2025 16:20:35 -0700 Subject: [PATCH 07/22] add_class WorkerWithDpRank --- lib/bindings/python/rust/lib.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/lib/bindings/python/rust/lib.rs b/lib/bindings/python/rust/lib.rs index 39cc1ea46e..200c56136d 100644 --- a/lib/bindings/python/rust/lib.rs +++ b/lib/bindings/python/rust/lib.rs @@ -61,6 +61,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; From 263c12d74c32c2bbfd68db85031012febb19a682 Mon Sep 17 00:00:00 2001 From: PeaBrane Date: Fri, 30 May 2025 21:28:29 -0700 Subject: [PATCH 08/22] renames + comments + fmt --- components/metrics/src/bin/mock_worker.rs | 2 +- components/metrics/src/lib.rs | 2 + components/metrics/src/main.rs | 10 ++-- components/router/src/main.rs | 4 +- lib/bindings/python/rust/lib.rs | 2 +- lib/bindings/python/rust/llm/kv.rs | 61 ++++++++++++++--------- lib/llm/src/kv_router.rs | 43 +++++++--------- lib/llm/src/kv_router/indexer.rs | 17 +++---- lib/llm/src/kv_router/protocols.rs | 17 +++---- lib/llm/src/kv_router/scheduler.rs | 36 ++++++------- 10 files changed, 98 insertions(+), 96 deletions(-) diff --git a/components/metrics/src/bin/mock_worker.rs b/components/metrics/src/bin/mock_worker.rs index 10dd4c946d..a2238ea5b1 100644 --- a/components/metrics/src/bin/mock_worker.rs +++ b/components/metrics/src/bin/mock_worker.rs @@ -89,7 +89,7 @@ async fn mock_event_publisher(namespace: Namespace) { let overlap_blocks = rand::rng().random_range(0..=isl_blocks); let event = KVHitRateEvent { - worker_id_general: worker_id, + worker: worker_id, isl_blocks, overlap_blocks, }; diff --git a/components/metrics/src/lib.rs b/components/metrics/src/lib.rs index 68c026e081..4de77617ed 100644 --- a/components/metrics/src/lib.rs +++ b/components/metrics/src/lib.rs @@ -450,6 +450,8 @@ impl PrometheusMetrics { let worker_id = worker_id.to_string(); let metrics = endpoint.data.clone(); + // NOTE: using metrics[0] just to get the first dp_rank for now + // to not change the existing behavior self.set_worker_gauge( &self.kv_blocks_active, config, diff --git a/components/metrics/src/main.rs b/components/metrics/src/main.rs index ae88684666..c6a0996280 100644 --- a/components/metrics/src/main.rs +++ b/components/metrics/src/main.rs @@ -27,7 +27,7 @@ //! - ISL Blocks: Cumulative count of total blocks in all KV hit rate events //! - Overlap Blocks: Cumulative count of blocks that were already in the KV cache use clap::Parser; -use dynamo_llm::kv_router::protocols::{KVHitRateEvent, WorkerWithDpRank}; +use dynamo_llm::kv_router::protocols::{KVHitRateEvent, WorkerDp}; use dynamo_llm::kv_router::KV_HIT_RATE_SUBJECT; use dynamo_runtime::{ error, logging, @@ -180,15 +180,15 @@ async fn app(runtime: Runtime) -> Result<()> { tracing::debug!("Successfully subscribed to KV hit rate events"); while let Some(msg) = subscriber.next().await { - match serde_json::from_slice::>(&msg.payload) { + match serde_json::from_slice::>(&msg.payload) { Ok(event) => { // TODO: Lower to debug let cache_hit_pct = (event.overlap_blocks as f64 / event.isl_blocks as f64) * 100.0; tracing::debug!( "Received KV hit rate event: worker_id={}, dp_rank={}, isl_blocks={}, overlap_blocks={}, cache_hit_pct={:.2}%", - event.worker_id_general.worker_id, - event.worker_id_general.dp_rank.unwrap_or(0), + event.worker.worker_id, + event.worker.dp_rank.unwrap_or(0), event.isl_blocks, event.overlap_blocks, cache_hit_pct @@ -199,7 +199,7 @@ async fn app(runtime: Runtime) -> Result<()> { metrics.update_kv_hit_rate( &config_clone, // TODO: this will not take care of dp ranks - event.worker_id_general.worker_id, + event.worker.worker_id, event.isl_blocks, event.overlap_blocks, ); diff --git a/components/router/src/main.rs b/components/router/src/main.rs index 7caeb56b0d..e8e7d08c72 100644 --- a/components/router/src/main.rs +++ b/components/router/src/main.rs @@ -25,7 +25,7 @@ use std::sync::Arc; use clap::Parser; use dynamo_llm::kv_router::{ - protocols::{WorkerSelectionResult, WorkerWithDpRank}, + protocols::{WorkerDp, WorkerSelectionResult}, scheduler::{DefaultWorkerSelector, KvSchedulerError, SchedulingRequest}, scoring::ProcessedEndpoints, KvRouter, WorkerSelector, @@ -89,7 +89,7 @@ impl WorkerSelector for CustomWorkerSelector { workers: &ProcessedEndpoints, request: &SchedulingRequest, block_size: usize, - ) -> Result, KvSchedulerError> { + ) -> Result, KvSchedulerError> { // customize logic here // F12 into [DefaultWorkerSelector] to see the original logic self.0.select_worker(workers, request, block_size) diff --git a/lib/bindings/python/rust/lib.rs b/lib/bindings/python/rust/lib.rs index 200c56136d..da04c5b5bf 100644 --- a/lib/bindings/python/rust/lib.rs +++ b/lib/bindings/python/rust/lib.rs @@ -61,7 +61,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; - m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/lib/bindings/python/rust/llm/kv.rs b/lib/bindings/python/rust/llm/kv.rs index e68ed61977..052317098e 100644 --- a/lib/bindings/python/rust/llm/kv.rs +++ b/lib/bindings/python/rust/llm/kv.rs @@ -22,19 +22,19 @@ use rs::traits::events::EventSubscriber; use tracing; use llm_rs::kv_router::protocols::*; -use llm_rs::kv_router::publisher::{create_stored_blocks, KvEventSourceConfig, KvCacheEventWithDp}; +use llm_rs::kv_router::publisher::{create_stored_blocks, KvCacheEventWithDp, KvEventSourceConfig}; #[pyclass] #[derive(Clone, PartialEq, Eq, Hash)] -pub struct WorkerWithDpRank { +pub struct WorkerDp { #[pyo3(get, set)] pub worker_id: i64, #[pyo3(get, set)] pub dp_rank: Option, } -impl From for WorkerWithDpRank { - fn from(value: llm_rs::kv_router::protocols::WorkerWithDpRank) -> Self { +impl From for WorkerDp { + fn from(value: llm_rs::kv_router::protocols::WorkerDp) -> Self { Self { worker_id: value.worker_id, dp_rank: value.dp_rank, @@ -75,7 +75,7 @@ impl KvRouter { .schedule(&token_ids, lora_id) .await .map_err(to_pyerr)?; - Ok(WorkerWithDpRank::from(worker_id)) + Ok(WorkerDp::from(worker_id)) }) } } @@ -263,14 +263,21 @@ impl KvEventPublisher { }), }; let event_with_dp = KvCacheEventWithDp { - kv_cache_event: event, dp_rank, + kv_cache_event: event, + dp_rank, }; self.inner.publish(event_with_dp).map_err(to_pyerr) } #[pyo3(signature = (event_id, block_hashes, dp_rank=None))] - fn publish_removed(&self, _py: Python, event_id: u64, block_hashes: Vec, dp_rank: Option) -> PyResult<()> { + fn publish_removed( + &self, + _py: Python, + event_id: u64, + block_hashes: Vec, + dp_rank: Option, + ) -> PyResult<()> { let block_hashes: Vec = block_hashes .iter() .map(|&h| ExternalSequenceBlockHash::from(h)) @@ -280,7 +287,8 @@ impl KvEventPublisher { data: KvCacheEventData::Removed(KvCacheRemoveData { block_hashes }), }; let event_with_dp = KvCacheEventWithDp { - kv_cache_event: event, dp_rank, + kv_cache_event: event, + dp_rank, }; self.inner.publish(event_with_dp).map_err(to_pyerr) @@ -290,15 +298,17 @@ impl KvEventPublisher { #[pyclass] #[derive(Clone)] pub(crate) struct OverlapScores { - inner: llm_rs::kv_router::indexer::OverlapScores, + inner: llm_rs::kv_router::indexer::OverlapScores, } #[pymethods] impl OverlapScores { #[getter] - fn scores(&self) -> HashMap { - self.inner.scores.iter() - .map(|(k, v)| (WorkerWithDpRank::from(*k), *v)) + fn scores(&self) -> HashMap { + self.inner + .scores + .iter() + .map(|(k, v)| (WorkerDp::from(*k), *v)) .collect() } @@ -310,7 +320,7 @@ impl OverlapScores { #[pyclass] pub(crate) struct KvIndexer { - inner: Arc>, + inner: Arc>, } #[pymethods] @@ -319,12 +329,13 @@ impl KvIndexer { fn new(component: Component, kv_block_size: usize) -> PyResult { let runtime = pyo3_async_runtimes::tokio::get_runtime(); runtime.block_on(async { - let inner: Arc> = - llm_rs::kv_router::indexer::KvIndexer::new( - component.inner.drt().runtime().child_token(), - kv_block_size, - ) - .into(); + let inner: Arc< + llm_rs::kv_router::indexer::KvIndexer, + > = llm_rs::kv_router::indexer::KvIndexer::new( + component.inner.drt().runtime().child_token(), + kv_block_size, + ) + .into(); // [gluo TODO] try subscribe_with_type::, // error checking below will be different. let mut kv_events_rx = component @@ -338,8 +349,9 @@ impl KvIndexer { // should have been made to a trait and implemented here? i.e. AsyncEngine style tokio::spawn(async move { while let Some(event) = kv_events_rx.next().await { - let event: llm_rs::kv_router::protocols::RouterEvent = - serde_json::from_slice(&event.payload).unwrap(); + let event: llm_rs::kv_router::protocols::RouterEvent< + llm_rs::kv_router::protocols::WorkerDp, + > = serde_json::from_slice(&event.payload).unwrap(); tracing::debug!("received kv event: {:?}", event); if let Err(e) = kv_events_tx.send(event).await { tracing::trace!( @@ -463,7 +475,7 @@ impl KvMetricsAggregator { #[pyclass] pub(crate) struct KvRecorder { - inner: Arc>, + inner: Arc>, } #[pymethods] @@ -514,8 +526,9 @@ impl KvRecorder { // Spawn a task to forward events to the recorder tokio::spawn(async move { while let Some(event) = kv_events_rx.next().await { - let event: llm_rs::kv_router::protocols::RouterEvent = - serde_json::from_slice(&event.payload).unwrap(); + let event: llm_rs::kv_router::protocols::RouterEvent< + llm_rs::kv_router::protocols::WorkerDp, + > = serde_json::from_slice(&event.payload).unwrap(); tracing::debug!("KvRecorder received kv event: {:?}", event); if let Err(e) = event_tx.send(event).await { tracing::trace!( diff --git a/lib/llm/src/kv_router.rs b/lib/llm/src/kv_router.rs index 998f04a572..3b49cee3ae 100644 --- a/lib/llm/src/kv_router.rs +++ b/lib/llm/src/kv_router.rs @@ -14,7 +14,7 @@ use dynamo_runtime::{ protocols::annotated::Annotated, }; use futures::stream::{self, StreamExt}; -use protocols::WorkerWithDpRank; +use protocols::WorkerDp; pub mod indexer; pub mod metrics_aggregator; @@ -54,13 +54,13 @@ pub trait WorkerSelector { workers: &ProcessedEndpoints, request: &SchedulingRequest, block_size: usize, - ) -> Result, KvSchedulerError>; + ) -> Result, KvSchedulerError>; } /// A KvRouter only decides which worker you should use. It doesn't send you there. /// TODO: Rename this to indicate it only selects a worker, it does not route. pub struct KvRouter { - indexer: KvIndexer, + indexer: KvIndexer, scheduler: KvScheduler, block_size: usize, } @@ -95,16 +95,15 @@ impl KvRouter { tokio::spawn(async move { while let Some(event) = kv_events_rx.next().await { - let event: RouterEvent = - match serde_json::from_slice(&event.payload) { - Ok(event) => event, - Err(e) => { - tracing::warn!("Failed to deserialize RouterEvent: {:?}", e); - // Choosing warn and continue to process other events from other workers - // A bad event likely signals a problem with a worker, but potentially other workers are still healthy - continue; - } - }; + let event: RouterEvent = match serde_json::from_slice(&event.payload) { + Ok(event) => event, + Err(e) => { + tracing::warn!("Failed to deserialize RouterEvent: {:?}", e); + // Choosing warn and continue to process other events from other workers + // A bad event likely signals a problem with a worker, but potentially other workers are still healthy + continue; + } + }; if let Err(e) = kv_events_tx.send(event).await { tracing::debug!("failed to send kv event to indexer; shutting down: {:?}", e); } @@ -119,7 +118,7 @@ impl KvRouter { } // [TODO] indexer needs to take 'lora_id' as parameter - pub async fn schedule(&self, token_ids: &Vec, _lora_id: u64) -> Result { + pub async fn schedule(&self, token_ids: &Vec, _lora_id: u64) -> Result { // Extracting part of the code in KvRouter::generate() for only // the decision making part, routing is done by the caller let isl_tokens = token_ids.len(); @@ -134,7 +133,7 @@ impl KvRouter { /// Give these tokens, find the worker with the best match in it's KV cache. /// Returned overlap amount is in number of blocks. - async fn find_best_match(&self, tokens: &[u32]) -> anyhow::Result<(WorkerWithDpRank, u32)> { + async fn find_best_match(&self, tokens: &[u32]) -> anyhow::Result<(WorkerDp, u32)> { let isl_tokens = tokens.len(); let block_size = self.block_size; @@ -161,23 +160,17 @@ impl KvRouter { } #[async_trait] -impl - AsyncEngine< - SingleIn, - ManyOut>>, - Error, - > for KvRouter +impl AsyncEngine, ManyOut>>, Error> + for KvRouter { async fn generate( &self, request: SingleIn, - ) -> Result>>> { + ) -> Result>>> { let (request, ctx) = request.into_parts(); let (best_match, _) = self.find_best_match(&request.tokens).await?; - let response = RouterResponse { - worker_id_general: best_match, - }; + let response = RouterResponse { worker: best_match }; let response = Annotated::from_data(response); let stream = stream::iter(vec![response]); Ok(ResponseStream::new(Box::pin(stream), ctx.context())) diff --git a/lib/llm/src/kv_router/indexer.rs b/lib/llm/src/kv_router/indexer.rs index f0d0bd7b1f..b8698aa10f 100644 --- a/lib/llm/src/kv_router/indexer.rs +++ b/lib/llm/src/kv_router/indexer.rs @@ -257,7 +257,7 @@ impl RadixTree { /// /// * `event` - The `RouterEvent` to apply. pub fn apply_event(&mut self, event: RouterEvent) { - let (worker_id, event) = (event.worker_id_general, event.event); + let (worker_id, event) = (event.worker, event.event); let (id, op) = (event.event_id, event.data); tracing::trace!(id, "Store operation: {:?}", op); @@ -845,10 +845,7 @@ impl KvIndexerInterface for KvIndexerSharded { async fn apply_event(&mut self, event: RouterEvent) { #[allow(clippy::map_entry)] - if !self - .worker_assignments - .contains_key(&event.worker_id_general) - { + if !self.worker_assignments.contains_key(&event.worker) { // Get the shard with the smallest amount of workers. let selected_shard = self .worker_counts @@ -859,11 +856,11 @@ impl KvIndexerInterface for KvIndexerSharded { .0; self.worker_assignments - .insert(event.worker_id_general.clone(), selected_shard); + .insert(event.worker.clone(), selected_shard); self.worker_counts[selected_shard] += 1; } - self.event_tx[self.worker_assignments[&event.worker_id_general]] + self.event_tx[self.worker_assignments[&event.worker]] .send(event) .await .unwrap(); @@ -927,7 +924,7 @@ mod tests { parent: Option, ) -> RouterEvent { RouterEvent { - worker_id_general: worker_id, + worker: worker_id, event: KvCacheEvent { event_id, data: add_blocks(hashes, parent), @@ -941,7 +938,7 @@ mod tests { hashes: Vec, ) -> RouterEvent { RouterEvent { - worker_id_general: worker_id, + worker: worker_id, event: KvCacheEvent { event_id, data: KvCacheEventData::Removed(KvCacheRemoveData { @@ -1518,7 +1515,7 @@ mod tests { }; let router_event = RouterEvent::new(worker_id, kv_cache_event); - assert_eq!(router_event.worker_id_general, worker_id); + assert_eq!(router_event.worker, worker_id); assert_eq!(router_event.event.event_id, 1); if let KvCacheEventData::Stored(store_op) = &router_event.event.data { assert_eq!(store_op.blocks.len(), 1); diff --git a/lib/llm/src/kv_router/protocols.rs b/lib/llm/src/kv_router/protocols.rs index f59af5f05e..460637c817 100644 --- a/lib/llm/src/kv_router/protocols.rs +++ b/lib/llm/src/kv_router/protocols.rs @@ -24,7 +24,7 @@ pub type WorkerId = i64; pub type DpRank = u32; #[derive(Hash, PartialEq, Eq, Debug, Clone, Copy, Serialize, Deserialize, Default)] -pub struct WorkerWithDpRank { +pub struct WorkerDp { pub worker_id: WorkerId, pub dp_rank: Option, } @@ -46,13 +46,13 @@ pub struct RouterRequest { #[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct RouterResponse { - pub worker_id_general: T, + pub worker: T, } #[derive(Debug)] pub struct WorkerSelectionResult { /// The worker id of the selected worker - pub worker_id_general: T, + pub worker: T, /// The total number of blocks required to prefill the request pub required_blocks: u64, @@ -162,7 +162,7 @@ pub struct KvCacheRemoveData { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct KVHitRateEvent { - pub worker_id_general: T, + pub worker: T, pub isl_blocks: usize, pub overlap_blocks: usize, } @@ -171,7 +171,7 @@ pub struct KVHitRateEvent { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct RouterEvent { /// The ID of the worker emitting the event. - pub worker_id_general: T, + pub worker: T, /// The cache event associated with the worker. pub event: KvCacheEvent, } @@ -187,11 +187,8 @@ impl RouterEvent { /// ### Returns /// /// A new `RouterEvent`. - pub fn new(worker_id_general: T, event: KvCacheEvent) -> Self { - Self { - worker_id_general, - event, - } + pub fn new(worker: T, event: KvCacheEvent) -> Self { + Self { worker, event } } } diff --git a/lib/llm/src/kv_router/scheduler.rs b/lib/llm/src/kv_router/scheduler.rs index 2bb8d3ab3c..b4062e9919 100644 --- a/lib/llm/src/kv_router/scheduler.rs +++ b/lib/llm/src/kv_router/scheduler.rs @@ -24,7 +24,7 @@ pub use crate::kv_router::protocols::ForwardPassMetrics; use crate::kv_router::scoring::ProcessedEndpoints; use crate::kv_router::KV_HIT_RATE_SUBJECT; -use super::protocols::{KVHitRateEvent, WorkerSelectionResult, WorkerWithDpRank}; +use super::protocols::{KVHitRateEvent, WorkerDp, WorkerSelectionResult}; use super::WorkerSelector; #[derive(Debug, thiserror::Error)] @@ -41,12 +41,12 @@ pub enum KvSchedulerError { pub struct SchedulingRequest { pub isl_tokens: usize, - pub overlap: OverlapScores, - resp_tx: tokio::sync::oneshot::Sender, + pub overlap: OverlapScores, + resp_tx: tokio::sync::oneshot::Sender, } impl SchedulingRequest { - pub fn respond(self, identifier: WorkerWithDpRank) { + pub fn respond(self, identifier: WorkerDp) { if self.resp_tx.send(identifier).is_err() { tracing::trace!("failed to send response to requestor"); } @@ -69,7 +69,7 @@ impl KvScheduler { let mut endpoints: ProcessedEndpoints = endpoints_rx.borrow_and_update().clone(); let (event_tx, event_rx) = - tokio::sync::mpsc::unbounded_channel::>(); + tokio::sync::mpsc::unbounded_channel::>(); tokio::spawn(async move { let mut event_rx = event_rx; while let Some(event) = event_rx.recv().await { @@ -147,9 +147,9 @@ impl KvScheduler { pub async fn schedule( &self, - overlap: OverlapScores, + overlap: OverlapScores, isl_tokens: usize, - ) -> Result { + ) -> Result { let (resp_tx, resp_rx) = tokio::sync::oneshot::channel(); let request = SchedulingRequest { isl_tokens, @@ -170,15 +170,15 @@ impl KvScheduler { // This becomes the driver function that handles the selection result pub fn process_worker_selection( workers: &mut ProcessedEndpoints, - selection: WorkerSelectionResult, - event_tx: &tokio::sync::mpsc::UnboundedSender>, -) -> WorkerWithDpRank { + selection: WorkerSelectionResult, + event_tx: &tokio::sync::mpsc::UnboundedSender>, +) -> WorkerDp { let worker = workers .endpoints - .get_mut(&selection.worker_id_general.worker_id) + .get_mut(&selection.worker.worker_id) .expect("worker not found"); - let dp_rank = selection.worker_id_general.dp_rank.unwrap_or(0) as usize; + let dp_rank = selection.worker.dp_rank.unwrap_or(0) as usize; // Update worker state predictively // Will be overwritten on next polling of metrics @@ -191,14 +191,14 @@ pub fn process_worker_selection( // Emit event if let Err(e) = event_tx.send(KVHitRateEvent { - worker_id_general: selection.worker_id_general, + worker: selection.worker, isl_blocks: selection.required_blocks as usize, overlap_blocks: selection.overlap_blocks, }) { tracing::warn!("Failed to send KV hit rate event: {:?}", e); } - selection.worker_id_general + selection.worker } // Default implementation matching the Python _cost_function @@ -211,7 +211,7 @@ impl WorkerSelector for DefaultWorkerSelector { workers: &ProcessedEndpoints, request: &SchedulingRequest, block_size: usize, - ) -> Result, KvSchedulerError> { + ) -> Result, KvSchedulerError> { assert!(request.isl_tokens > 0); if workers.endpoints.is_empty() { @@ -224,7 +224,7 @@ impl WorkerSelector for DefaultWorkerSelector { // Calculate worker scores and find max waiting requests for (worker_id, ep) in workers.endpoints.iter() { for dp_rank in ep.data.iter().map(|metrics| metrics.data_parallel_rank) { - let worker_with_dp_rank = WorkerWithDpRank { + let worker_with_dp_rank = WorkerDp { worker_id: *worker_id, dp_rank, }; @@ -252,7 +252,7 @@ impl WorkerSelector for DefaultWorkerSelector { let worker_id = *worker_id; for fwd_pass_metrics in ep.data.iter() { let dp_rank = fwd_pass_metrics.data_parallel_rank; - let worker_with_dp_rank = WorkerWithDpRank { worker_id, dp_rank }; + let worker_with_dp_rank = WorkerDp { worker_id, dp_rank }; // Get score or default to 0.0 let score = worker_scores @@ -319,7 +319,7 @@ impl WorkerSelector for DefaultWorkerSelector { .unwrap_or(0) as usize; Ok(WorkerSelectionResult { - worker_id_general: best_worker_and_dp, + worker: best_worker_and_dp, required_blocks: total_blocks, overlap_blocks, }) From 65ea6b5b2572bd66061b722bbc3b533339f423b7 Mon Sep 17 00:00:00 2001 From: PeaBrane Date: Mon, 2 Jun 2025 17:34:53 -0700 Subject: [PATCH 09/22] allow suffix for dp_rank identification --- components/metrics/src/lib.rs | 12 +- lib/bindings/python/rust/llm/kv.rs | 27 +++-- lib/llm/src/kv_router/metrics_aggregator.rs | 2 +- lib/llm/src/kv_router/protocols.rs | 9 ++ lib/llm/src/kv_router/publisher.rs | 9 +- lib/llm/src/kv_router/scheduler.rs | 115 ++++++++------------ lib/llm/src/kv_router/scoring.rs | 35 ++++-- 7 files changed, 111 insertions(+), 98 deletions(-) diff --git a/components/metrics/src/lib.rs b/components/metrics/src/lib.rs index 4de77617ed..023dd08a6a 100644 --- a/components/metrics/src/lib.rs +++ b/components/metrics/src/lib.rs @@ -456,31 +456,31 @@ impl PrometheusMetrics { &self.kv_blocks_active, config, &worker_id, - metrics[0].kv_active_blocks as f64, + metrics.kv_active_blocks as f64, ); self.set_worker_gauge( &self.kv_blocks_total, config, &worker_id, - metrics[0].kv_total_blocks as f64, + metrics.kv_total_blocks as f64, ); self.set_worker_gauge( &self.requests_active, config, &worker_id, - metrics[0].request_active_slots as f64, + metrics.request_active_slots as f64, ); self.set_worker_gauge( &self.requests_total, config, &worker_id, - metrics[0].request_total_slots as f64, + metrics.request_total_slots as f64, ); self.set_worker_gauge( &self.kv_hit_rate_percent, config, &worker_id, - metrics[0].gpu_prefix_cache_hit_rate as f64, + metrics.gpu_prefix_cache_hit_rate as f64, ); } @@ -603,7 +603,7 @@ pub fn postprocess_metrics( e.id().ok().map(|id| Endpoint { name: format!("worker-{id}"), subject: e.subject.clone(), - data: vec![m.clone()], + data: m.clone(), }) }) .collect(); diff --git a/lib/bindings/python/rust/llm/kv.rs b/lib/bindings/python/rust/llm/kv.rs index 052317098e..26b3d2776a 100644 --- a/lib/bindings/python/rust/llm/kv.rs +++ b/lib/bindings/python/rust/llm/kv.rs @@ -96,17 +96,18 @@ impl WorkerMetricsPublisher { }) } - #[pyo3(signature = (component))] + #[pyo3(signature = (component, dp_rank = None))] fn create_endpoint<'p>( &self, py: Python<'p>, component: Component, + dp_rank: Option, ) -> PyResult> { let rs_publisher = self.inner.clone(); let rs_component = component.inner.clone(); pyo3_async_runtimes::tokio::future_into_py(py, async move { rs_publisher - .create_endpoint(rs_component) + .create_endpoint(rs_component, dp_rank.as_ref().map(|v| v.to_string()).as_deref()) .await .map_err(to_pyerr)?; Ok(()) @@ -449,18 +450,16 @@ impl KvMetricsAggregator { let endpoint_kv_metrics = endpoints .endpoints .iter() - .flat_map(|(worker_id, x)| { - x.data.iter().map(move |data_item| EndpointKvMetrics { - worker_id: *worker_id, - dp_rank: data_item.data_parallel_rank, - request_active_slots: data_item.request_active_slots, - request_total_slots: data_item.request_total_slots, - kv_active_blocks: data_item.kv_active_blocks, - kv_total_blocks: data_item.kv_total_blocks, - num_requests_waiting: data_item.num_requests_waiting, - gpu_cache_usage_perc: data_item.gpu_cache_usage_perc, - gpu_prefix_cache_hit_rate: data_item.gpu_prefix_cache_hit_rate, - }) + .map(|(worker_dp, x)| EndpointKvMetrics { + worker_id: worker_dp.worker_id, + dp_rank: worker_dp.dp_rank, + request_active_slots: x.data.request_active_slots, + request_total_slots: x.data.request_total_slots, + kv_active_blocks: x.data.kv_active_blocks, + kv_total_blocks: x.data.kv_total_blocks, + num_requests_waiting: x.data.num_requests_waiting, + gpu_cache_usage_perc: x.data.gpu_cache_usage_perc, + gpu_prefix_cache_hit_rate: x.data.gpu_prefix_cache_hit_rate, }) .collect(); pyo3_async_runtimes::tokio::future_into_py(py, async move { diff --git a/lib/llm/src/kv_router/metrics_aggregator.rs b/lib/llm/src/kv_router/metrics_aggregator.rs index 6d824cefdf..2fd887a14f 100644 --- a/lib/llm/src/kv_router/metrics_aggregator.rs +++ b/lib/llm/src/kv_router/metrics_aggregator.rs @@ -118,7 +118,7 @@ pub async fn collect_endpoints_task( .into_iter() .filter(|s| s.data.is_some()) .filter_map(|s| - match s.data.unwrap().decode::>() { + match s.data.unwrap().decode::() { Ok(data) => Some(Endpoint { name: s.name, subject: s.subject, diff --git a/lib/llm/src/kv_router/protocols.rs b/lib/llm/src/kv_router/protocols.rs index 460637c817..4d2cab7bcf 100644 --- a/lib/llm/src/kv_router/protocols.rs +++ b/lib/llm/src/kv_router/protocols.rs @@ -29,6 +29,15 @@ pub struct WorkerDp { pub dp_rank: Option, } +impl std::fmt::Display for WorkerDp { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self.dp_rank { + Some(dp_rank) => write!(f, "{}_{}", self.worker_id, dp_rank), + None => write!(f, "{}", self.worker_id), + } + } +} + pub trait WorkerGeneral: Hash + Eq + Debug + Clone + Send + Sync + Default + 'static + Serialize { diff --git a/lib/llm/src/kv_router/publisher.rs b/lib/llm/src/kv_router/publisher.rs index 3029c0308f..5a4c6f1a2d 100644 --- a/lib/llm/src/kv_router/publisher.rs +++ b/lib/llm/src/kv_router/publisher.rs @@ -491,13 +491,18 @@ impl WorkerMetricsPublisher { self.tx.send(metrics) } - pub async fn create_endpoint(&self, component: Component) -> Result<()> { + pub async fn create_endpoint(&self, component: Component, suffix: Option<&str>) -> Result<()> { let mut metrics_rx = self.rx.clone(); let handler = Arc::new(KvLoadEndpoingHander::new(metrics_rx.clone())); let handler = Ingress::for_engine(handler)?; + let endpoint_name = match suffix { + Some(s) => format!("{}_{}", KV_METRICS_ENDPOINT, s), + None => KV_METRICS_ENDPOINT.to_string(), + }; + component - .endpoint(KV_METRICS_ENDPOINT) + .endpoint(&endpoint_name) .endpoint_builder() .stats_handler(move |_| { let metrics = metrics_rx.borrow_and_update().clone(); diff --git a/lib/llm/src/kv_router/scheduler.rs b/lib/llm/src/kv_router/scheduler.rs index b4062e9919..bf93dc197c 100644 --- a/lib/llm/src/kv_router/scheduler.rs +++ b/lib/llm/src/kv_router/scheduler.rs @@ -175,17 +175,15 @@ pub fn process_worker_selection( ) -> WorkerDp { let worker = workers .endpoints - .get_mut(&selection.worker.worker_id) + .get_mut(&selection.worker) .expect("worker not found"); - let dp_rank = selection.worker.dp_rank.unwrap_or(0) as usize; - // Update worker state predictively // Will be overwritten on next polling of metrics - worker.data[dp_rank].num_requests_waiting += 1; + worker.data.num_requests_waiting += 1; // Assumes radix attention so KV load is only incremented by uncached blocks // overlap_blocks can be bigger than required_blocks. I don't know if that's a bug or not. - worker.data[dp_rank].kv_active_blocks += selection + worker.data.kv_active_blocks += selection .required_blocks .saturating_sub(selection.overlap_blocks as u64); @@ -222,22 +220,13 @@ impl WorkerSelector for DefaultWorkerSelector { let mut max_waiting = 0.0; // Calculate worker scores and find max waiting requests - for (worker_id, ep) in workers.endpoints.iter() { - for dp_rank in ep.data.iter().map(|metrics| metrics.data_parallel_rank) { - let worker_with_dp_rank = WorkerDp { - worker_id: *worker_id, - dp_rank, - }; - if let Some(score) = request.overlap.scores.get(&worker_with_dp_rank) { - let score = *score as f64 * block_size as f64 / request.isl_tokens as f64; - worker_scores.insert(worker_with_dp_rank, score); - } - // Track max waiting requests - max_waiting = f64::max( - max_waiting, - ep.data[dp_rank.unwrap_or(0) as usize].num_requests_waiting as f64, - ); + for (worker_dp, ep) in workers.endpoints.iter() { + if let Some(score) = request.overlap.scores.get(worker_dp) { + let score = *score as f64 * block_size as f64 / request.isl_tokens as f64; + worker_scores.insert(worker_dp, score); } + // Track max waiting requests + max_waiting = f64::max(max_waiting, ep.data.num_requests_waiting as f64); } // make immutable @@ -246,80 +235,70 @@ impl WorkerSelector for DefaultWorkerSelector { // Calculate logits for each worker let mut best_logit = f64::NEG_INFINITY; - let mut best_workers = Vec::new(); - - for (worker_id, ep) in workers.endpoints.iter() { - let worker_id = *worker_id; - for fwd_pass_metrics in ep.data.iter() { - let dp_rank = fwd_pass_metrics.data_parallel_rank; - let worker_with_dp_rank = WorkerDp { worker_id, dp_rank }; - - // Get score or default to 0.0 - let score = worker_scores - .get(&worker_with_dp_rank) - .copied() - .unwrap_or(0.0); - - // Calculate normalized metrics - let gpu_cache_usage = - ep.data[dp_rank.unwrap_or(0) as usize].gpu_cache_usage_perc as f64; - let normalized_waiting = if max_waiting > 0.0 { - ep.data[dp_rank.unwrap_or(0) as usize].num_requests_waiting as f64 / max_waiting - } else { - 0.0 - }; - - // Calculate logit using same formula as Python - let logit = 2.0 * score - gpu_cache_usage - normalized_waiting; - - tracing::trace!( - "Formula for {worker_id}: {logit:.3} = 2.0 * {score:.3} - {gpu_cache_usage:.3} - {normalized_waiting:.3}", - ); - - // Track best workers - match logit.partial_cmp(&best_logit) { - Some(std::cmp::Ordering::Greater) => { - best_logit = logit; - best_workers.clear(); - best_workers.push(worker_with_dp_rank); - } - Some(std::cmp::Ordering::Equal) => { - best_workers.push(worker_with_dp_rank); - } - _ => {} + let mut best_worker_dps = Vec::new(); + + for (worker_dp, ep) in workers.endpoints.iter() { + // Get score or default to 0.0 + let score = worker_scores.get(worker_dp).copied().unwrap_or(0.0); + + // Calculate normalized metrics + let gpu_cache_usage = ep.data.gpu_cache_usage_perc as f64; + let normalized_waiting = if max_waiting > 0.0 { + ep.data.num_requests_waiting as f64 / max_waiting + } else { + 0.0 + }; + + // Calculate logit using same formula as Python + let logit = 2.0 * score - gpu_cache_usage - normalized_waiting; + + tracing::trace!( + "Formula for {worker_dp:?}: {logit:.3} = 2.0 * {score:.3} - {gpu_cache_usage:.3} - {normalized_waiting:.3}", + ); + + // Track best workers + match logit.partial_cmp(&best_logit) { + Some(std::cmp::Ordering::Greater) => { + best_logit = logit; + best_worker_dps.clear(); + best_worker_dps.push(worker_dp); + } + Some(std::cmp::Ordering::Equal) => { + best_worker_dps.push(worker_dp); } + _ => {} } } // Return early if no valid workers found - if best_workers.is_empty() { + if best_worker_dps.is_empty() { return Err(KvSchedulerError::NoEndpoints); } else if best_logit == 0.0 { tracing::debug!("best worker logit is 0"); } - let best_worker_and_dp = if best_workers.len() == 1 { - best_workers[0] + let best_worker_dp = if best_worker_dps.len() == 1 { + best_worker_dps[0] } else { // Randomly select from best workers let mut rng = rand::rng(); - best_workers[rng.random_range(0..best_workers.len())] + best_worker_dps[rng.random_range(0..best_worker_dps.len())] }; // Lower to trace level eventually. Nice to see KV routing working for now. - tracing::debug!("Selected worker: {best_worker_and_dp:?}, logit: {best_logit:.3}"); + tracing::debug!("Selected worker: {best_worker_dp:?}, logit: {best_logit:.3}"); // Log selection metrics let total_blocks = std::cmp::max(request.isl_tokens / block_size, 1) as u64; let overlap_blocks = request .overlap .scores - .get(&best_worker_and_dp) + .get(best_worker_dp) .copied() .unwrap_or(0) as usize; Ok(WorkerSelectionResult { - worker: best_worker_and_dp, + worker: *best_worker_dp, required_blocks: total_blocks, overlap_blocks, }) diff --git a/lib/llm/src/kv_router/scoring.rs b/lib/llm/src/kv_router/scoring.rs index d1dcec3980..4feae1665f 100644 --- a/lib/llm/src/kv_router/scoring.rs +++ b/lib/llm/src/kv_router/scoring.rs @@ -18,16 +18,17 @@ use serde::{Deserialize, Serialize}; use std::collections::HashMap; -use crate::kv_router::protocols::{ForwardPassMetrics, WorkerId}; +use crate::kv_router::protocols::{DpRank, ForwardPassMetrics, WorkerDp}; /// [gluo FIXME] exactly the same as EndpointInfo except that 'data' /// is cleaned (not optional) #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Endpoint { pub name: String, + // contains dp pub subject: String, // one set of metrics for each dp worker - pub data: Vec, + pub data: ForwardPassMetrics, } impl Endpoint { @@ -43,11 +44,20 @@ impl Endpoint { ) .expect("invalid worker id") } -} + + pub fn dp_rank(&self) -> Option { + let parts: Vec<&str> = self.subject.split("-").collect(); + if parts.len() < 2 { + return None; + } + let second_to_last = parts[parts.len() - 2]; + second_to_last.parse::().ok() + } +} // TODO: make dp_rank #[derive(Debug, Default, Serialize, Deserialize, Clone)] pub struct ProcessedEndpoints { - pub endpoints: HashMap, + pub endpoints: HashMap, pub load_avg: f64, pub load_std: f64, } @@ -57,8 +67,7 @@ impl ProcessedEndpoints { // compute some basic statistics let load_values: Vec = endpoints .iter() - .flat_map(|endpoint| endpoint.data.iter()) - .map(|metrics| metrics.kv_active_blocks as f64) + .map(|endpoint| endpoint.data.kv_active_blocks as f64) .collect(); if load_values.is_empty() { panic!("No endpoints to process!") @@ -71,7 +80,19 @@ impl ProcessedEndpoints { / load_values.len() as f64; let load_std = variance.sqrt(); - let endpoints = endpoints.into_iter().map(|e| (e.worker_id(), e)).collect(); + // pass in (worker_id, dp_rank) tuple + let endpoints = endpoints + .into_iter() + .map(|e| { + ( + WorkerDp { + worker_id: e.worker_id(), + dp_rank: e.dp_rank(), + }, + e, + ) + }) + .collect(); ProcessedEndpoints { endpoints, From a2ef896d7ded39a30a5f5fd1f62be95dbbe518ad Mon Sep 17 00:00:00 2001 From: Alec Date: Tue, 3 Jun 2025 06:15:32 +0000 Subject: [PATCH 10/22] WIP: fix fn dp_rank, add TODO's --- .../dynamo-run/src/subprocess/vllm_v1_inc.py | 25 +++++++++++++------ lib/bindings/python/src/dynamo/_core.pyi | 2 +- lib/llm/src/kv_router/publisher.rs | 2 +- lib/llm/src/kv_router/scoring.rs | 7 ++++-- 4 files changed, 25 insertions(+), 11 deletions(-) diff --git a/launch/dynamo-run/src/subprocess/vllm_v1_inc.py b/launch/dynamo-run/src/subprocess/vllm_v1_inc.py index 04732e11f5..dca27aa4aa 100644 --- a/launch/dynamo-run/src/subprocess/vllm_v1_inc.py +++ b/launch/dynamo-run/src/subprocess/vllm_v1_inc.py @@ -23,7 +23,7 @@ import uvloop from vllm.config import VllmConfig -from vllm.distributed.kv_events import KVEventsConfig +from vllm.distributed.kv_events import KVEventsConfig, ZmqEventPublisher from vllm.engine.arg_utils import AsyncEngineArgs from vllm.inputs import TokensPrompt from vllm.sampling_params import SamplingParams @@ -68,7 +68,7 @@ class DynamoStatLoggerPublisher(StatLoggerBase): def __init__(self, component: Component, dp_rank: int) -> None: self.inner = WorkerMetricsPublisher() - self.inner.create_endpoint(component) + self.inner.create_endpoint(component, dp_rank=dp_rank) self.dp_rank = dp_rank def record( @@ -246,12 +246,23 @@ async def init(runtime: DistributedRuntime, config: Config): ) logger.info("VllmWorker has been initialized") + base_zmq_endpoint = "tcp://127.0.0.1:5557" + dp_rank_size = vllm_config.parallel_config.data_parallel_size + + # TODO This isn't working still + for dp_rank in range(dp_rank_size): + print(f"DP_RANK in Dynamo: {dp_rank}") + zmq_endpoint = ZmqEventPublisher.offset_endpoint_port( + base_zmq_endpoint, data_parallel_rank=dp_rank + ) + print(f"ZMQ_ENPDPOINT in Dynamo: {zmq_endpoint}") + zmq_config = ZmqKvEventPublisherConfig( + worker_id=endpoint.lease_id(), + kv_block_size=engine_args.block_size, + zmq_endpoint=zmq_endpoint, + ) - zmq_config = ZmqKvEventPublisherConfig( - worker_id=endpoint.lease_id(), kv_block_size=engine_args.block_size - ) - - _ = ZmqKvEventPublisher(component=component, config=zmq_config) + _ = ZmqKvEventPublisher(component=component, config=zmq_config) handler = RequestHandler(component, engine_client, default_sampling_params) diff --git a/lib/bindings/python/src/dynamo/_core.pyi b/lib/bindings/python/src/dynamo/_core.pyi index 424496fe41..73ffe2f19d 100644 --- a/lib/bindings/python/src/dynamo/_core.pyi +++ b/lib/bindings/python/src/dynamo/_core.pyi @@ -356,7 +356,7 @@ class WorkerMetricsPublisher: Create a `WorkerMetricsPublisher` object """ - def create_service(self, component: Component) -> None: + def create_endpoint(self, component: Component, dp_rank: int) -> None: """ Similar to Component.create_service, but only service created through this method will interact with KV router of the same component. diff --git a/lib/llm/src/kv_router/publisher.rs b/lib/llm/src/kv_router/publisher.rs index 5a4c6f1a2d..099d978049 100644 --- a/lib/llm/src/kv_router/publisher.rs +++ b/lib/llm/src/kv_router/publisher.rs @@ -497,7 +497,7 @@ impl WorkerMetricsPublisher { let handler = Ingress::for_engine(handler)?; let endpoint_name = match suffix { - Some(s) => format!("{}_{}", KV_METRICS_ENDPOINT, s), + Some(s) => format!("{}-{}", KV_METRICS_ENDPOINT, s), None => KV_METRICS_ENDPOINT.to_string(), }; diff --git a/lib/llm/src/kv_router/scoring.rs b/lib/llm/src/kv_router/scoring.rs index 4feae1665f..d968cc4a90 100644 --- a/lib/llm/src/kv_router/scoring.rs +++ b/lib/llm/src/kv_router/scoring.rs @@ -46,12 +46,14 @@ impl Endpoint { } pub fn dp_rank(&self) -> Option { + tracing::info!("Parsing dp_rank from subject: {}", self.subject); let parts: Vec<&str> = self.subject.split("-").collect(); - if parts.len() < 2 { + if parts.len() < 3 { return None; } let second_to_last = parts[parts.len() - 2]; - second_to_last.parse::().ok() + let result = second_to_last.parse::().ok(); + result } } // TODO: make dp_rank @@ -70,6 +72,7 @@ impl ProcessedEndpoints { .map(|endpoint| endpoint.data.kv_active_blocks as f64) .collect(); if load_values.is_empty() { + // TODO we hit this panic while vLLM is starting the ranks up. Need to avoid this panic!("No endpoints to process!") }; let load_avg = load_values.iter().copied().sum::() / load_values.len() as f64; From e80d66cdb7c2a01389ece616ea0a6ce3d5f26ef3 Mon Sep 17 00:00:00 2001 From: Alec Date: Tue, 3 Jun 2025 18:23:36 +0000 Subject: [PATCH 11/22] refactor: fix bugs, kv publishing working --- .../dynamo-run/src/subprocess/vllm_v1_inc.py | 20 ++++++--- lib/llm/src/kv_router/publisher.rs | 42 ++++++++++++------- lib/llm/src/kv_router/scoring.rs | 5 +-- 3 files changed, 44 insertions(+), 23 deletions(-) diff --git a/launch/dynamo-run/src/subprocess/vllm_v1_inc.py b/launch/dynamo-run/src/subprocess/vllm_v1_inc.py index dca27aa4aa..3cd5318a34 100644 --- a/launch/dynamo-run/src/subprocess/vllm_v1_inc.py +++ b/launch/dynamo-run/src/subprocess/vllm_v1_inc.py @@ -249,20 +249,30 @@ async def init(runtime: DistributedRuntime, config: Config): base_zmq_endpoint = "tcp://127.0.0.1:5557" dp_rank_size = vllm_config.parallel_config.data_parallel_size - # TODO This isn't working still + # Store references to prevent garbage collection + kv_publishers = [] + for dp_rank in range(dp_rank_size): - print(f"DP_RANK in Dynamo: {dp_rank}") zmq_endpoint = ZmqEventPublisher.offset_endpoint_port( base_zmq_endpoint, data_parallel_rank=dp_rank ) - print(f"ZMQ_ENPDPOINT in Dynamo: {zmq_endpoint}") zmq_config = ZmqKvEventPublisherConfig( worker_id=endpoint.lease_id(), kv_block_size=engine_args.block_size, zmq_endpoint=zmq_endpoint, ) - _ = ZmqKvEventPublisher(component=component, config=zmq_config) + try: + publisher = ZmqKvEventPublisher(component=component, config=zmq_config) + kv_publishers.append(publisher) + except Exception as e: + logger.error( + f"Failed to create ZmqKvEventPublisher for dp_rank {dp_rank}: {e}" + ) + + logger.debug( + f"Successfully created {len(kv_publishers)} ZmqKvEventPublishers out of {dp_rank_size} expected" + ) handler = RequestHandler(component, engine_client, default_sampling_params) @@ -324,7 +334,7 @@ def cmd_line_args(): endpoint_str = args.endpoint.replace("dyn://", "", 1) endpoint_parts = endpoint_str.split(".") if len(endpoint_parts) != 3: - logging.error( + logger.error( f"Invalid endpoint format: '{args.endpoint}'. Expected 'dyn://namespace.component.endpoint' or 'namespace.component.endpoint'." ) sys.exit(1) diff --git a/lib/llm/src/kv_router/publisher.rs b/lib/llm/src/kv_router/publisher.rs index 099d978049..e32fc6800b 100644 --- a/lib/llm/src/kv_router/publisher.rs +++ b/lib/llm/src/kv_router/publisher.rs @@ -189,15 +189,17 @@ async fn start_event_processor( cancellation_token: CancellationToken, mut rx: mpsc::UnboundedReceiver, ) { + tracing::debug!("KV Event processor starting for worker_id: {}", worker_id); + loop { tokio::select! { _ = cancellation_token.cancelled() => { - tracing::info!("KV Event source received cancellation signal"); + tracing::debug!("KV Event processor received cancellation signal for worker_id: {}", worker_id); break; } maybe_data = rx.recv() => { let Some(data) = maybe_data else { - tracing::debug!("Event processor channel closed."); + tracing::debug!("KV Event processor channel closed for worker_id: {}", worker_id); break; }; @@ -207,11 +209,12 @@ async fn start_event_processor( let router_event = RouterEvent::new((worker_id, dp_rank), event); if let Err(e) = publisher.publish(KV_EVENT_SUBJECT, &router_event).await { - tracing::error!("Failed to publish event: {}", e); + tracing::error!("Failed to publish event for worker_id: {}, dp_rank: {}, error: {}", worker_id, dp_rank, e); } } } } + tracing::debug!("KV Event processor exiting for worker_id: {}", worker_id); } // Error handling configuration for ZMQ operations @@ -236,7 +239,7 @@ async fn start_zmq_listener( kv_block_size: usize, ) { tracing::debug!( - "KVEventPublisher connecting to ZMQ endpoint {} (topic '{}')", + "ZMQ listener starting - connecting to endpoint: {}, topic: '{}'", zmq_endpoint, zmq_topic ); @@ -247,16 +250,19 @@ async fn start_zmq_listener( // Subscribe to the requested topic (empty string == all topics) if let Err(e) = socket.subscribe(&zmq_topic).await { - tracing::error!("Failed to subscribe on ZMQ socket: {}", e); + tracing::error!("Failed to subscribe on ZMQ socket for {}: {}", zmq_endpoint, e); return; } if let Err(e) = socket.connect(&zmq_endpoint).await { - tracing::error!("Failed to connect ZMQ SUB socket: {}", e); + tracing::error!("Failed to connect ZMQ SUB socket to {}: {}", zmq_endpoint, e); return; } + tracing::debug!("ZMQ listener successfully connected to {}", zmq_endpoint); + let mut consecutive_errors = 0u32; + let mut message_count = 0u64; loop { tokio::select! { @@ -264,7 +270,7 @@ async fn start_zmq_listener( // Check for cancellation _ = cancellation_token.cancelled() => { - tracing::info!("ZMQ listener received cancellation signal"); + tracing::debug!("ZMQ listener received cancellation signal for {}", zmq_endpoint); break; } @@ -278,6 +284,7 @@ async fn start_zmq_listener( tracing::error!( error=%e, consecutive_errors=%consecutive_errors, + endpoint=%zmq_endpoint, "Too many consecutive ZMQ errors, terminating listener" ); break; @@ -290,6 +297,7 @@ async fn start_zmq_listener( error=%e, consecutive_errors=%consecutive_errors, backoff_ms=%backoff_ms, + endpoint=%zmq_endpoint, "Error reading from ZMQ socket, applying exponential backoff" ); @@ -298,12 +306,13 @@ async fn start_zmq_listener( }; // Reset error count on successful message consecutive_errors = 0; + message_count += 1; // We expect multipart frames: [topic, seq, payload] let mut frames: Vec> = msg.into_vec().into_iter().map(|frame| frame.to_vec()).collect(); if frames.len() != 3 { - tracing::warn!(expected=3, actual=%frames.len(), "Received unexpected ZMQ frame count"); + tracing::warn!(expected=3, actual=%frames.len(), endpoint=%zmq_endpoint, "Received unexpected ZMQ frame count"); continue; } @@ -312,7 +321,7 @@ async fn start_zmq_listener( let seq_bytes = frames.pop().unwrap(); if seq_bytes.len() != 8 { - tracing::warn!(expected=8, actual=%seq_bytes.len(), "Invalid sequence number byte length"); + tracing::warn!(expected=8, actual=%seq_bytes.len(), endpoint=%zmq_endpoint, "Invalid sequence number byte length"); continue; } @@ -322,23 +331,25 @@ async fn start_zmq_listener( let batch_result = rmps::from_slice::(&payload); let Ok(batch) = batch_result else { let e = batch_result.unwrap_err(); - tracing::warn!(error=%e, "Failed to decode KVEventBatch msgpack"); + tracing::warn!(error=%e, endpoint=%zmq_endpoint, "Failed to decode KVEventBatch msgpack"); continue; }; + tracing::trace!("ZMQ listener decoded batch with {} events, dp_rank: {:?} from {}", batch.events.len(), batch.data_parallel_rank, zmq_endpoint); + // For each of our events, convert them to [`KvCacheEvent`] and send to the event_processor. - let dp_rank = batch.dp_rank; + let dp_rank = batch.data_parallel_rank; for raw_event in batch.events.into_iter() { let kv_cache_event = convert_event(raw_event, seq, kv_block_size, &warning_count); if tx.send(KvCacheEventWithDp { kv_cache_event, dp_rank }).is_err() { - tracing::warn!("Failed to send message to channel - receiver dropped"); + tracing::warn!("Failed to send message to channel - receiver dropped for {}", zmq_endpoint); return; } } } } - tracing::debug!("ZMQ listener exiting"); } + tracing::debug!("ZMQ listener exiting for {}", zmq_endpoint); } /// Convert a raw event coming from the ZMQ channel into the internal @@ -449,7 +460,8 @@ pub fn create_stored_blocks( struct KvEventBatch { ts: f64, events: Vec, - dp_rank: Option, + #[serde(alias = "dp_rank")] + data_parallel_rank: Option, } #[derive(Debug, Deserialize, Serialize)] @@ -795,7 +807,7 @@ mod tests_startup_helpers { let batch = KvEventBatch { ts: 0.0, events, - dp_rank: None, + data_parallel_rank: None, }; let payload = Bytes::from(rmps::to_vec(&batch).unwrap()); diff --git a/lib/llm/src/kv_router/scoring.rs b/lib/llm/src/kv_router/scoring.rs index d968cc4a90..25d9614ba2 100644 --- a/lib/llm/src/kv_router/scoring.rs +++ b/lib/llm/src/kv_router/scoring.rs @@ -52,10 +52,9 @@ impl Endpoint { return None; } let second_to_last = parts[parts.len() - 2]; - let result = second_to_last.parse::().ok(); - result + second_to_last.parse::().ok() } -} // TODO: make dp_rank +} #[derive(Debug, Default, Serialize, Deserialize, Clone)] pub struct ProcessedEndpoints { From 7a733bd741054ceaae060cd473a2cf78fc69d04e Mon Sep 17 00:00:00 2001 From: Alec Date: Wed, 4 Jun 2025 01:52:43 +0000 Subject: [PATCH 12/22] fix panicing metric thread issue --- lib/llm/src/kv_router/metrics_aggregator.rs | 15 ++++++++++----- lib/llm/src/kv_router/publisher.rs | 1 - 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/lib/llm/src/kv_router/metrics_aggregator.rs b/lib/llm/src/kv_router/metrics_aggregator.rs index 2fd887a14f..fc2e451314 100644 --- a/lib/llm/src/kv_router/metrics_aggregator.rs +++ b/lib/llm/src/kv_router/metrics_aggregator.rs @@ -133,11 +133,16 @@ pub async fn collect_endpoints_task( .collect(); tracing::trace!("Found {} endpoints for service: {service_subject}", endpoints.len()); - let processed = ProcessedEndpoints::new(endpoints); - - if watch_tx.send(processed).is_err() { - tracing::trace!("failed to send processed endpoints; shutting down"); - break; + // Only create and send ProcessedEndpoints if we have valid endpoints + if !endpoints.is_empty() { + let processed = ProcessedEndpoints::new(endpoints); + + if watch_tx.send(processed).is_err() { + tracing::trace!("failed to send processed endpoints; shutting down"); + break; + } + } else { + tracing::trace!("No valid endpoints found, skipping ProcessedEndpoints creation"); } } } diff --git a/lib/llm/src/kv_router/publisher.rs b/lib/llm/src/kv_router/publisher.rs index e32fc6800b..2a02d87388 100644 --- a/lib/llm/src/kv_router/publisher.rs +++ b/lib/llm/src/kv_router/publisher.rs @@ -262,7 +262,6 @@ async fn start_zmq_listener( tracing::debug!("ZMQ listener successfully connected to {}", zmq_endpoint); let mut consecutive_errors = 0u32; - let mut message_count = 0u64; loop { tokio::select! { From 1bddc8e4bf495229c1eab1459f32755f91915eda Mon Sep 17 00:00:00 2001 From: Alec Date: Wed, 4 Jun 2025 02:11:50 +0000 Subject: [PATCH 13/22] remove verbose log --- lib/llm/src/kv_router/scoring.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/lib/llm/src/kv_router/scoring.rs b/lib/llm/src/kv_router/scoring.rs index 25d9614ba2..c62f486400 100644 --- a/lib/llm/src/kv_router/scoring.rs +++ b/lib/llm/src/kv_router/scoring.rs @@ -46,7 +46,6 @@ impl Endpoint { } pub fn dp_rank(&self) -> Option { - tracing::info!("Parsing dp_rank from subject: {}", self.subject); let parts: Vec<&str> = self.subject.split("-").collect(); if parts.len() < 3 { return None; From ee283cc70ccf8f6ad03d62cf7f5d9146e6e5fcf5 Mon Sep 17 00:00:00 2001 From: Alec Date: Wed, 4 Jun 2025 02:39:19 +0000 Subject: [PATCH 14/22] update v1 worker --- examples/vllm_v1/components/worker.py | 116 +++++++++++++++++++++++--- lib/llm/src/kv_router/publisher.rs | 1 - 2 files changed, 106 insertions(+), 11 deletions(-) diff --git a/examples/vllm_v1/components/worker.py b/examples/vllm_v1/components/worker.py index d5567e1f24..07702ab746 100644 --- a/examples/vllm_v1/components/worker.py +++ b/examples/vllm_v1/components/worker.py @@ -23,19 +23,79 @@ from utils.args import parse_vllm_args from utils.protocol import MyRequestOutput, vLLMGenerateRequest -from vllm.entrypoints.openai.api_server import ( - build_async_engine_client_from_engine_args, +from vllm.config import VllmConfig +from vllm.distributed.kv_events import ZmqEventPublisher +from vllm.usage.usage_lib import UsageContext +from vllm.v1.engine.async_llm import AsyncLLM +from vllm.v1.metrics.loggers import StatLoggerBase +from vllm.v1.metrics.stats import IterationStats, SchedulerStats + +from dynamo.llm import ( + WorkerMetricsPublisher, + ZmqKvEventPublisher, + ZmqKvEventPublisherConfig, ) - -from dynamo.sdk import async_on_start, endpoint, service +from dynamo.runtime import Component +from dynamo.sdk import async_on_start, dynamo_context, endpoint, service logger = logging.getLogger(__name__) +class DynamoStatLoggerPublisher(StatLoggerBase): + """Stat logger publisher. Wrapper for the WorkerMetricsPublisher to match the StatLoggerBase interface.""" + + def __init__(self, component: Component, dp_rank: int) -> None: + self.inner = WorkerMetricsPublisher() + self.inner.create_endpoint(component, dp_rank=dp_rank) + self.dp_rank = dp_rank + + def record( + self, scheduler_stats: SchedulerStats, iteration_stats: Optional[IterationStats] + ): + # request_total_slots and kv_total_blocks are properties of model + gpu + # we should only publish them once, not every metric update + # they should be part of some runtime metadata tied to MDC or put in etcd ? + hit_rate = 0 + if scheduler_stats.prefix_cache_stats.queries > 0: + hit_rate = ( + scheduler_stats.prefix_cache_stats.hits + / scheduler_stats.prefix_cache_stats.queries + ) + + # TODO Manage DP Ranks in metrics aggregation. + self.inner.publish( + request_active_slots=scheduler_stats.num_running_reqs, + request_total_slots=0, # TODO - remove from metrics + kv_active_blocks=0, # TODO - need to calculate this + kv_total_blocks=0, # TODO - remove from metrics + num_requests_waiting=scheduler_stats.num_waiting_reqs, # used in current cost function + gpu_cache_usage_perc=scheduler_stats.gpu_cache_usage, # used in current cost function + gpu_prefix_cache_hit_rate=hit_rate, + data_parallel_rank=self.dp_rank, + ) + + def log_engine_initialized(self) -> None: + pass + + +class StatLoggerFactory: + """Factory for creating stat logger publishers. Required by vLLM.""" + + def __init__(self, component: Component) -> None: + self.component = component + + def create_stat_logger(self, dp_rank: int) -> StatLoggerBase: + return DynamoStatLoggerPublisher(self.component, dp_rank) + + def __call__(self, vllm_config: VllmConfig, dp_rank: int) -> StatLoggerBase: + return self.create_stat_logger(dp_rank=dp_rank) + + class VllmBaseWorker: def __init__(self): class_name = self.__class__.__name__ self.engine_args = parse_vllm_args(class_name, "") + self.kv_publishers = [] signal.signal(signal.SIGTERM, self.shutdown_vllm_engine) signal.signal(signal.SIGINT, self.shutdown_vllm_engine) @@ -43,22 +103,58 @@ def __init__(self): self.set_side_channel_port() async def async_init(self): - self._engine_context = build_async_engine_client_from_engine_args( - self.engine_args + # Taken from build_async_engine_client_from_engine_args() + usage_context = UsageContext.OPENAI_API_SERVER + vllm_config = self.engine_args.create_engine_config(usage_context=usage_context) + + # Explicitly pass our custom stat logger for metrics + self.engine_client = AsyncLLM.from_vllm_config( + vllm_config=vllm_config, + usage_context=usage_context, + stat_loggers=[StatLoggerFactory(dynamo_context["component"])], + disable_log_requests=self.engine_args.disable_log_requests, + disable_log_stats=self.engine_args.disable_log_stats, ) - if self._engine_context is not None: - self.engine_client = await self._engine_context.__aenter__() - else: - raise RuntimeError("Failed to initialize engine client") logger.info("VllmWorker has been initialized") + base_zmq_endpoint = "tcp://127.0.0.1:5557" + dp_rank_size = vllm_config.parallel_config.data_parallel_size + + # Store references to prevent garbage collection + + for dp_rank in range(dp_rank_size): + zmq_endpoint = ZmqEventPublisher.offset_endpoint_port( + base_zmq_endpoint, data_parallel_rank=dp_rank + ) + zmq_config = ZmqKvEventPublisherConfig( + worker_id=dynamo_context["endpoints"][0].lease_id(), + kv_block_size=self.engine_args.block_size, + zmq_endpoint=zmq_endpoint, + ) + + try: + publisher = ZmqKvEventPublisher( + component=dynamo_context["component"], config=zmq_config + ) + self.kv_publishers.append(publisher) + except Exception as e: + logger.error( + f"Failed to create ZmqKvEventPublisher for dp_rank {dp_rank}: {e}" + ) + + logger.debug( + f"Successfully created {len(self.kv_publishers)} ZmqKvEventPublishers out of {dp_rank_size} expected" + ) + def shutdown_vllm_engine(self, signum, frame): """Shutdown the background loop""" logger.info(f"Received signal {signum}, shutting down") loop = asyncio.get_event_loop() try: self.engine_client.close() + for publisher in self.kv_publishers: + publisher.shutdown() logger.info("VllmWorker shutdown complete") except Exception as e: logger.error(f"Error during shutdown: {e}") diff --git a/lib/llm/src/kv_router/publisher.rs b/lib/llm/src/kv_router/publisher.rs index 2a02d87388..c3897c19c9 100644 --- a/lib/llm/src/kv_router/publisher.rs +++ b/lib/llm/src/kv_router/publisher.rs @@ -305,7 +305,6 @@ async fn start_zmq_listener( }; // Reset error count on successful message consecutive_errors = 0; - message_count += 1; // We expect multipart frames: [topic, seq, payload] let mut frames: Vec> = msg.into_vec().into_iter().map(|frame| frame.to_vec()).collect(); From 183a8fe7f306100a7e7b98b13597346745220bb9 Mon Sep 17 00:00:00 2001 From: PeaBrane Date: Tue, 3 Jun 2025 22:18:39 -0700 Subject: [PATCH 15/22] put dp_rank in PreprocessedRequest --- examples/vllm_v1/utils/protocol.py | 2 ++ lib/llm/src/kv_router.rs | 1 + lib/llm/src/protocols/common/preprocessor.rs | 4 ++++ 3 files changed, 7 insertions(+) diff --git a/examples/vllm_v1/utils/protocol.py b/examples/vllm_v1/utils/protocol.py index 0d83dda371..4a131c6d7b 100644 --- a/examples/vllm_v1/utils/protocol.py +++ b/examples/vllm_v1/utils/protocol.py @@ -61,6 +61,8 @@ class PreprocessedRequest(BaseModel): eos_token_ids: List[TokenIdType] = Field(default_factory=list) mdc_sum: Optional[str] = None annotations: List[str] = Field(default_factory=list) + estimated_prefix_hit_num_blocks: Optional[int] = None + dp_rank: Optional[int] = None # Hack to override the type of multi_modal_data in TokensPrompt diff --git a/lib/llm/src/kv_router.rs b/lib/llm/src/kv_router.rs index 3b49cee3ae..2a0ced6bdd 100644 --- a/lib/llm/src/kv_router.rs +++ b/lib/llm/src/kv_router.rs @@ -207,6 +207,7 @@ impl AsyncEngine, ManyOut>, Er // Update the request with the estimated prefix hit blocks let (mut backend_input, context) = request.into_parts(); backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount); + backend_input.dp_rank = instance_id.dp_rank; let updated_request = context.map(|_| backend_input); // TODO: this does not do dp routing self.inner diff --git a/lib/llm/src/protocols/common/preprocessor.rs b/lib/llm/src/protocols/common/preprocessor.rs index 6b3be76069..151b9dbfcc 100644 --- a/lib/llm/src/protocols/common/preprocessor.rs +++ b/lib/llm/src/protocols/common/preprocessor.rs @@ -51,6 +51,10 @@ pub struct PreprocessedRequest { /// Estimated number of prefix hit tokens (only used in kv aware routing) #[builder(default)] pub estimated_prefix_hit_num_blocks: Option, + + // The dp_rank to route to + #[builder(default)] + pub dp_rank: Option, } impl PreprocessedRequest { From be7f951131457a80e5cd30c36015e93ae8725484 Mon Sep 17 00:00:00 2001 From: PeaBrane Date: Tue, 3 Jun 2025 22:24:52 -0700 Subject: [PATCH 16/22] new agg config --- examples/vllm_v1/configs/agg.yaml | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/examples/vllm_v1/configs/agg.yaml b/examples/vllm_v1/configs/agg.yaml index 15b7c378e1..bcfb6c89a3 100644 --- a/examples/vllm_v1/configs/agg.yaml +++ b/examples/vllm_v1/configs/agg.yaml @@ -14,6 +14,10 @@ # limitations under the License. Common: model: deepseek-ai/DeepSeek-R1-Distill-Llama-8B + data-parallel-size: 2 + router: kv + block-size: 64 + max-model-len: 16384 served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B Frontend: @@ -21,14 +25,21 @@ Frontend: port: 8000 served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B +Router: + min-workers: 2 + common-configs: [model, block-size, router, served_model_name, data-parallel-size] + SimpleLoadBalancer: enable_disagg: false common-configs: [model, served_model_name] VllmDecodeWorker: enforce-eager: true + max-num-batched-tokens: 16384 + enable-prefix-caching: true ServiceArgs: - workers: 1 + workers: 2 # 2 workers resources: - gpu: 1 - common-configs: [model, served_model_name] + gpu: 2 # 2 dp ranks + common-configs: [model, served_model_name, block-size, data-parallel-size, max-model-len] + From e1011d86e6f08634207966f65f12c755e2713443 Mon Sep 17 00:00:00 2001 From: PeaBrane Date: Tue, 3 Jun 2025 22:44:45 -0700 Subject: [PATCH 17/22] updated comments --- lib/llm/src/kv_router.rs | 1 - lib/llm/src/kv_router/protocols.rs | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/llm/src/kv_router.rs b/lib/llm/src/kv_router.rs index 2a0ced6bdd..6a269adeeb 100644 --- a/lib/llm/src/kv_router.rs +++ b/lib/llm/src/kv_router.rs @@ -209,7 +209,6 @@ impl AsyncEngine, ManyOut>, Er backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount); backend_input.dp_rank = instance_id.dp_rank; let updated_request = context.map(|_| backend_input); - // TODO: this does not do dp routing self.inner .direct(updated_request, instance_id.worker_id) .await diff --git a/lib/llm/src/kv_router/protocols.rs b/lib/llm/src/kv_router/protocols.rs index 4d2cab7bcf..730be4929d 100644 --- a/lib/llm/src/kv_router/protocols.rs +++ b/lib/llm/src/kv_router/protocols.rs @@ -38,6 +38,7 @@ impl std::fmt::Display for WorkerDp { } } +// Cannot add DeserializedOwned otherwise compiler will complain pub trait WorkerGeneral: Hash + Eq + Debug + Clone + Send + Sync + Default + 'static + Serialize { From 5bf4faefb5848f9ccd161a6bbfe464906457c378 Mon Sep 17 00:00:00 2001 From: Alec Date: Wed, 4 Jun 2025 06:00:54 +0000 Subject: [PATCH 18/22] update v1 example --- examples/vllm_v1/components/frontend.py | 13 +- .../components/simple_load_balancer.py | 199 ------------------ examples/vllm_v1/components/worker.py | 87 ++++++-- examples/vllm_v1/configs/agg.yaml | 15 +- examples/vllm_v1/graphs/agg.py | 4 +- 5 files changed, 79 insertions(+), 239 deletions(-) delete mode 100644 examples/vllm_v1/components/simple_load_balancer.py diff --git a/examples/vllm_v1/components/frontend.py b/examples/vllm_v1/components/frontend.py index a0f86e72db..5c58aa08f8 100644 --- a/examples/vllm_v1/components/frontend.py +++ b/examples/vllm_v1/components/frontend.py @@ -17,7 +17,7 @@ import subprocess from pathlib import Path -from components.simple_load_balancer import SimpleLoadBalancer +from components.worker import VllmDecodeWorker from fastapi import FastAPI from pydantic import BaseModel @@ -42,9 +42,8 @@ def get_dynamo_run_binary(): class FrontendConfig(BaseModel): """Configuration for the Frontend service including model and HTTP server settings.""" - served_model_name: str - endpoint: str port: int = 8080 + router_mode: str = "round-robin" # TODO: move these to common for all LLMs once we adopt dynamo-run @@ -58,7 +57,7 @@ class FrontendConfig(BaseModel): app=FastAPI(title="LLM Example"), ) class Frontend: - worker = depends(SimpleLoadBalancer) + worker = depends(VllmDecodeWorker) def __init__(self): """Initialize Frontend service with HTTP server and model configuration.""" @@ -74,20 +73,20 @@ def start_ingress_and_processor(self): f"Starting HTTP server and processor on port {self.frontend_config.port}" ) dynamo_run_binary = get_dynamo_run_binary() - endpoint = f"dyn://{self.frontend_config.endpoint}" logger.info( f"Starting HTTP server and processor on port {self.frontend_config.port}" ) - logger.info(f"Endpoint: {endpoint}") self.process = subprocess.Popen( [ dynamo_run_binary, "in=http", - f"out={endpoint}", + "out=dyn", "--http-port", str(self.frontend_config.port), + "--router-mode", + self.frontend_config.router_mode, ], stdout=None, stderr=None, diff --git a/examples/vllm_v1/components/simple_load_balancer.py b/examples/vllm_v1/components/simple_load_balancer.py deleted file mode 100644 index 9a0d3bfb87..0000000000 --- a/examples/vllm_v1/components/simple_load_balancer.py +++ /dev/null @@ -1,199 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import copy -import logging -import uuid -from typing import AsyncGenerator, Optional - -from components.worker import VllmDecodeWorker, VllmPrefillWorker -from utils.args import parse_vllm_args -from utils.protocol import MyRequestOutput, PreprocessedRequest, vLLMGenerateRequest -from vllm.inputs import TokensPrompt -from vllm.sampling_params import SamplingParams - -from dynamo.llm import ModelType, register_llm -from dynamo.sdk import async_on_start, depends, dynamo_context, endpoint, service - -logger = logging.getLogger(__name__) - - -@service( - dynamo={ - "enabled": True, - "namespace": "dynamo", - }, - resources={"cpu": "10", "memory": "20Gi"}, - workers=1, -) -class SimpleLoadBalancer: - prefill_worker = depends(VllmPrefillWorker) - decode_worker = depends(VllmDecodeWorker) - - def __init__(self): - class_name = self.__class__.__name__ - self.engine_args = parse_vllm_args(class_name, "") - model_config = self.engine_args.create_model_config() - self.default_sampling_params = model_config.get_diff_sampling_param() - self.enable_disagg = self.engine_args.enable_disagg - - @async_on_start - async def async_init(self): - runtime = dynamo_context["runtime"] - logger.info("Registering LLM for discovery") - comp_ns, comp_name = SimpleLoadBalancer.dynamo_address() # type: ignore - endpoint_name = "generate" - for served_model_name in self.engine_args.served_model_name: - logger.info( - f"Registering endpoint {endpoint_name} with model {self.engine_args.model} and served_model_name {served_model_name}" - ) - endpoint = ( - runtime.namespace(comp_ns).component(comp_name).endpoint(endpoint_name) - ) - await register_llm( - ModelType.Backend, - endpoint, - self.engine_args.model, - served_model_name, - ) - - comp_ns, comp_name = VllmDecodeWorker.dynamo_address() # type: ignore - self.decode_worker_client = ( - await runtime.namespace(comp_ns) - .component(comp_name) - .endpoint("generate") - .client() - ) - - comp_ns, comp_name = VllmPrefillWorker.dynamo_address() # type: ignore - self.prefill_worker_client = ( - await runtime.namespace(comp_ns) - .component(comp_name) - .endpoint("generate") - .client() - ) - - logger.info("SimpleLoadBalancer has been initialized") - - async def send_request_to_prefill( - self, request: vLLMGenerateRequest - ) -> MyRequestOutput: - logger.debug("Sending request to prefill") - - prefill_request = copy.deepcopy(request) - extra_args = prefill_request.sampling_params.extra_args or {} - extra_args["kv_transfer_params"] = { - "do_remote_decode": True, - } - prefill_request.sampling_params.extra_args = extra_args - prefill_request.sampling_params.max_tokens = 1 - prefill_request.sampling_params.min_tokens = 1 - - logger.debug("Prefill request: %s", prefill_request.model_dump_json()) - - async for prefill_response in await self.prefill_worker_client.round_robin( - prefill_request.model_dump_json() - ): - return MyRequestOutput.model_validate_json(prefill_response.data()) - - async def send_request_to_decode( - self, - request: vLLMGenerateRequest, - prefill_response: Optional[MyRequestOutput] = None, - ) -> AsyncGenerator[MyRequestOutput, None]: - logger.debug("Sending request to decode") - - decode_request = copy.deepcopy(request) - - if prefill_response: - extra_args = decode_request.sampling_params.extra_args or {} - extra_args["kv_transfer_params"] = prefill_response.kv_transfer_params - decode_request.sampling_params.extra_args = extra_args - - logger.debug("Decode request: %s", decode_request.model_dump_json()) - - async for decode_response in await self.decode_worker_client.round_robin( - decode_request.model_dump_json() - ): - yield MyRequestOutput.model_validate_json(decode_response.data()) - - @endpoint() - async def generate(self, request: PreprocessedRequest): - logger.debug( - "Processor received completion request: %s", request.model_dump_json() - ) - - vllm_request = self._create_vllm_request(request) - - logger.debug("VLLM request: %s", vllm_request.model_dump_json()) - - if self.enable_disagg: - prefill_response = await self.send_request_to_prefill(vllm_request) - - logger.debug("Prefill response: %s", prefill_response.model_dump_json()) - else: - prefill_response = None - - gen = self.send_request_to_decode(vllm_request, prefill_response) - async for res in self._stream_response(gen): - yield res - - def _create_vllm_request(self, request: PreprocessedRequest) -> vLLMGenerateRequest: - request_id = str(uuid.uuid4().hex) - - prompt = TokensPrompt(prompt_token_ids=request.token_ids) - - sampling_params = SamplingParams(**self.default_sampling_params) - for key, value in request.sampling_options.model_dump().items(): - if not value: - continue - if hasattr(sampling_params, key): - setattr(sampling_params, key, value) - - max_tokens = request.stop_conditions.max_tokens - if max_tokens: - sampling_params.max_tokens = max_tokens - - return vLLMGenerateRequest( - prompt=prompt, - sampling_params=sampling_params, - request_id=request_id, - ) - - async def _stream_response(self, gen: AsyncGenerator[MyRequestOutput, None]): - num_output_tokens_so_far = 0 - async for res in gen: - logger.debug("Decode response: %s", res.model_dump_json()) - # res is our MyRequestOutput - - # This is the expected way for a request to end. - # The new token ID will be eos, don't forward it. - if res.finished: - yield {"finish_reason": "stop", "token_ids": []} - break - - if not res.outputs: - yield {"finish_reason": "error", "token_ids": []} - break - - output = res.outputs[0] - next_total_toks = len(output.token_ids) - out = {"token_ids": output.token_ids[num_output_tokens_so_far:]} - if output.finish_reason: - out["finish_reason"] = output.finish_reason - if output.stop_reason: - out["stop_reason"] = output.stop_reason - yield out - num_output_tokens_so_far = next_total_toks diff --git a/examples/vllm_v1/components/worker.py b/examples/vllm_v1/components/worker.py index 07702ab746..1dfd9d8b76 100644 --- a/examples/vllm_v1/components/worker.py +++ b/examples/vllm_v1/components/worker.py @@ -19,21 +19,26 @@ import os import signal import socket +import uuid from typing import Optional from utils.args import parse_vllm_args -from utils.protocol import MyRequestOutput, vLLMGenerateRequest +from utils.protocol import PreprocessedRequest from vllm.config import VllmConfig from vllm.distributed.kv_events import ZmqEventPublisher +from vllm.inputs import TokensPrompt +from vllm.sampling_params import SamplingParams from vllm.usage.usage_lib import UsageContext from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.metrics.loggers import StatLoggerBase from vllm.v1.metrics.stats import IterationStats, SchedulerStats from dynamo.llm import ( + ModelType, WorkerMetricsPublisher, ZmqKvEventPublisher, ZmqKvEventPublisherConfig, + register_llm, ) from dynamo.runtime import Component from dynamo.sdk import async_on_start, dynamo_context, endpoint, service @@ -91,10 +96,20 @@ def __call__(self, vllm_config: VllmConfig, dp_rank: int) -> StatLoggerBase: return self.create_stat_logger(dp_rank=dp_rank) +BLOCK_SIZE = 16 + + class VllmBaseWorker: def __init__(self): class_name = self.__class__.__name__ self.engine_args = parse_vllm_args(class_name, "") + if not self.engine_args.block_size: + logger.info(f"block_size not set, default to {BLOCK_SIZE}") + self.engine_args.block_size = BLOCK_SIZE + + model_config = self.engine_args.create_model_config() + self.default_sampling_params = model_config.get_diff_sampling_param() + self.kv_publishers = [] signal.signal(signal.SIGTERM, self.shutdown_vllm_engine) @@ -107,6 +122,15 @@ async def async_init(self): usage_context = UsageContext.OPENAI_API_SERVER vllm_config = self.engine_args.create_engine_config(usage_context=usage_context) + await register_llm( + ModelType.Backend, + dynamo_context["endpoints"][0], + self.engine_args.model, + self.engine_args.served_model_name[0], + context_length=self.engine_args.max_model_len, + kv_cache_block_size=self.engine_args.block_size, + ) + # Explicitly pass our custom stat logger for metrics self.engine_client = AsyncLLM.from_vllm_config( vllm_config=vllm_config, @@ -152,7 +176,7 @@ def shutdown_vllm_engine(self, signum, frame): logger.info(f"Received signal {signum}, shutting down") loop = asyncio.get_event_loop() try: - self.engine_client.close() + self.engine_client.shutdown() for publisher in self.kv_publishers: publisher.shutdown() logger.info("VllmWorker shutdown complete") @@ -162,24 +186,51 @@ def shutdown_vllm_engine(self, signum, frame): loop.stop() @endpoint() - async def generate(self, request: vLLMGenerateRequest): + async def generate(self, request: PreprocessedRequest): + request_id = str(uuid.uuid4().hex) + + prompt = TokensPrompt(prompt_token_ids=request.token_ids) + + sampling_params = SamplingParams(**self.default_sampling_params) + for key, value in request.sampling_options.model_dump().items(): + if not value: + continue + if hasattr(sampling_params, key): + setattr(sampling_params, key, value) + + max_tokens = request.stop_conditions.max_tokens + if max_tokens: + sampling_params.max_tokens = max_tokens + gen = self.engine_client.generate( - prompt=request.prompt, - sampling_params=request.sampling_params, - request_id=request.request_id, + prompt=prompt, + sampling_params=sampling_params, + request_id=request_id, + data_parallel_rank=request.dp_rank, ) - - async for response in gen: - yield MyRequestOutput( - request_id=response.request_id, - prompt=response.prompt, - prompt_token_ids=response.prompt_token_ids, - prompt_logprobs=response.prompt_logprobs, - outputs=response.outputs, - finished=response.finished, - metrics=response.metrics, - kv_transfer_params=response.kv_transfer_params, - ).model_dump_json() + num_output_tokens_so_far = 0 + async for res in gen: + # res is vllm's RequestOutput + + # This is the expected way for a request to end. + # The new token ID will be eos, don't forward it. + if res.finished: + yield {"finish_reason": "stop", "token_ids": []} + break + + if not res.outputs: + yield {"finish_reason": "error", "token_ids": []} + break + + output = res.outputs[0] + next_total_toks = len(output.token_ids) + out = {"token_ids": output.token_ids[num_output_tokens_so_far:]} + if output.finish_reason: + out["finish_reason"] = output.finish_reason + if output.stop_reason: + out["stop_reason"] = output.stop_reason + yield out + num_output_tokens_so_far = next_total_toks def set_side_channel_port(self, port: Optional[int] = None): """vLLM V1 NixlConnector creates a side channel to exchange metadata with other NIXL connectors. diff --git a/examples/vllm_v1/configs/agg.yaml b/examples/vllm_v1/configs/agg.yaml index bcfb6c89a3..ac5efd1477 100644 --- a/examples/vllm_v1/configs/agg.yaml +++ b/examples/vllm_v1/configs/agg.yaml @@ -13,25 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. Common: - model: deepseek-ai/DeepSeek-R1-Distill-Llama-8B + model: Qwen/Qwen3-0.6B data-parallel-size: 2 router: kv block-size: 64 max-model-len: 16384 - served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B + served_model_name: Qwen/Qwen3-0.6B Frontend: - endpoint: dynamo.SimpleLoadBalancer.generate_agg port: 8000 - served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B - -Router: - min-workers: 2 - common-configs: [model, block-size, router, served_model_name, data-parallel-size] - -SimpleLoadBalancer: - enable_disagg: false - common-configs: [model, served_model_name] + router_mode: kv VllmDecodeWorker: enforce-eager: true diff --git a/examples/vllm_v1/graphs/agg.py b/examples/vllm_v1/graphs/agg.py index b7428756b3..95e02efab1 100644 --- a/examples/vllm_v1/graphs/agg.py +++ b/examples/vllm_v1/graphs/agg.py @@ -14,8 +14,6 @@ # limitations under the License. from components.frontend import Frontend -from components.simple_load_balancer import SimpleLoadBalancer from components.worker import VllmDecodeWorker -load_balancer = Frontend.link(SimpleLoadBalancer) -load_balancer.link(VllmDecodeWorker) +Frontend.link(VllmDecodeWorker) From d6ded6ca98c5ac539f7cdd0661df57e74d339bbe Mon Sep 17 00:00:00 2001 From: Alec Date: Wed, 4 Jun 2025 07:15:30 +0000 Subject: [PATCH 19/22] final touches for it working with dp --- examples/vllm_v1/components/worker.py | 7 ++++++- examples/vllm_v1/configs/agg.yaml | 5 ++--- examples/vllm_v1/utils/args.py | 3 +++ 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/examples/vllm_v1/components/worker.py b/examples/vllm_v1/components/worker.py index 1dfd9d8b76..21a2fbffcd 100644 --- a/examples/vllm_v1/components/worker.py +++ b/examples/vllm_v1/components/worker.py @@ -25,7 +25,7 @@ from utils.args import parse_vllm_args from utils.protocol import PreprocessedRequest from vllm.config import VllmConfig -from vllm.distributed.kv_events import ZmqEventPublisher +from vllm.distributed.kv_events import KVEventsConfig, ZmqEventPublisher from vllm.inputs import TokensPrompt from vllm.sampling_params import SamplingParams from vllm.usage.usage_lib import UsageContext @@ -103,10 +103,15 @@ class VllmBaseWorker: def __init__(self): class_name = self.__class__.__name__ self.engine_args = parse_vllm_args(class_name, "") + self.engine_args.kv_events_config = KVEventsConfig( + enable_kv_cache_events=True, publisher="zmq" + ) if not self.engine_args.block_size: logger.info(f"block_size not set, default to {BLOCK_SIZE}") self.engine_args.block_size = BLOCK_SIZE + os.environ["VLLM_NO_USAGE_STATS"] = "1" # Avoid internal HTTP requests + model_config = self.engine_args.create_model_config() self.default_sampling_params = model_config.get_diff_sampling_param() diff --git a/examples/vllm_v1/configs/agg.yaml b/examples/vllm_v1/configs/agg.yaml index ac5efd1477..141adf1ba4 100644 --- a/examples/vllm_v1/configs/agg.yaml +++ b/examples/vllm_v1/configs/agg.yaml @@ -15,8 +15,7 @@ Common: model: Qwen/Qwen3-0.6B data-parallel-size: 2 - router: kv - block-size: 64 + block-size: 16 max-model-len: 16384 served_model_name: Qwen/Qwen3-0.6B @@ -29,7 +28,7 @@ VllmDecodeWorker: max-num-batched-tokens: 16384 enable-prefix-caching: true ServiceArgs: - workers: 2 # 2 workers + workers: 1 # 2 workers resources: gpu: 2 # 2 dp ranks common-configs: [model, served_model_name, block-size, data-parallel-size, max-model-len] diff --git a/examples/vllm_v1/utils/args.py b/examples/vllm_v1/utils/args.py index f05976c8b8..6780b72a78 100644 --- a/examples/vllm_v1/utils/args.py +++ b/examples/vllm_v1/utils/args.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + # TODO: rename to avoid ambiguity with vllm package from vllm.engine.arg_utils import AsyncEngineArgs from vllm.utils import FlexibleArgumentParser @@ -23,6 +24,7 @@ def parse_vllm_args(service_name, prefix) -> AsyncEngineArgs: config = ServiceConfig.get_instance() vllm_args = config.as_args(service_name, prefix=prefix) + parser = FlexibleArgumentParser() parser.add_argument( "--enable-disagg", action="store_true", help="Enable disaggregation" @@ -31,4 +33,5 @@ def parse_vllm_args(service_name, prefix) -> AsyncEngineArgs: args = parser.parse_args(vllm_args) engine_args = AsyncEngineArgs.from_cli_args(args) engine_args.enable_disagg = args.enable_disagg + return engine_args From 9335efe8e623b102b8beed4270cb36dd2073127c Mon Sep 17 00:00:00 2001 From: PeaBrane Date: Wed, 4 Jun 2025 15:40:29 -0700 Subject: [PATCH 20/22] fix cost function trace --- lib/llm/src/kv_router/scheduler.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/lib/llm/src/kv_router/scheduler.rs b/lib/llm/src/kv_router/scheduler.rs index aa31114ca2..f9afafc555 100644 --- a/lib/llm/src/kv_router/scheduler.rs +++ b/lib/llm/src/kv_router/scheduler.rs @@ -266,7 +266,10 @@ impl WorkerSelector for DefaultWorkerSelector { - self.kv_router_config.waiting_requests_weight * normalized_waiting; tracing::trace!( - "Formula for {worker_dp:?}: {logit:.3} = 2.0 * {score:.3} - {gpu_cache_usage:.3} - {normalized_waiting:.3}", + "Formula for {worker_dp:?}: {logit:.3} = {:.3} * {score:.3} - {:.3} * {gpu_cache_usage:.3} - {:.3} * {normalized_waiting:.3}", + self.kv_router_config.overlap_score_weight, + self.kv_router_config.gpu_cache_usage_weight, + self.kv_router_config.waiting_requests_weight, ); // Track best workers From 931b8372d7ea501af40c37b85111de3554b2b6bb Mon Sep 17 00:00:00 2001 From: PeaBrane Date: Wed, 4 Jun 2025 15:44:28 -0700 Subject: [PATCH 21/22] fmt --- lib/bindings/python/rust/llm/kv.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/lib/bindings/python/rust/llm/kv.rs b/lib/bindings/python/rust/llm/kv.rs index 26b3d2776a..7f192930e9 100644 --- a/lib/bindings/python/rust/llm/kv.rs +++ b/lib/bindings/python/rust/llm/kv.rs @@ -107,7 +107,10 @@ impl WorkerMetricsPublisher { let rs_component = component.inner.clone(); pyo3_async_runtimes::tokio::future_into_py(py, async move { rs_publisher - .create_endpoint(rs_component, dp_rank.as_ref().map(|v| v.to_string()).as_deref()) + .create_endpoint( + rs_component, + dp_rank.as_ref().map(|v| v.to_string()).as_deref(), + ) .await .map_err(to_pyerr)?; Ok(()) From eb7bb101f575e751d2f68af71775a3295040b433 Mon Sep 17 00:00:00 2001 From: Alec Date: Thu, 5 Jun 2025 07:09:52 +0000 Subject: [PATCH 22/22] WIP document current work steps --- examples/vllm_v1/README.md | 91 +++++----------- .../vllm_v1/components/headless_worker.py | 100 ++++++++++++++++++ examples/vllm_v1/configs/agg.yaml | 19 ++-- lib/llm/src/kv_router/indexer.rs | 6 +- 4 files changed, 142 insertions(+), 74 deletions(-) create mode 100644 examples/vllm_v1/components/headless_worker.py diff --git a/examples/vllm_v1/README.md b/examples/vllm_v1/README.md index 39d6c0e1db..9434d06a99 100644 --- a/examples/vllm_v1/README.md +++ b/examples/vllm_v1/README.md @@ -17,16 +17,15 @@ limitations under the License. # vLLM Deployment Examples -This directory contains examples for deploying vLLM models in both aggregated and disaggregated configurations. +This directory contains examples for deploying vLLM models aggregated with with DP. ## Prerequisites 1. Install vLLM: ```bash -# Note: Currently requires installation from main branch -# From vLLM 0.8.6 onwards, you can install directly from wheel git clone https://github.com/vllm-project/vllm.git -VLLM_USE_PRECOMPILED=1 uv pip install --editable ./vllm/ +cd vllm && git checkout d459fae0a2c464e28680bc6d564c1de1b295029e +VLLM_USE_PRECOMPILED=1 uv pip install --editable . ``` 2. Start required services: @@ -36,78 +35,46 @@ docker compose -f deploy/metrics/docker-compose.yml up -d ## Running the Server -### Aggregated Deployment +### Aggregated Deployment with Multiple disconnected DP engines + +Serves the leader AsyncLLM engine + number of dp ranks you specify ```bash cd examples/vllm_v1 dynamo serve graphs.agg:Frontend -f configs/agg.yaml ``` -### Disaggregated Deployment -```bash -cd examples/vllm_v1 -dynamo serve graphs.disagg:Frontend -f configs/disagg.yaml +To run other dp ranks headless on same node or other nodes can run + +``` +VLLM_LOGGING_LEVEL=DEBUG CUDA_VISIBLE_DEVICES=1 VLLM_USE_V1=1 vllm serve Qwen/Qwen3-0.6B -dp 1 -dpr 1 --data-parallel-address 127.0.0.1 --data-parallel-rpc-port 62300 --data-parallel-size-local 1 --enforce-eager --headless --kv-events-config '{"enable_kv_cache_events": true, "publisher": "zmq"}' --enable-prefix-caching ``` -## Testing the API +To test can run this curl reqeust. KV Routing will mean this will keep routing to a single node, so you will need to switch it up to see routing to different dp workers. -Send a test request using curl: -```bash -curl localhost:8000/v1/completions \ +``` +curl localhost:8000/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ - "model": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B", - "prompt": "In the heart of Eldoria...", - "stream": false, + "model": "Qwen/Qwen3-0.6B", + "messages": [ + { + "role": "user", + "content": "In the heart of Eldoria, an ancient land of boundless magic and mysterious creatures, lies the long-forgotten city of Aeloria. Once a beacon of knowledge and power, Aeloria was buried beneath the shifting sands of time, lost to the world for centuries. You are an intrepid explorer, known for your unparalleled curiosity and courage, who has stumbled upon an ancient map hinting at ests that Aeloria holds a secret so profound that it has the potential to reshape the very fabric of reality. Your journey will take you through treacherous deserts, enchanted forests, and across perilous mountain ranges. Your Task: Character Background: Develop a detailed background for your character. Describe their motivations for seeking out Aeloria, their skills and weaknesses, and any personal connections to the ancient city or its legends. Are they driven by a quest for knowledge, a search for lost familt clue is hidden." + } + ], + "stream":false, "max_tokens": 30 }' -``` - -For more detailed explenations, refer to the main [LLM examples README](../llm/README.md). - - - -## Deepseek R1 - -To run DSR1 model please first follow the Ray setup from the [multinode documentation](../../docs/examples/multinode.md). - -### Aggregated Deployment - -```bash -cd examples/vllm_v1 -dynamo serve graphs.agg:Frontend -f configs/deepseek_r1/agg.yaml -``` - - -### Disaggregated Deployment + ``` -To create frontend with a single decode worker: -```bash -cd examples/vllm_v1 -dynamo serve graphs.agg:Frontend -f configs/deepseek_r1/disagg.yaml -``` - -To create a single decode worker: -```bash -cd examples/vllm_v1 -dynamo serve components.worker:VllmDecodeWorker -f configs/deepseek_r1/disagg.yaml +TODO: +- Currently if you run more than one instance or worker on the same node this will fail because the ZmqKvPublishers will overlap ports, need to add some port offsetting to manage that. ``` - -To create a single prefill worker: -```bash -cd examples/vllm_v1 -dynamo serve components.worker:VllmPrefillWorker -f configs/deepseek_r1/disagg.yaml + ServiceArgs: + workers: 1 # 2 workers not supported ``` +- It would be best to distill the vLLM serve into a VllmHeadlessWorker using - run_headless(self.engine_args). This is relatively simple, the main difficulty here is if you want to add the ZmqKvEventPublisher to these nodes (which would be easier for multi-node because then you just need to set-up nats and not worry about port stuff) they will have a different lease_id than the leader worker. This is a problem because we don't actually route requests to these dp_ranks directly but in the KV Router and KV Indexer it will see these KVEvents as coming from a seperate "worker". We still need to route the KVEvents through the leader AsyncLLM engine and that engine will take care of routing to the dp ranks. + - To address this we could create a concept of worker groups? IE components whose lease_ids are tied to a single leader worker? -## Testing -Send a test request using curl: -```bash -curl localhost:8000/v1/completions \ - -H "Content-Type: application/json" \ - -d '{ - "model": "deepseek-ai/DeepSeek-R1", - "prompt": "In the heart of Eldoria...", - "stream": false, - "max_tokens": 30 - }' -``` \ No newline at end of file +For more detailed explenations, refer to the main [LLM examples README](../llm/README.md). diff --git a/examples/vllm_v1/components/headless_worker.py b/examples/vllm_v1/components/headless_worker.py new file mode 100644 index 0000000000..1403cd62d5 --- /dev/null +++ b/examples/vllm_v1/components/headless_worker.py @@ -0,0 +1,100 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Work In Progress. This is not usable currently + +import asyncio +import logging +import os +import signal +import socket +from typing import Optional + +from utils.args import parse_vllm_args +from vllm import run_headless +from vllm.distributed.kv_events import KVEventsConfig + +from dynamo.sdk import service + +logger = logging.getLogger(__name__) + +BLOCK_SIZE = 16 + + +@service( + dynamo={ + "enabled": True, + "namespace": "dynamo", + }, + resources={"gpu": 1, "cpu": "10", "memory": "20Gi"}, + workers=1, +) +class VllmHeadlessWorker: + def __init__(self): + class_name = self.__class__.__name__ + self.engine_args = parse_vllm_args(class_name, "") + self.engine_args.kv_events_config = KVEventsConfig( + enable_kv_cache_events=True, publisher="zmq" + ) + if not self.engine_args.block_size: + logger.info(f"block_size not set, default to {BLOCK_SIZE}") + self.engine_args.block_size = BLOCK_SIZE + + os.environ["VLLM_NO_USAGE_STATS"] = "1" # Avoid internal HTTP requests + + model_config = self.engine_args.create_model_config() + self.default_sampling_params = model_config.get_diff_sampling_param() + + self.kv_publishers = [] + + signal.signal(signal.SIGTERM, self.shutdown_vllm_engine) + signal.signal(signal.SIGINT, self.shutdown_vllm_engine) + + self.set_side_channel_host_and_port() + + async def async_init(self): + run_headless(self.engine_args) + + def shutdown_vllm_engine(self, signum, frame): + """Shutdown the background loop""" + logger.info(f"Received signal {signum}, shutting down") + loop = asyncio.get_event_loop() + try: + self.engine_client.shutdown() + for publisher in self.kv_publishers: + publisher.shutdown() + logger.info("VllmWorker shutdown complete") + except Exception as e: + logger.error(f"Error during shutdown: {e}") + finally: + loop.stop() + + def set_side_channel_host_and_port( + self, hostname: Optional[str] = None, port: Optional[int] = None + ): + """vLLM V1 NixlConnector creates a side channel to exchange metadata with other NIXL connectors. + This sets the port number for the side channel. + """ + if hostname is None: + hostname = socket.gethostname() + if port is None: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) # Bind to a free port provided by the host. + port = s.getsockname()[1] # Get the port number assigned. + logger.debug("Setting VLLM_NIXL_SIDE_CHANNEL_HOST to %s", hostname) + os.environ["VLLM_NIXL_SIDE_CHANNEL_HOST"] = hostname + logger.debug("Setting VLLM_NIXL_SIDE_CHANNEL_PORT to %s", port) + os.environ["VLLM_NIXL_SIDE_CHANNEL_PORT"] = str(port) diff --git a/examples/vllm_v1/configs/agg.yaml b/examples/vllm_v1/configs/agg.yaml index 141adf1ba4..ddc1d83a37 100644 --- a/examples/vllm_v1/configs/agg.yaml +++ b/examples/vllm_v1/configs/agg.yaml @@ -1,8 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. +# you may not use this file except in compliance with the License.More actions # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 @@ -14,7 +10,7 @@ # limitations under the License. Common: model: Qwen/Qwen3-0.6B - data-parallel-size: 2 + block-size: 16 max-model-len: 16384 served_model_name: Qwen/Qwen3-0.6B @@ -27,9 +23,14 @@ VllmDecodeWorker: enforce-eager: true max-num-batched-tokens: 16384 enable-prefix-caching: true + data-parallel-address: 127.0.0.1 + data-parallel-rpc-port: 62300 + data-parallel-size: 2 + data-parallel-size-local: 1 + # api-server-count: 2 + ServiceArgs: workers: 1 # 2 workers resources: - gpu: 2 # 2 dp ranks - common-configs: [model, served_model_name, block-size, data-parallel-size, max-model-len] - + gpu: 1 # 2 dp ranks + common-configs: [model, served_model_name, block-size, max-model-len] diff --git a/lib/llm/src/kv_router/indexer.rs b/lib/llm/src/kv_router/indexer.rs index b8698aa10f..8ffab92184 100644 --- a/lib/llm/src/kv_router/indexer.rs +++ b/lib/llm/src/kv_router/indexer.rs @@ -259,7 +259,7 @@ impl RadixTree { pub fn apply_event(&mut self, event: RouterEvent) { let (worker_id, event) = (event.worker, event.event); let (id, op) = (event.event_id, event.data); - tracing::trace!(id, "Store operation: {:?}", op); + tracing::trace!(worker_id = ?worker_id, id=?id, "Store operation: {:?}", op); let worker_lookup = self.lookup.entry(worker_id.clone()).or_default(); @@ -278,7 +278,7 @@ impl RadixTree { None => { tracing::warn!( worker_id = ?worker_id, - id, + id = ?id, parent_hash = ?op.parent_hash, "Failed to find parent block; skipping store operation" ); @@ -332,7 +332,7 @@ impl RadixTree { None => { tracing::warn!( worker_id = ?worker_id, - id, + id = ?id, "Failed to find block to remove; skipping remove operation" ); continue;