@@ -14,6 +14,7 @@ use thiserror::Error;
1414use tokio:: sync:: { Notify , Semaphore , TryAcquireError } ;
1515use tokio:: time:: Instant ;
1616use tracing:: { info_span, instrument, Instrument , Span } ;
17+ use crate :: queue:: BatchingConfig ;
1718
1819/// Inference struct
1920#[ derive( Clone ) ]
@@ -39,19 +40,24 @@ impl Infer {
3940 client : ShardedClient ,
4041 validation : Validation ,
4142 max_batch_size : usize ,
43+ max_batch_weight : usize ,
44+ max_prefill_weight : usize ,
4245 max_waiting_tokens : usize ,
4346 max_concurrent_requests : usize ,
4447 ) -> Self {
4548 // Infer shared state
46- let queue = Queue :: new ( ) ;
49+ let queue = Queue :: new ( BatchingConfig {
50+ size_limit : max_batch_size,
51+ weight_limit : max_batch_weight,
52+ prefill_weight_limit : max_prefill_weight,
53+ } ) ;
4754 let shared = Arc :: new ( Shared {
4855 batching_task : Notify :: new ( ) ,
4956 } ) ;
5057
5158 // Spawn batching background task that contains all the inference logic
5259 tokio:: spawn ( batching_task (
5360 client,
54- max_batch_size,
5561 max_waiting_tokens,
5662 queue. clone ( ) ,
5763 shared. clone ( ) ,
@@ -99,6 +105,7 @@ impl Infer {
99105 // Append the request to the queue
100106 self . queue . append ( Entry {
101107 request : valid_request,
108+ generated_tokens : 0 ,
102109 response_tx,
103110 span : Span :: current ( ) ,
104111 temp_span : None ,
@@ -227,18 +234,11 @@ impl Infer {
227234/// Batches requests and sends them to the inference server
228235async fn batching_task (
229236 mut client : ShardedClient ,
230- max_batch_size : usize ,
237+ // max_batch_size: usize,
231238 max_waiting_tokens : usize ,
232239 queue : Queue ,
233240 shared : Arc < Shared > ,
234241) {
235- // Minimum batch size after which we try to add more requests
236- let limit_min_batch_size = if max_batch_size > 1 {
237- ( max_batch_size / 2 ) as u32
238- } else {
239- 0
240- } ;
241-
242242 // Infinite loop
243243 loop {
244244 // Wait for a notification from the Infer struct
@@ -247,8 +247,8 @@ async fn batching_task(
247247 // Get the next batch from the queue
248248 // This batch might be smaller than the maximum batch size if there are not enough requests
249249 // waiting in the queue
250- while let Some ( ( mut entries, batch, span) ) = queue. next_batch ( None , max_batch_size ) . await {
251- let mut cached_batch = prefill ( & mut client, batch, & mut entries)
250+ while let ( _ , Some ( ( mut entries, batch, span) ) ) = queue. next_batch ( None ) . await {
251+ let ( mut cached_batch, mut some_completed ) = prefill ( & mut client, batch, & mut entries)
252252 . instrument ( span)
253253 . await ;
254254 let mut waiting_tokens = 1 ;
@@ -261,21 +261,16 @@ async fn batching_task(
261261 let mut batches = vec ! [ batch] ;
262262 metrics:: gauge!( "tgi_batch_current_size" , batch_size as f64 ) ;
263263
264- // If the current batch is too small, we try to add more requests to it
265- if batch_size <= limit_min_batch_size {
266- let min_size = match waiting_tokens {
267- // If we didn't onboard any new requests since >= max_waiting_tokens, we try
268- // to add a new batch even though its size might be small
269- _ if waiting_tokens >= max_waiting_tokens => None ,
270- // Minimum size criteria
271- _ => Some ( limit_min_batch_size as usize ) ,
272- } ;
273-
274- // Try to get a new batch
275- if let Some ( ( mut new_entries, new_batch, span) ) = queue
276- . next_batch ( min_size, max_batch_size - batch_size as usize )
277- . await
278- {
264+ // Try to extend batch if its size reduced or enough tokens have elapsed since last one
265+ if some_completed || waiting_tokens >= max_waiting_tokens {
266+
267+ // Try to get a new batch - ownership of entries passed in and out
268+ let (
269+ existing_entries, new_entries
270+ ) = queue. next_batch ( Some ( entries) ) . await ;
271+ entries = existing_entries. unwrap ( ) ;
272+
273+ if let Some ( ( mut new_entries, new_batch, span) ) = new_entries {
279274 let new_batch_size = new_batch. size ;
280275 entries. iter_mut ( ) . for_each ( |( _, entry) | {
281276 // Create a new span to add the info that this entry is waiting
@@ -290,7 +285,7 @@ async fn batching_task(
290285 } ) ;
291286
292287 // Generate one token for this new batch to have the attention past in cache
293- let new_cached_batch = prefill ( & mut client, new_batch, & mut new_entries)
288+ let ( new_cached_batch, _ ) = prefill ( & mut client, new_batch, & mut new_entries)
294289 . instrument ( span)
295290 . await ;
296291 // Reset waiting counter
@@ -317,7 +312,7 @@ async fn batching_task(
317312 entry. temp_span = Some ( entry_batch_span) ;
318313 } ) ;
319314
320- cached_batch = decode ( & mut client, batches, & mut entries)
315+ ( cached_batch, some_completed ) = decode ( & mut client, batches, & mut entries)
321316 . instrument ( next_batch_span)
322317 . await ;
323318 waiting_tokens += 1 ;
@@ -332,24 +327,24 @@ async fn prefill(
332327 client : & mut ShardedClient ,
333328 batch : Batch ,
334329 entries : & mut IntMap < u64 , Entry > ,
335- ) -> Option < Batch > {
330+ ) -> ( Option < Batch > , bool ) {
336331 let start_time = Instant :: now ( ) ;
337332 let batch_id = batch. id ;
338333 metrics:: increment_counter!( "tgi_batch_inference_count" , "method" => "prefill" ) ;
339334
340335 match client. prefill ( batch) . await {
341336 Ok ( ( generations, next_batch) ) => {
342- send_generations ( generations, entries) ;
337+ let some_completed = send_generations ( generations, entries) ;
343338 metrics:: histogram!( "tgi_batch_inference_duration" , start_time. elapsed( ) . as_secs_f64( ) , "method" => "prefill" ) ;
344339 metrics:: increment_counter!( "tgi_batch_inference_success" , "method" => "prefill" ) ;
345- next_batch
340+ ( next_batch, some_completed )
346341 }
347342 // If we have an error, we discard the whole batch
348343 Err ( err) => {
349344 let _ = client. clear_cache ( Some ( batch_id) ) . await ;
350345 send_errors ( err, entries) ;
351346 metrics:: increment_counter!( "tgi_batch_inference_failure" , "method" => "prefill" ) ;
352- None
347+ ( None , true )
353348 }
354349 }
355350}
@@ -359,22 +354,22 @@ async fn decode(
359354 client : & mut ShardedClient ,
360355 batches : Vec < Batch > ,
361356 entries : & mut IntMap < u64 , Entry > ,
362- ) -> Option < Batch > {
357+ ) -> ( Option < Batch > , bool ) {
363358 let start_time = Instant :: now ( ) ;
364359 metrics:: increment_counter!( "tgi_batch_inference_count" , "method" => "decode" ) ;
365360
366361 match client. decode ( batches) . await {
367362 Ok ( ( generations, next_batch) ) => {
368- send_generations ( generations, entries) ;
363+ let some_completed = send_generations ( generations, entries) ;
369364 metrics:: histogram!( "tgi_batch_inference_duration" , start_time. elapsed( ) . as_secs_f64( ) , "method" => "decode" ) ;
370365 metrics:: increment_counter!( "tgi_batch_inference_success" , "method" => "decode" ) ;
371- next_batch
366+ ( next_batch, some_completed )
372367 }
373368 // If we have an error, we discard the whole batch
374369 Err ( err) => {
375370 send_errors ( err, entries) ;
376371 metrics:: increment_counter!( "tgi_batch_inference_failure" , "method" => "decode" ) ;
377- None
372+ ( None , true )
378373 }
379374 }
380375}
@@ -398,13 +393,15 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
398393}
399394
400395/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
396+ /// Return true if any requests completed
401397#[ instrument( skip_all) ]
402- fn send_generations ( generations : Vec < Generation > , entries : & mut IntMap < u64 , Entry > ) {
398+ fn send_generations ( generations : Vec < Generation > , entries : & mut IntMap < u64 , Entry > ) -> bool {
399+ let mut some_completed = false ;
403400 generations. into_iter ( ) . for_each ( |generation| {
404401 // Get entry
405402 // We can `expect` here as the request id should always be in the entries
406403 let entry = entries
407- . get ( & generation. request_id )
404+ . get_mut ( & generation. request_id )
408405 . expect ( "ID not found in entries. This is a bug." ) ;
409406
410407 // Create and enter a span to link this function back to the entry
@@ -445,7 +442,10 @@ fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entr
445442 start : entry. batch_time . unwrap ( ) ,
446443 } ) )
447444 . unwrap_or ( ( ) ) ;
445+ some_completed = true ;
448446 } else {
447+ // Increment generated token count
448+ entry. generated_tokens += 1 ;
449449 // Send message
450450 // unwrap_or is valid here as we don't care if the receiver is gone.
451451 entry
@@ -454,6 +454,8 @@ fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entr
454454 . unwrap_or ( ( ) ) ;
455455 }
456456 } ) ;
457+
458+ return some_completed;
457459}
458460
459461#[ derive( Debug ) ]
0 commit comments