Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 46 additions & 27 deletions vllm/v1/sample/rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,8 +427,9 @@ def sample_recovered_tokens(
target_probs,
q,
vocab_size,
triton.next_power_of_2(vocab_size),
NO_DRAFT_PROBS=draft_probs is None,
BLOCK_SIZE=16384,
num_warps=8,
)
return recovered_token_ids

Expand Down Expand Up @@ -579,8 +580,8 @@ def sample_recovered_tokens_kernel(
target_probs_ptr, # [num_tokens, vocab_size]
q_ptr, # [batch_size, vocab_size]
vocab_size,
PADDED_VOCAB_SIZE: tl.constexpr,
NO_DRAFT_PROBS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
if req_idx == 0:
Expand All @@ -595,29 +596,47 @@ def sample_recovered_tokens_kernel(
if pos >= num_draft_tokens:
return

vocab_offset = tl.arange(0, PADDED_VOCAB_SIZE)
if NO_DRAFT_PROBS:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size +
vocab_offset,
mask=((vocab_offset < vocab_size) &
(vocab_offset != draft_token_id)),
other=0)
else:
draft_prob = tl.load(draft_probs_ptr + (start_idx + pos) * vocab_size +
vocab_offset,
mask=vocab_offset < vocab_size,
other=0)
target_prob = tl.load(target_probs_ptr +
(start_idx + pos) * vocab_size + vocab_offset,
mask=vocab_offset < vocab_size,
other=0)
prob = tl.maximum(target_prob - draft_prob, 0)
# NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here because
# `tl.argmax` will select the maximum value.

q = tl.load(q_ptr + req_idx * vocab_size + vocab_offset,
# online max
max_val = -float("inf")
max_idx = -1
for off in range(0, vocab_size, BLOCK_SIZE):
vocab_offset = off + tl.arange(0, BLOCK_SIZE)
if NO_DRAFT_PROBS:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
prob = tl.load(
target_probs_ptr + (start_idx + pos) * vocab_size +
vocab_offset,
mask=((vocab_offset < vocab_size) &
(vocab_offset != draft_token_id)),
other=0,
)
else:
draft_prob = tl.load(
draft_probs_ptr + (start_idx + pos) * vocab_size +
vocab_offset,
mask=vocab_offset < vocab_size,
other=0,
)
target_prob = tl.load(
target_probs_ptr + (start_idx + pos) * vocab_size +
vocab_offset,
mask=vocab_offset < vocab_size,
other=float("-inf"))
recovered_id = tl.argmax(prob / q, axis=-1)
tl.store(output_token_ids_ptr + start_idx + pos, recovered_id)
other=0,
)
prob = tl.maximum(target_prob - draft_prob, 0)

q = tl.load(
q_ptr + req_idx * vocab_size + vocab_offset,
mask=vocab_offset < vocab_size,
other=float("-inf"),
)
scores = prob / q
local_val = tl.max(scores, axis=-1)
local_idx = tl.argmax(scores, axis=-1) + off

# update global max
better = local_val > max_val
max_val = tl.where(better, local_val, max_val)
max_idx = tl.where(better, local_idx, max_idx)
Comment on lines +602 to +640
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

For performance, draft_token_id should be loaded only once before the loop since its value doesn't change across iterations. Loading it inside the loop results in redundant reads from global memory, which can negatively impact kernel performance.

    if NO_DRAFT_PROBS:
        draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
    for off in range(0, vocab_size, BLOCK_SIZE):
        vocab_offset = off + tl.arange(0, BLOCK_SIZE)
        if NO_DRAFT_PROBS:
            prob = tl.load(
                target_probs_ptr + (start_idx + pos) * vocab_size +
                vocab_offset,
                mask=((vocab_offset < vocab_size) &
                      (vocab_offset != draft_token_id)),
                other=0,
            )
        else:
            draft_prob = tl.load(
                draft_probs_ptr + (start_idx + pos) * vocab_size +
                vocab_offset,
                mask=vocab_offset < vocab_size,
                other=0,
            )
            target_prob = tl.load(
                target_probs_ptr + (start_idx + pos) * vocab_size +
                vocab_offset,
                mask=vocab_offset < vocab_size,
                other=0,
            )
            prob = tl.maximum(target_prob - draft_prob, 0)

        q = tl.load(
            q_ptr + req_idx * vocab_size + vocab_offset,
            mask=vocab_offset < vocab_size,
            other=float("-inf"),
        )
        scores = prob / q
        local_val = tl.max(scores, axis=-1)
        local_idx = tl.argmax(scores, axis=-1) + off

        # update global max
        better = local_val > max_val
        max_val = tl.where(better, local_val, max_val)
        max_idx = tl.where(better, local_idx, max_idx)


tl.store(output_token_ids_ptr + start_idx + pos, max_idx)