7474
7575LORA_WARMUP_RANK = 8
7676
77+ DUMMY_TOKEN_ID = - 1
78+
7779
7880class Singleton (type ):
7981 _instances : Dict [type , object ] = {}
@@ -668,6 +670,9 @@ def __init__(
668670
669671 # For multi-step scheduling
670672 self .cached_step_outputs : List [torch .Tensor ] = []
673+ # For delayed sampling
674+ self .cached_step_inputs : List [
675+ ModelInputForHPUWithSamplingMetadata ] = []
671676
672677 def _set_gc_threshold (self ) -> None :
673678 # Read https://docs.python.org/3/library/gc.html#gc.set_threshold
@@ -771,6 +776,12 @@ def load_model(self) -> None:
771776 msg = f"Loading model weights took in total { m .get_summary_string ()} "
772777 logger .info (msg )
773778
779+ def _maybe_wrap_in_hpu_graph (self , * args , ** kwargs ):
780+ return htorch .hpu .wrap_in_hpu_graph (
781+ HpuModelAdapter (* args , ** kwargs ), disable_tensor_cache = True
782+ ) if htorch .utils .internal .is_lazy () else HpuModelAdapter (
783+ * args , ** kwargs )
784+
774785 def get_model (self ) -> nn .Module :
775786 return self .model
776787
@@ -2020,6 +2031,21 @@ def create_lora_mask(self, input_tokens: torch.Tensor, lora_ids: List[int],
20202031
20212032 return lora_mask , lora_logits_mask
20222033
2034+ def _get_seq_ids (self , model_input ):
2035+ return ([
2036+ sg .seq_ids [0 ] for sg in model_input .sampling_metadata .seq_groups
2037+ ])
2038+
2039+ def _pad_to_max_num_seqs (self , tensor , value ):
2040+ padding_needed = self .max_num_seqs - tensor .size (0 )
2041+ if padding_needed :
2042+ padding = torch .full ((padding_needed , * tensor .shape [1 :]),
2043+ value ,
2044+ device = tensor .device ,
2045+ dtype = tensor .dtype )
2046+ tensor = torch .cat ([tensor , padding ])
2047+ return tensor
2048+
20232049 @torch .inference_mode ()
20242050 def execute_model (
20252051 self ,
@@ -2030,6 +2056,37 @@ def execute_model(
20302056 warmup_mode = False ,
20312057 seqs = None ,
20322058 ) -> Optional [Union [List [SamplerOutput ], IntermediateTensors ]]:
2059+ VLLM_DELAYED_SAMPLING = envs .VLLM_HPU_USE_DELAYED_SAMPLING
2060+ use_delayed_sampling = VLLM_DELAYED_SAMPLING and not warmup_mode
2061+ assert not (use_delayed_sampling and num_steps != 1 ), \
2062+ 'Delayed sampling is not compatible with MSS!'
2063+ assert model_input .input_tokens is not None
2064+ if use_delayed_sampling and not model_input .is_prompt and \
2065+ self .is_driver_worker :
2066+ num_cached = len (self .cached_step_outputs )
2067+ assert num_cached > 0
2068+ cur_seq_ids = self ._get_seq_ids (model_input )
2069+ cur_seq_id_pos = {
2070+ sid : idx
2071+ for idx , sid in enumerate (cur_seq_ids ) if sid >= 0
2072+ }
2073+ htorch .core .mark_step ()
2074+ for i in range (num_cached ):
2075+ prev_seq_ids = self ._get_seq_ids (self .cached_step_inputs [i ])
2076+ target_indices = [
2077+ cur_seq_id_pos .get (psi , - 1 ) for psi in prev_seq_ids
2078+ ]
2079+ padding = self .cached_step_outputs [i ].size (0 ) - len (
2080+ target_indices )
2081+ target_indices .extend ([- 1 ] * padding )
2082+ target_indices = torch .tensor (
2083+ target_indices ,
2084+ device = model_input .input_tokens .device ,
2085+ dtype = model_input .input_tokens .dtype )
2086+ model_input .input_tokens .index_copy_ (
2087+ 0 , target_indices , self .cached_step_outputs [i ])
2088+ htorch .core .mark_step ()
2089+
20332090 if not model_input .is_first_multi_step :
20342091 if not model_input .is_last_step :
20352092 # not first or last multi-step
@@ -2045,7 +2102,21 @@ def execute_model(
20452102 assert model_input .lora_mapping is not None
20462103 self .set_active_loras (model_input .lora_requests ,
20472104 model_input .lora_mapping )
2048- input_tokens = model_input .input_tokens
2105+ # Rank!=0 workers has is_prompt==None
2106+ if use_delayed_sampling and not model_input .is_prompt and \
2107+ model_input .input_tokens .size (1 ) == 1 :
2108+ if self .is_driver_worker :
2109+ model_kwargs_broadcast_data = {
2110+ "input_tokens" : model_input .input_tokens
2111+ }
2112+ broadcast_tensor_dict (model_kwargs_broadcast_data , src = 0 )
2113+ input_tokens = model_input .input_tokens
2114+
2115+ else :
2116+ model_kwargs_broadcast_data = broadcast_tensor_dict (src = 0 )
2117+ input_tokens = model_kwargs_broadcast_data ["input_tokens" ]
2118+ else :
2119+ input_tokens = model_input .input_tokens
20492120 input_positions = model_input .input_positions
20502121 attn_metadata = model_input .attn_metadata
20512122 sampling_metadata = model_input .sampling_metadata
@@ -2092,7 +2163,7 @@ def execute_model(
20922163 f"graphs{ 'T' if use_graphs else 'F' } " )
20932164 else :
20942165 model_event_name = 'model_executable'
2095- if num_steps > 1 :
2166+ if num_steps > 1 or use_delayed_sampling :
20962167 # in case of multi-step scheduling
20972168 # we only want to pythonize in the last step
20982169 sampling_metadata .skip_sampler_cpu_output = True
@@ -2152,9 +2223,9 @@ def try_revert_dummy_output_tokens():
21522223 if not self .is_driver_worker :
21532224 continue
21542225
2155- if model_input . async_callback is not None :
2156- model_input . async_callback ( )
2157- # Sample the next token.
2226+ if use_delayed_sampling :
2227+ fake_output = self . _delayed_sampler_outputs ( model_input )
2228+
21582229 with self .profiler .record_event (
21592230 'internal' , ('sample_'
21602231 f'{ "prompt" if is_prompt else "decode" } _'
@@ -2166,9 +2237,16 @@ def try_revert_dummy_output_tokens():
21662237 )
21672238 if num_steps > 1 :
21682239 output = output .sampled_token_ids
2169- self .cached_step_outputs .append (
2170- output .detach ().clone ())
2240+ self .cached_step_outputs .append (output )
2241+ if use_delayed_sampling and self .is_driver_worker :
2242+ self ._patch_prev_output ()
2243+ output = self ._pad_to_max_num_seqs (
2244+ output .sampled_token_ids , DUMMY_TOKEN_ID )
2245+ self .cached_step_outputs .append (output )
2246+ self .cached_step_inputs .append (model_input )
21712247 htorch .core .mark_step ()
2248+ if model_input .async_callback is not None :
2249+ model_input .async_callback ()
21722250 if i < num_steps - 1 :
21732251 if i == 0 :
21742252 if model_input .async_callback is not None :
@@ -2241,11 +2319,30 @@ def try_revert_dummy_output_tokens():
22412319 is_prompt = is_prompt )
22422320 self .profiler .record_counter (self .event_start , counters )
22432321 if num_steps == 1 :
2322+ if self .return_hidden_states :
2323+ # we only need to pass hidden states of most recent token
2324+ assert model_input .sampling_metadata is not None
2325+ if model_input .is_prompt :
2326+ output .prefill_hidden_states = hidden_states
2327+ output .hidden_states = hidden_states
2328+ if use_delayed_sampling :
2329+ if self .is_driver_worker :
2330+ return [fake_output ]
2331+ else :
2332+ return []
2333+
22442334 return [output ] if self .is_driver_worker else []
22452335 else :
22462336 return []
22472337 return output if type (output ) is list else [output ]
22482338
2339+ def _delayed_sampler_outputs (self , model_input ):
2340+ next_token_ids = [[DUMMY_TOKEN_ID ]] * len (
2341+ model_input .sampling_metadata .seq_groups )
2342+ sampler_output = self ._make_decode_output (
2343+ next_token_ids , model_input .sampling_metadata .seq_groups )
2344+ return sampler_output
2345+
22492346 def _decode_sampler_outputs (self , model_input ):
22502347 use_async_out_proc = model_input .async_callback is not None
22512348 sampler_outputs = []
@@ -2312,3 +2409,32 @@ def shutdown_inc(self):
23122409
23132410 def __del__ (self ):
23142411 self .shutdown_inc ()
2412+
2413+ def _patch_prev_output (self ):
2414+ assert len (self .cached_step_inputs ) == len (self .cached_step_outputs ), \
2415+ f'''Inputs and outputs are out of sync!
2416+ { len (self .cached_step_inputs )} vs { len (self .cached_step_outputs )} '''
2417+ if len (self .cached_step_inputs ) == 0 :
2418+ return
2419+ model_input = self .cached_step_inputs .pop (0 )
2420+ delayed_output = self .cached_step_outputs .pop (0 ).cpu ().squeeze (
2421+ - 1 ).tolist ()
2422+ ctx = model_input .async_callback .keywords ["ctx" ] # type: ignore
2423+ # If there's no output to patch with, which is usually the case when
2424+ # we're starting a new request after all requests are completed.
2425+ if len (ctx .output_queue ) == 0 :
2426+ return
2427+ assert len (
2428+ ctx .output_queue ) == 1 , 'There should be exactly 1 output waiting!'
2429+ output_data = ctx .output_queue [0 ]
2430+ assert len (output_data .outputs ) == 1
2431+ for fake_out , real_out in zip (output_data .outputs [0 ], delayed_output ):
2432+ fake_out .samples [0 ].output_token = real_out
2433+ for sg , real_out in zip (output_data .seq_group_metadata_list ,
2434+ delayed_output ):
2435+ assert len (sg .seq_data ) == 1
2436+ seq_data = list (sg .seq_data .values ())[0 ]
2437+ # This is a hack. Assigning output_token_ids triggers
2438+ # a cache recomputation and we only need to update the last token
2439+ seq_data .output_token_ids_array [- 1 ] = real_out
2440+ seq_data ._cached_all_token_ids [- 1 ] = real_out
0 commit comments