Skip to content

Conversation

@MasterJH5574
Copy link
Contributor

This PR introduces two changes to the (paged) KV cache:

The first is introducing RoPE mode to PagedKVCache. Right now there are two modes: normal/inline. In "normal" mode, RoPE will be applied to input Q/K/V data before appending the K/V data to cache. In "inline" mode, the input K/V data is directly appending to cache, and the RoPE will be on-the-fly applied inside attention kernel. The main purpose of introducing RoPE mode is to balance the need of on-the-fly RoPE (in cases like Mistral where positions can cahnge) and the attention kernel performance.

The second is introducing a new interface AttentionWithFusedQKV to KV cache. This function takes the input QKV data that is fused along the head dimension. And the fused QKV will be split into separate Q/K/V internally (note: requiring external workspace passed in). We introduce this function since in practice we note that when RoPE mode is "normal," it offers better performance if we fuse the QKV split and RoPE application.

MasterJH5574 added a commit to MasterJH5574/mlc-llm that referenced this pull request Jan 23, 2024
Following apache/tvm#16456, this PR leverages the RoPE
mode and AttentionWithFusedQKV function in llama.
@MasterJH5574 MasterJH5574 marked this pull request as draft January 23, 2024 16:40
MasterJH5574 added a commit to MasterJH5574/mlc-llm that referenced this pull request Jan 23, 2024
Following apache/tvm#16456, this PR leverages the RoPE
mode and AttentionWithFusedQKV function in llama.
This PR introduces two changes to the (paged) KV cache:

The first is introducing RoPE mode to PagedKVCache. Right now there are
two modes: normal/inline. In "normal" mode, RoPE will be applied
to input Q/K/V data before appending the K/V data to cache.
In "inline" mode, the input K/V data is directly appending to
cache, and the RoPE will be on-the-fly applied inside attention
kernel. The main purpose of introducing RoPE mode is to balance
the need of on-the-fly RoPE (in cases like Mistral where positions
can cahnge) and the attention kernel performance.

The second is introducing a new interface `AttentionWithFusedQKV`
to KV cache. This function takes the input QKV data that is fused
along the head dimension. And the fused QKV will be split into
separate Q/K/V internally (note: requiring external workspace passed
in). We introduce this function since in practice we note that
when RoPE mode is "normal," it offers better performance if we
fuse the QKV split and RoPE application.
MasterJH5574 added a commit to MasterJH5574/mlc-llm that referenced this pull request Jan 23, 2024
Following apache/tvm#16456, this PR leverages the RoPE
mode and AttentionWithFusedQKV function in llama.
@MasterJH5574 MasterJH5574 force-pushed the tvm-dev/2024-01-23-kv-cache-rope-mode branch from 12ad561 to 95ccc52 Compare January 23, 2024 20:44
MasterJH5574 added a commit to MasterJH5574/mlc-llm that referenced this pull request Jan 23, 2024
Following apache/tvm#16456, this PR leverages the RoPE
mode and AttentionWithFusedQKV function in llama.
MasterJH5574 added a commit to MasterJH5574/mlc-llm that referenced this pull request Jan 23, 2024
Following apache/tvm#16456, this PR leverages the RoPE
mode and AttentionWithFusedQKV function in llama.
MasterJH5574 added a commit to MasterJH5574/mlc-llm that referenced this pull request Jan 23, 2024
Following apache/tvm#16456, this PR leverages the RoPE
mode and AttentionWithFusedQKV function in llama.
MasterJH5574 added a commit to MasterJH5574/mlc-llm that referenced this pull request Jan 23, 2024
Following apache/tvm#16456, this PR leverages the RoPE
mode and AttentionWithFusedQKV function in llama.
MasterJH5574 added a commit to MasterJH5574/mlc-llm that referenced this pull request Jan 23, 2024
Following apache/tvm#16456, this PR leverages the RoPE
mode and AttentionWithFusedQKV function in llama.
MasterJH5574 added a commit to MasterJH5574/mlc-llm that referenced this pull request Jan 23, 2024
Following apache/tvm#16456, this PR leverages the RoPE
mode and AttentionWithFusedQKV function in llama.
MasterJH5574 added a commit to MasterJH5574/mlc-llm that referenced this pull request Jan 23, 2024
Following apache/tvm#16456, this PR leverages the RoPE
mode and AttentionWithFusedQKV function in llama.
MasterJH5574 added a commit to MasterJH5574/mlc-llm that referenced this pull request Jan 23, 2024
Following apache/tvm#16456, this PR leverages the RoPE
mode and AttentionWithFusedQKV function in llama.
@MasterJH5574 MasterJH5574 marked this pull request as ready for review January 23, 2024 21:17
@tqchen tqchen merged commit 20b08a5 into apache:main Jan 24, 2024
MasterJH5574 added a commit to MasterJH5574/mlc-llm that referenced this pull request Jan 24, 2024
Following apache/tvm#16456, this PR leverages the RoPE
mode and AttentionWithFusedQKV function in llama.
MasterJH5574 added a commit to MasterJH5574/mlc-llm that referenced this pull request Jan 25, 2024
Following apache/tvm#16456, this PR leverages the RoPE
mode and AttentionWithFusedQKV function in llama.
MasterJH5574 added a commit to mlc-ai/mlc-llm that referenced this pull request Jan 25, 2024
Following apache/tvm#16456, this PR leverages the RoPE
mode and AttentionWithFusedQKV function in llama.
smickey040404 added a commit to smickey040404/mlc-llm that referenced this pull request Feb 11, 2025
Following apache/tvm#16456, this PR leverages the RoPE
mode and AttentionWithFusedQKV function in llama.
tristankincaid added a commit to tristankincaid/mlc-llm that referenced this pull request Feb 16, 2025
Following apache/tvm#16456, this PR leverages the RoPE
mode and AttentionWithFusedQKV function in llama.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants