|
| 1 | +import random |
| 2 | + |
| 3 | +import torch |
| 4 | + |
| 5 | +from vllm.utils.tile_lang_kernels import act_quant, fp8_index |
| 6 | +from vllm import _custom_ops as ops |
| 7 | + |
| 8 | + |
| 9 | +def ref_compute_logits_fp8(q, kv, weights, mask, block_size): |
| 10 | + q_fp8, q_scale = act_quant(q, block_size, "ue8m0") |
| 11 | + k_fp8, k_scale = act_quant(kv, block_size, "ue8m0") |
| 12 | + |
| 13 | + weights = weights.unsqueeze(-1) * q_scale |
| 14 | + weights = weights * (128**(-0.5)) |
| 15 | + index_score = fp8_index( |
| 16 | + q_fp8.contiguous(), weights, |
| 17 | + k_fp8.contiguous(), |
| 18 | + k_scale.contiguous()) |
| 19 | + if mask is not None: |
| 20 | + index_score += mask |
| 21 | + return index_score |
| 22 | + |
| 23 | +def ref_indexer(seq_len, q, kv, weights, block_size, topk): |
| 24 | + B = seq_len.shape[0] |
| 25 | + varlen_logits = [] |
| 26 | + |
| 27 | + for i in range(B): |
| 28 | + S = seq_len[i] |
| 29 | + q_s = q[i][:S].contiguous().unsqueeze(0) |
| 30 | + kv_s = kv[i][:S].contiguous().unsqueeze(0) |
| 31 | + weights_s = weights[i][:S].contiguous().unsqueeze(0) |
| 32 | + mask = torch.full( |
| 33 | + (S, S), float("-inf"), |
| 34 | + device="cuda").triu_(1) |
| 35 | + logits = ref_compute_logits_fp8(q_s, kv_s, weights_s, mask, block_size) |
| 36 | + varlen_logits.append(logits) |
| 37 | + # topk_indices = index_score.topk(topk, |
| 38 | + # dim=-1)[1] |
| 39 | + return varlen_logits |
| 40 | + |
| 41 | +def kv_spans_from_batches(start_seq_loc: torch.Tensor, |
| 42 | + seq_len_per_batch: torch.Tensor): |
| 43 | + """ |
| 44 | + Args: |
| 45 | + start_seq_loc: 1D long tensor [B+1], cumulative counts of selected tokens per batch. |
| 46 | + Example: [0, 2, 4, 7] -> batch sizes (selected) [2, 2, 3], N=7 tokens total. |
| 47 | + seq_len_per_batch: 1D long tensor [B], full sequence length (KV length) of each batch. |
| 48 | + Example: [5, 9, 4]. |
| 49 | +
|
| 50 | + Returns: |
| 51 | + start_tensor: 1D long tensor [N], start offset in the concatenated KV cache for each token's batch. |
| 52 | + end_location: 1D long tensor [N], **exclusive** end = start + token's local position. |
| 53 | + (So the attended KV slice is kv[start:end].) |
| 54 | +
|
| 55 | + Assumes each batch contributes its full `seq_len_per_batch[i]` keys to the KV cache, and |
| 56 | + the selected tokens within a batch are the **last** `counts[i]` positions of that sequence. |
| 57 | + """ |
| 58 | + q = start_seq_loc.to(dtype=torch.long) |
| 59 | + L = seq_len_per_batch.to(dtype=torch.long, device=q.device) |
| 60 | + assert q.dim() == 1 and L.dim() == 1 |
| 61 | + assert q.numel() == L.numel() + 1, "start_seq_loc must have length B+1" |
| 62 | + |
| 63 | + # Selected tokens per batch and totals |
| 64 | + counts = q[1:] - q[:-1] # [B] |
| 65 | + N = int(q[-1].item()) # total selected tokens |
| 66 | + B = L.numel() |
| 67 | + device = L.device |
| 68 | + |
| 69 | + if N == 0: |
| 70 | + return (torch.empty(0, dtype=torch.long, device=device), |
| 71 | + torch.empty(0, dtype=torch.long, device=device)) |
| 72 | + |
| 73 | + # KV start offsets per batch in the concatenated KV cache |
| 74 | + kv_starts_per_batch = torch.cumsum(L, dim=0) - L # [B] |
| 75 | + |
| 76 | + # For each selected token, which batch does it belong to? |
| 77 | + batch_id = torch.repeat_interleave(torch.arange(B, device=device), counts) # [N] |
| 78 | + |
| 79 | + # Map batch KV start to each token |
| 80 | + start_tensor = kv_starts_per_batch[batch_id] # [N] |
| 81 | + |
| 82 | + # End-align local positions inside each batch: |
| 83 | + # local_pos = L[b] - counts[b] + (1..counts[b]) for each batch b |
| 84 | + L_expand = torch.repeat_interleave(L, counts) # [N] |
| 85 | + m_expand = torch.repeat_interleave(counts, counts) # [N] |
| 86 | + # position within the selected block: 1..counts[b] |
| 87 | + pos_within = (torch.arange(N, device=device, dtype=torch.long) |
| 88 | + - torch.repeat_interleave(q[:-1], counts) + 1) |
| 89 | + |
| 90 | + local_pos = L_expand - m_expand + pos_within # [N], 1-based |
| 91 | + end_location = start_tensor + local_pos # exclusive end |
| 92 | + |
| 93 | + return start_tensor, end_location |
| 94 | + |
| 95 | +def ref_fp8_mqa_logits( |
| 96 | + q: torch.Tensor, |
| 97 | + kv: torch.Tensor, |
| 98 | + weights: torch.Tensor, |
| 99 | + cu_seqlen_ks: torch.Tensor, |
| 100 | + cu_seqlen_ke: torch.Tensor, |
| 101 | +): |
| 102 | + k = kv |
| 103 | + q = q.float() |
| 104 | + k = k.float() |
| 105 | + |
| 106 | + seq_len_kv = kv.shape[0] |
| 107 | + mask_lo = (torch.arange(0, seq_len_kv, device="cuda")[None, :] |
| 108 | + >= cu_seqlen_ks[:, None]) |
| 109 | + mask_hi = (torch.arange(0, seq_len_kv, device="cuda")[None, :] |
| 110 | + < cu_seqlen_ke[:, None]) |
| 111 | + mask = mask_lo & mask_hi |
| 112 | + |
| 113 | + score = torch.einsum("mhd,nd->hmn", q, k) |
| 114 | + logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0) |
| 115 | + logits = logits.masked_fill(~mask, float("-inf")) |
| 116 | + |
| 117 | + cost = mask.sum() |
| 118 | + return logits, cost |
| 119 | + |
| 120 | +def torch_indexer(seq_len, q, kv, weights, block_size, topk): |
| 121 | + NUM_BLOCKS = 8 |
| 122 | + BLOCK_SIZE = 32 |
| 123 | + |
| 124 | + B = seq_len.shape[0] |
| 125 | + concat_q = [] |
| 126 | + concat_kv = [] |
| 127 | + concat_weights = [] |
| 128 | + total_slots = NUM_BLOCKS * BLOCK_SIZE |
| 129 | + head_dim = kv.shape[-1] |
| 130 | + max_num_block_per_batch = torch.max(seq_len) |
| 131 | + block_table = torch.empty((B, max_num_block_per_batch), |
| 132 | + dtype=torch.int32, |
| 133 | + device="cuda") |
| 134 | + |
| 135 | + for i in range(B): |
| 136 | + S = seq_len[i] |
| 137 | + q_s = q[i][:S].contiguous() |
| 138 | + kv_s = kv[i][:S].contiguous() |
| 139 | + weight_s = weights[i][:S].contiguous() |
| 140 | + concat_q.append(q_s) |
| 141 | + concat_kv.append(kv_s) |
| 142 | + concat_weights.append(weight_s) |
| 143 | + |
| 144 | + q = torch.cat(concat_q, dim=0) |
| 145 | + kv = torch.cat(concat_kv, dim=0) |
| 146 | + weights = torch.cat(concat_weights, dim=0) |
| 147 | + |
| 148 | + # write to kv cache based on slot mapping |
| 149 | + entry_size = head_dim * 2 |
| 150 | + num_tokens = q.size(0) |
| 151 | + slot_mapping_lst = random.sample(range(total_slots), num_tokens) |
| 152 | + slot_mapping = torch.tensor(slot_mapping_lst, |
| 153 | + dtype=torch.long, |
| 154 | + device="cuda") |
| 155 | + kv_cache = torch.zeros( |
| 156 | + NUM_BLOCKS, |
| 157 | + BLOCK_SIZE, |
| 158 | + entry_size, |
| 159 | + dtype=torch.bfloat16, |
| 160 | + device="cuda" |
| 161 | + ) |
| 162 | + scale = torch.tensor(1, dtype=torch.float32, device="cuda") |
| 163 | + ops.concat_and_cache_mla( |
| 164 | + kv, |
| 165 | + kv.clone(), |
| 166 | + kv_cache, |
| 167 | + slot_mapping, |
| 168 | + "auto", |
| 169 | + scale |
| 170 | + ) |
| 171 | + |
| 172 | + current_index = 0 |
| 173 | + for i in range(B): |
| 174 | + S = seq_len[i] |
| 175 | + block_table[i][:S] = slot_mapping[current_index: current_index + S] |
| 176 | + current_index += S |
| 177 | + |
| 178 | + weights = weights * (128**(-0.5)) |
| 179 | + query_start_loc = torch.empty((B + 1), device="cuda") |
| 180 | + query_start_loc[0] = 0 |
| 181 | + query_start_loc[1:] = seq_len.cumsum(dim=0).to(dtype=torch.int32) |
| 182 | + |
| 183 | + kv_gathered = kv_cache.view(-1, entry_size)[slot_mapping][..., :head_dim] |
| 184 | + torch.testing.assert_close(kv, kv_gathered) |
| 185 | + |
| 186 | + cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(query_start_loc, seq_len) |
| 187 | + |
| 188 | + logits, _ = ref_fp8_mqa_logits( |
| 189 | + q, |
| 190 | + kv_gathered, |
| 191 | + weights, |
| 192 | + cu_seqlen_ks, |
| 193 | + cu_seqlen_ke |
| 194 | + ) |
| 195 | + topk_indices = logits.topk(topk, dim=-1)[1] |
| 196 | + mask_lo = topk_indices >= cu_seqlen_ks[:, None] |
| 197 | + mask_hi = topk_indices < cu_seqlen_ke[:, None] |
| 198 | + mask = mask_lo & mask_hi |
| 199 | + topk_indices = topk_indices.masked_fill(~mask, -1) |
| 200 | + return logits |
| 201 | + |
| 202 | +def test_paged_indexer_python(): |
| 203 | + B = 2 |
| 204 | + S = 128 |
| 205 | + SKV = S |
| 206 | + H = 64 |
| 207 | + HKV = 1 |
| 208 | + D = 128 |
| 209 | + block_size = 128 |
| 210 | + topk = 64 |
| 211 | + device = "cuda" |
| 212 | + seq_len = torch.randint(low=64, high=S, size=(B,)) |
| 213 | + |
| 214 | + q = torch.randn(B, S, H, D, device="cuda", |
| 215 | + dtype=torch.bfloat16) |
| 216 | + kv = torch.randn(B, SKV, D, device="cuda", |
| 217 | + dtype=torch.bfloat16) |
| 218 | + weights = torch.randn(B, S, H, device=device, dtype=torch.float32) * H**-0.5 |
| 219 | + |
| 220 | + ref_indices = ref_indexer(seq_len, q, kv, weights, block_size, topk) |
| 221 | + torch_indices = torch_indexer(seq_len, q, kv, weights, block_size, topk) |
| 222 | + import pdb; pdb.set_trace() |
| 223 | + print(ref_indices) |
| 224 | + |
| 225 | + |
| 226 | +if __name__ == "__main__": |
| 227 | + test_paged_indexer_python() |
0 commit comments