Skip to content

Conversation

@simondanielsson
Copy link
Contributor

@simondanielsson simondanielsson commented Oct 14, 2025

Purpose

Part of #26201.

Adds Automatic Prefix Caching for GDN. Tries to be similar to APC for Mamba2 as introduced in #25752.

Specifically:

  • Extends the gated-delta chunk kernel to optionally return per-chunk intermediate states (flattening them into a contiguous stream so callers can repopulate prefix cache blocks).
  • Updates Qwen3NextGatedDeltaNet to recycle cached states during decode by copying the last computed block into the newly scheduled slot, and during prefill to replay the returned chunk history into persistent SSM cache blocks so later tokens can hit the prefix cache

Latency benchmark (APC ("default") vs no-APC ("default-noapc")):
image

TODOs:

  • Add better logic for making the kernel return intermediate states rather than using GDN_RECOMPUTE_SUPPRESS_LEVEL=4.
  • Make it work with fullgraph (decode)
  • Extend APC test suite to also run on qwen3-next (tiny random)
  • Run latency benchmarks on small model
  • Benchmark on 80B-A3 (I will need help from someone here)

Outstanding tasks, not captured here:

  • Support specdec

Test Plan

Note: this runs only with the tiny tiny-random/qwen3-next-moe model, as I only have an L4 with 20GB VRAM. Would be great if someone could try also with Qwen3-Next-80B-A3B

from vllm import LLM, SamplingParams
from vllm.distributed import cleanup_dist_env_and_memory
import time

if __name__ == "__main__":
    # Note: should be tested with Qwen/Qwen3-Next-80B-A3B-Instruct
    MODEL = "tiny-random/qwen3-next-moe"
    PROMPT_MULTIPLE = 310
    sampling_params = SamplingParams(temperature=0.0)
    prefix = (  # examples/offline_inference/prefix_caching.py
        "You are an expert school principal, skilled in effectively managing "
        "faculty and staff. Draft 10-15 questions for a potential first grade "
        "Head Teacher for my K-12, all-girls', independent school that emphasizes "
        "community, joyful discovery, and life-long learning. The candidate is "
        "coming in for a first-round panel interview for a 8th grade Math "
        "teaching role. They have 5 years of previous teaching experience "
        "as an assistant teacher at a co-ed, public school with experience "
        "in middle school math teaching. "
    )
    prefix2 = "Based on these information, fulfill " "the following paragraph: "
    prompt = PROMPT_MULTIPLE * prefix + prefix2 + "Hello, my name is"
    print("Prompt length:", len(prompt))
    for APC in [True, False]:
        engine = LLM(
            model=MODEL,
            enable_prefix_caching=APC,
            gpu_memory_utilization=0.3,
            disable_log_stats=False,
        )
        for i in range(3):
            if i == 0:
                print("Warm-up")
            if i == 1:
                print("Measuring")
                start_time = time.time()
            outputs = engine.generate(prompt, sampling_params)
            print("APC:", APC, i, f"Generated text: {outputs[0].outputs[0].text!r}")
            for m in engine.llm_engine.get_metrics():
                if "vllm:prefix_cache_hits" in m.name:
                    print(m.name, m.value)
        print("APC:", APC, "loop took --- %s seconds ---" % (time.time() - start_time))
        del engine
        cleanup_dist_env_and_memory()

Test Result

Note: gibberish output due to random model.

No cudagraphs (enforce_eager=True):

Warm-up
APC: True 0 Generated text: ' estado Bernieatial oggi_five뉼หน้าที่wordpressหน้าที่ibelENCY荁=x Color Gh [],\r\n'
vllm:prefix_cache_hits 0
Measuring
APC: True 1 Generated text: ' estado Bernieatial oggi_five뉼หน้าที่wordpressหน้าที่ibelENCY荁=x Color Gh [],\r\n'
vllm:prefix_cache_hits 31680
APC: True 2 Generated text: ' estado Bernieatial oggi_five뉼หน้าที่wordpressหน้าที่ibelENCY荁=x Color Gh [],\r\n'
vllm:prefix_cache_hits 63360
APC: True loop took --- 0.7412824630737305 seconds ---

