-
-
Couldn't load subscription status.
- Fork 10.9k
[Kernel] Chunk-aligned mamba2 #24683
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
|
Server: Client: Results (main): Results (tpa-mamba-aligned): |
Signed-off-by: Thomas Parnell <[email protected]>
|
Server: Benchmark: Branch Branch |
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
|
More benchmarking data, this time for Server Client Results from Results from |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks great overall! Really simplifies the code
I added a few nit comments about the need for comments and documenting expected shapes
I also think it's worth adding a general comment somewhere that in this implementation we're assuming each chunk has only a single sequence, since this is a significant change which is different from the original implementation
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @tdoublep. LGTM now. The descriptions and shapes really help
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR looks great at first pass. Love to see more red than green.
|
In the figure in the PR description, why does A1.a fall at the beginning of the chunk rather than the end? I thought A0 should be ahead of it rather than behind |
@tlrmchlsmth A0 isn't actually added to the chunk, it has already been prefilled and doesn't need to be computed again. We just need to partition A1 in such a way that |
|
Do the padded regions get loaded at all?
makes sense. So then the A0-sized padded region could overlap with another chunk, or it could fall off the end of the KV cache tensor, right? Do we mask off the loads of the padded region as well? |
Signed-off-by: Thomas Parnell <[email protected]>
No, padding is maybe the wrong word. There isn't any actual padding of tensors in memory here. Masking would probably be a better word. If we have 5 chunks like in the above example, we would launch a Triton kernel with a grid size of We are basically trading off a bit of extra compute in order to get intermediate states at exactly where we want them within each sequence. It turns out it isn't really a trade-off since it strips out so much complexity, it is a net-win. |
Yes, if we don't introduce the padding/masking it will lead to (a) having multiple sequences within the same chunk and (b) needing this whole mapping between "logical" and "physical" chunks to track where everything is.
Yes, we mask off the loads exactly (example: https://github.com/tdoublep/vllm/blob/tpa-aligned-mamba/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py#L231) |
Signed-off-by: yewentao256 <[email protected]>
Signed-off-by: Tomer Asida <[email protected]>
Signed-off-by: xuebwang-amd <[email protected]>
Signed-off-by: xuebwang-amd <[email protected]>
Purpose
This PR changes the way that the mamba2 kernels split the batch into "chunks". The change ensures that (a) no chunk ever contains more than one sequence, and (b) all intermediate states are computed at the chunk boundaries within each sequence.
This change is useful for three reasons:
The downside is that it introduces some "virtual" padding inside the chunks. We don't actually pad anything in GPU memory, we just potentially need to use a larger grid when launching kernels and may do some redundant compute. However, this padding is bounded to at most one chunk per sequence, and my initial experiments suggest it really doesn't hurt a lot. In fact, we actually see a significant speedup because we skip the call to the final "varlen" kernel. We follow a very similar approach for working with varlen batches in the Triton attention kernels, so this kind of technique is not without precedent.
TODO:
seq_idxcan be made simpler - we just need to keep track of the seq_idx per chunkchunk_indicesandchunk_offsetsSimple example for two sequences A and B is shown below. A0 and B0 represent the chunks that were prefilled at the previous step, and A1 and B1 are the new chunks we want to prefill in this iteration.
The idea is that for sequence A, we first take enough tokens from the new part (A1) to ensure that, when taking together with the precomputed part (A0), the state is chunked-aligned. Then we fill chunks with new tokens (from A1) until we run out, at which we pad to the chunk boundary. Then repeat for B.
Test Plan
See correctness + benchmarking below.
Test Result
See correctness + benchmarking below.
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.