Skip to content

Commit c89260f

Browse files
committed
feat(router): Dynamic batch sizing
1 parent b6ee0ec commit c89260f

File tree

6 files changed

+357
-285
lines changed

6 files changed

+357
-285
lines changed

launcher/src/main.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ struct Args {
4141
max_total_tokens: usize,
4242
#[clap(default_value = "32", long, env)]
4343
max_batch_size: usize,
44+
#[clap(default_value = None, long, env)]
45+
max_batch_weight: Option<usize>,
46+
#[clap(default_value = None, long, env)]
47+
max_prefill_weight: Option<usize>,
4448
#[clap(default_value = "20", long, env)]
4549
max_waiting_tokens: usize,
4650
#[clap(default_value = "3000", long, short, env)]
@@ -93,6 +97,8 @@ fn main() -> ExitCode {
9397
max_input_length,
9498
max_total_tokens,
9599
max_batch_size,
100+
max_batch_weight,
101+
max_prefill_weight,
96102
max_waiting_tokens,
97103
port,
98104
shard_uds_path,
@@ -392,6 +398,16 @@ fn main() -> ExitCode {
392398
model_id,
393399
];
394400

401+
if let Some(max_batch_weight) = max_batch_weight {
402+
argv.push("--max-batch-weight".to_string());
403+
argv.push(max_batch_weight.to_string())
404+
}
405+
406+
if let Some(max_prefill_weight) = max_prefill_weight {
407+
argv.push("--max-batch-weight".to_string());
408+
argv.push(max_prefill_weight.to_string())
409+
}
410+
395411
// Model optional revision
396412
if let Some(ref revision) = revision {
397413
argv.push("--revision".to_string());

router/src/infer.rs

Lines changed: 41 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ use thiserror::Error;
1414
use tokio::sync::{Notify, Semaphore, TryAcquireError};
1515
use tokio::time::Instant;
1616
use tracing::{info_span, instrument, Instrument, Span};
17+
use crate::queue::BatchingConfig;
1718

1819
/// Inference struct
1920
#[derive(Clone)]
@@ -39,19 +40,24 @@ impl Infer {
3940
client: ShardedClient,
4041
validation: Validation,
4142
max_batch_size: usize,
43+
max_batch_weight: usize,
44+
max_prefill_weight: usize,
4245
max_waiting_tokens: usize,
4346
max_concurrent_requests: usize,
4447
) -> Self {
4548
// Infer shared state
46-
let queue = Queue::new();
49+
let queue = Queue::new(BatchingConfig {
50+
size_limit: max_batch_size,
51+
weight_limit: max_batch_weight,
52+
prefill_weight_limit: max_prefill_weight,
53+
});
4754
let shared = Arc::new(Shared {
4855
batching_task: Notify::new(),
4956
});
5057

5158
// Spawn batching background task that contains all the inference logic
5259
tokio::spawn(batching_task(
5360
client,
54-
max_batch_size,
5561
max_waiting_tokens,
5662
queue.clone(),
5763
shared.clone(),
@@ -99,6 +105,7 @@ impl Infer {
99105
// Append the request to the queue
100106
self.queue.append(Entry {
101107
request: valid_request,
108+
generated_tokens: 0,
102109
response_tx,
103110
span: Span::current(),
104111
temp_span: None,
@@ -227,18 +234,11 @@ impl Infer {
227234
/// Batches requests and sends them to the inference server
228235
async fn batching_task(
229236
mut client: ShardedClient,
230-
max_batch_size: usize,
237+
// max_batch_size: usize,
231238
max_waiting_tokens: usize,
232239
queue: Queue,
233240
shared: Arc<Shared>,
234241
) {
235-
// Minimum batch size after which we try to add more requests
236-
let limit_min_batch_size = if max_batch_size > 1 {
237-
(max_batch_size / 2) as u32
238-
} else {
239-
0
240-
};
241-
242242
// Infinite loop
243243
loop {
244244
// Wait for a notification from the Infer struct
@@ -247,8 +247,8 @@ async fn batching_task(
247247
// Get the next batch from the queue
248248
// This batch might be smaller than the maximum batch size if there are not enough requests
249249
// waiting in the queue
250-
while let Some((mut entries, batch, span)) = queue.next_batch(None, max_batch_size).await {
251-
let mut cached_batch = prefill(&mut client, batch, &mut entries)
250+
while let (_, Some((mut entries, batch, span))) = queue.next_batch(None).await {
251+
let (mut cached_batch, mut some_completed) = prefill(&mut client, batch, &mut entries)
252252
.instrument(span)
253253
.await;
254254
let mut waiting_tokens = 1;
@@ -261,21 +261,16 @@ async fn batching_task(
261261
let mut batches = vec![batch];
262262
metrics::gauge!("tgi_batch_current_size", batch_size as f64);
263263

264-
// If the current batch is too small, we try to add more requests to it
265-
if batch_size <= limit_min_batch_size {
266-
let min_size = match waiting_tokens {
267-
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
268-
// to add a new batch even though its size might be small
269-
_ if waiting_tokens >= max_waiting_tokens => None,
270-
// Minimum size criteria
271-
_ => Some(limit_min_batch_size as usize),
272-
};
273-
274-
// Try to get a new batch
275-
if let Some((mut new_entries, new_batch, span)) = queue
276-
.next_batch(min_size, max_batch_size - batch_size as usize)
277-
.await
278-
{
264+
// Try to extend batch if its size reduced or enough tokens have elapsed since last one
265+
if some_completed || waiting_tokens >= max_waiting_tokens {
266+
267+
// Try to get a new batch - ownership of entries passed in and out
268+
let (
269+
existing_entries, new_entries
270+
) = queue.next_batch(Some(entries)).await;
271+
entries = existing_entries.unwrap();
272+
273+
if let Some((mut new_entries, new_batch, span)) = new_entries {
279274
let new_batch_size = new_batch.size;
280275
entries.iter_mut().for_each(|(_, entry)| {
281276
// Create a new span to add the info that this entry is waiting
@@ -290,7 +285,7 @@ async fn batching_task(
290285
});
291286

292287
// Generate one token for this new batch to have the attention past in cache
293-
let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries)
288+
let (new_cached_batch, _) = prefill(&mut client, new_batch, &mut new_entries)
294289
.instrument(span)
295290
.await;
296291
// Reset waiting counter
@@ -317,7 +312,7 @@ async fn batching_task(
317312
entry.temp_span = Some(entry_batch_span);
318313
});
319314

320-
cached_batch = decode(&mut client, batches, &mut entries)
315+
(cached_batch, some_completed) = decode(&mut client, batches, &mut entries)
321316
.instrument(next_batch_span)
322317
.await;
323318
waiting_tokens += 1;
@@ -332,24 +327,24 @@ async fn prefill(
332327
client: &mut ShardedClient,
333328
batch: Batch,
334329
entries: &mut IntMap<u64, Entry>,
335-
) -> Option<Batch> {
330+
) -> (Option<Batch>, bool) {
336331
let start_time = Instant::now();
337332
let batch_id = batch.id;
338333
metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill");
339334

340335
match client.prefill(batch).await {
341336
Ok((generations, next_batch)) => {
342-
send_generations(generations, entries);
337+
let some_completed = send_generations(generations, entries);
343338
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill");
344339
metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill");
345-
next_batch
340+
(next_batch, some_completed)
346341
}
347342
// If we have an error, we discard the whole batch
348343
Err(err) => {
349344
let _ = client.clear_cache(Some(batch_id)).await;
350345
send_errors(err, entries);
351346
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill");
352-
None
347+
(None, true)
353348
}
354349
}
355350
}
@@ -359,22 +354,22 @@ async fn decode(
359354
client: &mut ShardedClient,
360355
batches: Vec<Batch>,
361356
entries: &mut IntMap<u64, Entry>,
362-
) -> Option<Batch> {
357+
) -> (Option<Batch>, bool) {
363358
let start_time = Instant::now();
364359
metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode");
365360

366361
match client.decode(batches).await {
367362
Ok((generations, next_batch)) => {
368-
send_generations(generations, entries);
363+
let some_completed = send_generations(generations, entries);
369364
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode");
370365
metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode");
371-
next_batch
366+
(next_batch, some_completed)
372367
}
373368
// If we have an error, we discard the whole batch
374369
Err(err) => {
375370
send_errors(err, entries);
376371
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "decode");
377-
None
372+
(None, true)
378373
}
379374
}
380375
}
@@ -398,13 +393,15 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
398393
}
399394

400395
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
396+
/// Return true if any requests completed
401397
#[instrument(skip_all)]
402-
fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {
398+
fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) -> bool {
399+
let mut some_completed = false;
403400
generations.into_iter().for_each(|generation| {
404401
// Get entry
405402
// We can `expect` here as the request id should always be in the entries
406403
let entry = entries
407-
.get(&generation.request_id)
404+
.get_mut(&generation.request_id)
408405
.expect("ID not found in entries. This is a bug.");
409406

410407
// Create and enter a span to link this function back to the entry
@@ -445,7 +442,10 @@ fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entr
445442
start: entry.batch_time.unwrap(),
446443
}))
447444
.unwrap_or(());
445+
some_completed = true;
448446
} else {
447+
// Increment generated token count
448+
entry.generated_tokens += 1;
449449
// Send message
450450
// unwrap_or is valid here as we don't care if the receiver is gone.
451451
entry
@@ -454,6 +454,8 @@ fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entr
454454
.unwrap_or(());
455455
}
456456
});
457+
458+
return some_completed;
457459
}
458460

459461
#[derive(Debug)]

router/src/main.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ struct Args {
3333
max_total_tokens: usize,
3434
#[clap(default_value = "32", long, env)]
3535
max_batch_size: usize,
36+
#[clap(default_value = None, long, env)]
37+
max_batch_weight: Option<usize>,
38+
#[clap(default_value = None, long, env)]
39+
max_prefill_weight: Option<usize>,
3640
#[clap(default_value = "20", long, env)]
3741
max_waiting_tokens: usize,
3842
#[clap(default_value = "3000", long, short, env)]
@@ -64,6 +68,8 @@ fn main() -> Result<(), std::io::Error> {
6468
max_input_length,
6569
max_total_tokens,
6670
max_batch_size,
71+
max_batch_weight,
72+
max_prefill_weight,
6773
max_waiting_tokens,
6874
port,
6975
master_shard_uds_path,
@@ -169,6 +175,8 @@ fn main() -> Result<(), std::io::Error> {
169175
max_input_length,
170176
max_total_tokens,
171177
max_batch_size,
178+
max_batch_weight,
179+
max_prefill_weight,
172180
max_waiting_tokens,
173181
sharded_client,
174182
tokenizer,

0 commit comments

Comments
 (0)