-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[Runtime][KVCache] AttentionWithFusedQKV and RoPE mode #16456
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
tqchen
merged 1 commit into
apache:main
from
MasterJH5574:tvm-dev/2024-01-23-kv-cache-rope-mode
Jan 24, 2024
Merged
[Runtime][KVCache] AttentionWithFusedQKV and RoPE mode #16456
tqchen
merged 1 commit into
apache:main
from
MasterJH5574:tvm-dev/2024-01-23-kv-cache-rope-mode
Jan 24, 2024
+627
−220
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
MasterJH5574
added a commit
to MasterJH5574/mlc-llm
that referenced
this pull request
Jan 23, 2024
Following apache/tvm#16456, this PR leverages the RoPE mode and AttentionWithFusedQKV function in llama.
MasterJH5574
added a commit
to MasterJH5574/mlc-llm
that referenced
this pull request
Jan 23, 2024
Following apache/tvm#16456, this PR leverages the RoPE mode and AttentionWithFusedQKV function in llama.
This PR introduces two changes to the (paged) KV cache: The first is introducing RoPE mode to PagedKVCache. Right now there are two modes: normal/inline. In "normal" mode, RoPE will be applied to input Q/K/V data before appending the K/V data to cache. In "inline" mode, the input K/V data is directly appending to cache, and the RoPE will be on-the-fly applied inside attention kernel. The main purpose of introducing RoPE mode is to balance the need of on-the-fly RoPE (in cases like Mistral where positions can cahnge) and the attention kernel performance. The second is introducing a new interface `AttentionWithFusedQKV` to KV cache. This function takes the input QKV data that is fused along the head dimension. And the fused QKV will be split into separate Q/K/V internally (note: requiring external workspace passed in). We introduce this function since in practice we note that when RoPE mode is "normal," it offers better performance if we fuse the QKV split and RoPE application.
MasterJH5574
added a commit
to MasterJH5574/mlc-llm
that referenced
this pull request
Jan 23, 2024
Following apache/tvm#16456, this PR leverages the RoPE mode and AttentionWithFusedQKV function in llama.
12ad561 to
95ccc52
Compare
MasterJH5574
added a commit
to MasterJH5574/mlc-llm
that referenced
this pull request
Jan 23, 2024
Following apache/tvm#16456, this PR leverages the RoPE mode and AttentionWithFusedQKV function in llama.
MasterJH5574
added a commit
to MasterJH5574/mlc-llm
that referenced
this pull request
Jan 23, 2024
Following apache/tvm#16456, this PR leverages the RoPE mode and AttentionWithFusedQKV function in llama.
MasterJH5574
added a commit
to MasterJH5574/mlc-llm
that referenced
this pull request
Jan 23, 2024
Following apache/tvm#16456, this PR leverages the RoPE mode and AttentionWithFusedQKV function in llama.
MasterJH5574
added a commit
to MasterJH5574/mlc-llm
that referenced
this pull request
Jan 23, 2024
Following apache/tvm#16456, this PR leverages the RoPE mode and AttentionWithFusedQKV function in llama.
MasterJH5574
added a commit
to MasterJH5574/mlc-llm
that referenced
this pull request
Jan 23, 2024
Following apache/tvm#16456, this PR leverages the RoPE mode and AttentionWithFusedQKV function in llama.
MasterJH5574
added a commit
to MasterJH5574/mlc-llm
that referenced
this pull request
Jan 23, 2024
Following apache/tvm#16456, this PR leverages the RoPE mode and AttentionWithFusedQKV function in llama.
MasterJH5574
added a commit
to MasterJH5574/mlc-llm
that referenced
this pull request
Jan 23, 2024
Following apache/tvm#16456, this PR leverages the RoPE mode and AttentionWithFusedQKV function in llama.
MasterJH5574
added a commit
to MasterJH5574/mlc-llm
that referenced
this pull request
Jan 23, 2024
Following apache/tvm#16456, this PR leverages the RoPE mode and AttentionWithFusedQKV function in llama.
tqchen
approved these changes
Jan 23, 2024
MasterJH5574
added a commit
to MasterJH5574/mlc-llm
that referenced
this pull request
Jan 24, 2024
Following apache/tvm#16456, this PR leverages the RoPE mode and AttentionWithFusedQKV function in llama.
MasterJH5574
added a commit
to MasterJH5574/mlc-llm
that referenced
this pull request
Jan 25, 2024
Following apache/tvm#16456, this PR leverages the RoPE mode and AttentionWithFusedQKV function in llama.
MasterJH5574
added a commit
to mlc-ai/mlc-llm
that referenced
this pull request
Jan 25, 2024
Following apache/tvm#16456, this PR leverages the RoPE mode and AttentionWithFusedQKV function in llama.
smickey040404
added a commit
to smickey040404/mlc-llm
that referenced
this pull request
Feb 11, 2025
Following apache/tvm#16456, this PR leverages the RoPE mode and AttentionWithFusedQKV function in llama.
tristankincaid
added a commit
to tristankincaid/mlc-llm
that referenced
this pull request
Feb 16, 2025
Following apache/tvm#16456, this PR leverages the RoPE mode and AttentionWithFusedQKV function in llama.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR introduces two changes to the (paged) KV cache:
The first is introducing RoPE mode to PagedKVCache. Right now there are two modes: normal/inline. In "normal" mode, RoPE will be applied to input Q/K/V data before appending the K/V data to cache. In "inline" mode, the input K/V data is directly appending to cache, and the RoPE will be on-the-fly applied inside attention kernel. The main purpose of introducing RoPE mode is to balance the need of on-the-fly RoPE (in cases like Mistral where positions can cahnge) and the attention kernel performance.
The second is introducing a new interface
AttentionWithFusedQKVto KV cache. This function takes the input QKV data that is fused along the head dimension. And the fused QKV will be split into separate Q/K/V internally (note: requiring external workspace passed in). We introduce this function since in practice we note that when RoPE mode is "normal," it offers better performance if we fuse the QKV split and RoPE application.