@@ -791,8 +791,18 @@ def execute_model(
791791 arange )
792792 selected_token_ids = self .sample_from_logits (logits ,
793793 tpu_sampling_metadata )
794+
795+ # NOTE (NickLucche) Use the original logits (before any penalties or
796+ # temperature scaling) for the top-k logprobs. We can't enforce it due
797+ # to recompilations outside torch.compiled code, so just make sure
798+ # `sample_from_logits` does not modify the logits in-place.
799+ logprobs = self .gather_logprobs (logits , selected_token_ids ) \
800+ if tpu_sampling_metadata .logprobs else None
801+
794802 # Remove padding on cpu and keep dynamic op outside of xla graph.
795803 selected_token_ids = selected_token_ids .cpu ()[:num_reqs ]
804+ logprobs_lists = logprobs .tolists () \
805+ if tpu_sampling_metadata .logprobs else None
796806
797807 # Update the cache state concurrently. Code above will not block until
798808 # we use `selected_token_ids`. Add mark_step if post-processing changes
@@ -862,7 +872,7 @@ def execute_model(
862872 req_id_to_index = self .input_batch .req_id_to_index ,
863873 sampled_token_ids = valid_sampled_token_ids ,
864874 spec_token_ids = None ,
865- logprobs = None ,
875+ logprobs = logprobs_lists ,
866876 prompt_logprobs_dict = prompt_logprobs_dict ,
867877 )
868878
@@ -1121,6 +1131,22 @@ def _precompile_sample_from_logits(self) -> None:
11211131 logger .info ("Compilation finished in %.2f [secs]." , end - start )
11221132 self ._update_num_xla_graphs ("sample_from_logits" )
11231133
1134+ def _precompile_gather_logprobs (self ) -> None :
1135+ logger .info ("Compiling gather_logprobs with different input shapes." )
1136+ start = time .perf_counter ()
1137+ for num_reqs in self .num_reqs_paddings :
1138+ dummy_logits = torch .zeros ((num_reqs , self .vocab_size ),
1139+ device = self .device ,
1140+ dtype = self ._hidden_states_dtype )
1141+ dummy_tokens = torch .zeros ((num_reqs , 1 ),
1142+ dtype = torch .int64 ).to (self .device )
1143+ self .gather_logprobs (dummy_logits , dummy_tokens )
1144+ logger .info (" -- num_seqs: %d" , num_reqs )
1145+ xm .wait_device_ops ()
1146+ end = time .perf_counter ()
1147+ logger .info ("Compilation finished in %.2f [secs]." , end - start )
1148+ self ._update_num_xla_graphs ("gather_logprobs" )
1149+
11241150 def capture_model (self ) -> None :
11251151 """
11261152 Precompile all the subgraphs with possible input shapes.
@@ -1131,6 +1157,7 @@ def capture_model(self) -> None:
11311157 self ._precompile_compute_logits ()
11321158 self ._precompile_structured_decoding ()
11331159 self ._precompile_sample_from_logits ()
1160+ self ._precompile_gather_logprobs ()
11341161
11351162 def profile_run (
11361163 self ,
@@ -1254,13 +1281,31 @@ def compute_logits(self,
12541281 def sample_from_logits (
12551282 self , logits : torch .Tensor ,
12561283 sampling_metadata : TPUSupportedSamplingMetadata ) -> torch .Tensor :
1284+ """
1285+ Sample with xla-friendly function. This function is to be traced
1286+ separately from `forward` for lighter compilation overhead.
1287+ """
12571288 if sampling_metadata .all_greedy :
12581289 out_tokens = torch .argmax (logits , dim = - 1 , keepdim = True )
12591290 else :
12601291 out_tokens = self .sampler (logits ,
12611292 sampling_metadata ).sampled_token_ids
12621293 return out_tokens
12631294
1295+ @torch .compile (backend = "openxla" , fullgraph = True , dynamic = False )
1296+ def gather_logprobs (self , logits : torch .Tensor ,
1297+ sampled_tokens : torch .Tensor ) -> LogprobsTensors :
1298+ """
1299+ Gather the top_logprobs with corresponding tokens. Use a fixed number
1300+ of logprobs as an alternative to having multiple pre-compiled graphs.
1301+ Select the number of logprobs actually demanded by each request on CPU.
1302+ """
1303+ logprobs = self .sampler .compute_logprobs (logits )
1304+ return self .sampler .gather_logprobs (
1305+ logprobs ,
1306+ self .model_config .max_logprobs ,
1307+ token_ids = sampled_tokens .squeeze (- 1 ))
1308+
12641309 @torch .compile (backend = "openxla" , fullgraph = True , dynamic = False )
12651310 def structured_decode (self , require_struct_decoding : torch .Tensor ,
12661311 grammar_bitmask : torch .Tensor , logits : torch .Tensor ,
0 commit comments