Skip to content

Commit 90ef0c7

Browse files
committed
chore: refine sampling strategy selection
Signed-off-by: ixlmar <[email protected]>
1 parent ab07b4c commit 90ef0c7

File tree

6 files changed

+458
-125
lines changed

6 files changed

+458
-125
lines changed

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ extend_skip_glob = [
3434
"tests/unittest/_torch/modeling/test_modeling_mistral.py",
3535
"tests/unittest/_torch/modeling/test_modeling_pixtral.py",
3636
"tests/unittest/_torch/models/checkpoints/hf/test_weight_loader.py",
37+
"tests/unittest/_torch/sampler/test_torch_sampler.py",
3738
]
3839

3940
[tool.yapf]
@@ -65,6 +66,7 @@ ignore_patterns = [
6566
"tests/unittest/_torch/modeling/test_modeling_mistral.py",
6667
"tests/unittest/_torch/modeling/test_modeling_pixtral.py",
6768
"tests/unittest/_torch/models/checkpoints/hf/test_weight_loader.py",
69+
"tests/unittest/_torch/sampler/test_torch_sampler.py",
6870
]
6971

7072
[tool.codespell]
@@ -144,6 +146,7 @@ include = [
144146
"tests/unittest/_torch/modeling/test_modeling_mistral.py",
145147
"tests/unittest/_torch/modeling/test_modeling_pixtral.py",
146148
"tests/unittest/_torch/models/checkpoints/hf/test_weight_loader.py",
149+
"tests/unittest/_torch/sampler/test_torch_sampler.py",
147150
]
148151
exclude = [
149152
"**3rdparty/**",

tensorrt_llm/_torch/auto_deploy/shim/demollm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ def _sample(
235235
logits_shape = logits.shape
236236
logits = logits.view(-1, logits_shape[-1]) # sampling_batch expects 2D logits
237237
if isinstance(sampling_params.top_k, int):
238-
idx_next, probs = top_k_sampling_batch(logits, sampling_params.top_k)
238+
idx_next, probs = top_k_sampling_batch(logits, top_k=sampling_params.top_k)
239239
else:
240240
idx_next, probs = greedy_search_sampling_batch(logits)
241241
idx_next = idx_next.view(logits_shape[:-1])

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 153 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from collections.abc import Iterable
77
from dataclasses import dataclass
88
from itertools import repeat
9-
from typing import Any, List, Literal, Optional, cast
9+
from typing import Any, List, Literal, Optional, TypeVar, cast
1010

1111
import torch
1212
import torch.nn.functional as F
@@ -26,6 +26,7 @@
2626
GptDecoderBatched)
2727
from tensorrt_llm.executor.result import Logprob
2828
from tensorrt_llm.mapping import Mapping
29+
from tensorrt_llm.sampling_params import SamplingParams
2930

3031
from ..speculative.spec_tree_manager import SpecTreeManager
3132
from .finish_reason import FinishedState
@@ -195,106 +196,104 @@ def is_generation_model(self) -> bool:
195196

196197
def top_k_sampling_batch(
197198
logits,
198-
top_k=50,
199-
generator: Optional[torch.Generator] = None
199+
*,
200+
top_k: int,
201+
temperature: float,
202+
generator: Optional[torch.Generator] = None,
200203
) -> tuple[torch.Tensor, torch.Tensor]:
201-
logits_dim = logits.dim()
202-
if logits_dim == 1:
203-
logits = logits.unsqueeze(0)
204-
# logits should be 2D :[batch_size, vocab_size]
205-
batch_size, vocab_size = logits.size()
204+
# NB: To be replaced by a more efficient implementation.
205+
return top_k_top_p_sampling_batch(
206+
logits,
207+
top_k=top_k,
208+
temperature=temperature,
209+
generator=generator,
210+
top_p=1,
211+
)
206212

207-
# get first top_k logits of each sample and their indices
208-
if top_k > 0:
209-
values, indices = torch.topk(logits, top_k, dim=-1)
210-
min_values = values[:, -1].unsqueeze(-1).expand(batch_size, vocab_size)
211213

212-
# set the logits who is less than first top_k logits to -inf
213-
logits = torch.where(logits < min_values,
214-
torch.full_like(logits, float('-inf')), logits)
214+
def top_p_sampling_batch(
215+
logits: torch.Tensor,
216+
*,
217+
top_p: float,
218+
temperature: float,
219+
generator: Optional[torch.Generator] = None,
220+
) -> tuple[torch.Tensor, torch.Tensor]:
221+
# NB: To be replaced by a more efficient implementation.
222+
return top_k_top_p_sampling_batch(
223+
logits,
224+
top_p=top_p,
225+
top_k=logits.size(1),
226+
temperature=temperature,
227+
generator=generator,
228+
)
215229

216-
# compute probability distribution
217-
softmax = torch.softmax(logits, dim=-1)
218230

219-
# sample from the distribution and generate result of [batch_size, 1]
220-
next_tokens = torch.multinomial(softmax, num_samples=1,
221-
generator=generator).squeeze(-1)
222-
return next_tokens, softmax
231+
def temperature_sampling_batch(
232+
logits: torch.Tensor,
233+
*,
234+
temperature: float,
235+
generator: Optional[torch.Generator] = None,
236+
) -> tuple[torch.Tensor, torch.Tensor]:
237+
# NB: To be replaced by a more efficient implementation.
238+
return top_k_top_p_sampling_batch(
239+
logits,
240+
top_p=1,
241+
top_k=logits.size(1),
242+
temperature=temperature,
243+
generator=generator,
244+
)
223245

224246

225-
def top_p_sampling_batch(
247+
def top_k_top_p_sampling_batch(
226248
logits: torch.Tensor,
227249
*,
228-
top_p: float = 0.9,
229-
temperature: float = 1.0,
250+
top_k: int,
251+
top_p: float,
252+
temperature: float,
230253
generator: Optional[torch.Generator] = None
231254
) -> tuple[torch.Tensor, torch.Tensor]:
232255
logits_dim = logits.dim()
233256
assert logits_dim == 2, "logits should be 2D: [batch_size, vocab_size]"
257+
assert temperature > 0, "non-greedy sampling requires valid temperature"
258+
logits = logits / max(temperature, 1e-5)
259+
batch_size, vocab_size = logits.size()
234260

235-
if temperature != 0:
236-
logits = logits / max(temperature, 1e-5)
237-
238-
# sort the logits of each sample in descending order
239-
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
240-
241-
# compute cumulative probability distribution of each sample
242-
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1),
243-
dim=-1)
244-
# get the location of top_p
245-
sorted_indices_to_remove = cumulative_probs > top_p
246-
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
247-
sorted_indices_to_remove[:, 0] = 0
248-
249-
# set the logits to -inf whose is outside top_p
250-
indices_to_remove = sorted_indices_to_remove.scatter(
251-
1, sorted_indices, sorted_indices_to_remove)
252-
logits = logits.masked_fill(indices_to_remove, float('-inf'))
253-
254-
# compute probability distribution
255-
softmax = torch.softmax(logits, dim=-1)
256-
257-
# sample from the distribution and generate result of [batch_size, 1]
258-
next_tokens = torch.multinomial(softmax, num_samples=1,
259-
generator=generator).squeeze(-1)
260-
return next_tokens, softmax
261-
261+
assert top_k > 1, "non-greedy sampling requires valid top_k"
262+
need_top_k = top_k < vocab_size
263+
assert top_p > 0, "non-greedy sampling requires valid top_p"
264+
need_top_p = top_p < 1
262265

