1111import  weakref 
1212from  collections  import  deque , namedtuple 
1313from  contextlib  import  contextmanager 
14- from  typing  import  Dict , List , Optional , Tuple ,  Union 
14+ from  typing  import  Dict , List , Optional , Union 
1515
1616import  torch 
1717
@@ -308,7 +308,7 @@ def __init__(self,
308308        if  is_trace_enabled ("TLLM_TRACE_EXECUTOR_LOOP" ):
309309            self .event_loop  =  trace_func (self .event_loop )
310310
311-         if  self .draft_model_engine  is  not   None :
311+         if  self .drafter  is  not   None :
312312            if  self .event_loop .__name__  !=  self ._executor_loop .__name__ :
313313                raise  NotImplementedError (
314314                    "Drafting is not supported for selected executor loop. " 
@@ -905,10 +905,6 @@ def _executor_loop_pp(self):
905905
906906    def  _executor_loop (self ):
907907        torch .cuda .set_device (self .device_id )
908-         is_ngram  =  hasattr (
909-             self .model_engine , "spec_config" 
910-         ) and  self .model_engine .spec_config  is  not   None  and  self .model_engine .spec_config .spec_dec_mode .is_ngram (
911-         )
912908        with  self ._profiler () as  profile_step :
913909            sample_state  =  None 
914910            iter_start_time  =  time .time ()
@@ -931,7 +927,7 @@ def _executor_loop(self):
931927
932928                self ._pad_attention_dp_dummy_request ()
933929
934-                 if  self .draft_model_engine   is   not   None   or   is_ngram   or   self . drafter  is  not   None :
930+                 if  self .drafter  is  not   None :
935931                    self ._prepare_draft_requests (self .active_requests )
936932
937933                scheduled_batch , fitting_disagg_gen_init_requests , num_fitting_reqs  =  self ._schedule (
@@ -971,11 +967,9 @@ def _executor_loop(self):
971967                            scheduled_batch )
972968
973969                    self .resource_manager .prepare_resources (scheduled_batch )
974-                     if  self .draft_model_engine  is  not   None :
975-                         self ._prepare_draft_tokens (scheduled_batch )
976- 
977970                    if  self .drafter  is  not   None :
978-                         self .drafter .prepare_draft_tokens (scheduled_batch )
971+                         self .drafter .prepare_draft_tokens (
972+                             scheduled_batch , self .resource_manager )
979973
980974                    if  self .kv_cache_transceiver :
981975                        # For generation requests which have completed KV cache transfer 
@@ -1798,188 +1792,6 @@ def _update_requests(self, sample_state: SampleState):
17981792            logger .error (f"Encountered an error in sampling: { error_msg }  " )
17991793            self ._handle_errors (error_msg )
18001794
1801-     @nvtx_range ("_prepare_draft_batch" ) 
1802-     def  _prepare_draft_batch (
1803-         self , scheduled_requests : ScheduledRequests 
1804-     ) ->  Tuple [ScheduledRequests , Dict [int , LlmRequest ]]:
1805-         """ 
1806-         Prepares a batch for the draft model engine. Draft tokens are only produced 
1807-         for generation requests. 
1808- 
1809-         The requests are prepared as follows: 
1810-         1. The first time the draft engine sees a request, it's a context request. 
1811-         2. Otherwise, if draft tokens were accepted on the last target model decoding 
1812-         step, it's a chunked context request (we process all the accepted tokens together). 
1813-         3. Otherwise, it's a generation request. 
1814-         """ 
1815-         try :
1816-             draft_batch  =  ScheduledRequests ()
1817- 
1818-             for  request  in  scheduled_requests .generation_requests :
1819-                 if  request .py_draft_pages_allocated  ==  0 :
1820-                     # No space for draft tokens. 
1821-                     continue 
1822- 
1823-                 # Stop drafting when we hit the max seqlen. We still need dummy draft 
1824-                 # tokens attached to the requests to make sure everything works properly 
1825-                 # with CUDA graph. These dummy tokens are already added by 
1826-                 # _prepare_draft_requests to make the KV cache/scheduler aware of the fact 
1827-                 # that we want to do spec decoding, so no need to do anything else here. 
1828-                 # This makes the perf for this case suboptimal, but that's OK - this is 
1829-                 # a corner case for weird models like the llama 3.1 8b EAGLE3 implementation. 
1830-                 if  request .max_beam_num_tokens  -  1  >=  self .draft_model_engine .max_seq_len :
1831-                     continue 
1832- 
1833-                 num_draft_tokens  =  len (
1834-                     request .py_last_draft_tokens 
1835-                 ) if  request .py_last_draft_tokens  is  not   None  else  0 
1836-                 request .py_draft_tokens  =  []
1837- 
1838-                 num_accepted_tokens  =  request .py_num_accepted_draft_tokens 
1839-                 num_rejected_tokens  =  num_draft_tokens  -  num_accepted_tokens 
1840-                 assert  num_rejected_tokens  >=  0 
1841- 
1842-                 spec_config  =  self .model_engine .spec_config 
1843-                 beam_idx  =  0 
1844-                 input_tokens  =  spec_config .get_draft_model_prompt (
1845-                     request .get_tokens ()[beam_idx ])
1846- 
1847-                 def  create_new_request (input_tokens ):
1848-                     return  LlmRequest (
1849-                         request_id = request .py_request_id ,
1850-                         max_new_tokens = request .py_max_new_tokens ,
1851-                         input_tokens = input_tokens ,
1852-                         sampling_config = request .sampling_config ,
1853-                         return_perf_metrics = request .return_perf_metrics ,
1854-                         is_streaming = False ,
1855-                         is_draft = True )
1856- 
1857-                 if  request .max_beam_num_tokens  -  1  ==  request .py_prompt_len :
1858-                     # This is the first time the draft model is seeing this request. 
1859-                     # Prepare a context request. We discard the first token and take 
1860-                     # the newly decoded one - this is the convention for EAGLE 2 and 3. 
1861-                     new_request  =  create_new_request (input_tokens )
1862-                     draft_batch .context_requests .append (new_request )
1863-                 elif  num_accepted_tokens  ==  0 :
1864-                     new_request  =  create_new_request (input_tokens [:- 1 ])
1865-                     # Explicitly add the last token so get_last_tokens() returns 
1866-                     # the right value 
1867-                     new_request .add_new_token (input_tokens [- 1 ], beam_idx )
1868-                     new_request .state  =  LlmRequestState .GENERATION_IN_PROGRESS 
1869-                     draft_batch .generation_requests .append (new_request )
1870-                 else :
1871-                     new_request  =  create_new_request (input_tokens )
1872-                     new_request .context_chunk_size  =  num_accepted_tokens  +  1 
1873-                     new_request .context_current_position  =  len (
1874-                         input_tokens ) -  num_accepted_tokens  -  1 
1875-                     new_request .context_chunk_size  =  num_accepted_tokens  +  1 
1876-                     new_request .context_current_position  =  len (
1877-                         input_tokens ) -  num_accepted_tokens  -  1 
1878- 
1879-                     draft_batch .context_requests .append (new_request )
1880- 
1881-                 new_request .py_stop_words_list  =  request .py_stop_words_list 
1882- 
1883-             return  draft_batch 
1884- 
1885-         except  Exception  as  e :
1886-             traceback .print_exc ()
1887-             error_msg  =  str (e )
1888-             logger .error (f"Encountered an error in decode: { error_msg }  " )
1889-             self ._handle_errors (error_msg )
1890- 
1891-     @nvtx_range ("_prepare_draft_tokens" ) 
1892-     def  _prepare_draft_tokens (self , scheduled_requests : ScheduledRequests ):
1893-         if  not  self .draft_model_engine :
1894-             raise  ValueError ("Draft model engine is not set" )
1895- 
1896-         try :
1897-             draft_batch  =  self ._prepare_draft_batch (scheduled_requests )
1898- 
1899-             if  draft_batch .batch_size  ==  0 :
1900-                 return 
1901-             self .draft_seq_slot_manager .prepare_resources (draft_batch )
1902- 
1903-             req_id_to_old_request  =  {
1904-                 req .py_request_id : req 
1905-                 for  req  in  scheduled_requests .all_requests ()
1906-             }
1907- 
1908-             # Disable cuda graph for the 1st draft model forward 
1909-             if  self .model_engine .spec_config .spec_dec_mode .needs_kv_cache_recompute (
1910-             ):
1911-                 with  self .draft_model_engine .no_cuda_graph ():
1912-                     outputs  =  self .draft_model_engine .forward (
1913-                         draft_batch , self .resource_manager )
1914-             else :
1915-                 outputs  =  self .draft_model_engine .forward (
1916-                     draft_batch , self .resource_manager )
1917-             if  hasattr (self .draft_model_engine .model .model , 'd2t' ):
1918-                 outputs ['d2t' ] =  self .draft_model_engine .model .model .d2t .data 
1919- 
1920-             sample_state  =  self ._sample_async (draft_batch , outputs )
1921-             previous_batch  =  sample_state 
1922- 
1923-             self ._update_request_states (draft_batch )
1924- 
1925-             def  _process_decoded_tokens (draft_batch ):
1926-                 new_requests  =  []
1927-                 for  req  in  draft_batch .all_requests ():
1928-                     target_model_req  =  req_id_to_old_request [req .py_request_id ]
1929-                     target_model_req .py_draft_tokens .append (
1930-                         req .get_last_tokens (0 ))
1931-                     if  req .state  !=  LlmRequestState .GENERATION_COMPLETE  and  len (
1932-                             target_model_req .py_draft_tokens 
1933-                     ) <  target_model_req .py_draft_pages_allocated :
1934-                         new_requests .append (req )
1935-                     else :
1936-                         self .draft_seq_slot_manager .free_resources (req )
1937- 
1938-                 return  new_requests 
1939- 
1940-             # The TRTLLM attention kernels cannot handle generation requests with 
1941-             # different seqlens. No issues with flashinfer, should we look into removing 
1942-             # this? Just needs proper kernel support. 
1943-             def  _pad_to_max_draft_tokens ():
1944-                 for  req  in  scheduled_requests .generation_requests :
1945-                     max_draft_len  =  self .max_draft_len 
1946-                     num_draft_tokens  =  len (req .py_draft_tokens )
1947-                     req .py_draft_tokens .extend (
1948-                         0  for  _  in  range (max_draft_len  -  num_draft_tokens ))
1949- 
1950-             draft_batch .generation_requests  =  draft_batch .context_requests  +  draft_batch .generation_requests 
1951-             draft_batch .context_requests  =  []
1952- 
1953-             for  i  in  range (self .max_draft_len  -  1 ):
1954-                 if  len (draft_batch .generation_requests ) ==  0 :
1955-                     break 
1956- 
1957-                 outputs  =  self .draft_model_engine .forward (
1958-                     draft_batch ,
1959-                     self .resource_manager ,
1960-                     new_tensors_device = previous_batch .device )
1961- 
1962-                 if  hasattr (self .draft_model_engine .model .model , 'd2t' ):
1963-                     outputs [
1964-                         'd2t' ] =  self .draft_model_engine .model .model .d2t .data 
1965-                 sample_state  =  self ._sample_async (draft_batch , outputs )
1966-                 self ._update_request_states (draft_batch )
1967-                 self ._update_requests (previous_batch )
1968-                 new_requests  =  _process_decoded_tokens (
1969-                     previous_batch .scheduled_requests )
1970-                 draft_batch .generation_requests  =  new_requests 
1971-                 previous_batch  =  sample_state 
1972-             self ._update_requests (previous_batch )
1973-             new_requests  =  _process_decoded_tokens (
1974-                 previous_batch .scheduled_requests )
1975-             _pad_to_max_draft_tokens ()
1976- 
1977-         except  Exception  as  e :
1978-             traceback .print_exc ()
1979-             error_msg  =  str (e )
1980-             logger .error (f"Encountered an error in decode: { error_msg }  " )
1981-             self ._handle_errors (error_msg )
1982- 
19831795    def  _handle_errors (self , error_msg : Optional [str ] =  None ):
19841796        error_responses  =  {}
19851797        error_msg  =  error_msg  or  "error" 
0 commit comments