Skip to content

Commit ee0dad7

Browse files
committed
address reviewer comments
1 parent 0c966e1 commit ee0dad7

File tree

7 files changed

+113
-56
lines changed

7 files changed

+113
-56
lines changed

components/metrics/src/bin/mock_worker.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ fn mock_stats_handler(_stats: EndpointStats) -> serde_json::Value {
115115
let gpu_cache_usage_perc = rand::rng().random_range(0.0..=1.0);
116116
let gpu_prefix_cache_hit_rate = rand::rng().random_range(0.0..=1.0);
117117
let stats = ForwardPassMetrics {
118+
data_parallel_rank: None, // Default for backwards compatibility
118119
request_active_slots,
119120
request_total_slots,
120121
kv_active_blocks,

launch/dynamo-run/src/subprocess/vllm_v1_inc.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,7 @@
55
# Can also be used standalone: `python3 vllm_inc.py` - lots of optional cmd line params
66

77
# Setup checklist:
8-
# - We are in a virtualenv with vllm installed - and patched if using kv routing.
9-
# - `libdynamo_llm_capi.so` is in system lib path or it's containing folder is in LD_LIBRARY_PATH
10-
# It builds in target/debug/ by default.
8+
# - We are in a virtualenv with vllm installed. Must be newer than v0.9.0 (currently pre-release)
119

1210
import argparse
1311
import asyncio
@@ -56,15 +54,17 @@ class Config:
5654
model_name: Optional[str]
5755
tensor_parallel_size: int
5856
kv_block_size: int
57+
context_length: int
5958
extra_engine_args: str
6059

6160

6261
class DynamoStatLoggerPublisher(StatLoggerBase):
6362
"""Stat logger publisher. Wrapper for the KvMetricsPublisher to match the StatLoggerBase interface."""
6463

65-
def __init__(self, component: Component) -> None:
64+
def __init__(self, component: Component, dp_rank: int) -> None:
6665
self.inner = KvMetricsPublisher()
6766
self.inner.create_endpoint(component)
67+
self.dp_rank = dp_rank
6868

6969
def record(
7070
self, scheduler_stats: SchedulerStats, iteration_stats: Optional[IterationStats]
@@ -79,7 +79,9 @@ def record(
7979
/ scheduler_stats.prefix_cache_stats.queries
8080
)
8181

82+
# TODO Manage DP Ranks in metrics aggregation.
8283
self.inner.publish(
84+
data_parallel_rank=self.dp_rank,
8385
request_active_slots=scheduler_stats.num_running_reqs,
8486
request_total_slots=0, # TODO - remove from metrics
8587
kv_active_blocks=0, # TODO - need to calculate this
@@ -99,12 +101,11 @@ class StatLoggerFactory:
99101
def __init__(self, component: Component) -> None:
100102
self.component = component
101103

102-
def create_stat_logger(self) -> StatLoggerBase:
103-
return DynamoStatLoggerPublisher(self.component)
104+
def create_stat_logger(self, dp_rank: int) -> StatLoggerBase:
105+
return DynamoStatLoggerPublisher(self.component, dp_rank)
104106

105-
# TODO investigate if rank is imporant. Do I need to only do for rank 0?
106-
def __call__(self, vllm_config: VllmConfig, rank: int) -> StatLoggerBase:
107-
return self.create_stat_logger()
107+
def __call__(self, vllm_config: VllmConfig, dp_rank: int) -> StatLoggerBase:
108+
return self.create_stat_logger(dp_rank=dp_rank)
108109

109110

110111
class RequestHandler:
@@ -172,8 +173,13 @@ async def init(runtime: DistributedRuntime, config: Config):
172173
await component.create_service()
173174

174175
endpoint = component.endpoint(config.endpoint)
176+
print(f"BLOCK SIZE: {config.kv_block_size}")
175177
await register_llm(
176-
ModelType.Backend, endpoint, config.model_path, config.model_name
178+
ModelType.Backend,
179+
endpoint,
180+
config.model_path,
181+
config.model_name,
182+
kv_cache_block_size=config.kv_block_size,
177183
)
178184

179185
arg_map = {
@@ -183,13 +189,20 @@ async def init(runtime: DistributedRuntime, config: Config):
183189
"skip_tokenizer_init": True,
184190
"disable_log_requests": True,
185191
"enable_prefix_caching": True,
186-
"block_size": config.kv_block_size,
187192
# KV routing relies on logging KV metrics
188193
"disable_log_stats": False,
189194
"kv_events_config": KVEventsConfig(
190195
enable_kv_cache_events=True, publisher="zmq"
191196
),
192197
}
198+
199+
if config.context_length:
200+
# Usually we want it to default to the max (from tokenizer_config.json)
201+
arg_map["max_model_len"] = config.context_length
202+
203+
if config.kv_block_size > 0:
204+
arg_map["block_size"] = config.kv_block_size
205+
193206
if config.extra_engine_args != "":
194207
json_map = {}
195208
# extra_engine_args is a filename
@@ -271,6 +284,12 @@ def cmd_line_args():
271284
parser.add_argument(
272285
"--kv-block-size", type=int, default=16, help="Size of a KV cache block."
273286
)
287+
parser.add_argument(
288+
"--context-length",
289+
type=int,
290+
default=None,
291+
help="Max model context length. Defaults to models max, usually model_max_length from tokenizer_config.json. Reducing this reduces VRAM requirements.",
292+
)
274293
parser.add_argument(
275294
"--extra-engine-args",
276295
type=str,
@@ -302,6 +321,7 @@ def cmd_line_args():
302321
config.endpoint = parsed_endpoint_name
303322
config.tensor_parallel_size = args.tensor_parallel_size
304323
config.kv_block_size = args.kv_block_size
324+
config.context_length = args.context_length
305325
config.extra_engine_args = args.extra_engine_args
306326

307327
return config

lib/bindings/python/rust/llm/kv.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ impl KvMetricsPublisher {
9797
fn publish(
9898
&self,
9999
_py: Python,
100+
data_parallel_rank: u32,
100101
request_active_slots: u64,
101102
request_total_slots: u64,
102103
kv_active_blocks: u64,
@@ -108,6 +109,7 @@ impl KvMetricsPublisher {
108109
self.inner
109110
.publish(
110111
llm_rs::kv_router::protocols::ForwardPassMetrics {
112+
data_parallel_rank: Some(data_parallel_rank),
111113
request_active_slots,
112114
request_total_slots,
113115
kv_active_blocks,

lib/bindings/python/src/dynamo/_core.pyi

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,10 +364,14 @@ class KvMetricsPublisher:
364364

365365
def publish(
366366
self,
367+
data_parallel_rank: int,
367368
request_active_slots: int,
368369
request_total_slots: int,
369370
kv_active_blocks: int,
370371
kv_total_blocks: int,
372+
num_requests_waiting: int,
373+
gpu_cache_usage_perc: float,
374+
gpu_prefix_cache_hit_rate: float,
371375
) -> None:
372376
"""
373377
Update the KV metrics being reported.
@@ -637,7 +641,7 @@ class ModelType:
637641
"""What type of request this model needs: Chat, Component or Backend (pre-processed)"""
638642
...
639643

640-
async def register_llm(model_type: ModelType, endpoint: Endpoint, model_path: str, model_name: Optional[str]) -> None:
644+
async def register_llm(model_type: ModelType, endpoint: Endpoint, model_path: str, model_name: Optional[str] = None, context_length: Optional[int] = None, kv_cache_block_size: Optional[int] = None) -> None:
641645
"""Attach the model at path to the given endpoint, and advertise it as model_type"""
642646
...
643647

lib/llm/src/kv_router/protocols.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ pub struct WorkerSelectionResult {
4141

4242
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
4343
pub struct ForwardPassMetrics {
44+
pub data_parallel_rank: Option<u32>, // backwards compatible
4445
pub request_active_slots: u64,
4546
pub request_total_slots: u64,
4647
pub kv_active_blocks: u64,

lib/llm/src/kv_router/publisher.rs

Lines changed: 72 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ impl KvEventPublisher {
6060
}
6161

6262
pub fn publish(&self, event: KvCacheEvent) -> Result<(), mpsc::error::SendError<KvCacheEvent>> {
63-
tracing::info!("Publish event: {:?}", event);
63+
tracing::trace!("Publish event: {:?}", event);
6464
self.tx.send(event)
6565
}
6666

@@ -90,6 +90,7 @@ fn start_publish_task(
9090

9191
pub struct KvEventPublisherFromZmq {
9292
kv_block_size: usize,
93+
processor_handle: Option<tokio::task::JoinHandle<()>>,
9394
zmq_handle: Option<tokio::task::JoinHandle<()>>,
9495
zmq_token: Option<dynamo_runtime::CancellationToken>,
9596
}
@@ -98,6 +99,7 @@ impl KvEventPublisherFromZmq {
9899
pub fn new(kv_block_size: usize) -> Self {
99100
Self {
100101
kv_block_size,
102+
processor_handle: None,
101103
zmq_handle: None,
102104
zmq_token: None,
103105
}
@@ -126,20 +128,23 @@ impl KvEventPublisherFromZmq {
126128
zmq_endpoint,
127129
zmq_topic,
128130
raw_tx,
129-
zmq_token,
131+
zmq_token.clone(),
130132
)),
131133
);
132134

133-
component
134-
.drt()
135-
.runtime()
136-
.secondary()
137-
.spawn(start_event_processor(
138-
raw_rx,
139-
component,
140-
worker_id,
141-
kv_block_size,
142-
));
135+
self.processor_handle = Some(
136+
component
137+
.drt()
138+
.runtime()
139+
.secondary()
140+
.spawn(start_event_processor(
141+
raw_rx,
142+
component,
143+
worker_id,
144+
kv_block_size,
145+
zmq_token,
146+
))
147+
);
143148
}
144149

145150
pub fn shutdown(&mut self) {
@@ -149,6 +154,9 @@ impl KvEventPublisherFromZmq {
149154
if let Some(handle) = self.zmq_handle.take() {
150155
handle.abort();
151156
}
157+
if let Some(handle) = self.processor_handle.take() {
158+
handle.abort();
159+
}
152160
}
153161
}
154162

@@ -157,24 +165,45 @@ async fn start_event_processor<P: EventPublisher>(
157165
component: P,
158166
worker_id: i64,
159167
kv_block_size: usize,
168+
cancellation_token: dynamo_runtime::CancellationToken,
160169
) {
161-
while let Some((seq, payload)) = raw_rx.recv().await {
162-
match rmps::from_slice::<KvEventBatch>(&payload) {
163-
Ok(batch) => {
164-
for raw_evt in batch.events.into_iter() {
165-
if let Some(event) = convert_event(raw_evt, seq, kv_block_size) {
166-
let router_event = RouterEvent::new(worker_id, event);
167-
if let Err(e) = component.publish(KV_EVENT_SUBJECT, &router_event).await {
168-
tracing::warn!("Failed to publish router event: {}", e);
170+
loop {
171+
tokio::select! {
172+
// Check for cancellation
173+
_ = cancellation_token.cancelled() => {
174+
tracing::debug!("Event processor received cancellation signal");
175+
break;
176+
}
177+
178+
// Process incoming messages
179+
msg = raw_rx.recv() => {
180+
match msg {
181+
Some((seq, payload)) => {
182+
match rmps::from_slice::<KvEventBatch>(&payload) {
183+
Ok(batch) => {
184+
for raw_evt in batch.events.into_iter() {
185+
if let Some(event) = convert_event(raw_evt, seq, kv_block_size) {
186+
let router_event = RouterEvent::new(worker_id, event);
187+
if let Err(e) = component.publish(KV_EVENT_SUBJECT, &router_event).await {
188+
tracing::warn!(error=%e, "Failed to publish router event.");
189+
}
190+
}
191+
}
192+
}
193+
Err(e) => {
194+
tracing::warn!(error=%e, "Failed to decode KVEventBatch msgpack");
195+
}
169196
}
170197
}
198+
None => {
199+
tracing::debug!("Event processor channel closed");
200+
break;
201+
}
171202
}
172203
}
173-
Err(e) => {
174-
tracing::warn!("Failed to decode KVEventBatch msgpack: {}", e);
175-
}
176204
}
177205
}
206+
tracing::debug!("Event processor exiting");
178207
}
179208

180209
async fn start_zmq_listener(
@@ -183,7 +212,7 @@ async fn start_zmq_listener(
183212
raw_tx: mpsc::UnboundedSender<(u64, Vec<u8>)>,
184213
zmq_token: dynamo_runtime::CancellationToken,
185214
) {
186-
tracing::info!(
215+
tracing::debug!(
187216
"KVEventPublisher connecting to ZMQ endpoint {} (topic '{}')",
188217
zmq_endpoint,
189218
zmq_topic
@@ -217,34 +246,34 @@ async fn start_zmq_listener(
217246
// We expect multipart frames: [topic, seq, payload]
218247
let mut frames: Vec<Vec<u8>> = msg.into_vec().into_iter().map(|frame| frame.to_vec()).collect();
219248

220-
if frames.len() == 3 {
221-
let payload = frames.remove(2);
222-
let seq_bytes = frames.remove(1);
249+
if frames.len() != 3 {
250+
tracing::warn!(expected=3, actual=%frames.len(), "Received unexpected ZMQ frame count");
251+
continue;
252+
}
253+
let payload = frames.remove(2);
254+
let seq_bytes = frames.remove(1);
223255

224-
if seq_bytes.len() != 8 {
225-
tracing::warn!("Invalid sequence number frame len={}", seq_bytes.len());
226-
continue;
227-
}
256+
if seq_bytes.len() != 8 {
257+
tracing::warn!(expected=8, actual=%seq_bytes.len(), "Invalid sequence number byte length");
258+
continue;
259+
}
228260

229-
let seq = u64::from_be_bytes(seq_bytes.try_into().unwrap());
230-
if raw_tx.send((seq, payload)).is_err() {
231-
tracing::warn!("Failed to send message to channel - receiver dropped");
232-
break;
233-
}
234-
} else {
235-
tracing::warn!("Received unexpected ZMQ frame count: {}", frames.len());
261+
let seq = u64::from_be_bytes(seq_bytes.try_into().unwrap());
262+
if raw_tx.send((seq, payload)).is_err() {
263+
tracing::warn!("Failed to send message to channel - receiver dropped");
264+
break;
236265
}
237266
}
238267
Err(e) => {
239-
tracing::warn!("Error reading from ZMQ socket: {}", e);
268+
tracing::warn!(error=%e, "Error reading from ZMQ socket");
240269
// Brief sleep to avoid tight error loop
241270
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
242271
}
243272
}
244273
}
245274
}
246275
}
247-
tracing::info!("ZMQ listener exiting");
276+
tracing::debug!("ZMQ listener exiting");
248277
}
249278

250279
/// Convert a raw event coming from the ZMQ channel into the internal
@@ -355,17 +384,14 @@ struct KvEventBatch {
355384
#[derive(Debug, Deserialize, Serialize)]
356385
#[serde(tag = "type")] // msgspec encodes variant tag as a string when `tag=True`
357386
enum RawKvEvent {
358-
#[serde(rename = "BlockStored")]
359387
BlockStored {
360388
block_hashes: Vec<i64>,
361389
parent_block_hash: Option<i64>,
362390
token_ids: Vec<u32>,
363391
block_size: usize,
364392
lora_id: Option<u64>,
365393
},
366-
#[serde(rename = "BlockRemoved")]
367394
BlockRemoved { block_hashes: Vec<i64> },
368-
#[serde(rename = "AllBlocksCleared")]
369395
AllBlocksCleared,
370396
}
371397

@@ -620,6 +646,8 @@ mod tests_startup_helpers {
620646
};
621647
let payload = rmps::to_vec(&batch).unwrap();
622648

649+
let token = dynamo_runtime::CancellationToken::new();
650+
623651
// 2) channel feeding the processor
624652
let (tx, rx) = mpsc::unbounded_channel::<(u64, Vec<u8>)>();
625653
tx.send((123, payload.clone())).unwrap(); // seq = 123
@@ -629,7 +657,7 @@ mod tests_startup_helpers {
629657
let (comp, published) = MockComponent::new();
630658

631659
// 4) run the function under test (let it consume exactly one msg)
632-
let handle = tokio::spawn(start_event_processor(rx, comp, worker_id, kv_block_size));
660+
let handle = tokio::spawn(start_event_processor(rx, comp, worker_id, kv_block_size, token));
633661

634662
tokio::time::timeout(std::time::Duration::from_secs(1), handle)
635663
.await

0 commit comments

Comments
 (0)