Warm-up
APC: False 0 Generated text: ' estado Bernieatial oggi_five뉼หน้าที่wordpressหน้าที่ibelENCY荁=x Color Gh [],\r\n'
vllm:prefix_cache_hits 0
Measuring
APC: False 1 Generated text: ' estado Bernieatial oggi_five뉼หน้าที่wordpressหน้าที่ibelENCY荁=x Color Gh [],\r\n'
vllm:prefix_cache_hits 0
APC: False 2 Generated text: ' estado Bernieatial oggi_five뉼หน้าที่wordpressหน้าที่ibelENCY荁=x Color Gh [],\r\n'
vllm:prefix_cache_hits 0
APC: False loop took --- 0.9228880405426025 seconds ---

With cudagraphs (enforce_eager=False):

Capturing CUDA graphs (mixed prefill-decode, PIECEWISE): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:02<00:00, 24.18it/s]
Capturing CUDA graphs (decode, FULL): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [00:03<00:00,  8.98it/s]
INFO 10-14 13:44:50 [gpu_model_runner.py:3821] Graph capturing finished in 7 secs, took 0.34 GiB
INFO 10-14 13:44:50 [core.py:242] init engine (profile, create kv cache, warmup model) took 25.02 seconds
INFO 10-14 13:44:51 [loggers.py:191] Engine 000: vllm cache_config_info with initialization after num_gpu_blocks is: 10969
INFO 10-14 13:44:51 [llm.py:335] Supported tasks: ('generate',)
Warm-up
APC: True 0 Generated text: ' estado Bernie阿拉 remotelySr春晚 ứngibelENCYcancel scientificallyResidentsnah Stout__))荁'
vllm:prefix_cache_hits 0
Measuring
APC: True 1 Generated text: ' estado Bernieatial oggi_five뉼หน้าที่wordpressหน้าที่ibelENCY荁=x Color Gh [],\r\n'
vllm:prefix_cache_hits 31680
APC: True 2 Generated text: ' estado Bernieatial oggi_five뉼หน้าที่wordpressหน้าที่ibelENCY荁=x Color Gh [],\r\n'
vllm:prefix_cache_hits 63360
APC: True loop took --- 0.3312194347381592 seconds ---

Capturing CUDA graphs (mixed prefill-decode, PIECEWISE): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:00<00:00, 72.41it/s]
Capturing CUDA graphs (decode, FULL): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [00:01<00:00, 17.99it/s]
INFO 10-14 13:54:26 [gpu_model_runner.py:3821] Graph capturing finished in 3 secs, took 0.20 GiB
INFO 10-14 13:54:26 [core.py:242] init engine (profile, create kv cache, warmup model) took 8.07 seconds
INFO 10-14 13:54:27 [loggers.py:191] Engine 000: vllm cache_config_info with initialization after num_gpu_blocks is: 11615
INFO 10-14 13:54:27 [llm.py:335] Supported tasks: ('generate',)
Warm-up
APC: False 0 Generated text: ' estado Bernieatial oggi_five뉼หน้าที่wordpressหน้าที่ibelENCY荁=x Color Gh [],\r\n'
vllm:prefix_cache_hits 0
Measuring
APC: False 1 Generated text: ' estado Bernieatial oggi_five뉼หน้าที่wordpressหน้าที่ibelENCY荁=x Color Gh [],\r\n'
vllm:prefix_cache_hits 0
APC: False 2 Generated text: ' estado Bernieatial oggi_five뉼หน้าที่wordpressหน้าที่ibelENCY荁=x Color Gh [],\r\n'
vllm:prefix_cache_hits 0
APC: False loop took --- 0.5677089691162109 seconds ---

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@mergify mergify bot added qwen Related to Qwen models v1 labels Oct 14, 2025
@mergify
Copy link

