Skip to content

Commit 061688b

Browse files
feat(router): use number of tokens in batch as input for dynamic batching
1 parent 6ded76a commit 061688b

File tree

14 files changed

+357
-147
lines changed

14 files changed

+357
-147
lines changed

launcher/src/main.rs

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,12 @@ struct Args {
3939
max_input_length: usize,
4040
#[clap(default_value = "1512", long, env)]
4141
max_total_tokens: usize,
42-
#[clap(default_value = "32", long, env)]
43-
max_batch_size: usize,
42+
#[clap(long, env)]
43+
max_batch_size: Option<usize>,
44+
#[clap(default_value = "1.2", long, env)]
45+
waiting_served_ratio: f32,
46+
#[clap(default_value = "32000", long, env)]
47+
max_batch_total_tokens: u32,
4448
#[clap(default_value = "20", long, env)]
4549
max_waiting_tokens: usize,
4650
#[clap(default_value = "3000", long, short, env)]
@@ -93,6 +97,8 @@ fn main() -> ExitCode {
9397
max_input_length,
9498
max_total_tokens,
9599
max_batch_size,
100+
max_batch_total_tokens,
101+
waiting_served_ratio,
96102
max_waiting_tokens,
97103
port,
98104
shard_uds_path,
@@ -380,8 +386,8 @@ fn main() -> ExitCode {
380386
max_input_length.to_string(),
381387
"--max-total-tokens".to_string(),
382388
max_total_tokens.to_string(),
383-
"--max-batch-size".to_string(),
384-
max_batch_size.to_string(),
389+
"--waiting-served-ratio".to_string(),
390+
waiting_served_ratio.to_string(),
385391
"--max-waiting-tokens".to_string(),
386392
max_waiting_tokens.to_string(),
387393
"--port".to_string(),
@@ -392,6 +398,15 @@ fn main() -> ExitCode {
392398
model_id,
393399
];
394400

401+
// Deprecate max_batch_size
402+
if let Some(max_batch_size) = max_batch_size {
403+
argv.push("--max-batch-size".to_string());
404+
argv.push(max_batch_size.to_string())
405+
} else {
406+
argv.push("--max-batch-total-tokens".to_string());
407+
argv.push(max_batch_total_tokens.to_string())
408+
}
409+
395410
// Model optional revision
396411
if let Some(ref revision) = revision {
397412
argv.push("--revision".to_string());

proto/generate.proto

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ service TextGenerationService {
99
rpc ServiceDiscovery (ServiceDiscoveryRequest) returns (ServiceDiscoveryResponse) {}
1010
/// Empties batch cache
1111
rpc ClearCache (ClearCacheRequest) returns (ClearCacheResponse);
12+
/// Remove requests from a cached batch
13+
rpc FilterBatch (FilterBatchRequest) returns (FilterBatchResponse);
1214
/// Prefill batch and decode first token
1315
rpc Prefill (PrefillRequest) returns (PrefillResponse);
1416
/// Decode token for a list of prefilled batches
@@ -89,6 +91,8 @@ message Batch {
8991
repeated Request requests = 2;
9092
/// Batch size (==len(requests))
9193
uint32 size = 3;
94+
/// Maximum number of tokens this batch will grow to
95+
uint32 max_tokens = 4;
9296
}
9397

9498
enum FinishReason {
@@ -134,6 +138,19 @@ message Generation {
134138
GeneratedText generated_text = 7;
135139
}
136140

141+
message FilterBatchRequest {
142+
/// Batch ID
143+
uint64 batch_id = 1;
144+
/// Requests to keep
145+
repeated Request keep_requests = 2;
146+
}
147+
148+
message FilterBatchResponse {
149+
/// Filtered Batch (cached)
150+
Batch batch = 1;
151+
}
152+
153+
137154
message PrefillRequest {
138155
/// Batch
139156
Batch batch = 1;

router/client/src/client.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,22 @@ impl Client {
7070
Ok(())
7171
}
7272

73+
/// Filter a cached batch
74+
#[instrument(skip(self))]
75+
pub async fn filter_batch(
76+
&mut self,
77+
batch_id: u64,
78+
keep_requests: Vec<Request>,
79+
) -> Result<Option<Batch>> {
80+
let request = tonic::Request::new(FilterBatchRequest {
81+
batch_id,
82+
keep_requests,
83+
})
84+
.inject_context();
85+
let filtered_batch = self.stub.filter_batch(request).await?.into_inner();
86+
Ok(filtered_batch.batch)
87+
}
88+
7389
/// Generate one token for each request in the given batch
7490
///
7591
/// Returns Generation for each request in batch

router/client/src/sharded_client.rs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
/// Multi shard Client
22
use crate::Result;
3-
use crate::{Batch, Client, Generation, ShardInfo};
3+
use crate::{Batch, Client, Generation, Request, ShardInfo};
44
use futures::future::join_all;
55
use tonic::transport::Uri;
66
use tracing::instrument;
@@ -59,6 +59,22 @@ impl ShardedClient {
5959
join_all(futures).await.into_iter().collect()
6060
}
6161

62+
/// Filter a cached batch
63+
#[instrument(skip(self))]
64+
pub async fn filter_batch(
65+
&mut self,
66+
batch_id: u64,
67+
keep_requests: Vec<Request>,
68+
) -> Result<Option<Batch>> {
69+
let futures: Vec<_> = self
70+
.clients
71+
.iter_mut()
72+
.map(|client| Box::pin(client.filter_batch(batch_id, keep_requests.clone())))
73+
.collect();
74+
// all shards return the same message
75+
join_all(futures).await.pop().unwrap()
76+
}
77+
6278
/// Generate one token for each request in the given batch
6379
///
6480
/// Returns Generation for each request in batch

router/src/infer.rs

Lines changed: 81 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -39,20 +39,23 @@ impl Infer {
3939
pub(crate) fn new(
4040
client: ShardedClient,
4141
validation: Validation,
42-
max_batch_size: usize,
42+
waiting_served_ratio: f32,
43+
max_batch_total_tokens: u32,
4344
max_waiting_tokens: usize,
4445
max_concurrent_requests: usize,
46+
requires_padding: bool,
4547
) -> Self {
4648
// Infer shared state
47-
let queue = Queue::new();
49+
let queue = Queue::new(requires_padding);
4850
let shared = Arc::new(Shared {
4951
batching_task: Notify::new(),
5052
});
5153

5254
// Spawn batching background task that contains all the inference logic
5355
tokio::spawn(batching_task(
5456
client,
55-
max_batch_size,
57+
waiting_served_ratio,
58+
max_batch_total_tokens,
5659
max_waiting_tokens,
5760
queue.clone(),
5861
shared.clone(),
@@ -232,18 +235,12 @@ impl Infer {
232235
/// Batches requests and sends them to the inference server
233236
async fn batching_task(
234237
mut client: ShardedClient,
235-
max_batch_size: usize,
238+
waiting_served_ratio: f32,
239+
max_batch_total_tokens: u32,
236240
max_waiting_tokens: usize,
237241
queue: Queue,
238242
shared: Arc<Shared>,
239243
) {
240-
// Minimum batch size after which we try to add more requests
241-
let limit_min_batch_size = if max_batch_size > 1 {
242-
(max_batch_size / 2) as u32
243-
} else {
244-
0
245-
};
246-
247244
// Infinite loop
248245
loop {
249246
// Wait for a notification from the Infer struct
@@ -252,7 +249,9 @@ async fn batching_task(
252249
// Get the next batch from the queue
253250
// This batch might be smaller than the maximum batch size if there are not enough requests
254251
// waiting in the queue
255-
while let Some((mut entries, batch, span)) = queue.next_batch(None, max_batch_size).await {
252+
while let Some((mut entries, batch, span)) =
253+
queue.next_batch(None, max_batch_total_tokens).await
254+
{
256255
let mut cached_batch = prefill(&mut client, batch, &mut entries)
257256
.instrument(span)
258257
.await;
@@ -263,48 +262,50 @@ async fn batching_task(
263262
while let Some(batch) = cached_batch {
264263
// Get current batch info
265264
let batch_size = batch.size;
265+
let batch_max_tokens = batch.max_tokens;
266266
let mut batches = vec![batch];
267267
metrics::gauge!("tgi_batch_current_size", batch_size as f64);
268268

269-
// If the current batch is too small, we try to add more requests to it
270-
if batch_size <= limit_min_batch_size {
271-
let min_size = match waiting_tokens {
272-
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
273-
// to add a new batch even though its size might be small
274-
_ if waiting_tokens >= max_waiting_tokens => None,
275-
// Minimum size criteria
276-
_ => Some(limit_min_batch_size as usize),
277-
};
278-
279-
// Try to get a new batch
280-
if let Some((mut new_entries, new_batch, span)) = queue
281-
.next_batch(min_size, max_batch_size - batch_size as usize)
282-
.await
283-
{
284-
entries.iter_mut().for_each(|(_, entry)| {
285-
// Create a new span to add the info that this entry is waiting
286-
// because a new batch is being computed
287-
let entry_waiting_span = info_span!(parent: &entry.span, "waiting");
288-
// Add relationships
289-
span.follows_from(&entry_waiting_span);
290-
entry_waiting_span.follows_from(&span);
291-
// Update entry
292-
entry.temp_span = Some(entry_waiting_span);
293-
});
294-
295-
// Generate one token for this new batch to have the attention past in cache
296-
let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries)
297-
.instrument(span)
298-
.await;
299-
// Reset waiting counter
300-
waiting_tokens = 1;
301-
// Extend current batch with the new batch
302-
if let Some(new_cached_batch) = new_cached_batch {
303-
entries.extend(new_entries);
304-
batches.push(new_cached_batch);
305-
}
269+
let min_size = match waiting_tokens {
270+
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
271+
// to add a new batch even though its size might be small
272+
_ if waiting_tokens >= max_waiting_tokens => None,
273+
// Minimum size criteria
274+
_ => Some((batch_size as f32 * waiting_served_ratio).floor() as usize),
275+
};
276+
277+
let token_budget = max_batch_total_tokens - batch_max_tokens;
278+
279+
// Try to get a new batch
280+
if let Some((mut new_entries, new_batch, span)) =
281+
queue.next_batch(min_size, token_budget).await
282+
{
283+
// Tracking metrics
284+
285+
entries.iter_mut().for_each(|(_, entry)| {
286+
// Create a new span to add the info that this entry is waiting
287+
// because a new batch is being computed
288+
let entry_waiting_span = info_span!(parent: &entry.span, "waiting");
289+
// Add relationships
290+
span.follows_from(&entry_waiting_span);
291+
entry_waiting_span.follows_from(&span);
292+
// Update entry
293+
entry.temp_span = Some(entry_waiting_span);
294+
});
295+
296+
// Generate one token for this new batch to have the attention past in cache
297+
let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries)
298+
.instrument(span)
299+
.await;
300+
// Reset waiting counter
301+
waiting_tokens = 1;
302+
// Extend current batch with the new batch
303+
if let Some(new_cached_batch) = new_cached_batch {
304+
entries.extend(new_entries);
305+
batches.push(new_cached_batch);
306306
}
307307
}
308+
308309
// Create span for this batch to add context to inference calls
309310
let next_batch_size = entries.len();
310311
let next_batch_span =
@@ -341,22 +342,11 @@ async fn prefill(
341342

342343
match client.prefill(batch).await {
343344
Ok((generations, next_batch)) => {
345+
// Send generated tokens and filter stopped entries
344346
filter_send_generations(generations, entries);
345347

346348
// Filter next batch and remove requests that were stopped
347-
let next_batch = match next_batch {
348-
None => None,
349-
Some(batch) => {
350-
let id = batch.id;
351-
let next_batch = filter_batch(batch, entries);
352-
// Next batch is now empty
353-
// Clear it from the Python shards cache
354-
if next_batch.is_none() {
355-
let _ = client.clear_cache(Some(id)).await;
356-
}
357-
next_batch
358-
}
359-
};
349+
let next_batch = filter_batch(client, next_batch, entries).await;
360350

361351
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill");
362352
metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill");
@@ -384,22 +374,11 @@ async fn decode(
384374

385375
match client.decode(batches).await {
386376
Ok((generations, next_batch)) => {
377+
// Send generated tokens and filter stopped entries
387378
filter_send_generations(generations, entries);
388379

389380
// Filter next batch and remove requests that were stopped
390-
let next_batch = match next_batch {
391-
None => None,
392-
Some(batch) => {
393-
let id = batch.id;
394-
let next_batch = filter_batch(batch, entries);
395-
// Next batch is now empty
396-
// Clear it from the Python shards cache
397-
if next_batch.is_none() {
398-
let _ = client.clear_cache(Some(id)).await;
399-
}
400-
next_batch
401-
}
402-
};
381+
let next_batch = filter_batch(client, next_batch, entries).await;
403382

404383
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode");
405384
metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode");
@@ -419,14 +398,35 @@ async fn decode(
419398

420399
/// Filter a `batch` and remove all requests not present in `entries`
421400
#[instrument(skip_all)]
422-
fn filter_batch(mut batch: Batch, entries: &IntMap<u64, Entry>) -> Option<Batch> {
401+
async fn filter_batch(
402+
client: &mut ShardedClient,
403+
next_batch: Option<Batch>,
404+
entries: &IntMap<u64, Entry>,
405+
) -> Option<Batch> {
406+
let mut batch = next_batch?;
407+
408+
// No need to filter
409+
if batch.size as usize == entries.len() {
410+
return Some(batch);
411+
}
412+
413+
let id = batch.id;
414+
415+
// Retain only requests that are still in entries
423416
batch.requests.retain(|r| entries.contains_key(&r.id));
424-
let size = batch.requests.len();
425-
if size == 0 {
426-
return None;
417+
418+
if batch.requests.is_empty() {
419+
// All requests have been filtered out
420+
// Next batch is now empty
421+
// Clear it from the Python shards cache
422+
// We unwrap here as we need to panic since we cannot recover if this method fails
423+
client.clear_cache(Some(id)).await.unwrap();
424+
None
425+
} else {
426+
// Filter Python shard cache
427+
// We unwrap here as we need to panic since we cannot recover if this method fails
428+
client.filter_batch(id, batch.requests).await.unwrap()
427429
}
428-
batch.size = size as u32;
429-
Some(batch)
430430
}
431431

432432
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`

0 commit comments

Comments
 (0)