@@ -39,20 +39,23 @@ impl Infer {
3939 pub ( crate ) fn new (
4040 client : ShardedClient ,
4141 validation : Validation ,
42- max_batch_size : usize ,
42+ waiting_served_ratio : f32 ,
43+ max_batch_total_tokens : u32 ,
4344 max_waiting_tokens : usize ,
4445 max_concurrent_requests : usize ,
46+ requires_padding : bool ,
4547 ) -> Self {
4648 // Infer shared state
47- let queue = Queue :: new ( ) ;
49+ let queue = Queue :: new ( requires_padding ) ;
4850 let shared = Arc :: new ( Shared {
4951 batching_task : Notify :: new ( ) ,
5052 } ) ;
5153
5254 // Spawn batching background task that contains all the inference logic
5355 tokio:: spawn ( batching_task (
5456 client,
55- max_batch_size,
57+ waiting_served_ratio,
58+ max_batch_total_tokens,
5659 max_waiting_tokens,
5760 queue. clone ( ) ,
5861 shared. clone ( ) ,
@@ -232,18 +235,12 @@ impl Infer {
232235/// Batches requests and sends them to the inference server
233236async fn batching_task (
234237 mut client : ShardedClient ,
235- max_batch_size : usize ,
238+ waiting_served_ratio : f32 ,
239+ max_batch_total_tokens : u32 ,
236240 max_waiting_tokens : usize ,
237241 queue : Queue ,
238242 shared : Arc < Shared > ,
239243) {
240- // Minimum batch size after which we try to add more requests
241- let limit_min_batch_size = if max_batch_size > 1 {
242- ( max_batch_size / 2 ) as u32
243- } else {
244- 0
245- } ;
246-
247244 // Infinite loop
248245 loop {
249246 // Wait for a notification from the Infer struct
@@ -252,7 +249,9 @@ async fn batching_task(
252249 // Get the next batch from the queue
253250 // This batch might be smaller than the maximum batch size if there are not enough requests
254251 // waiting in the queue
255- while let Some ( ( mut entries, batch, span) ) = queue. next_batch ( None , max_batch_size) . await {
252+ while let Some ( ( mut entries, batch, span) ) =
253+ queue. next_batch ( None , max_batch_total_tokens) . await
254+ {
256255 let mut cached_batch = prefill ( & mut client, batch, & mut entries)
257256 . instrument ( span)
258257 . await ;
@@ -263,48 +262,50 @@ async fn batching_task(
263262 while let Some ( batch) = cached_batch {
264263 // Get current batch info
265264 let batch_size = batch. size ;
265+ let batch_max_tokens = batch. max_tokens ;
266266 let mut batches = vec ! [ batch] ;
267267 metrics:: gauge!( "tgi_batch_current_size" , batch_size as f64 ) ;
268268
269- // If the current batch is too small, we try to add more requests to it
270- if batch_size <= limit_min_batch_size {
271- let min_size = match waiting_tokens {
272- // If we didn't onboard any new requests since >= max_waiting_tokens, we try
273- // to add a new batch even though its size might be small
274- _ if waiting_tokens >= max_waiting_tokens => None ,
275- // Minimum size criteria
276- _ => Some ( limit_min_batch_size as usize ) ,
277- } ;
278-
279- // Try to get a new batch
280- if let Some ( ( mut new_entries, new_batch, span) ) = queue
281- . next_batch ( min_size, max_batch_size - batch_size as usize )
282- . await
283- {
284- entries . iter_mut ( ) . for_each ( | ( _ , entry ) | {
285- // Create a new span to add the info that this entry is waiting
286- // because a new batch is being computed
287- let entry_waiting_span = info_span ! ( parent : & entry . span , "waiting" ) ;
288- // Add relationships
289- span . follows_from ( & entry_waiting_span ) ;
290- entry_waiting_span . follows_from ( & span ) ;
291- // Update entry
292- entry. temp_span = Some ( entry_waiting_span ) ;
293- } ) ;
294-
295- // Generate one token for this new batch to have the attention past in cache
296- let new_cached_batch = prefill ( & mut client , new_batch , & mut new_entries )
297- . instrument ( span )
298- . await ;
299- // Reset waiting counter
300- waiting_tokens = 1 ;
301- // Extend current batch with the new batch
302- if let Some ( new_cached_batch ) = new_cached_batch {
303- entries . extend ( new_entries ) ;
304- batches . push ( new_cached_batch ) ;
305- }
269+ let min_size = match waiting_tokens {
270+ // If we didn't onboard any new requests since >= max_waiting_tokens, we try
271+ // to add a new batch even though its size might be small
272+ _ if waiting_tokens >= max_waiting_tokens => None ,
273+ // Minimum size criteria
274+ _ => Some ( ( batch_size as f32 * waiting_served_ratio ) . floor ( ) as usize ) ,
275+ } ;
276+
277+ let token_budget = max_batch_total_tokens - batch_max_tokens ;
278+
279+ // Try to get a new batch
280+ if let Some ( ( mut new_entries, new_batch, span) ) =
281+ queue . next_batch ( min_size, token_budget ) . await
282+ {
283+ // Tracking metrics
284+
285+ entries . iter_mut ( ) . for_each ( | ( _ , entry) | {
286+ // Create a new span to add the info that this entry is waiting
287+ // because a new batch is being computed
288+ let entry_waiting_span = info_span ! ( parent : & entry . span , "waiting" ) ;
289+ // Add relationships
290+ span . follows_from ( & entry_waiting_span ) ;
291+ entry_waiting_span . follows_from ( & span ) ;
292+ // Update entry
293+ entry . temp_span = Some ( entry_waiting_span ) ;
294+ } ) ;
295+
296+ // Generate one token for this new batch to have the attention past in cache
297+ let new_cached_batch = prefill ( & mut client , new_batch , & mut new_entries )
298+ . instrument ( span )
299+ . await ;
300+ // Reset waiting counter
301+ waiting_tokens = 1 ;
302+ // Extend current batch with the new batch
303+ if let Some ( new_cached_batch ) = new_cached_batch {
304+ entries . extend ( new_entries ) ;
305+ batches . push ( new_cached_batch ) ;
306306 }
307307 }
308+
308309 // Create span for this batch to add context to inference calls
309310 let next_batch_size = entries. len ( ) ;
310311 let next_batch_span =
@@ -341,22 +342,11 @@ async fn prefill(
341342
342343 match client. prefill ( batch) . await {
343344 Ok ( ( generations, next_batch) ) => {
345+ // Send generated tokens and filter stopped entries
344346 filter_send_generations ( generations, entries) ;
345347
346348 // Filter next batch and remove requests that were stopped
347- let next_batch = match next_batch {
348- None => None ,
349- Some ( batch) => {
350- let id = batch. id ;
351- let next_batch = filter_batch ( batch, entries) ;
352- // Next batch is now empty
353- // Clear it from the Python shards cache
354- if next_batch. is_none ( ) {
355- let _ = client. clear_cache ( Some ( id) ) . await ;
356- }
357- next_batch
358- }
359- } ;
349+ let next_batch = filter_batch ( client, next_batch, entries) . await ;
360350
361351 metrics:: histogram!( "tgi_batch_inference_duration" , start_time. elapsed( ) . as_secs_f64( ) , "method" => "prefill" ) ;
362352 metrics:: increment_counter!( "tgi_batch_inference_success" , "method" => "prefill" ) ;
@@ -384,22 +374,11 @@ async fn decode(
384374
385375 match client. decode ( batches) . await {
386376 Ok ( ( generations, next_batch) ) => {
377+ // Send generated tokens and filter stopped entries
387378 filter_send_generations ( generations, entries) ;
388379
389380 // Filter next batch and remove requests that were stopped
390- let next_batch = match next_batch {
391- None => None ,
392- Some ( batch) => {
393- let id = batch. id ;
394- let next_batch = filter_batch ( batch, entries) ;
395- // Next batch is now empty
396- // Clear it from the Python shards cache
397- if next_batch. is_none ( ) {
398- let _ = client. clear_cache ( Some ( id) ) . await ;
399- }
400- next_batch
401- }
402- } ;
381+ let next_batch = filter_batch ( client, next_batch, entries) . await ;
403382
404383 metrics:: histogram!( "tgi_batch_inference_duration" , start_time. elapsed( ) . as_secs_f64( ) , "method" => "decode" ) ;
405384 metrics:: increment_counter!( "tgi_batch_inference_success" , "method" => "decode" ) ;
@@ -419,14 +398,35 @@ async fn decode(
419398
420399/// Filter a `batch` and remove all requests not present in `entries`
421400#[ instrument( skip_all) ]
422- fn filter_batch ( mut batch : Batch , entries : & IntMap < u64 , Entry > ) -> Option < Batch > {
401+ async fn filter_batch (
402+ client : & mut ShardedClient ,
403+ next_batch : Option < Batch > ,
404+ entries : & IntMap < u64 , Entry > ,
405+ ) -> Option < Batch > {
406+ let mut batch = next_batch?;
407+
408+ // No need to filter
409+ if batch. size as usize == entries. len ( ) {
410+ return Some ( batch) ;
411+ }
412+
413+ let id = batch. id ;
414+
415+ // Retain only requests that are still in entries
423416 batch. requests . retain ( |r| entries. contains_key ( & r. id ) ) ;
424- let size = batch. requests . len ( ) ;
425- if size == 0 {
426- return None ;
417+
418+ if batch. requests . is_empty ( ) {
419+ // All requests have been filtered out
420+ // Next batch is now empty
421+ // Clear it from the Python shards cache
422+ // We unwrap here as we need to panic since we cannot recover if this method fails
423+ client. clear_cache ( Some ( id) ) . await . unwrap ( ) ;
424+ None
425+ } else {
426+ // Filter Python shard cache
427+ // We unwrap here as we need to panic since we cannot recover if this method fails
428+ client. filter_batch ( id, batch. requests ) . await . unwrap ( )
427429 }
428- batch. size = size as u32 ;
429- Some ( batch)
430430}
431431
432432/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
0 commit comments