2222from tensorrt_llm ._torch .pyexecutor .seq_slot_manager import SeqSlotManager
2323from tensorrt_llm ._utils import (customized_gc_thresholds , global_mpi_rank ,
2424 is_trace_enabled , nvtx_range , trace_func )
25+ from tensorrt_llm .bindings .exceptions import RequestSpecificException
2526from tensorrt_llm .bindings .executor import (DisServingRequestStats ,
2627 FinishReason , InflightBatchingStats ,
2728 IterationStats , KvCacheStats ,
@@ -686,8 +687,7 @@ def _executor_loop_pp(self):
686687 logger .warning (
687688 "num_fitting_reqs=0 and fitting_disagg_gen_init_requests is empty, may not have enough kvCache"
688689 )
689- self .kv_cache_transceiver .check_context_transfer_status (
690- 1 )
690+ self ._check_disagg_ctx_cache_transfer_status (1 )
691691
692692 self .num_scheduled_requests = scheduled_batch .batch_size
693693
@@ -887,7 +887,11 @@ def _prepare_and_schedule_batch(self):
887887 logger .warning (
888888 "num_fitting_reqs=0 and fitting_disagg_gen_init_requests is empty, may not have enough kvCache"
889889 )
890- self .kv_cache_transceiver .check_context_transfer_status (1 )
890+ self ._check_disagg_ctx_cache_transfer_status (1 )
891+ else :
892+ assert scheduled_batch .batch_size > 0 , (
893+ "fail to schedule any pending request, "
894+ "probably run out of resource." )
891895
892896 self .num_scheduled_requests = scheduled_batch .batch_size
893897 logger .debug (
@@ -1258,7 +1262,7 @@ def _check_disagg_gen_transfer_status(self):
12581262
12591263 if need_check :
12601264 at_least_num = 1 if need_check_one else 0
1261- self .kv_cache_transceiver . check_gen_transfer_status (at_least_num )
1265+ self ._check_disagg_gen_cache_transfer_status (at_least_num )
12621266
12631267 return
12641268
@@ -1361,8 +1365,7 @@ def _recv_disagg_gen_cache(self, new_gen_reqs):
13611365 req .is_disagg_generation_transmission_in_progress
13621366 for req in self .active_requests
13631367 ])
1364- self .kv_cache_transceiver .check_gen_transfer_status (
1365- 1 if block_transfer else 0 )
1368+ self ._check_disagg_gen_cache_transfer_status (1 if block_transfer else 0 )
13661369
13671370 return
13681371
@@ -1382,7 +1385,7 @@ def _send_disagg_ctx_cache(self, scheduled_ctx_requests):
13821385 self .resource_manager .resource_managers [
13831386 resource_mgr_type ].free_resources (req )
13841387
1385- self .kv_cache_transceiver . check_context_transfer_status (0 )
1388+ self ._check_disagg_ctx_cache_transfer_status (0 )
13861389
13871390 # Keep track of ctx requests that are in transmission
13881391 ctx_transmission_reqs = [
@@ -1392,6 +1395,38 @@ def _send_disagg_ctx_cache(self, scheduled_ctx_requests):
13921395
13931396 return ctx_transmission_reqs
13941397
1398+ def _check_cache_transfer_status_helper (self ,
1399+ method_name : str ,
1400+ method_call ,
1401+ atLeastNum : int = 0 ):
1402+ """Helper method to handle cache transfer status checking with error handling."""
1403+ try :
1404+ method_call (atLeastNum )
1405+ except RequestSpecificException as e :
1406+ error_msg = str (e )
1407+ logger .error (
1408+ f"Encountered a request-specific error in { method_name } : { error_msg } "
1409+ )
1410+ request_ids = [e .request_id ]
1411+ self ._handle_errors (error_msg , request_ids )
1412+ except Exception as e :
1413+ error_msg = str (e )
1414+ logger .error (
1415+ f"Encountered a system error in { method_name } : { error_msg } " )
1416+ self ._handle_errors (error_msg )
1417+
1418+ @nvtx_range ("_check_disagg_ctx_cache_transfer_status" )
1419+ def _check_disagg_ctx_cache_transfer_status (self , atLeastNum : int = 0 ):
1420+ self ._check_cache_transfer_status_helper (
1421+ "checking context transfer status" ,
1422+ self .kv_cache_transceiver .check_context_transfer_status , atLeastNum )
1423+
1424+ @nvtx_range ("_check_disagg_gen_cache_transfer_status" )
1425+ def _check_disagg_gen_cache_transfer_status (self , atLeastNum : int = 0 ):
1426+ self ._check_cache_transfer_status_helper (
1427+ "checking generation transfer status" ,
1428+ self .kv_cache_transceiver .check_gen_transfer_status , atLeastNum )
1429+
13951430 def _forward_step (self ,
13961431 scheduled_requests ,
13971432 new_tensors_device : Optional [SampleStateTensors ] = None ):
@@ -1501,27 +1536,26 @@ def _update_requests(self, sample_state: SampleState):
15011536
15021537 def _handle_errors (self ,
15031538 error_msg : Optional [str ] = None ,
1504- * ,
1505- requests : Optional [List [LlmRequest ]] = None ):
1506- error_responses : Dict [int , LlmResponse ] = {}
1539+ request_ids : Optional [List [int ]] = None ):
1540+ error_responses = {}
15071541 error_msg = error_msg or "error"
1508- failed_requests = requests if requests is not None else self .active_requests
1509- for request in failed_requests :
1542+ for request in self .active_requests :
1543+ if request_ids is not None and request .py_request_id not in request_ids :
1544+ continue
15101545 req_id = request .py_request_id
15111546 request .state = LlmRequestState .GENERATION_COMPLETE
15121547 self ._terminate_request (request )
15131548 error_responses [req_id ] = LlmResponse (
15141549 request_id = req_id ,
15151550 error_msg = error_msg ,
15161551 client_id = request .py_client_id )
1517- if requests is None :
1518- self .active_requests .clear ()
1552+
1553+ if request_ids is not None :
1554+ for req_id in request_ids :
1555+ self .active_requests .remove (req_id )
15191556 else :
1520- self .active_requests = [
1521- request for request in self .active_requests
1522- if request not in requests
1523- ]
1524- self ._enqueue_responses (error_responses .items ())
1557+ self .active_requests .clear ()
1558+ self ._enqueue_responses (error_responses )
15251559
15261560 def _terminate_request (self , request : LlmRequest ):
15271561 self .resource_manager .free_resources (request )
0 commit comments