7474
7575LORA_WARMUP_RANK  =  8 
7676
77+ DUMMY_TOKEN_ID  =  - 1 
78+ 
7779
7880class  Singleton (type ):
7981    _instances : Dict [type , object ] =  {}
@@ -668,6 +670,9 @@ def __init__(
668670
669671        # For multi-step scheduling 
670672        self .cached_step_outputs : List [torch .Tensor ] =  []
673+         # For delayed sampling 
674+         self .cached_step_inputs : List [
675+             ModelInputForHPUWithSamplingMetadata ] =  []
671676
672677    def  _set_gc_threshold (self ) ->  None :
673678        # Read https://docs.python.org/3/library/gc.html#gc.set_threshold 
@@ -771,6 +776,12 @@ def load_model(self) -> None:
771776        msg  =  f"Loading model weights took in total { m .get_summary_string ()}  
772777        logger .info (msg )
773778
779+     def  _maybe_wrap_in_hpu_graph (self , * args , ** kwargs ):
780+         return  htorch .hpu .wrap_in_hpu_graph (
781+             HpuModelAdapter (* args , ** kwargs ), disable_tensor_cache = True 
782+         ) if  htorch .utils .internal .is_lazy () else  HpuModelAdapter (
783+             * args , ** kwargs )
784+ 
774785    def  get_model (self ) ->  nn .Module :
775786        return  self .model 
776787
@@ -2020,6 +2031,21 @@ def create_lora_mask(self, input_tokens: torch.Tensor, lora_ids: List[int],
20202031
20212032        return  lora_mask , lora_logits_mask 
20222033
2034+     def  _get_seq_ids (self , model_input ):
2035+         return  ([
2036+             sg .seq_ids [0 ] for  sg  in  model_input .sampling_metadata .seq_groups 
2037+         ])
2038+ 
2039+     def  _pad_to_max_num_seqs (self , tensor , value ):
2040+         padding_needed  =  self .max_num_seqs  -  tensor .size (0 )
2041+         if  padding_needed :
2042+             padding  =  torch .full ((padding_needed , * tensor .shape [1 :]),
2043+                                  value ,
2044+                                  device = tensor .device ,
2045+                                  dtype = tensor .dtype )
2046+             tensor  =  torch .cat ([tensor , padding ])
2047+         return  tensor 
2048+ 
20232049    @torch .inference_mode () 
20242050    def  execute_model (
20252051        self ,
@@ -2030,6 +2056,37 @@ def execute_model(
20302056        warmup_mode = False ,
20312057        seqs = None ,
20322058    ) ->  Optional [Union [List [SamplerOutput ], IntermediateTensors ]]:
2059+         VLLM_DELAYED_SAMPLING  =  envs .VLLM_HPU_USE_DELAYED_SAMPLING 
2060+         use_delayed_sampling  =  VLLM_DELAYED_SAMPLING  and  not  warmup_mode 
2061+         assert  not  (use_delayed_sampling  and  num_steps  !=  1 ), \
2062+             'Delayed sampling is not compatible with MSS!' 
2063+         assert  model_input .input_tokens  is  not None 
2064+         if  use_delayed_sampling  and  not  model_input .is_prompt  and  \
2065+                 self .is_driver_worker :
2066+             num_cached  =  len (self .cached_step_outputs )
2067+             assert  num_cached  >  0 
2068+             cur_seq_ids  =  self ._get_seq_ids (model_input )
2069+             cur_seq_id_pos  =  {
2070+                 sid : idx 
2071+                 for  idx , sid  in  enumerate (cur_seq_ids ) if  sid  >=  0 
2072+             }
2073+             htorch .core .mark_step ()
2074+             for  i  in  range (num_cached ):
2075+                 prev_seq_ids  =  self ._get_seq_ids (self .cached_step_inputs [i ])
2076+                 target_indices  =  [
2077+                     cur_seq_id_pos .get (psi , - 1 ) for  psi  in  prev_seq_ids 
2078+                 ]
2079+                 padding  =  self .cached_step_outputs [i ].size (0 ) -  len (
2080+                     target_indices )
2081+                 target_indices .extend ([- 1 ] *  padding )
2082+                 target_indices  =  torch .tensor (
2083+                     target_indices ,
2084+                     device = model_input .input_tokens .device ,
2085+                     dtype = model_input .input_tokens .dtype )
2086+                 model_input .input_tokens .index_copy_ (
2087+                     0 , target_indices , self .cached_step_outputs [i ])
2088+                 htorch .core .mark_step ()
2089+ 
20332090        if  not  model_input .is_first_multi_step :
20342091            if  not  model_input .is_last_step :
20352092                # not first or last multi-step 
@@ -2045,7 +2102,21 @@ def execute_model(
20452102                assert  model_input .lora_mapping  is  not None 
20462103                self .set_active_loras (model_input .lora_requests ,
20472104                                      model_input .lora_mapping )
2048-             input_tokens  =  model_input .input_tokens 
2105+             # Rank!=0 workers has is_prompt==None 
2106+             if  use_delayed_sampling  and  not  model_input .is_prompt  and  \
2107+                     model_input .input_tokens .size (1 ) ==  1 :
2108+                 if  self .is_driver_worker :
2109+                     model_kwargs_broadcast_data  =  {
2110+                         "input_tokens" : model_input .input_tokens 
2111+                     }
2112+                     broadcast_tensor_dict (model_kwargs_broadcast_data , src = 0 )
2113+                     input_tokens  =  model_input .input_tokens 
2114+ 
2115+                 else :
2116+                     model_kwargs_broadcast_data  =  broadcast_tensor_dict (src = 0 )
2117+                     input_tokens  =  model_kwargs_broadcast_data ["input_tokens" ]
2118+             else :
2119+                 input_tokens  =  model_input .input_tokens 
20492120            input_positions  =  model_input .input_positions 
20502121            attn_metadata  =  model_input .attn_metadata 
20512122            sampling_metadata  =  model_input .sampling_metadata 
@@ -2092,7 +2163,7 @@ def execute_model(
20922163                                    f"graphs{ 'T'  if  use_graphs  else  'F' }  )
20932164            else :
20942165                model_event_name  =  'model_executable' 
2095-             if  num_steps  >  1 :
2166+             if  num_steps  >  1   or   use_delayed_sampling :
20962167                # in case of multi-step scheduling 
20972168                # we only want to pythonize in the last step 
20982169                sampling_metadata .skip_sampler_cpu_output  =  True 
@@ -2152,9 +2223,9 @@ def try_revert_dummy_output_tokens():
21522223                if  not  self .is_driver_worker :
21532224                    continue 
21542225
2155-                 if  model_input . async_callback   is   not   None :
2156-                     model_input . async_callback ( )
2157-                  # Sample the next token. 
2226+                 if  use_delayed_sampling :
2227+                     fake_output   =   self . _delayed_sampler_outputs ( model_input )
2228+ 
21582229                with  self .profiler .record_event (
21592230                        'internal' , ('sample_' 
21602231                                     f'{ "prompt"  if  is_prompt  else  "decode" }  
@@ -2166,9 +2237,16 @@ def try_revert_dummy_output_tokens():
21662237                    )
21672238                    if  num_steps  >  1 :
21682239                        output  =  output .sampled_token_ids 
2169-                         self .cached_step_outputs .append (
2170-                             output .detach ().clone ())
2240+                         self .cached_step_outputs .append (output )
2241+                     if  use_delayed_sampling  and  self .is_driver_worker :
2242+                         self ._patch_prev_output ()
2243+                         output  =  self ._pad_to_max_num_seqs (
2244+                             output .sampled_token_ids , DUMMY_TOKEN_ID )
2245+                         self .cached_step_outputs .append (output )
2246+                         self .cached_step_inputs .append (model_input )
21712247                htorch .core .mark_step ()
2248+                 if  model_input .async_callback  is  not None :
2249+                     model_input .async_callback ()
21722250                if  i  <  num_steps  -  1 :
21732251                    if  i  ==  0 :
21742252                        if  model_input .async_callback  is  not None :
@@ -2241,11 +2319,30 @@ def try_revert_dummy_output_tokens():
22412319                    is_prompt = is_prompt )
22422320                self .profiler .record_counter (self .event_start , counters )
22432321            if  num_steps  ==  1 :
2322+                 if  self .return_hidden_states :
2323+                     # we only need to pass hidden states of most recent token 
2324+                     assert  model_input .sampling_metadata  is  not None 
2325+                     if  model_input .is_prompt :
2326+                         output .prefill_hidden_states  =  hidden_states 
2327+                     output .hidden_states  =  hidden_states 
2328+                 if  use_delayed_sampling :
2329+                     if  self .is_driver_worker :
2330+                         return  [fake_output ]
2331+                     else :
2332+                         return  []
2333+ 
22442334                return  [output ] if  self .is_driver_worker  else  []
22452335            else :
22462336                return  []
22472337        return  output  if  type (output ) is  list  else  [output ]
22482338
2339+     def  _delayed_sampler_outputs (self , model_input ):
2340+         next_token_ids  =  [[DUMMY_TOKEN_ID ]] *  len (
2341+             model_input .sampling_metadata .seq_groups )
2342+         sampler_output  =  self ._make_decode_output (
2343+             next_token_ids , model_input .sampling_metadata .seq_groups )
2344+         return  sampler_output 
2345+ 
22492346    def  _decode_sampler_outputs (self , model_input ):
22502347        use_async_out_proc  =  model_input .async_callback  is  not None 
22512348        sampler_outputs  =  []
@@ -2312,3 +2409,32 @@ def shutdown_inc(self):
23122409
23132410    def  __del__ (self ):
23142411        self .shutdown_inc ()
2412+ 
2413+     def  _patch_prev_output (self ):
2414+         assert  len (self .cached_step_inputs ) ==  len (self .cached_step_outputs ), \
2415+             f'''Inputs and outputs are out of sync! 
2416+             { len (self .cached_step_inputs )} { len (self .cached_step_outputs )}  
2417+         if  len (self .cached_step_inputs ) ==  0 :
2418+             return 
2419+         model_input  =  self .cached_step_inputs .pop (0 )
2420+         delayed_output  =  self .cached_step_outputs .pop (0 ).cpu ().squeeze (
2421+             - 1 ).tolist ()
2422+         ctx  =  model_input .async_callback .keywords ["ctx" ]  # type: ignore 
2423+         # If there's no output to patch with, which is usually the case when 
2424+         # we're starting a new request after all requests are completed. 
2425+         if  len (ctx .output_queue ) ==  0 :
2426+             return 
2427+         assert  len (
2428+             ctx .output_queue ) ==  1 , 'There should be exactly 1 output waiting!' 
2429+         output_data  =  ctx .output_queue [0 ]
2430+         assert  len (output_data .outputs ) ==  1 
2431+         for  fake_out , real_out  in  zip (output_data .outputs [0 ], delayed_output ):
2432+             fake_out .samples [0 ].output_token  =  real_out 
2433+         for  sg , real_out  in  zip (output_data .seq_group_metadata_list ,
2434+                                 delayed_output ):
2435+             assert  len (sg .seq_data ) ==  1 
2436+             seq_data  =  list (sg .seq_data .values ())[0 ]
2437+             # This is a hack. Assigning output_token_ids triggers 
2438+             # a cache recomputation and we only need to update the last token 
2439+             seq_data .output_token_ids_array [- 1 ] =  real_out 
2440+             seq_data ._cached_all_token_ids [- 1 ] =  real_out 
0 commit comments