66from collections .abc import Iterable
77from dataclasses import dataclass
88from itertools import repeat
9- from typing import Any , List , Literal , Optional , cast
9+ from typing import Any , List , Literal , Optional , TypeVar , cast
1010
1111import torch
1212import torch .nn .functional as F
2626 GptDecoderBatched )
2727from tensorrt_llm .executor .result import Logprob
2828from tensorrt_llm .mapping import Mapping
29+ from tensorrt_llm .sampling_params import SamplingParams
2930
3031from ..speculative .spec_tree_manager import SpecTreeManager
3132from .finish_reason import FinishedState
@@ -195,106 +196,104 @@ def is_generation_model(self) -> bool:
195196
196197def top_k_sampling_batch (
197198 logits ,
198- top_k = 50 ,
199- generator : Optional [torch .Generator ] = None
199+ * ,
200+ top_k : int ,
201+ temperature : float ,
202+ generator : Optional [torch .Generator ] = None ,
200203) -> tuple [torch .Tensor , torch .Tensor ]:
201- logits_dim = logits .dim ()
202- if logits_dim == 1 :
203- logits = logits .unsqueeze (0 )
204- # logits should be 2D :[batch_size, vocab_size]
205- batch_size , vocab_size = logits .size ()
204+ # NB: To be replaced by a more efficient implementation.
205+ return top_k_top_p_sampling_batch (
206+ logits ,
207+ top_k = top_k ,
208+ temperature = temperature ,
209+ generator = generator ,
210+ top_p = 1 ,
211+ )
206212
207- # get first top_k logits of each sample and their indices
208- if top_k > 0 :
209- values , indices = torch .topk (logits , top_k , dim = - 1 )
210- min_values = values [:, - 1 ].unsqueeze (- 1 ).expand (batch_size , vocab_size )
211213
212- # set the logits who is less than first top_k logits to -inf
213- logits = torch .where (logits < min_values ,
214- torch .full_like (logits , float ('-inf' )), logits )
214+ def top_p_sampling_batch (
215+ logits : torch .Tensor ,
216+ * ,
217+ top_p : float ,
218+ temperature : float ,
219+ generator : Optional [torch .Generator ] = None ,
220+ ) -> tuple [torch .Tensor , torch .Tensor ]:
221+ # NB: To be replaced by a more efficient implementation.
222+ return top_k_top_p_sampling_batch (
223+ logits ,
224+ top_p = top_p ,
225+ top_k = logits .size (1 ),
226+ temperature = temperature ,
227+ generator = generator ,
228+ )
215229
216- # compute probability distribution
217- softmax = torch .softmax (logits , dim = - 1 )
218230
219- # sample from the distribution and generate result of [batch_size, 1]
220- next_tokens = torch .multinomial (softmax , num_samples = 1 ,
221- generator = generator ).squeeze (- 1 )
222- return next_tokens , softmax
231+ def temperature_sampling_batch (
232+ logits : torch .Tensor ,
233+ * ,
234+ temperature : float ,
235+ generator : Optional [torch .Generator ] = None ,
236+ ) -> tuple [torch .Tensor , torch .Tensor ]:
237+ # NB: To be replaced by a more efficient implementation.
238+ return top_k_top_p_sampling_batch (
239+ logits ,
240+ top_p = 1 ,
241+ top_k = logits .size (1 ),
242+ temperature = temperature ,
243+ generator = generator ,
244+ )
223245
224246
225- def top_p_sampling_batch (
247+ def top_k_top_p_sampling_batch (
226248 logits : torch .Tensor ,
227249 * ,
228- top_p : float = 0.9 ,
229- temperature : float = 1.0 ,
250+ top_k : int ,
251+ top_p : float ,
252+ temperature : float ,
230253 generator : Optional [torch .Generator ] = None
231254) -> tuple [torch .Tensor , torch .Tensor ]:
232255 logits_dim = logits .dim ()
233256 assert logits_dim == 2 , "logits should be 2D: [batch_size, vocab_size]"
257+ assert temperature > 0 , "non-greedy sampling requires valid temperature"
258+ logits = logits / max (temperature , 1e-5 )
259+ batch_size , vocab_size = logits .size ()
234260
235- if temperature != 0 :
236- logits = logits / max (temperature , 1e-5 )
237-
238- # sort the logits of each sample in descending order
239- sorted_logits , sorted_indices = torch .sort (logits , descending = True , dim = - 1 )
240-
241- # compute cumulative probability distribution of each sample
242- cumulative_probs = torch .cumsum (torch .softmax (sorted_logits , dim = - 1 ),
243- dim = - 1 )
244- # get the location of top_p
245- sorted_indices_to_remove = cumulative_probs > top_p
246- sorted_indices_to_remove [:, 1 :] = sorted_indices_to_remove [:, :- 1 ].clone ()
247- sorted_indices_to_remove [:, 0 ] = 0
248-
249- # set the logits to -inf whose is outside top_p
250- indices_to_remove = sorted_indices_to_remove .scatter (
251- 1 , sorted_indices , sorted_indices_to_remove )
252- logits = logits .masked_fill (indices_to_remove , float ('-inf' ))
253-
254- # compute probability distribution
255- softmax = torch .softmax (logits , dim = - 1 )
256-
257- # sample from the distribution and generate result of [batch_size, 1]
258- next_tokens = torch .multinomial (softmax , num_samples = 1 ,
259- generator = generator ).squeeze (- 1 )
260- return next_tokens , softmax
261-
261+ assert top_k > 1 , "non-greedy sampling requires valid top_k"
262+ need_top_k = top_k < vocab_size
263+ assert top_p > 0 , "non-greedy sampling requires valid top_p"
264+ need_top_p = top_p < 1
262265
263- def top_k_top_p_sampling_batch (logits : torch .Tensor ,
264- * ,
265- top_k : int ,
266- top_p : float ,
267- temperature : float = 1.0 ,
268- generator : Optional [torch .Generator ] = None ):
269- logits_dim = logits .dim ()
270- assert logits_dim == 2 , "logits should be 2D: [batch_size, vocab_size]"
271- if temperature != 0 :
272- logits = logits / max (temperature , 1e-5 )
273- batch_size , vocab_size = logits .size ()
274- # get first top_k logits of each sample and their indices
275- if top_k > 0 :
266+ # top-K: mask out logits not belonging to the top-K for each sample
267+ if need_top_k :
276268 values , _ = torch .topk (logits , top_k , dim = - 1 )
277269 min_values = values [:, - 1 ].unsqueeze (- 1 ).expand (batch_size , vocab_size )
278270
279271 # set the logits who is less than first top_k logits to -inf
280272 logits = torch .where (logits < min_values ,
281273 torch .full_like (logits , float ('-inf' )), logits )
282274
283- sorted_logits , sorted_indices = torch .sort (logits , descending = True , dim = - 1 )
284-
285- # compute cumulative probability distribution of each sample
286- cumulative_probs = torch .cumsum (torch .softmax (sorted_logits , dim = - 1 ),
287- dim = - 1 )
288-
289- # get the location of top_p
290- sorted_indices_to_remove = cumulative_probs > top_p
291- sorted_indices_to_remove [:, 1 :] = sorted_indices_to_remove [:, :- 1 ].clone ()
292- sorted_indices_to_remove [:, 0 ] = 0
293-
294- # set the logits to -inf whose is outside top_p
295- indices_to_remove = sorted_indices_to_remove .scatter (
296- 1 , sorted_indices , sorted_indices_to_remove )
297- logits = logits .masked_fill (indices_to_remove , float ('-inf' ))
275+ # top-p: mask out logits outside the nucleus
276+ if need_top_p :
277+ sorted_logits , sorted_indices = torch .sort (logits ,
278+ descending = True ,
279+ dim = - 1 )
280+
281+ # compute cumulative probability distribution of each sample
282+ cumulative_probs = torch .cumsum (torch .softmax (sorted_logits , dim = - 1 ),
283+ dim = - 1 )
284+
285+ # get the location of top_p
286+ # NB: Currently selecting the smallest index with cumulative_probs > top_p.
287+ # Thus, top_p -> 0 resembles greedy; agreement requires torch.sort(..., stable=True).
288+ sorted_indices_to_remove = cumulative_probs > top_p
289+ sorted_indices_to_remove [:,
290+ 1 :] = sorted_indices_to_remove [:, :- 1 ].clone ()
291+ sorted_indices_to_remove [:, 0 ] = 0
292+
293+ # set the logits to -inf for token indices outside top_p
294+ indices_to_remove = sorted_indices_to_remove .scatter (
295+ 1 , sorted_indices , sorted_indices_to_remove )
296+ logits = logits .masked_fill (indices_to_remove , float ('-inf' ))
298297
299298 # compute probability distribution
300299 softmax = torch .softmax (logits , dim = - 1 )
@@ -359,48 +358,78 @@ def sample_rejected(draft_probs: torch.Tensor, target_probs: torch.Tensor,
359358 return new_token
360359
361360
362- TopK = tuple [Literal ["top_k" ], int ]
361+ TemperatureOnly = tuple [Literal ["temperature" ], float ]
362+ TopK = tuple [Literal ["top_k" ], int , float ]
363363TopP = tuple [Literal ["top_p" ], float , float ]
364364TopKTopP = tuple [Literal ["top_k_top_p" ], int , float , float ]
365365Greedy = tuple [Literal ["greedy" ], None ]
366366GREEDY : Greedy = ("greedy" , None )
367- Strategy = TopK | TopP | Greedy | TopKTopP
368-
369-
370- def _request_strategy (request : LlmRequest ) -> Strategy :
371- # top_p and top_K with temperature=0.0 reduces to greedy
372- # sampling
373- temperature = request .sampling_config .temperature
374- if temperature is not None :
375- temperature = temperature [0 ]
376- if temperature == 0.0 :
377- return GREEDY
378-
379- if request .sampling_config .top_k is not None and len (
380- request .sampling_config .top_k
381- ) > 0 and request .sampling_config .top_p is not None and len (
382- request .sampling_config .top_p ) > 0 :
383- return ("top_k_top_p" , request .sampling_config .top_k [0 ],
384- request .sampling_config .top_p [0 ], temperature )
385- elif request .sampling_config .top_p is not None and len (
386- request .sampling_config .top_p ) > 0 :
387- top_p = request .sampling_config .top_p [0 ]
388- return ("top_p" , top_p , temperature )
389- elif request .sampling_config .top_k is not None and len (
390- request .sampling_config .top_k ) > 0 :
391- return ("top_k" , request .sampling_config .top_k [0 ])
392- else :
367+ Strategy = TopK | TopP | Greedy | TopKTopP | TemperatureOnly
368+
369+ T = TypeVar ('T' )
370+
371+
372+ # Due to tensorrt_llm::runtime::SamplingConfig using vectors, params
373+ # in LlmRequest.sampling_params are either None or single-element lists.
374+ # This helper method simplifies code using such params.
375+ def _unwrap_singleton (p : Optional [List [T ]]) -> Optional [T ]:
376+ if p is None :
377+ return None
378+ t , = p
379+ return t
380+
381+
382+ def _request_strategy (request : LlmRequest , * , vocab_size : int ) -> Strategy :
383+ # The semantics are specified in the doc-string of SamplingParams
384+
385+ sampling_config = request .sampling_config
386+ temperature = _unwrap_singleton (
387+ cast (Optional [List [float ]], sampling_config .temperature ))
388+ top_p = _unwrap_singleton (cast (Optional [List [float ]],
389+ sampling_config .top_p ))
390+ top_k = _unwrap_singleton (cast (Optional [List [int ]], sampling_config .top_k ))
391+
392+ if SamplingParams .params_imply_greedy_decoding (
393+ temperature = temperature ,
394+ top_p = top_p ,
395+ top_k = top_k ,
396+ ):
393397 return GREEDY
394398
399+ # --- resolving default values
400+ # NB: not greedy, hence temperature != 0 if specified
401+ temperature = temperature or 1.0
402+
403+ # NB: not greedy, hence top_p != 0 if specified
404+ top_p = top_p or 1.0
405+ # NB: not greedy, hence top_k != 1 if specified
406+ # (0 and vocab_size are equivalent)
407+ top_k = top_k or vocab_size
408+
409+ assert top_k > 1 , "non-greedy sampling requires valid top_k"
410+ need_top_k = top_k < vocab_size
411+ assert top_p > 0 , "non-greedy sampling requires valid top_p"
412+ need_top_p = top_p < 1
413+
414+ if need_top_p :
415+ if need_top_k :
416+ return ("top_k_top_p" , top_k , top_p , temperature )
417+ return ("top_p" , top_p , temperature )
418+ if need_top_k :
419+ return ("top_k" , top_k , temperature )
420+ return ("temperature" , temperature )
421+
395422
396423def _group_requests_by_sampling_strategy (
397424 requests : Iterable [LlmRequest ],
398425 * ,
399- pin_memory : bool = False ) -> dict [Strategy , torch .Tensor ]:
426+ pin_memory : bool = False ,
427+ vocab_size : int ) -> dict [Strategy , torch .Tensor ]:
400428 # NB: Client code relies on request indices in returned torch.Tensor being sorted.
401429 strategy_dict : dict [Strategy , list [int ]] = defaultdict (list )
402430 for req_index , req in enumerate (requests ):
403- strategy_dict [_request_strategy (req )].append (req_index )
431+ strategy_dict [_request_strategy (
432+ req , vocab_size = vocab_size )].append (req_index )
404433 return {
405434 strategy : torch .tensor (indices ,
406435 pin_memory = pin_memory ,
@@ -418,23 +447,32 @@ def sample(
418447) -> tuple [torch .Tensor , torch .Tensor ]:
419448 filter_softmax = True
420449 match strategy :
421- case ("top_k" , top_k ):
422- tokens , softmax = top_k_sampling_batch (logits , top_k , generator )
450+ case ("top_k" , top_k , temperature ):
451+ tokens , softmax = top_k_sampling_batch (logits ,
452+ top_k = top_k ,
453+ temperature = temperature ,
454+ generator = generator )
423455 case ("top_p" , top_p , temperature ):
424456 tokens , softmax = top_p_sampling_batch (
425457 logits ,
426458 top_p = top_p ,
427459 generator = generator ,
428- ** ( dict ( temperature = temperature )
429- if temperature is not None else dict ()) )
460+ temperature = temperature ,
461+ )
430462 case ("top_k_top_p" , top_k , top_p , temperature ):
431463 tokens , softmax = top_k_top_p_sampling_batch (
432464 logits ,
433465 top_k = top_k ,
434466 top_p = top_p ,
467+ temperature = temperature ,
435468 generator = generator ,
436- ** (dict (temperature = temperature )
437- if temperature is not None else dict ()))
469+ )
470+ case ("temperature" , temperature ):
471+ tokens , softmax = temperature_sampling_batch (
472+ logits ,
473+ temperature = temperature ,
474+ generator = generator ,
475+ )
438476 case ("greedy" , None ):
439477 tokens , softmax = greedy_search_sampling_batch (
440478 logits , softmax_indices = softmax_indices )
@@ -1323,7 +1361,7 @@ def _sample_batched_by_strategy(
13231361 dim = - 1 )
13241362
13251363 requests_by_strategy = _group_requests_by_sampling_strategy (
1326- requests , pin_memory = True )
1364+ requests , pin_memory = True , vocab_size = logits_cuda . size ( 1 ) )
13271365 generator_cuda = self .get_generator (cuda_device )
13281366
13291367 # FIXME: This check should/could be performed in ModelDrafter.prepare_draft_tokens
0 commit comments