mergify bot commented Oct 14, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @simondanielsson.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 14, 2025
Signed-off-by: simondanielsson <[email protected]>
@mergify mergify bot removed the needs-rebase label Oct 14, 2025
Signed-off-by: simondanielsson <[email protected]>
Signed-off-by: simondanielsson <[email protected]>
@simondanielsson simondanielsson changed the title [Feature] GatedDeltaNet Automatic Prefix Caching [V1][Hybrid] GatedDeltaNet Automatic Prefix Caching Oct 16, 2025
@simondanielsson simondanielsson marked this pull request as ready for review October 16, 2025 15:40
@simondanielsson
Copy link
Contributor Author

@codex review

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting

@simondanielsson
Copy link
Contributor Author

@codex review

@chatgpt-codex-connector
Copy link

Codex Review: Didn't find any major issues. Another round soon, please!

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting

Signed-off-by: simondanielsson <[email protected]>
Signed-off-by: simondanielsson <[email protected]>
Copy link
Member

@tdoublep tdoublep left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for working on this! I have some initial comments and questions.

I can help to benchmark this on H100

Comment on lines +30 to +31
GDN_MODELS = ["tiny-random/qwen3-next-moe"]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any specific reason to split it off into GDN models?

):
raise ValueError(
"GDN prefix caching requires the mamba block size to be a "
"multiple of the kernel chunk size."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe include self.chunk_size in the error message to help guide the user to set it correctly?


# Decode-side APC metadata
state_indices_tensor_d: torch.Tensor | None = None
state_indices_tensor_p: torch.Tensor | None = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this tensor with _p to the section below?

Comment on lines +174 to +178
self.state_indices_tensor_p_buf = torch.empty(
(self.decode_cudagraph_max_bs, self._max_cached_blocks),
dtype=torch.int32,
device=device,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to use these buffers to tensors that relate to prefill, because we don't use full CUDA graphs for batches that contain prefills.

Comment on lines +190 to +193
max_num_prefill_chunks = (
cdiv(vllm_config.model_config.max_model_len, self.chunk_size)
* self.decode_cudagraph_max_bs
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why the number of prefill chunks could be related to the maximum decode-only batch size

Comment on lines +199 to +228
self.cu_chunk_seqlen_p_buf = torch.empty(
(max_num_prefill_chunks + 1,),
dtype=torch.int32,
device=device,
)
self.last_chunk_indices_p_buf = torch.empty(
(self.decode_cudagraph_max_bs,),
dtype=torch.int32,
device=device,
)
self.num_computed_tokens_p_buf = torch.empty(
(self.decode_cudagraph_max_bs,),
dtype=torch.int32,
device=device,
)
self.block_idx_first_scheduled_token_p_buf = torch.empty(
(self.decode_cudagraph_max_bs,),
dtype=torch.int32,
device=device,
)
self.block_idx_last_computed_token_p_buf = torch.empty(
(self.decode_cudagraph_max_bs,),
dtype=torch.int32,
device=device,
)
self.block_idx_last_scheduled_token_p_buf = torch.empty(
(self.decode_cudagraph_max_bs,),
dtype=torch.int32,
device=device,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question for all of these prefill tensors - why do we need to use static buffers?

torch.int32
)

if spec_sequence_masks is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, this PR does not attempt to support APC + spec dec. Could we simplify this logic by just raising if spec decode is enabled?

num_computed_tokens_cpu_non_spec = m.num_computed_tokens_cpu

if num_decodes > 0:
state_indices_tensor_d = non_spec_block_table[:num_decodes].contiguous()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need the .contiguous on these?

Comment on lines +475 to +478
cu_chunk_seqlen: list[int] = []
seq_idx_list: list[int] = []
last_chunk_indices_list: list[int] = []
seqlen_pos = 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where do these actually get used by the model?

@mergify
Copy link

mergify bot commented Nov 7, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @simondanielsson.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

needs-rebase qwen Related to Qwen models v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants