diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 8f0b38ecb34d..59ac1331e39e 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -427,8 +427,9 @@ def sample_recovered_tokens( target_probs, q, vocab_size, - triton.next_power_of_2(vocab_size), NO_DRAFT_PROBS=draft_probs is None, + BLOCK_SIZE=16384, + num_warps=8, ) return recovered_token_ids @@ -579,8 +580,8 @@ def sample_recovered_tokens_kernel( target_probs_ptr, # [num_tokens, vocab_size] q_ptr, # [batch_size, vocab_size] vocab_size, - PADDED_VOCAB_SIZE: tl.constexpr, NO_DRAFT_PROBS: tl.constexpr, + BLOCK_SIZE: tl.constexpr, ): req_idx = tl.program_id(0) if req_idx == 0: @@ -595,29 +596,47 @@ def sample_recovered_tokens_kernel( if pos >= num_draft_tokens: return - vocab_offset = tl.arange(0, PADDED_VOCAB_SIZE) - if NO_DRAFT_PROBS: - draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) - prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + - vocab_offset, - mask=((vocab_offset < vocab_size) & - (vocab_offset != draft_token_id)), - other=0) - else: - draft_prob = tl.load(draft_probs_ptr + (start_idx + pos) * vocab_size + - vocab_offset, - mask=vocab_offset < vocab_size, - other=0) - target_prob = tl.load(target_probs_ptr + - (start_idx + pos) * vocab_size + vocab_offset, - mask=vocab_offset < vocab_size, - other=0) - prob = tl.maximum(target_prob - draft_prob, 0) - # NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here because - # `tl.argmax` will select the maximum value. - - q = tl.load(q_ptr + req_idx * vocab_size + vocab_offset, + # online max + max_val = -float("inf") + max_idx = -1 + for off in range(0, vocab_size, BLOCK_SIZE): + vocab_offset = off + tl.arange(0, BLOCK_SIZE) + if NO_DRAFT_PROBS: + draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) + prob = tl.load( + target_probs_ptr + (start_idx + pos) * vocab_size + + vocab_offset, + mask=((vocab_offset < vocab_size) & + (vocab_offset != draft_token_id)), + other=0, + ) + else: + draft_prob = tl.load( + draft_probs_ptr + (start_idx + pos) * vocab_size + + vocab_offset, + mask=vocab_offset < vocab_size, + other=0, + ) + target_prob = tl.load( + target_probs_ptr + (start_idx + pos) * vocab_size + + vocab_offset, mask=vocab_offset < vocab_size, - other=float("-inf")) - recovered_id = tl.argmax(prob / q, axis=-1) - tl.store(output_token_ids_ptr + start_idx + pos, recovered_id) + other=0, + ) + prob = tl.maximum(target_prob - draft_prob, 0) + + q = tl.load( + q_ptr + req_idx * vocab_size + vocab_offset, + mask=vocab_offset < vocab_size, + other=float("-inf"), + ) + scores = prob / q + local_val = tl.max(scores, axis=-1) + local_idx = tl.argmax(scores, axis=-1) + off + + # update global max + better = local_val > max_val + max_val = tl.where(better, local_val, max_val) + max_idx = tl.where(better, local_idx, max_idx) + + tl.store(output_token_ids_ptr + start_idx + pos, max_idx)