Skip to content

Conversation

happierpig
Copy link

@happierpig happierpig commented Sep 26, 2025

Purpose

  1. Optimize sample_recovered_tokens_kernel implementation by unrolling CTA over the vocab_size dimension.
  2. Adding more parallelism by using more num_warps, could be helpful for small batch sizes.

Test Plan

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request optimizes the sample_recovered_tokens_kernel Triton kernel by unrolling the computation over the vocabulary size. This is a good performance optimization for large vocabularies. The implementation of the online maximum calculation is correct. However, I've identified a performance issue where a value is redundantly loaded within a loop. My review includes a suggestion to fix this.

Comment on lines +602 to +640
for off in range(0, vocab_size, BLOCK_SIZE):
vocab_offset = off + tl.arange(0, BLOCK_SIZE)
if NO_DRAFT_PROBS:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
prob = tl.load(
target_probs_ptr + (start_idx + pos) * vocab_size +
vocab_offset,
mask=((vocab_offset < vocab_size) &
(vocab_offset != draft_token_id)),
other=0,
)
else:
draft_prob = tl.load(
draft_probs_ptr + (start_idx + pos) * vocab_size +
vocab_offset,
mask=vocab_offset < vocab_size,
other=0,
)
target_prob = tl.load(
target_probs_ptr + (start_idx + pos) * vocab_size +
vocab_offset,
mask=vocab_offset < vocab_size,
other=float("-inf"))
recovered_id = tl.argmax(prob / q, axis=-1)
tl.store(output_token_ids_ptr + start_idx + pos, recovered_id)
other=0,
)
prob = tl.maximum(target_prob - draft_prob, 0)

q = tl.load(
q_ptr + req_idx * vocab_size + vocab_offset,
mask=vocab_offset < vocab_size,
other=float("-inf"),
)
scores = prob / q
local_val = tl.max(scores, axis=-1)
local_idx = tl.argmax(scores, axis=-1) + off

# update global max
better = local_val > max_val
max_val = tl.where(better, local_val, max_val)
max_idx = tl.where(better, local_idx, max_idx)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

For performance, draft_token_id should be loaded only once before the loop since its value doesn't change across iterations. Loading it inside the loop results in redundant reads from global memory, which can negatively impact kernel performance.

    if NO_DRAFT_PROBS:
        draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
    for off in range(0, vocab_size, BLOCK_SIZE):
        vocab_offset = off + tl.arange(0, BLOCK_SIZE)
        if NO_DRAFT_PROBS:
            prob = tl.load(
                target_probs_ptr + (start_idx + pos) * vocab_size +
                vocab_offset,
                mask=((vocab_offset < vocab_size) &
                      (vocab_offset != draft_token_id)),
                other=0,
            )
        else:
            draft_prob = tl.load(
                draft_probs_ptr + (start_idx + pos) * vocab_size +
                vocab_offset,
                mask=vocab_offset < vocab_size,
                other=0,
            )
            target_prob = tl.load(
                target_probs_ptr + (start_idx + pos) * vocab_size +
                vocab_offset,
                mask=vocab_offset < vocab_size,
                other=0,
            )
            prob = tl.maximum(target_prob - draft_prob, 0)

        q = tl.load(
            q_ptr + req_idx * vocab_size + vocab_offset,
            mask=vocab_offset < vocab_size,
            other=float("-inf"),
        )
        scores = prob / q
        local_val = tl.max(scores, axis=-1)
        local_idx = tl.argmax(scores, axis=-1) + off

        # update global max
        better = local_val > max_val
        max_val = tl.where(better, local_val, max_val)
        max_idx = tl.where(better, local_idx, max_idx)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant