Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 81 additions & 3 deletions vllm/v1/sample/ops/topk_topp_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,18 @@ def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs") -> None:
self.forward = self.forward_native
elif current_platform.is_cpu():
self.forward = self.forward_cpu
elif (
logprobs_mode not in ("processed_logits", "processed_logprobs")
and current_platform.is_rocm()
and envs.VLLM_ROCM_USE_AITER
):
import aiter.ops.sampling # noqa: F401

self.aiter_ops = torch.ops.aiter
logger.info_once(
"Using aiter sampler on ROCm (lazy import, sampling-only)."
)
self.forward = self.forward_hip
else:
self.forward = self.forward_native

Expand Down Expand Up @@ -120,9 +132,10 @@ def forward_cuda(
"PyTorch-native implementation."
)
return self.forward_native(logits, generators, k, p)
assert self.logprobs_mode not in ("processed_logits", "processed_logprobs"), (
"FlashInfer does not support returning logits/logprobs"
)
assert self.logprobs_mode not in (
"processed_logits",
"processed_logprobs",
), "FlashInfer does not support returning logits/logprobs"
# flashinfer sampling functions expect contiguous logits.
# In flex_attn/triton_attn fp32 inference, logits can be non-contiguous
# because of slicing operation in logits_processor.
Expand Down Expand Up @@ -167,6 +180,64 @@ def compiled_random_sample(logits: torch.Tensor) -> torch.Tensor:

return probs.div_(q).argmax(dim=-1).view(-1), logits_to_return

def forward_hip(
self,
logits: torch.Tensor,
generators: dict[int, torch.Generator],
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Optimized ROCm/aiter path (same structure as forward_cuda)."""
if (k is None and p is None) or generators:
if generators:
logger.warning_once(
"aiter sampler does not support per-request generators; "
"falling back to PyTorch-native."
)
return self.forward_native(logits, generators, k, p)
assert self.logprobs_mode not in (
"processed_logits",
"processed_logprobs",
), "aiter sampler does not support returning logits/logprobs."
return self.aiter_sample(logits, k, p, generators), None

def aiter_sample(
self,
logits: torch.Tensor,
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
generators: dict[int, torch.Generator],
) -> torch.Tensor:
"""Sample from logits using aiter ops."""
use_top_k = k is not None
use_top_p = p is not None
# Joint k+p path
if use_top_p and use_top_k:
probs = logits.softmax(dim=-1, dtype=torch.float32).contiguous()
next_token_ids = self.aiter_ops.top_k_top_p_sampling_from_probs(
probs,
None,
*_to_tensor_scalar_tuple(k),
*_to_tensor_scalar_tuple(p),
deterministic=True,
)
return next_token_ids.view(-1)
# Top-p only path
elif use_top_p:
probs = logits.softmax(dim=-1, dtype=torch.float32).contiguous()
next_token_ids = self.aiter_ops.top_p_sampling_from_probs(
probs, None, *_to_tensor_scalar_tuple(p), deterministic=True
)
return next_token_ids.view(-1)
# Top-k only path
elif use_top_k:
probs = logits.softmax(dim=-1, dtype=torch.float32).contiguous()
renorm_probs = self.aiter_ops.top_k_renorm_probs(
probs, *_to_tensor_scalar_tuple(k)
)
return torch.multinomial(renorm_probs, num_samples=1).view(-1)
raise RuntimeError("aiter_sample was called with no active top-k or top-p.")


def apply_top_k_top_p(
logits: torch.Tensor,
Expand Down Expand Up @@ -300,3 +371,10 @@ def flashinfer_sample(
)

return next_token_ids.view(-1)


def _to_tensor_scalar_tuple(x):
if isinstance(x, torch.Tensor):
return (x, 0)
else:
return (None, x)