-
-
Notifications
You must be signed in to change notification settings - Fork 10.4k
perf: optimize rejection sampling triton kernel #25791
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request optimizes the sample_recovered_tokens_kernel
Triton kernel by unrolling the computation over the vocabulary size. This is a good performance optimization for large vocabularies. The implementation of the online maximum calculation is correct. However, I've identified a performance issue where a value is redundantly loaded within a loop. My review includes a suggestion to fix this.
for off in range(0, vocab_size, BLOCK_SIZE): | ||
vocab_offset = off + tl.arange(0, BLOCK_SIZE) | ||
if NO_DRAFT_PROBS: | ||
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) | ||
prob = tl.load( | ||
target_probs_ptr + (start_idx + pos) * vocab_size + | ||
vocab_offset, | ||
mask=((vocab_offset < vocab_size) & | ||
(vocab_offset != draft_token_id)), | ||
other=0, | ||
) | ||
else: | ||
draft_prob = tl.load( | ||
draft_probs_ptr + (start_idx + pos) * vocab_size + | ||
vocab_offset, | ||
mask=vocab_offset < vocab_size, | ||
other=0, | ||
) | ||
target_prob = tl.load( | ||
target_probs_ptr + (start_idx + pos) * vocab_size + | ||
vocab_offset, | ||
mask=vocab_offset < vocab_size, | ||
other=float("-inf")) | ||
recovered_id = tl.argmax(prob / q, axis=-1) | ||
tl.store(output_token_ids_ptr + start_idx + pos, recovered_id) | ||
other=0, | ||
) | ||
prob = tl.maximum(target_prob - draft_prob, 0) | ||
|
||
q = tl.load( | ||
q_ptr + req_idx * vocab_size + vocab_offset, | ||
mask=vocab_offset < vocab_size, | ||
other=float("-inf"), | ||
) | ||
scores = prob / q | ||
local_val = tl.max(scores, axis=-1) | ||
local_idx = tl.argmax(scores, axis=-1) + off | ||
|
||
# update global max | ||
better = local_val > max_val | ||
max_val = tl.where(better, local_val, max_val) | ||
max_idx = tl.where(better, local_idx, max_idx) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For performance, draft_token_id
should be loaded only once before the loop since its value doesn't change across iterations. Loading it inside the loop results in redundant reads from global memory, which can negatively impact kernel performance.
if NO_DRAFT_PROBS:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
for off in range(0, vocab_size, BLOCK_SIZE):
vocab_offset = off + tl.arange(0, BLOCK_SIZE)
if NO_DRAFT_PROBS:
prob = tl.load(
target_probs_ptr + (start_idx + pos) * vocab_size +
vocab_offset,
mask=((vocab_offset < vocab_size) &
(vocab_offset != draft_token_id)),
other=0,
)
else:
draft_prob = tl.load(
draft_probs_ptr + (start_idx + pos) * vocab_size +
vocab_offset,
mask=vocab_offset < vocab_size,
other=0,
)
target_prob = tl.load(
target_probs_ptr + (start_idx + pos) * vocab_size +
vocab_offset,
mask=vocab_offset < vocab_size,
other=0,
)
prob = tl.maximum(target_prob - draft_prob, 0)
q = tl.load(
q_ptr + req_idx * vocab_size + vocab_offset,
mask=vocab_offset < vocab_size,
other=float("-inf"),
)
scores = prob / q
local_val = tl.max(scores, axis=-1)
local_idx = tl.argmax(scores, axis=-1) + off
# update global max
better = local_val > max_val
max_val = tl.where(better, local_val, max_val)
max_idx = tl.where(better, local_idx, max_idx)
Purpose
sample_recovered_tokens_kernel
implementation by unrolling CTA over thevocab_size
dimension.num_warps
, could be helpful for small batch sizes.Test Plan
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.md
andexamples
for a new model.