@@ -15,6 +15,7 @@ use thiserror::Error;
1515use tokio:: sync:: { Notify , OwnedSemaphorePermit , Semaphore , TryAcquireError } ;
1616use tokio:: time:: Instant ;
1717use tracing:: { info_span, instrument, Instrument , Span } ;
18+ use crate :: queue:: BatchingConfig ;
1819
1920/// Inference struct
2021#[ derive( Clone ) ]
@@ -40,19 +41,24 @@ impl Infer {
4041 client : ShardedClient ,
4142 validation : Validation ,
4243 max_batch_size : usize ,
44+ max_batch_weight : usize ,
45+ max_prefill_weight : usize ,
4346 max_waiting_tokens : usize ,
4447 max_concurrent_requests : usize ,
4548 ) -> Self {
4649 // Infer shared state
47- let queue = Queue :: new ( ) ;
50+ let queue = Queue :: new ( BatchingConfig {
51+ size_limit : max_batch_size,
52+ weight_limit : max_batch_weight,
53+ prefill_weight_limit : max_prefill_weight,
54+ } ) ;
4855 let shared = Arc :: new ( Shared {
4956 batching_task : Notify :: new ( ) ,
5057 } ) ;
5158
5259 // Spawn batching background task that contains all the inference logic
5360 tokio:: spawn ( batching_task (
5461 client,
55- max_batch_size,
5662 max_waiting_tokens,
5763 queue. clone ( ) ,
5864 shared. clone ( ) ,
@@ -105,6 +111,7 @@ impl Infer {
105111 // Append the request to the queue
106112 self . queue . append ( Entry {
107113 request : valid_request,
114+ generated_tokens : 0 ,
108115 response_tx,
109116 span : Span :: current ( ) ,
110117 temp_span : None ,
@@ -232,18 +239,11 @@ impl Infer {
232239/// Batches requests and sends them to the inference server
233240async fn batching_task (
234241 mut client : ShardedClient ,
235- max_batch_size : usize ,
242+ // max_batch_size: usize,
236243 max_waiting_tokens : usize ,
237244 queue : Queue ,
238245 shared : Arc < Shared > ,
239246) {
240- // Minimum batch size after which we try to add more requests
241- let limit_min_batch_size = if max_batch_size > 1 {
242- ( max_batch_size / 2 ) as u32
243- } else {
244- 0
245- } ;
246-
247247 // Infinite loop
248248 loop {
249249 // Wait for a notification from the Infer struct
@@ -252,8 +252,8 @@ async fn batching_task(
252252 // Get the next batch from the queue
253253 // This batch might be smaller than the maximum batch size if there are not enough requests
254254 // waiting in the queue
255- while let Some ( ( mut entries, batch, span) ) = queue. next_batch ( None , max_batch_size ) . await {
256- let mut cached_batch = prefill ( & mut client, batch, & mut entries)
255+ while let ( _ , Some ( ( mut entries, batch, span) ) ) = queue. next_batch ( None ) . await {
256+ let ( mut cached_batch, mut some_completed ) = prefill ( & mut client, batch, & mut entries)
257257 . instrument ( span)
258258 . await ;
259259 let mut waiting_tokens = 1 ;
@@ -266,21 +266,16 @@ async fn batching_task(
266266 let mut batches = vec ! [ batch] ;
267267 metrics:: gauge!( "tgi_batch_current_size" , batch_size as f64 ) ;
268268
269- // If the current batch is too small, we try to add more requests to it
270- if batch_size <= limit_min_batch_size {
271- let min_size = match waiting_tokens {
272- // If we didn't onboard any new requests since >= max_waiting_tokens, we try
273- // to add a new batch even though its size might be small
274- _ if waiting_tokens >= max_waiting_tokens => None ,
275- // Minimum size criteria
276- _ => Some ( limit_min_batch_size as usize ) ,
277- } ;
278-
279- // Try to get a new batch
280- if let Some ( ( mut new_entries, new_batch, span) ) = queue
281- . next_batch ( min_size, max_batch_size - batch_size as usize )
282- . await
283- {
269+ // Try to extend batch if its size reduced or enough tokens have elapsed since last one
270+ if some_completed || waiting_tokens >= max_waiting_tokens {
271+
272+ // Try to get a new batch - ownership of entries passed in and out
273+ let (
274+ existing_entries, new_entries
275+ ) = queue. next_batch ( Some ( entries) ) . await ;
276+ entries = existing_entries. unwrap ( ) ;
277+
278+ if let Some ( ( mut new_entries, new_batch, span) ) = new_entries {
284279 entries. iter_mut ( ) . for_each ( |( _, entry) | {
285280 // Create a new span to add the info that this entry is waiting
286281 // because a new batch is being computed
@@ -293,7 +288,7 @@ async fn batching_task(
293288 } ) ;
294289
295290 // Generate one token for this new batch to have the attention past in cache
296- let new_cached_batch = prefill ( & mut client, new_batch, & mut new_entries)
291+ let ( new_cached_batch, _ ) = prefill ( & mut client, new_batch, & mut new_entries)
297292 . instrument ( span)
298293 . await ;
299294 // Reset waiting counter
@@ -319,7 +314,7 @@ async fn batching_task(
319314 entry. temp_span = Some ( entry_batch_span) ;
320315 } ) ;
321316
322- cached_batch = decode ( & mut client, batches, & mut entries)
317+ ( cached_batch, some_completed ) = decode ( & mut client, batches, & mut entries)
323318 . instrument ( next_batch_span)
324319 . await ;
325320 waiting_tokens += 1 ;
@@ -334,14 +329,14 @@ async fn prefill(
334329 client : & mut ShardedClient ,
335330 batch : Batch ,
336331 entries : & mut IntMap < u64 , Entry > ,
337- ) -> Option < Batch > {
332+ ) -> ( Option < Batch > , bool ) {
338333 let start_time = Instant :: now ( ) ;
339334 let batch_id = batch. id ;
340335 metrics:: increment_counter!( "tgi_batch_inference_count" , "method" => "prefill" ) ;
341336
342337 match client. prefill ( batch) . await {
343338 Ok ( ( generations, next_batch) ) => {
344- filter_send_generations ( generations, entries) ;
339+ let some_completed = filter_send_generations ( generations, entries) ;
345340
346341 // Filter next batch and remove requests that were stopped
347342 let next_batch = match next_batch {
@@ -360,14 +355,14 @@ async fn prefill(
360355
361356 metrics:: histogram!( "tgi_batch_inference_duration" , start_time. elapsed( ) . as_secs_f64( ) , "method" => "prefill" ) ;
362357 metrics:: increment_counter!( "tgi_batch_inference_success" , "method" => "prefill" ) ;
363- next_batch
358+ ( next_batch, some_completed )
364359 }
365360 // If we have an error, we discard the whole batch
366361 Err ( err) => {
367362 let _ = client. clear_cache ( Some ( batch_id) ) . await ;
368363 send_errors ( err, entries) ;
369364 metrics:: increment_counter!( "tgi_batch_inference_failure" , "method" => "prefill" ) ;
370- None
365+ ( None , true )
371366 }
372367 }
373368}
@@ -377,14 +372,14 @@ async fn decode(
377372 client : & mut ShardedClient ,
378373 batches : Vec < Batch > ,
379374 entries : & mut IntMap < u64 , Entry > ,
380- ) -> Option < Batch > {
375+ ) -> ( Option < Batch > , bool ) {
381376 let start_time = Instant :: now ( ) ;
382377 let batch_ids: Vec < u64 > = batches. iter ( ) . map ( |b| b. id ) . collect ( ) ;
383378 metrics:: increment_counter!( "tgi_batch_inference_count" , "method" => "decode" ) ;
384379
385380 match client. decode ( batches) . await {
386381 Ok ( ( generations, next_batch) ) => {
387- filter_send_generations ( generations, entries) ;
382+ let some_completed = filter_send_generations ( generations, entries) ;
388383
389384 // Filter next batch and remove requests that were stopped
390385 let next_batch = match next_batch {
@@ -403,7 +398,7 @@ async fn decode(
403398
404399 metrics:: histogram!( "tgi_batch_inference_duration" , start_time. elapsed( ) . as_secs_f64( ) , "method" => "decode" ) ;
405400 metrics:: increment_counter!( "tgi_batch_inference_success" , "method" => "decode" ) ;
406- next_batch
401+ ( next_batch, some_completed )
407402 }
408403 // If we have an error, we discard the whole batch
409404 Err ( err) => {
@@ -412,7 +407,7 @@ async fn decode(
412407 }
413408 send_errors ( err, entries) ;
414409 metrics:: increment_counter!( "tgi_batch_inference_failure" , "method" => "decode" ) ;
415- None
410+ ( None , true )
416411 }
417412 }
418413}
@@ -431,14 +426,16 @@ fn filter_batch(mut batch: Batch, entries: &IntMap<u64, Entry>) -> Option<Batch>
431426
432427/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
433428/// and filter entries
429+ /// Return true if any requests completed
434430#[ instrument( skip_all) ]
435- fn filter_send_generations ( generations : Vec < Generation > , entries : & mut IntMap < u64 , Entry > ) {
431+ fn filter_send_generations ( generations : Vec < Generation > , entries : & mut IntMap < u64 , Entry > ) -> bool {
432+ let mut some_stopped = false ;
436433 generations. into_iter ( ) . for_each ( |generation| {
437434 let id = generation. request_id ;
438435 // Get entry
439436 // We can `expect` here as the request id should always be in the entries
440437 let entry = entries
441- . get ( & id)
438+ . get_mut ( & id)
442439 . expect ( "ID not found in entries. This is a bug." ) ;
443440
444441 // Create and enter a span to link this function back to the entry
@@ -451,9 +448,14 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
451448 err
452449 } ) . unwrap_or ( true ) ;
453450 if stopped {
451+ some_stopped = true ;
454452 entries. remove ( & id) . expect ( "ID not found in entries. This is a bug." ) ;
453+ } else {
454+ // Increment generated token count
455+ entry. generated_tokens += 1 ;
455456 }
456457 } ) ;
458+ return some_stopped;
457459}
458460
459461/// Send responses through the `entry` response channel
0 commit comments