diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 5d8b3f423b02..4a5fbb10d408 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -89,18 +89,18 @@ def forward_cuda( p: Optional[torch.Tensor], ) -> torch.Tensor: """More optimized implementation for top-k and top-p sampling.""" - probs = logits.softmax(dim=-1, dtype=torch.float32) if k is None and p is None: # We prefer `random_sample` over `flashinfer_sample` when sorting is # not needed. This is because `random_sample` does not require # CPU-GPU synchronization while `flashinfer_sample` does. + probs = logits.softmax(dim=-1, dtype=torch.float32) return random_sample(probs, generators) if generators: logger.warning("FlashInfer 0.2.3+ does not support " "per-request generators. Falling back to " "PyTorch-native implementation.") return self.forward_native(logits, generators, k, p) - return flashinfer_sample(probs, k, p, generators) + return flashinfer_sample(logits, k, p, generators) def forward_tpu( self, @@ -254,17 +254,17 @@ def random_sample( def flashinfer_sample( - probs: torch.Tensor, + logits: torch.Tensor, k: Optional[torch.Tensor], p: Optional[torch.Tensor], generators: dict[int, torch.Generator], ) -> torch.Tensor: - """Sample from the probabilities using FlashInfer. + """Sample from the logits using FlashInfer. Statistically, this function is equivalent to the `random_sample` function. However, this function is faster because it avoids sorting the logits tensor via rejection sampling. - + NOTE: The outputs of this function do not necessarily match the outputs of the `random_sample` function. It only guarantees that the outputs are statistically equivalent. @@ -274,18 +274,19 @@ def flashinfer_sample( the synchronization overhead. """ assert not (k is None and p is None) - if k is None: # Top-p only. + probs = logits.softmax(dim=-1, dtype=torch.float32) next_token_ids = flashinfer.sampling.top_p_sampling_from_probs( probs, p, deterministic=True) elif p is None: # Top-k only. + probs = logits.softmax(dim=-1, dtype=torch.float32) next_token_ids = flashinfer.sampling.top_k_sampling_from_probs( probs, k, deterministic=True) else: # Both top-k and top-p. - next_token_ids = (flashinfer.sampling.top_k_top_p_sampling_from_probs( - probs, k, p, deterministic=True)) + next_token_ids = flashinfer.sampling.top_k_top_p_sampling_from_logits( + logits, k, p, deterministic=True) return next_token_ids.view(-1)