263-
def top_k_top_p_sampling_batch(logits: torch.Tensor,
264-
*,
265-
top_k: int,
266-
top_p: float,
267-
temperature: float = 1.0,
268-
generator: Optional[torch.Generator] = None):
269-
logits_dim = logits.dim()
270-
assert logits_dim == 2, "logits should be 2D: [batch_size, vocab_size]"
271-
if temperature != 0:
272-
logits = logits / max(temperature, 1e-5)
273-
batch_size, vocab_size = logits.size()
274-
# get first top_k logits of each sample and their indices
275-
if top_k > 0:
266+
# top-K: mask out logits not belonging to the top-K for each sample
267+
if need_top_k:
276268
values, _ = torch.topk(logits, top_k, dim=-1)
277269
min_values = values[:, -1].unsqueeze(-1).expand(batch_size, vocab_size)
278270

279271
# set the logits who is less than first top_k logits to -inf
280272
logits = torch.where(logits < min_values,
281273
torch.full_like(logits, float('-inf')), logits)
282274

283-
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
284-
285-
# compute cumulative probability distribution of each sample
286-
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1),
287-
dim=-1)
288-
289-
# get the location of top_p
290-
sorted_indices_to_remove = cumulative_probs > top_p
291-
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
292-
sorted_indices_to_remove[:, 0] = 0
293-
294-
# set the logits to -inf whose is outside top_p
295-
indices_to_remove = sorted_indices_to_remove.scatter(
296-
1, sorted_indices, sorted_indices_to_remove)
297-
logits = logits.masked_fill(indices_to_remove, float('-inf'))
275+
# top-p: mask out logits outside the nucleus
276+
if need_top_p:
277+
sorted_logits, sorted_indices = torch.sort(logits,
278+
descending=True,
279+
dim=-1)
280+
281+
# compute cumulative probability distribution of each sample
282+
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1),
283+
dim=-1)
284+
285+
# get the location of top_p
286+
# NB: Currently selecting the smallest index with cumulative_probs > top_p.
287+
# Thus, top_p -> 0 resembles greedy; agreement requires torch.sort(..., stable=True).
288+
sorted_indices_to_remove = cumulative_probs > top_p
289+
sorted_indices_to_remove[:,
290+
1:] = sorted_indices_to_remove[:, :-1].clone()
291+
sorted_indices_to_remove[:, 0] = 0
292+
293+
# set the logits to -inf for token indices outside top_p
294+
indices_to_remove = sorted_indices_to_remove.scatter(
295+
1, sorted_indices, sorted_indices_to_remove)
296+
logits = logits.masked_fill(indices_to_remove, float('-inf'))
298297

299298
# compute probability distribution
300299
softmax = torch.softmax(logits, dim=-1)
@@ -359,48 +358,78 @@ def sample_rejected(draft_probs: torch.Tensor, target_probs: torch.Tensor,
359358
return new_token
360359

361360

362-
TopK = tuple[Literal["top_k"], int]
361+
TemperatureOnly = tuple[Literal["temperature"], float]
362+
TopK = tuple[Literal["top_k"], int, float]
363363
TopP = tuple[Literal["top_p"], float, float]
364364
TopKTopP = tuple[Literal["top_k_top_p"], int, float, float]
365365
Greedy = tuple[Literal["greedy"], None]
366366
GREEDY: Greedy = ("greedy", None)
367-
Strategy = TopK | TopP | Greedy | TopKTopP
368-
369-
370-
def _request_strategy(request: LlmRequest) -> Strategy:
371-
# top_p and top_K with temperature=0.0 reduces to greedy
372-
# sampling
373-
temperature = request.sampling_config.temperature
374-
if temperature is not None:
375-
temperature = temperature[0]
376-
if temperature == 0.0:
377-
return GREEDY
378-
379-
if request.sampling_config.top_k is not None and len(
380-
request.sampling_config.top_k
381-
) > 0 and request.sampling_config.top_p is not None and len(
382-
request.sampling_config.top_p) > 0:
383-
return ("top_k_top_p", request.sampling_config.top_k[0],
384-
request.sampling_config.top_p[0], temperature)
385-
elif request.sampling_config.top_p is not None and len(
386-
request.sampling_config.top_p) > 0:
387-
top_p = request.sampling_config.top_p[0]
388-
return ("top_p", top_p, temperature)
389-
elif request.sampling_config.top_k is not None and len(
390-
request.sampling_config.top_k) > 0:
391-
return ("top_k", request.sampling_config.top_k[0])
392-
else:
367+
Strategy = TopK | TopP | Greedy | TopKTopP | TemperatureOnly
368+
369+
T = TypeVar('T')
370+
371+
372+
# Due to tensorrt_llm::runtime::SamplingConfig using vectors, params
373+
# in LlmRequest.sampling_params are either None or single-element lists.
374+
# This helper method simplifies code using such params.
375+
def _unwrap_singleton(p: Optional[List[T]]) -> Optional[T]:
376+
if p is None:
377+
return None
378+
t, = p
379+
return t
380+
381+
382+
def _request_strategy(request: LlmRequest, *, vocab_size: int) -> Strategy:
383+
# The semantics are specified in the doc-string of SamplingParams
384+
385+
sampling_config = request.sampling_config
386+
temperature = _unwrap_singleton(
387+
cast(Optional[List[float]], sampling_config.temperature))
388+
top_p = _unwrap_singleton(cast(Optional[List[float]],
389+
sampling_config.top_p))
390+
top_k = _unwrap_singleton(cast(Optional[List[int]], sampling_config.top_k))
391+
392+
if SamplingParams.params_imply_greedy_decoding(
393+
temperature=temperature,
394+
top_p=top_p,
395+
top_k=top_k,
396+
):
393397
return GREEDY
394398

399+
# --- resolving default values
400+
# NB: not greedy, hence temperature != 0 if specified
401+
temperature = temperature or 1.0
402+
403+
# NB: not greedy, hence top_p != 0 if specified
404+
top_p = top_p or 1.0
405+
# NB: not greedy, hence top_k != 1 if specified
406+
# (0 and vocab_size are equivalent)
407+
top_k = top_k or vocab_size
408+
409+
assert top_k > 1, "non-greedy sampling requires valid top_k"
410+
need_top_k = top_k < vocab_size
411+
assert top_p > 0, "non-greedy sampling requires valid top_p"
412+
need_top_p = top_p < 1
413+
414+
if need_top_p:
415+
if need_top_k:
416+
return ("top_k_top_p", top_k, top_p, temperature)
417+
return ("top_p", top_p, temperature)
418+
if need_top_k:
419+
return ("top_k", top_k, temperature)
420+
return ("temperature", temperature)
421+
395422

396423
def _group_requests_by_sampling_strategy(
397424
requests: Iterable[LlmRequest],
398425
*,
399-
pin_memory: bool = False) -> dict[Strategy, torch.Tensor]:
426+
pin_memory: bool = False,
427+
vocab_size: int) -> dict[Strategy, torch.Tensor]:
400428
# NB: Client code relies on request indices in returned torch.Tensor being sorted.
401429
strategy_dict: dict[Strategy, list[int]] = defaultdict(list)
402430
for req_index, req in enumerate(requests):
403-
strategy_dict[_request_strategy(req)].append(req_index)
431+
strategy_dict[_request_strategy(
432+
req, vocab_size=vocab_size)].append(req_index)
404433
return {
405434
strategy: torch.tensor(indices,
406435
pin_memory=pin_memory,
@@ -418,23 +447,32 @@ def sample(
418447
) -> tuple[torch.Tensor, torch.Tensor]:
419448
filter_softmax = True
420449
match strategy:
421-
case ("top_k", top_k):
422-
tokens, softmax = top_k_sampling_batch(logits, top_k, generator)
450+
case ("top_k", top_k, temperature):
451+
tokens, softmax = top_k_sampling_batch(logits,
452+
top_k=top_k,
453+
temperature=temperature,
454+
generator=generator)
423455
case ("top_p", top_p, temperature):
424456
tokens, softmax = top_p_sampling_batch(
425457
logits,
426458
top_p=top_p,
427459
generator=generator,
428-
**(dict(temperature=temperature)
429-
if temperature is not None else dict()))
460+
temperature=temperature,
461+
)
430462
case ("top_k_top_p", top_k, top_p, temperature):
431463
tokens, softmax = top_k_top_p_sampling_batch(
432464
logits,
433465
top_k=top_k,
434466
top_p=top_p,
467+
temperature=temperature,
435468
generator=generator,
436-
**(dict(temperature=temperature)
437-
if temperature is not None else dict()))
469+
)
470+
case ("temperature", temperature):
471+
tokens, softmax = temperature_sampling_batch(
472+
logits,
473+
temperature=temperature,
474+
generator=generator,
475+
)
438476
case ("greedy", None):
439477
tokens, softmax = greedy_search_sampling_batch(
440478
logits, softmax_indices=softmax_indices)
@@ -1323,7 +1361,7 @@ def _sample_batched_by_strategy(
13231361
dim=-1)
13241362

13251363
requests_by_strategy = _group_requests_by_sampling_strategy(
1326-
requests, pin_memory=True)
1364+
requests, pin_memory=True, vocab_size=logits_cuda.size(1))
13271365
generator_cuda = self.get_generator(cuda_device)
13281366

13291367
# FIXME: This check should/could be performed in ModelDrafter.prepare_draft_tokens

0 commit comments

Comments
 (0)