Skip to content

Commit ba1aae3

Browse files
committed
feat(router): dynamic batch sizing
1 parent 709d893 commit ba1aae3

File tree

6 files changed

+342
-303
lines changed

6 files changed

+342
-303
lines changed

launcher/src/main.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ struct Args {
4141
max_total_tokens: usize,
4242
#[clap(default_value = "32", long, env)]
4343
max_batch_size: usize,
44+
#[clap(default_value = None, long, env)]
45+
max_batch_weight: Option<usize>,
46+
#[clap(default_value = None, long, env)]
47+
max_prefill_weight: Option<usize>,
4448
#[clap(default_value = "20", long, env)]
4549
max_waiting_tokens: usize,
4650
#[clap(default_value = "3000", long, short, env)]
@@ -93,6 +97,8 @@ fn main() -> ExitCode {
9397
max_input_length,
9498
max_total_tokens,
9599
max_batch_size,
100+
max_batch_weight,
101+
max_prefill_weight,
96102
max_waiting_tokens,
97103
port,
98104
shard_uds_path,
@@ -392,6 +398,16 @@ fn main() -> ExitCode {
392398
model_id,
393399
];
394400

401+
if let Some(max_batch_weight) = max_batch_weight {
402+
argv.push("--max-batch-weight".to_string());
403+
argv.push(max_batch_weight.to_string())
404+
}
405+
406+
if let Some(max_prefill_weight) = max_prefill_weight {
407+
argv.push("--max-batch-weight".to_string());
408+
argv.push(max_prefill_weight.to_string())
409+
}
410+
395411
// Model optional revision
396412
if let Some(ref revision) = revision {
397413
argv.push("--revision".to_string());

router/src/infer.rs

Lines changed: 41 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ use thiserror::Error;
1515
use tokio::sync::{Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError};
1616
use tokio::time::Instant;
1717
use tracing::{info_span, instrument, Instrument, Span};
18+
use crate::queue::BatchingConfig;
1819

1920
/// Inference struct
2021
#[derive(Clone)]
@@ -40,19 +41,24 @@ impl Infer {
4041
client: ShardedClient,
4142
validation: Validation,
4243
max_batch_size: usize,
44+
max_batch_weight: usize,
45+
max_prefill_weight: usize,
4346
max_waiting_tokens: usize,
4447
max_concurrent_requests: usize,
4548
) -> Self {
4649
// Infer shared state
47-
let queue = Queue::new();
50+
let queue = Queue::new(BatchingConfig {
51+
size_limit: max_batch_size,
52+
weight_limit: max_batch_weight,
53+
prefill_weight_limit: max_prefill_weight,
54+
});
4855
let shared = Arc::new(Shared {
4956
batching_task: Notify::new(),
5057
});
5158

5259
// Spawn batching background task that contains all the inference logic
5360
tokio::spawn(batching_task(
5461
client,
55-
max_batch_size,
5662
max_waiting_tokens,
5763
queue.clone(),
5864
shared.clone(),
@@ -105,6 +111,7 @@ impl Infer {
105111
// Append the request to the queue
106112
self.queue.append(Entry {
107113
request: valid_request,
114+
generated_tokens: 0,
108115
response_tx,
109116
span: Span::current(),
110117
temp_span: None,
@@ -232,18 +239,11 @@ impl Infer {
232239
/// Batches requests and sends them to the inference server
233240
async fn batching_task(
234241
mut client: ShardedClient,
235-
max_batch_size: usize,
242+
// max_batch_size: usize,
236243
max_waiting_tokens: usize,
237244
queue: Queue,
238245
shared: Arc<Shared>,
239246
) {
240-
// Minimum batch size after which we try to add more requests
241-
let limit_min_batch_size = if max_batch_size > 1 {
242-
(max_batch_size / 2) as u32
243-
} else {
244-
0
245-
};
246-
247247
// Infinite loop
248248
loop {
249249
// Wait for a notification from the Infer struct
@@ -252,8 +252,8 @@ async fn batching_task(
252252
// Get the next batch from the queue
253253
// This batch might be smaller than the maximum batch size if there are not enough requests
254254
// waiting in the queue
255-
while let Some((mut entries, batch, span)) = queue.next_batch(None, max_batch_size).await {
256-
let mut cached_batch = prefill(&mut client, batch, &mut entries)
255+
while let (_, Some((mut entries, batch, span))) = queue.next_batch(None).await {
256+
let (mut cached_batch, mut some_completed) = prefill(&mut client, batch, &mut entries)
257257
.instrument(span)
258258
.await;
259259
let mut waiting_tokens = 1;
@@ -266,21 +266,16 @@ async fn batching_task(
266266
let mut batches = vec![batch];
267267
metrics::gauge!("tgi_batch_current_size", batch_size as f64);
268268

269-
// If the current batch is too small, we try to add more requests to it
270-
if batch_size <= limit_min_batch_size {
271-
let min_size = match waiting_tokens {
272-
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
273-
// to add a new batch even though its size might be small
274-
_ if waiting_tokens >= max_waiting_tokens => None,
275-
// Minimum size criteria
276-
_ => Some(limit_min_batch_size as usize),
277-
};
278-
279-
// Try to get a new batch
280-
if let Some((mut new_entries, new_batch, span)) = queue
281-
.next_batch(min_size, max_batch_size - batch_size as usize)
282-
.await
283-
{
269+
// Try to extend batch if its size reduced or enough tokens have elapsed since last one
270+
if some_completed || waiting_tokens >= max_waiting_tokens {
271+
272+
// Try to get a new batch - ownership of entries passed in and out
273+
let (
274+
existing_entries, new_entries
275+
) = queue.next_batch(Some(entries)).await;
276+
entries = existing_entries.unwrap();
277+
278+
if let Some((mut new_entries, new_batch, span)) = new_entries {
284279
entries.iter_mut().for_each(|(_, entry)| {
285280
// Create a new span to add the info that this entry is waiting
286281
// because a new batch is being computed
@@ -293,7 +288,7 @@ async fn batching_task(
293288
});
294289

295290
// Generate one token for this new batch to have the attention past in cache
296-
let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries)
291+
let (new_cached_batch, _) = prefill(&mut client, new_batch, &mut new_entries)
297292
.instrument(span)
298293
.await;
299294
// Reset waiting counter
@@ -319,7 +314,7 @@ async fn batching_task(
319314
entry.temp_span = Some(entry_batch_span);
320315
});
321316

322-
cached_batch = decode(&mut client, batches, &mut entries)
317+
(cached_batch, some_completed) = decode(&mut client, batches, &mut entries)
323318
.instrument(next_batch_span)
324319
.await;
325320
waiting_tokens += 1;
@@ -334,14 +329,14 @@ async fn prefill(
334329
client: &mut ShardedClient,
335330
batch: Batch,
336331
entries: &mut IntMap<u64, Entry>,
337-
) -> Option<Batch> {
332+
) -> (Option<Batch>, bool) {
338333
let start_time = Instant::now();
339334
let batch_id = batch.id;
340335
metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill");
341336

342337
match client.prefill(batch).await {
343338
Ok((generations, next_batch)) => {
344-
filter_send_generations(generations, entries);
339+
let some_completed = filter_send_generations(generations, entries);
345340

346341
// Filter next batch and remove requests that were stopped
347342
let next_batch = match next_batch {
@@ -360,14 +355,14 @@ async fn prefill(
360355

361356
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill");
362357
metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill");
363-
next_batch
358+
(next_batch, some_completed)
364359
}
365360
// If we have an error, we discard the whole batch
366361
Err(err) => {
367362
let _ = client.clear_cache(Some(batch_id)).await;
368363
send_errors(err, entries);
369364
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill");
370-
None
365+
(None, true)
371366
}
372367
}
373368
}
@@ -377,14 +372,14 @@ async fn decode(
377372
client: &mut ShardedClient,
378373
batches: Vec<Batch>,
379374
entries: &mut IntMap<u64, Entry>,
380-
) -> Option<Batch> {
375+
) -> (Option<Batch>, bool) {
381376
let start_time = Instant::now();
382377
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
383378
metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode");
384379

385380
match client.decode(batches).await {
386381
Ok((generations, next_batch)) => {
387-
filter_send_generations(generations, entries);
382+
let some_completed = filter_send_generations(generations, entries);
388383

389384
// Filter next batch and remove requests that were stopped
390385
let next_batch = match next_batch {
@@ -403,7 +398,7 @@ async fn decode(
403398

404399
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode");
405400
metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode");
406-
next_batch
401+
(next_batch, some_completed)
407402
}
408403
// If we have an error, we discard the whole batch
409404
Err(err) => {
@@ -412,7 +407,7 @@ async fn decode(
412407
}
413408
send_errors(err, entries);
414409
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "decode");
415-
None
410+
(None, true)
416411
}
417412
}
418413
}
@@ -431,14 +426,16 @@ fn filter_batch(mut batch: Batch, entries: &IntMap<u64, Entry>) -> Option<Batch>
431426

432427
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
433428
/// and filter entries
429+
/// Return true if any requests completed
434430
#[instrument(skip_all)]
435-
fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {
431+
fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) -> bool {
432+
let mut some_stopped = false;
436433
generations.into_iter().for_each(|generation| {
437434
let id = generation.request_id;
438435
// Get entry
439436
// We can `expect` here as the request id should always be in the entries
440437
let entry = entries
441-
.get(&id)
438+
.get_mut(&id)
442439
.expect("ID not found in entries. This is a bug.");
443440

444441
// Create and enter a span to link this function back to the entry
@@ -451,9 +448,14 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
451448
err
452449
}).unwrap_or(true);
453450
if stopped {
451+
some_stopped = true;
454452
entries.remove(&id).expect("ID not found in entries. This is a bug.");
453+
} else {
454+
// Increment generated token count
455+
entry.generated_tokens += 1;
455456
}
456457
});
458+
return some_stopped;
457459
}
458460

459461
/// Send responses through the `entry` response channel

router/src/main.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ struct Args {
3333
max_total_tokens: usize,
3434
#[clap(default_value = "32", long, env)]
3535
max_batch_size: usize,
36+
#[clap(default_value = None, long, env)]
37+
max_batch_weight: Option<usize>,
38+
#[clap(default_value = None, long, env)]
39+
max_prefill_weight: Option<usize>,
3640
#[clap(default_value = "20", long, env)]
3741
max_waiting_tokens: usize,
3842
#[clap(default_value = "3000", long, short, env)]
@@ -64,6 +68,8 @@ fn main() -> Result<(), std::io::Error> {
6468
max_input_length,
6569
max_total_tokens,
6670
max_batch_size,
71+
max_batch_weight,
72+
max_prefill_weight,
6773
max_waiting_tokens,
6874
port,
6975
master_shard_uds_path,
@@ -169,6 +175,8 @@ fn main() -> Result<(), std::io::Error> {
169175
max_input_length,
170176
max_total_tokens,
171177
max_batch_size,
178+
max_batch_weight,
179+
max_prefill_weight,
172180
max_waiting_tokens,
173181
sharded_client,
174182
tokenizer,

0 commit comments

Comments
 (0)