Skip to content

Commit c0c0624

Browse files
zyongyeyoukaichaomgoinyewentao256heheda12345
authored
Adding pytorch impl for Paged Indexer (vllm-project#9)
* code from ds Signed-off-by: youkaichao <[email protected]> * doc from ds Signed-off-by: youkaichao <[email protected]> * Fixes for support_materials/2-tilelang/ Signed-off-by: mgoin <[email protected]> * Fix example 1 Signed-off-by: mgoin <[email protected]> * Fix Einsum in deepgemm * Fix `libc10.so` unimported error * fix reference code Signed-off-by: youkaichao <[email protected]> * adding missing indexer args * passing index args into the module * init Signed-off-by: Chen Zhang <[email protected]> * build indexer k cache medadata * prefill indexer, but weight_proj will output -inf * unqiantized paged indexer, still have -inf issue * remove support material * adding topk_indices mask * add weight scale * unittest infrastructure and fix weight_proj, numeric error due to quantization * varlen prefill passed * paged prefill * add indices mask --------- Signed-off-by: youkaichao <[email protected]> Signed-off-by: mgoin <[email protected]> Signed-off-by: Chen Zhang <[email protected]> Co-authored-by: youkaichao <[email protected]> Co-authored-by: mgoin <[email protected]> Co-authored-by: Wentao Ye <[email protected]> Co-authored-by: Chen Zhang <[email protected]>
1 parent 0eba9f1 commit c0c0624

File tree

7 files changed

+727
-17
lines changed

7 files changed

+727
-17
lines changed
Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
import random
2+
3+
import torch
4+
5+
from vllm.utils.tile_lang_kernels import act_quant, fp8_index
6+
from vllm import _custom_ops as ops
7+
8+
9+
def ref_compute_logits_fp8(q, kv, weights, mask, block_size):
10+
q_fp8, q_scale = act_quant(q, block_size, "ue8m0")
11+
k_fp8, k_scale = act_quant(kv, block_size, "ue8m0")
12+
13+
weights = weights.unsqueeze(-1) * q_scale
14+
weights = weights * (128**(-0.5))
15+
index_score = fp8_index(
16+
q_fp8.contiguous(), weights,
17+
k_fp8.contiguous(),
18+
k_scale.contiguous())
19+
if mask is not None:
20+
index_score += mask
21+
return index_score
22+
23+
def ref_indexer(seq_len, q, kv, weights, block_size, topk):
24+
B = seq_len.shape[0]
25+
varlen_logits = []
26+
27+
for i in range(B):
28+
S = seq_len[i]
29+
q_s = q[i][:S].contiguous().unsqueeze(0)
30+
kv_s = kv[i][:S].contiguous().unsqueeze(0)
31+
weights_s = weights[i][:S].contiguous().unsqueeze(0)
32+
mask = torch.full(
33+
(S, S), float("-inf"),
34+
device="cuda").triu_(1)
35+
logits = ref_compute_logits_fp8(q_s, kv_s, weights_s, mask, block_size)
36+
varlen_logits.append(logits)
37+
# topk_indices = index_score.topk(topk,
38+
# dim=-1)[1]
39+
return varlen_logits
40+
41+
def kv_spans_from_batches(start_seq_loc: torch.Tensor,
42+
seq_len_per_batch: torch.Tensor):
43+
"""
44+
Args:
45+
start_seq_loc: 1D long tensor [B+1], cumulative counts of selected tokens per batch.
46+
Example: [0, 2, 4, 7] -> batch sizes (selected) [2, 2, 3], N=7 tokens total.
47+
seq_len_per_batch: 1D long tensor [B], full sequence length (KV length) of each batch.
48+
Example: [5, 9, 4].
49+
50+
Returns:
51+
start_tensor: 1D long tensor [N], start offset in the concatenated KV cache for each token's batch.
52+
end_location: 1D long tensor [N], **exclusive** end = start + token's local position.
53+
(So the attended KV slice is kv[start:end].)
54+
55+
Assumes each batch contributes its full `seq_len_per_batch[i]` keys to the KV cache, and
56+
the selected tokens within a batch are the **last** `counts[i]` positions of that sequence.
57+
"""
58+
q = start_seq_loc.to(dtype=torch.long)
59+
L = seq_len_per_batch.to(dtype=torch.long, device=q.device)
60+
assert q.dim() == 1 and L.dim() == 1
61+
assert q.numel() == L.numel() + 1, "start_seq_loc must have length B+1"
62+
63+
# Selected tokens per batch and totals
64+
counts = q[1:] - q[:-1] # [B]
65+
N = int(q[-1].item()) # total selected tokens
66+
B = L.numel()
67+
device = L.device
68+
69+
if N == 0:
70+
return (torch.empty(0, dtype=torch.long, device=device),
71+
torch.empty(0, dtype=torch.long, device=device))
72+
73+
# KV start offsets per batch in the concatenated KV cache
74+
kv_starts_per_batch = torch.cumsum(L, dim=0) - L # [B]
75+
76+
# For each selected token, which batch does it belong to?
77+
batch_id = torch.repeat_interleave(torch.arange(B, device=device), counts) # [N]
78+
79+
# Map batch KV start to each token
80+
start_tensor = kv_starts_per_batch[batch_id] # [N]
81+
82+
# End-align local positions inside each batch:
83+
# local_pos = L[b] - counts[b] + (1..counts[b]) for each batch b
84+
L_expand = torch.repeat_interleave(L, counts) # [N]
85+
m_expand = torch.repeat_interleave(counts, counts) # [N]
86+
# position within the selected block: 1..counts[b]
87+
pos_within = (torch.arange(N, device=device, dtype=torch.long)
88+
- torch.repeat_interleave(q[:-1], counts) + 1)
89+
90+
local_pos = L_expand - m_expand + pos_within # [N], 1-based
91+
end_location = start_tensor + local_pos # exclusive end
92+
93+
return start_tensor, end_location
94+
95+
def ref_fp8_mqa_logits(
96+
q: torch.Tensor,
97+
kv: torch.Tensor,
98+
weights: torch.Tensor,
99+
cu_seqlen_ks: torch.Tensor,
100+
cu_seqlen_ke: torch.Tensor,
101+
):
102+
k = kv
103+
q = q.float()
104+
k = k.float()
105+
106+
seq_len_kv = kv.shape[0]
107+
mask_lo = (torch.arange(0, seq_len_kv, device="cuda")[None, :]
108+
>= cu_seqlen_ks[:, None])
109+
mask_hi = (torch.arange(0, seq_len_kv, device="cuda")[None, :]
110+
< cu_seqlen_ke[:, None])
111+
mask = mask_lo & mask_hi
112+
113+
score = torch.einsum("mhd,nd->hmn", q, k)
114+
logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
115+
logits = logits.masked_fill(~mask, float("-inf"))
116+
117+
cost = mask.sum()
118+
return logits, cost
119+
120+
def torch_indexer(seq_len, q, kv, weights, block_size, topk):
121+
NUM_BLOCKS = 8
122+
BLOCK_SIZE = 32
123+
124+
B = seq_len.shape[0]
125+
concat_q = []
126+
concat_kv = []
127+
concat_weights = []
128+
total_slots = NUM_BLOCKS * BLOCK_SIZE
129+
head_dim = kv.shape[-1]
130+
max_num_block_per_batch = torch.max(seq_len)
131+
block_table = torch.empty((B, max_num_block_per_batch),
132+
dtype=torch.int32,
133+
device="cuda")
134+
135+
for i in range(B):
136+
S = seq_len[i]
137+
q_s = q[i][:S].contiguous()
138+
kv_s = kv[i][:S].contiguous()
139+
weight_s = weights[i][:S].contiguous()
140+
concat_q.append(q_s)
141+
concat_kv.append(kv_s)
142+
concat_weights.append(weight_s)
143+
144+
q = torch.cat(concat_q, dim=0)
145+
kv = torch.cat(concat_kv, dim=0)
146+
weights = torch.cat(concat_weights, dim=0)
147+
148+
# write to kv cache based on slot mapping
149+
entry_size = head_dim * 2
150+
num_tokens = q.size(0)
151+
slot_mapping_lst = random.sample(range(total_slots), num_tokens)
152+
slot_mapping = torch.tensor(slot_mapping_lst,
153+
dtype=torch.long,
154+
device="cuda")
155+
kv_cache = torch.zeros(
156+
NUM_BLOCKS,
157+
BLOCK_SIZE,
158+
entry_size,
159+
dtype=torch.bfloat16,
160+
device="cuda"
161+
)
162+
scale = torch.tensor(1, dtype=torch.float32, device="cuda")
163+
ops.concat_and_cache_mla(
164+
kv,
165+
kv.clone(),
166+
kv_cache,
167+
slot_mapping,
168+
"auto",
169+
scale
170+
)
171+
172+
current_index = 0
173+
for i in range(B):
174+
S = seq_len[i]
175+
block_table[i][:S] = slot_mapping[current_index: current_index + S]
176+
current_index += S
177+
178+
weights = weights * (128**(-0.5))
179+
query_start_loc = torch.empty((B + 1), device="cuda")
180+
query_start_loc[0] = 0
181+
query_start_loc[1:] = seq_len.cumsum(dim=0).to(dtype=torch.int32)
182+
183+
kv_gathered = kv_cache.view(-1, entry_size)[slot_mapping][..., :head_dim]
184+
torch.testing.assert_close(kv, kv_gathered)
185+
186+
cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(query_start_loc, seq_len)
187+
188+
logits, _ = ref_fp8_mqa_logits(
189+
q,
190+
kv_gathered,
191+
weights,
192+
cu_seqlen_ks,
193+
cu_seqlen_ke
194+
)
195+
topk_indices = logits.topk(topk, dim=-1)[1]
196+
mask_lo = topk_indices >= cu_seqlen_ks[:, None]
197+
mask_hi = topk_indices < cu_seqlen_ke[:, None]
198+
mask = mask_lo & mask_hi
199+
topk_indices = topk_indices.masked_fill(~mask, -1)
200+
return logits
201+
202+
def test_paged_indexer_python():
203+
B = 2
204+
S = 128
205+
SKV = S
206+
H = 64
207+
HKV = 1
208+
D = 128
209+
block_size = 128
210+
topk = 64
211+
device = "cuda"
212+
seq_len = torch.randint(low=64, high=S, size=(B,))
213+
214+
q = torch.randn(B, S, H, D, device="cuda",
215+
dtype=torch.bfloat16)
216+
kv = torch.randn(B, SKV, D, device="cuda",
217+
dtype=torch.bfloat16)
218+
weights = torch.randn(B, S, H, device=device, dtype=torch.float32) * H**-0.5
219+
220+
ref_indices = ref_indexer(seq_len, q, kv, weights, block_size, topk)
221+
torch_indices = torch_indexer(seq_len, q, kv, weights, block_size, topk)
222+
import pdb; pdb.set_trace()
223+
print(ref_indices)
224+
225+
226+
if __name__ == "__main__":
227+
test_paged_indexer_python()

vllm/model_executor/layers/mla.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def __init__(
108108
qk_head_dim=self.qk_head_dim,
109109
v_head_dim=self.v_head_dim,
110110
kv_b_proj=self.kv_b_proj,
111+
indexer=self.indexer,
111112
)
112113

113114
self.prefix = prefix
@@ -153,6 +154,11 @@ def forward_native(
153154

154155
q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb(
155156
positions, q[..., self.qk_nope_head_dim:], k_pe)
157+
158+
if self.indexer:
159+
topk_indices = self.indexer(hidden_states, q_c, positions, self.rotary_emb)
160+
# if topk_indices is not None:
161+
# print(topk_indices)
156162

157163
if self.use_sparse:
158164
topk_indices = torch.zeros(q.shape[0], self.topk_tokens)

0 commit comments

Comments
 (0)