From 12f85fb15d558ee782fd75e0a57be39ce683f76d Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 19 Sep 2025 23:35:07 +0000 Subject: [PATCH 01/82] init dev branch Signed-off-by: Chen Zhang --- vllm/model_executor/models/deepseek_v2.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 636554bd648f..bbccbeecd992 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -66,6 +66,7 @@ make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) +WITH_V32 = True class DeepseekV2MLP(nn.Module): From 14868306a4da6190275c16bde4bad713056794d6 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 20 Sep 2025 22:09:10 +0800 Subject: [PATCH 02/82] add indexer module Signed-off-by: youkaichao --- vllm/model_executor/layers/layernorm.py | 18 +++++++++++++ vllm/model_executor/layers/mla.py | 1 + vllm/model_executor/models/deepseek_v2.py | 33 ++++++++++++++++++++++- 3 files changed, 51 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index f875f712ba9c..a44ca5c8939e 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -5,6 +5,7 @@ import torch import torch.nn as nn +import torch.nn.functional as F import vllm.envs as envs from vllm.model_executor.custom_op import CustomOp @@ -379,3 +380,20 @@ def forward_cuda( x: torch.Tensor, ) -> torch.Tensor: return poly_norm(x, self.weight, self.bias, self.variance_epsilon) + + +class LayerNorm(nn.Module): + """ + Layer Normalization. + """ + + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.dim = dim + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: torch.Tensor): + return F.layer_norm(x.float(), (self.dim, ), self.weight, self.bias, + self.eps).type_as(x) diff --git a/vllm/model_executor/layers/mla.py b/vllm/model_executor/layers/mla.py index a05716190365..c2a48acc2d3c 100644 --- a/vllm/model_executor/layers/mla.py +++ b/vllm/model_executor/layers/mla.py @@ -24,6 +24,7 @@ class MLAModules: q_a_layernorm: Optional[torch.nn.Module] q_b_proj: Optional[torch.nn.Module] q_proj: Optional[torch.nn.Module] + indexer: Optional[torch.nn.Module] @CustomOp.register("multi_head_latent_attention") diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index bbccbeecd992..1d73d06cf474 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -42,7 +42,7 @@ tensor_model_parallel_all_gather) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE -from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.layernorm import RMSNorm, LayerNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, ReplicatedLinear, @@ -475,6 +475,29 @@ def forward( return output +class Indexer(nn.Module): + def __init__(self, config: Union[DeepseekV2Config, DeepseekV3Config], + parent_attn_mod: nn.Module, + prefix: str = ""): + super().__init__() + self.config = config + self.indexer_cfg = config.attn_module_list_cfg[0]["attn_index"] + self.topk_tokens = config.attn_module_list_cfg[0]["topk_tokens"] + self.n_heads = self.indexer_cfg["n_heads"] # 64 + self.head_dim = self.indexer_cfg["head_dim"] # 128 + self.rope_dim = self.indexer_cfg["rope_dim"] # 64 + self.q_lora_rank = parent_attn_mod.q_lora_rank # 1536 + # no tensor parallel, just replicated + self.wq_b = ReplicatedLinear(self.q_lora_rank, self.head_dim * self.n_heads, prefix=f"{prefix}.wq_b") + self.wk = ReplicatedLinear(parent_attn_mod.hidden_size, self.head_dim, prefix=f"{prefix}.wk") + self.k_norm = LayerNorm(self.head_dim, eps=1e-6) + self.weights_proj = ReplicatedLinear(parent_attn_mod.hidden_size, self.n_heads, prefix=f"{prefix}.weights_proj") + self.softmax_scale = self.head_dim ** -0.5 + + def forward(self, hidden_states: torch.Tensor, q: torch.Tensor) -> torch.Tensor: + return torch.empty_like(hidden_states)[:, :self.topk_tokens].contiguous() + + class DeepseekV2MLAAttention(nn.Module): """ Main reference: DeepseekV2 paper, and FlashInfer Implementation @@ -579,6 +602,12 @@ def __init__( mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) self.scaling = self.scaling * mscale * mscale + if hasattr(config, "attn_module_list_cfg") and "attn_index" in config.attn_module_list_cfg[0]: + # DSv3.2 + self.indexer = Indexer(config, self, f"{prefix}.indexer") + else: + self.indexer = None + mla_modules = MLAModules( kv_a_layernorm=self.kv_a_layernorm, kv_b_proj=self.kv_b_proj, @@ -592,7 +621,9 @@ def __init__( if self.q_lora_rank is not None else None, q_b_proj=self.q_b_proj if self.q_lora_rank is not None else None, q_proj=self.q_proj if self.q_lora_rank is None else None, + indexer=self.indexer, ) + self.mla_attn = MultiHeadLatentAttention( self.hidden_size, self.num_local_heads, From ee3271edfd58cfddaede9ef573fe92c31de5860a Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 20 Sep 2025 22:25:35 +0800 Subject: [PATCH 03/82] fix fp8 weight loading Signed-off-by: youkaichao --- vllm/model_executor/models/deepseek_v2.py | 47 +++++++++++++++-------- 1 file changed, 32 insertions(+), 15 deletions(-) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 1d73d06cf474..25507f7c7c42 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -42,7 +42,7 @@ tensor_model_parallel_all_gather) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE -from vllm.model_executor.layers.layernorm import RMSNorm, LayerNorm +from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, ReplicatedLinear, @@ -476,26 +476,41 @@ def forward( class Indexer(nn.Module): - def __init__(self, config: Union[DeepseekV2Config, DeepseekV3Config], - parent_attn_mod: nn.Module, + + def __init__(self, + config: Union[DeepseekV2Config, DeepseekV3Config], + hidden_size: int, + q_lora_rank: int, + quant_config: Optional[QuantizationConfig], prefix: str = ""): super().__init__() self.config = config self.indexer_cfg = config.attn_module_list_cfg[0]["attn_index"] self.topk_tokens = config.attn_module_list_cfg[0]["topk_tokens"] - self.n_heads = self.indexer_cfg["n_heads"] # 64 - self.head_dim = self.indexer_cfg["head_dim"] # 128 - self.rope_dim = self.indexer_cfg["rope_dim"] # 64 - self.q_lora_rank = parent_attn_mod.q_lora_rank # 1536 + self.n_heads = self.indexer_cfg["n_heads"] # 64 + self.head_dim = self.indexer_cfg["head_dim"] # 128 + self.rope_dim = self.indexer_cfg["rope_dim"] # 64 + self.q_lora_rank = q_lora_rank # 1536 # no tensor parallel, just replicated - self.wq_b = ReplicatedLinear(self.q_lora_rank, self.head_dim * self.n_heads, prefix=f"{prefix}.wq_b") - self.wk = ReplicatedLinear(parent_attn_mod.hidden_size, self.head_dim, prefix=f"{prefix}.wk") + self.wq_b = ReplicatedLinear(self.q_lora_rank, + self.head_dim * self.n_heads, + quant_config=quant_config, + prefix=f"{prefix}.wq_b") + self.wk = ReplicatedLinear(hidden_size, + self.head_dim, + quant_config=quant_config, + prefix=f"{prefix}.wk") self.k_norm = LayerNorm(self.head_dim, eps=1e-6) - self.weights_proj = ReplicatedLinear(parent_attn_mod.hidden_size, self.n_heads, prefix=f"{prefix}.weights_proj") - self.softmax_scale = self.head_dim ** -0.5 + self.weights_proj = ReplicatedLinear(hidden_size, + self.n_heads, + quant_config=quant_config, + prefix=f"{prefix}.weights_proj") + self.softmax_scale = self.head_dim**-0.5 - def forward(self, hidden_states: torch.Tensor, q: torch.Tensor) -> torch.Tensor: - return torch.empty_like(hidden_states)[:, :self.topk_tokens].contiguous() + def forward(self, hidden_states: torch.Tensor, + q: torch.Tensor) -> torch.Tensor: + return torch.empty_like( + hidden_states)[:, :self.topk_tokens].contiguous() class DeepseekV2MLAAttention(nn.Module): @@ -602,9 +617,11 @@ def __init__( mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) self.scaling = self.scaling * mscale * mscale - if hasattr(config, "attn_module_list_cfg") and "attn_index" in config.attn_module_list_cfg[0]: + if hasattr(config, "attn_module_list_cfg" + ) and "attn_index" in config.attn_module_list_cfg[0]: # DSv3.2 - self.indexer = Indexer(config, self, f"{prefix}.indexer") + self.indexer = Indexer(config, hidden_size, q_lora_rank, + quant_config, f"{prefix}.indexer") else: self.indexer = None From 3f4154dfa9cc72881f3493b345897b5d4c6d0dd5 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 20 Sep 2025 22:26:47 +0800 Subject: [PATCH 04/82] fix key Signed-off-by: youkaichao --- vllm/model_executor/models/deepseek_v2.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 25507f7c7c42..cb5db90c16de 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -487,13 +487,13 @@ def __init__(self, self.config = config self.indexer_cfg = config.attn_module_list_cfg[0]["attn_index"] self.topk_tokens = config.attn_module_list_cfg[0]["topk_tokens"] - self.n_heads = self.indexer_cfg["n_heads"] # 64 + self.n_head = self.indexer_cfg["n_head"] # 64 self.head_dim = self.indexer_cfg["head_dim"] # 128 self.rope_dim = self.indexer_cfg["rope_dim"] # 64 self.q_lora_rank = q_lora_rank # 1536 # no tensor parallel, just replicated self.wq_b = ReplicatedLinear(self.q_lora_rank, - self.head_dim * self.n_heads, + self.head_dim * self.n_head, quant_config=quant_config, prefix=f"{prefix}.wq_b") self.wk = ReplicatedLinear(hidden_size, @@ -502,7 +502,7 @@ def __init__(self, prefix=f"{prefix}.wk") self.k_norm = LayerNorm(self.head_dim, eps=1e-6) self.weights_proj = ReplicatedLinear(hidden_size, - self.n_heads, + self.n_head, quant_config=quant_config, prefix=f"{prefix}.weights_proj") self.softmax_scale = self.head_dim**-0.5 From 991b94f807124f6988942e3db847f7c2f0843394 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 20 Sep 2025 22:43:44 +0800 Subject: [PATCH 05/82] basic test Signed-off-by: youkaichao --- examples/offline_inference/basic/basic.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/examples/offline_inference/basic/basic.py b/examples/offline_inference/basic/basic.py index 78bfda9bcf4e..75c364321c98 100644 --- a/examples/offline_inference/basic/basic.py +++ b/examples/offline_inference/basic/basic.py @@ -5,18 +5,15 @@ # Sample prompts. prompts = [ - "Hello, my name is", - "The president of the United States is", "The capital of France is", - "The future of AI is", ] # Create a sampling params object. -sampling_params = SamplingParams(temperature=0.8, top_p=0.95) +sampling_params = SamplingParams(temperature=0, top_p=1, max_tokens=1024) def main(): # Create an LLM. - llm = LLM(model="facebook/opt-125m") + llm = LLM(model="/home/vllm-dsv32/DeepSeek-V3.2-Preview-Fix", tensor_parallel_size=8) # Generate texts from the prompts. # The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. From 00c455c89e3e2b099f0623087ee1d53c7165b439 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sat, 20 Sep 2025 12:04:27 -0700 Subject: [PATCH 06/82] add indexer cache (#12) Signed-off-by: Chen Zhang --- vllm/model_executor/models/deepseek_v2.py | 60 ++++++++++++++++++++++- vllm/v1/attention/backends/mla/indexer.py | 38 ++++++++++++++ vllm/v1/worker/gpu_model_runner.py | 5 ++ 3 files changed, 101 insertions(+), 2 deletions(-) create mode 100644 vllm/v1/attention/backends/mla/indexer.py diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index cb5db90c16de..927adc13a605 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -32,15 +32,19 @@ from torch import nn from transformers import DeepseekV2Config, DeepseekV3Config +from vllm.attention.backends.abstract import AttentionBackend +from vllm.logger import init_logger +from vllm.config.compilation import CompilationConfig import vllm.envs as envs from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, ParallelConfig, VllmConfig +from vllm.config import CacheConfig, ParallelConfig, VllmConfig, get_current_vllm_config from vllm.distributed import (get_ep_group, get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_gather) from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -60,14 +64,19 @@ from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.utils import cdiv, direct_register_custom_op +from vllm.v1.attention.backends.mla.indexer import DeepseekV32IndexerBackend from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP from .utils import (PPMissingLayer, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) +from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheSpec + +logger = init_logger(__name__) WITH_V32 = True + class DeepseekV2MLP(nn.Module): def __init__( @@ -475,6 +484,40 @@ def forward( return output +class DeepseekV32IndexerCache(torch.nn.Module, AttentionLayerBase): + + def __init__(self, head_dim: int, dtype: torch.dtype, prefix: str, + cache_config: CacheConfig): + super().__init__() + self.kv_cache = [torch.tensor([])] + self.head_dim = head_dim + self.prefix = prefix + self.cache_config = cache_config + self.dtype = dtype + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + logger.info(f"register indexer cache {prefix}") + + def get_kv_cache_spec(self) -> KVCacheSpec: + return FullAttentionSpec( + block_size=self.cache_config.block_size, + num_kv_heads=1, + head_size=self.head_dim, + dtype=self.dtype, + use_mla=True # Only has one vector instead of K + V + ) + + def forward(self): + logger.info( + f"self.kv_cache {self.prefix} {self.kv_cache[0].shape} {self.kv_cache[0].dtype}" + ) + + def get_attn_backend(self) -> AttentionBackend: + return DeepseekV32IndexerBackend + + class Indexer(nn.Module): def __init__(self, @@ -482,6 +525,7 @@ def __init__(self, hidden_size: int, q_lora_rank: int, quant_config: Optional[QuantizationConfig], + cache_config: Optional[CacheConfig], prefix: str = ""): super().__init__() self.config = config @@ -507,6 +551,17 @@ def __init__(self, prefix=f"{prefix}.weights_proj") self.softmax_scale = self.head_dim**-0.5 + self.quant_block_size = 128 # TODO: get from config + self.k_cache = DeepseekV32IndexerCache(head_dim=self.head_dim, + dtype=torch.float8_e4m3fn, + prefix=f"{prefix}.k_cache", + cache_config=cache_config) + self.k_scale_cache = DeepseekV32IndexerCache( + head_dim=self.head_dim // self.quant_block_size, + dtype=torch.float32, + prefix=f"{prefix}.k_scale_cache", + cache_config=cache_config) + def forward(self, hidden_states: torch.Tensor, q: torch.Tensor) -> torch.Tensor: return torch.empty_like( @@ -621,7 +676,8 @@ def __init__( ) and "attn_index" in config.attn_module_list_cfg[0]: # DSv3.2 self.indexer = Indexer(config, hidden_size, q_lora_rank, - quant_config, f"{prefix}.indexer") + quant_config, cache_config, + f"{prefix}.indexer") else: self.indexer = None diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py new file mode 100644 index 000000000000..db474cb92718 --- /dev/null +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -0,0 +1,38 @@ +from dataclasses import dataclass +from vllm.attention.backends.abstract import AttentionBackend +from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, + CommonAttentionMetadata) + + +class DeepseekV32IndexerBackend(AttentionBackend): + + @staticmethod + def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int, + head_size: int) -> tuple[int, ...]: + assert num_kv_heads == 1 + return (num_blocks, block_size, head_size) + + @staticmethod + def get_kv_cache_stride_order() -> tuple[int, ...]: + return (0, 1, 2) + + @staticmethod + def get_builder_cls() -> type["DeepseekV32IndexerMetadataBuilder"]: + return DeepseekV32IndexerMetadataBuilder + + +@dataclass +class DeepseekV32IndexerMetadata: + pass + + +class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def build(self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False) -> DeepseekV32IndexerMetadata: + return DeepseekV32IndexerMetadata() diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d0946e8c5d7d..e6a492266b53 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -40,6 +40,7 @@ from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader +from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache from vllm.model_executor.models.interfaces import (is_mixture_of_experts, supports_eagle3, supports_mrope, @@ -4028,6 +4029,10 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: self.speculative_config.num_speculative_tokens if self.speculative_config else 0), ) + ds_indexer_layers = get_layers_from_vllm_config( + self.vllm_config, DeepseekV32IndexerCache) + for layer_name, ds_indexer_module in ds_indexer_layers.items(): + kv_cache_spec[layer_name] = ds_indexer_module.get_kv_cache_spec() return kv_cache_spec From ddaf933aec44640a277607b40cf213103a28e8c2 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sun, 21 Sep 2025 07:13:13 +0000 Subject: [PATCH 07/82] setup sparse attention backend Signed-off-by: Chen Zhang --- .../vllm_add_dummy_platform/dummy_platform.py | 2 +- vllm/attention/layer.py | 5 +- vllm/attention/selector.py | 5 +- vllm/model_executor/layers/mla.py | 2 + vllm/model_executor/models/deepseek_v2.py | 9 +- vllm/platforms/cpu.py | 5 +- vllm/platforms/cuda.py | 7 +- vllm/platforms/interface.py | 2 +- vllm/platforms/rocm.py | 5 +- vllm/platforms/tpu.py | 5 +- vllm/platforms/xpu.py | 5 +- vllm/v1/attention/backends/mla/common.py | 9 +- .../attention/backends/mla/flashmla_sparse.py | 150 ++++++++++++++++++ 13 files changed, 197 insertions(+), 14 deletions(-) create mode 100644 vllm/v1/attention/backends/mla/flashmla_sparse.py diff --git a/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py index 8d0687b49bb4..30d721304b5c 100644 --- a/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py +++ b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py @@ -26,5 +26,5 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: def get_attn_backend_cls(self, backend_name, head_size, dtype, kv_cache_dtype, block_size, use_v1, use_mla, - has_sink): + has_sink, use_sparse): return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501 diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 544a72052442..2a89db21fb47 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -92,6 +92,7 @@ def __init__( logits_soft_cap: Optional[float] = None, per_layer_sliding_window: Optional[int] = None, use_mla: bool = False, + use_sparse: bool = False, prefix: str = "", attn_type: str = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[str] = None, @@ -154,6 +155,7 @@ def __init__( self._o_scale_float: Optional[float] = None self.use_mla = use_mla + self.use_sparse = use_sparse self.num_heads = num_heads self.head_size = head_size self.num_kv_heads = num_kv_heads @@ -187,7 +189,8 @@ def __init__( block_size, is_attention_free, use_mla=use_mla, - has_sink=self.has_sink) + has_sink=self.has_sink, + use_sparse=use_sparse) else: self.attn_backend = attn_backend diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 3a235ba6e0b4..e53674494a12 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -145,6 +145,7 @@ def get_attn_backend( is_attention_free: bool = False, use_mla: bool = False, has_sink: bool = False, + use_sparse: bool = False, ) -> type[AttentionBackend]: """Selects which attention backend to use and lazily imports it.""" # Accessing envs.* behind an @lru_cache decorator can cause the wrong @@ -160,6 +161,7 @@ def get_attn_backend( use_v1=envs.VLLM_USE_V1, use_mla=use_mla, has_sink=has_sink, + use_sparse=use_sparse, ) @@ -173,6 +175,7 @@ def _cached_get_attn_backend( use_v1: bool = False, use_mla: bool = False, has_sink: bool = False, + use_sparse: bool = False, ) -> type[AttentionBackend]: # If there are no attention layers (e.g. we are running Mamba), # use the placeholder NO_ATTENTION @@ -204,7 +207,7 @@ def _cached_get_attn_backend( # get device-specific attn_backend attention_cls = current_platform.get_attn_backend_cls( selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, - use_mla, has_sink) + use_mla, has_sink, use_sparse) if not attention_cls: raise ValueError( f"Invalid attention backend for {current_platform.device_name}") diff --git a/vllm/model_executor/layers/mla.py b/vllm/model_executor/layers/mla.py index c2a48acc2d3c..35680a969a65 100644 --- a/vllm/model_executor/layers/mla.py +++ b/vllm/model_executor/layers/mla.py @@ -25,6 +25,7 @@ class MLAModules: q_b_proj: Optional[torch.nn.Module] q_proj: Optional[torch.nn.Module] indexer: Optional[torch.nn.Module] + is_sparse: bool @CustomOp.register("multi_head_latent_attention") @@ -93,6 +94,7 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.attn", use_mla=True, + use_sparse=mla_modules.is_sparse, # MLA Args q_lora_rank=self.q_lora_rank, kv_lora_rank=self.kv_lora_rank, diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 927adc13a605..a4ad9c288f12 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -672,9 +672,11 @@ def __init__( mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) self.scaling = self.scaling * mscale * mscale - if hasattr(config, "attn_module_list_cfg" - ) and "attn_index" in config.attn_module_list_cfg[0]: - # DSv3.2 + self.is_v32 = hasattr( + config, "attn_module_list_cfg" + ) and "attn_index" in config.attn_module_list_cfg[0] + + if self.is_v32: self.indexer = Indexer(config, hidden_size, q_lora_rank, quant_config, cache_config, f"{prefix}.indexer") @@ -695,6 +697,7 @@ def __init__( q_b_proj=self.q_b_proj if self.q_lora_rank is not None else None, q_proj=self.q_proj if self.q_lora_rank is None else None, indexer=self.indexer, + is_sparse=self.is_v32, ) self.mla_attn = MultiHeadLatentAttention( diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 544e091491bf..78d2b0f37328 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -93,11 +93,14 @@ def get_device_name(cls, device_id: int = 0) -> str: def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, dtype: torch.dtype, kv_cache_dtype: Optional[str], block_size: int, use_v1: bool, use_mla: bool, - has_sink: bool) -> str: + has_sink: bool, use_sparse: bool) -> str: if selected_backend and selected_backend != _Backend.TORCH_SDPA: logger.info("Cannot use %s backend on CPU.", selected_backend) if use_mla: raise NotImplementedError("MLA is not supported on CPU.") + if use_sparse: + raise NotImplementedError( + "Sparse Attention is not supported on CPU.") logger.info("Using Torch SDPA backend.") if not use_v1: raise ValueError("CPU backend only supports V1.") diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 87d8f2b7481b..56ff8bff3ad0 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -234,7 +234,7 @@ def get_vit_attn_backend(cls, head_size: int, @classmethod def get_attn_backend_cls(cls, selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, use_mla, - has_sink) -> str: + has_sink, use_sparse) -> str: if use_mla: # TODO(lucas): refactor to be more concise # we should probably consider factoring out V1 here @@ -242,6 +242,11 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, from vllm.attention.ops.flashmla import is_flashmla_supported from vllm.attention.utils.fa_utils import flash_attn_supports_mla + if use_sparse: + logger.info_once("Using Sparse MLA backend on V1 engine.") + return ("vllm.v1.attention.backends.mla.flashmla_sparse." + "FlashMLASparseBackend") + use_cutlassmla = selected_backend == _Backend.CUTLASS_MLA or ( selected_backend is None and cls.is_device_capability(100) and block_size == 128) diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 53fc762dce54..0abc536d8a6f 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -200,7 +200,7 @@ def get_vit_attn_backend(cls, head_size: int, def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, dtype: torch.dtype, kv_cache_dtype: Optional[str], block_size: int, use_v1: bool, use_mla: bool, - has_sink: bool) -> str: + has_sink: bool, use_sparse: bool) -> str: """Get the attention backend class of a device.""" return "" diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 4f540fe965e2..ba97df02e8d2 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -189,7 +189,10 @@ def get_vit_attn_backend(cls, head_size: int, @classmethod def get_attn_backend_cls(cls, selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, use_mla, - has_sink) -> str: + has_sink, use_sparse) -> str: + if use_sparse: + raise NotImplementedError( + "Sparse Attention is not supported on ROCm.") if use_mla: from vllm.v1.attention.backends.mla.rocm_aiter_mla import ( is_aiter_mla_enabled) diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 4e4db116abca..d846eebac136 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -49,7 +49,10 @@ class TpuPlatform(Platform): def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, dtype: torch.dtype, kv_cache_dtype: Optional[str], block_size: int, use_v1: bool, use_mla: bool, - has_sink) -> str: + has_sink, use_sparse) -> str: + if use_sparse: + raise NotImplementedError( + "Sparse Attention is not supported on TPU.") if (selected_backend != _Backend.PALLAS and selected_backend != _Backend.PALLAS_VLLM_V1): logger.info("Cannot use %s backend on TPU.", selected_backend) diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 67ef058df10f..574576f3e9ed 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -36,7 +36,10 @@ class XPUPlatform(Platform): def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, dtype: torch.dtype, kv_cache_dtype: Optional[str], block_size: int, use_v1: bool, use_mla: bool, - has_sink: bool) -> str: + has_sink: bool, use_sparse) -> str: + if use_sparse: + raise NotImplementedError( + "Sparse Attention is not supported on XPU.") use_v1 = envs.VLLM_USE_V1 if not use_v1: raise ValueError("XPU backend only supports V1.") diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 5b307810de93..33bfb4127be1 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -237,6 +237,11 @@ except ImportError: flashinfer_available = False +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from vllm.v1.attention.backends.mla.flashmla_sparse import MLASparsePrefillMetadata + def is_rocm_aiter_fp8bmm_enabled() -> bool: return current_platform.is_rocm() \ @@ -398,8 +403,8 @@ class MLACommonMetadata(Generic[D]): decode: Optional[D] = None prefill: Optional[Union[MLACommonPrefillMetadata, - FlashInferPrefillMetadata, - CudnnPrefillMetadata]] = None + FlashInferPrefillMetadata, CudnnPrefillMetadata, + "MLASparsePrefillMetadata"]] = None def __post_init__(self): if self.head_dim is not None: diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py new file mode 100644 index 000000000000..931a1c264ac7 --- /dev/null +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -0,0 +1,150 @@ +from vllm.attention.backends.abstract import AttentionMetadata, AttentionLayer +import torch +from vllm.logger import init_logger +from vllm.v1.attention.backends.mla.common import MLACommonBackend, MLACommonDecodeMetadata, MLACommonImpl, MLACommonMetadata, MLACommonMetadataBuilder +from vllm.v1.attention.backends.utils import CommonAttentionMetadata, split_decodes_and_prefills +from dataclasses import dataclass +from vllm.config import VllmConfig +from vllm.v1.kv_cache_interface import AttentionSpec +from typing import Optional + +logger = init_logger(__name__) + + +class FlashMLASparseBackend(MLACommonBackend): + + @staticmethod + def get_name() -> str: + return "FLASHMLA_SPARSE_VLLM_V1" + + @staticmethod + def get_metadata_cls() -> type[AttentionMetadata]: + return FlashMLASparseMetadata + + @staticmethod + def get_builder_cls() -> type["FlashMLASparseMetadataBuilder"]: + return FlashMLASparseMetadataBuilder + + @staticmethod + def get_impl_cls() -> type["FlashMLASparseImpl"]: + return FlashMLASparseImpl + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, # assumed to be 1 for MLA + head_size: int, + ) -> tuple[int, ...]: + return (num_blocks, block_size, head_size) + + @classmethod + def get_supported_dtypes(cls) -> list[torch.dtype]: + print("try running get_supported_dtypes") + # TODO: verify this + return [torch.float16, torch.bfloat16] + + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + # TODO: verify this + return [576] + + +class MLASparsePrefillMetadata: + # NOTE(Chen): not call it "FlashMLASparsePrefillMetadata" because + # the kernel is not from flashmla + def __init__(self): + pass + + +class FlashMLASparseDecodeMetadata(MLACommonDecodeMetadata): + + def __init__(self): + pass + + +@dataclass +class FlashMLASparseMetadata(MLACommonMetadata[MLASparsePrefillMetadata]): + pass + + +@dataclass +class FlashMLASparseMetadataBuilder( + MLACommonMetadataBuilder[FlashMLASparseMetadata]): + + def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], + vllm_config: VllmConfig, device: torch.device): + super().__init__(kv_cache_spec, layer_names, vllm_config, device, + FlashMLASparseMetadata) + + def _build_prefill( + self, common_attn_metadata: CommonAttentionMetadata + ) -> MLASparsePrefillMetadata: + return MLASparsePrefillMetadata() + + def _build_decode( + self, common_attn_metadata: CommonAttentionMetadata + ) -> FlashMLASparseDecodeMetadata: + return FlashMLASparseDecodeMetadata() + + def build(self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False) -> FlashMLASparseMetadata: + logger.info(f"build FlashMLASparseMetadata") + num_reqs = common_attn_metadata.num_reqs + num_actual_tokens = common_attn_metadata.num_actual_tokens + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\ + split_decodes_and_prefills(common_attn_metadata, + decode_threshold=self.reorder_batch_threshold) + return FlashMLASparseMetadata( + num_reqs=common_attn_metadata.num_reqs, + max_query_len=common_attn_metadata.max_query_len, + max_seq_len=common_attn_metadata.max_seq_len, + num_actual_tokens=common_attn_metadata.num_actual_tokens, + query_start_loc=common_attn_metadata.query_start_loc, + slot_mapping=common_attn_metadata.slot_mapping, + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + num_prefills=num_prefills, + prefill=self._build_prefill(common_attn_metadata), + decode=self._build_decode(common_attn_metadata), + ) + + +@dataclass +class FlashMLASparseImpl(MLACommonImpl[FlashMLASparseMetadata]): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _forward_prefill( + self, + q: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: FlashMLASparseMetadata, + k_scale: torch.Tensor, + ) -> torch.Tensor: + return torch.empty_like(q) + + def _forward_decode( + self, + q: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: FlashMLASparseMetadata, + ) -> torch.Tensor: + return torch.empty_like(q) + + def forward( + self, + layer: AttentionLayer, + q: torch.Tensor, + k_c_normed: torch.Tensor, + k_pe: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: FlashMLASparseMetadata, + output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return output.fill_(0) From aff95965c36bb575cd42537fe933fade34f491ee Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Sun, 21 Sep 2025 05:00:09 +0000 Subject: [PATCH 08/82] build sparse Signed-off-by: Lucas Wilkinson fix smoke tests Signed-off-by: Lucas Wilkinson moved to FlashMLA repo Signed-off-by: Lucas Wilkinson removed pytorch shim Signed-off-by: Lucas Wilkinson --- cmake/external_projects/flashmla.cmake | 90 ++++++++++++- setup.py | 3 + .../kernels/attention/test_flashmla_sparse.py | 120 ++++++++++++++++++ vllm/attention/ops/flashmla.py | 111 ++++++++++++++++ 4 files changed, 322 insertions(+), 2 deletions(-) create mode 100644 tests/kernels/attention/test_flashmla_sparse.py diff --git a/cmake/external_projects/flashmla.cmake b/cmake/external_projects/flashmla.cmake index 02224cfe3ee8..abde4118bcc0 100644 --- a/cmake/external_projects/flashmla.cmake +++ b/cmake/external_projects/flashmla.cmake @@ -18,8 +18,8 @@ if(FLASH_MLA_SRC_DIR) else() FetchContent_Declare( flashmla - GIT_REPOSITORY https://github.com/vllm-project/FlashMLA.git - GIT_TAG a757314c04eedd166e329e846c820eb1bdd702de + GIT_REPOSITORY https://github.com/vllm-model-0920/FlashMLA + GIT_TAG a25b977fae6925c45c3d0404c98c6ce6f4563dac GIT_PROGRESS TRUE CONFIGURE_COMMAND "" BUILD_COMMAND "" @@ -35,6 +35,10 @@ message(STATUS "FlashMLA is available at ${flashmla_SOURCE_DIR}") # sm90a cuda_archs_loose_intersection(FLASH_MLA_ARCHS "9.0a" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS) + ####################################################################### + # FlashMLA Dense -- _flashmla_C + ####################################################################### + set(FlashMLA_SOURCES ${flashmla_SOURCE_DIR}/csrc/flash_api.cpp ${flashmla_SOURCE_DIR}/csrc/kernels/get_mla_metadata.cu @@ -60,8 +64,90 @@ if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS) INCLUDE_DIRECTORIES ${FlashMLA_INCLUDES} USE_SABI 3 WITH_SOABI) + + ####################################################################### + # FlashMLA Sparse -- _flashmla_sparse_C + ####################################################################### + + # We use seperate libraries to avoid crosss contaminating includes, + # namely kernels/utils.h + + set(DECODE_FOLDER ${flashmla_SOURCE_DIR}/csrc/sparse/decode) + set(PREFILL_FOLDER ${flashmla_SOURCE_DIR}/csrc/sparse/prefill) + + # ---- Decode object library ---- + set(SPARSE_FLASHMLA_DECODE_SOURCES + ${DECODE_FOLDER}/flash_api.cpp + ${DECODE_FOLDER}/kernels/get_mla_metadata.cu + ${DECODE_FOLDER}/kernels/mla_combine.cu + ${DECODE_FOLDER}/kernels/fp8_sparse/splitkv_mla.cu + ) + + add_library(_flashmla_sparse_decode OBJECT ${SPARSE_FLASHMLA_DECODE_SOURCES}) + set_property(TARGET _flashmla_sparse_decode PROPERTY POSITION_INDEPENDENT_CODE ON) + + set_gencode_flags_for_srcs( + SRCS "${SPARSE_FLASHMLA_DECODE_SOURCES}" + CUDA_ARCHS "${FLASH_MLA_ARCHS}" + ) + + # Include paths for decode ONLY (do not leak DECODE_FOLDER to others) + target_include_directories(_flashmla_sparse_decode + PRIVATE + ${flashmla_SOURCE_DIR}/csrc/cutlass/include + ${TORCH_INCLUDE_DIRS} + ${Python_INCLUDE_DIRS} + ${DECODE_FOLDER} + ) + target_compile_options(_flashmla_sparse_decode PRIVATE + $<$:${VLLM_GPU_FLAGS}>) + + # ---- Prefill object library ---- + set(SPARSE_FLASHMLA_PREFILL_SOURCES + ${PREFILL_FOLDER}/api.cpp + ${PREFILL_FOLDER}/kernels/sm90/fwd/fwd.cu + ) + + add_library(_flashmla_sparse_prefill OBJECT ${SPARSE_FLASHMLA_PREFILL_SOURCES}) + set_property(TARGET _flashmla_sparse_prefill PROPERTY POSITION_INDEPENDENT_CODE ON) + + set_gencode_flags_for_srcs( + SRCS "${SPARSE_FLASHMLA_PREFILL_SOURCES}" + CUDA_ARCHS "${FLASH_MLA_ARCHS}" + ) + + target_include_directories(_flashmla_sparse_prefill + PRIVATE + ${flashmla_SOURCE_DIR}/csrc/cutlass/include + ${TORCH_INCLUDE_DIRS} + ${Python_INCLUDE_DIRS} + ${PREFILL_FOLDER} + ) + target_compile_options(_flashmla_sparse_prefill PRIVATE + $<$:${VLLM_GPU_FLAGS}>) + + # ---- Final extension target with unified API ---- + define_gpu_extension_target( + _flashmla_sparse_C + DESTINATION vllm + LANGUAGE ${VLLM_GPU_LANG} + SOURCES + ${flashmla_SOURCE_DIR}/csrc/sparse/api.cpp + $ + $ + COMPILE_FLAGS ${VLLM_GPU_FLAGS} + ARCHITECTURES ${VLLM_GPU_ARCHES} + # Only the common/public includes here; do NOT add decode/prefill folders + INCLUDE_DIRECTORIES + csrc/ + ${CUTLASS_INCLUDE_DIR} + ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} + USE_SABI 3 + WITH_SOABI + ) else() # Create an empty target for setup.py when not targeting sm90a systems add_custom_target(_flashmla_C) + add_custom_target(_flashmla_sparse_C) endif() diff --git a/setup.py b/setup.py index e4c40d22b928..ca8fd08a57fb 100644 --- a/setup.py +++ b/setup.py @@ -322,6 +322,7 @@ def extract_precompiled_and_patch_package(wheel_url_or_path: str) -> dict: "vllm/_C.abi3.so", "vllm/_moe_C.abi3.so", "vllm/_flashmla_C.abi3.so", + "vllm/_sparse_flashmla_C.abi3.so", "vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so", "vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so", "vllm/cumem_allocator.abi3.so", @@ -589,6 +590,8 @@ def _read_requirements(filename: str) -> list[str]: # not targeting a hopper system ext_modules.append( CMakeExtension(name="vllm._flashmla_C", optional=True)) + ext_modules.append( + CMakeExtension(name="vllm._flashmla_sparse_C", optional=True)) ext_modules.append(CMakeExtension(name="vllm.cumem_allocator")) if _build_custom_ops(): diff --git a/tests/kernels/attention/test_flashmla_sparse.py b/tests/kernels/attention/test_flashmla_sparse.py new file mode 100644 index 000000000000..6488e0c01e0c --- /dev/null +++ b/tests/kernels/attention/test_flashmla_sparse.py @@ -0,0 +1,120 @@ +import pytest +import torch + + +def _cuda_sm90_available() -> bool: + if not torch.cuda.is_available(): + return False + major, _ = torch.cuda.get_device_capability() + return major == 9 + + +@pytest.mark.cuda +def test_sparse_flashmla_imports_and_flags(): + import vllm.attention.ops.flashmla as fm + # Functions should exist + assert hasattr(fm, "get_sparse_mla_metadata") + assert hasattr(fm, "flash_mla_sparse_with_kvcache") + assert hasattr(fm, "flash_mla_sparse_prefill") + # Support check should return a (bool, reason) + ok, reason = fm.is_flashmla_supported() + assert isinstance(ok, bool) + assert (reason is None) or isinstance(reason, str) + + +def test_sparse_flashmla_metadata_smoke(): + import vllm.attention.ops.flashmla as fm + ok, reason = fm.is_flashmla_supported() + if not ok or not _cuda_sm90_available(): + pytest.skip(reason or "SM90 not available") + + device = torch.device("cuda") + batch_size = 1 + seqlen_q = 1 + num_heads_q = 128 + num_heads_k = 1 + q_seq_per_hk = seqlen_q * num_heads_q // num_heads_k + q_heads_per_hk = num_heads_q // num_heads_k + topk = 128 + + cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device) + + tile_md, num_splits = fm.get_sparse_mla_metadata(cache_seqlens, + q_seq_per_hk, + num_heads_k, + topk, + q_heads_per_hk) + assert tile_md.dtype == torch.int32 + assert num_splits.dtype == torch.int32 + + +def test_sparse_flashmla_decode_smoke(): + import vllm.attention.ops.flashmla as fm + ok, reason = fm.is_flashmla_supported() + if not ok or not _cuda_sm90_available(): + pytest.skip(reason or "SM90 not available") + + device = torch.device("cuda") + batch_size = 1 + seqlen_q = 1 + num_heads_q = 1 + head_dim_k = 576 + head_dim_v = 512 + num_heads_k = 1 + page_block_size = 64 + bytes_per_token = 656 + topk = 128 + + # Metadata + q_seq_per_hk = seqlen_q * num_heads_q // num_heads_k + q_heads_per_hk = num_heads_q // num_heads_k + cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device) + tile_md, num_splits = fm.get_sparse_mla_metadata(cache_seqlens, + q_seq_per_hk, + num_heads_k, + topk, + q_heads_per_hk) + + # Inputs + q = torch.zeros((batch_size, seqlen_q, num_heads_q, head_dim_k), + dtype=torch.bfloat16, + device=device) + k_cache = torch.zeros((1, page_block_size, num_heads_k, bytes_per_token), + dtype=torch.uint8, + device=device) + indices = torch.zeros((batch_size, seqlen_q, topk), + dtype=torch.int32, + device=device) + + out, lse = fm.flash_mla_sparse_with_kvcache(q, k_cache, cache_seqlens, + head_dim_v, tile_md, + num_splits, indices) + assert out.shape[0] == batch_size + assert out.shape[-1] == head_dim_v + assert lse.shape[0] == batch_size + + +def test_sparse_flashmla_prefill_smoke(): + import vllm.attention.ops.flashmla as fm + ok, reason = fm.is_flashmla_supported() + if not ok or not _cuda_sm90_available(): + pytest.skip(reason or "SM90 not available") + + device = torch.device("cuda") + s_q = 1 + s_kv = 1 + h_q = 64 # kernel expects multiple of 64 + h_kv = 1 + d_qk = 576 + d_v = 512 + topk = 128 + + q = torch.zeros((s_q, h_q, d_qk), dtype=torch.bfloat16, device=device) + kv = torch.zeros((s_kv, h_kv, d_qk), dtype=torch.bfloat16, device=device) + indices = torch.zeros((s_q, h_kv, topk), dtype=torch.int32, device=device) + + out, max_logits, lse = fm.flash_mla_sparse_prefill(q, kv, indices, 1.0, d_v) + assert out.shape == (s_q, h_q, d_v) + assert max_logits.shape == (s_q, h_q) + assert lse.shape == (s_q, h_q) + diff --git a/vllm/attention/ops/flashmla.py b/vllm/attention/ops/flashmla.py index 2c3e8c42400c..d6ef46bd2d4a 100644 --- a/vllm/attention/ops/flashmla.py +++ b/vllm/attention/ops/flashmla.py @@ -13,6 +13,7 @@ if current_platform.is_cuda(): try: import vllm._flashmla_C # noqa: F401 + import vllm._flashmla_sparse_C # noqa: F401 _flashmla_C_AVAILABLE = True except ImportError: _flashmla_C_AVAILABLE = False @@ -109,6 +110,116 @@ def flash_mla_with_kvcache( # Note(hc): need revisit when we support DCP with decode query_len > 1. return out.squeeze(1), softmax_lse.squeeze(-1) +# ------------------------ Sparse FlashMLA bindings ------------------------- + +def get_sparse_mla_metadata( + cache_seqlens: torch.Tensor, + q_seq_per_hk: int, + num_heads_k: int, + topk: int, + q_heads_per_hk: Optional[int] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Arguments: + cache_seqlens: (batch_size), dtype torch.int32. + q_seq_per_hk: Equals to seq_len_q * num_heads_q // num_heads_k. + num_heads_k: num_heads_k. + topk: topk + q_heads_per_hk: equals to num_heads_q // num_heads_k. Only need to be + specified when topk is not None. + + Return: + tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), + dtype torch.int32. + num_splits: (batch_size + 1), dtype torch.int32. + """ + return torch.ops._flashmla_sparse_C.get_mla_metadata(cache_seqlens, q_seq_per_hk, + num_heads_k, topk, q_heads_per_hk) + +def flash_mla_sparse_with_kvcache( + q: torch.Tensor, + k_cache: torch.Tensor, + cache_seqlens: torch.Tensor, + head_dim_v: int, + tile_scheduler_metadata: torch.Tensor, + num_splits: torch.Tensor, + indices_in_kvcache: torch.Tensor, + softmax_scale: Optional[float] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Arguments: + q: (batch_size, seq_len_q, num_heads_q, head_dim). + k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). + cache_seqlens: (batch_size), torch.int32. + head_dim_v: Head_dim of v. + tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), + torch.int32, returned by get_sparse_mla_metadata. + num_splits: (batch_size + 1), torch.int32, returned by + get_sparse_mla_metadata. + indices_in_kvcache: (batch_size, seq_len_q, topk). KV indices when + sparse attention is enabled. Note that + indices_in_kvcache[i][j][k] = + (the index of the page block where token t resides) * + page_block_size + (the offset of token t within that page block), + where t is the k-th token of the j-th q-sequence in the i-th batch. + softmax_scale: float. Scaling of QK^T before softmax. + Defaults to 1 / sqrt(head_dim). + + Explanation of K/V cache layout: + We quantize the NoPE part of each token (in 1x128 granularity), + yielding 512 float8_e4m3 values and 4 float32 scale factors. For the + RoPE part, we keep it as 64 bfloat16. Each token occupies 656 bytes: + - First 512 bytes: quantized NoPE (512 x float8_e4m3) + - Next 16 bytes: scale factors (4 x float32) + - Last 128 bytes: RoPE (64 x bfloat16) + + Return: + out: (batch_size, seq_len_q, num_heads_q, head_dim_v). + softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. + """ + if softmax_scale is None: + softmax_scale = q.shape[-1]**(-0.5) + # Strict shape checks like the reference implementation + assert k_cache.shape[-1] == 656, ( + "The last dim of k_cache must be 656 (=512+2*16+4*4) when " + "is_fp8_kvcache is True") + assert k_cache.shape[-2] == 1, ( + "The number of K heads must be 1 when is_fp8_kvcache is True") + + out, softmax_lse = torch.ops._flashmla_sparse_C.fwd_kvcache_mla( + q, k_cache, head_dim_v, cache_seqlens, softmax_scale, + tile_scheduler_metadata, num_splits, indices_in_kvcache) + return out, softmax_lse + + +def flash_mla_sparse_prefill( + q: torch.Tensor, + kv: torch.Tensor, + indices: torch.Tensor, + sm_scale: float, + d_v: int = 512, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Sparse attention forward operator, for prefill + + Args: + q: [s_q, h_q, d_qk], bfloat16 + kv: [s_kv, h_kv, d_qk], bfloat16 + indices: [s_q, h_kv, topk], int32. Invalid indices should be set to -1, + or to a number >= s_kv + sm_scale: float, scaling factor for the attention scores + d_v: dimension of the value, default (and only supported) is 512 + + Returns: + Returns (output, max_logits, lse) + For definitions of output, max_logits, and lse, please refer to README.md + - output: [s_q, h_q, d_v], bfloat16, the result of attention + - max_logits: [s_q, h_q], float + - lse: [s_q, h_q], float, base-2 + """ + results = torch.ops._flashmla_sparse_C.sparse_topk_attn_fwd( + q, kv, indices, sm_scale, d_v) + return results # # TODO: Add fake functions From 22d0fe52385107024e83d4f2ddb897da819e2dda Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sun, 21 Sep 2025 22:25:59 +0000 Subject: [PATCH 09/82] pass in selected index Signed-off-by: Chen Zhang --- vllm/model_executor/layers/mla.py | 8 +++ vllm/v1/attention/backends/mla/common.py | 16 ++++- .../attention/backends/mla/flashmla_sparse.py | 63 ++++++++++++++----- 3 files changed, 68 insertions(+), 19 deletions(-) diff --git a/vllm/model_executor/layers/mla.py b/vllm/model_executor/layers/mla.py index 35680a969a65..0259b63a47c8 100644 --- a/vllm/model_executor/layers/mla.py +++ b/vllm/model_executor/layers/mla.py @@ -78,6 +78,8 @@ def __init__( self.kv_b_proj = mla_modules.kv_b_proj self.rotary_emb = mla_modules.rotary_emb self.o_proj = mla_modules.o_proj + self.indexer = mla_modules.indexer + self.topk_tokens = mla_modules.indexer.topk_tokens # In the MLA backend, kv_cache includes both k_c and # pe (i.e. decoupled position embeddings). In particular, @@ -149,6 +151,12 @@ def forward_native( q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb( positions, q[..., self.qk_nope_head_dim:], k_pe) + topk_indices = torch.zeros(q.shape[0], self.topk_tokens) + + # NOTE(Chen): a bit hacky, but need to modify Attention.forward + # otherwise. Try to refactor this later. + self.mla_attn.impl.set_topk_indices(topk_indices) + attn_out = self.mla_attn( q, kv_c_normed, diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 33bfb4127be1..ec618b59f06b 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -947,6 +947,7 @@ def __init__( qk_head_dim: int, v_head_dim: int, kv_b_proj: ColumnParallelLinear, + use_sparse: bool = False ) -> None: if kv_sharing_target_layer_name is not None: raise NotImplementedError("KV sharing is not supported for MLA") @@ -1428,6 +1429,7 @@ def _forward_prefill( kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, k_scale: torch.Tensor, + topk_indices: Optional[torch.Tensor] = None, # sparse attn ) -> torch.Tensor: assert attn_metadata.prefill is not None assert self.dcp_world_size is not None @@ -1482,6 +1484,7 @@ def _forward_decode( kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: M, layer: AttentionLayer, + topk_indices: Optional[torch.Tensor] = None, # sparse attn ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: raise NotImplementedError @@ -1548,6 +1551,14 @@ def forward( kv_cache_dtype=self.kv_cache_dtype, scale=layer._k_scale, ) + + if hasattr(self, "topk_indices"): + topk_indices = self.topk_indices + decode_topk_indices = topk_indices[:num_decode_tokens] + prefill_topk_indices = topk_indices[num_decode_tokens:] + else: + decode_topk_indices = None + prefill_topk_indices = None if fp8_attention: kv_cache = kv_cache.view(current_platform.fp8_dtype()) @@ -1555,7 +1566,7 @@ def forward( if has_prefill: output[num_decode_tokens:] = self._forward_prefill( prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, - attn_metadata, layer._k_scale) + attn_metadata, layer._k_scale, prefill_topk_indices) if has_decode: assert attn_metadata.decode is not None @@ -1601,7 +1612,8 @@ def forward( # call decode attn attn_out, lse = self._forward_decode(decode_q, kv_cache, - attn_metadata, layer) + attn_metadata, layer, + decode_topk_indices) # recorect dcp attn_out with lse. if self.dcp_world_size > 1: diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 931a1c264ac7..b6316f5725b5 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -115,36 +115,65 @@ def build(self, @dataclass class FlashMLASparseImpl(MLACommonImpl[FlashMLASparseMetadata]): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], + # MLA Specific Arguments + **mla_args) -> None: + super().__init__(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype, + logits_soft_cap, attn_type, + kv_sharing_target_layer_name, **mla_args) + # self.sm_scale = + self.topk_indices = None + + + def set_topk_indices(self, topk_indices: torch.Tensor): + self.topk_indices = topk_indices def _forward_prefill( self, q: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: FlashMLASparseMetadata, k_scale: torch.Tensor, + topk_indices: Optional[torch.Tensor] = None, # sparse attn ) -> torch.Tensor: - return torch.empty_like(q) + assert topk_indices is not None + + # # assume indice of shape [num_prefill_tokens, topk] + # block_id_in_req = topk_indices // self.block_size + + logger.info(f"called _forward_prefill") + # NOTE(Chen): shape is unsure + + return torch.zeros((q.shape[0], 2048), dtype=q.dtype, device=q.device) def _forward_decode( self, q: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: FlashMLASparseMetadata, - ) -> torch.Tensor: - return torch.empty_like(q) - - def forward( - self, layer: AttentionLayer, - q: torch.Tensor, - k_c_normed: torch.Tensor, - k_pe: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: FlashMLASparseMetadata, - output: Optional[torch.Tensor] = None, - output_scale: Optional[torch.Tensor] = None, - output_block_scale: Optional[torch.Tensor] = None, + topk_indices: Optional[torch.Tensor] = None, # sparse attn ) -> torch.Tensor: - return output.fill_(0) + + assert topk_indices is not None + + # # assume indice of shape [num_decode_tokens, topk] + # block_id_in_req = topk_indices // self.block_size + + logger.info(f"called _forward_decode") + # NOTE(Chen): shape is unsure + return torch.zeros((q[0].shape[0], 16*512), dtype=q[0].dtype, device=q[0].device), torch.zeros((q[0].shape[0], 128), dtype=q[0].dtype, device=q[0].device) From 3b9df196d46719e910b830218ebb18219169ebbc Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sun, 21 Sep 2025 22:49:54 +0000 Subject: [PATCH 10/82] make basic.py runable Signed-off-by: Chen Zhang --- vllm/model_executor/layers/mla.py | 10 ++++++---- vllm/platforms/cuda.py | 1 + vllm/v1/attention/backends/mla/common.py | 3 ++- vllm/v1/attention/backends/mla/flashmla.py | 1 + 4 files changed, 10 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/mla.py b/vllm/model_executor/layers/mla.py index 0259b63a47c8..79e4a0e8fe99 100644 --- a/vllm/model_executor/layers/mla.py +++ b/vllm/model_executor/layers/mla.py @@ -80,6 +80,7 @@ def __init__( self.o_proj = mla_modules.o_proj self.indexer = mla_modules.indexer self.topk_tokens = mla_modules.indexer.topk_tokens + self.use_sparse = mla_modules.is_sparse and False # In the MLA backend, kv_cache includes both k_c and # pe (i.e. decoupled position embeddings). In particular, @@ -151,11 +152,12 @@ def forward_native( q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb( positions, q[..., self.qk_nope_head_dim:], k_pe) - topk_indices = torch.zeros(q.shape[0], self.topk_tokens) + if self.use_sparse: + topk_indices = torch.zeros(q.shape[0], self.topk_tokens) - # NOTE(Chen): a bit hacky, but need to modify Attention.forward - # otherwise. Try to refactor this later. - self.mla_attn.impl.set_topk_indices(topk_indices) + # NOTE(Chen): a bit hacky, but need to modify Attention.forward + # otherwise. Try to refactor this later. + self.mla_attn.impl.set_topk_indices(topk_indices) attn_out = self.mla_attn( q, diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 56ff8bff3ad0..ffaac6c7c01b 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -236,6 +236,7 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, use_mla, has_sink, use_sparse) -> str: if use_mla: + use_sparse = False # TODO(lucas): refactor to be more concise # we should probably consider factoring out V1 here diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index ec618b59f06b..6f1cd7a6351e 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -1002,6 +1002,7 @@ def __init__( and current_platform.get_device_capability()[0] == 9) self.dcp_world_size: Optional[int] = None + self.use_sparse = use_sparse def _flash_attn_varlen_diff_headdims(self, q, @@ -1552,7 +1553,7 @@ def forward( scale=layer._k_scale, ) - if hasattr(self, "topk_indices"): + if self.use_sparse: topk_indices = self.topk_indices decode_topk_indices = topk_indices[:num_decode_tokens] prefill_topk_indices = topk_indices[num_decode_tokens:] diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 150e38553e4b..c6db469b055b 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -176,6 +176,7 @@ def _forward_decode( kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: FlashMLAMetadata, layer: AttentionLayer, + topk_indices: Optional[torch.Tensor] = None, # sparse attn ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: assert kv_c_and_k_pe_cache.numel() > 0 assert attn_metadata.decode is not None From f85564f2eebb5c71a3cd531b4ed4231b11093b60 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sun, 21 Sep 2025 23:11:46 +0000 Subject: [PATCH 11/82] small fix Signed-off-by: Chen Zhang --- vllm/v1/attention/backends/mla/flashmla_sparse.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index b6316f5725b5..5d772b7ee707 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -132,7 +132,7 @@ def __init__( super().__init__(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, logits_soft_cap, attn_type, - kv_sharing_target_layer_name, **mla_args) + kv_sharing_target_layer_name, use_sparse=True, **mla_args) # self.sm_scale = self.topk_indices = None From fe45b061c7ec4d500fad3945495aa036a3b3cde8 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Mon, 22 Sep 2025 00:36:10 +0000 Subject: [PATCH 12/82] reduce api change Signed-off-by: Chen Zhang --- vllm/v1/attention/backends/mla/common.py | 19 +++---------------- vllm/v1/attention/backends/mla/flashmla.py | 1 - .../attention/backends/mla/flashmla_sparse.py | 15 ++++++--------- 3 files changed, 9 insertions(+), 26 deletions(-) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 6f1cd7a6351e..7f8d6ccff3eb 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -946,8 +946,7 @@ def __init__( qk_rope_head_dim: int, qk_head_dim: int, v_head_dim: int, - kv_b_proj: ColumnParallelLinear, - use_sparse: bool = False + kv_b_proj: ColumnParallelLinear ) -> None: if kv_sharing_target_layer_name is not None: raise NotImplementedError("KV sharing is not supported for MLA") @@ -1002,7 +1001,6 @@ def __init__( and current_platform.get_device_capability()[0] == 9) self.dcp_world_size: Optional[int] = None - self.use_sparse = use_sparse def _flash_attn_varlen_diff_headdims(self, q, @@ -1430,7 +1428,6 @@ def _forward_prefill( kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, k_scale: torch.Tensor, - topk_indices: Optional[torch.Tensor] = None, # sparse attn ) -> torch.Tensor: assert attn_metadata.prefill is not None assert self.dcp_world_size is not None @@ -1485,7 +1482,6 @@ def _forward_decode( kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: M, layer: AttentionLayer, - topk_indices: Optional[torch.Tensor] = None, # sparse attn ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: raise NotImplementedError @@ -1552,14 +1548,6 @@ def forward( kv_cache_dtype=self.kv_cache_dtype, scale=layer._k_scale, ) - - if self.use_sparse: - topk_indices = self.topk_indices - decode_topk_indices = topk_indices[:num_decode_tokens] - prefill_topk_indices = topk_indices[num_decode_tokens:] - else: - decode_topk_indices = None - prefill_topk_indices = None if fp8_attention: kv_cache = kv_cache.view(current_platform.fp8_dtype()) @@ -1567,7 +1555,7 @@ def forward( if has_prefill: output[num_decode_tokens:] = self._forward_prefill( prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, - attn_metadata, layer._k_scale, prefill_topk_indices) + attn_metadata, layer._k_scale) if has_decode: assert attn_metadata.decode is not None @@ -1613,8 +1601,7 @@ def forward( # call decode attn attn_out, lse = self._forward_decode(decode_q, kv_cache, - attn_metadata, layer, - decode_topk_indices) + attn_metadata, layer) # recorect dcp attn_out with lse. if self.dcp_world_size > 1: diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index c6db469b055b..150e38553e4b 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -176,7 +176,6 @@ def _forward_decode( kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: FlashMLAMetadata, layer: AttentionLayer, - topk_indices: Optional[torch.Tensor] = None, # sparse attn ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: assert kv_c_and_k_pe_cache.numel() > 0 assert attn_metadata.decode is not None diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 5d772b7ee707..c0f8f0f59143 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -132,7 +132,7 @@ def __init__( super().__init__(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, logits_soft_cap, attn_type, - kv_sharing_target_layer_name, use_sparse=True, **mla_args) + kv_sharing_target_layer_name, **mla_args) # self.sm_scale = self.topk_indices = None @@ -147,15 +147,12 @@ def _forward_prefill( k_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: FlashMLASparseMetadata, - k_scale: torch.Tensor, - topk_indices: Optional[torch.Tensor] = None, # sparse attn + k_scale: torch.Tensor ) -> torch.Tensor: - assert topk_indices is not None - # # assume indice of shape [num_prefill_tokens, topk] # block_id_in_req = topk_indices // self.block_size - - logger.info(f"called _forward_prefill") + topk_indices = self.topk_indices[attn_metadata.num_decodes:] + logger.info(f"called _forward_prefill with topk_indices shape {topk_indices.shape}") # NOTE(Chen): shape is unsure return torch.zeros((q.shape[0], 2048), dtype=q.dtype, device=q.device) @@ -169,11 +166,11 @@ def _forward_decode( topk_indices: Optional[torch.Tensor] = None, # sparse attn ) -> torch.Tensor: - assert topk_indices is not None + topk_indices = self.topk_indices[:attn_metadata.num_decodes] # # assume indice of shape [num_decode_tokens, topk] # block_id_in_req = topk_indices // self.block_size - logger.info(f"called _forward_decode") + logger.info(f"called _forward_decode with topk_indices shape {topk_indices.shape}") # NOTE(Chen): shape is unsure return torch.zeros((q[0].shape[0], 16*512), dtype=q[0].dtype, device=q[0].device), torch.zeros((q[0].shape[0], 128), dtype=q[0].dtype, device=q[0].device) From 216c42fa024c3262167bf2224796ceff8768da76 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Mon, 22 Sep 2025 00:36:50 +0000 Subject: [PATCH 13/82] revert Signed-off-by: Chen Zhang --- vllm/v1/attention/backends/mla/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 7f8d6ccff3eb..33bfb4127be1 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -946,7 +946,7 @@ def __init__( qk_rope_head_dim: int, qk_head_dim: int, v_head_dim: int, - kv_b_proj: ColumnParallelLinear + kv_b_proj: ColumnParallelLinear, ) -> None: if kv_sharing_target_layer_name is not None: raise NotImplementedError("KV sharing is not supported for MLA") From 840f20519fdf6ce2d3c0fdfcf565bf34264606c9 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Mon, 22 Sep 2025 01:20:55 +0000 Subject: [PATCH 14/82] format Signed-off-by: Lucas Wilkinson --- vllm/attention/ops/flashmla.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/vllm/attention/ops/flashmla.py b/vllm/attention/ops/flashmla.py index d6ef46bd2d4a..7c7e010e2af2 100644 --- a/vllm/attention/ops/flashmla.py +++ b/vllm/attention/ops/flashmla.py @@ -110,8 +110,10 @@ def flash_mla_with_kvcache( # Note(hc): need revisit when we support DCP with decode query_len > 1. return out.squeeze(1), softmax_lse.squeeze(-1) + # ------------------------ Sparse FlashMLA bindings ------------------------- + def get_sparse_mla_metadata( cache_seqlens: torch.Tensor, q_seq_per_hk: int, @@ -133,8 +135,9 @@ def get_sparse_mla_metadata( dtype torch.int32. num_splits: (batch_size + 1), dtype torch.int32. """ - return torch.ops._flashmla_sparse_C.get_mla_metadata(cache_seqlens, q_seq_per_hk, - num_heads_k, topk, q_heads_per_hk) + return torch.ops._flashmla_sparse_C.get_mla_metadata( + cache_seqlens, q_seq_per_hk, num_heads_k, topk, q_heads_per_hk) + def flash_mla_sparse_with_kvcache( q: torch.Tensor, @@ -212,7 +215,6 @@ def flash_mla_sparse_prefill( Returns: Returns (output, max_logits, lse) - For definitions of output, max_logits, and lse, please refer to README.md - output: [s_q, h_q, d_v], bfloat16, the result of attention - max_logits: [s_q, h_q], float - lse: [s_q, h_q], float, base-2 @@ -221,6 +223,7 @@ def flash_mla_sparse_prefill( q, kv, indices, sm_scale, d_v) return results + # # TODO: Add fake functions # From 0f54ca63a611b1aadf84f028a305f0cf0b2a457d Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Mon, 22 Sep 2025 21:32:54 +0000 Subject: [PATCH 15/82] deepgemm integration --- .../attention/test_deepgemm_attention.py | 298 ++++++++++++++++++ vllm/utils/deep_gemm.py | 127 +++++++- 2 files changed, 423 insertions(+), 2 deletions(-) create mode 100644 tests/kernels/attention/test_deepgemm_attention.py diff --git a/tests/kernels/attention/test_deepgemm_attention.py b/tests/kernels/attention/test_deepgemm_attention.py new file mode 100644 index 000000000000..d6c7c4368de9 --- /dev/null +++ b/tests/kernels/attention/test_deepgemm_attention.py @@ -0,0 +1,298 @@ +import random +import pytest +import torch + +from vllm.platforms import current_platform +from vllm.utils import has_deep_gemm, cdiv +from vllm.utils.deep_gemm import ( + _ceil_to_ue8m0, + fp8_mqa_logits, + calc_diff, + get_paged_mqa_logits_metadata, + fp8_paged_mqa_logits, +) + + +def kv_cache_cast_to_fp8(x: torch.Tensor) -> torch.Tensor: + # x: (num_blocks, block_size, 1, head_dim) + num_blocks, block_size, num_heads, head_dim = x.shape + assert num_heads == 1 + x_amax = x.abs().float().amax(dim=3, keepdim=True).clamp(1e-4) + sf = x_amax / 448.0 + x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn) + x_fp8 = torch.empty( + (num_blocks, block_size * (head_dim + 4)), + device=x.device, + dtype=torch.uint8, + ) + x_fp8[:, : block_size * head_dim] = x_scaled.view( + num_blocks, block_size * head_dim + ).view(dtype=torch.uint8) + x_fp8[:, block_size * head_dim :] = sf.view(num_blocks, block_size).view( + dtype=torch.uint8 + ) + return x_fp8.view(num_blocks, block_size, num_heads, head_dim + 4) + + +def per_custom_dims_cast_to_fp8( + x: torch.Tensor, dims: tuple, use_ue8m0: bool +) -> tuple[torch.Tensor, torch.Tensor]: + excluded_dims = tuple([i for i in range(x.dim()) if i not in set(dims)]) + x_amax = x.abs().float().amax(dim=excluded_dims, keepdim=True).clamp(1e-4) + sf = x_amax / 448.0 + sf = _ceil_to_ue8m0(sf) if use_ue8m0 else sf + x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn) + return x_scaled, sf.squeeze() + + +def _generate_cp_test_data(seq_len: int, seq_len_kv: int): + assert seq_len_kv % seq_len == 0 and seq_len % 2 == 0 + chunk_size = seq_len // 2 + cp_size = seq_len_kv // seq_len + cp_id = cp_size // 3 + ks = torch.zeros(seq_len, dtype=torch.int, device="cuda") + ke = torch.zeros(seq_len, dtype=torch.int, device="cuda") + for i in range(chunk_size): + ke[i] = cp_id * chunk_size + i + ke[i + chunk_size] = (cp_size * 2 - 1 - cp_id) * chunk_size + i + return ks, ke + + +def _ref_fp8_mqa_logits( + q: torch.Tensor, + kv: torch.Tensor, + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, +): + seq_len_kv = kv.shape[0] + + k = kv + q = q.float() + k = k.float() + + mask_lo = ( + torch.arange(0, seq_len_kv, device="cuda")[None, :] + >= cu_seqlen_ks[:, None] + ) + mask_hi = ( + torch.arange(0, seq_len_kv, device="cuda")[None, :] + < cu_seqlen_ke[:, None] + ) + mask = mask_lo & mask_hi + + score = torch.einsum("mhd,nd->hmn", q, k) + logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0) + logits = logits.masked_fill(~mask, float("-inf")) + + return logits + + +@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only") +@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available") +def test_deepgemm_fp8_mqa_logits(): + torch.manual_seed(0) + random.seed(0) + num_heads, head_dim = 32, 128 + for seq_len in (512,): + for seq_len_kv in (1024,): + for disable_cp in (False, True): + q = torch.randn( + seq_len, + num_heads, + head_dim, + device="cuda", + dtype=torch.bfloat16, + ) + kv = torch.randn( + seq_len_kv, head_dim, device="cuda", dtype=torch.bfloat16 + ) + weights = torch.randn( + seq_len, num_heads, device="cuda", dtype=torch.float32 + ) + + if disable_cp: + ks = torch.zeros(seq_len, dtype=torch.int, device="cuda") + ke = torch.arange( + seq_len, dtype=torch.int, device="cuda" + ) + (seq_len_kv - seq_len) + else: + ks, ke = _generate_cp_test_data(seq_len, seq_len_kv) + + q_fp8 = q.to(torch.float8_e4m3fn) + kv_fp8 = per_custom_dims_cast_to_fp8(kv, (0,), False) + logits = fp8_mqa_logits(q_fp8, kv_fp8, weights, ks, ke) + + ref_logits = _ref_fp8_mqa_logits( + q=q, + kv=kv, + weights=weights, + cu_seqlen_ks=ks, + cu_seqlen_ke=ke, + ) + + ref_neginf_mask = ref_logits == float("-inf") + neginf_mask = logits == float("-inf") + assert torch.equal(neginf_mask, ref_neginf_mask) + + ref_logits = ref_logits.masked_fill(ref_neginf_mask, 0) + logits = logits.masked_fill(neginf_mask, 0) + diff = calc_diff(logits, ref_logits) + assert diff < 1e-3, f"{diff=}" + + +def _ref_fp8_paged_mqa_logits( + q: torch.Tensor, + kv_cache: torch.Tensor, + weights: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + max_model_len: int, +): + batch_size, next_n, _, _ = q.size() + _, block_size, _, _ = kv_cache.size() + logits = torch.full( + [batch_size * next_n, max_model_len], + float("-inf"), + device=q.device, + dtype=torch.float32, + ) + context_lens_list = context_lens.tolist() + for i in range(batch_size): + context_len = context_lens_list[i] + q_offsets = torch.arange( + context_len - next_n, context_len, device="cuda" + ) + weight_slice = ( + weights[i * next_n : (i + 1) * next_n, :] + .transpose(0, 1) + .contiguous() + ) + for block_rk in range(cdiv(context_len, block_size)): + block_idx = block_tables[i][block_rk] + qx, kx = q[i], kv_cache[block_idx] + k_offsets = torch.arange( + block_rk * block_size, + (block_rk + 1) * block_size, + device="cuda", + ) + mask = (k_offsets[None, :] < context_len) & ( + k_offsets[None, :] <= q_offsets[:, None] + ) + s = torch.where( + mask[None, :, :], + (qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to( + logits.dtype + ), + float("-inf"), + ) + s = torch.relu(s) * weight_slice[..., None] + s = s.sum(dim=0) + logits[ + i * next_n : (i + 1) * next_n, + block_rk * block_size : (block_rk + 1) * block_size, + ] = torch.where( + k_offsets[None, :] <= q_offsets[:, None], s, float("-inf") + ) + return logits + + +@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only") +@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available") +def test_deepgemm_fp8_paged_mqa_logits(): + torch.manual_seed(0) + random.seed(0) + + max_model_len = 4096 + for batch_size, next_n in [(4, 1), (2, 2)]: + for heads, index_dim in [(16, 128)]: + for avg_kv in (2048,): + num_blocks, blocksize = max_model_len * 2, 64 + + q = torch.randn( + (batch_size, next_n, heads, index_dim), + device="cuda", + dtype=torch.bfloat16, + ) + kv_cache = torch.randn( + (num_blocks, blocksize, 1, index_dim), + device="cuda", + dtype=torch.bfloat16, + ) + weights = torch.randn( + (batch_size * next_n, heads), + device="cuda", + dtype=torch.float32, + ) + + context_lens = ( + torch.randint( + int(0.8 * avg_kv), int(1.2 * avg_kv), (batch_size,) + ) + .cuda() + .to(torch.int32) + ) + max_block_len = ( + (context_lens.max().item() + blocksize - 1) + // blocksize + * blocksize + ) + block_tables = torch.zeros( + (batch_size, max_block_len), + device="cuda", + dtype=torch.int32, + ) + + counter = 0 + block_idx_pool = list(range(num_blocks)) + random.shuffle(block_idx_pool) + for i in range(batch_size): + ctx_len = int(context_lens[i].item()) + for j in range((ctx_len + blocksize - 1) // blocksize): + block_tables[i][j] = block_idx_pool[counter] + counter += 1 + + q_fp8 = q.to(torch.float8_e4m3fn) + kv_cache_fp8 = kv_cache_cast_to_fp8(kv_cache) + + schedule_metadata = get_paged_mqa_logits_metadata( + context_lens, blocksize, 132 + ) + logits = fp8_paged_mqa_logits( + q_fp8, + kv_cache_fp8, + weights, + context_lens, + block_tables, + schedule_metadata, + max_model_len, + ) + + ref_logits = _ref_fp8_paged_mqa_logits( + q, + kv_cache, + weights, + context_lens, + block_tables, + max_model_len, + ) + + positions = ( + torch.arange(max_model_len, device="cuda") + .unsqueeze(0) + .expand(batch_size * next_n, -1) + ) + row_indices = ( + torch.arange(batch_size * next_n, device="cuda") // next_n + ) + next_n_offset = ( + torch.arange(batch_size * next_n, device="cuda") % next_n + ) + mask = positions <= ( + context_lens[row_indices] - next_n + next_n_offset + ).unsqueeze(1) + + logits = logits.masked_fill(~mask, 0) + ref_logits = ref_logits.masked_fill(~mask, 0) + diff = calc_diff(logits, ref_logits) + assert diff < 1e-3, f"{diff=}" diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index 38d92f01192b..10ba6aa94add 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -70,15 +70,26 @@ def _missing(*_: Any, **__: Any) -> NoReturn: _fp8_gemm_nt_impl: Callable[..., Any] | None = None _grouped_impl: Callable[..., Any] | None = None _grouped_masked_impl: Callable[..., Any] | None = None +_fp8_mqa_logits_impl: Callable[..., Any] | None = None +_fp8_paged_mqa_logits_impl: Callable[..., Any] | None = None +_get_paged_mqa_logits_metadata_impl: Callable[..., Any] | None = None def _lazy_init() -> None: """Import deep_gemm and resolve symbols on first use.""" global _fp8_gemm_nt_impl, _grouped_impl, _grouped_masked_impl + global _fp8_mqa_logits_impl, _fp8_paged_mqa_logits_impl + global _get_paged_mqa_logits_metadata_impl # fast path - if (_fp8_gemm_nt_impl is not None or _grouped_impl is not None - or _grouped_masked_impl is not None): + if ( + _fp8_gemm_nt_impl is not None + or _grouped_impl is not None + or _grouped_masked_impl is not None + or _fp8_mqa_logits_impl is not None + or _fp8_paged_mqa_logits_impl is not None + or _get_paged_mqa_logits_metadata_impl is not None + ): return if not has_deep_gemm(): @@ -95,6 +106,11 @@ def _lazy_init() -> None: _fp8_gemm_nt_impl = getattr(_dg, "fp8_gemm_nt", None) _grouped_impl = getattr(_dg, "m_grouped_fp8_gemm_nt_contiguous", None) _grouped_masked_impl = getattr(_dg, "fp8_m_grouped_gemm_nt_masked", None) + _fp8_mqa_logits_impl = getattr(_dg, "fp8_mqa_logits", None) + _fp8_paged_mqa_logits_impl = getattr(_dg, "fp8_paged_mqa_logits", None) + _get_paged_mqa_logits_metadata_impl = getattr( + _dg, "get_paged_mqa_logits_metadata", None + ) def fp8_gemm_nt(*args, **kwargs): @@ -123,6 +139,110 @@ def fp8_m_grouped_gemm_nt_masked(*args, **kwargs): *args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs) +def fp8_mqa_logits( + q: torch.Tensor, + kv: tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, + clean_logits: bool = True, +) -> torch.Tensor: + """Compute FP8 MQA logits for a single sequence without KV paging. + + Args: + q: Query tensor of shape [M, H, D]. Casted to + `torch.float8_e4m3fn` by caller. + kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with + dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or + [N, 1]) with dtype `torch.float32`. + weights: weights of shape [M, H], dtype `torch.float32`. + cu_seqlen_ks: Start indices (inclusive) for valid K per query position, + shape [M], dtype int32. + cu_seqlen_ke: End indices (exclusive) for valid K per query position, + shape [M], dtype int32. + clean_logits: If True, fill logits outside [ks, ke) with -inf on the + backend. Defaults to True. + + Returns: + Logits tensor of shape [M, N], dtype `torch.float32`. + """ + _lazy_init() + if _fp8_mqa_logits_impl is None: + return _missing() + return _fp8_mqa_logits_impl(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke, + clean_logits) + + + +def get_paged_mqa_logits_metadata( + context_lens: torch.Tensor, block_size: int, num_sms: int +) -> torch.Tensor: + """Build scheduling metadata for paged MQA logits. + + Args: + context_lens: Tensor of shape [B], dtype int32; effective context length + per batch element. + block_size: KV-cache block size in tokens (e.g., 64). + num_sms: Number of SMs available. 132 for Hopper + + Returns: + Backend-specific tensor consumed by `fp8_paged_mqa_logits` to + schedule work across SMs. + """ + _lazy_init() + if _get_paged_mqa_logits_metadata_impl is None: + return _missing() + return _get_paged_mqa_logits_metadata_impl( + context_lens, block_size, num_sms + ) + + +def fp8_paged_mqa_logits( + q_fp8: torch.Tensor, + kv_cache_fp8: torch.Tensor, + weights: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + schedule_metadata: torch.Tensor, + max_model_len: int, +) -> torch.Tensor: + """Compute FP8 MQA logits using paged KV-cache. + + Args: + q_fp8: Query tensor of shape [B, next_n, H, D]. Casted to + `torch.float8_e4m3fn` by caller. + kv_cache_fp8: Paged KV-cache in packed FP8+scale layout with shape + [num_blocks, block_size, 1, D+4], dtype `torch.uint8`. The last + 4 bytes per (block,pos) store the `float` dequant scale. + weights: Tensor of shape [B * next_n, H], dtype `torch.float32`. + context_lens: Tensor of shape [B], dtype int32; effective context length + for each batch element. + block_tables: Tensor of shape [B, max_blocks], dtype int32; maps logical + block indices to physical blocks in the paged cache. + schedule_metadata: Returned by `get_paged_mqa_logits_metadata`; + used to distribute work across SMs. + max_model_len: Maximum sequence length used to size the logits output. + + Returns: + Logits tensor of shape [B * next_n, max_model_len], dtype + `torch.float32`. + """ + _lazy_init() + if _fp8_paged_mqa_logits_impl is None: + return _missing() + return _fp8_paged_mqa_logits_impl( + q_fp8, + kv_cache_fp8, + weights, + context_lens, + block_tables, + schedule_metadata, + max_model_len, + clean_logits=False + ) + + + def _ceil_to_ue8m0(x: torch.Tensor): return torch.pow(2.0, torch.ceil(torch.log2(x.abs()))) @@ -183,6 +303,9 @@ def should_use_deepgemm_for_fp8_linear(output_dtype: torch.dtype, "fp8_gemm_nt", "m_grouped_fp8_gemm_nt_contiguous", "fp8_m_grouped_gemm_nt_masked", + "fp8_mqa_logits", + "fp8_paged_mqa_logits", + "get_paged_mqa_logits_metadata", "per_block_cast_to_fp8", "is_deep_gemm_e8m0_used", "is_deep_gemm_supported", From 93eade097f49af4d9fe5fca47eda22b2138b80c6 Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Mon, 22 Sep 2025 21:35:10 +0000 Subject: [PATCH 16/82] fix clean logic --- vllm/utils/deep_gemm.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index 10ba6aa94add..2f34d93e49fc 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -145,7 +145,6 @@ def fp8_mqa_logits( weights: torch.Tensor, cu_seqlen_ks: torch.Tensor, cu_seqlen_ke: torch.Tensor, - clean_logits: bool = True, ) -> torch.Tensor: """Compute FP8 MQA logits for a single sequence without KV paging. @@ -160,8 +159,6 @@ def fp8_mqa_logits( shape [M], dtype int32. cu_seqlen_ke: End indices (exclusive) for valid K per query position, shape [M], dtype int32. - clean_logits: If True, fill logits outside [ks, ke) with -inf on the - backend. Defaults to True. Returns: Logits tensor of shape [M, N], dtype `torch.float32`. @@ -169,8 +166,7 @@ def fp8_mqa_logits( _lazy_init() if _fp8_mqa_logits_impl is None: return _missing() - return _fp8_mqa_logits_impl(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke, - clean_logits) + return _fp8_mqa_logits_impl(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke) From 0eba9f1166fa9a95a76e9341ad43a0bb863d45a2 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Mon, 22 Sep 2025 21:11:16 -0400 Subject: [PATCH 17/82] sparse decode and make prefill and decode both use MQA (#16) * and env and MQA path for both prefill and decode Signed-off-by: Lucas Wilkinson * fix shapes Signed-off-by: Lucas Wilkinson --------- Signed-off-by: Lucas Wilkinson --- vllm/model_executor/layers/mla.py | 6 +- vllm/platforms/cuda.py | 3 +- .../attention/backends/mla/flashmla_sparse.py | 171 +++++++++++++++--- 3 files changed, 151 insertions(+), 29 deletions(-) diff --git a/vllm/model_executor/layers/mla.py b/vllm/model_executor/layers/mla.py index 79e4a0e8fe99..34919b9f6384 100644 --- a/vllm/model_executor/layers/mla.py +++ b/vllm/model_executor/layers/mla.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os from dataclasses import dataclass from typing import Optional @@ -80,7 +81,8 @@ def __init__( self.o_proj = mla_modules.o_proj self.indexer = mla_modules.indexer self.topk_tokens = mla_modules.indexer.topk_tokens - self.use_sparse = mla_modules.is_sparse and False + self.use_sparse = mla_modules.is_sparse and os.getenv( + "VLLM_MLA_SPARSE_ENABLED") == "1" # In the MLA backend, kv_cache includes both k_c and # pe (i.e. decoupled position embeddings). In particular, @@ -155,7 +157,7 @@ def forward_native( if self.use_sparse: topk_indices = torch.zeros(q.shape[0], self.topk_tokens) - # NOTE(Chen): a bit hacky, but need to modify Attention.forward + # NOTE(Chen): a bit hacky, but need to modify Attention.forward # otherwise. Try to refactor this later. self.mla_attn.impl.set_topk_indices(topk_indices) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index ffaac6c7c01b..b26e1d100dd8 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -236,7 +236,8 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, use_mla, has_sink, use_sparse) -> str: if use_mla: - use_sparse = False + use_sparse = os.getenv( + "VLLM_MLA_SPARSE_ENABLED") == "1" and use_sparse # TODO(lucas): refactor to be more concise # we should probably consider factoring out V1 here diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index c0f8f0f59143..ba9cfd1bbc32 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -1,12 +1,23 @@ -from vllm.attention.backends.abstract import AttentionMetadata, AttentionLayer -import torch -from vllm.logger import init_logger -from vllm.v1.attention.backends.mla.common import MLACommonBackend, MLACommonDecodeMetadata, MLACommonImpl, MLACommonMetadata, MLACommonMetadataBuilder -from vllm.v1.attention.backends.utils import CommonAttentionMetadata, split_decodes_and_prefills +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass +from typing import Optional, Union + +import numpy as np +import torch + +from vllm import _custom_ops as ops +from vllm.attention.backends.abstract import AttentionLayer, AttentionMetadata from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.v1.attention.backends.mla.common import (MLACommonBackend, + MLACommonDecodeMetadata, + MLACommonImpl, + MLACommonMetadata, + MLACommonMetadataBuilder) +from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, + split_decodes_and_prefills) from vllm.v1.kv_cache_interface import AttentionSpec -from typing import Optional logger = init_logger(__name__) @@ -65,7 +76,9 @@ def __init__(self): @dataclass class FlashMLASparseMetadata(MLACommonMetadata[MLASparsePrefillMetadata]): - pass + # For now just create topk_indices that just attend to the first topk tokens + # always to enable development + debug_topk_indices: Optional[torch.Tensor] = None @dataclass @@ -76,6 +89,8 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): super().__init__(kv_cache_spec, layer_names, vllm_config, device, FlashMLASparseMetadata) + self.topk_tokens = vllm_config.model_config.hf_config\ + .attn_module_list_cfg[0]["topk_tokens"] def _build_prefill( self, common_attn_metadata: CommonAttentionMetadata @@ -91,12 +106,23 @@ def build(self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, fast_build: bool = False) -> FlashMLASparseMetadata: - logger.info(f"build FlashMLASparseMetadata") - num_reqs = common_attn_metadata.num_reqs + logger.info("build FlashMLASparseMetadata") num_actual_tokens = common_attn_metadata.num_actual_tokens num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\ split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.reorder_batch_threshold) + + starts = np.asarray(common_attn_metadata.query_start_loc_cpu) + pos = np.arange(starts[-1]) - np.repeat(starts[:-1], np.diff(starts)) + pos_gpu = torch.as_tensor(pos, device=self.device, dtype=torch.long) + + row = torch.arange(self.topk_tokens, + device=self.device, + dtype=torch.int64) + debug_topk_indices = row.repeat(num_actual_tokens, 1) + mask = debug_topk_indices < pos_gpu.unsqueeze(1) + debug_topk_indices = debug_topk_indices.masked_fill(~mask, -1) + return FlashMLASparseMetadata( num_reqs=common_attn_metadata.num_reqs, max_query_len=common_attn_metadata.max_query_len, @@ -107,6 +133,7 @@ def build(self, num_decodes=num_decodes, num_decode_tokens=num_decode_tokens, num_prefills=num_prefills, + debug_topk_indices=debug_topk_indices, prefill=self._build_prefill(common_attn_metadata), decode=self._build_decode(common_attn_metadata), ) @@ -133,37 +160,120 @@ def __init__( alibi_slopes, sliding_window, kv_cache_dtype, logits_soft_cap, attn_type, kv_sharing_target_layer_name, **mla_args) - # self.sm_scale = + # self.sm_scale = self.topk_indices = None - def set_topk_indices(self, topk_indices: torch.Tensor): self.topk_indices = topk_indices - def _forward_prefill( + def forward( self, + layer: AttentionLayer, q: torch.Tensor, - kv_c_normed: torch.Tensor, - k_pe: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, + k_c_normed: torch.Tensor, # key in unified attn + k_pe: torch.Tensor, # value in unified attn + kv_cache: torch.Tensor, attn_metadata: FlashMLASparseMetadata, - k_scale: torch.Tensor + output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: + # NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use + # MQA 576/512 approach for both prefill and decode (see: + # https://vllm-dev.slack.com/archives/C09GKA1D4LR/p1758506094148479) + + assert output is not None, "Output tensor must be provided." + + if output_scale is not None or output_block_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported" + " for MLACommonImpl") + + if attn_metadata is None: + # The zero fill is required when used with DP + EP + # to ensure all ranks within a DP group compute the + # same expert outputs. + return output.fill_(0) + + num_actual_toks = attn_metadata.num_actual_tokens + + # Inputs and outputs may be padded for CUDA graphs + output_padded = output + output = output[:num_actual_toks, ...] + q = q[:num_actual_toks, ...] + k_c_normed = k_c_normed[:num_actual_toks, ...] + k_pe = k_pe[:num_actual_toks, ...] + + assert attn_metadata.num_decodes is not None and \ + attn_metadata.num_prefills is not None and \ + attn_metadata.num_decode_tokens is not None + + has_decode = attn_metadata.num_decodes > 0 + has_prefill = attn_metadata.num_prefills > 0 + num_decode_tokens = attn_metadata.num_decode_tokens + + q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], + dim=-1) + # Convert from (B, N, P) to (N, B, P) + q_nope = q_nope.transpose(0, 1) + # Multiply (N, B, P) x (N, P, L) -> (N, B, L) + ql_nope = torch.bmm(q_nope, self.W_UK_T) + # Convert from (N, B, L) to (B, N, L) + ql_nope = ql_nope.transpose(0, 1) + + decode_ql_nope = ql_nope[:num_decode_tokens] + decode_q_pe = q_pe[:num_decode_tokens] + + prefill_ql_nope = ql_nope[num_decode_tokens:] + prefill_q_pe = q_pe[num_decode_tokens:] + + # write the latent and rope to kv cache + if kv_cache.numel() > 0: + ops.concat_and_cache_mla( + k_c_normed, + k_pe.squeeze(1), + kv_cache, + attn_metadata.slot_mapping.flatten(), + kv_cache_dtype=self.kv_cache_dtype, + scale=layer._k_scale, + ) + + if has_prefill: + attn_out = self._forward_prefill(prefill_ql_nope, prefill_q_pe, + kv_cache, attn_metadata, + layer._k_scale) + # v_up projection + output[num_decode_tokens:] = self._v_up_proj(attn_out) + if has_decode: + # call decode attn + attn_out, lse = self._forward_decode( + (decode_ql_nope, decode_q_pe), kv_cache, attn_metadata, layer) + # v_up projection + output[:num_decode_tokens] = self._v_up_proj(attn_out) + return output_padded + + def _forward_prefill(self, ql_nope: torch.Tensor, q_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: FlashMLASparseMetadata, + k_scale: torch.Tensor) -> torch.Tensor: # # assume indice of shape [num_prefill_tokens, topk] # block_id_in_req = topk_indices // self.block_size topk_indices = self.topk_indices[attn_metadata.num_decodes:] - logger.info(f"called _forward_prefill with topk_indices shape {topk_indices.shape}") + logger.info("called _forward_prefill with topk_indices shape %s", + topk_indices.shape) # NOTE(Chen): shape is unsure - return torch.zeros((q.shape[0], 2048), dtype=q.dtype, device=q.device) + return torch.zeros((ql_nope.shape[0], ql_nope.shape[1], 512), + dtype=ql_nope.dtype, + device=ql_nope.device) def _forward_decode( - self, - q: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, - attn_metadata: FlashMLASparseMetadata, - layer: AttentionLayer, - topk_indices: Optional[torch.Tensor] = None, # sparse attn + self, + q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: FlashMLASparseMetadata, + layer: AttentionLayer, + topk_indices: Optional[torch.Tensor] = None, # sparse attn ) -> torch.Tensor: topk_indices = self.topk_indices[:attn_metadata.num_decodes] @@ -171,6 +281,15 @@ def _forward_decode( # # assume indice of shape [num_decode_tokens, topk] # block_id_in_req = topk_indices // self.block_size - logger.info(f"called _forward_decode with topk_indices shape {topk_indices.shape}") + logger.info("called _forward_decode with topk_indices shape %s", + topk_indices.shape) + + ql_nope, q_pe = q + + attn_out = torch.zeros((ql_nope.shape[0], ql_nope.shape[1], 512), + dtype=ql_nope.dtype, + device=ql_nope.device) + lse = None #TODO + # NOTE(Chen): shape is unsure - return torch.zeros((q[0].shape[0], 16*512), dtype=q[0].dtype, device=q[0].device), torch.zeros((q[0].shape[0], 128), dtype=q[0].dtype, device=q[0].device) + return attn_out, lse From c0c0624a5ab6e23352e6c04776af014da0104fe9 Mon Sep 17 00:00:00 2001 From: Yongye Zhu Date: Mon, 22 Sep 2025 21:16:39 -0400 Subject: [PATCH 18/82] Adding pytorch impl for Paged Indexer (#9) * code from ds Signed-off-by: youkaichao * doc from ds Signed-off-by: youkaichao * Fixes for support_materials/2-tilelang/ Signed-off-by: mgoin * Fix example 1 Signed-off-by: mgoin * Fix Einsum in deepgemm * Fix `libc10.so` unimported error * fix reference code Signed-off-by: youkaichao * adding missing indexer args * passing index args into the module * init Signed-off-by: Chen Zhang * build indexer k cache medadata * prefill indexer, but weight_proj will output -inf * unqiantized paged indexer, still have -inf issue * remove support material * adding topk_indices mask * add weight scale * unittest infrastructure and fix weight_proj, numeric error due to quantization * varlen prefill passed * paged prefill * add indices mask --------- Signed-off-by: youkaichao Signed-off-by: mgoin Signed-off-by: Chen Zhang Co-authored-by: youkaichao Co-authored-by: mgoin Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: Chen Zhang --- tests/kernels/attention/test_indexer.py | 227 +++++++++++++++++ vllm/model_executor/layers/mla.py | 6 + vllm/model_executor/models/deepseek_v2.py | 202 +++++++++++++-- vllm/utils/tile_lang_kernels.py | 282 +++++++++++++++++++++ vllm/v1/attention/backends/mla/common.py | 3 + vllm/v1/attention/backends/mla/flashmla.py | 1 + vllm/v1/attention/backends/mla/indexer.py | 23 +- 7 files changed, 727 insertions(+), 17 deletions(-) create mode 100644 tests/kernels/attention/test_indexer.py create mode 100644 vllm/utils/tile_lang_kernels.py diff --git a/tests/kernels/attention/test_indexer.py b/tests/kernels/attention/test_indexer.py new file mode 100644 index 000000000000..f0ea88a71f53 --- /dev/null +++ b/tests/kernels/attention/test_indexer.py @@ -0,0 +1,227 @@ +import random + +import torch + +from vllm.utils.tile_lang_kernels import act_quant, fp8_index +from vllm import _custom_ops as ops + + +def ref_compute_logits_fp8(q, kv, weights, mask, block_size): + q_fp8, q_scale = act_quant(q, block_size, "ue8m0") + k_fp8, k_scale = act_quant(kv, block_size, "ue8m0") + + weights = weights.unsqueeze(-1) * q_scale + weights = weights * (128**(-0.5)) + index_score = fp8_index( + q_fp8.contiguous(), weights, + k_fp8.contiguous(), + k_scale.contiguous()) + if mask is not None: + index_score += mask + return index_score + +def ref_indexer(seq_len, q, kv, weights, block_size, topk): + B = seq_len.shape[0] + varlen_logits = [] + + for i in range(B): + S = seq_len[i] + q_s = q[i][:S].contiguous().unsqueeze(0) + kv_s = kv[i][:S].contiguous().unsqueeze(0) + weights_s = weights[i][:S].contiguous().unsqueeze(0) + mask = torch.full( + (S, S), float("-inf"), + device="cuda").triu_(1) + logits = ref_compute_logits_fp8(q_s, kv_s, weights_s, mask, block_size) + varlen_logits.append(logits) + # topk_indices = index_score.topk(topk, + # dim=-1)[1] + return varlen_logits + +def kv_spans_from_batches(start_seq_loc: torch.Tensor, + seq_len_per_batch: torch.Tensor): + """ + Args: + start_seq_loc: 1D long tensor [B+1], cumulative counts of selected tokens per batch. + Example: [0, 2, 4, 7] -> batch sizes (selected) [2, 2, 3], N=7 tokens total. + seq_len_per_batch: 1D long tensor [B], full sequence length (KV length) of each batch. + Example: [5, 9, 4]. + + Returns: + start_tensor: 1D long tensor [N], start offset in the concatenated KV cache for each token's batch. + end_location: 1D long tensor [N], **exclusive** end = start + token's local position. + (So the attended KV slice is kv[start:end].) + + Assumes each batch contributes its full `seq_len_per_batch[i]` keys to the KV cache, and + the selected tokens within a batch are the **last** `counts[i]` positions of that sequence. + """ + q = start_seq_loc.to(dtype=torch.long) + L = seq_len_per_batch.to(dtype=torch.long, device=q.device) + assert q.dim() == 1 and L.dim() == 1 + assert q.numel() == L.numel() + 1, "start_seq_loc must have length B+1" + + # Selected tokens per batch and totals + counts = q[1:] - q[:-1] # [B] + N = int(q[-1].item()) # total selected tokens + B = L.numel() + device = L.device + + if N == 0: + return (torch.empty(0, dtype=torch.long, device=device), + torch.empty(0, dtype=torch.long, device=device)) + + # KV start offsets per batch in the concatenated KV cache + kv_starts_per_batch = torch.cumsum(L, dim=0) - L # [B] + + # For each selected token, which batch does it belong to? + batch_id = torch.repeat_interleave(torch.arange(B, device=device), counts) # [N] + + # Map batch KV start to each token + start_tensor = kv_starts_per_batch[batch_id] # [N] + + # End-align local positions inside each batch: + # local_pos = L[b] - counts[b] + (1..counts[b]) for each batch b + L_expand = torch.repeat_interleave(L, counts) # [N] + m_expand = torch.repeat_interleave(counts, counts) # [N] + # position within the selected block: 1..counts[b] + pos_within = (torch.arange(N, device=device, dtype=torch.long) + - torch.repeat_interleave(q[:-1], counts) + 1) + + local_pos = L_expand - m_expand + pos_within # [N], 1-based + end_location = start_tensor + local_pos # exclusive end + + return start_tensor, end_location + +def ref_fp8_mqa_logits( + q: torch.Tensor, + kv: torch.Tensor, + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, +): + k = kv + q = q.float() + k = k.float() + + seq_len_kv = kv.shape[0] + mask_lo = (torch.arange(0, seq_len_kv, device="cuda")[None, :] + >= cu_seqlen_ks[:, None]) + mask_hi = (torch.arange(0, seq_len_kv, device="cuda")[None, :] + < cu_seqlen_ke[:, None]) + mask = mask_lo & mask_hi + + score = torch.einsum("mhd,nd->hmn", q, k) + logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0) + logits = logits.masked_fill(~mask, float("-inf")) + + cost = mask.sum() + return logits, cost + +def torch_indexer(seq_len, q, kv, weights, block_size, topk): + NUM_BLOCKS = 8 + BLOCK_SIZE = 32 + + B = seq_len.shape[0] + concat_q = [] + concat_kv = [] + concat_weights = [] + total_slots = NUM_BLOCKS * BLOCK_SIZE + head_dim = kv.shape[-1] + max_num_block_per_batch = torch.max(seq_len) + block_table = torch.empty((B, max_num_block_per_batch), + dtype=torch.int32, + device="cuda") + + for i in range(B): + S = seq_len[i] + q_s = q[i][:S].contiguous() + kv_s = kv[i][:S].contiguous() + weight_s = weights[i][:S].contiguous() + concat_q.append(q_s) + concat_kv.append(kv_s) + concat_weights.append(weight_s) + + q = torch.cat(concat_q, dim=0) + kv = torch.cat(concat_kv, dim=0) + weights = torch.cat(concat_weights, dim=0) + + # write to kv cache based on slot mapping + entry_size = head_dim * 2 + num_tokens = q.size(0) + slot_mapping_lst = random.sample(range(total_slots), num_tokens) + slot_mapping = torch.tensor(slot_mapping_lst, + dtype=torch.long, + device="cuda") + kv_cache = torch.zeros( + NUM_BLOCKS, + BLOCK_SIZE, + entry_size, + dtype=torch.bfloat16, + device="cuda" + ) + scale = torch.tensor(1, dtype=torch.float32, device="cuda") + ops.concat_and_cache_mla( + kv, + kv.clone(), + kv_cache, + slot_mapping, + "auto", + scale + ) + + current_index = 0 + for i in range(B): + S = seq_len[i] + block_table[i][:S] = slot_mapping[current_index: current_index + S] + current_index += S + + weights = weights * (128**(-0.5)) + query_start_loc = torch.empty((B + 1), device="cuda") + query_start_loc[0] = 0 + query_start_loc[1:] = seq_len.cumsum(dim=0).to(dtype=torch.int32) + + kv_gathered = kv_cache.view(-1, entry_size)[slot_mapping][..., :head_dim] + torch.testing.assert_close(kv, kv_gathered) + + cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(query_start_loc, seq_len) + + logits, _ = ref_fp8_mqa_logits( + q, + kv_gathered, + weights, + cu_seqlen_ks, + cu_seqlen_ke + ) + topk_indices = logits.topk(topk, dim=-1)[1] + mask_lo = topk_indices >= cu_seqlen_ks[:, None] + mask_hi = topk_indices < cu_seqlen_ke[:, None] + mask = mask_lo & mask_hi + topk_indices = topk_indices.masked_fill(~mask, -1) + return logits + +def test_paged_indexer_python(): + B = 2 + S = 128 + SKV = S + H = 64 + HKV = 1 + D = 128 + block_size = 128 + topk = 64 + device = "cuda" + seq_len = torch.randint(low=64, high=S, size=(B,)) + + q = torch.randn(B, S, H, D, device="cuda", + dtype=torch.bfloat16) + kv = torch.randn(B, SKV, D, device="cuda", + dtype=torch.bfloat16) + weights = torch.randn(B, S, H, device=device, dtype=torch.float32) * H**-0.5 + + ref_indices = ref_indexer(seq_len, q, kv, weights, block_size, topk) + torch_indices = torch_indexer(seq_len, q, kv, weights, block_size, topk) + import pdb; pdb.set_trace() + print(ref_indices) + + +if __name__ == "__main__": + test_paged_indexer_python() diff --git a/vllm/model_executor/layers/mla.py b/vllm/model_executor/layers/mla.py index 34919b9f6384..13b879107c8c 100644 --- a/vllm/model_executor/layers/mla.py +++ b/vllm/model_executor/layers/mla.py @@ -108,6 +108,7 @@ def __init__( qk_head_dim=self.qk_head_dim, v_head_dim=self.v_head_dim, kv_b_proj=self.kv_b_proj, + indexer=self.indexer, ) self.prefix = prefix @@ -153,6 +154,11 @@ def forward_native( q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb( positions, q[..., self.qk_nope_head_dim:], k_pe) + + if self.indexer: + topk_indices = self.indexer(hidden_states, q_c, positions, self.rotary_emb) + # if topk_indices is not None: + # print(topk_indices) if self.use_sparse: topk_indices = torch.zeros(q.shape[0], self.topk_tokens) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index a4ad9c288f12..b782a6990abb 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -43,6 +43,7 @@ get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_gather) +from vllm.forward_context import get_forward_context from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.fused_moe import FusedMoE @@ -72,6 +73,11 @@ maybe_prefix) from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheSpec +if current_platform.is_cuda_alike(): + from vllm import _custom_ops as ops +elif current_platform.is_xpu(): + from vllm._ipex_ops import ipex_ops as ops + logger = init_logger(__name__) WITH_V32 = True @@ -498,7 +504,6 @@ def __init__(self, head_dim: int, dtype: torch.dtype, prefix: str, if prefix in compilation_config.static_forward_context: raise ValueError(f"Duplicate layer name: {prefix}") compilation_config.static_forward_context[prefix] = self - logger.info(f"register indexer cache {prefix}") def get_kv_cache_spec(self) -> KVCacheSpec: return FullAttentionSpec( @@ -510,13 +515,102 @@ def get_kv_cache_spec(self) -> KVCacheSpec: ) def forward(self): - logger.info( - f"self.kv_cache {self.prefix} {self.kv_cache[0].shape} {self.kv_cache[0].dtype}" - ) + attn_metadata = get_forward_context().attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[self.prefix] + logger.info(f"attn_metadata {attn_metadata}") def get_attn_backend(self) -> AttentionBackend: return DeepseekV32IndexerBackend +# ignore or replace with pytorch +def rotate_activation(x: torch.Tensor) -> torch.Tensor: + assert x.dtype == torch.bfloat16 + from fast_hadamard_transform import hadamard_transform + hidden_size = x.size(-1) + # make sure the hidden_size is expontial of 2 + return hadamard_transform(x, scale=hidden_size**-0.5) + +def ref_fp8_mqa_logits( + q: torch.Tensor, + kv: torch.Tensor, + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, +): + # print(f"q_shape: {q.shape}, v_shape: {kv.shape}, weights.shape: {weights.shape}") + k = kv + q = q.float() + k = k.float() + + seq_len_kv = kv.shape[0] + mask_lo = (torch.arange(0, seq_len_kv, device="cuda")[None, :] + >= cu_seqlen_ks[:, None]) + mask_hi = (torch.arange(0, seq_len_kv, device="cuda")[None, :] + < cu_seqlen_ke[:, None]) + mask = mask_lo & mask_hi + + score = torch.einsum("mhd,nd->hmn", q, k) + logits = (score.relu() * weights.transpose(0, 1)).sum(dim=0) + logits = logits.masked_fill(~mask, float("-inf")) + + cost = mask.sum() + return logits, cost + +# TODO (zyongye) optimize this, this is now vibe coded +def kv_spans_from_batches(start_seq_loc: torch.Tensor, + seq_len_per_batch: torch.Tensor): + """ + Args: + start_seq_loc: 1D long tensor [B+1], cumulative counts of selected tokens per batch. + Example: [0, 2, 4, 7] -> batch sizes (selected) [2, 2, 3], N=7 tokens total. + seq_len_per_batch: 1D long tensor [B], full sequence length (KV length) of each batch. + Example: [5, 9, 4]. + + Returns: + start_tensor: 1D long tensor [N], start offset in the concatenated KV cache for each token's batch. + end_location: 1D long tensor [N], **exclusive** end = start + token's local position. + (So the attended KV slice is kv[start:end].) + + Assumes each batch contributes its full `seq_len_per_batch[i]` keys to the KV cache, and + the selected tokens within a batch are the **last** `counts[i]` positions of that sequence. + """ + q = start_seq_loc.to(dtype=torch.long) + L = seq_len_per_batch.to(dtype=torch.long, device=q.device) + assert q.dim() == 1 and L.dim() == 1 + assert q.numel() == L.numel() + 1, "start_seq_loc must have length B+1" + + # Selected tokens per batch and totals + counts = q[1:] - q[:-1] # [B] + N = int(q[-1].item()) # total selected tokens + B = L.numel() + device = L.device + + if N == 0: + return (torch.empty(0, dtype=torch.long, device=device), + torch.empty(0, dtype=torch.long, device=device)) + + # KV start offsets per batch in the concatenated KV cache + kv_starts_per_batch = torch.cumsum(L, dim=0) - L # [B] + + # For each selected token, which batch does it belong to? + batch_id = torch.repeat_interleave(torch.arange(B, device=device), counts) # [N] + + # Map batch KV start to each token + start_tensor = kv_starts_per_batch[batch_id] # [N] + + # End-align local positions inside each batch: + # local_pos = L[b] - counts[b] + (1..counts[b]) for each batch b + L_expand = torch.repeat_interleave(L, counts) # [N] + m_expand = torch.repeat_interleave(counts, counts) # [N] + # position within the selected block: 1..counts[b] + pos_within = (torch.arange(N, device=device, dtype=torch.long) + - torch.repeat_interleave(q[:-1], counts) + 1) + + local_pos = L_expand - m_expand + pos_within # [N], 1-based + end_location = start_tensor + local_pos # exclusive end + + return start_tensor, end_location class Indexer(nn.Module): @@ -547,25 +641,102 @@ def __init__(self, self.k_norm = LayerNorm(self.head_dim, eps=1e-6) self.weights_proj = ReplicatedLinear(hidden_size, self.n_head, - quant_config=quant_config, + quant_config=None, prefix=f"{prefix}.weights_proj") self.softmax_scale = self.head_dim**-0.5 + self.scale_fmt = "ue8m0" self.quant_block_size = 128 # TODO: get from config - self.k_cache = DeepseekV32IndexerCache(head_dim=self.head_dim, - dtype=torch.float8_e4m3fn, + + #TODO (zyongye) change dim to fp8 later to (self.head_dim + 4) + self.k_cache = DeepseekV32IndexerCache(head_dim=self.head_dim * 2, + dtype=torch.bfloat16, prefix=f"{prefix}.k_cache", cache_config=cache_config) - self.k_scale_cache = DeepseekV32IndexerCache( - head_dim=self.head_dim // self.quant_block_size, - dtype=torch.float32, - prefix=f"{prefix}.k_scale_cache", - cache_config=cache_config) def forward(self, hidden_states: torch.Tensor, - q: torch.Tensor) -> torch.Tensor: - return torch.empty_like( - hidden_states)[:, :self.topk_tokens].contiguous() + qr: torch.Tensor, positions, rotary_emb) -> torch.Tensor: + # print(f"hidden_states: {torch.isinf(hidden_states).any()}, qr: {torch.isinf(qr).any()}") + # print(f"weight_proj: {torch.isneginf(self.weights_proj.weight.to(torch.float32)).any()}") + q, _= self.wq_b(qr) + q = q.view(-1, self.n_head, self.head_dim) + q_pe, q_nope = torch.split( + q, [self.rope_dim, self.head_dim - self.rope_dim], + dim=-1 + ) + + k, _ = self.wk(hidden_states) + k = self.k_norm(k) + k_pe, k_nope = torch.split( + k, [self.rope_dim, self.head_dim - self.rope_dim], + dim=-1) + + #FIXME (zyongye) this will cause OOM when using full sequence forward on 8xH200 + q_pe, k_pe = rotary_emb(positions, q_pe, k_pe.unsqueeze(1)) + q = torch.cat([q_pe, q_nope], dim=-1) + k = torch.cat([k_pe.squeeze(1), k_nope], dim=-1) + q = rotate_activation(q) + k = rotate_activation(k) + + from vllm.utils.tile_lang_kernels import act_quant + q_fp8, q_scale = act_quant(q, self.quant_block_size, self.scale_fmt) + k_fp8, k_scale = act_quant(k, self.quant_block_size, self.scale_fmt) + # k_cache_entry = torch.cat([k_fp8, k_scale], dim=-1) + weights, _ = self.weights_proj(hidden_states) + weights = weights.unsqueeze(-1) * self.softmax_scale * self.n_head**-0.5 + + # careful! this will be None in dummy run + attn_metadata = get_forward_context().attn_metadata + if isinstance(attn_metadata, dict): + k_cache_attn_metadata = attn_metadata[self.k_cache.prefix] + slot_mapping = k_cache_attn_metadata.slot_mapping + + query_start_loc = k_cache_attn_metadata.query_start_loc + seq_lens = k_cache_attn_metadata.seq_lens + batch_size = seq_lens.size(0) + cu_seq_lens = torch.empty((batch_size + 1), + dtype=torch.int32, + device=q.device) + cu_seq_lens[0] = 0 + cu_seq_lens[1:] = seq_lens.cumsum(dim=0).to(dtype=torch.int32) + cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(query_start_loc, seq_lens) + #TODO (zyongye) use quant type + kv_cache = self.k_cache.kv_cache[0] + + # FIXME (zyongye) right now k_pe cache is a dummy tensor, + # we need to change kv cache to only store k cache + scale = torch.tensor(1, dtype=torch.float32, device=k.device) + ops.concat_and_cache_mla( + k, + k.clone(), + kv_cache, + slot_mapping, + "auto", + scale, + ) + + flattened_kv = torch.empty([cu_seqlen_ks.size(-1), self.head_dim * 2], device=k.device, dtype=torch.bfloat16) + ops.cp_gather_cache( + kv_cache, + flattened_kv, + k_cache_attn_metadata.block_table, + cu_seq_lens, + batch_size, + ) + logits, _ = ref_fp8_mqa_logits( + q, + flattened_kv[..., :self.head_dim], + weights, + cu_seqlen_ks, + cu_seqlen_ke, + ) + topk_indices = logits.topk(min(self.topk_tokens, logits.shape[-1]), + dim=-1)[1] + mask_lo = topk_indices >= cu_seqlen_ks[:, None] + mask_hi = topk_indices < cu_seqlen_ke[:, None] + mask = mask_lo & mask_hi + topk_indices = topk_indices.masked_fill(~mask, -1) + return topk_indices class DeepseekV2MLAAttention(nn.Module): @@ -720,6 +891,7 @@ def forward( positions: torch.Tensor, hidden_states: torch.Tensor, ) -> torch.Tensor: + # self.indexer(torch.tensor([]), torch.tensor([])) return self.mla_attn(positions, hidden_states) diff --git a/vllm/utils/tile_lang_kernels.py b/vllm/utils/tile_lang_kernels.py new file mode 100644 index 000000000000..5e4576fea45e --- /dev/null +++ b/vllm/utils/tile_lang_kernels.py @@ -0,0 +1,282 @@ +from typing import Optional, Tuple + +import tilelang +import tilelang.language as T +import torch + +tilelang.set_log_level("WARNING") + +pass_configs = { + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_FAST_MATH: True, +} + +FP8 = "float8_e4m3" +BF16 = "bfloat16" +FP32 = "float32" + + +def fast_log2_ceil(x): + bits_x = T.reinterpret("uint32", x) + exp_x = (bits_x >> 23) & 0xFF + man_bits = bits_x & ((1 << 23) - 1) + return T.Cast("int32", exp_x - 127 + T.if_then_else(man_bits != 0, 1, 0)) + + +def fast_pow2(x): + bits_x = (x + 127) << 23 + return T.reinterpret("float32", bits_x) + + +def fast_round_scale(amax, fp8_max_inv): + return fast_pow2(fast_log2_ceil(amax * fp8_max_inv)) + + +@tilelang.jit(pass_configs=pass_configs) +def act_quant_kernel(N, + in_dtype=BF16, + out_dtype=FP8, + scale_dtype=FP32, + round_scale=False): + M = T.symbolic("M") + fp8_min = -448.0 + fp8_max = 448.0 + fp8_max_inv = 1 / fp8_max + num_stages = 0 if round_scale else 2 + blk_m = 32 + group_size = 128 + + @T.prim_func + def act_quant_kernel_( + X: T.Tensor[(M, N), in_dtype], + Y: T.Tensor[(M, N), out_dtype], + S: T.Tensor[(M, T.ceildiv(N, group_size)), scale_dtype], + ): + with T.Kernel(T.ceildiv(M, blk_m), + T.ceildiv(N, group_size), + threads=128) as ( + pid_m, + pid_n, + ): + x_shared = T.alloc_shared((blk_m, group_size), in_dtype) + x_local = T.alloc_fragment((blk_m, group_size), in_dtype) + amax_local = T.alloc_fragment((blk_m, ), scale_dtype) + s_local = T.alloc_fragment((blk_m, ), scale_dtype) + y_local = T.alloc_fragment((blk_m, group_size), out_dtype) + y_shared = T.alloc_shared((blk_m, group_size), out_dtype) + + for _ in T.Pipelined(1, num_stages=num_stages): + T.copy(X[pid_m * blk_m, pid_n * group_size], x_shared) + T.copy(x_shared, x_local) + T.reduce_absmax(x_local, amax_local, dim=1) + for i in T.Parallel(blk_m): + amax_local[i] = T.max(amax_local[i], 1e-4) + if round_scale: + s_local[i] = fast_round_scale(amax_local[i], + fp8_max_inv) + else: + s_local[i] = amax_local[i] * fp8_max_inv + for i, j in T.Parallel(blk_m, group_size): + y_local[i, j] = T.clamp(x_local[i, j] / s_local[i], + fp8_min, fp8_max) + for i in T.Parallel(blk_m): + S[pid_m * blk_m + i, pid_n] = s_local[i] + T.copy(y_local, y_shared) + T.copy(y_shared, Y[pid_m * blk_m, pid_n * group_size]) + + return act_quant_kernel_ + + +def act_quant( + x: torch.Tensor, + block_size: int = 128, + scale_fmt: Optional[str] = None) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantizes the input tensor `x` using block-wise quantization. + + Args: + x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`. + block_size (int, optional): The size of the blocks to be used for quantization. Default is 128. + scale_fmt (Optional[str], optional): The format of the scale. Default is None. + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - The quantized tensor with dtype `torch.float8_e4m3fn`. + - A tensor of scaling factors with dtype `torch.float32`. + """ + assert x.is_contiguous(), "Input tensor must be contiguous" + assert x.size(-1) % block_size == 0, ( + f"Last dimension size must be divisible by block_size (block_size={block_size})" + ) + N = x.size(-1) + y = torch.empty_like(x, dtype=torch.float8_e4m3fn) + s = x.new_empty(*x.size()[:-1], N // block_size, dtype=torch.float32) + kernel = act_quant_kernel(N, round_scale=scale_fmt is not None) + kernel(x.view(-1, N), y.view(-1, N), s.view(-1, N // block_size)) + return y, s + + +@tilelang.jit(pass_configs=pass_configs) +def fp8_gemm_kernel(N, K, out_dtype=BF16, accum_dtype="float32"): + assert out_dtype in [BF16, "float32"] + + M = T.symbolic("M") + group_size = 128 + block_M = 32 + block_N = 128 + block_K = 128 + + @T.prim_func + def fp8_gemm_kernel_( + A: T.Tensor[(M, K), FP8], + B: T.Tensor[(N, K), FP8], + C: T.Tensor[(M, N), out_dtype], + scales_a: T.Tensor[(M, T.ceildiv(K, group_size)), FP32], + scales_b: T.Tensor[(T.ceildiv(N, group_size), + T.ceildiv(K, group_size)), FP32], + ): + with T.Kernel(T.ceildiv(N, block_N), + T.ceildiv(M, block_M), + threads=128) as ( + bx, + by, + ): + A_shared = T.alloc_shared((block_M, block_K), FP8) + B_shared = T.alloc_shared((block_N, block_K), FP8) + C_shared = T.alloc_shared((block_M, block_N), out_dtype) + Scale_C_shared = T.alloc_shared((block_M), FP32) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_local_accum = T.alloc_fragment((block_M, block_N), accum_dtype) + + # Improve L2 Cache + T.use_swizzle(panel_size=10) + + T.clear(C_local) + T.clear(C_local_accum) + K_iters = T.ceildiv(K, block_K) + for k in T.Pipelined(K_iters, num_stages=4): + # Load A into shared memory + T.copy(A[by * block_M, k * block_K], A_shared) + # Load B into shared memory + T.copy(B[bx * block_N, k * block_K], B_shared) + # Load scale into shared memory + Scale_B = scales_b[bx * block_N // group_size, k] + for i in T.Parallel(block_M): + Scale_C_shared[i] = scales_a[by * block_M + i, k] * Scale_B + + T.gemm(A_shared, B_shared, C_local, transpose_B=True) + # Promote to enable 2xAcc + for i, j in T.Parallel(block_M, block_N): + C_local_accum[i, j] += C_local[i, j] * Scale_C_shared[i] + T.clear(C_local) + # TMA store + T.copy(C_local_accum, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return fp8_gemm_kernel_ + + +def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, + b_s: torch.Tensor) -> torch.Tensor: + """ + Perform a matrix multiplication using FP8 precision. + + Args: + a (torch.Tensor): The first input matrix, must be contiguous. + a_s (torch.Tensor): The scaling factor for the first input matrix, must be contiguous. + b (torch.Tensor): The second input matrix, must be contiguous. + b_s (torch.Tensor): The scaling factor for the second input matrix, must be contiguous. + + Returns: + torch.Tensor: The result of the matrix multiplication. + """ + assert a.is_contiguous() and b.is_contiguous( + ), "Input tensors must be contiguous" + assert a_s.is_contiguous() and b_s.is_contiguous(), ( + "Scaling factor tensors must be contiguous") + K = a.size(-1) + M = a.numel() // K + N = b.size(0) + c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype()) + kernel = fp8_gemm_kernel(N, K) + kernel(a.view(M, K), b, c.view(M, N), a_s.view(M, -1), b_s) + return c + + +@tilelang.jit(out_idx=[4], pass_configs=pass_configs) +def fp8_index_kernel(h: int, d: int): + b = T.symbolic("b") + m = T.symbolic("m") + n = T.symbolic("n") + + blk_n1 = 512 + blk_n2 = 128 + + @T.prim_func + def fp8_index_kernel_( + q: T.Tensor[(b, m, h, d), FP8], + q_s: T.Tensor[(b, m, h), FP32], + k: T.Tensor[(b, n, d), FP8], + k_s: T.Tensor[(b, n), FP32], + o: T.Tensor[(b, m, n), FP32], + ) -> None: + with T.Kernel(b, m, T.ceildiv(n, blk_n1)) as (i_b, i_m, i1_n): + q_smem = T.alloc_shared((h, d), FP8) + T.copy(q[i_b, i_m, 0, 0], q_smem) + + q_s_frag = T.alloc_fragment(h, FP32) + T.copy(q_s[i_b, i_m, 0], q_s_frag) + + for i2_n in T.Pipelined(blk_n1 // blk_n2, num_stages=2): + k_smem = T.alloc_shared((blk_n2, d), FP8) + T.copy(k[i_b, i1_n * blk_n1 + i2_n * blk_n2, 0], k_smem) + + k_s_frag = T.alloc_fragment(blk_n2, FP32) + T.copy(k_s[i_b, i1_n * blk_n1 + i2_n * blk_n2], k_s_frag) + + logits = T.alloc_fragment((blk_n2, h), FP32) + T.gemm( + k_smem, + q_smem, + logits, + transpose_A=False, + transpose_B=True, + clear_accum=True, + ) + + for i_h, i3_n in T.Parallel(h, blk_n2): + logits[i3_n, + i_h] = T.max(logits[i3_n, i_h], 0) * q_s_frag[i_h] + + logits_sum = T.alloc_fragment(blk_n2, FP32) + T.reduce_sum(logits, logits_sum, dim=1) + + for i3_n in T.Parallel(blk_n2): + logits_sum[i3_n] *= k_s_frag[i3_n] + + T.copy(logits_sum, o[i_b, i_m, i1_n * blk_n1 + i2_n * blk_n2]) + + return fp8_index_kernel_ + + +def fp8_index( + q: torch.Tensor, + q_s: torch.Tensor, + k: torch.Tensor, + k_s: torch.Tensor, +) -> torch.Tensor: + """ + Perform index score using FP8 precision. + + Args: + q (torch.Tensor): The Q tensor, must be contiguous. + q_s (torch.Tensor): The scaling factor for Q (float), must be contiguous. + k (torch.Tensor): The K tensor, must be contiguous. + k_s (torch.Tensor): The scaling factor for K (e8m0 here), must be contiguous. + + fp8 q @ fp8 k -> fp32 logits + relu(fp32 logits) * q_s (weights) -> fp32 logits + fp32 logits -> fp32 logits_sum + fp32 logits_sum * k_s (e8m0) -> fp32 index_score + """ + return fp8_index_kernel(q.shape[2], q.shape[3])(q, q_s, k, k_s) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 33bfb4127be1..c3de294d947e 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -947,6 +947,7 @@ def __init__( qk_head_dim: int, v_head_dim: int, kv_b_proj: ColumnParallelLinear, + indexer = None, ) -> None: if kv_sharing_target_layer_name is not None: raise NotImplementedError("KV sharing is not supported for MLA") @@ -964,6 +965,7 @@ def __init__( self.qk_head_dim = qk_head_dim self.v_head_dim = v_head_dim self.kv_b_proj = kv_b_proj + self.indexer = indexer if use_flashinfer_prefill(): logger.debug_once("Using FlashInfer prefill for MLA") @@ -1429,6 +1431,7 @@ def _forward_prefill( attn_metadata: MLACommonMetadata, k_scale: torch.Tensor, ) -> torch.Tensor: + # TODO (zyongye): Prefill function here assert attn_metadata.prefill is not None assert self.dcp_world_size is not None diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 150e38553e4b..bb145e1f4a29 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -177,6 +177,7 @@ def _forward_decode( attn_metadata: FlashMLAMetadata, layer: AttentionLayer, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + # TODO: (zyongye) decode function for mla here assert kv_c_and_k_pe_cache.numel() > 0 assert attn_metadata.decode is not None diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index db474cb92718..c2e478e382b8 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -2,6 +2,7 @@ from vllm.attention.backends.abstract import AttentionBackend from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, CommonAttentionMetadata) +import torch class DeepseekV32IndexerBackend(AttentionBackend): @@ -23,7 +24,17 @@ def get_builder_cls() -> type["DeepseekV32IndexerMetadataBuilder"]: @dataclass class DeepseekV32IndexerMetadata: - pass + + query_start_loc: torch.Tensor + seq_lens: torch.Tensor + + num_reqs: int + num_actual_tokens: int # Number of tokens excluding padding. + max_query_len: int + max_seq_len: int + + block_table: torch.Tensor # [num_req, (max_req_len + block_size - 1) // block_size] + slot_mapping: torch.Tensor class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): @@ -35,4 +46,12 @@ def build(self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, fast_build: bool = False) -> DeepseekV32IndexerMetadata: - return DeepseekV32IndexerMetadata() + return DeepseekV32IndexerMetadata( + query_start_loc = common_attn_metadata.query_start_loc, + seq_lens=common_attn_metadata.seq_lens, + num_reqs=common_attn_metadata.num_reqs, + num_actual_tokens=common_attn_metadata.num_actual_tokens, + max_query_len=common_attn_metadata.max_query_len, + max_seq_len=common_attn_metadata.max_seq_len, + block_table=common_attn_metadata.block_table_tensor, + slot_mapping=common_attn_metadata.slot_mapping) From 6a29a01ddbc285783604705c49baa4dff6c38b28 Mon Sep 17 00:00:00 2001 From: Lucia Fang <116399278+luccafong@users.noreply.github.com> Date: Mon, 22 Sep 2025 18:44:50 -0700 Subject: [PATCH 19/82] support mtp with indexer kv (#21) Co-authored-by: Lucia Fang --- vllm/v1/spec_decode/eagle.py | 41 ++++++++++++++++++++++++++++++++++-- 1 file changed, 39 insertions(+), 2 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 2a178ddf4877..f7a5dd20df97 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -16,6 +16,7 @@ from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model +from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache from vllm.model_executor.models import supports_multimodal from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.platforms import current_platform @@ -62,6 +63,7 @@ def __init__( self.method = self.speculative_config.method self.runner = runner + self.device = device self.dtype = vllm_config.model_config.dtype self.max_model_len = vllm_config.model_config.max_model_len self.block_size = vllm_config.cache_config.block_size @@ -197,12 +199,26 @@ def propose( self.runner.attn_groups[0][0].metadata_builders[ubatch_id] attn_metadata = attn_metadata_builder.build_for_drafting( common_attn_metadata=common_attn_metadata, draft_index=0) - + # FIXME: support hybrid kv for draft model (remove separate indexer) + if self.draft_indexer_metadata_builder: + draft_indexer_metadata = ( + self.draft_indexer_metadata_builder + .build_for_drafting( + common_attn_metadata=common_attn_metadata, + draft_index=0, + ) + ) + else: + draft_indexer_metadata = None # At this moment, we assume all eagle layers belong to the same KV # cache group, thus using the same attention metadata. per_layer_attn_metadata = {} for layer_name in self.attn_layer_names: per_layer_attn_metadata[layer_name] = attn_metadata + for layer_name in self.indexer_layer_names: + assert draft_indexer_metadata is not None + per_layer_attn_metadata[layer_name] = draft_indexer_metadata + if self.use_cuda_graph and \ num_tokens <= self.cudagraph_batch_sizes[-1]: num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) @@ -806,6 +822,10 @@ def load_model(self, target_model: nn.Module) -> None: self.vllm_config.speculative_config.draft_model_config target_attn_layer_names = set( get_layers_from_vllm_config(self.vllm_config, Attention).keys()) + # FIXME: support hybrid kv for draft model + target_indexer_layer_names = set( + get_layers_from_vllm_config( + self.vllm_config, DeepseekV32IndexerCache).keys()) from vllm.compilation.backends import set_model_tag with set_model_tag("eagle_head"): @@ -815,8 +835,25 @@ def load_model(self, target_model: nn.Module) -> None: draft_attn_layer_names = ( get_layers_from_vllm_config(self.vllm_config, Attention).keys() - target_attn_layer_names) - + indexer_layers = get_layers_from_vllm_config(self.vllm_config, DeepseekV32IndexerCache) + draft_indexer_layer_names = (indexer_layers.keys() - target_indexer_layer_names) self.attn_layer_names = list(draft_attn_layer_names) + self.indexer_layer_names = list(draft_indexer_layer_names) + + if self.indexer_layer_names: + first_layer = self.indexer_layer_names[0] + self.draft_indexer_metadata_builder = ( + indexer_layers[first_layer] + .get_attn_backend() + .get_builder_cls()( + indexer_layers[first_layer].get_kv_cache_spec(), + self.indexer_layer_names, + self.vllm_config, + self.device, + ) + ) + else: + self.draft_indexer_metadata_builder = None if supports_multimodal(target_model): # handle multimodality From 9ca6434b749b1b7c32ac13180a9edaf4290a5845 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Mon, 22 Sep 2025 22:51:28 -0700 Subject: [PATCH 20/82] FlashMLA prefill kernel integration (#17) * prefill mla Signed-off-by: Chen Zhang * can run now Signed-off-by: Chen Zhang * tmp Signed-off-by: Chen Zhang * can output the first token Signed-off-by: Chen Zhang * fix bug Signed-off-by: Chen Zhang * remove some debug Signed-off-by: Chen Zhang * update Signed-off-by: Chen Zhang * hack through cu_seqlen_ks exploding issue * update basic.py Signed-off-by: Chen Zhang * remove some unnecessary changes Signed-off-by: Chen Zhang * clean up Signed-off-by: Chen Zhang --------- Signed-off-by: Chen Zhang Co-authored-by: Yongye Zhu --- examples/offline_inference/basic/basic.py | 36 ++- vllm/model_executor/layers/mla.py | 11 +- vllm/model_executor/models/deepseek_v2.py | 95 +++--- vllm/platforms/cuda.py | 7 + .../attention/backends/mla/flashmla_sparse.py | 284 +++++++++++++----- 5 files changed, 310 insertions(+), 123 deletions(-) diff --git a/examples/offline_inference/basic/basic.py b/examples/offline_inference/basic/basic.py index 75c364321c98..63e6045cd3c7 100644 --- a/examples/offline_inference/basic/basic.py +++ b/examples/offline_inference/basic/basic.py @@ -2,22 +2,50 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from vllm import LLM, SamplingParams +from vllm.inputs.data import TokensPrompt # Sample prompts. prompts = [ - "The capital of France is", + "hello, can you tell me the answer of 1 + 1?", + ] # Create a sampling params object. -sampling_params = SamplingParams(temperature=0, top_p=1, max_tokens=1024) +sampling_params = SamplingParams(temperature=0, top_p=1, max_tokens=50) +prompt_token_ids = [ + TokensPrompt( + prompt_token_ids=[0, 128803, 33310, 14, 588, 440, 4575, 678, 270, 3287, 294, 223, 19, 940, 223, 19, 33, 128804, 128799], + ), # hello, can you tell me the answer of 1 + 1? + TokensPrompt( + prompt_token_ids=[0, 128803, 33310, 14, 1205, 344, 223, 20, 940, 223, 20, 33, 128804, 128799], + ), # hello, what is 2 + 2? + TokensPrompt( + prompt_token_ids=[0, 128803, 9602, 344, 223, 21, 940, 223, 21, 33, 8033, 1801, 678, 16, 128804, 128799], + ), # what is 3 + 3? please show me. +] + +""" +Prompt: hello, can you tell me the answer of 1 + 1? +Output: Hello! The answer to 1 + 1 is **2**. \n\nIf you have any more questions, feel free to ask! 😊 +""" + +""" +Prompt: hello, what is 2 + 2? +Output: Hello! 2 + 2 equals 4. 😊 +""" + +""" +Prompt: what is 3 + 3? please show me. +Output: Let's add 3 and 3 together:\n\n3 + 3 = 6\n\nSo, 3 plus 3 equals 6." +""" def main(): # Create an LLM. - llm = LLM(model="/home/vllm-dsv32/DeepSeek-V3.2-Preview-Fix", tensor_parallel_size=8) + llm = LLM(model="/home/vllm-dsv32/DeepSeek-V3.2-Preview-Fix", tensor_parallel_size=8, enforce_eager=True) # Generate texts from the prompts. # The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. - outputs = llm.generate(prompts, sampling_params) + outputs = llm.generate(prompt_token_ids, sampling_params) # Print the outputs. print("\nGenerated Outputs:\n" + "-" * 60) for output in outputs: diff --git a/vllm/model_executor/layers/mla.py b/vllm/model_executor/layers/mla.py index 13b879107c8c..89d62a1b2f23 100644 --- a/vllm/model_executor/layers/mla.py +++ b/vllm/model_executor/layers/mla.py @@ -154,15 +154,10 @@ def forward_native( q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb( positions, q[..., self.qk_nope_head_dim:], k_pe) - - if self.indexer: - topk_indices = self.indexer(hidden_states, q_c, positions, self.rotary_emb) - # if topk_indices is not None: - # print(topk_indices) - - if self.use_sparse: - topk_indices = torch.zeros(q.shape[0], self.topk_tokens) + if self.indexer and self.use_sparse: + topk_indices = self.indexer(hidden_states, q_c, positions, + self.rotary_emb) # NOTE(Chen): a bit hacky, but need to modify Attention.forward # otherwise. Try to refactor this later. self.mla_attn.impl.set_topk_indices(topk_indices) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index b782a6990abb..a910fbddcc1e 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -31,6 +31,7 @@ import torch from torch import nn from transformers import DeepseekV2Config, DeepseekV3Config +import torch.distributed as dist from vllm.attention.backends.abstract import AttentionBackend from vllm.logger import init_logger @@ -68,7 +69,8 @@ from vllm.v1.attention.backends.mla.indexer import DeepseekV32IndexerBackend from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP -from .utils import (PPMissingLayer, is_pp_missing_parameter, +from .utils import (PPMissingLayer, extract_layer_index, + is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheSpec @@ -80,8 +82,6 @@ logger = init_logger(__name__) -WITH_V32 = True - class DeepseekV2MLP(nn.Module): @@ -523,6 +523,7 @@ def forward(self): def get_attn_backend(self) -> AttentionBackend: return DeepseekV32IndexerBackend + # ignore or replace with pytorch def rotate_activation(x: torch.Tensor) -> torch.Tensor: assert x.dtype == torch.bfloat16 @@ -531,6 +532,7 @@ def rotate_activation(x: torch.Tensor) -> torch.Tensor: # make sure the hidden_size is expontial of 2 return hadamard_transform(x, scale=hidden_size**-0.5) + def ref_fp8_mqa_logits( q: torch.Tensor, kv: torch.Tensor, @@ -557,6 +559,7 @@ def ref_fp8_mqa_logits( cost = mask.sum() return logits, cost + # TODO (zyongye) optimize this, this is now vibe coded def kv_spans_from_batches(start_seq_loc: torch.Tensor, seq_len_per_batch: torch.Tensor): @@ -581,8 +584,8 @@ def kv_spans_from_batches(start_seq_loc: torch.Tensor, assert q.numel() == L.numel() + 1, "start_seq_loc must have length B+1" # Selected tokens per batch and totals - counts = q[1:] - q[:-1] # [B] - N = int(q[-1].item()) # total selected tokens + counts = q[1:] - q[:-1] # [B] + N = int(q[-1].item()) # total selected tokens B = L.numel() device = L.device @@ -591,26 +594,28 @@ def kv_spans_from_batches(start_seq_loc: torch.Tensor, torch.empty(0, dtype=torch.long, device=device)) # KV start offsets per batch in the concatenated KV cache - kv_starts_per_batch = torch.cumsum(L, dim=0) - L # [B] + kv_starts_per_batch = torch.cumsum(L, dim=0) - L # [B] # For each selected token, which batch does it belong to? - batch_id = torch.repeat_interleave(torch.arange(B, device=device), counts) # [N] + batch_id = torch.repeat_interleave(torch.arange(B, device=device), + counts) # [N] # Map batch KV start to each token - start_tensor = kv_starts_per_batch[batch_id] # [N] + start_tensor = kv_starts_per_batch[batch_id] # [N] # End-align local positions inside each batch: # local_pos = L[b] - counts[b] + (1..counts[b]) for each batch b - L_expand = torch.repeat_interleave(L, counts) # [N] - m_expand = torch.repeat_interleave(counts, counts) # [N] + L_expand = torch.repeat_interleave(L, counts) # [N] + m_expand = torch.repeat_interleave(counts, counts) # [N] # position within the selected block: 1..counts[b] - pos_within = (torch.arange(N, device=device, dtype=torch.long) - - torch.repeat_interleave(q[:-1], counts) + 1) + pos_within = (torch.arange(N, device=device, dtype=torch.long) - + torch.repeat_interleave(q[:-1], counts) + 1) + + local_pos = L_expand - m_expand + pos_within # [N], 1-based + end_location = start_tensor + local_pos # exclusive end - local_pos = L_expand - m_expand + pos_within # [N], 1-based - end_location = start_tensor + local_pos # exclusive end + return start_tensor.int(), end_location.int() - return start_tensor, end_location class Indexer(nn.Module): @@ -647,29 +652,26 @@ def __init__(self, self.scale_fmt = "ue8m0" self.quant_block_size = 128 # TODO: get from config - + #TODO (zyongye) change dim to fp8 later to (self.head_dim + 4) self.k_cache = DeepseekV32IndexerCache(head_dim=self.head_dim * 2, dtype=torch.bfloat16, prefix=f"{prefix}.k_cache", cache_config=cache_config) - def forward(self, hidden_states: torch.Tensor, - qr: torch.Tensor, positions, rotary_emb) -> torch.Tensor: - # print(f"hidden_states: {torch.isinf(hidden_states).any()}, qr: {torch.isinf(qr).any()}") + def forward(self, hidden_states: torch.Tensor, qr: torch.Tensor, positions, + rotary_emb) -> torch.Tensor: + # print(f"hidden_states: {torch.isinf(hidden_states).any()}, qr: {torch.isinf(qr).any()}") # print(f"weight_proj: {torch.isneginf(self.weights_proj.weight.to(torch.float32)).any()}") - q, _= self.wq_b(qr) + q, _ = self.wq_b(qr) q = q.view(-1, self.n_head, self.head_dim) q_pe, q_nope = torch.split( - q, [self.rope_dim, self.head_dim - self.rope_dim], - dim=-1 - ) - + q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1) + k, _ = self.wk(hidden_states) k = self.k_norm(k) k_pe, k_nope = torch.split( - k, [self.rope_dim, self.head_dim - self.rope_dim], - dim=-1) + k, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1) #FIXME (zyongye) this will cause OOM when using full sequence forward on 8xH200 q_pe, k_pe = rotary_emb(positions, q_pe, k_pe.unsqueeze(1)) @@ -683,10 +685,11 @@ def forward(self, hidden_states: torch.Tensor, k_fp8, k_scale = act_quant(k, self.quant_block_size, self.scale_fmt) # k_cache_entry = torch.cat([k_fp8, k_scale], dim=-1) weights, _ = self.weights_proj(hidden_states) - weights = weights.unsqueeze(-1) * self.softmax_scale * self.n_head**-0.5 + weights = weights.unsqueeze( + -1) * self.softmax_scale * self.n_head**-0.5 # careful! this will be None in dummy run - attn_metadata = get_forward_context().attn_metadata + attn_metadata = get_forward_context().attn_metadata if isinstance(attn_metadata, dict): k_cache_attn_metadata = attn_metadata[self.k_cache.prefix] slot_mapping = k_cache_attn_metadata.slot_mapping @@ -695,11 +698,10 @@ def forward(self, hidden_states: torch.Tensor, seq_lens = k_cache_attn_metadata.seq_lens batch_size = seq_lens.size(0) cu_seq_lens = torch.empty((batch_size + 1), - dtype=torch.int32, - device=q.device) + dtype=torch.int32, + device=q.device) cu_seq_lens[0] = 0 cu_seq_lens[1:] = seq_lens.cumsum(dim=0).to(dtype=torch.int32) - cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(query_start_loc, seq_lens) #TODO (zyongye) use quant type kv_cache = self.k_cache.kv_cache[0] @@ -707,22 +709,28 @@ def forward(self, hidden_states: torch.Tensor, # we need to change kv cache to only store k cache scale = torch.tensor(1, dtype=torch.float32, device=k.device) ops.concat_and_cache_mla( - k, - k.clone(), - kv_cache, + k, + k.clone(), + kv_cache, slot_mapping, - "auto", + "auto", scale, - ) + ) - flattened_kv = torch.empty([cu_seqlen_ks.size(-1), self.head_dim * 2], device=k.device, dtype=torch.bfloat16) + flattened_kv = torch.empty((cu_seq_lens[-1], self.head_dim * 2), + device=k.device, + dtype=torch.bfloat16) ops.cp_gather_cache( kv_cache, - flattened_kv, + flattened_kv, k_cache_attn_metadata.block_table, cu_seq_lens, batch_size, ) + # FIXME (zyongye) put this function before ops.cp_gather_cache would + # cause cu_seqlen_ks been changed, this need investigation later + cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches( + query_start_loc, seq_lens) logits, _ = ref_fp8_mqa_logits( q, flattened_kv[..., :self.head_dim], @@ -731,12 +739,19 @@ def forward(self, hidden_states: torch.Tensor, cu_seqlen_ke, ) topk_indices = logits.topk(min(self.topk_tokens, logits.shape[-1]), - dim=-1)[1] + dim=-1)[1] mask_lo = topk_indices >= cu_seqlen_ks[:, None] mask_hi = topk_indices < cu_seqlen_ke[:, None] mask = mask_lo & mask_hi topk_indices = topk_indices.masked_fill(~mask, -1) - return topk_indices + topk_indices_buffer = torch.full( + (logits.shape[0], self.topk_tokens), + -1, + dtype=torch.int32, + device=logits.device) + topk_indices_buffer[:, :topk_indices.shape[-1]] = topk_indices.to( + dtype=torch.int32) + return topk_indices_buffer class DeepseekV2MLAAttention(nn.Module): diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index b26e1d100dd8..34ad6e3b2bb7 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -139,6 +139,7 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: # TODO(lucas): handle this more gracefully # Note: model_config may be None during testing if model_config is not None and model_config.use_mla: + use_sparse = os.getenv("VLLM_MLA_SPARSE_ENABLED") == "1" # If `VLLM_ATTENTION_BACKEND` is not set and we are using MLA, # then we default to FlashMLA backend for non-blackwell GPUs, # else we default to CutlassMLA. For each case, we force the @@ -185,6 +186,12 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: "Forcing kv cache block size to 64 for FlashInferMLA " "backend.") + # TODO(Chen): remove this hacky code + if use_sparse and cache_config.block_size != 64: + cache_config.block_size = 64 + logger.info( + "Forcing kv cache block size to 64 for FlashMLASparse " + "backend.") # lazy import to avoid circular import from vllm.config import CUDAGraphMode diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index ba9cfd1bbc32..bbdcc5a77784 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -9,6 +9,7 @@ from vllm import _custom_ops as ops from vllm.attention.backends.abstract import AttentionLayer, AttentionMetadata from vllm.config import VllmConfig +from vllm.distributed.parallel_state import get_tensor_model_parallel_rank from vllm.logger import init_logger from vllm.v1.attention.backends.mla.common import (MLACommonBackend, MLACommonDecodeMetadata, @@ -18,6 +19,8 @@ from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, split_decodes_and_prefills) from vllm.v1.kv_cache_interface import AttentionSpec +import triton +import triton.language as tl logger = init_logger(__name__) @@ -64,7 +67,8 @@ def get_supported_head_sizes(cls) -> list[int]: class MLASparsePrefillMetadata: # NOTE(Chen): not call it "FlashMLASparsePrefillMetadata" because # the kernel is not from flashmla - def __init__(self): + def __init__(self, block_table: torch.Tensor, + req_id_per_token: torch.Tensor): pass @@ -75,12 +79,141 @@ def __init__(self): @dataclass -class FlashMLASparseMetadata(MLACommonMetadata[MLASparsePrefillMetadata]): +class FlashMLASparseMetadata: + num_reqs: int + max_query_len: int + max_seq_len: int + + num_actual_tokens: int # Number of tokens excluding padding. + query_start_loc: torch.Tensor + slot_mapping: torch.Tensor + + block_table: torch.Tensor + req_id_per_token: torch.Tensor + block_size: int = 64 + topk_tokens: int = 2048 + # For now just create topk_indices that just attend to the first topk tokens # always to enable development debug_topk_indices: Optional[torch.Tensor] = None +@triton.jit +def _convert_req_index_to_global_index_kernel( + req_id_ptr, # int32 [num_tokens] + block_table_ptr, # int32 [num_requests, max_num_blocks_per_req] + token_indices_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS] + out_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS] + # shapes (compile-time where possible) + max_num_blocks_per_req: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + BLOCK_N: tl.constexpr, # tile width along columns + # strides (in elements) + bt_stride0, + bt_stride1, + ti_stride0, + ti_stride1, + out_stride0, + out_stride1, +): + # program_id(0) -> token_id (row) + # program_id(1) -> tile index along columns + token_id = tl.program_id(0) + tile_id = tl.program_id(1) + + # Each program covers BLOCK_N consecutive columns + indice_id = tile_id * BLOCK_N + tl.arange(0, BLOCK_N) + + # Load request id for this token (no mask: grid is exact) + req = tl.load(req_id_ptr + token_id) + + # Load token indices for this tile + ti_ptr = token_indices_ptr + token_id * ti_stride0 + indice_id * ti_stride1 + tok = tl.load(ti_ptr) # int32 + + # Only token == -1 should propagate as -1 + is_invalid_tok = tok < 0 + + # Compute block id and in-block offset + block_id = tok // BLOCK_SIZE + inblock_off = tok % BLOCK_SIZE + + # Guard block_table access + valid_block = block_id < max_num_blocks_per_req + bt_ptr = block_table_ptr + req * bt_stride0 + block_id * bt_stride1 + base = tl.load(bt_ptr, mask=valid_block, other=0) + + # If token == -1 OR block_id OOB, output -1; else base * BLOCK_SIZE + offset + out_val = tl.where(is_invalid_tok | (~valid_block), -1, + base * BLOCK_SIZE + inblock_off) + + # Store results + out_ptr_ij = out_ptr + token_id * out_stride0 + indice_id * out_stride1 + tl.store(out_ptr_ij, out_val) + + +def triton_convert_req_index_to_global_index( + req_id: torch.Tensor, # int32 [num_tokens] + block_table: torch. + Tensor, # int32 [num_requests, max_num_blocks_per_req] + token_indices: torch.Tensor, # int32 [num_tokens, NUM_TOPK_TOKENS] + BLOCK_SIZE: int = 64, + NUM_TOPK_TOKENS: int = 2048, + BLOCK_N: int = 128, # tile width along columns +): + """ + out[token_id, indice_id] = + block_table[req_id[token_id], token_indices[token_id, indice_id] // BLOCK_SIZE] * BLOCK_SIZE + + token_indices[token_id, indice_id] % BLOCK_SIZE + + Only when token_indices[token_id, indice_id] == -1 do we output -1. + For safety, we also output -1 if the derived block_id would be out-of-bounds. + """ + assert req_id.dtype == torch.int32 + assert block_table.dtype == torch.int32 + assert token_indices.dtype == torch.int32 + assert token_indices.shape[1] == NUM_TOPK_TOKENS + assert NUM_TOPK_TOKENS % BLOCK_N == 0, \ + f"NUM_TOPK_TOKENS ({NUM_TOPK_TOKENS}) must be divisible by BLOCK_N ({BLOCK_N})" + + num_tokens = req_id.shape[0] + num_requests, max_num_blocks_per_req = block_table.shape + tiles_per_row = NUM_TOPK_TOKENS // BLOCK_N + + # Ensure contiguous tensors on the same device + req_id_c = req_id.contiguous() + block_table_c = block_table.contiguous() + token_indices_c = token_indices.contiguous() + out = torch.empty_like(token_indices_c) + + # Strides in elements + bt_stride0, bt_stride1 = block_table_c.stride() + ti_stride0, ti_stride1 = token_indices_c.stride() + out_stride0, out_stride1 = out.stride() + + # Exact 2D grid: tokens × column tiles + grid = (num_tokens, tiles_per_row) + + _convert_req_index_to_global_index_kernel[grid]( + req_id_c, + block_table_c, + token_indices_c, + out, + # shapes / constexprs + max_num_blocks_per_req, + BLOCK_SIZE, + BLOCK_N, + # strides + bt_stride0, + bt_stride1, + ti_stride0, + ti_stride1, + out_stride0, + out_stride1, + ) + return out + + @dataclass class FlashMLASparseMetadataBuilder( MLACommonMetadataBuilder[FlashMLASparseMetadata]): @@ -106,23 +239,28 @@ def build(self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, fast_build: bool = False) -> FlashMLASparseMetadata: - logger.info("build FlashMLASparseMetadata") num_actual_tokens = common_attn_metadata.num_actual_tokens - num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\ - split_decodes_and_prefills(common_attn_metadata, - decode_threshold=self.reorder_batch_threshold) - - starts = np.asarray(common_attn_metadata.query_start_loc_cpu) - pos = np.arange(starts[-1]) - np.repeat(starts[:-1], np.diff(starts)) - pos_gpu = torch.as_tensor(pos, device=self.device, dtype=torch.long) - - row = torch.arange(self.topk_tokens, - device=self.device, - dtype=torch.int64) - debug_topk_indices = row.repeat(num_actual_tokens, 1) - mask = debug_topk_indices < pos_gpu.unsqueeze(1) - debug_topk_indices = debug_topk_indices.masked_fill(~mask, -1) + starts = np.asarray(common_attn_metadata.query_start_loc_cpu, + dtype=np.int32) + seg_lengths = np.diff(starts) + req_id_per_token = np.repeat( + np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths) + + # pos = np.arange(starts[-1]) - np.repeat(starts[:-1], np.diff(starts)) + # seq_lengths = np.asarray(common_attn_metadata.seq_lens_cpu, + # dtype=np.int32) + # prefix_length = seq_lengths - seg_lengths + # prefix_length_per_token = np.repeat(prefix_length, seg_lengths) + # pos = pos + prefix_length_per_token + # pos_gpu = torch.as_tensor(pos, device=self.device, dtype=torch.long) + # row = torch.arange(self.topk_tokens, + # device=self.device, + # dtype=torch.int32) + # debug_topk_indices = row.repeat(num_actual_tokens, 1) + # mask = debug_topk_indices <= pos_gpu.unsqueeze(1) + # debug_topk_indices = debug_topk_indices.masked_fill(~mask, -1) + debug_topk_indices = None return FlashMLASparseMetadata( num_reqs=common_attn_metadata.num_reqs, max_query_len=common_attn_metadata.max_query_len, @@ -130,12 +268,15 @@ def build(self, num_actual_tokens=common_attn_metadata.num_actual_tokens, query_start_loc=common_attn_metadata.query_start_loc, slot_mapping=common_attn_metadata.slot_mapping, - num_decodes=num_decodes, - num_decode_tokens=num_decode_tokens, - num_prefills=num_prefills, + block_table=common_attn_metadata.block_table_tensor, + req_id_per_token=torch.from_numpy(req_id_per_token).to( + device='cuda'), + # num_decodes=num_decodes, + # num_decode_tokens=num_decode_tokens, + # num_prefills=num_prefills, + block_size=self.kv_cache_spec.block_size, + topk_tokens=self.topk_tokens, debug_topk_indices=debug_topk_indices, - prefill=self._build_prefill(common_attn_metadata), - decode=self._build_decode(common_attn_metadata), ) @@ -160,8 +301,14 @@ def __init__( alibi_slopes, sliding_window, kv_cache_dtype, logits_soft_cap, attn_type, kv_sharing_target_layer_name, **mla_args) - # self.sm_scale = self.topk_indices = None + # TODO(Chen): use the correct way to import. Now I just pip install + # the reference repo and use it. + try: + from sparse_topk_attn import sparse_topk_attn_fwd + except ImportError: + raise ImportError("sparse_topk_attn_fwd is not found") + self.sparse_topk_attn_fwd = sparse_topk_attn_fwd def set_topk_indices(self, topk_indices: torch.Tensor): self.topk_indices = topk_indices @@ -195,22 +342,14 @@ def forward( # same expert outputs. return output.fill_(0) - num_actual_toks = attn_metadata.num_actual_tokens + num_actual_tokens = attn_metadata.num_actual_tokens # Inputs and outputs may be padded for CUDA graphs output_padded = output - output = output[:num_actual_toks, ...] - q = q[:num_actual_toks, ...] - k_c_normed = k_c_normed[:num_actual_toks, ...] - k_pe = k_pe[:num_actual_toks, ...] - - assert attn_metadata.num_decodes is not None and \ - attn_metadata.num_prefills is not None and \ - attn_metadata.num_decode_tokens is not None - - has_decode = attn_metadata.num_decodes > 0 - has_prefill = attn_metadata.num_prefills > 0 - num_decode_tokens = attn_metadata.num_decode_tokens + output = output[:num_actual_tokens, ...] + q = q[:num_actual_tokens, ...] + k_c_normed = k_c_normed[:num_actual_tokens, ...] + k_pe = k_pe[:num_actual_tokens, ...] q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) @@ -221,12 +360,6 @@ def forward( # Convert from (N, B, L) to (B, N, L) ql_nope = ql_nope.transpose(0, 1) - decode_ql_nope = ql_nope[:num_decode_tokens] - decode_q_pe = q_pe[:num_decode_tokens] - - prefill_ql_nope = ql_nope[num_decode_tokens:] - prefill_q_pe = q_pe[num_decode_tokens:] - # write the latent and rope to kv cache if kv_cache.numel() > 0: ops.concat_and_cache_mla( @@ -238,34 +371,43 @@ def forward( scale=layer._k_scale, ) - if has_prefill: - attn_out = self._forward_prefill(prefill_ql_nope, prefill_q_pe, - kv_cache, attn_metadata, - layer._k_scale) - # v_up projection - output[num_decode_tokens:] = self._v_up_proj(attn_out) - if has_decode: - # call decode attn - attn_out, lse = self._forward_decode( - (decode_ql_nope, decode_q_pe), kv_cache, attn_metadata, layer) - # v_up projection - output[:num_decode_tokens] = self._v_up_proj(attn_out) + attn_out = self._forward_bf16_kv(ql_nope, q_pe, kv_cache, + attn_metadata, self.scale) + + output[:num_actual_tokens] = self._v_up_proj(attn_out) return output_padded - def _forward_prefill(self, ql_nope: torch.Tensor, q_pe: torch.Tensor, + def _forward_bf16_kv(self, ql_nope: torch.Tensor, q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: FlashMLASparseMetadata, k_scale: torch.Tensor) -> torch.Tensor: - # # assume indice of shape [num_prefill_tokens, topk] - # block_id_in_req = topk_indices // self.block_size - topk_indices = self.topk_indices[attn_metadata.num_decodes:] - logger.info("called _forward_prefill with topk_indices shape %s", - topk_indices.shape) - # NOTE(Chen): shape is unsure - - return torch.zeros((ql_nope.shape[0], ql_nope.shape[1], 512), - dtype=ql_nope.dtype, - device=ql_nope.device) + topk_indices = self.topk_indices[:attn_metadata.num_actual_tokens] + num_tokens = ql_nope.shape[0] + q = torch.cat([ql_nope, q_pe], dim=-1) + kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view( + -1, 1, kv_c_and_k_pe_cache.shape[-1]) + # NOTE(Chen): kernel requires num_local_head to be a multiple of 64. + if self.num_heads % 64 != 0: + assert 64 % self.num_heads == 0 + logger.warning_once( + f"padding num_heads to 64 due to sparse attn kernel requirement" + ) + q_padded = q.new_empty((q.shape[0], 64, q.shape[2])) + q_padded[:, :self.num_heads, :] = q + q = q_padded + # TODO: handle index / kv_cache correctly + topk_indices_global = triton_convert_req_index_to_global_index( + attn_metadata.req_id_per_token, + attn_metadata.block_table, + topk_indices, + BLOCK_SIZE=attn_metadata.block_size, + NUM_TOPK_TOKENS=attn_metadata.topk_tokens, + ) + topk_indices_global = topk_indices_global.view(num_tokens, 1, -1) + output = self.sparse_topk_attn_fwd(q, kv_c_and_k_pe_cache, + topk_indices_global, k_scale)[0] + output = output[:, :self.num_heads, :] + return output def _forward_decode( self, @@ -283,13 +425,13 @@ def _forward_decode( logger.info("called _forward_decode with topk_indices shape %s", topk_indices.shape) - + ql_nope, q_pe = q - + attn_out = torch.zeros((ql_nope.shape[0], ql_nope.shape[1], 512), - dtype=ql_nope.dtype, - device=ql_nope.device) - lse = None #TODO - + dtype=ql_nope.dtype, + device=ql_nope.device) + lse = None #TODO + # NOTE(Chen): shape is unsure return attn_out, lse From 75d382efb1df4259798e83728f85f8512c8e7dbd Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Mon, 22 Sep 2025 23:24:47 -0700 Subject: [PATCH 21/82] fix indexer bs>1 (#23) Signed-off-by: Chen Zhang --- vllm/model_executor/models/deepseek_v2.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index a910fbddcc1e..20cffd278eb8 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -740,8 +740,10 @@ def forward(self, hidden_states: torch.Tensor, qr: torch.Tensor, positions, ) topk_indices = logits.topk(min(self.topk_tokens, logits.shape[-1]), dim=-1)[1] - mask_lo = topk_indices >= cu_seqlen_ks[:, None] - mask_hi = topk_indices < cu_seqlen_ke[:, None] + topk_indices -= cu_seqlen_ks[:, None] + mask_lo = topk_indices >= 0 + mask_hi = topk_indices < cu_seqlen_ke[:, None] - cu_seqlen_ks[:, + None] mask = mask_lo & mask_hi topk_indices = topk_indices.masked_fill(~mask, -1) topk_indices_buffer = torch.full( From 9905f9d910fe1c81f49dfe181f3690a1eb4eb293 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 23 Sep 2025 15:35:17 +0000 Subject: [PATCH 22/82] fix build Signed-off-by: Lucas Wilkinson --- cmake/external_projects/flashmla.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/external_projects/flashmla.cmake b/cmake/external_projects/flashmla.cmake index abde4118bcc0..b8a0b0394771 100644 --- a/cmake/external_projects/flashmla.cmake +++ b/cmake/external_projects/flashmla.cmake @@ -139,7 +139,7 @@ if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS) ARCHITECTURES ${VLLM_GPU_ARCHES} # Only the common/public includes here; do NOT add decode/prefill folders INCLUDE_DIRECTORIES - csrc/ + ${flashmla_SOURCE_DIR}/csrc/ ${CUTLASS_INCLUDE_DIR} ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} USE_SABI 3 From 23e809c6d23a3499ea6bd37adacd494d9477ac96 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 23 Sep 2025 15:54:51 -0700 Subject: [PATCH 23/82] fix import (#24) Signed-off-by: Chen Zhang --- vllm/v1/attention/backends/mla/flashmla_sparse.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index bbdcc5a77784..6495458c265e 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -8,6 +8,7 @@ from vllm import _custom_ops as ops from vllm.attention.backends.abstract import AttentionLayer, AttentionMetadata +from vllm.attention.ops.flashmla import flash_mla_sparse_prefill from vllm.config import VllmConfig from vllm.distributed.parallel_state import get_tensor_model_parallel_rank from vllm.logger import init_logger @@ -302,13 +303,6 @@ def __init__( logits_soft_cap, attn_type, kv_sharing_target_layer_name, **mla_args) self.topk_indices = None - # TODO(Chen): use the correct way to import. Now I just pip install - # the reference repo and use it. - try: - from sparse_topk_attn import sparse_topk_attn_fwd - except ImportError: - raise ImportError("sparse_topk_attn_fwd is not found") - self.sparse_topk_attn_fwd = sparse_topk_attn_fwd def set_topk_indices(self, topk_indices: torch.Tensor): self.topk_indices = topk_indices @@ -404,8 +398,8 @@ def _forward_bf16_kv(self, ql_nope: torch.Tensor, q_pe: torch.Tensor, NUM_TOPK_TOKENS=attn_metadata.topk_tokens, ) topk_indices_global = topk_indices_global.view(num_tokens, 1, -1) - output = self.sparse_topk_attn_fwd(q, kv_c_and_k_pe_cache, - topk_indices_global, k_scale)[0] + output = flash_mla_sparse_prefill(q, kv_c_and_k_pe_cache, + topk_indices_global, k_scale)[0] output = output[:, :self.num_heads, :] return output From e19d0c9941cfac0a3bc05ea828b3b2bfd6eb819c Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 23 Sep 2025 23:13:02 +0000 Subject: [PATCH 24/82] enable sparse by default Signed-off-by: Chen Zhang --- vllm/model_executor/layers/mla.py | 2 +- vllm/platforms/cuda.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/mla.py b/vllm/model_executor/layers/mla.py index 89d62a1b2f23..f3134f6e8caa 100644 --- a/vllm/model_executor/layers/mla.py +++ b/vllm/model_executor/layers/mla.py @@ -82,7 +82,7 @@ def __init__( self.indexer = mla_modules.indexer self.topk_tokens = mla_modules.indexer.topk_tokens self.use_sparse = mla_modules.is_sparse and os.getenv( - "VLLM_MLA_SPARSE_ENABLED") == "1" + "VLLM_MLA_SPARSE_DISABLED") != "1" # In the MLA backend, kv_cache includes both k_c and # pe (i.e. decoupled position embeddings). In particular, diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 34ad6e3b2bb7..1575b78d1a06 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -139,7 +139,7 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: # TODO(lucas): handle this more gracefully # Note: model_config may be None during testing if model_config is not None and model_config.use_mla: - use_sparse = os.getenv("VLLM_MLA_SPARSE_ENABLED") == "1" + use_sparse = os.getenv("VLLM_MLA_SPARSE_DISABLED") != "1" # If `VLLM_ATTENTION_BACKEND` is not set and we are using MLA, # then we default to FlashMLA backend for non-blackwell GPUs, # else we default to CutlassMLA. For each case, we force the @@ -244,7 +244,7 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, has_sink, use_sparse) -> str: if use_mla: use_sparse = os.getenv( - "VLLM_MLA_SPARSE_ENABLED") == "1" and use_sparse + "VLLM_MLA_SPARSE_DISABLED") != "1" and use_sparse # TODO(lucas): refactor to be more concise # we should probably consider factoring out V1 here From bff59447c059f8204fed914375650b7abafc9b63 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Wed, 24 Sep 2025 13:59:32 +0000 Subject: [PATCH 25/82] fix mla Signed-off-by: NickLucche --- vllm/model_executor/layers/mla.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/mla.py b/vllm/model_executor/layers/mla.py index 89d62a1b2f23..00192c2628c0 100644 --- a/vllm/model_executor/layers/mla.py +++ b/vllm/model_executor/layers/mla.py @@ -80,7 +80,8 @@ def __init__( self.rotary_emb = mla_modules.rotary_emb self.o_proj = mla_modules.o_proj self.indexer = mla_modules.indexer - self.topk_tokens = mla_modules.indexer.topk_tokens + self.topk_tokens = mla_modules.indexer.topk_tokens \ + if self.indexer else None self.use_sparse = mla_modules.is_sparse and os.getenv( "VLLM_MLA_SPARSE_ENABLED") == "1" From b3a44bd30aac2edc0c7dfa0603acc8fba4bc0b73 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Thu, 25 Sep 2025 00:28:17 +0000 Subject: [PATCH 26/82] fix unify kv cache spec Signed-off-by: Chen Zhang --- vllm/v1/core/kv_cache_utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 47a41322c423..55cc7ea5a265 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -1103,7 +1103,9 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]): kv_cache_spec: The kv cache spec of each attention layer in the model """ - if is_kv_cache_spec_uniform(kv_cache_spec): + if is_kv_cache_spec_uniform( + kv_cache_spec) or UniformTypeKVCacheSpecs.is_uniform_type( + kv_cache_spec): return logger.warning( @@ -1141,7 +1143,8 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]): attention_chunk_size=spec.attention_chunk_size, ) - if not is_kv_cache_spec_uniform(kv_cache_spec): + if not (is_kv_cache_spec_uniform(kv_cache_spec) + or UniformTypeKVCacheSpecs.is_uniform_type(kv_cache_spec)): raise ValueError("Hybrid KV cache manager is disabled but failed to " "convert the KV cache specs to one unified type.") From c81e3f7a753d9a87b767f4007ec02467a6cb2e0b Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Thu, 25 Sep 2025 14:43:25 +0000 Subject: [PATCH 27/82] Fix paged_mqa_logits clear True --- vllm/utils/deep_gemm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index 2f34d93e49fc..c264a814bebb 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -234,7 +234,7 @@ def fp8_paged_mqa_logits( block_tables, schedule_metadata, max_model_len, - clean_logits=False + clean_logits=True ) From 87104b57afb3e1bef0c9948305bc2ff0252787d0 Mon Sep 17 00:00:00 2001 From: Yongye Zhu Date: Thu, 25 Sep 2025 15:39:03 -0400 Subject: [PATCH 28/82] Separate indexer prefill and decode and use different kernel (#26) * indexer medatata to separate prefill and decode * deep_gemm prefill kernel * decode kernel, can run for single batch * bug fixing insert decode k into kv before gemm * don't use tilelang quant function * faster non-looping torch for kv cache insertion * add chunked prefill impl * change quant kernel back to tilelang for promotion * fix format (#31) Signed-off-by: Chen Zhang * update unit tests * Fp8 indexer prefill (#33) * init Signed-off-by: Chen Zhang * can run --------- Signed-off-by: Chen Zhang * remove debug comment Signed-off-by: Chen Zhang * cleanup * further cleanup --------- Signed-off-by: Chen Zhang Co-authored-by: mgoin Co-authored-by: Chen Zhang --- tests/kernels/attention/test_indexer.py | 120 ++------- vllm/model_executor/models/deepseek_v2.py | 284 +++++++++++----------- vllm/v1/attention/backends/mla/indexer.py | 216 ++++++++++++++-- 3 files changed, 362 insertions(+), 258 deletions(-) diff --git a/tests/kernels/attention/test_indexer.py b/tests/kernels/attention/test_indexer.py index f0ea88a71f53..a7b3e0c93572 100644 --- a/tests/kernels/attention/test_indexer.py +++ b/tests/kernels/attention/test_indexer.py @@ -4,7 +4,13 @@ from vllm.utils.tile_lang_kernels import act_quant, fp8_index from vllm import _custom_ops as ops - +from vllm.model_executor.models.deepseek_v2 import kv_spans_from_batches, indexer_k_quant_and_cache +from vllm.utils.deep_gemm import ( + fp8_mqa_logits, + get_paged_mqa_logits_metadata, + fp8_paged_mqa_logits, +) +from vllm.utils.tile_lang_kernels import act_quant def ref_compute_logits_fp8(q, kv, weights, mask, block_size): q_fp8, q_scale = act_quant(q, block_size, "ue8m0") @@ -38,86 +44,7 @@ def ref_indexer(seq_len, q, kv, weights, block_size, topk): # dim=-1)[1] return varlen_logits -def kv_spans_from_batches(start_seq_loc: torch.Tensor, - seq_len_per_batch: torch.Tensor): - """ - Args: - start_seq_loc: 1D long tensor [B+1], cumulative counts of selected tokens per batch. - Example: [0, 2, 4, 7] -> batch sizes (selected) [2, 2, 3], N=7 tokens total. - seq_len_per_batch: 1D long tensor [B], full sequence length (KV length) of each batch. - Example: [5, 9, 4]. - - Returns: - start_tensor: 1D long tensor [N], start offset in the concatenated KV cache for each token's batch. - end_location: 1D long tensor [N], **exclusive** end = start + token's local position. - (So the attended KV slice is kv[start:end].) - - Assumes each batch contributes its full `seq_len_per_batch[i]` keys to the KV cache, and - the selected tokens within a batch are the **last** `counts[i]` positions of that sequence. - """ - q = start_seq_loc.to(dtype=torch.long) - L = seq_len_per_batch.to(dtype=torch.long, device=q.device) - assert q.dim() == 1 and L.dim() == 1 - assert q.numel() == L.numel() + 1, "start_seq_loc must have length B+1" - - # Selected tokens per batch and totals - counts = q[1:] - q[:-1] # [B] - N = int(q[-1].item()) # total selected tokens - B = L.numel() - device = L.device - - if N == 0: - return (torch.empty(0, dtype=torch.long, device=device), - torch.empty(0, dtype=torch.long, device=device)) - - # KV start offsets per batch in the concatenated KV cache - kv_starts_per_batch = torch.cumsum(L, dim=0) - L # [B] - - # For each selected token, which batch does it belong to? - batch_id = torch.repeat_interleave(torch.arange(B, device=device), counts) # [N] - - # Map batch KV start to each token - start_tensor = kv_starts_per_batch[batch_id] # [N] - - # End-align local positions inside each batch: - # local_pos = L[b] - counts[b] + (1..counts[b]) for each batch b - L_expand = torch.repeat_interleave(L, counts) # [N] - m_expand = torch.repeat_interleave(counts, counts) # [N] - # position within the selected block: 1..counts[b] - pos_within = (torch.arange(N, device=device, dtype=torch.long) - - torch.repeat_interleave(q[:-1], counts) + 1) - - local_pos = L_expand - m_expand + pos_within # [N], 1-based - end_location = start_tensor + local_pos # exclusive end - - return start_tensor, end_location - -def ref_fp8_mqa_logits( - q: torch.Tensor, - kv: torch.Tensor, - weights: torch.Tensor, - cu_seqlen_ks: torch.Tensor, - cu_seqlen_ke: torch.Tensor, -): - k = kv - q = q.float() - k = k.float() - - seq_len_kv = kv.shape[0] - mask_lo = (torch.arange(0, seq_len_kv, device="cuda")[None, :] - >= cu_seqlen_ks[:, None]) - mask_hi = (torch.arange(0, seq_len_kv, device="cuda")[None, :] - < cu_seqlen_ke[:, None]) - mask = mask_lo & mask_hi - - score = torch.einsum("mhd,nd->hmn", q, k) - logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0) - logits = logits.masked_fill(~mask, float("-inf")) - - cost = mask.sum() - return logits, cost - -def torch_indexer(seq_len, q, kv, weights, block_size, topk): +def deepgemm_mqa_indexer(seq_len, q, kv, weights, block_size, topk): NUM_BLOCKS = 8 BLOCK_SIZE = 32 @@ -144,9 +71,11 @@ def torch_indexer(seq_len, q, kv, weights, block_size, topk): q = torch.cat(concat_q, dim=0) kv = torch.cat(concat_kv, dim=0) weights = torch.cat(concat_weights, dim=0) + q_fp8, q_scale = act_quant(q, 128, "ue8m0") + kv_fp8, kv_scale = act_quant(kv, 128, "ue8m0") # write to kv cache based on slot mapping - entry_size = head_dim * 2 + entry_size = head_dim + 4 num_tokens = q.size(0) slot_mapping_lst = random.sample(range(total_slots), num_tokens) slot_mapping = torch.tensor(slot_mapping_lst, @@ -156,18 +85,9 @@ def torch_indexer(seq_len, q, kv, weights, block_size, topk): NUM_BLOCKS, BLOCK_SIZE, entry_size, - dtype=torch.bfloat16, + dtype=torch.uint8, device="cuda" ) - scale = torch.tensor(1, dtype=torch.float32, device="cuda") - ops.concat_and_cache_mla( - kv, - kv.clone(), - kv_cache, - slot_mapping, - "auto", - scale - ) current_index = 0 for i in range(B): @@ -175,19 +95,17 @@ def torch_indexer(seq_len, q, kv, weights, block_size, topk): block_table[i][:S] = slot_mapping[current_index: current_index + S] current_index += S - weights = weights * (128**(-0.5)) + weights = weights.unsqueeze(-1) * (128**(-0.5)) * q_scale + weights = weights.squeeze(-1) query_start_loc = torch.empty((B + 1), device="cuda") query_start_loc[0] = 0 query_start_loc[1:] = seq_len.cumsum(dim=0).to(dtype=torch.int32) - kv_gathered = kv_cache.view(-1, entry_size)[slot_mapping][..., :head_dim] - torch.testing.assert_close(kv, kv_gathered) - cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(query_start_loc, seq_len) - logits, _ = ref_fp8_mqa_logits( - q, - kv_gathered, + logits = fp8_mqa_logits( + q_fp8, + (kv_fp8, kv_scale), weights, cu_seqlen_ks, cu_seqlen_ke @@ -200,7 +118,7 @@ def torch_indexer(seq_len, q, kv, weights, block_size, topk): return logits def test_paged_indexer_python(): - B = 2 + B = S = 128 SKV = S H = 64 @@ -218,7 +136,7 @@ def test_paged_indexer_python(): weights = torch.randn(B, S, H, device=device, dtype=torch.float32) * H**-0.5 ref_indices = ref_indexer(seq_len, q, kv, weights, block_size, topk) - torch_indices = torch_indexer(seq_len, q, kv, weights, block_size, topk) + torch_indices = deepgemm_mqa_indexer(seq_len, q, kv, weights, block_size, topk) import pdb; pdb.set_trace() print(ref_indices) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 20cffd278eb8..27403d5d3813 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -34,6 +34,7 @@ import torch.distributed as dist from vllm.attention.backends.abstract import AttentionBackend +from vllm.distributed.parallel_state import get_tp_group from vllm.logger import init_logger from vllm.config.compilation import CompilationConfig import vllm.envs as envs @@ -66,7 +67,7 @@ from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.utils import cdiv, direct_register_custom_op -from vllm.v1.attention.backends.mla.indexer import DeepseekV32IndexerBackend +from vllm.v1.attention.backends.mla.indexer import DeepseekV32IndexerBackend, DeepseekV32IndexerMetadata from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP from .utils import (PPMissingLayer, extract_layer_index, @@ -74,6 +75,12 @@ make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheSpec +from vllm.utils.deep_gemm import ( + fp8_mqa_logits, + get_paged_mqa_logits_metadata, + fp8_paged_mqa_logits, +) +from vllm.model_executor.layers.quantization.utils.fp8_utils import per_token_group_quant_fp8 if current_platform.is_cuda_alike(): from vllm import _custom_ops as ops @@ -348,6 +355,7 @@ class DeepseekV2Attention(nn.Module): def __init__( self, + vllm_config: VllmConfig, config: Union[DeepseekV2Config, DeepseekV3Config], hidden_size: int, num_heads: int, @@ -532,94 +540,35 @@ def rotate_activation(x: torch.Tensor) -> torch.Tensor: # make sure the hidden_size is expontial of 2 return hadamard_transform(x, scale=hidden_size**-0.5) - -def ref_fp8_mqa_logits( - q: torch.Tensor, - kv: torch.Tensor, - weights: torch.Tensor, - cu_seqlen_ks: torch.Tensor, - cu_seqlen_ke: torch.Tensor, +@torch.inference_mode() +def indexer_k_quant_and_cache( + k, + kv_cache, + slot_mapping, + quant_block_size, + scale_fmt, ): - # print(f"q_shape: {q.shape}, v_shape: {kv.shape}, weights.shape: {weights.shape}") - k = kv - q = q.float() - k = k.float() - - seq_len_kv = kv.shape[0] - mask_lo = (torch.arange(0, seq_len_kv, device="cuda")[None, :] - >= cu_seqlen_ks[:, None]) - mask_hi = (torch.arange(0, seq_len_kv, device="cuda")[None, :] - < cu_seqlen_ke[:, None]) - mask = mask_lo & mask_hi - - score = torch.einsum("mhd,nd->hmn", q, k) - logits = (score.relu() * weights.transpose(0, 1)).sum(dim=0) - logits = logits.masked_fill(~mask, float("-inf")) - - cost = mask.sum() - return logits, cost - - -# TODO (zyongye) optimize this, this is now vibe coded -def kv_spans_from_batches(start_seq_loc: torch.Tensor, - seq_len_per_batch: torch.Tensor): - """ - Args: - start_seq_loc: 1D long tensor [B+1], cumulative counts of selected tokens per batch. - Example: [0, 2, 4, 7] -> batch sizes (selected) [2, 2, 3], N=7 tokens total. - seq_len_per_batch: 1D long tensor [B], full sequence length (KV length) of each batch. - Example: [5, 9, 4]. - - Returns: - start_tensor: 1D long tensor [N], start offset in the concatenated KV cache for each token's batch. - end_location: 1D long tensor [N], **exclusive** end = start + token's local position. - (So the attended KV slice is kv[start:end].) - - Assumes each batch contributes its full `seq_len_per_batch[i]` keys to the KV cache, and - the selected tokens within a batch are the **last** `counts[i]` positions of that sequence. - """ - q = start_seq_loc.to(dtype=torch.long) - L = seq_len_per_batch.to(dtype=torch.long, device=q.device) - assert q.dim() == 1 and L.dim() == 1 - assert q.numel() == L.numel() + 1, "start_seq_loc must have length B+1" - - # Selected tokens per batch and totals - counts = q[1:] - q[:-1] # [B] - N = int(q[-1].item()) # total selected tokens - B = L.numel() - device = L.device - - if N == 0: - return (torch.empty(0, dtype=torch.long, device=device), - torch.empty(0, dtype=torch.long, device=device)) - - # KV start offsets per batch in the concatenated KV cache - kv_starts_per_batch = torch.cumsum(L, dim=0) - L # [B] + from vllm.utils.tile_lang_kernels import act_quant + _, block_size, head_dim = kv_cache.shape + k_fp8, k_scale = act_quant(k, quant_block_size, scale_fmt) + k_bytes = k_fp8.view(torch.uint8) + s_bytes = k_scale.view(torch.uint8) - # For each selected token, which batch does it belong to? - batch_id = torch.repeat_interleave(torch.arange(B, device=device), - counts) # [N] + packed = torch.cat([k_bytes, s_bytes], dim=-1) - # Map batch KV start to each token - start_tensor = kv_starts_per_batch[batch_id] # [N] + block_idx = torch.div(slot_mapping, block_size, rounding_mode='floor') + inblock = slot_mapping - block_idx * block_size + linear = block_idx * block_size + inblock - # End-align local positions inside each batch: - # local_pos = L[b] - counts[b] + (1..counts[b]) for each batch b - L_expand = torch.repeat_interleave(L, counts) # [N] - m_expand = torch.repeat_interleave(counts, counts) # [N] - # position within the selected block: 1..counts[b] - pos_within = (torch.arange(N, device=device, dtype=torch.long) - - torch.repeat_interleave(q[:-1], counts) + 1) + kv_cache_flat = kv_cache.view(-1, head_dim) - local_pos = L_expand - m_expand + pos_within # [N], 1-based - end_location = start_tensor + local_pos # exclusive end - - return start_tensor.int(), end_location.int() + kv_cache_flat.index_copy_(0, linear, packed) class Indexer(nn.Module): def __init__(self, + vllm_config: VllmConfig, config: Union[DeepseekV2Config, DeepseekV3Config], hidden_size: int, q_lora_rank: int, @@ -627,6 +576,7 @@ def __init__(self, cache_config: Optional[CacheConfig], prefix: str = ""): super().__init__() + self.vllm_config = vllm_config self.config = config self.indexer_cfg = config.attn_module_list_cfg[0]["attn_index"] self.topk_tokens = config.attn_module_list_cfg[0]["topk_tokens"] @@ -654,10 +604,12 @@ def __init__(self, self.quant_block_size = 128 # TODO: get from config #TODO (zyongye) change dim to fp8 later to (self.head_dim + 4) - self.k_cache = DeepseekV32IndexerCache(head_dim=self.head_dim * 2, - dtype=torch.bfloat16, + self.k_cache = DeepseekV32IndexerCache(head_dim=self.head_dim + 4, + dtype=torch.uint8, prefix=f"{prefix}.k_cache", cache_config=cache_config) + self.max_model_len = vllm_config.model_config.max_model_len + self.prefix = prefix def forward(self, hidden_states: torch.Tensor, qr: torch.Tensor, positions, rotary_emb) -> torch.Tensor: @@ -680,80 +632,130 @@ def forward(self, hidden_states: torch.Tensor, qr: torch.Tensor, positions, q = rotate_activation(q) k = rotate_activation(k) + # we only quant q here since k quant is fused with cache insertion from vllm.utils.tile_lang_kernels import act_quant q_fp8, q_scale = act_quant(q, self.quant_block_size, self.scale_fmt) - k_fp8, k_scale = act_quant(k, self.quant_block_size, self.scale_fmt) - # k_cache_entry = torch.cat([k_fp8, k_scale], dim=-1) + weights, _ = self.weights_proj(hidden_states) weights = weights.unsqueeze( - -1) * self.softmax_scale * self.n_head**-0.5 + -1) * q_scale * self.softmax_scale * self.n_head**-0.5 + weights = weights.squeeze(-1) # careful! this will be None in dummy run attn_metadata = get_forward_context().attn_metadata if isinstance(attn_metadata, dict): - k_cache_attn_metadata = attn_metadata[self.k_cache.prefix] - slot_mapping = k_cache_attn_metadata.slot_mapping - - query_start_loc = k_cache_attn_metadata.query_start_loc - seq_lens = k_cache_attn_metadata.seq_lens - batch_size = seq_lens.size(0) - cu_seq_lens = torch.empty((batch_size + 1), - dtype=torch.int32, - device=q.device) - cu_seq_lens[0] = 0 - cu_seq_lens[1:] = seq_lens.cumsum(dim=0).to(dtype=torch.int32) - #TODO (zyongye) use quant type - kv_cache = self.k_cache.kv_cache[0] + attn_metadata = attn_metadata[self.k_cache.prefix] + assert isinstance(attn_metadata, DeepseekV32IndexerMetadata) + slot_mapping = attn_metadata.slot_mapping + has_decode = attn_metadata.num_decodes > 0 + has_prefill = attn_metadata.num_prefills > 0 + num_decode_tokens = attn_metadata.num_decode_tokens - # FIXME (zyongye) right now k_pe cache is a dummy tensor, - # we need to change kv cache to only store k cache - scale = torch.tensor(1, dtype=torch.float32, device=k.device) - ops.concat_and_cache_mla( + kv_cache = self.k_cache.kv_cache[0] + indexer_k_quant_and_cache( k, - k.clone(), kv_cache, slot_mapping, - "auto", - scale, + self.quant_block_size, + self.scale_fmt, ) - flattened_kv = torch.empty((cu_seq_lens[-1], self.head_dim * 2), - device=k.device, - dtype=torch.bfloat16) - ops.cp_gather_cache( - kv_cache, - flattened_kv, - k_cache_attn_metadata.block_table, - cu_seq_lens, - batch_size, - ) - # FIXME (zyongye) put this function before ops.cp_gather_cache would - # cause cu_seqlen_ks been changed, this need investigation later - cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches( - query_start_loc, seq_lens) - logits, _ = ref_fp8_mqa_logits( - q, - flattened_kv[..., :self.head_dim], - weights, - cu_seqlen_ks, - cu_seqlen_ke, - ) - topk_indices = logits.topk(min(self.topk_tokens, logits.shape[-1]), - dim=-1)[1] - topk_indices -= cu_seqlen_ks[:, None] - mask_lo = topk_indices >= 0 - mask_hi = topk_indices < cu_seqlen_ke[:, None] - cu_seqlen_ks[:, - None] - mask = mask_lo & mask_hi - topk_indices = topk_indices.masked_fill(~mask, -1) topk_indices_buffer = torch.full( - (logits.shape[0], self.topk_tokens), + (hidden_states.shape[0], self.topk_tokens), + -1, + dtype=torch.int32, + device=hidden_states.device) + if has_prefill: + prefill_metadata = attn_metadata.prefill + num_prefills = attn_metadata.num_prefills + flattened_kv = torch.empty( + [prefill_metadata.total_seq_lens, self.head_dim + 4], + device=k.device, + dtype=torch.uint8) + ops.cp_gather_cache( + kv_cache, + flattened_kv, + prefill_metadata.block_table, + prefill_metadata.cu_seq_lens, + num_prefills, + ) + # TODO: the memory footprint here can be optimized + k_fp8 = flattened_kv[..., :self.head_dim].view( + torch.float8_e4m3fn).contiguous() + k_scale = flattened_kv[..., self.head_dim:].view( + torch.float32).contiguous() + cu_seqlen_ks = prefill_metadata.cu_seqlen_ks + cu_seqlen_ke = prefill_metadata.cu_seqlen_ke + logits = fp8_mqa_logits( + q_fp8[num_decode_tokens:], + (k_fp8, k_scale), + weights[num_decode_tokens:], + cu_seqlen_ks, + cu_seqlen_ke, + ) + topk_indices = logits.topk(min(self.topk_tokens, + logits.shape[-1]), + dim=-1)[1] + topk_indices -= cu_seqlen_ks[:, None] + mask_lo = topk_indices >= 0 + mask_hi = topk_indices < cu_seqlen_ke[:, + None] - cu_seqlen_ks[:, + None] + mask = mask_lo & mask_hi + topk_indices = topk_indices.masked_fill(~mask, -1) + topk_indices_buffer[num_decode_tokens:, :topk_indices. + shape[-1]] = topk_indices.to( + dtype=torch.int32) + if has_decode: + decode_metadata = attn_metadata.decode + # kv_cache size requirement [num_block, block_size, n_head, head_dim], + # we only have [num_block, block_size, head_dim], + kv_cache = kv_cache.unsqueeze(-2) + logits = fp8_paged_mqa_logits( + q_fp8[:num_decode_tokens].unsqueeze(1), + kv_cache, + weights[:num_decode_tokens], + decode_metadata.seq_lens, + decode_metadata.block_table, + decode_metadata.schedule_metadata, + max_model_len=self.max_model_len, + ) + positions = torch.arange(self.max_model_len, + device="cuda").unsqueeze( + 0) # [1, max_model_len] + next_n_offset = torch.arange(num_decode_tokens, device="cuda") + # NOTE(Chen): not true for spec decode + # [1, max_model_len] < [num_decode_tokens, 1] -> [num_decode_tokens, max_model_len] + mask = positions <= (decode_metadata.seq_lens - 1 + + next_n_offset).unsqueeze(1) + logits = logits.masked_fill(~mask, float("-inf")) + topk_indices = logits.topk( + min(self.topk_tokens, logits.shape[-1]), dim=-1)[1].to( + torch.int32) # [num_decode_tokens, topk_tokens] + topk_indices[topk_indices >= + decode_metadata.seq_lens[:, None]] = -1 + topk_indices_buffer[:num_decode_tokens, :topk_indices. + shape[-1]] = topk_indices.to( + dtype=torch.int32) + else: + # profile run + from vllm.v1.attention.backends.mla.indexer import get_max_prefill_buffer_size + total_seq_lens = get_max_prefill_buffer_size(self.vllm_config) + # NOTE(Chen): create the max possible flattened_kv. So that + # profile_run can get correct memory usage. + _flattened_kv = torch.empty([total_seq_lens, self.head_dim + 4], + device=k.device, + dtype=torch.uint8) + _k_fp8 = _flattened_kv[..., :self.head_dim].view( + torch.float8_e4m3fn).contiguous() + _k_scale = _flattened_kv[..., self.head_dim:].view( + torch.float32).contiguous() + topk_indices_buffer = torch.full( + (hidden_states.shape[0], self.topk_tokens), -1, dtype=torch.int32, - device=logits.device) - topk_indices_buffer[:, :topk_indices.shape[-1]] = topk_indices.to( - dtype=torch.int32) - return topk_indices_buffer + device=hidden_states.device) + return topk_indices_buffer class DeepseekV2MLAAttention(nn.Module): @@ -766,6 +768,7 @@ class DeepseekV2MLAAttention(nn.Module): def __init__( self, + vllm_config: VllmConfig, config: Union[DeepseekV2Config, DeepseekV3Config], hidden_size: int, num_heads: int, @@ -865,8 +868,8 @@ def __init__( ) and "attn_index" in config.attn_module_list_cfg[0] if self.is_v32: - self.indexer = Indexer(config, hidden_size, q_lora_rank, - quant_config, cache_config, + self.indexer = Indexer(vllm_config, config, hidden_size, + q_lora_rank, quant_config, cache_config, f"{prefix}.indexer") else: self.indexer = None @@ -937,6 +940,7 @@ def __init__(self, vllm_config: VllmConfig, prefix: str) -> None: else: attn_cls = DeepseekV2Attention self.self_attn = attn_cls( + vllm_config=vllm_config, config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index c2e478e382b8..daf9b507f215 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -1,15 +1,40 @@ from dataclasses import dataclass -from vllm.attention.backends.abstract import AttentionBackend +from typing import Optional + +from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata +from vllm.config import VllmConfig +from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, - CommonAttentionMetadata) + CommonAttentionMetadata, + split_decodes_and_prefills) import torch +from vllm.distributed.parallel_state import get_tensor_model_parallel_rank +from vllm.logger import init_logger + +logger = init_logger(__name__) class DeepseekV32IndexerBackend(AttentionBackend): @staticmethod - def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int, - head_size: int) -> tuple[int, ...]: + def get_metadata_cls() -> type["AttentionMetadata"]: + return DeepseekV32IndexerMetadata + + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + return [32, 64, 128] + + @staticmethod + def get_builder_cls() -> type["DeepseekV32IndexerMetadataBuilder"]: + return DeepseekV32IndexerMetadataBuilder + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> tuple[int, ...]: assert num_kv_heads == 1 return (num_blocks, block_size, head_size) @@ -17,41 +42,198 @@ def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int, def get_kv_cache_stride_order() -> tuple[int, ...]: return (0, 1, 2) - @staticmethod - def get_builder_cls() -> type["DeepseekV32IndexerMetadataBuilder"]: - return DeepseekV32IndexerMetadataBuilder + +@dataclass +class DeepseekV32IndexerPrefillMetadata: + block_table: torch.Tensor + query_start_loc: torch.Tensor + max_query_len: int + cu_seqlen_ks: torch.Tensor + cu_seqlen_ke: torch.Tensor + cu_seq_lens: torch.Tensor + total_seq_lens: int + + +@dataclass +class DeepSeekV32IndexerDecodeMetadata: + block_table: torch.Tensor + seq_lens: torch.Tensor + schedule_metadata: torch.Tensor @dataclass class DeepseekV32IndexerMetadata: - - query_start_loc: torch.Tensor + + #FIXME (zyongye) hacky way to access the data now, need to be in chunked meta seq_lens: torch.Tensor num_reqs: int - num_actual_tokens: int # Number of tokens excluding padding. max_query_len: int max_seq_len: int - - block_table: torch.Tensor # [num_req, (max_req_len + block_size - 1) // block_size] + + num_actual_tokens: int # Number of tokens excluding padding. + query_start_loc: torch.Tensor slot_mapping: torch.Tensor + # The dimension of the attention heads + head_dim: int + + # New for MLA (compared to FlashAttention) + # For handling prefill decode split + num_decodes: int + num_decode_tokens: int + num_prefills: int + num_prefill_tokens: int + + decode: Optional[DeepSeekV32IndexerDecodeMetadata] = None + prefill: Optional[DeepseekV32IndexerPrefillMetadata] = None + + +# TODO (zyongye) optimize this, this is now vibe coded +def kv_spans_from_batches(start_seq_loc: torch.Tensor, + seq_len_per_batch: torch.Tensor): + """ + Args: + start_seq_loc: 1D long tensor [B+1], cumulative counts of selected tokens per batch. + Example: [0, 2, 4, 7] -> batch sizes (selected) [2, 2, 3], N=7 tokens total. + seq_len_per_batch: 1D long tensor [B], full sequence length (KV length) of each batch. + Example: [5, 9, 4]. + + Returns: + start_tensor: 1D long tensor [N], start offset in the concatenated KV cache for each token's batch. + end_location: 1D long tensor [N], **exclusive** end = start + token's local position. + (So the attended KV slice is kv[start:end].) + + Assumes each batch contributes its full `seq_len_per_batch[i]` keys to the KV cache, and + the selected tokens within a batch are the **last** `counts[i]` positions of that sequence. + """ + q = start_seq_loc.to(dtype=torch.long) + L = seq_len_per_batch.to(dtype=torch.long, device=q.device) + assert q.dim() == 1 and L.dim() == 1 + assert q.numel() == L.numel() + 1, "start_seq_loc must have length B+1" + + # Selected tokens per batch and totals + counts = q[1:] - q[:-1] # [B] + N = int(q[-1].item()) # total selected tokens + B = L.numel() + device = L.device + + if N == 0: + return (torch.empty(0, dtype=torch.long, device=device), + torch.empty(0, dtype=torch.long, device=device)) + + # KV start offsets per batch in the concatenated KV cache + kv_starts_per_batch = torch.cumsum(L, dim=0) - L # [B] + + # For each selected token, which batch does it belong to? + batch_id = torch.repeat_interleave(torch.arange(B, device=device), + counts) # [N] + + # Map batch KV start to each token + start_tensor = kv_starts_per_batch[batch_id] # [N] + + # End-align local positions inside each batch: + # local_pos = L[b] - counts[b] + (1..counts[b]) for each batch b + L_expand = torch.repeat_interleave(L, counts) # [N] + m_expand = torch.repeat_interleave(counts, counts) # [N] + # position within the selected block: 1..counts[b] + pos_within = (torch.arange(N, device=device, dtype=torch.long) - + torch.repeat_interleave(q[:-1], counts) + 1) + + local_pos = L_expand - m_expand + pos_within # [N], 1-based + end_location = start_tensor + local_pos # exclusive end + + return start_tensor.int(), end_location.int() + + +def get_max_prefill_buffer_size(vllm_config: VllmConfig): + max_model_len = vllm_config.model_config.max_model_len + max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens + # NOTE(Chen): an estimated max size of flattened_kv. Need to double check. + return max_model_len + max_num_batched_tokens class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + max_model_len = self.vllm_config.model_config.max_model_len + max_num_batched_tokens = self.vllm_config.scheduler_config.max_num_batched_tokens + # NOTE(Chen): an estimated max size of flattened_kv. Need to double check. + self.max_prefill_buffer_size = get_max_prefill_buffer_size( + self.vllm_config) def build(self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, fast_build: bool = False) -> DeepseekV32IndexerMetadata: - return DeepseekV32IndexerMetadata( - query_start_loc = common_attn_metadata.query_start_loc, + + num_reqs = common_attn_metadata.num_reqs + num_tokens = common_attn_metadata.num_actual_tokens + + device = self.device + block_table_tensor = common_attn_metadata.block_table_tensor + + query_start_loc = common_attn_metadata.query_start_loc + + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ + split_decodes_and_prefills(common_attn_metadata) + + assert num_decodes + num_prefills == num_reqs + assert num_decode_tokens + num_prefill_tokens == num_tokens + + prefill_metadata = None + if num_prefills > 0: + reqs_start = num_decodes + prefill_query_start_loc = query_start_loc[ + reqs_start:] - query_start_loc[reqs_start] + cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches( + prefill_query_start_loc, + common_attn_metadata.seq_lens[reqs_start:]) + total_seq_lens = common_attn_metadata.seq_lens[reqs_start:].sum() + assert total_seq_lens < self.max_prefill_buffer_size + cu_seq_lens = torch.cat([ + torch.zeros(1, dtype=torch.int32, device=device), + common_attn_metadata.seq_lens[reqs_start:].cumsum(dim=0) + ]).to(torch.int32).cuda() + prefill_metadata = DeepseekV32IndexerPrefillMetadata( + block_table=block_table_tensor[reqs_start:, ...], + query_start_loc=prefill_query_start_loc, + max_query_len=common_attn_metadata.max_query_len, + cu_seqlen_ks=cu_seqlen_ks, + cu_seqlen_ke=cu_seqlen_ke, + cu_seq_lens=cu_seq_lens, + total_seq_lens=total_seq_lens, + ) + + decode_metadata = None + if num_decodes > 0: + seq_lens = common_attn_metadata.seq_lens[:num_decodes] + schedule_metadata = get_paged_mqa_logits_metadata( + seq_lens, self.kv_cache_spec.block_size, 132) + decode_metadata = DeepSeekV32IndexerDecodeMetadata( + block_table=common_attn_metadata. + block_table_tensor[:num_decodes, ...], + seq_lens=common_attn_metadata.seq_lens[:num_decodes], + schedule_metadata=schedule_metadata, + ) + + attn_metadata = DeepseekV32IndexerMetadata( seq_lens=common_attn_metadata.seq_lens, num_reqs=common_attn_metadata.num_reqs, - num_actual_tokens=common_attn_metadata.num_actual_tokens, max_query_len=common_attn_metadata.max_query_len, max_seq_len=common_attn_metadata.max_seq_len, - block_table=common_attn_metadata.block_table_tensor, - slot_mapping=common_attn_metadata.slot_mapping) + num_actual_tokens=common_attn_metadata.num_actual_tokens, + query_start_loc=common_attn_metadata.query_start_loc, + slot_mapping=common_attn_metadata.slot_mapping, + head_dim=128, + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + num_prefills=num_prefills, + num_prefill_tokens=num_prefill_tokens, + prefill=prefill_metadata, + decode=decode_metadata, + ) + + # if get_tensor_model_parallel_rank() == 0: + # logger.info(f"attn_metadata: {attn_metadata}") + return attn_metadata From e2dcd85b40fab13dca5c62cdf0933967d7aabcdd Mon Sep 17 00:00:00 2001 From: Yongye Zhu Date: Fri, 26 Sep 2025 02:01:26 +0000 Subject: [PATCH 29/82] update prefill indexer unittest --- tests/kernels/attention/test_indexer.py | 61 +++++++------------------ 1 file changed, 17 insertions(+), 44 deletions(-) diff --git a/tests/kernels/attention/test_indexer.py b/tests/kernels/attention/test_indexer.py index a7b3e0c93572..10f0b92e60e4 100644 --- a/tests/kernels/attention/test_indexer.py +++ b/tests/kernels/attention/test_indexer.py @@ -2,9 +2,10 @@ import torch +from vllm.v1.attention.backends.mla.indexer import kv_spans_from_batches from vllm.utils.tile_lang_kernels import act_quant, fp8_index from vllm import _custom_ops as ops -from vllm.model_executor.models.deepseek_v2 import kv_spans_from_batches, indexer_k_quant_and_cache +from vllm.model_executor.models.deepseek_v2 import indexer_k_quant_and_cache from vllm.utils.deep_gemm import ( fp8_mqa_logits, get_paged_mqa_logits_metadata, @@ -28,8 +29,10 @@ def ref_compute_logits_fp8(q, kv, weights, mask, block_size): def ref_indexer(seq_len, q, kv, weights, block_size, topk): B = seq_len.shape[0] - varlen_logits = [] + total_seqlen = torch.sum(seq_len) + varlen_logits = torch.full((total_seqlen, total_seqlen), float("-inf"), device="cuda") + current_context_ptr = 0 for i in range(B): S = seq_len[i] q_s = q[i][:S].contiguous().unsqueeze(0) @@ -39,25 +42,17 @@ def ref_indexer(seq_len, q, kv, weights, block_size, topk): (S, S), float("-inf"), device="cuda").triu_(1) logits = ref_compute_logits_fp8(q_s, kv_s, weights_s, mask, block_size) - varlen_logits.append(logits) - # topk_indices = index_score.topk(topk, - # dim=-1)[1] + logits = logits.squeeze(0) + + varlen_logits[current_context_ptr:current_context_ptr + S, current_context_ptr: current_context_ptr + S] = logits + current_context_ptr += S return varlen_logits def deepgemm_mqa_indexer(seq_len, q, kv, weights, block_size, topk): - NUM_BLOCKS = 8 - BLOCK_SIZE = 32 - B = seq_len.shape[0] concat_q = [] concat_kv = [] concat_weights = [] - total_slots = NUM_BLOCKS * BLOCK_SIZE - head_dim = kv.shape[-1] - max_num_block_per_batch = torch.max(seq_len) - block_table = torch.empty((B, max_num_block_per_batch), - dtype=torch.int32, - device="cuda") for i in range(B): S = seq_len[i] @@ -71,29 +66,8 @@ def deepgemm_mqa_indexer(seq_len, q, kv, weights, block_size, topk): q = torch.cat(concat_q, dim=0) kv = torch.cat(concat_kv, dim=0) weights = torch.cat(concat_weights, dim=0) - q_fp8, q_scale = act_quant(q, 128, "ue8m0") - kv_fp8, kv_scale = act_quant(kv, 128, "ue8m0") - - # write to kv cache based on slot mapping - entry_size = head_dim + 4 - num_tokens = q.size(0) - slot_mapping_lst = random.sample(range(total_slots), num_tokens) - slot_mapping = torch.tensor(slot_mapping_lst, - dtype=torch.long, - device="cuda") - kv_cache = torch.zeros( - NUM_BLOCKS, - BLOCK_SIZE, - entry_size, - dtype=torch.uint8, - device="cuda" - ) - - current_index = 0 - for i in range(B): - S = seq_len[i] - block_table[i][:S] = slot_mapping[current_index: current_index + S] - current_index += S + q_fp8, q_scale = act_quant(q, block_size, "ue8m0") + kv_fp8, kv_scale = act_quant(kv, block_size, "ue8m0") weights = weights.unsqueeze(-1) * (128**(-0.5)) * q_scale weights = weights.squeeze(-1) @@ -117,8 +91,8 @@ def deepgemm_mqa_indexer(seq_len, q, kv, weights, block_size, topk): topk_indices = topk_indices.masked_fill(~mask, -1) return logits -def test_paged_indexer_python(): - B = +def test_prefill_indexer(): + B = 3 S = 128 SKV = S H = 64 @@ -135,11 +109,10 @@ def test_paged_indexer_python(): dtype=torch.bfloat16) weights = torch.randn(B, S, H, device=device, dtype=torch.float32) * H**-0.5 - ref_indices = ref_indexer(seq_len, q, kv, weights, block_size, topk) - torch_indices = deepgemm_mqa_indexer(seq_len, q, kv, weights, block_size, topk) - import pdb; pdb.set_trace() - print(ref_indices) + ref_logits = ref_indexer(seq_len, q, kv, weights, block_size, topk) + deepgemm_logits = deepgemm_mqa_indexer(seq_len, q, kv, weights, block_size, topk) + torch.testing.assert_close(ref_logits, deepgemm_logits) if __name__ == "__main__": - test_paged_indexer_python() + test_prefill_indexer() \ No newline at end of file From d7f80ed2b9fe9efde991b533e32ef211a81c4507 Mon Sep 17 00:00:00 2001 From: Yongye Zhu Date: Fri, 26 Sep 2025 04:02:53 +0000 Subject: [PATCH 30/82] paged_indexer_unit test --- tests/kernels/attention/test_indexer.py | 129 +++++++++++++++++++++--- 1 file changed, 117 insertions(+), 12 deletions(-) diff --git a/tests/kernels/attention/test_indexer.py b/tests/kernels/attention/test_indexer.py index 10f0b92e60e4..1121892f5172 100644 --- a/tests/kernels/attention/test_indexer.py +++ b/tests/kernels/attention/test_indexer.py @@ -2,16 +2,33 @@ import torch +from vllm.utils import cdiv from vllm.v1.attention.backends.mla.indexer import kv_spans_from_batches from vllm.utils.tile_lang_kernels import act_quant, fp8_index from vllm import _custom_ops as ops from vllm.model_executor.models.deepseek_v2 import indexer_k_quant_and_cache from vllm.utils.deep_gemm import ( fp8_mqa_logits, + calc_diff, get_paged_mqa_logits_metadata, fp8_paged_mqa_logits, ) -from vllm.utils.tile_lang_kernels import act_quant + +def kv_cache_cast_to_fp8(x: torch.Tensor) -> torch.Tensor: + num_blocks, block_size, num_heads, head_dim = x.shape + assert num_heads == 1 + x_amax = x.abs().float().amax(dim=3, keepdim=True).clamp(1e-4) + sf = x_amax / 448.0 + x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn) + x_fp8 = torch.empty((num_blocks, block_size * (head_dim + 4)), + device=x.device, + dtype=torch.uint8) + x_fp8[:, :block_size * head_dim] = x_scaled.view( + num_blocks, block_size * head_dim).view(dtype=torch.uint8) + x_fp8[:, + block_size * head_dim:] = sf.view(num_blocks, + block_size).view(dtype=torch.uint8) + return x_fp8.view(num_blocks, block_size, num_heads, head_dim + 4) def ref_compute_logits_fp8(q, kv, weights, mask, block_size): q_fp8, q_scale = act_quant(q, block_size, "ue8m0") @@ -48,24 +65,27 @@ def ref_indexer(seq_len, q, kv, weights, block_size, topk): current_context_ptr += S return varlen_logits -def deepgemm_mqa_indexer(seq_len, q, kv, weights, block_size, topk): +def deepgemm_mqa_indexer(seq_len, query_seq_len, q, kv, weights, block_size, topk, is_kv_batched=True): B = seq_len.shape[0] concat_q = [] concat_kv = [] concat_weights = [] - + for i in range(B): S = seq_len[i] q_s = q[i][:S].contiguous() - kv_s = kv[i][:S].contiguous() - weight_s = weights[i][:S].contiguous() + if is_kv_batched: + kv_s = kv[i][:S].contiguous() + weight_s = weights[i][:S].contiguous() concat_q.append(q_s) - concat_kv.append(kv_s) - concat_weights.append(weight_s) + if is_kv_batched: + concat_kv.append(kv_s) + concat_weights.append(weight_s) q = torch.cat(concat_q, dim=0) - kv = torch.cat(concat_kv, dim=0) - weights = torch.cat(concat_weights, dim=0) + if is_kv_batched: + kv = torch.cat(concat_kv, dim=0) + weights = torch.cat(concat_weights, dim=0) q_fp8, q_scale = act_quant(q, block_size, "ue8m0") kv_fp8, kv_scale = act_quant(kv, block_size, "ue8m0") @@ -73,7 +93,7 @@ def deepgemm_mqa_indexer(seq_len, q, kv, weights, block_size, topk): weights = weights.squeeze(-1) query_start_loc = torch.empty((B + 1), device="cuda") query_start_loc[0] = 0 - query_start_loc[1:] = seq_len.cumsum(dim=0).to(dtype=torch.int32) + query_start_loc[1:] = query_seq_len.cumsum(dim=0).to(dtype=torch.int32) cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(query_start_loc, seq_len) @@ -110,9 +130,94 @@ def test_prefill_indexer(): weights = torch.randn(B, S, H, device=device, dtype=torch.float32) * H**-0.5 ref_logits = ref_indexer(seq_len, q, kv, weights, block_size, topk) - deepgemm_logits = deepgemm_mqa_indexer(seq_len, q, kv, weights, block_size, topk) + deepgemm_logits = deepgemm_mqa_indexer(seq_len, seq_len, q, kv, weights, block_size, topk) torch.testing.assert_close(ref_logits, deepgemm_logits) +def test_decode_paged_indexer(): + num_blocks, blocksize = 111 * 3000, 64 + B = 3 + S = 128 + SKV = S + H = 64 + HKV = 1 + D = 128 + block_size = 128 + topk = 64 + device = "cuda" + seq_len = torch.randint(low=64, high=S, size=(B,), device="cuda") + + query_seq_len = torch.ones(B, device="cuda") + + q = torch.randn((B, 1, H, D), + device='cuda', + dtype=torch.bfloat16) + kv_cache = torch.randn((num_blocks, blocksize, 1, D), + device='cuda', + dtype=torch.bfloat16) + weights = torch.randn((B * 1, H), + device='cuda', + dtype=torch.float32) * H**-0.5 + max_block_len = (seq_len.max().item() + blocksize - + 1) // blocksize * blocksize + + block_tables = torch.zeros((B, max_block_len), + device='cuda', + dtype=torch.int32) + + counter = 0 + block_idx_pool = list(range(num_blocks)) + random.shuffle(block_idx_pool) + for i in range(B): + ctx_len = seq_len[i].item() + for j in range(cdiv(ctx_len, blocksize)): + block_tables[i][j] = block_idx_pool[counter] + counter += 1 + + flatten_kv = torch.empty( + [seq_len.sum(), D], device="cuda", dtype=torch.bfloat16 + ) + cu_seq_lens = torch.cat([ + torch.zeros(1, dtype=torch.int32, device=device), + seq_len.cumsum(dim=0) + ]).to(torch.int32).cuda() + + ops.cp_gather_cache( + kv_cache, + flatten_kv, + block_tables, + cu_seq_lens, + B, + ) + + ref_logits = deepgemm_mqa_indexer(seq_len, query_seq_len, q, flatten_kv, weights, block_size, topk, is_kv_batched=False) + + q_fp8, q_scale = act_quant(q, block_size, "ue8m0") + kv_cache_fp8 = kv_cache_cast_to_fp8(kv_cache) + + schedule_metadata = get_paged_mqa_logits_metadata( + seq_len.int(), blocksize, 132) + + weights = weights.unsqueeze(-1) * (128**(-0.5)) * q_scale.squeeze(1) + weights = weights.squeeze(-1) + + logits = fp8_paged_mqa_logits( + q_fp8, kv_cache_fp8, weights, seq_len.int(), block_tables, + schedule_metadata, 4096) + + concat_logit = [] + context = 0 + for i in range(B): + per_seq_logits = torch.zeros(4096, device="cuda") + S = seq_len[i] + per_seq_logits[:S] = ref_logits[i][context: context + S] + concat_logit.append(per_seq_logits) + context += S + ref_logits = torch.stack(concat_logit, dim=0) + logits[logits == float("-inf")] = 0 + diff = calc_diff(logits, ref_logits) + assert diff < 1e-3, f"{diff=}" + if __name__ == "__main__": - test_prefill_indexer() \ No newline at end of file + test_prefill_indexer() + test_decode_paged_indexer() \ No newline at end of file From b7e4b60247ed5a35aab7d9761e0cc5c6db076b6b Mon Sep 17 00:00:00 2001 From: Yongye Zhu Date: Sat, 27 Sep 2025 06:14:20 +0000 Subject: [PATCH 31/82] remove unnecessary bias in wq_b and wk layer, accuracy is greatly improved --- vllm/model_executor/models/deepseek_v2.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 27403d5d3813..e0eb0966333e 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -587,10 +587,12 @@ def __init__(self, # no tensor parallel, just replicated self.wq_b = ReplicatedLinear(self.q_lora_rank, self.head_dim * self.n_head, + bias=False, quant_config=quant_config, prefix=f"{prefix}.wq_b") self.wk = ReplicatedLinear(hidden_size, self.head_dim, + bias=False, quant_config=quant_config, prefix=f"{prefix}.wk") self.k_norm = LayerNorm(self.head_dim, eps=1e-6) From 9bb302e3a4e1d1e8bbdeb8748511f97417219611 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sat, 27 Sep 2025 00:15:22 -0700 Subject: [PATCH 32/82] Support piecewise cuda graph (#42) * replace hadmard * revert hardmard transform for k * non-eager mode * add wrappers for rotate_activation and act quant * fix calculation * refactor indexer. use custom op * revert changes in mask * add topk_indices guard back * minor fix * fix bug Signed-off-by: Chen Zhang * fix piecewise cuda graph Signed-off-by: Chen Zhang --------- Signed-off-by: Chen Zhang Co-authored-by: Siyuan Fu --- examples/offline_inference/basic/basic.py | 2 +- vllm/config/compilation.py | 1 + vllm/model_executor/layers/mla.py | 7 +- vllm/model_executor/models/deepseek_v2.py | 403 ++++++++++++------ .../attention/backends/mla/flashmla_sparse.py | 42 +- 5 files changed, 308 insertions(+), 147 deletions(-) diff --git a/examples/offline_inference/basic/basic.py b/examples/offline_inference/basic/basic.py index 63e6045cd3c7..cd73022ba6ee 100644 --- a/examples/offline_inference/basic/basic.py +++ b/examples/offline_inference/basic/basic.py @@ -41,7 +41,7 @@ """ def main(): # Create an LLM. - llm = LLM(model="/home/vllm-dsv32/DeepSeek-V3.2-Preview-Fix", tensor_parallel_size=8, enforce_eager=True) + llm = LLM(model="/home/vllm-dsv32/DeepSeek-V3.2-Preview-Fix", tensor_parallel_size=8) # Generate texts from the prompts. # The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 22b38daf46c3..4dad3b668285 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -362,6 +362,7 @@ class CompilationConfig: "vllm.linear_attention", "vllm.plamo2_mamba_mixer", "vllm.gdn_attention", + "vllm.sparse_attn_indexer", ] def compute_hash(self) -> str: diff --git a/vllm/model_executor/layers/mla.py b/vllm/model_executor/layers/mla.py index af5766bc814d..892bf5e09c8f 100644 --- a/vllm/model_executor/layers/mla.py +++ b/vllm/model_executor/layers/mla.py @@ -27,6 +27,7 @@ class MLAModules: q_proj: Optional[torch.nn.Module] indexer: Optional[torch.nn.Module] is_sparse: bool + topk_indices_buffer: Optional[torch.Tensor] @CustomOp.register("multi_head_latent_attention") @@ -84,6 +85,7 @@ def __init__( if self.indexer else None self.use_sparse = mla_modules.is_sparse and os.getenv( "VLLM_MLA_SPARSE_DISABLED") != "1" + self.topk_indices_buffer = mla_modules.topk_indices_buffer # In the MLA backend, kv_cache includes both k_c and # pe (i.e. decoupled position embeddings). In particular, @@ -157,11 +159,8 @@ def forward_native( positions, q[..., self.qk_nope_head_dim:], k_pe) if self.indexer and self.use_sparse: - topk_indices = self.indexer(hidden_states, q_c, positions, + _topk_indices = self.indexer(hidden_states, q_c, positions, self.rotary_emb) - # NOTE(Chen): a bit hacky, but need to modify Attention.forward - # otherwise. Try to refactor this later. - self.mla_attn.impl.set_topk_indices(topk_indices) attn_out = self.mla_attn( q, diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index e0eb0966333e..bb7c6ec96696 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -34,7 +34,6 @@ import torch.distributed as dist from vllm.attention.backends.abstract import AttentionBackend -from vllm.distributed.parallel_state import get_tp_group from vllm.logger import init_logger from vllm.config.compilation import CompilationConfig import vllm.envs as envs @@ -532,7 +531,6 @@ def get_attn_backend(self) -> AttentionBackend: return DeepseekV32IndexerBackend -# ignore or replace with pytorch def rotate_activation(x: torch.Tensor) -> torch.Tensor: assert x.dtype == torch.bfloat16 from fast_hadamard_transform import hadamard_transform @@ -540,6 +538,81 @@ def rotate_activation(x: torch.Tensor) -> torch.Tensor: # make sure the hidden_size is expontial of 2 return hadamard_transform(x, scale=hidden_size**-0.5) + +def hadacore_transform(x: torch.Tensor, inplace: bool = False) -> torch.Tensor: + assert x.dtype == torch.bfloat16 + return ops.hadacore_transform(x, inplace=inplace) + + +def rotate_activation_fake(x: torch.Tensor, ) -> torch.Tensor: + return torch.empty_like(x) + + +direct_register_custom_op( + op_name="rotate_activation", + op_func=rotate_activation, + mutates_args=["x"], + fake_impl=rotate_activation_fake, + dispatch_key=current_platform.dispatch_key, +) + + +def tilelang_act_quant( + x: torch.Tensor, + block_size: int, + scale_fmt: Optional[str], +) -> tuple[torch.Tensor, torch.Tensor]: + from vllm.utils.tile_lang_kernels import act_quant + return act_quant(x, block_size, scale_fmt) + + +def tilelang_act_quant_fake( + x: torch.Tensor, + block_size: int, + scale_fmt: Optional[str], +) -> tuple[torch.Tensor, torch.Tensor]: + return per_token_group_quant_fp8(x, + block_size, + column_major_scales=False, + use_ue8m0=scale_fmt is not None) + + +direct_register_custom_op( + op_name="tilelang_act_quant", + op_func=tilelang_act_quant, + mutates_args=[], + fake_impl=tilelang_act_quant_fake, + dispatch_key=current_platform.dispatch_key, +) + + +def ref_fp8_mqa_logits( + q: torch.Tensor, + kv: torch.Tensor, + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, +): + # print(f"q_shape: {q.shape}, v_shape: {kv.shape}, weights.shape: {weights.shape}") + k = kv + q = q.float() + k = k.float() + + seq_len_kv = kv.shape[0] + mask_lo = (torch.arange(0, seq_len_kv, device="cuda")[None, :] + >= cu_seqlen_ks[:, None]) + mask_hi = (torch.arange(0, seq_len_kv, device="cuda")[None, :] + < cu_seqlen_ke[:, None]) + mask = mask_lo & mask_hi + + score = torch.einsum("mhd,nd->hmn", q, k) + logits = (score.relu() * weights.transpose(0, 1)).sum(dim=0) + logits = logits.masked_fill(~mask, float("-inf")) + + cost = mask.sum() + return logits, cost + + @torch.inference_mode() def indexer_k_quant_and_cache( k, @@ -548,9 +621,9 @@ def indexer_k_quant_and_cache( quant_block_size, scale_fmt, ): - from vllm.utils.tile_lang_kernels import act_quant _, block_size, head_dim = kv_cache.shape - k_fp8, k_scale = act_quant(k, quant_block_size, scale_fmt) + k_fp8, k_scale = torch.ops.vllm.tilelang_act_quant(k, quant_block_size, + scale_fmt) k_bytes = k_fp8.view(torch.uint8) s_bytes = k_scale.view(torch.uint8) @@ -561,8 +634,169 @@ def indexer_k_quant_and_cache( linear = block_idx * block_size + inblock kv_cache_flat = kv_cache.view(-1, head_dim) + # kv_cache_flat.shape: torch.Size([22326528, 132]), packed.shape: torch.Size([96, 132]), kv_cache.shape: torch.Size([348852, 64, 132]), linear.shape: torch.Size([91]) + + kv_cache_flat.index_copy_(0, linear, packed[:len(linear)]) + + +def sparse_attn_indexer( + hidden_states: torch.Tensor, + k_cache_prefix: str, + kv_cache: torch.Tensor, + q_fp8: torch.Tensor, + k: torch.Tensor, + weights: torch.Tensor, + quant_block_size: int, + scale_fmt: Optional[str], + topk_tokens: int, + head_dim: int, + max_model_len: int, + total_seq_lens: int, + topk_indices_buffer: Optional[torch.Tensor], +) -> torch.Tensor: + + # careful! this will be None in dummy run + attn_metadata = get_forward_context().attn_metadata + # assert isinstance(attn_metadata, dict) + if not isinstance(attn_metadata, dict): + return sparse_attn_indexer_fake( + hidden_states, + k_cache_prefix, + kv_cache, + q_fp8, + k, + weights, + quant_block_size, + scale_fmt, + topk_tokens, + head_dim, + max_model_len, + total_seq_lens, + topk_indices_buffer, + ) + attn_metadata = attn_metadata[k_cache_prefix] + assert isinstance(attn_metadata, DeepseekV32IndexerMetadata) + slot_mapping = attn_metadata.slot_mapping + has_decode = attn_metadata.num_decodes > 0 + has_prefill = attn_metadata.num_prefills > 0 + num_decode_tokens = attn_metadata.num_decode_tokens + + indexer_k_quant_and_cache( + k, + kv_cache, + slot_mapping, + quant_block_size, + scale_fmt, + ) + + topk_indices_buffer[:hidden_states.shape[0]] = -1 + if has_prefill: + prefill_metadata = attn_metadata.prefill + num_prefills = attn_metadata.num_prefills + flattened_kv = torch.empty( + [prefill_metadata.total_seq_lens, head_dim + 4], + device=k.device, + dtype=torch.uint8) + ops.cp_gather_cache( + kv_cache, + flattened_kv, + prefill_metadata.block_table, + prefill_metadata.cu_seq_lens, + num_prefills, + ) + # TODO: the memory footprint here can be optimized + k_fp8 = flattened_kv[..., :head_dim].view( + torch.float8_e4m3fn).contiguous() + k_scale = flattened_kv[..., head_dim:].view(torch.float32).contiguous() + cu_seqlen_ks = prefill_metadata.cu_seqlen_ks + cu_seqlen_ke = prefill_metadata.cu_seqlen_ke + num_tokens = attn_metadata.num_actual_tokens + logits = fp8_mqa_logits( + q_fp8[num_decode_tokens:num_tokens], + (k_fp8, k_scale), + weights[num_decode_tokens:num_tokens], + cu_seqlen_ks, + cu_seqlen_ke, + ) + topk_indices = logits.topk(min(topk_tokens, logits.shape[-1]), + dim=-1)[1] + topk_indices -= cu_seqlen_ks[:, None] + mask_lo = topk_indices >= 0 + mask_hi = topk_indices - (cu_seqlen_ke - cu_seqlen_ks)[:, None] < 0 + mask = torch.full_like(topk_indices, + False, + dtype=torch.bool, + device=topk_indices.device) + mask = mask_lo & mask_hi + topk_indices = topk_indices.masked_fill(~mask, -1) + topk_indices_buffer[num_decode_tokens:num_tokens, :topk_indices. + shape[-1]] = topk_indices.to(dtype=torch.int32) + + if has_decode: + decode_metadata = attn_metadata.decode + # kv_cache size requirement [num_block, block_size, n_head, head_dim], + # we only have [num_block, block_size, head_dim], + kv_cache = kv_cache.unsqueeze(-2) + logits = fp8_paged_mqa_logits( + q_fp8[:num_decode_tokens].unsqueeze(1), + kv_cache, + weights[:num_decode_tokens], + decode_metadata.seq_lens, + decode_metadata.block_table, + decode_metadata.schedule_metadata, + max_model_len=max_model_len, + ) + positions = torch.arange(max_model_len, device="cuda").unsqueeze( + 0) # [1, max_model_len] + next_n_offset = torch.arange(num_decode_tokens, device="cuda") + # NOTE(Chen): not true for spec decode + # [1, max_model_len] < [num_decode_tokens, 1] -> [num_decode_tokens, max_model_len] + mask = positions <= (decode_metadata.seq_lens - 1 + + next_n_offset).unsqueeze(1) + logits = logits.masked_fill(~mask, float("-inf")) + topk_indices = logits.topk( + min(topk_tokens, logits.shape[-1]), + dim=-1)[1].to(torch.int32) # [num_decode_tokens, topk_tokens] + topk_indices[topk_indices >= decode_metadata.seq_lens[:, None]] = -1 + topk_indices_buffer[:num_decode_tokens, :topk_indices. + shape[-1]] = topk_indices.to(dtype=torch.int32) + return topk_indices_buffer + + +def sparse_attn_indexer_fake( + hidden_states: torch.Tensor, + k_cache_prefix: str, + kv_cache: torch.Tensor, + q_fp8: torch.Tensor, + k: torch.Tensor, + weights: torch.Tensor, + quant_block_size: int, + scale_fmt: Optional[str], + topk_tokens: int, + head_dim: int, + max_model_len: int, + total_seq_lens: int, + topk_indices_buffer: Optional[torch.Tensor], +) -> torch.Tensor: + # profile run + # NOTE(Chen): create the max possible flattened_kv. So that + # profile_run can get correct memory usage. + _flattened_kv = torch.empty([total_seq_lens, head_dim + 4], + device=k.device, + dtype=torch.uint8) + _k_fp8 = _flattened_kv[..., :head_dim].view( + torch.float8_e4m3fn).contiguous() + _k_scale = _flattened_kv[..., head_dim:].view(torch.float32).contiguous() + return topk_indices_buffer - kv_cache_flat.index_copy_(0, linear, packed) + +direct_register_custom_op( + op_name="sparse_attn_indexer", + op_func=sparse_attn_indexer, + mutates_args=["topk_indices_buffer"], + fake_impl=sparse_attn_indexer_fake, + dispatch_key=current_platform.dispatch_key, +) class Indexer(nn.Module): @@ -574,6 +808,7 @@ def __init__(self, q_lora_rank: int, quant_config: Optional[QuantizationConfig], cache_config: Optional[CacheConfig], + topk_indices_buffer: Optional[torch.Tensor], prefix: str = ""): super().__init__() self.vllm_config = vllm_config @@ -604,6 +839,7 @@ def __init__(self, self.scale_fmt = "ue8m0" self.quant_block_size = 128 # TODO: get from config + self.topk_indices_buffer = topk_indices_buffer #TODO (zyongye) change dim to fp8 later to (self.head_dim + 4) self.k_cache = DeepseekV32IndexerCache(head_dim=self.head_dim + 4, @@ -612,9 +848,13 @@ def __init__(self, cache_config=cache_config) self.max_model_len = vllm_config.model_config.max_model_len self.prefix = prefix + from vllm.v1.attention.backends.mla.indexer import get_max_prefill_buffer_size + self.max_total_seq_len = get_max_prefill_buffer_size(vllm_config) def forward(self, hidden_states: torch.Tensor, qr: torch.Tensor, positions, rotary_emb) -> torch.Tensor: + # hidden_states.shape: torch.Size([16, 7168]), qr.shape: torch.Size([16, 1536]), positions.shape: torch.Size([16]) + # print(f"hidden_states: {torch.isinf(hidden_states).any()}, qr: {torch.isinf(qr).any()}") # print(f"weight_proj: {torch.isneginf(self.weights_proj.weight.to(torch.float32)).any()}") q, _ = self.wq_b(qr) @@ -631,133 +871,26 @@ def forward(self, hidden_states: torch.Tensor, qr: torch.Tensor, positions, q_pe, k_pe = rotary_emb(positions, q_pe, k_pe.unsqueeze(1)) q = torch.cat([q_pe, q_nope], dim=-1) k = torch.cat([k_pe.squeeze(1), k_nope], dim=-1) - q = rotate_activation(q) - k = rotate_activation(k) + # logger.info_once(f'q.shape: {q.shape}, k.shape: {k.shape}') + q = torch.ops.vllm.rotate_activation(q) + k = torch.ops.vllm.rotate_activation( + k + ) #FIXME (siyuanf) hadacore_transform causes illegal memory access when applying to k # we only quant q here since k quant is fused with cache insertion - from vllm.utils.tile_lang_kernels import act_quant - q_fp8, q_scale = act_quant(q, self.quant_block_size, self.scale_fmt) + q_fp8, q_scale = torch.ops.vllm.tilelang_act_quant( + q, self.quant_block_size, self.scale_fmt) weights, _ = self.weights_proj(hidden_states) weights = weights.unsqueeze( -1) * q_scale * self.softmax_scale * self.n_head**-0.5 weights = weights.squeeze(-1) - # careful! this will be None in dummy run - attn_metadata = get_forward_context().attn_metadata - if isinstance(attn_metadata, dict): - attn_metadata = attn_metadata[self.k_cache.prefix] - assert isinstance(attn_metadata, DeepseekV32IndexerMetadata) - slot_mapping = attn_metadata.slot_mapping - has_decode = attn_metadata.num_decodes > 0 - has_prefill = attn_metadata.num_prefills > 0 - num_decode_tokens = attn_metadata.num_decode_tokens - - kv_cache = self.k_cache.kv_cache[0] - indexer_k_quant_and_cache( - k, - kv_cache, - slot_mapping, - self.quant_block_size, - self.scale_fmt, - ) - - topk_indices_buffer = torch.full( - (hidden_states.shape[0], self.topk_tokens), - -1, - dtype=torch.int32, - device=hidden_states.device) - if has_prefill: - prefill_metadata = attn_metadata.prefill - num_prefills = attn_metadata.num_prefills - flattened_kv = torch.empty( - [prefill_metadata.total_seq_lens, self.head_dim + 4], - device=k.device, - dtype=torch.uint8) - ops.cp_gather_cache( - kv_cache, - flattened_kv, - prefill_metadata.block_table, - prefill_metadata.cu_seq_lens, - num_prefills, - ) - # TODO: the memory footprint here can be optimized - k_fp8 = flattened_kv[..., :self.head_dim].view( - torch.float8_e4m3fn).contiguous() - k_scale = flattened_kv[..., self.head_dim:].view( - torch.float32).contiguous() - cu_seqlen_ks = prefill_metadata.cu_seqlen_ks - cu_seqlen_ke = prefill_metadata.cu_seqlen_ke - logits = fp8_mqa_logits( - q_fp8[num_decode_tokens:], - (k_fp8, k_scale), - weights[num_decode_tokens:], - cu_seqlen_ks, - cu_seqlen_ke, - ) - topk_indices = logits.topk(min(self.topk_tokens, - logits.shape[-1]), - dim=-1)[1] - topk_indices -= cu_seqlen_ks[:, None] - mask_lo = topk_indices >= 0 - mask_hi = topk_indices < cu_seqlen_ke[:, - None] - cu_seqlen_ks[:, - None] - mask = mask_lo & mask_hi - topk_indices = topk_indices.masked_fill(~mask, -1) - topk_indices_buffer[num_decode_tokens:, :topk_indices. - shape[-1]] = topk_indices.to( - dtype=torch.int32) - if has_decode: - decode_metadata = attn_metadata.decode - # kv_cache size requirement [num_block, block_size, n_head, head_dim], - # we only have [num_block, block_size, head_dim], - kv_cache = kv_cache.unsqueeze(-2) - logits = fp8_paged_mqa_logits( - q_fp8[:num_decode_tokens].unsqueeze(1), - kv_cache, - weights[:num_decode_tokens], - decode_metadata.seq_lens, - decode_metadata.block_table, - decode_metadata.schedule_metadata, - max_model_len=self.max_model_len, - ) - positions = torch.arange(self.max_model_len, - device="cuda").unsqueeze( - 0) # [1, max_model_len] - next_n_offset = torch.arange(num_decode_tokens, device="cuda") - # NOTE(Chen): not true for spec decode - # [1, max_model_len] < [num_decode_tokens, 1] -> [num_decode_tokens, max_model_len] - mask = positions <= (decode_metadata.seq_lens - 1 + - next_n_offset).unsqueeze(1) - logits = logits.masked_fill(~mask, float("-inf")) - topk_indices = logits.topk( - min(self.topk_tokens, logits.shape[-1]), dim=-1)[1].to( - torch.int32) # [num_decode_tokens, topk_tokens] - topk_indices[topk_indices >= - decode_metadata.seq_lens[:, None]] = -1 - topk_indices_buffer[:num_decode_tokens, :topk_indices. - shape[-1]] = topk_indices.to( - dtype=torch.int32) - else: - # profile run - from vllm.v1.attention.backends.mla.indexer import get_max_prefill_buffer_size - total_seq_lens = get_max_prefill_buffer_size(self.vllm_config) - # NOTE(Chen): create the max possible flattened_kv. So that - # profile_run can get correct memory usage. - _flattened_kv = torch.empty([total_seq_lens, self.head_dim + 4], - device=k.device, - dtype=torch.uint8) - _k_fp8 = _flattened_kv[..., :self.head_dim].view( - torch.float8_e4m3fn).contiguous() - _k_scale = _flattened_kv[..., self.head_dim:].view( - torch.float32).contiguous() - topk_indices_buffer = torch.full( - (hidden_states.shape[0], self.topk_tokens), - -1, - dtype=torch.int32, - device=hidden_states.device) - return topk_indices_buffer + return torch.ops.vllm.sparse_attn_indexer( + hidden_states, self.k_cache.prefix, self.k_cache.kv_cache[0], + q_fp8, k, weights, self.quant_block_size, self.scale_fmt, + self.topk_tokens, self.head_dim, self.max_model_len, + self.max_total_seq_len, self.topk_indices_buffer) class DeepseekV2MLAAttention(nn.Module): @@ -785,6 +918,7 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + topk_indices_buffer: Optional[torch.Tensor] = None, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -872,7 +1006,7 @@ def __init__( if self.is_v32: self.indexer = Indexer(vllm_config, config, hidden_size, q_lora_rank, quant_config, cache_config, - f"{prefix}.indexer") + topk_indices_buffer, f"{prefix}.indexer") else: self.indexer = None @@ -891,6 +1025,7 @@ def __init__( q_proj=self.q_proj if self.q_lora_rank is None else None, indexer=self.indexer, is_sparse=self.is_v32, + topk_indices_buffer=topk_indices_buffer, ) self.mla_attn = MultiHeadLatentAttention( @@ -919,7 +1054,8 @@ def forward( class DeepseekV2DecoderLayer(nn.Module): - def __init__(self, vllm_config: VllmConfig, prefix: str) -> None: + def __init__(self, vllm_config: VllmConfig, prefix: str, + topk_indices_buffer: Optional[torch.Tensor]) -> None: super().__init__() config = vllm_config.model_config.hf_config @@ -958,6 +1094,7 @@ def __init__(self, vllm_config: VllmConfig, prefix: str) -> None: cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.self_attn", + topk_indices_buffer=topk_indices_buffer, ) if (config.n_routed_experts is not None @@ -1041,6 +1178,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config self.vocab_size = config.vocab_size + self.is_v32 = hasattr( + config, "attn_module_list_cfg" + ) and "attn_index" in config.attn_module_list_cfg[0] + if self.is_v32: + # TODO(Chen): remove this hardcode + topk_indices_buffer = torch.empty(1000, + 2048, + dtype=torch.int32, + device="cuda") + else: + topk_indices_buffer = None if get_pp_group().is_first_rank: self.embed_tokens = VocabParallelEmbedding( @@ -1053,7 +1201,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: DeepseekV2DecoderLayer(vllm_config, prefix), + lambda prefix: DeepseekV2DecoderLayer(vllm_config, prefix, + topk_indices_buffer), prefix=f"{prefix}.layers") if get_pp_group().is_last_rank: diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 6495458c265e..e1da4264913a 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -22,6 +22,9 @@ from vllm.v1.kv_cache_interface import AttentionSpec import triton import triton.language as tl +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from vllm.model_executor.models.deepseek_v2 import Indexer logger = init_logger(__name__) @@ -262,7 +265,7 @@ def build(self, # mask = debug_topk_indices <= pos_gpu.unsqueeze(1) # debug_topk_indices = debug_topk_indices.masked_fill(~mask, -1) debug_topk_indices = None - return FlashMLASparseMetadata( + metadata = FlashMLASparseMetadata( num_reqs=common_attn_metadata.num_reqs, max_query_len=common_attn_metadata.max_query_len, max_seq_len=common_attn_metadata.max_seq_len, @@ -279,6 +282,7 @@ def build(self, topk_tokens=self.topk_tokens, debug_topk_indices=debug_topk_indices, ) + return metadata @dataclass @@ -297,15 +301,22 @@ def __init__( attn_type: str, kv_sharing_target_layer_name: Optional[str], # MLA Specific Arguments + topk_indice_buffer: Optional[torch.Tensor] = None, + indexer: Optional["Indexer"] = None, **mla_args) -> None: - super().__init__(num_heads, head_size, scale, num_kv_heads, - alibi_slopes, sliding_window, kv_cache_dtype, - logits_soft_cap, attn_type, - kv_sharing_target_layer_name, **mla_args) - self.topk_indices = None - - def set_topk_indices(self, topk_indices: torch.Tensor): - self.topk_indices = topk_indices + super().__init__(num_heads, + head_size, + scale, + num_kv_heads, + alibi_slopes, + sliding_window, + kv_cache_dtype, + logits_soft_cap, + attn_type, + kv_sharing_target_layer_name, + indexer=indexer, + **mla_args) + self.topk_indice_buffer = indexer.topk_indices_buffer def forward( self, @@ -368,15 +379,16 @@ def forward( attn_out = self._forward_bf16_kv(ql_nope, q_pe, kv_cache, attn_metadata, self.scale) - output[:num_actual_tokens] = self._v_up_proj(attn_out) + output[:num_actual_tokens] = self._v_up_proj( + attn_out[:num_actual_tokens]) return output_padded def _forward_bf16_kv(self, ql_nope: torch.Tensor, q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: FlashMLASparseMetadata, k_scale: torch.Tensor) -> torch.Tensor: - topk_indices = self.topk_indices[:attn_metadata.num_actual_tokens] - num_tokens = ql_nope.shape[0] + topk_indices = self.topk_indice_buffer + num_tokens = attn_metadata.num_actual_tokens q = torch.cat([ql_nope, q_pe], dim=-1) kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view( -1, 1, kv_c_and_k_pe_cache.shape[-1]) @@ -391,14 +403,14 @@ def _forward_bf16_kv(self, ql_nope: torch.Tensor, q_pe: torch.Tensor, q = q_padded # TODO: handle index / kv_cache correctly topk_indices_global = triton_convert_req_index_to_global_index( - attn_metadata.req_id_per_token, + attn_metadata.req_id_per_token[:num_tokens], attn_metadata.block_table, - topk_indices, + topk_indices[:num_tokens], BLOCK_SIZE=attn_metadata.block_size, NUM_TOPK_TOKENS=attn_metadata.topk_tokens, ) topk_indices_global = topk_indices_global.view(num_tokens, 1, -1) - output = flash_mla_sparse_prefill(q, kv_c_and_k_pe_cache, + output = flash_mla_sparse_prefill(q[:num_tokens], kv_c_and_k_pe_cache, topk_indices_global, k_scale)[0] output = output[:, :self.num_heads, :] return output From 065e9c49f6f687298a25810fe43f4e941efe06ab Mon Sep 17 00:00:00 2001 From: Yongye Zhu Date: Sat, 27 Sep 2025 11:54:39 -0400 Subject: [PATCH 33/82] set max buffer size (#45) --------- Signed-off-by: Yongye Zhu --- vllm/model_executor/models/deepseek_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index bb7c6ec96696..45c619c6f7a7 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -1183,7 +1183,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) and "attn_index" in config.attn_module_list_cfg[0] if self.is_v32: # TODO(Chen): remove this hardcode - topk_indices_buffer = torch.empty(1000, + topk_indices_buffer = torch.empty(vllm_config.scheduler_config.max_num_batched_tokens, 2048, dtype=torch.int32, device="cuda") From 5fc357152af308140a1127cee2a4af511ef5107a Mon Sep 17 00:00:00 2001 From: Lucia Fang <116399278+luccafong@users.noreply.github.com> Date: Sat, 27 Sep 2025 13:48:39 -0700 Subject: [PATCH 34/82] fix indexer + mtp (#43) Co-authored-by: Lucia Fang --- .../attention/test_pack_unpack_triton.py | 396 ++++++++++++++++++ vllm/attention/ops/common.py | 176 ++++++++ vllm/model_executor/models/deepseek_mtp.py | 14 +- vllm/model_executor/models/deepseek_v2.py | 62 ++- .../attention/backends/mla/flashmla_sparse.py | 6 +- vllm/v1/attention/backends/mla/indexer.py | 9 +- 6 files changed, 639 insertions(+), 24 deletions(-) create mode 100644 tests/kernels/attention/test_pack_unpack_triton.py diff --git a/tests/kernels/attention/test_pack_unpack_triton.py b/tests/kernels/attention/test_pack_unpack_triton.py new file mode 100644 index 000000000000..a44c49829612 --- /dev/null +++ b/tests/kernels/attention/test_pack_unpack_triton.py @@ -0,0 +1,396 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch +from torch.testing import assert_close + +from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton + + +def test_pack_decode_query_basic_fp8(): + """Test basic functionality of pack_seq_triton with fp8 and 3D tensors.""" + device = "cuda" + dtype = torch.float8_e4m3fn + + # Test cases with 3D tensors (N, H, D) + test_cases = [ + (6, 8, 4, 2, [3, 3]), # (6, 8, 4) -> (2, 3, 8, 4) + (10, 4, 8, 3, [2, 4, 4]), # (10, 4, 8) -> (3, 4, 4, 8) + (20, 16, 32, 4, [5, 5, 5, 5]), # (20, 16, 32) -> (4, 5, 16, 32) + ] + + for N, H, D, B, lengths_list in test_cases: + # Create input tensor with small values for fp8 + x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1 + x = x.to(dtype=dtype) + lengths = torch.tensor(lengths_list, device=device) + + # Pack the data + packed = pack_seq_triton(x, lengths) + + # Check output shape and properties + expected_shape = (B, max(lengths_list), H, D) + assert packed.shape == expected_shape + assert packed.dtype == dtype + assert packed.device == x.device + + # Check that valid data is preserved (within fp8 precision) + for b in range(B): + start_idx = sum(lengths_list[:b]) + seq_len = lengths_list[b] + + expected_data = x[start_idx:start_idx + seq_len].to(torch.float32) + actual_data = packed[b, :seq_len].to(torch.float32) + + assert_close(actual_data, expected_data, rtol=1e-1, atol=1e-2) + + +def test_pack_decode_query_custom_padding_fp8(): + """Test pack_seq_triton with custom padding values for fp8.""" + device = "cuda" + dtype = torch.float8_e4m3fn + N, H, D, B = 20, 8, 16, 2 + lengths = torch.tensor([10, 10], device=device) + + x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1 + x = x.to(dtype=dtype) + + # Test with different padding values + for pad_value in [-100.0, -10.0, 0.0, 10.0, 100.0]: + result = pack_seq_triton(x, lengths, pad_value=pad_value) + + # Check valid data + for b in range(B): + start_idx = b * 10 + expected_data = x[start_idx:start_idx + 10].to(torch.float32) + actual_data = result[b, :10].to(torch.float32) + assert_close(actual_data, expected_data, rtol=1e-1, atol=1e-2) + + # Check padding (fp8 has limited range, so check for large values) + padded_data = result[:, 10:].to(torch.float32) + if pad_value < 0: + assert torch.all(padded_data < -50) # Large negative values + elif pad_value > 0: + assert torch.all(padded_data > 50) # Large positive values + else: + assert torch.allclose(padded_data, torch.zeros_like(padded_data), atol=1e-2) + + +def test_pack_decode_query_default_negative_inf_padding_fp8(): + """Test that pack_seq_triton uses -inf padding by default for fp8.""" + device = "cuda" + dtype = torch.float8_e4m3fn + N, H, D, B = 20, 8, 16, 2 + lengths = torch.tensor([10, 10], device=device) + + x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1 + x = x.to(dtype=dtype) + result = pack_seq_triton(x, lengths) + + # Check that padding is large negative values (fp8 representation of -inf) + padded_data = result[:, 10:].to(torch.float32) + assert torch.all(padded_data < -100) # fp8 -inf is represented as large negative number + + +def test_pack_decode_query_edge_cases_fp8(): + """Test pack_seq_triton with edge cases for fp8.""" + device = "cuda" + dtype = torch.float8_e4m3fn + + # Test with single batch element + x = torch.randn(10, 8, 16, dtype=torch.float32, device=device) * 0.1 + x = x.to(dtype=dtype) + lengths = torch.tensor([10], device=device) + result = pack_seq_triton(x, lengths) + assert result.shape == (1, 10, 8, 16) + + # Test with very short sequences + x = torch.randn(20, 4, 8, dtype=torch.float32, device=device) * 0.1 + x = x.to(dtype=dtype) + lengths = torch.tensor([1, 1, 1], device=device) + result = pack_seq_triton(x, lengths) + assert result.shape == (3, 1, 4, 8) + + # Test with different sequence lengths + x = torch.randn(15, 8, 16, dtype=torch.float32, device=device) * 0.1 + x = x.to(dtype=dtype) + lengths = torch.tensor([5, 7, 3], device=device) + result = pack_seq_triton(x, lengths) + assert result.shape == (3, 7, 8, 16) + + +def test_pack_decode_query_different_block_sizes_fp8(): + """Test pack_seq_triton with different block sizes for fp8.""" + device = "cuda" + dtype = torch.float8_e4m3fn + N, H, D, B = 100, 16, 32, 4 + lengths = torch.tensor([25, 25, 25, 25], device=device) + + x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1 + x = x.to(dtype=dtype) + + # Test different block sizes + for block_t, block_d in [(32, 32), (64, 64), (128, 128)]: + result = pack_seq_triton(x, lengths, block_t=block_t, block_d=block_d) + + assert result.shape == (B, 25, H, D) + + # Check that valid data is preserved (within fp8 precision) + for b in range(B): + start_idx = b * 25 + expected_data = x[start_idx:start_idx + 25].to(torch.float32) + actual_data = result[b, :25].to(torch.float32) + assert_close(actual_data, expected_data, rtol=1e-1, atol=1e-2) + + +def test_pack_decode_query_shape_consistency(): + """Test that pack_seq_triton maintains shape consistency.""" + device = "cuda" + dtype = torch.float8_e4m3fn + N, H, D, B = 20, 8, 16, 2 + lengths = torch.tensor([10, 10], device=device) + + x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1 + x = x.to(dtype=dtype) + + result = pack_seq_triton(x, lengths) + + # Check shape consistency + assert result.shape[0] == B # Batch dimension + assert result.shape[1] == lengths.max().item() # Max sequence length + assert result.shape[2:] == x.shape[1:] # Feature dimensions preserved + + +def test_pack_unpack_roundtrip_fp8(): + """Test that pack -> unpack gives us back the original data for fp8.""" + device = "cuda" + dtype = torch.float8_e4m3fn + + # Test cases with 3D tensors + test_cases = [ + (6, 8, 4, 2, [3, 3]), + (10, 4, 8, 3, [2, 4, 4]), + (20, 16, 32, 4, [5, 5, 5, 5]), + (15, 8, 16, 3, [7, 5, 3]), + ] + + for N, H, D, B, lengths_list in test_cases: + # Create input tensor with small values for fp8 + x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1 + x = x.to(dtype=dtype) + lengths = torch.tensor(lengths_list, device=device) + + # Pack the data + packed = pack_seq_triton(x, lengths) + + # Unpack the data + unpacked = unpack_seq_triton(packed, lengths) + + # Check that we get back the original data (within fp8 precision) + assert unpacked.shape == x.shape + x_f32 = x.to(torch.float32) + unpacked_f32 = unpacked.to(torch.float32) + assert_close(x_f32, unpacked_f32, rtol=1e-1, atol=1e-2) + + # Test with query_start_loc + query_start_loc = torch.cat([torch.zeros(1, device=device, dtype=lengths.dtype), + lengths.cumsum(0)[:-1]]) + unpacked_with_loc = unpack_seq_triton(packed, lengths, query_start_loc) + assert_close(x_f32, unpacked_with_loc.to(torch.float32), rtol=1e-1, atol=1e-2) + + +def test_unpack_seq_triton_edge_cases_fp8(): + """Test unpack function with edge cases for fp8.""" + device = "cuda" + dtype = torch.float8_e4m3fn + + # Test with single batch element + x = torch.randn(10, 8, 16, dtype=torch.float32, device=device) * 0.1 + x = x.to(dtype=dtype) + lengths = torch.tensor([10], device=device) + packed = pack_seq_triton(x, lengths) + unpacked = unpack_seq_triton(packed, lengths) + assert unpacked.shape == x.shape + assert_close(x.to(torch.float32), unpacked.to(torch.float32), rtol=1e-1, atol=1e-2) + + # Test with very short sequences + x = torch.randn(20, 4, 8, dtype=torch.float32, device=device) * 0.1 + x = x.to(dtype=dtype) + lengths = torch.tensor([1, 1, 1], device=device) + packed = pack_seq_triton(x, lengths) + unpacked = unpack_seq_triton(packed, lengths) + # Only compare the first 3 elements that were actually packed + assert_close(x[:3].to(torch.float32), unpacked.to(torch.float32), rtol=1e-1, atol=1e-2) + + # Test with query_start_loc + x = torch.randn(15, 8, 16, dtype=torch.float32, device=device) * 0.1 + x = x.to(dtype=dtype) + lengths = torch.tensor([5, 7, 3], device=device) + query_start_loc = torch.tensor([0, 5, 12], device=device) + packed = pack_seq_triton(x, lengths) + unpacked = unpack_seq_triton(packed, lengths, query_start_loc) + assert unpacked.shape == x.shape + assert_close(x.to(torch.float32), unpacked.to(torch.float32), rtol=1e-1, atol=1e-2) + + +def test_masked_topk_basic(): + """Test basic functionality of masked_topk function.""" + device = "cuda" + + # Test case 1: Simple example + seq_lens = torch.tensor([2, 1], device=device) # 2 batches: lengths 2,1 + starting_pos = torch.tensor([3, 7], device=device) # starting positions + N = seq_lens.sum().item() # 3 total positions + vocab_size, k = 20, 2 + + scores = torch.randn(N, vocab_size, device=device) + + indices, top_scores = masked_topk(scores, seq_lens, starting_pos, k) + + # Check output shapes + assert indices.shape == (N, k) + assert top_scores.shape == (N, k) + + # Verify masking constraints + # Positions 0,1 (batch 0): should only use indices < 3 + assert torch.all(indices[0] < 3) + assert torch.all(indices[1] < 3) + # Position 2 (batch 1): should only use indices < 7 + assert torch.all(indices[2] < 7) + + +def test_masked_topk_complex(): + """Test masked_topk with more complex sequences.""" + device = "cuda" + + # Test case: 4 batches with different lengths + seq_lens = torch.tensor([3, 1, 1, 1], device=device) # lengths: 3,1,1,1 + starting_pos = torch.tensor([4, 12, 33, 50], device=device) # starting positions + N = seq_lens.sum().item() # 6 total positions + vocab_size, k = 100, 3 + + scores = torch.randn(N, vocab_size, device=device) + + indices, top_scores = masked_topk(scores, seq_lens, starting_pos, k) + + # Check output shapes + assert indices.shape == (N, k) + assert top_scores.shape == (N, k) + + # Verify masking constraints for each batch + pos_idx = 0 + for b in range(len(seq_lens)): + seq_len = seq_lens[b].item() + start_pos = starting_pos[b].item() + + # Check all positions in this batch + for i in range(seq_len): + assert torch.all(indices[pos_idx] < start_pos), f"Position {pos_idx} should only use indices < {start_pos}" + pos_idx += 1 + + +def test_masked_topk_edge_cases(): + """Test masked_topk with edge cases.""" + device = "cuda" + + # Test case 1: Single batch + seq_lens = torch.tensor([5], device=device) + starting_pos = torch.tensor([10], device=device) + scores = torch.randn(5, 50, device=device) + + indices, top_scores = masked_topk(scores, seq_lens, starting_pos, k=3) + assert indices.shape == (5, 3) + assert torch.all(indices < 10) # All positions should use indices < 10 + + # Test case 2: Very small starting positions + seq_lens = torch.tensor([2, 1], device=device) + starting_pos = torch.tensor([1, 2], device=device) + scores = torch.randn(3, 20, device=device) + + indices, top_scores = masked_topk(scores, seq_lens, starting_pos, k=1) + assert indices.shape == (3, 1) + assert torch.all(indices[0] < 1) # First position can only use index 0 + assert torch.all(indices[1] < 1) # Second position can only use index 0 + assert torch.all(indices[2] < 2) # Third position can use indices 0,1 + + # Test case 3: Large starting positions + seq_lens = torch.tensor([2], device=device) + starting_pos = torch.tensor([95], device=device) + scores = torch.randn(2, 100, device=device) + + indices, top_scores = masked_topk(scores, seq_lens, starting_pos, k=5) + assert indices.shape == (2, 5) + assert torch.all(indices < 95) + + +def test_masked_topk_different_k_values(): + """Test masked_topk with different k values.""" + device = "cuda" + + seq_lens = torch.tensor([2, 1], device=device) + starting_pos = torch.tensor([5, 10], device=device) + scores = torch.randn(3, 20, device=device) + + # Test different k values + for k in [1, 3, 5, 10]: + indices, top_scores = masked_topk(scores, seq_lens, starting_pos, k) + + assert indices.shape == (3, k) + assert top_scores.shape == (3, k) + + # Verify masking constraints + assert torch.all(indices[0] < 5) + assert torch.all(indices[1] < 5) + assert torch.all(indices[2] < 10) + + +def test_masked_topk_fp8(): + """Test masked_topk with fp8 dtype.""" + device = "cuda" + dtype = torch.float8_e4m3fn + + seq_lens = torch.tensor([2, 1], device=device) + starting_pos = torch.tensor([5, 10], device=device) + + # Create fp8 scores + scores_f32 = torch.randn(3, 20, device=device) * 0.1 + scores = scores_f32.to(dtype) + + indices, top_scores = masked_topk(scores, seq_lens, starting_pos, k=3) + + # Check output shapes + assert indices.shape == (3, 3) + assert top_scores.shape == (3, 3) + assert top_scores.dtype == dtype + + # Verify masking constraints + assert torch.all(indices[0] < 5) + assert torch.all(indices[1] < 5) + assert torch.all(indices[2] < 10) + + # Check that top scores are reasonable (not all -inf) + assert not torch.all(torch.isinf(top_scores.to(torch.float32))) + + +def test_masked_topk_consistency(): + """Test that masked_topk produces consistent results.""" + device = "cuda" + + seq_lens = torch.tensor([2, 1], device=device) + starting_pos = torch.tensor([5, 10], device=device) + + # Use deterministic scores for consistency testing + torch.manual_seed(42) + scores = torch.randn(3, 20, device=device) + + # Run multiple times + results = [] + for _ in range(3): + indices, top_scores = masked_topk(scores, seq_lens, starting_pos, k=3) + results.append((indices.clone(), top_scores.clone())) + + # Check that all runs produce identical results + for i in range(1, len(results)): + assert torch.equal(results[0][0], results[i][0]), "Indices should be consistent" + assert_close(results[0][1], results[i][1], rtol=1e-5, atol=1e-5), "Scores should be consistent" diff --git a/vllm/attention/ops/common.py b/vllm/attention/ops/common.py index 189b57e8e8b8..cfd5c54c711e 100644 --- a/vllm/attention/ops/common.py +++ b/vllm/attention/ops/common.py @@ -137,3 +137,179 @@ def cp_lse_ag_out_rs(cp_attn_out: torch.Tensor, assert out.is_contiguous() out = cp_group.reduce_scatter(out, dim=1) return out + +@triton.jit +def _pack_seq_kernel( + x_ptr, # *fp8, [N, D] + out_ptr, # *fp8, [B, Lmax, D] + starts_ptr, # *i32, [B] + N: tl.constexpr, D: tl.constexpr, Lmax: tl.constexpr, + PAD_VALUE: tl.constexpr, + BLOCK_T: tl.constexpr, # timesteps per program + BLOCK_D: tl.constexpr # features per program +): + pid_b = tl.program_id(0) # batch id + pid_t = tl.program_id(1) # block over time dimension + pid_d = tl.program_id(2) # block over feature dimension + off_t = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) # [BLOCK_T] + off_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) # [BLOCK_D] + + # bounds + in_start = tl.load(starts_ptr + pid_b) + + # Calculate sequence length from starts + if pid_b < tl.num_programs(0) - 1: + next_start = tl.load(starts_ptr + pid_b + 1) + seq_len = next_start - in_start + else: + seq_len = N - in_start + + # valid time positions for this block + t_mask = off_t < Lmax + + # compute input row indices for valid (b, t) + in_row = in_start + off_t + valid_row = (off_t < seq_len) & t_mask + + # Pointers + # x_ptr: row-major [N, D] + x_row_ptr = x_ptr + in_row[:, None] * D + off_d[None, :] + + # out_ptr: row-major [B, Lmax, D] + out_row_ptr = out_ptr + (pid_b * Lmax + off_t)[:, None] * D + off_d[None, :] + + # Initialize with PAD + # (write pad for all t in this block) + d_mask = off_d[None, :] < D + tl.store(out_row_ptr, tl.full([BLOCK_T, BLOCK_D], PAD_VALUE, tl.float32), mask=t_mask[:, None] & d_mask) + + # Load & write only where within seq_len + x_vals = tl.load(x_row_ptr, mask=valid_row[:, None] & d_mask) + tl.store(out_row_ptr, x_vals, mask=valid_row[:, None] & d_mask) + +def pack_seq_triton(x, starts, pad_value=-float('inf'), block_t=64, block_d=64): + + # Handle multi-dimensional input by reshaping to (N, -1) + original_shape = x.shape + if len(original_shape) > 2: + N = original_shape[0] + x_reshaped = x.reshape(N, -1) + D = x_reshaped.shape[1] # Get the actual feature dimension + else: + N, D = x.shape + x_reshaped = x + + B = starts.numel() + # Calculate Lmax from starts without creating lengths tensor + if B == 1: + Lmax = N - starts[0].item() + else: + # Calculate max length from consecutive starts + lengths = starts[1:] - starts[:-1] + last_length = N - starts[-1].item() + Lmax = max(int(lengths.max().item()), int(last_length)) + + out = torch.empty((B, Lmax, D), device=x.device, dtype=x.dtype) + + grid = (B, triton.cdiv(Lmax, block_t), triton.cdiv(D, block_d)) + _pack_seq_kernel[grid]( + x_reshaped, out, starts.int(), + N, D, Lmax, + PAD_VALUE=float(pad_value), + BLOCK_T=block_t, BLOCK_D=block_d, + num_warps=4, num_stages=2 + ) + + # Reshape output back to original dimensions (except first dimension) + if len(original_shape) > 2: + output_shape = (B, Lmax) + original_shape[1:] + out = out.reshape(output_shape) + + return out + + +@triton.jit +def _unpack_seq_triton_kernel( + packed_ptr, # *fp8, [B, Lmax, D] + out_ptr, # *fp8, [N, D] + starts_ptr, # *i32, [B] + lengths_ptr, # *i32, [B] + B: tl.constexpr, Lmax: tl.constexpr, D: tl.constexpr, + BLOCK_T: tl.constexpr, # timesteps per program + BLOCK_D: tl.constexpr # features per program +): + pid_b = tl.program_id(0) # batch id + pid_t = tl.program_id(1) # block over time dimension + pid_d = tl.program_id(2) # block over feature dimension + off_t = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) # [BLOCK_T] + off_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) # [BLOCK_D] + + # bounds + in_start = tl.load(starts_ptr + pid_b) + seq_len = tl.load(lengths_ptr + pid_b) + + # valid time positions for this block + t_mask = off_t < Lmax + valid_row = (off_t < seq_len) & t_mask + + # compute output row indices for valid (b, t) + out_row = in_start + off_t + + # Pointers + # packed_ptr: row-major [B, Lmax, D] + packed_row_ptr = packed_ptr + (pid_b * Lmax + off_t)[:, None] * D + off_d[None, :] + + # out_ptr: row-major [N, D] + out_row_ptr = out_ptr + out_row[:, None] * D + off_d[None, :] + + # Load from packed tensor and store to output + d_mask = off_d[None, :] < D + packed_vals = tl.load(packed_row_ptr, mask=valid_row[:, None] & d_mask) + tl.store(out_row_ptr, packed_vals, mask=valid_row[:, None] & d_mask) + + +def unpack_seq_triton(packed_tensor, starts, lengths, block_t=64, block_d=64): + """ + Unpack a packed decode query tensor back to the original format. + Efficient Triton implementation. + + Args: + packed_tensor: [B, Lmax, ...] - packed tensor from pack_seq_triton + starts: [B] - start locations for each batch + lengths: [B] - sequence lengths for each batch (needed to calculate total N) + block_t: block size for time dimension + block_d: block size for feature dimension + + Returns: + unpacked_tensor: [N, ...] where N = sum(lengths) + """ + + # Handle multi-dimensional input by reshaping to (B, Lmax, -1) + original_shape = packed_tensor.shape + if len(original_shape) > 3: + B, Lmax = original_shape[:2] + packed_reshaped = packed_tensor.reshape(B, Lmax, -1) + D = packed_reshaped.shape[2] # Get the actual feature dimension + else: + B, Lmax, D = packed_tensor.shape + packed_reshaped = packed_tensor + + # Calculate total number of elements + N = int(lengths.sum().item()) + + out = torch.empty((N, D), device=packed_tensor.device, dtype=packed_tensor.dtype) + + grid = (B, triton.cdiv(Lmax, block_t), triton.cdiv(D, block_d)) + _unpack_seq_triton_kernel[grid]( + packed_reshaped, out, starts.int(), lengths.int(), + B, Lmax, D, + BLOCK_T=block_t, BLOCK_D=block_d, + num_warps=4, num_stages=2 + ) + + # Reshape output back to original dimensions (except first dimension) + if len(original_shape) > 3: + output_shape = (N,) + original_shape[2:] + out = out.reshape(output_shape) + + return out \ No newline at end of file diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index 8fbf16d206a8..42e7fe5b6f59 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -54,8 +54,20 @@ def __init__(self, vllm_config: VllmConfig, prefix: str) -> None: self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False) + + self.is_v32 = hasattr( + config, "attn_module_list_cfg" + ) and "attn_index" in config.attn_module_list_cfg[0] + if self.is_v32: + topk_tokens = config.attn_module_list_cfg[0]["topk_tokens"] + topk_indices_buffer = torch.empty(vllm_config.scheduler_config.max_num_batched_tokens, + topk_tokens, + dtype=torch.int32, + device="cuda") + else: + topk_indices_buffer = None self.shared_head = SharedHead(config=config, quant_config=quant_config) - self.mtp_block = DeepseekV2DecoderLayer(vllm_config, prefix) + self.mtp_block = DeepseekV2DecoderLayer(vllm_config, prefix, topk_indices_buffer) def forward( self, diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 45c619c6f7a7..4a9d11b7dbc2 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -67,6 +67,7 @@ from vllm.sequence import IntermediateTensors from vllm.utils import cdiv, direct_register_custom_op from vllm.v1.attention.backends.mla.indexer import DeepseekV32IndexerBackend, DeepseekV32IndexerMetadata +from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP from .utils import (PPMissingLayer, extract_layer_index, @@ -736,9 +737,22 @@ def sparse_attn_indexer( decode_metadata = attn_metadata.decode # kv_cache size requirement [num_block, block_size, n_head, head_dim], # we only have [num_block, block_size, head_dim], + query_start_loc = attn_metadata.query_start_loc + decode_lens = query_start_loc[1:attn_metadata.num_decodes+1] - query_start_loc[:attn_metadata.num_decodes] kv_cache = kv_cache.unsqueeze(-2) + require_padding = (decode_lens.max() > decode_lens.min()).item() + if require_padding: + # pad in edge case where we have short chunked prefill length < + # decode_threshold since we unstrictly split + # prefill and decode by decode_threshold (currently set to 1 + speculative tokens) + padded_q_fp8_decode_tokens = pack_seq_triton(q_fp8[:num_decode_tokens], decode_lens) + else: + padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape(decode_lens.shape[0], -1, *q_fp8.shape[1:]) + # TODO: move and optimize below logic with triton kernels + batch_size = padded_q_fp8_decode_tokens.shape[0] + assert batch_size == decode_metadata.seq_lens.shape[0] logits = fp8_paged_mqa_logits( - q_fp8[:num_decode_tokens].unsqueeze(1), + padded_q_fp8_decode_tokens, kv_cache, weights[:num_decode_tokens], decode_metadata.seq_lens, @@ -746,20 +760,32 @@ def sparse_attn_indexer( decode_metadata.schedule_metadata, max_model_len=max_model_len, ) - positions = torch.arange(max_model_len, device="cuda").unsqueeze( - 0) # [1, max_model_len] - next_n_offset = torch.arange(num_decode_tokens, device="cuda") - # NOTE(Chen): not true for spec decode - # [1, max_model_len] < [num_decode_tokens, 1] -> [num_decode_tokens, max_model_len] - mask = positions <= (decode_metadata.seq_lens - 1 + - next_n_offset).unsqueeze(1) - logits = logits.masked_fill(~mask, float("-inf")) - topk_indices = logits.topk( - min(topk_tokens, logits.shape[-1]), - dim=-1)[1].to(torch.int32) # [num_decode_tokens, topk_tokens] - topk_indices[topk_indices >= decode_metadata.seq_lens[:, None]] = -1 + # [B, N, L] + next_n = padded_q_fp8_decode_tokens.shape[1] + # padded query len + current_device = padded_q_fp8_decode_tokens.device + padded_num_tokens = batch_size * next_n + positions = torch.arange(max_model_len, device=current_device).unsqueeze(0).expand( + batch_size * next_n, -1) + row_indices = torch.arange(padded_num_tokens, device=current_device) // next_n + next_n_offset = torch.arange(padded_num_tokens, device=padded_q_fp8_decode_tokens.device) % next_n + index_end_pos = (decode_metadata.seq_lens[row_indices] - next_n + next_n_offset).unsqueeze(1) + # index_end_pos: [B * N, 1] + mask = positions <= index_end_pos + # mask: [B * N, L] + logits = logits.masked_fill(~mask, float('-inf')) + topk_indices = logits.topk(topk_tokens, dim=-1)[1].to( + torch.int32) # [B * N, K] + # ensure we don't set indices for the top k that out of range(masked already) + # this will happen if context length is shorter than K + topk_indices[topk_indices > index_end_pos] = -1 + if require_padding: + # if padded, we need to unpack the topk indices removing padded tokens + topk_indices = unpack_seq_triton(topk_indices.reshape(batch_size, -1, logits.shape[-1]), decode_lens) topk_indices_buffer[:num_decode_tokens, :topk_indices. - shape[-1]] = topk_indices.to(dtype=torch.int32) + shape[-1]] = topk_indices.to( + dtype=torch.int32) + return topk_indices_buffer @@ -853,10 +879,6 @@ def __init__(self, def forward(self, hidden_states: torch.Tensor, qr: torch.Tensor, positions, rotary_emb) -> torch.Tensor: - # hidden_states.shape: torch.Size([16, 7168]), qr.shape: torch.Size([16, 1536]), positions.shape: torch.Size([16]) - - # print(f"hidden_states: {torch.isinf(hidden_states).any()}, qr: {torch.isinf(qr).any()}") - # print(f"weight_proj: {torch.isneginf(self.weights_proj.weight.to(torch.float32)).any()}") q, _ = self.wq_b(qr) q = q.view(-1, self.n_head, self.head_dim) q_pe, q_nope = torch.split( @@ -1182,9 +1204,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config, "attn_module_list_cfg" ) and "attn_index" in config.attn_module_list_cfg[0] if self.is_v32: - # TODO(Chen): remove this hardcode + topk_tokens = config.attn_module_list_cfg[0]["topk_tokens"] topk_indices_buffer = torch.empty(vllm_config.scheduler_config.max_num_batched_tokens, - 2048, + topk_tokens, dtype=torch.int32, device="cuda") else: diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index e1da4264913a..0bf03c10c984 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import Optional, Union +from typing import Optional, Union, ClassVar import numpy as np import torch @@ -222,12 +222,16 @@ def triton_convert_req_index_to_global_index( class FlashMLASparseMetadataBuilder( MLACommonMetadataBuilder[FlashMLASparseMetadata]): + reorder_batch_threshold: ClassVar[int] = 1 + def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): super().__init__(kv_cache_spec, layer_names, vllm_config, device, FlashMLASparseMetadata) self.topk_tokens = vllm_config.model_config.hf_config\ .attn_module_list_cfg[0]["topk_tokens"] + self.num_speculative_tokens = vllm_config.speculative_config.num_speculative_tokens + self.reorder_batch_threshold += self.num_speculative_tokens def _build_prefill( self, common_attn_metadata: CommonAttentionMetadata diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index daf9b507f215..13432a18e92e 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Optional +from typing import ClassVar, Optional from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata from vllm.config import VllmConfig @@ -154,6 +154,7 @@ def get_max_prefill_buffer_size(vllm_config: VllmConfig): class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): + reorder_batch_threshold: ClassVar[int] = 1 def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) max_model_len = self.vllm_config.model_config.max_model_len @@ -161,6 +162,8 @@ def __init__(self, *args, **kwargs): # NOTE(Chen): an estimated max size of flattened_kv. Need to double check. self.max_prefill_buffer_size = get_max_prefill_buffer_size( self.vllm_config) + self.num_speculative_tokens = self.vllm_config.speculative_config.num_speculative_tokens + self.reorder_batch_threshold += self.num_speculative_tokens def build(self, common_prefix_len: int, @@ -176,7 +179,9 @@ def build(self, query_start_loc = common_attn_metadata.query_start_loc num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ - split_decodes_and_prefills(common_attn_metadata) + split_decodes_and_prefills( + common_attn_metadata, + decode_threshold=self.reorder_batch_threshold) assert num_decodes + num_prefills == num_reqs assert num_decode_tokens + num_prefill_tokens == num_tokens From a95229bbf9664f45a552cb07075bd451004c246f Mon Sep 17 00:00:00 2001 From: Yongye Zhu Date: Sun, 28 Sep 2025 00:34:42 -0400 Subject: [PATCH 35/82] indexer ref code cleanup (#47) --- vllm/model_executor/models/deepseek_v2.py | 28 ----------------------- 1 file changed, 28 deletions(-) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 4a9d11b7dbc2..25d6340f198b 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -586,34 +586,6 @@ def tilelang_act_quant_fake( dispatch_key=current_platform.dispatch_key, ) - -def ref_fp8_mqa_logits( - q: torch.Tensor, - kv: torch.Tensor, - weights: torch.Tensor, - cu_seqlen_ks: torch.Tensor, - cu_seqlen_ke: torch.Tensor, -): - # print(f"q_shape: {q.shape}, v_shape: {kv.shape}, weights.shape: {weights.shape}") - k = kv - q = q.float() - k = k.float() - - seq_len_kv = kv.shape[0] - mask_lo = (torch.arange(0, seq_len_kv, device="cuda")[None, :] - >= cu_seqlen_ks[:, None]) - mask_hi = (torch.arange(0, seq_len_kv, device="cuda")[None, :] - < cu_seqlen_ke[:, None]) - mask = mask_lo & mask_hi - - score = torch.einsum("mhd,nd->hmn", q, k) - logits = (score.relu() * weights.transpose(0, 1)).sum(dim=0) - logits = logits.masked_fill(~mask, float("-inf")) - - cost = mask.sum() - return logits, cost - - @torch.inference_mode() def indexer_k_quant_and_cache( k, From ff5eb403d1e27351e2fe37aa8895f25ad7a7e8cd Mon Sep 17 00:00:00 2001 From: Lucia Fang <116399278+luccafong@users.noreply.github.com> Date: Sun, 28 Sep 2025 09:21:57 -0700 Subject: [PATCH 36/82] fix non spec decode error (#48) Co-authored-by: Lucia Fang --- vllm/v1/attention/backends/mla/flashmla_sparse.py | 5 ++++- vllm/v1/attention/backends/mla/indexer.py | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 0bf03c10c984..1a0b91ebc544 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -230,7 +230,10 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], FlashMLASparseMetadata) self.topk_tokens = vllm_config.model_config.hf_config\ .attn_module_list_cfg[0]["topk_tokens"] - self.num_speculative_tokens = vllm_config.speculative_config.num_speculative_tokens + self.num_speculative_tokens = ( + vllm_config.speculative_config.num_speculative_tokens + if vllm_config.speculative_config else 0 + ) self.reorder_batch_threshold += self.num_speculative_tokens def _build_prefill( diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index 13432a18e92e..83b363c85d9c 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -162,7 +162,10 @@ def __init__(self, *args, **kwargs): # NOTE(Chen): an estimated max size of flattened_kv. Need to double check. self.max_prefill_buffer_size = get_max_prefill_buffer_size( self.vllm_config) - self.num_speculative_tokens = self.vllm_config.speculative_config.num_speculative_tokens + self.num_speculative_tokens = ( + self.vllm_config.speculative_config.num_speculative_tokens + if self.vllm_config.speculative_config else 0 + ) self.reorder_batch_threshold += self.num_speculative_tokens def build(self, From 10e6d47031fa61f96c60caf9e5f2d6a79a031611 Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Sun, 28 Sep 2025 10:45:31 -0700 Subject: [PATCH 37/82] fix test --- .../attention/test_deepgemm_attention.py | 143 +++++++--------- tests/kernels/attention/test_indexer.py | 161 +++++++++--------- vllm/utils/deep_gemm.py | 53 +++--- 3 files changed, 170 insertions(+), 187 deletions(-) diff --git a/tests/kernels/attention/test_deepgemm_attention.py b/tests/kernels/attention/test_deepgemm_attention.py index d6c7c4368de9..03cc6b930c94 100644 --- a/tests/kernels/attention/test_deepgemm_attention.py +++ b/tests/kernels/attention/test_deepgemm_attention.py @@ -1,16 +1,15 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import random + import pytest import torch from vllm.platforms import current_platform -from vllm.utils import has_deep_gemm, cdiv -from vllm.utils.deep_gemm import ( - _ceil_to_ue8m0, - fp8_mqa_logits, - calc_diff, - get_paged_mqa_logits_metadata, - fp8_paged_mqa_logits, -) +from vllm.utils import cdiv, has_deep_gemm +from vllm.utils.deep_gemm import (_ceil_to_ue8m0, calc_diff, fp8_mqa_logits, + fp8_paged_mqa_logits, get_num_sms, + get_paged_mqa_logits_metadata) def kv_cache_cast_to_fp8(x: torch.Tensor) -> torch.Tensor: @@ -25,18 +24,17 @@ def kv_cache_cast_to_fp8(x: torch.Tensor) -> torch.Tensor: device=x.device, dtype=torch.uint8, ) - x_fp8[:, : block_size * head_dim] = x_scaled.view( - num_blocks, block_size * head_dim - ).view(dtype=torch.uint8) - x_fp8[:, block_size * head_dim :] = sf.view(num_blocks, block_size).view( - dtype=torch.uint8 - ) + x_fp8[:, :block_size * head_dim] = x_scaled.view( + num_blocks, block_size * head_dim).view(dtype=torch.uint8) + x_fp8[:, + block_size * head_dim:] = sf.view(num_blocks, + block_size).view(dtype=torch.uint8) return x_fp8.view(num_blocks, block_size, num_heads, head_dim + 4) def per_custom_dims_cast_to_fp8( - x: torch.Tensor, dims: tuple, use_ue8m0: bool -) -> tuple[torch.Tensor, torch.Tensor]: + x: torch.Tensor, dims: tuple, + use_ue8m0: bool) -> tuple[torch.Tensor, torch.Tensor]: excluded_dims = tuple([i for i in range(x.dim()) if i not in set(dims)]) x_amax = x.abs().float().amax(dim=excluded_dims, keepdim=True).clamp(1e-4) sf = x_amax / 448.0 @@ -71,17 +69,13 @@ def _ref_fp8_mqa_logits( q = q.float() k = k.float() - mask_lo = ( - torch.arange(0, seq_len_kv, device="cuda")[None, :] - >= cu_seqlen_ks[:, None] - ) - mask_hi = ( - torch.arange(0, seq_len_kv, device="cuda")[None, :] - < cu_seqlen_ke[:, None] - ) + mask_lo = (torch.arange(0, seq_len_kv, device="cuda")[None, :] + >= cu_seqlen_ks[:, None]) + mask_hi = (torch.arange(0, seq_len_kv, device="cuda")[None, :] + < cu_seqlen_ke[:, None]) mask = mask_lo & mask_hi - score = torch.einsum("mhd,nd->hmn", q, k) + score = torch.einsum("mhd,and->hmn", q, k) logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0) logits = logits.masked_fill(~mask, float("-inf")) @@ -94,8 +88,8 @@ def test_deepgemm_fp8_mqa_logits(): torch.manual_seed(0) random.seed(0) num_heads, head_dim = 32, 128 - for seq_len in (512,): - for seq_len_kv in (1024,): + for seq_len in (512, ): + for seq_len_kv in (1024, ): for disable_cp in (False, True): q = torch.randn( seq_len, @@ -104,23 +98,24 @@ def test_deepgemm_fp8_mqa_logits(): device="cuda", dtype=torch.bfloat16, ) - kv = torch.randn( - seq_len_kv, head_dim, device="cuda", dtype=torch.bfloat16 - ) - weights = torch.randn( - seq_len, num_heads, device="cuda", dtype=torch.float32 - ) + kv = torch.randn(seq_len_kv, + head_dim, + device="cuda", + dtype=torch.bfloat16) + weights = torch.randn(seq_len, + num_heads, + device="cuda", + dtype=torch.float32) if disable_cp: ks = torch.zeros(seq_len, dtype=torch.int, device="cuda") - ke = torch.arange( - seq_len, dtype=torch.int, device="cuda" - ) + (seq_len_kv - seq_len) + ke = torch.arange(seq_len, dtype=torch.int, + device="cuda") + (seq_len_kv - seq_len) else: ks, ke = _generate_cp_test_data(seq_len, seq_len_kv) q_fp8 = q.to(torch.float8_e4m3fn) - kv_fp8 = per_custom_dims_cast_to_fp8(kv, (0,), False) + kv_fp8 = per_custom_dims_cast_to_fp8(kv, (0, ), False) logits = fp8_mqa_logits(q_fp8, kv_fp8, weights, ks, ke) ref_logits = _ref_fp8_mqa_logits( @@ -160,14 +155,11 @@ def _ref_fp8_paged_mqa_logits( context_lens_list = context_lens.tolist() for i in range(batch_size): context_len = context_lens_list[i] - q_offsets = torch.arange( - context_len - next_n, context_len, device="cuda" - ) - weight_slice = ( - weights[i * next_n : (i + 1) * next_n, :] - .transpose(0, 1) - .contiguous() - ) + q_offsets = torch.arange(context_len - next_n, + context_len, + device="cuda") + weight_slice = (weights[i * next_n:(i + 1) * next_n, :].transpose( + 0, 1).contiguous()) for block_rk in range(cdiv(context_len, block_size)): block_idx = block_tables[i][block_rk] qx, kx = q[i], kv_cache[block_idx] @@ -176,24 +168,21 @@ def _ref_fp8_paged_mqa_logits( (block_rk + 1) * block_size, device="cuda", ) - mask = (k_offsets[None, :] < context_len) & ( - k_offsets[None, :] <= q_offsets[:, None] - ) + mask = (k_offsets[None, :] < context_len) & (k_offsets[None, :] + <= q_offsets[:, None]) s = torch.where( mask[None, :, :], (qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to( - logits.dtype - ), + logits.dtype), float("-inf"), ) s = torch.relu(s) * weight_slice[..., None] s = s.sum(dim=0) logits[ - i * next_n : (i + 1) * next_n, - block_rk * block_size : (block_rk + 1) * block_size, - ] = torch.where( - k_offsets[None, :] <= q_offsets[:, None], s, float("-inf") - ) + i * next_n:(i + 1) * next_n, + block_rk * block_size:(block_rk + 1) * block_size, + ] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, + float("-inf")) return logits @@ -205,8 +194,8 @@ def test_deepgemm_fp8_paged_mqa_logits(): max_model_len = 4096 for batch_size, next_n in [(4, 1), (2, 2)]: - for heads, index_dim in [(16, 128)]: - for avg_kv in (2048,): + for heads, index_dim in [(32, 128)]: + for avg_kv in (2048, ): num_blocks, blocksize = max_model_len * 2, 64 q = torch.randn( @@ -225,18 +214,12 @@ def test_deepgemm_fp8_paged_mqa_logits(): dtype=torch.float32, ) - context_lens = ( - torch.randint( - int(0.8 * avg_kv), int(1.2 * avg_kv), (batch_size,) - ) - .cuda() - .to(torch.int32) - ) - max_block_len = ( - (context_lens.max().item() + blocksize - 1) - // blocksize - * blocksize - ) + context_lens = (torch.randint(int(0.8 * avg_kv), + int(1.2 * avg_kv), + (batch_size, )).cuda().to( + torch.int32)) + max_block_len = ((context_lens.max().item() + blocksize - 1) // + blocksize * blocksize) block_tables = torch.zeros( (batch_size, max_block_len), device="cuda", @@ -256,8 +239,7 @@ def test_deepgemm_fp8_paged_mqa_logits(): kv_cache_fp8 = kv_cache_cast_to_fp8(kv_cache) schedule_metadata = get_paged_mqa_logits_metadata( - context_lens, blocksize, 132 - ) + context_lens, blocksize, get_num_sms()) logits = fp8_paged_mqa_logits( q_fp8, kv_cache_fp8, @@ -277,20 +259,15 @@ def test_deepgemm_fp8_paged_mqa_logits(): max_model_len, ) - positions = ( - torch.arange(max_model_len, device="cuda") - .unsqueeze(0) - .expand(batch_size * next_n, -1) - ) + positions = (torch.arange(max_model_len, + device="cuda").unsqueeze(0).expand( + batch_size * next_n, -1)) row_indices = ( - torch.arange(batch_size * next_n, device="cuda") // next_n - ) + torch.arange(batch_size * next_n, device="cuda") // next_n) next_n_offset = ( - torch.arange(batch_size * next_n, device="cuda") % next_n - ) - mask = positions <= ( - context_lens[row_indices] - next_n + next_n_offset - ).unsqueeze(1) + torch.arange(batch_size * next_n, device="cuda") % next_n) + mask = positions <= (context_lens[row_indices] - next_n + + next_n_offset).unsqueeze(1) logits = logits.masked_fill(~mask, 0) ref_logits = ref_logits.masked_fill(~mask, 0) diff --git a/tests/kernels/attention/test_indexer.py b/tests/kernels/attention/test_indexer.py index 1121892f5172..696e1587037f 100644 --- a/tests/kernels/attention/test_indexer.py +++ b/tests/kernels/attention/test_indexer.py @@ -1,18 +1,17 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import random import torch +from vllm import _custom_ops as ops from vllm.utils import cdiv -from vllm.v1.attention.backends.mla.indexer import kv_spans_from_batches +from vllm.utils.deep_gemm import (calc_diff, fp8_mqa_logits, + fp8_paged_mqa_logits, get_num_sms, + get_paged_mqa_logits_metadata) from vllm.utils.tile_lang_kernels import act_quant, fp8_index -from vllm import _custom_ops as ops -from vllm.model_executor.models.deepseek_v2 import indexer_k_quant_and_cache -from vllm.utils.deep_gemm import ( - fp8_mqa_logits, - calc_diff, - get_paged_mqa_logits_metadata, - fp8_paged_mqa_logits, -) +from vllm.v1.attention.backends.mla.indexer import kv_spans_from_batches + def kv_cache_cast_to_fp8(x: torch.Tensor) -> torch.Tensor: num_blocks, block_size, num_heads, head_dim = x.shape @@ -30,47 +29,56 @@ def kv_cache_cast_to_fp8(x: torch.Tensor) -> torch.Tensor: block_size).view(dtype=torch.uint8) return x_fp8.view(num_blocks, block_size, num_heads, head_dim + 4) + def ref_compute_logits_fp8(q, kv, weights, mask, block_size): q_fp8, q_scale = act_quant(q, block_size, "ue8m0") k_fp8, k_scale = act_quant(kv, block_size, "ue8m0") - + weights = weights.unsqueeze(-1) * q_scale weights = weights * (128**(-0.5)) - index_score = fp8_index( - q_fp8.contiguous(), weights, - k_fp8.contiguous(), - k_scale.contiguous()) + index_score = fp8_index(q_fp8.contiguous(), weights, k_fp8.contiguous(), + k_scale.contiguous()) if mask is not None: index_score += mask return index_score + def ref_indexer(seq_len, q, kv, weights, block_size, topk): B = seq_len.shape[0] total_seqlen = torch.sum(seq_len) - varlen_logits = torch.full((total_seqlen, total_seqlen), float("-inf"), device="cuda") - + varlen_logits = torch.full((total_seqlen, total_seqlen), + float("-inf"), + device="cuda") + current_context_ptr = 0 for i in range(B): S = seq_len[i] q_s = q[i][:S].contiguous().unsqueeze(0) kv_s = kv[i][:S].contiguous().unsqueeze(0) weights_s = weights[i][:S].contiguous().unsqueeze(0) - mask = torch.full( - (S, S), float("-inf"), - device="cuda").triu_(1) + mask = torch.full((S, S), float("-inf"), device="cuda").triu_(1) logits = ref_compute_logits_fp8(q_s, kv_s, weights_s, mask, block_size) logits = logits.squeeze(0) - - varlen_logits[current_context_ptr:current_context_ptr + S, current_context_ptr: current_context_ptr + S] = logits + + varlen_logits[current_context_ptr:current_context_ptr + S, + current_context_ptr:current_context_ptr + S] = logits current_context_ptr += S return varlen_logits -def deepgemm_mqa_indexer(seq_len, query_seq_len, q, kv, weights, block_size, topk, is_kv_batched=True): + +def deepgemm_mqa_indexer(seq_len, + query_seq_len, + q, + kv, + weights, + block_size, + topk, + is_kv_batched=True): B = seq_len.shape[0] concat_q = [] concat_kv = [] concat_weights = [] - + for i in range(B): S = seq_len[i] q_s = q[i][:S].contiguous() @@ -81,29 +89,25 @@ def deepgemm_mqa_indexer(seq_len, query_seq_len, q, kv, weights, block_size, top if is_kv_batched: concat_kv.append(kv_s) concat_weights.append(weight_s) - + q = torch.cat(concat_q, dim=0) if is_kv_batched: kv = torch.cat(concat_kv, dim=0) weights = torch.cat(concat_weights, dim=0) q_fp8, q_scale = act_quant(q, block_size, "ue8m0") kv_fp8, kv_scale = act_quant(kv, block_size, "ue8m0") - + weights = weights.unsqueeze(-1) * (128**(-0.5)) * q_scale weights = weights.squeeze(-1) query_start_loc = torch.empty((B + 1), device="cuda") query_start_loc[0] = 0 query_start_loc[1:] = query_seq_len.cumsum(dim=0).to(dtype=torch.int32) - cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(query_start_loc, seq_len) + cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(query_start_loc, + seq_len) - logits = fp8_mqa_logits( - q_fp8, - (kv_fp8, kv_scale), - weights, - cu_seqlen_ks, - cu_seqlen_ke - ) + logits = fp8_mqa_logits(q_fp8, (kv_fp8, kv_scale), weights, cu_seqlen_ks, + cu_seqlen_ke) topk_indices = logits.topk(topk, dim=-1)[1] mask_lo = topk_indices >= cu_seqlen_ks[:, None] mask_hi = topk_indices < cu_seqlen_ke[:, None] @@ -111,6 +115,7 @@ def deepgemm_mqa_indexer(seq_len, query_seq_len, q, kv, weights, block_size, top topk_indices = topk_indices.masked_fill(~mask, -1) return logits + def test_prefill_indexer(): B = 3 S = 128 @@ -121,16 +126,16 @@ def test_prefill_indexer(): block_size = 128 topk = 64 device = "cuda" - seq_len = torch.randint(low=64, high=S, size=(B,)) - - q = torch.randn(B, S, H, D, device="cuda", - dtype=torch.bfloat16) - kv = torch.randn(B, SKV, D, device="cuda", - dtype=torch.bfloat16) - weights = torch.randn(B, S, H, device=device, dtype=torch.float32) * H**-0.5 + seq_len = torch.randint(low=64, high=S, size=(B, )) + + q = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16) + kv = torch.randn(B, SKV, D, device="cuda", dtype=torch.bfloat16) + weights = torch.randn(B, S, H, device=device, + dtype=torch.float32) * H**-0.5 ref_logits = ref_indexer(seq_len, q, kv, weights, block_size, topk) - deepgemm_logits = deepgemm_mqa_indexer(seq_len, seq_len, q, kv, weights, block_size, topk) + deepgemm_logits = deepgemm_mqa_indexer(seq_len, seq_len, q, kv, weights, + block_size, topk) torch.testing.assert_close(ref_logits, deepgemm_logits) @@ -145,26 +150,23 @@ def test_decode_paged_indexer(): block_size = 128 topk = 64 device = "cuda" - seq_len = torch.randint(low=64, high=S, size=(B,), device="cuda") + seq_len = torch.randint(low=64, high=S, size=(B, ), device="cuda") query_seq_len = torch.ones(B, device="cuda") - q = torch.randn((B, 1, H, D), - device='cuda', - dtype=torch.bfloat16) + q = torch.randn((B, 1, H, D), device='cuda', dtype=torch.bfloat16) kv_cache = torch.randn((num_blocks, blocksize, 1, D), - device='cuda', - dtype=torch.bfloat16) - weights = torch.randn((B * 1, H), - device='cuda', - dtype=torch.float32) * H**-0.5 + device='cuda', + dtype=torch.bfloat16) + weights = torch.randn( + (B * 1, H), device='cuda', dtype=torch.float32) * H**-0.5 max_block_len = (seq_len.max().item() + blocksize - - 1) // blocksize * blocksize - + 1) // blocksize * blocksize + block_tables = torch.zeros((B, max_block_len), - device='cuda', - dtype=torch.int32) - + device='cuda', + dtype=torch.int32) + counter = 0 block_idx_pool = list(range(num_blocks)) random.shuffle(block_idx_pool) @@ -173,51 +175,58 @@ def test_decode_paged_indexer(): for j in range(cdiv(ctx_len, blocksize)): block_tables[i][j] = block_idx_pool[counter] counter += 1 - - flatten_kv = torch.empty( - [seq_len.sum(), D], device="cuda", dtype=torch.bfloat16 - ) + + flatten_kv = torch.empty([seq_len.sum(), D], + device="cuda", + dtype=torch.bfloat16) cu_seq_lens = torch.cat([ - torch.zeros(1, dtype=torch.int32, device=device), - seq_len.cumsum(dim=0) - ]).to(torch.int32).cuda() + torch.zeros(1, dtype=torch.int32, device=device), + seq_len.cumsum(dim=0) + ]).to(torch.int32).cuda() ops.cp_gather_cache( - kv_cache, + kv_cache, flatten_kv, block_tables, cu_seq_lens, B, ) - - ref_logits = deepgemm_mqa_indexer(seq_len, query_seq_len, q, flatten_kv, weights, block_size, topk, is_kv_batched=False) + + ref_logits = deepgemm_mqa_indexer(seq_len, + query_seq_len, + q, + flatten_kv, + weights, + block_size, + topk, + is_kv_batched=False) q_fp8, q_scale = act_quant(q, block_size, "ue8m0") kv_cache_fp8 = kv_cache_cast_to_fp8(kv_cache) - schedule_metadata = get_paged_mqa_logits_metadata( - seq_len.int(), blocksize, 132) - + schedule_metadata = get_paged_mqa_logits_metadata(seq_len.int(), blocksize, + get_num_sms()) + weights = weights.unsqueeze(-1) * (128**(-0.5)) * q_scale.squeeze(1) weights = weights.squeeze(-1) - - logits = fp8_paged_mqa_logits( - q_fp8, kv_cache_fp8, weights, seq_len.int(), block_tables, - schedule_metadata, 4096) - + + logits = fp8_paged_mqa_logits(q_fp8, kv_cache_fp8, weights, seq_len.int(), + block_tables, schedule_metadata, 4096) + concat_logit = [] context = 0 for i in range(B): per_seq_logits = torch.zeros(4096, device="cuda") S = seq_len[i] - per_seq_logits[:S] = ref_logits[i][context: context + S] + per_seq_logits[:S] = ref_logits[i][context:context + S] concat_logit.append(per_seq_logits) context += S ref_logits = torch.stack(concat_logit, dim=0) logits[logits == float("-inf")] = 0 diff = calc_diff(logits, ref_logits) assert diff < 1e-3, f"{diff=}" - + + if __name__ == "__main__": test_prefill_indexer() - test_decode_paged_indexer() \ No newline at end of file + test_decode_paged_indexer() diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index c264a814bebb..56ccff507612 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -82,14 +82,11 @@ def _lazy_init() -> None: global _get_paged_mqa_logits_metadata_impl # fast path - if ( - _fp8_gemm_nt_impl is not None - or _grouped_impl is not None - or _grouped_masked_impl is not None - or _fp8_mqa_logits_impl is not None - or _fp8_paged_mqa_logits_impl is not None - or _get_paged_mqa_logits_metadata_impl is not None - ): + if (_fp8_gemm_nt_impl is not None or _grouped_impl is not None + or _grouped_masked_impl is not None + or _fp8_mqa_logits_impl is not None + or _fp8_paged_mqa_logits_impl is not None + or _get_paged_mqa_logits_metadata_impl is not None): return if not has_deep_gemm(): @@ -109,8 +106,13 @@ def _lazy_init() -> None: _fp8_mqa_logits_impl = getattr(_dg, "fp8_mqa_logits", None) _fp8_paged_mqa_logits_impl = getattr(_dg, "fp8_paged_mqa_logits", None) _get_paged_mqa_logits_metadata_impl = getattr( - _dg, "get_paged_mqa_logits_metadata", None - ) + _dg, "get_paged_mqa_logits_metadata", None) + + +def get_num_sms() -> int: + _lazy_init() + _dg = importlib.import_module("deep_gemm") + return int(_dg.get_num_sms()) def fp8_gemm_nt(*args, **kwargs): @@ -169,10 +171,8 @@ def fp8_mqa_logits( return _fp8_mqa_logits_impl(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke) - -def get_paged_mqa_logits_metadata( - context_lens: torch.Tensor, block_size: int, num_sms: int -) -> torch.Tensor: +def get_paged_mqa_logits_metadata(context_lens: torch.Tensor, block_size: int, + num_sms: int) -> torch.Tensor: """Build scheduling metadata for paged MQA logits. Args: @@ -188,9 +188,8 @@ def get_paged_mqa_logits_metadata( _lazy_init() if _get_paged_mqa_logits_metadata_impl is None: return _missing() - return _get_paged_mqa_logits_metadata_impl( - context_lens, block_size, num_sms - ) + return _get_paged_mqa_logits_metadata_impl(context_lens, block_size, + num_sms) def fp8_paged_mqa_logits( @@ -226,17 +225,14 @@ def fp8_paged_mqa_logits( _lazy_init() if _fp8_paged_mqa_logits_impl is None: return _missing() - return _fp8_paged_mqa_logits_impl( - q_fp8, - kv_cache_fp8, - weights, - context_lens, - block_tables, - schedule_metadata, - max_model_len, - clean_logits=True - ) - + return _fp8_paged_mqa_logits_impl(q_fp8, + kv_cache_fp8, + weights, + context_lens, + block_tables, + schedule_metadata, + max_model_len, + clean_logits=True) def _ceil_to_ue8m0(x: torch.Tensor): @@ -305,5 +301,6 @@ def should_use_deepgemm_for_fp8_linear(output_dtype: torch.dtype, "per_block_cast_to_fp8", "is_deep_gemm_e8m0_used", "is_deep_gemm_supported", + "get_num_sms", "should_use_deepgemm_for_fp8_linear", ] From 6853a0e91a3fa8bdc54f0f82cde997cafa6e2196 Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Sun, 28 Sep 2025 13:49:01 -0400 Subject: [PATCH 38/82] Revert "[Bug] Fix test for Blackwell" --- .../attention/test_deepgemm_attention.py | 143 +++++++++------- tests/kernels/attention/test_indexer.py | 161 +++++++++--------- vllm/utils/deep_gemm.py | 53 +++--- 3 files changed, 187 insertions(+), 170 deletions(-) diff --git a/tests/kernels/attention/test_deepgemm_attention.py b/tests/kernels/attention/test_deepgemm_attention.py index 03cc6b930c94..d6c7c4368de9 100644 --- a/tests/kernels/attention/test_deepgemm_attention.py +++ b/tests/kernels/attention/test_deepgemm_attention.py @@ -1,15 +1,16 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import random - import pytest import torch from vllm.platforms import current_platform -from vllm.utils import cdiv, has_deep_gemm -from vllm.utils.deep_gemm import (_ceil_to_ue8m0, calc_diff, fp8_mqa_logits, - fp8_paged_mqa_logits, get_num_sms, - get_paged_mqa_logits_metadata) +from vllm.utils import has_deep_gemm, cdiv +from vllm.utils.deep_gemm import ( + _ceil_to_ue8m0, + fp8_mqa_logits, + calc_diff, + get_paged_mqa_logits_metadata, + fp8_paged_mqa_logits, +) def kv_cache_cast_to_fp8(x: torch.Tensor) -> torch.Tensor: @@ -24,17 +25,18 @@ def kv_cache_cast_to_fp8(x: torch.Tensor) -> torch.Tensor: device=x.device, dtype=torch.uint8, ) - x_fp8[:, :block_size * head_dim] = x_scaled.view( - num_blocks, block_size * head_dim).view(dtype=torch.uint8) - x_fp8[:, - block_size * head_dim:] = sf.view(num_blocks, - block_size).view(dtype=torch.uint8) + x_fp8[:, : block_size * head_dim] = x_scaled.view( + num_blocks, block_size * head_dim + ).view(dtype=torch.uint8) + x_fp8[:, block_size * head_dim :] = sf.view(num_blocks, block_size).view( + dtype=torch.uint8 + ) return x_fp8.view(num_blocks, block_size, num_heads, head_dim + 4) def per_custom_dims_cast_to_fp8( - x: torch.Tensor, dims: tuple, - use_ue8m0: bool) -> tuple[torch.Tensor, torch.Tensor]: + x: torch.Tensor, dims: tuple, use_ue8m0: bool +) -> tuple[torch.Tensor, torch.Tensor]: excluded_dims = tuple([i for i in range(x.dim()) if i not in set(dims)]) x_amax = x.abs().float().amax(dim=excluded_dims, keepdim=True).clamp(1e-4) sf = x_amax / 448.0 @@ -69,13 +71,17 @@ def _ref_fp8_mqa_logits( q = q.float() k = k.float() - mask_lo = (torch.arange(0, seq_len_kv, device="cuda")[None, :] - >= cu_seqlen_ks[:, None]) - mask_hi = (torch.arange(0, seq_len_kv, device="cuda")[None, :] - < cu_seqlen_ke[:, None]) + mask_lo = ( + torch.arange(0, seq_len_kv, device="cuda")[None, :] + >= cu_seqlen_ks[:, None] + ) + mask_hi = ( + torch.arange(0, seq_len_kv, device="cuda")[None, :] + < cu_seqlen_ke[:, None] + ) mask = mask_lo & mask_hi - score = torch.einsum("mhd,and->hmn", q, k) + score = torch.einsum("mhd,nd->hmn", q, k) logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0) logits = logits.masked_fill(~mask, float("-inf")) @@ -88,8 +94,8 @@ def test_deepgemm_fp8_mqa_logits(): torch.manual_seed(0) random.seed(0) num_heads, head_dim = 32, 128 - for seq_len in (512, ): - for seq_len_kv in (1024, ): + for seq_len in (512,): + for seq_len_kv in (1024,): for disable_cp in (False, True): q = torch.randn( seq_len, @@ -98,24 +104,23 @@ def test_deepgemm_fp8_mqa_logits(): device="cuda", dtype=torch.bfloat16, ) - kv = torch.randn(seq_len_kv, - head_dim, - device="cuda", - dtype=torch.bfloat16) - weights = torch.randn(seq_len, - num_heads, - device="cuda", - dtype=torch.float32) + kv = torch.randn( + seq_len_kv, head_dim, device="cuda", dtype=torch.bfloat16 + ) + weights = torch.randn( + seq_len, num_heads, device="cuda", dtype=torch.float32 + ) if disable_cp: ks = torch.zeros(seq_len, dtype=torch.int, device="cuda") - ke = torch.arange(seq_len, dtype=torch.int, - device="cuda") + (seq_len_kv - seq_len) + ke = torch.arange( + seq_len, dtype=torch.int, device="cuda" + ) + (seq_len_kv - seq_len) else: ks, ke = _generate_cp_test_data(seq_len, seq_len_kv) q_fp8 = q.to(torch.float8_e4m3fn) - kv_fp8 = per_custom_dims_cast_to_fp8(kv, (0, ), False) + kv_fp8 = per_custom_dims_cast_to_fp8(kv, (0,), False) logits = fp8_mqa_logits(q_fp8, kv_fp8, weights, ks, ke) ref_logits = _ref_fp8_mqa_logits( @@ -155,11 +160,14 @@ def _ref_fp8_paged_mqa_logits( context_lens_list = context_lens.tolist() for i in range(batch_size): context_len = context_lens_list[i] - q_offsets = torch.arange(context_len - next_n, - context_len, - device="cuda") - weight_slice = (weights[i * next_n:(i + 1) * next_n, :].transpose( - 0, 1).contiguous()) + q_offsets = torch.arange( + context_len - next_n, context_len, device="cuda" + ) + weight_slice = ( + weights[i * next_n : (i + 1) * next_n, :] + .transpose(0, 1) + .contiguous() + ) for block_rk in range(cdiv(context_len, block_size)): block_idx = block_tables[i][block_rk] qx, kx = q[i], kv_cache[block_idx] @@ -168,21 +176,24 @@ def _ref_fp8_paged_mqa_logits( (block_rk + 1) * block_size, device="cuda", ) - mask = (k_offsets[None, :] < context_len) & (k_offsets[None, :] - <= q_offsets[:, None]) + mask = (k_offsets[None, :] < context_len) & ( + k_offsets[None, :] <= q_offsets[:, None] + ) s = torch.where( mask[None, :, :], (qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to( - logits.dtype), + logits.dtype + ), float("-inf"), ) s = torch.relu(s) * weight_slice[..., None] s = s.sum(dim=0) logits[ - i * next_n:(i + 1) * next_n, - block_rk * block_size:(block_rk + 1) * block_size, - ] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, - float("-inf")) + i * next_n : (i + 1) * next_n, + block_rk * block_size : (block_rk + 1) * block_size, + ] = torch.where( + k_offsets[None, :] <= q_offsets[:, None], s, float("-inf") + ) return logits @@ -194,8 +205,8 @@ def test_deepgemm_fp8_paged_mqa_logits(): max_model_len = 4096 for batch_size, next_n in [(4, 1), (2, 2)]: - for heads, index_dim in [(32, 128)]: - for avg_kv in (2048, ): + for heads, index_dim in [(16, 128)]: + for avg_kv in (2048,): num_blocks, blocksize = max_model_len * 2, 64 q = torch.randn( @@ -214,12 +225,18 @@ def test_deepgemm_fp8_paged_mqa_logits(): dtype=torch.float32, ) - context_lens = (torch.randint(int(0.8 * avg_kv), - int(1.2 * avg_kv), - (batch_size, )).cuda().to( - torch.int32)) - max_block_len = ((context_lens.max().item() + blocksize - 1) // - blocksize * blocksize) + context_lens = ( + torch.randint( + int(0.8 * avg_kv), int(1.2 * avg_kv), (batch_size,) + ) + .cuda() + .to(torch.int32) + ) + max_block_len = ( + (context_lens.max().item() + blocksize - 1) + // blocksize + * blocksize + ) block_tables = torch.zeros( (batch_size, max_block_len), device="cuda", @@ -239,7 +256,8 @@ def test_deepgemm_fp8_paged_mqa_logits(): kv_cache_fp8 = kv_cache_cast_to_fp8(kv_cache) schedule_metadata = get_paged_mqa_logits_metadata( - context_lens, blocksize, get_num_sms()) + context_lens, blocksize, 132 + ) logits = fp8_paged_mqa_logits( q_fp8, kv_cache_fp8, @@ -259,15 +277,20 @@ def test_deepgemm_fp8_paged_mqa_logits(): max_model_len, ) - positions = (torch.arange(max_model_len, - device="cuda").unsqueeze(0).expand( - batch_size * next_n, -1)) + positions = ( + torch.arange(max_model_len, device="cuda") + .unsqueeze(0) + .expand(batch_size * next_n, -1) + ) row_indices = ( - torch.arange(batch_size * next_n, device="cuda") // next_n) + torch.arange(batch_size * next_n, device="cuda") // next_n + ) next_n_offset = ( - torch.arange(batch_size * next_n, device="cuda") % next_n) - mask = positions <= (context_lens[row_indices] - next_n + - next_n_offset).unsqueeze(1) + torch.arange(batch_size * next_n, device="cuda") % next_n + ) + mask = positions <= ( + context_lens[row_indices] - next_n + next_n_offset + ).unsqueeze(1) logits = logits.masked_fill(~mask, 0) ref_logits = ref_logits.masked_fill(~mask, 0) diff --git a/tests/kernels/attention/test_indexer.py b/tests/kernels/attention/test_indexer.py index 696e1587037f..1121892f5172 100644 --- a/tests/kernels/attention/test_indexer.py +++ b/tests/kernels/attention/test_indexer.py @@ -1,17 +1,18 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import random import torch -from vllm import _custom_ops as ops from vllm.utils import cdiv -from vllm.utils.deep_gemm import (calc_diff, fp8_mqa_logits, - fp8_paged_mqa_logits, get_num_sms, - get_paged_mqa_logits_metadata) -from vllm.utils.tile_lang_kernels import act_quant, fp8_index from vllm.v1.attention.backends.mla.indexer import kv_spans_from_batches - +from vllm.utils.tile_lang_kernels import act_quant, fp8_index +from vllm import _custom_ops as ops +from vllm.model_executor.models.deepseek_v2 import indexer_k_quant_and_cache +from vllm.utils.deep_gemm import ( + fp8_mqa_logits, + calc_diff, + get_paged_mqa_logits_metadata, + fp8_paged_mqa_logits, +) def kv_cache_cast_to_fp8(x: torch.Tensor) -> torch.Tensor: num_blocks, block_size, num_heads, head_dim = x.shape @@ -29,56 +30,47 @@ def kv_cache_cast_to_fp8(x: torch.Tensor) -> torch.Tensor: block_size).view(dtype=torch.uint8) return x_fp8.view(num_blocks, block_size, num_heads, head_dim + 4) - def ref_compute_logits_fp8(q, kv, weights, mask, block_size): q_fp8, q_scale = act_quant(q, block_size, "ue8m0") k_fp8, k_scale = act_quant(kv, block_size, "ue8m0") - + weights = weights.unsqueeze(-1) * q_scale weights = weights * (128**(-0.5)) - index_score = fp8_index(q_fp8.contiguous(), weights, k_fp8.contiguous(), - k_scale.contiguous()) + index_score = fp8_index( + q_fp8.contiguous(), weights, + k_fp8.contiguous(), + k_scale.contiguous()) if mask is not None: index_score += mask return index_score - def ref_indexer(seq_len, q, kv, weights, block_size, topk): B = seq_len.shape[0] total_seqlen = torch.sum(seq_len) - varlen_logits = torch.full((total_seqlen, total_seqlen), - float("-inf"), - device="cuda") - + varlen_logits = torch.full((total_seqlen, total_seqlen), float("-inf"), device="cuda") + current_context_ptr = 0 for i in range(B): S = seq_len[i] q_s = q[i][:S].contiguous().unsqueeze(0) kv_s = kv[i][:S].contiguous().unsqueeze(0) weights_s = weights[i][:S].contiguous().unsqueeze(0) - mask = torch.full((S, S), float("-inf"), device="cuda").triu_(1) + mask = torch.full( + (S, S), float("-inf"), + device="cuda").triu_(1) logits = ref_compute_logits_fp8(q_s, kv_s, weights_s, mask, block_size) logits = logits.squeeze(0) - - varlen_logits[current_context_ptr:current_context_ptr + S, - current_context_ptr:current_context_ptr + S] = logits + + varlen_logits[current_context_ptr:current_context_ptr + S, current_context_ptr: current_context_ptr + S] = logits current_context_ptr += S return varlen_logits - -def deepgemm_mqa_indexer(seq_len, - query_seq_len, - q, - kv, - weights, - block_size, - topk, - is_kv_batched=True): +def deepgemm_mqa_indexer(seq_len, query_seq_len, q, kv, weights, block_size, topk, is_kv_batched=True): B = seq_len.shape[0] concat_q = [] concat_kv = [] concat_weights = [] - + for i in range(B): S = seq_len[i] q_s = q[i][:S].contiguous() @@ -89,25 +81,29 @@ def deepgemm_mqa_indexer(seq_len, if is_kv_batched: concat_kv.append(kv_s) concat_weights.append(weight_s) - + q = torch.cat(concat_q, dim=0) if is_kv_batched: kv = torch.cat(concat_kv, dim=0) weights = torch.cat(concat_weights, dim=0) q_fp8, q_scale = act_quant(q, block_size, "ue8m0") kv_fp8, kv_scale = act_quant(kv, block_size, "ue8m0") - + weights = weights.unsqueeze(-1) * (128**(-0.5)) * q_scale weights = weights.squeeze(-1) query_start_loc = torch.empty((B + 1), device="cuda") query_start_loc[0] = 0 query_start_loc[1:] = query_seq_len.cumsum(dim=0).to(dtype=torch.int32) - cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(query_start_loc, - seq_len) + cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(query_start_loc, seq_len) - logits = fp8_mqa_logits(q_fp8, (kv_fp8, kv_scale), weights, cu_seqlen_ks, - cu_seqlen_ke) + logits = fp8_mqa_logits( + q_fp8, + (kv_fp8, kv_scale), + weights, + cu_seqlen_ks, + cu_seqlen_ke + ) topk_indices = logits.topk(topk, dim=-1)[1] mask_lo = topk_indices >= cu_seqlen_ks[:, None] mask_hi = topk_indices < cu_seqlen_ke[:, None] @@ -115,7 +111,6 @@ def deepgemm_mqa_indexer(seq_len, topk_indices = topk_indices.masked_fill(~mask, -1) return logits - def test_prefill_indexer(): B = 3 S = 128 @@ -126,16 +121,16 @@ def test_prefill_indexer(): block_size = 128 topk = 64 device = "cuda" - seq_len = torch.randint(low=64, high=S, size=(B, )) - - q = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16) - kv = torch.randn(B, SKV, D, device="cuda", dtype=torch.bfloat16) - weights = torch.randn(B, S, H, device=device, - dtype=torch.float32) * H**-0.5 + seq_len = torch.randint(low=64, high=S, size=(B,)) + + q = torch.randn(B, S, H, D, device="cuda", + dtype=torch.bfloat16) + kv = torch.randn(B, SKV, D, device="cuda", + dtype=torch.bfloat16) + weights = torch.randn(B, S, H, device=device, dtype=torch.float32) * H**-0.5 ref_logits = ref_indexer(seq_len, q, kv, weights, block_size, topk) - deepgemm_logits = deepgemm_mqa_indexer(seq_len, seq_len, q, kv, weights, - block_size, topk) + deepgemm_logits = deepgemm_mqa_indexer(seq_len, seq_len, q, kv, weights, block_size, topk) torch.testing.assert_close(ref_logits, deepgemm_logits) @@ -150,23 +145,26 @@ def test_decode_paged_indexer(): block_size = 128 topk = 64 device = "cuda" - seq_len = torch.randint(low=64, high=S, size=(B, ), device="cuda") + seq_len = torch.randint(low=64, high=S, size=(B,), device="cuda") query_seq_len = torch.ones(B, device="cuda") - q = torch.randn((B, 1, H, D), device='cuda', dtype=torch.bfloat16) + q = torch.randn((B, 1, H, D), + device='cuda', + dtype=torch.bfloat16) kv_cache = torch.randn((num_blocks, blocksize, 1, D), - device='cuda', - dtype=torch.bfloat16) - weights = torch.randn( - (B * 1, H), device='cuda', dtype=torch.float32) * H**-0.5 + device='cuda', + dtype=torch.bfloat16) + weights = torch.randn((B * 1, H), + device='cuda', + dtype=torch.float32) * H**-0.5 max_block_len = (seq_len.max().item() + blocksize - - 1) // blocksize * blocksize - + 1) // blocksize * blocksize + block_tables = torch.zeros((B, max_block_len), - device='cuda', - dtype=torch.int32) - + device='cuda', + dtype=torch.int32) + counter = 0 block_idx_pool = list(range(num_blocks)) random.shuffle(block_idx_pool) @@ -175,58 +173,51 @@ def test_decode_paged_indexer(): for j in range(cdiv(ctx_len, blocksize)): block_tables[i][j] = block_idx_pool[counter] counter += 1 - - flatten_kv = torch.empty([seq_len.sum(), D], - device="cuda", - dtype=torch.bfloat16) + + flatten_kv = torch.empty( + [seq_len.sum(), D], device="cuda", dtype=torch.bfloat16 + ) cu_seq_lens = torch.cat([ - torch.zeros(1, dtype=torch.int32, device=device), - seq_len.cumsum(dim=0) - ]).to(torch.int32).cuda() + torch.zeros(1, dtype=torch.int32, device=device), + seq_len.cumsum(dim=0) + ]).to(torch.int32).cuda() ops.cp_gather_cache( - kv_cache, + kv_cache, flatten_kv, block_tables, cu_seq_lens, B, ) - - ref_logits = deepgemm_mqa_indexer(seq_len, - query_seq_len, - q, - flatten_kv, - weights, - block_size, - topk, - is_kv_batched=False) + + ref_logits = deepgemm_mqa_indexer(seq_len, query_seq_len, q, flatten_kv, weights, block_size, topk, is_kv_batched=False) q_fp8, q_scale = act_quant(q, block_size, "ue8m0") kv_cache_fp8 = kv_cache_cast_to_fp8(kv_cache) - schedule_metadata = get_paged_mqa_logits_metadata(seq_len.int(), blocksize, - get_num_sms()) - + schedule_metadata = get_paged_mqa_logits_metadata( + seq_len.int(), blocksize, 132) + weights = weights.unsqueeze(-1) * (128**(-0.5)) * q_scale.squeeze(1) weights = weights.squeeze(-1) - - logits = fp8_paged_mqa_logits(q_fp8, kv_cache_fp8, weights, seq_len.int(), - block_tables, schedule_metadata, 4096) - + + logits = fp8_paged_mqa_logits( + q_fp8, kv_cache_fp8, weights, seq_len.int(), block_tables, + schedule_metadata, 4096) + concat_logit = [] context = 0 for i in range(B): per_seq_logits = torch.zeros(4096, device="cuda") S = seq_len[i] - per_seq_logits[:S] = ref_logits[i][context:context + S] + per_seq_logits[:S] = ref_logits[i][context: context + S] concat_logit.append(per_seq_logits) context += S ref_logits = torch.stack(concat_logit, dim=0) logits[logits == float("-inf")] = 0 diff = calc_diff(logits, ref_logits) assert diff < 1e-3, f"{diff=}" - - + if __name__ == "__main__": test_prefill_indexer() - test_decode_paged_indexer() + test_decode_paged_indexer() \ No newline at end of file diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index 56ccff507612..c264a814bebb 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -82,11 +82,14 @@ def _lazy_init() -> None: global _get_paged_mqa_logits_metadata_impl # fast path - if (_fp8_gemm_nt_impl is not None or _grouped_impl is not None - or _grouped_masked_impl is not None - or _fp8_mqa_logits_impl is not None - or _fp8_paged_mqa_logits_impl is not None - or _get_paged_mqa_logits_metadata_impl is not None): + if ( + _fp8_gemm_nt_impl is not None + or _grouped_impl is not None + or _grouped_masked_impl is not None + or _fp8_mqa_logits_impl is not None + or _fp8_paged_mqa_logits_impl is not None + or _get_paged_mqa_logits_metadata_impl is not None + ): return if not has_deep_gemm(): @@ -106,13 +109,8 @@ def _lazy_init() -> None: _fp8_mqa_logits_impl = getattr(_dg, "fp8_mqa_logits", None) _fp8_paged_mqa_logits_impl = getattr(_dg, "fp8_paged_mqa_logits", None) _get_paged_mqa_logits_metadata_impl = getattr( - _dg, "get_paged_mqa_logits_metadata", None) - - -def get_num_sms() -> int: - _lazy_init() - _dg = importlib.import_module("deep_gemm") - return int(_dg.get_num_sms()) + _dg, "get_paged_mqa_logits_metadata", None + ) def fp8_gemm_nt(*args, **kwargs): @@ -171,8 +169,10 @@ def fp8_mqa_logits( return _fp8_mqa_logits_impl(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke) -def get_paged_mqa_logits_metadata(context_lens: torch.Tensor, block_size: int, - num_sms: int) -> torch.Tensor: + +def get_paged_mqa_logits_metadata( + context_lens: torch.Tensor, block_size: int, num_sms: int +) -> torch.Tensor: """Build scheduling metadata for paged MQA logits. Args: @@ -188,8 +188,9 @@ def get_paged_mqa_logits_metadata(context_lens: torch.Tensor, block_size: int, _lazy_init() if _get_paged_mqa_logits_metadata_impl is None: return _missing() - return _get_paged_mqa_logits_metadata_impl(context_lens, block_size, - num_sms) + return _get_paged_mqa_logits_metadata_impl( + context_lens, block_size, num_sms + ) def fp8_paged_mqa_logits( @@ -225,14 +226,17 @@ def fp8_paged_mqa_logits( _lazy_init() if _fp8_paged_mqa_logits_impl is None: return _missing() - return _fp8_paged_mqa_logits_impl(q_fp8, - kv_cache_fp8, - weights, - context_lens, - block_tables, - schedule_metadata, - max_model_len, - clean_logits=True) + return _fp8_paged_mqa_logits_impl( + q_fp8, + kv_cache_fp8, + weights, + context_lens, + block_tables, + schedule_metadata, + max_model_len, + clean_logits=True + ) + def _ceil_to_ue8m0(x: torch.Tensor): @@ -301,6 +305,5 @@ def should_use_deepgemm_for_fp8_linear(output_dtype: torch.dtype, "per_block_cast_to_fp8", "is_deep_gemm_e8m0_used", "is_deep_gemm_supported", - "get_num_sms", "should_use_deepgemm_for_fp8_linear", ] From ed9e42c1ac55fbaea0b9bc39dce79eba12b874ec Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Sun, 28 Sep 2025 10:50:39 -0700 Subject: [PATCH 39/82] fix test --- tests/kernels/attention/test_deepgemm_attention.py | 5 +++-- tests/kernels/attention/test_indexer.py | 3 ++- vllm/utils/deep_gemm.py | 7 +++++++ 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/tests/kernels/attention/test_deepgemm_attention.py b/tests/kernels/attention/test_deepgemm_attention.py index d6c7c4368de9..50c547b84be6 100644 --- a/tests/kernels/attention/test_deepgemm_attention.py +++ b/tests/kernels/attention/test_deepgemm_attention.py @@ -10,6 +10,7 @@ calc_diff, get_paged_mqa_logits_metadata, fp8_paged_mqa_logits, + get_num_sms, ) @@ -205,7 +206,7 @@ def test_deepgemm_fp8_paged_mqa_logits(): max_model_len = 4096 for batch_size, next_n in [(4, 1), (2, 2)]: - for heads, index_dim in [(16, 128)]: + for heads, index_dim in [(32, 128)]: for avg_kv in (2048,): num_blocks, blocksize = max_model_len * 2, 64 @@ -256,7 +257,7 @@ def test_deepgemm_fp8_paged_mqa_logits(): kv_cache_fp8 = kv_cache_cast_to_fp8(kv_cache) schedule_metadata = get_paged_mqa_logits_metadata( - context_lens, blocksize, 132 + context_lens, blocksize, get_num_sms() ) logits = fp8_paged_mqa_logits( q_fp8, diff --git a/tests/kernels/attention/test_indexer.py b/tests/kernels/attention/test_indexer.py index 1121892f5172..5ed6c212e528 100644 --- a/tests/kernels/attention/test_indexer.py +++ b/tests/kernels/attention/test_indexer.py @@ -12,6 +12,7 @@ calc_diff, get_paged_mqa_logits_metadata, fp8_paged_mqa_logits, + get_num_sms, ) def kv_cache_cast_to_fp8(x: torch.Tensor) -> torch.Tensor: @@ -196,7 +197,7 @@ def test_decode_paged_indexer(): kv_cache_fp8 = kv_cache_cast_to_fp8(kv_cache) schedule_metadata = get_paged_mqa_logits_metadata( - seq_len.int(), blocksize, 132) + seq_len.int(), blocksize, get_num_sms()) weights = weights.unsqueeze(-1) * (128**(-0.5)) * q_scale.squeeze(1) weights = weights.squeeze(-1) diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index c264a814bebb..9b95de373fe9 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -113,6 +113,12 @@ def _lazy_init() -> None: ) +def get_num_sms() -> int: + _lazy_init() + _dg = importlib.import_module("deep_gemm") + return int(_dg.get_num_sms()) + + def fp8_gemm_nt(*args, **kwargs): _lazy_init() if _fp8_gemm_nt_impl is None: @@ -305,5 +311,6 @@ def should_use_deepgemm_for_fp8_linear(output_dtype: torch.dtype, "per_block_cast_to_fp8", "is_deep_gemm_e8m0_used", "is_deep_gemm_supported", + "get_num_sms", "should_use_deepgemm_for_fp8_linear", ] From 656ab3c6296d48a437b8f35020b0d04da7cb7f8d Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Sun, 28 Sep 2025 15:09:31 -0400 Subject: [PATCH 40/82] fix num sms (#53) --- vllm/v1/attention/backends/mla/indexer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index 83b363c85d9c..2cefef6206a7 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -3,12 +3,11 @@ from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata from vllm.config import VllmConfig -from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata +from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata, get_num_sms from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, CommonAttentionMetadata, split_decodes_and_prefills) import torch -from vllm.distributed.parallel_state import get_tensor_model_parallel_rank from vllm.logger import init_logger logger = init_logger(__name__) @@ -217,7 +216,7 @@ def build(self, if num_decodes > 0: seq_lens = common_attn_metadata.seq_lens[:num_decodes] schedule_metadata = get_paged_mqa_logits_metadata( - seq_lens, self.kv_cache_spec.block_size, 132) + seq_lens, self.kv_cache_spec.block_size, get_num_sms()) decode_metadata = DeepSeekV32IndexerDecodeMetadata( block_table=common_attn_metadata. block_table_tensor[:num_decodes, ...], From e744e06079606f2f179ad86abf7089d4352feb27 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Sun, 28 Sep 2025 20:33:55 -0400 Subject: [PATCH 41/82] FP8 Cache by using decode kernel only (#40) * Squashed commit of lwilkinson/decode-only changes relative to origin/dev Co-authored-by: Matthew Bonanni Signed-off-by: Lucas Wilkinson * update FlashMLA Signed-off-by: Lucas Wilkinson * fix non-spec error Signed-off-by: Lucas Wilkinson --------- Signed-off-by: Lucas Wilkinson Signed-off-by: Chen Zhang Co-authored-by: Matthew Bonanni --- cmake/external_projects/flashmla.cmake | 155 +++---- csrc/cache_kernels.cu | 152 ++++++- csrc/quantization/fp8/nvidia/quant_utils.cuh | 11 + setup.py | 3 +- tests/compile/test_fusion_attn.py | 1 - tests/kernels/attention/test_cache.py | 113 +++++ tests/kernels/attention/test_flashmla.py | 22 +- .../kernels/attention/test_flashmla_sparse.py | 38 +- tests/v1/attention/test_mla_backends.py | 70 ++- .../v1/attention/test_sparse_mla_backends.py | 426 ++++++++++++++++++ tests/v1/attention/utils.py | 1 - tests/v1/core/test_kv_cache_utils.py | 56 ++- tests/v1/core/test_prefix_caching.py | 1 - .../core/test_single_type_kv_cache_manager.py | 6 - tests/v1/engine/test_engine_core_client.py | 3 +- tests/v1/worker/test_gpu_model_runner.py | 1 - vllm/attention/backends/abstract.py | 1 + .../backends/differential_flash_attn.py | 1 + vllm/attention/backends/flash_attn.py | 1 + vllm/attention/backends/mla/common.py | 1 + vllm/attention/backends/placeholder_attn.py | 1 + vllm/attention/backends/rocm_flash_attn.py | 1 + vllm/attention/backends/xformers.py | 1 + vllm/attention/ops/flashmla.py | 194 +++----- vllm/attention/ops/paged_attn.py | 1 + vllm/model_executor/layers/mla.py | 11 +- vllm/model_executor/models/config.py | 19 + vllm/model_executor/models/deepseek_v2.py | 5 +- vllm/utils/__init__.py | 1 + vllm/v1/attention/backends/cpu_attn.py | 1 + vllm/v1/attention/backends/flash_attn.py | 1 + vllm/v1/attention/backends/flashinfer.py | 1 + vllm/v1/attention/backends/flex_attention.py | 1 + vllm/v1/attention/backends/mla/common.py | 247 +++++----- .../attention/backends/mla/flashmla_sparse.py | 366 ++++++++------- vllm/v1/attention/backends/mla/indexer.py | 1 + vllm/v1/attention/backends/pallas.py | 1 + vllm/v1/attention/backends/rocm_aiter_fa.py | 1 + vllm/v1/attention/backends/tree_attn.py | 1 + vllm/v1/attention/backends/triton_attn.py | 1 + vllm/v1/attention/backends/utils.py | 1 - vllm/v1/attention/backends/xformers.py | 1 + vllm/v1/core/kv_cache_utils.py | 2 - vllm/v1/core/single_type_kv_cache_manager.py | 2 + vllm/v1/kv_cache_interface.py | 51 ++- vllm/v1/worker/gpu_model_runner.py | 36 +- vllm/v1/worker/tpu_model_runner.py | 2 - vllm/worker/cache_engine.py | 3 +- 48 files changed, 1412 insertions(+), 605 deletions(-) create mode 100644 tests/v1/attention/test_sparse_mla_backends.py diff --git a/cmake/external_projects/flashmla.cmake b/cmake/external_projects/flashmla.cmake index b8a0b0394771..1e15cd168489 100644 --- a/cmake/external_projects/flashmla.cmake +++ b/cmake/external_projects/flashmla.cmake @@ -19,7 +19,7 @@ else() FetchContent_Declare( flashmla GIT_REPOSITORY https://github.com/vllm-model-0920/FlashMLA - GIT_TAG a25b977fae6925c45c3d0404c98c6ce6f4563dac + GIT_TAG c2726ac45add214249698c7d7053851b9f3e54a4 GIT_PROGRESS TRUE CONFIGURE_COMMAND "" BUILD_COMMAND "" @@ -33,27 +33,64 @@ message(STATUS "FlashMLA is available at ${flashmla_SOURCE_DIR}") # The FlashMLA kernels only work on hopper and require CUDA 12.3 or later. # Only build FlashMLA kernels if we are building for something compatible with # sm90a -cuda_archs_loose_intersection(FLASH_MLA_ARCHS "9.0a" "${CUDA_ARCHS}") -if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS) - ####################################################################### - # FlashMLA Dense -- _flashmla_C - ####################################################################### + +set(SUPPORT_ARCHS) +if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3) + list(APPEND SUPPORT_ARCHS 9.0a) +endif() +if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.9) + list(APPEND SUPPORT_ARCHS 10.0a) +endif() + + +cuda_archs_loose_intersection(FLASH_MLA_ARCHS "${SUPPORT_ARCHS}" "${CUDA_ARCHS}") +if(FLASH_MLA_ARCHS) + set(VLLM_FLASHMLA_GPU_FLAGS ${VLLM_GPU_FLAGS}) + list(APPEND VLLM_FLASHMLA_GPU_FLAGS "--expt-relaxed-constexpr" "--expt-extended-lambda" "--use_fast_math") set(FlashMLA_SOURCES - ${flashmla_SOURCE_DIR}/csrc/flash_api.cpp - ${flashmla_SOURCE_DIR}/csrc/kernels/get_mla_metadata.cu - ${flashmla_SOURCE_DIR}/csrc/kernels/mla_combine.cu - ${flashmla_SOURCE_DIR}/csrc/kernels/splitkv_mla.cu - ${flashmla_SOURCE_DIR}/csrc/kernels_fp8/flash_fwd_mla_fp8_sm90.cu) + ${flashmla_SOURCE_DIR}/csrc/torch_api.cpp + ${flashmla_SOURCE_DIR}/csrc/pybind.cpp + ${flashmla_SOURCE_DIR}/csrc/smxx/get_mla_metadata.cu + ${flashmla_SOURCE_DIR}/csrc/smxx/mla_combine.cu + ${flashmla_SOURCE_DIR}/csrc/sm90/decode/dense/splitkv_mla.cu + ${flashmla_SOURCE_DIR}/csrc/sm90/decode/sparse_fp8/splitkv_mla.cu + ${flashmla_SOURCE_DIR}/csrc/sm90/prefill/sparse/fwd.cu + ${flashmla_SOURCE_DIR}/csrc/sm100/decode/sparse_fp8/splitkv_mla.cu + ${flashmla_SOURCE_DIR}/csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu + ${flashmla_SOURCE_DIR}/csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cu + ${flashmla_SOURCE_DIR}/csrc/sm100/prefill/sparse/fwd.cu + ) + + set(FlashMLA_Extension_SOURCES + ${flashmla_SOURCE_DIR}/csrc/extension/torch_api.cpp + ${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/pybind.cpp + ${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/flash_fwd_mla_fp8_sm90.cu + ) set(FlashMLA_INCLUDES + ${flashmla_SOURCE_DIR}/csrc + ${flashmla_SOURCE_DIR}/csrc/sm90 ${flashmla_SOURCE_DIR}/csrc/cutlass/include - ${flashmla_SOURCE_DIR}/csrc) + ${flashmla_SOURCE_DIR}/csrc/cutlass/tools/util/include + ) + + set(FlashMLA_Extension_INCLUDES + ${flashmla_SOURCE_DIR}/csrc + ${flashmla_SOURCE_DIR}/csrc/sm90 + ${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/ + ${flashmla_SOURCE_DIR}/csrc/cutlass/include + ${flashmla_SOURCE_DIR}/csrc/cutlass/tools/util/include + ) set_gencode_flags_for_srcs( SRCS "${FlashMLA_SOURCES}" CUDA_ARCHS "${FLASH_MLA_ARCHS}") + set_gencode_flags_for_srcs( + SRCS "${FlashMLA_Extension_SOURCES}" + CUDA_ARCHS "${FLASH_MLA_ARCHS}") + define_gpu_extension_target( _flashmla_C DESTINATION vllm @@ -64,90 +101,32 @@ if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS) INCLUDE_DIRECTORIES ${FlashMLA_INCLUDES} USE_SABI 3 WITH_SOABI) - - ####################################################################### - # FlashMLA Sparse -- _flashmla_sparse_C - ####################################################################### - - # We use seperate libraries to avoid crosss contaminating includes, - # namely kernels/utils.h - - set(DECODE_FOLDER ${flashmla_SOURCE_DIR}/csrc/sparse/decode) - set(PREFILL_FOLDER ${flashmla_SOURCE_DIR}/csrc/sparse/prefill) - - # ---- Decode object library ---- - set(SPARSE_FLASHMLA_DECODE_SOURCES - ${DECODE_FOLDER}/flash_api.cpp - ${DECODE_FOLDER}/kernels/get_mla_metadata.cu - ${DECODE_FOLDER}/kernels/mla_combine.cu - ${DECODE_FOLDER}/kernels/fp8_sparse/splitkv_mla.cu - ) - add_library(_flashmla_sparse_decode OBJECT ${SPARSE_FLASHMLA_DECODE_SOURCES}) - set_property(TARGET _flashmla_sparse_decode PROPERTY POSITION_INDEPENDENT_CODE ON) + # Keep Stable ABI for the module, but *not* for CUDA/C++ files. + # This prevents Py_LIMITED_API from affecting nvcc and C++ compiles. + target_compile_options(_flashmla_C PRIVATE + $<$:-UPy_LIMITED_API> + $<$:-UPy_LIMITED_API>) - set_gencode_flags_for_srcs( - SRCS "${SPARSE_FLASHMLA_DECODE_SOURCES}" - CUDA_ARCHS "${FLASH_MLA_ARCHS}" - ) - - # Include paths for decode ONLY (do not leak DECODE_FOLDER to others) - target_include_directories(_flashmla_sparse_decode - PRIVATE - ${flashmla_SOURCE_DIR}/csrc/cutlass/include - ${TORCH_INCLUDE_DIRS} - ${Python_INCLUDE_DIRS} - ${DECODE_FOLDER} - ) - target_compile_options(_flashmla_sparse_decode PRIVATE - $<$:${VLLM_GPU_FLAGS}>) - - # ---- Prefill object library ---- - set(SPARSE_FLASHMLA_PREFILL_SOURCES - ${PREFILL_FOLDER}/api.cpp - ${PREFILL_FOLDER}/kernels/sm90/fwd/fwd.cu - ) - - add_library(_flashmla_sparse_prefill OBJECT ${SPARSE_FLASHMLA_PREFILL_SOURCES}) - set_property(TARGET _flashmla_sparse_prefill PROPERTY POSITION_INDEPENDENT_CODE ON) - - set_gencode_flags_for_srcs( - SRCS "${SPARSE_FLASHMLA_PREFILL_SOURCES}" - CUDA_ARCHS "${FLASH_MLA_ARCHS}" - ) - - target_include_directories(_flashmla_sparse_prefill - PRIVATE - ${flashmla_SOURCE_DIR}/csrc/cutlass/include - ${TORCH_INCLUDE_DIRS} - ${Python_INCLUDE_DIRS} - ${PREFILL_FOLDER} - ) - target_compile_options(_flashmla_sparse_prefill PRIVATE - $<$:${VLLM_GPU_FLAGS}>) - - # ---- Final extension target with unified API ---- define_gpu_extension_target( - _flashmla_sparse_C + _flashmla_extension_C DESTINATION vllm LANGUAGE ${VLLM_GPU_LANG} - SOURCES - ${flashmla_SOURCE_DIR}/csrc/sparse/api.cpp - $ - $ - COMPILE_FLAGS ${VLLM_GPU_FLAGS} + SOURCES ${FlashMLA_Extension_SOURCES} + COMPILE_FLAGS ${VLLM_FLASHMLA_GPU_FLAGS} ARCHITECTURES ${VLLM_GPU_ARCHES} - # Only the common/public includes here; do NOT add decode/prefill folders - INCLUDE_DIRECTORIES - ${flashmla_SOURCE_DIR}/csrc/ - ${CUTLASS_INCLUDE_DIR} - ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} + INCLUDE_DIRECTORIES ${FlashMLA_Extension_INCLUDES} USE_SABI 3 - WITH_SOABI - ) + WITH_SOABI) + + # Keep Stable ABI for the module, but *not* for CUDA/C++ files. + # This prevents Py_LIMITED_API from affecting nvcc and C++ compiles. + target_compile_options(_flashmla_extension_C PRIVATE + $<$:-UPy_LIMITED_API> + $<$:-UPy_LIMITED_API>) else() - # Create an empty target for setup.py when not targeting sm90a systems + # Create empty targets for setup.py when not targeting sm90a systems add_custom_target(_flashmla_C) - add_custom_target(_flashmla_sparse_C) + add_custom_target(_flashmla_extension_C) endif() diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 80b4c47c5547..422f6907083f 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -16,6 +16,7 @@ #include #include +#include // FLT_MIN #include #include @@ -396,6 +397,109 @@ __global__ void concat_and_cache_mla_kernel( copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank); } +template +__global__ void concat_and_cache_ds_mla_kernel( + const scalar_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank] + const scalar_t* __restrict__ k_pe, // [num_tokens, pe_dim] + cache_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank + // + pe_dim)] + const int64_t* __restrict__ slot_mapping, // [num_tokens] + const int block_stride, // + const int entry_stride, // + const int kv_c_stride, // + const int k_pe_stride, // + const int kv_lora_rank, // + const int pe_dim, // + const int block_size, // + const float* scale // +) { + const int64_t token_idx = blockIdx.x; + const int64_t slot_idx = slot_mapping[token_idx]; + // NOTE: slot_idx can be -1 if the token is padded + if (slot_idx < 0) { + return; + } + const int64_t block_idx = slot_idx / block_size; + const int64_t block_offset = slot_idx % block_size; + const int64_t dst_idx_start = + block_idx * block_stride + block_offset * entry_stride; + + // Create 4 tile scales in shared memory + __shared__ float smem[20]; + float* shard_abs_max = smem; + float* tile_scales = smem + 16; + + // For the NoPE part, each tile of 128 elements is handled by 4 warps + // (128 threads). There are 4 total tiles, so 16 warps (512 threads). + // The first thread of the first warp in each tile writes the scale + // value for the tile. The RoPE part (last 64 elements) is handled + // by another 2 warps (64 threads). + // So in total, we use 18 warps (576 threads) per block. + + // Cast kv_cache to 16_bit for RoPE values + scalar_t* kv_cache_16bit = + reinterpret_cast(&kv_cache[dst_idx_start]); + + // The last 64 threads handle the RoPE part + if (threadIdx.x >= kv_lora_rank) { + const int8_t pe_idx = threadIdx.x - kv_lora_rank; + const int64_t src_idx = token_idx * k_pe_stride + pe_idx; + // RoPE values start after the packed 8-bit NoPE values and the + // 32-bit scales + const int64_t dst_idx = kv_lora_rank / 2 + 8 + pe_idx; + kv_cache_16bit[dst_idx] = k_pe[src_idx]; + return; + } + + // Determine the scale for each chunk of NoPE + const int16_t tile_idx = threadIdx.x >> 7; + const int16_t warp_idx = (threadIdx.x & 127) >> 5; + const int16_t lane_idx = threadIdx.x & 31; + + // Load the NoPE element for this thread into registers + const int64_t src_idx = token_idx * kv_c_stride + threadIdx.x; + const scalar_t src_val = kv_c[src_idx]; + + // Warp-level reduction to find the max absolute value in the warp + float max_abs = fabsf(src_val); +#pragma unroll + for (int offset = 16; offset > 0; offset /= 2) { + max_abs = fmaxf(max_abs, __shfl_down_sync(0xFFFFFFFF, max_abs, offset)); + } + + // The first lane of each warp in each tile writes the max_abs of this part + // of the tile to shared memory + if (lane_idx == 0) { + shard_abs_max[tile_idx * 4 + warp_idx] = max_abs; + } + __syncthreads(); + + // The first lane of the first warp in each tile computes the scale for the + // tile and writes it to shared memory and to kv_cache + if (warp_idx == 0 && lane_idx == 0) { + float4 shard_abs_max_vec = + reinterpret_cast(shard_abs_max)[tile_idx]; + float tile_scale = fmaxf(fmaxf(shard_abs_max_vec.x, shard_abs_max_vec.y), + fmaxf(shard_abs_max_vec.z, shard_abs_max_vec.w)) / + 448.f; + + // Avoid division by zero in `scaled_convert` + tile_scales[tile_idx] = fmaxf(tile_scale, FLT_MIN); + float* kv_cache_32bit = reinterpret_cast(&kv_cache[dst_idx_start]); + const uint64_t dst_idx = kv_lora_rank / 4 + tile_idx; + kv_cache_32bit[dst_idx] = tile_scales[tile_idx]; + } + + __syncthreads(); + + // Now all threads in the block scale and write their element + const float scale_val = tile_scales[tile_idx]; + const int64_t dst_idx = dst_idx_start + threadIdx.x; + kv_cache[dst_idx] = + fp8::scaled_convert( + src_val, scale_val); +} + } // namespace vllm // KV_T is the data type of key and value tensors. @@ -438,7 +542,7 @@ void reshape_and_cache( const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype, - CALL_RESHAPE_AND_CACHE) + CALL_RESHAPE_AND_CACHE); } // KV_T is the data type of key and value tensors. @@ -509,6 +613,18 @@ void reshape_and_cache_flash( kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \ reinterpret_cast(scale.data_ptr())); +// KV_T is the data type of key and value tensors. +// CACHE_T is the stored data type of kv-cache. +#define CALL_CONCAT_AND_CACHE_DS_MLA(KV_T, CACHE_T, KV_DTYPE) \ + vllm::concat_and_cache_ds_mla_kernel \ + <<>>( \ + reinterpret_cast(kv_c.data_ptr()), \ + reinterpret_cast(k_pe.data_ptr()), \ + reinterpret_cast(kv_cache.data_ptr()), \ + slot_mapping.data_ptr(), block_stride, entry_stride, \ + kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \ + reinterpret_cast(scale.data_ptr())); + void concat_and_cache_mla( torch::Tensor& kv_c, // [num_tokens, kv_lora_rank] torch::Tensor& k_pe, // [num_tokens, pe_dim] @@ -531,20 +647,44 @@ void concat_and_cache_mla( int pe_dim = k_pe.size(1); int block_size = kv_cache.size(1); - TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim); + if (kv_cache_dtype == "fp8_ds_mla") { + TORCH_CHECK(kv_lora_rank == 512, "kv_lora_rank must be 512 for fp8_ds_mla"); + TORCH_CHECK(pe_dim == 64, "pe_dim must be 64 for fp8_ds_mla"); + TORCH_CHECK(kv_cache.size(2) == 656 / kv_cache.itemsize(), + "kv_cache.size(2) must be 656 bytes for fp8_ds_mla"); + TORCH_CHECK(kv_c.itemsize() == 2, + "kv_c.itemsize() must be 2 for fp8_ds_mla"); + TORCH_CHECK(k_pe.itemsize() == 2, + "k_pe.itemsize() must be 2 for fp8_ds_mla"); + } else { + TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim); + } int kv_c_stride = kv_c.stride(0); int k_pe_stride = k_pe.stride(0); int block_stride = kv_cache.stride(0); int entry_stride = kv_cache.stride(1); - dim3 grid(num_tokens); - dim3 block(std::min(kv_lora_rank, 512)); const at::cuda::OptionalCUDAGuard device_guard(device_of(kv_c)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype, - CALL_CONCAT_AND_CACHE_MLA); + if (kv_cache_dtype == "fp8_ds_mla") { + dim3 grid(num_tokens); + // For the NoPE part, each tile of 128 elements is handled by 4 warps + // (128 threads). There are 4 total tiles, so 16 warps (512 threads). + // The first thread of the first warp in each tile writes the scale + // value for the tile. The RoPE part (last 64 elements) is handled + // by another 2 warps (64 threads). + // So in total, we use 18 warps (576 threads) per block. + dim3 block(576); + DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype, + CALL_CONCAT_AND_CACHE_DS_MLA); + } else { + dim3 grid(num_tokens); + dim3 block(std::min(kv_lora_rank, 512)); + DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype, + CALL_CONCAT_AND_CACHE_MLA); + } } namespace vllm { diff --git a/csrc/quantization/fp8/nvidia/quant_utils.cuh b/csrc/quantization/fp8/nvidia/quant_utils.cuh index 5b9c2df8468c..5361a8b1a598 100644 --- a/csrc/quantization/fp8/nvidia/quant_utils.cuh +++ b/csrc/quantization/fp8/nvidia/quant_utils.cuh @@ -576,6 +576,17 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) { TORCH_CHECK(false, \ "Unsupported input type of kv cache: ", SRC_DTYPE); \ } \ + } else if (KV_DTYPE == "fp8_ds_mla") { \ + if (SRC_DTYPE == at::ScalarType::Float) { \ + FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else if (SRC_DTYPE == at::ScalarType::Half) { \ + FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ + FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else { \ + TORCH_CHECK(false, \ + "Unsupported input type of kv cache: ", SRC_DTYPE); \ + } \ } else { \ TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \ } \ diff --git a/setup.py b/setup.py index ca8fd08a57fb..6434bada6898 100644 --- a/setup.py +++ b/setup.py @@ -322,6 +322,7 @@ def extract_precompiled_and_patch_package(wheel_url_or_path: str) -> dict: "vllm/_C.abi3.so", "vllm/_moe_C.abi3.so", "vllm/_flashmla_C.abi3.so", + "vllm/_flashmla_extension_C.abi3.so", "vllm/_sparse_flashmla_C.abi3.so", "vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so", "vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so", @@ -591,7 +592,7 @@ def _read_requirements(filename: str) -> list[str]: ext_modules.append( CMakeExtension(name="vllm._flashmla_C", optional=True)) ext_modules.append( - CMakeExtension(name="vllm._flashmla_sparse_C", optional=True)) + CMakeExtension(name="vllm._flashmla_extension_C", optional=True)) ext_modules.append(CMakeExtension(name="vllm.cumem_allocator")) if _build_custom_ops(): diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index 022f183b3193..76e82bfa8087 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -191,7 +191,6 @@ def __init__(self, num_qo_heads: int, num_kv_heads: int, head_size: int, num_kv_heads=self.num_kv_heads, head_size=self.head_size, dtype=self.kv_cache_dtype, - use_mla=False, ), layer_names=[self.attn.layer_name], vllm_config=self.vllm_config, diff --git a/tests/kernels/attention/test_cache.py b/tests/kernels/attention/test_cache.py index 69e96dfd2cb1..75bdcb6808b9 100644 --- a/tests/kernels/attention/test_cache.py +++ b/tests/kernels/attention/test_cache.py @@ -578,6 +578,119 @@ def test_concat_and_cache_mla( torch.testing.assert_close(kv_cache, ref_kv_cache) +@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS) +@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS) +@pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA) +@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA) +@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_concat_and_cache_ds_mla( + kv_lora_rank: int, + qk_rope_head_dim: int, + num_tokens: int, + block_size: int, + num_blocks: int, + dtype: torch.dtype, + seed: int, + device: str, +) -> None: + if dtype.itemsize != 2: + pytest.skip("ds_mla only supports 16-bit input") + kv_cache_dtype = "fp8_ds_mla" + current_platform.seed_everything(seed) + torch.set_default_device(device) + + total_slots = num_blocks * block_size + slot_mapping_lst = random.sample(range(total_slots), num_tokens) + slot_mapping = torch.tensor(slot_mapping_lst, + dtype=torch.long, + device=device) + + kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device) + k_pe = torch.randn(num_tokens, + qk_rope_head_dim, + dtype=dtype, + device=device) + entry_size = kv_lora_rank + (4 * 4) + (2 * qk_rope_head_dim) + + scale = torch.tensor(1.0, dtype=torch.float32, device=device) + kv_cache = _create_mla_cache(num_blocks, + block_size, + entry_size, + dtype=torch.uint8, + kv_cache_dtype=kv_cache_dtype, + device=device) + + ref_cache = torch.zeros_like(kv_cache, dtype=kv_cache.dtype) + tile_data = torch.zeros(128, dtype=dtype, device=device) + + for i in range(num_tokens): + slot = slot_mapping[i].item() + block_idx = slot // block_size + block_offset = slot % block_size + + ref_cache_slice = ref_cache[block_idx, block_offset] + ref_cache_16bit = ref_cache_slice.view(dtype) + ref_cache_32bit = ref_cache_slice.view(torch.float32) + + kv_c_data = kv_c[i] + for tile_idx in range(4): + tile_start = tile_idx * 128 + tile_end = (tile_idx + 1) * 128 + tile_data[:] = kv_c_data[tile_start:tile_end] + + # tile_scale = tile_data.amax().to(torch.float32) / 448. + # NOTE: Using torch's amax() gives different results, + # so this must be manually computed. + tile_data_float = tile_data.to(torch.float32) + manual_max = abs(tile_data_float[0]) + for j in range(1, 128): + manual_max = max(manual_max, abs(tile_data_float[j])) + tile_scale = manual_max / 448. + + ref_cache_32bit[kv_lora_rank // 4 + tile_idx] = tile_scale + + ops.convert_fp8(ref_cache_slice[tile_start:tile_end], + tile_data, + tile_scale.item(), + kv_dtype="fp8") + + for j in range(qk_rope_head_dim): + ref_cache_16bit[kv_lora_rank // 2 + 8 + j] = k_pe[i, j] + + opcheck( + torch.ops._C_cache_ops.concat_and_cache_mla, + (kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale), + test_utils=DEFAULT_OPCHECK_TEST_UTILS, + ) + + ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping, + kv_cache_dtype, scale) + + for i in range(num_tokens): + slot = slot_mapping[i].item() + block_idx = slot // block_size + block_offset = slot % block_size + kv_cache_slice = kv_cache[block_idx, block_offset] + ref_cache_slice = ref_cache[block_idx, block_offset] + + kv_nope = kv_cache_slice[:kv_lora_rank] + ref_nope = ref_cache_slice[:kv_lora_rank] + kv_scales = kv_cache_slice.view(torch.float32)[kv_lora_rank // + 4:kv_lora_rank // 4 + 4] + ref_scales = ref_cache_slice.view( + torch.float32)[kv_lora_rank // 4:kv_lora_rank // 4 + 4] + kv_rope = kv_cache_slice.view(dtype)[kv_lora_rank // 2 + 8:] + ref_rope = ref_cache_slice.view(dtype)[kv_lora_rank // 2 + 8:] + + torch.testing.assert_close(kv_nope, ref_nope, atol=0.001, rtol=0.1) + torch.testing.assert_close(kv_scales, ref_scales, atol=0.001, rtol=0.1) + torch.testing.assert_close(kv_rope, ref_rope, atol=0.001, rtol=0.1) + + @pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS) @pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS) @pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA) diff --git a/tests/kernels/attention/test_flashmla.py b/tests/kernels/attention/test_flashmla.py index abcfe828d5ac..bddd7e5c50ed 100644 --- a/tests/kernels/attention/test_flashmla.py +++ b/tests/kernels/attention/test_flashmla.py @@ -97,18 +97,16 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, descale_k = None def flash_mla(): - return flash_mla_with_kvcache( - q, - blocked_k, - block_table, - cache_seqlens, - dv, - tile_scheduler_metadata, - num_splits, - causal=causal, - descale_q=descale_q, - descale_k=descale_k, - ) + return flash_mla_with_kvcache(q, + blocked_k, + block_table, + cache_seqlens, + dv, + tile_scheduler_metadata, + num_splits, + causal=causal, + descale_q=descale_q, + descale_k=descale_k) def scaled_dot_product_attention(query, key, value, is_causal=False): query = query.float() diff --git a/tests/kernels/attention/test_flashmla_sparse.py b/tests/kernels/attention/test_flashmla_sparse.py index 6488e0c01e0c..62ff7f65a0a2 100644 --- a/tests/kernels/attention/test_flashmla_sparse.py +++ b/tests/kernels/attention/test_flashmla_sparse.py @@ -9,19 +9,6 @@ def _cuda_sm90_available() -> bool: return major == 9 -@pytest.mark.cuda -def test_sparse_flashmla_imports_and_flags(): - import vllm.attention.ops.flashmla as fm - # Functions should exist - assert hasattr(fm, "get_sparse_mla_metadata") - assert hasattr(fm, "flash_mla_sparse_with_kvcache") - assert hasattr(fm, "flash_mla_sparse_prefill") - # Support check should return a (bool, reason) - ok, reason = fm.is_flashmla_supported() - assert isinstance(ok, bool) - assert (reason is None) or isinstance(reason, str) - - def test_sparse_flashmla_metadata_smoke(): import vllm.attention.ops.flashmla as fm ok, reason = fm.is_flashmla_supported() @@ -34,16 +21,16 @@ def test_sparse_flashmla_metadata_smoke(): num_heads_q = 128 num_heads_k = 1 q_seq_per_hk = seqlen_q * num_heads_q // num_heads_k - q_heads_per_hk = num_heads_q // num_heads_k topk = 128 cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device) - tile_md, num_splits = fm.get_sparse_mla_metadata(cache_seqlens, - q_seq_per_hk, - num_heads_k, - topk, - q_heads_per_hk) + tile_md, num_splits = fm.get_mla_metadata(cache_seqlens, + q_seq_per_hk, + num_heads_k, + num_heads_q=num_heads_q, + topk=topk, + is_fp8_kvcache=True) assert tile_md.dtype == torch.int32 assert num_splits.dtype == torch.int32 @@ -69,11 +56,13 @@ def test_sparse_flashmla_decode_smoke(): q_seq_per_hk = seqlen_q * num_heads_q // num_heads_k q_heads_per_hk = num_heads_q // num_heads_k cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device) - tile_md, num_splits = fm.get_sparse_mla_metadata(cache_seqlens, + tile_md, num_splits = fm.get_mla_metadata(cache_seqlens, + q_seq_per_hk, num_heads_k, - topk, - q_heads_per_hk) + num_heads_q=num_heads_q, + topk=topk, + is_fp8_kvcache=True) # Inputs q = torch.zeros((batch_size, seqlen_q, num_heads_q, head_dim_k), @@ -86,9 +75,10 @@ def test_sparse_flashmla_decode_smoke(): dtype=torch.int32, device=device) - out, lse = fm.flash_mla_sparse_with_kvcache(q, k_cache, cache_seqlens, + block_table = torch.zeros((batch_size, 128), dtype=torch.int32, device=device) + out, lse = fm.flash_mla_with_kvcache(q, k_cache, block_table, cache_seqlens, head_dim_v, tile_md, - num_splits, indices) + num_splits, indices=indices, is_fp8_kvcache=True) assert out.shape[0] == batch_size assert out.shape[-1] == head_dim_v assert lse.shape[0] == batch_size diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py index a62993950aff..a5a7af7c0250 100644 --- a/tests/v1/attention/test_mla_backends.py +++ b/tests/v1/attention/test_mla_backends.py @@ -5,6 +5,8 @@ import pytest import torch +from vllm import _custom_ops as ops + from tests.v1.attention.utils import (BatchSpec, _Backend, create_common_attn_metadata, create_standard_kv_cache_spec, @@ -78,7 +80,9 @@ def create_and_prepopulate_kv_cache( device: torch.device, num_blocks: int, common_attn_metadata: CommonAttentionMetadata, - randomize_blocks: bool = True) -> torch.Tensor: + randomize_blocks: bool = True, + kv_cache_dtype: str | None = None, + scale: float | torch.Tensor = 1.0) -> torch.Tensor: """Create and prepopulate an MLA KV cache with context data. Args: @@ -93,6 +97,11 @@ def create_and_prepopulate_kv_cache( common_attn_metadata: Common attention metadata randomize_blocks: Whether to randomly permute blocks or use sequential order + kv_cache_dtype: Optional kv cache dtype string. When set to + "fp8_ds_mla" the cache is populated using the + fp8 DeepSeek MLA layout via concat_and_cache_mla. + scale: Scaling factor forwarded to concat_and_cache_mla when the + fp8 cache layout is requested. Returns: MLA KV cache tensor @@ -105,23 +114,62 @@ def create_and_prepopulate_kv_cache( block_table = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping - # Create MLA KV cache: (num_blocks, block_size, head_size) - kv_cache = torch.empty(num_blocks, - block_size, - head_size, - dtype=dtype, - device=device) - kv_cache_flat = kv_cache.view(-1, head_size) + use_fp8_ds_mla = kv_cache_dtype == "fp8_ds_mla" + + if use_fp8_ds_mla: + if not kv_c_contexts: + raise ValueError("kv_c_contexts cannot be empty when using" + " fp8_ds_mla cache dtype") + kv_lora_rank = kv_c_contexts[0].shape[-1] + rope_dim = k_pe_contexts[0].shape[-1] + entry_size = kv_lora_rank + 4 * 4 + 2 * rope_dim + kv_cache = torch.zeros(num_blocks, + block_size, + entry_size, + dtype=torch.uint8, + device=device) + scale_tensor = (scale if isinstance(scale, torch.Tensor) else + torch.tensor(scale, dtype=torch.float32, + device=device)) + scale_tensor = scale_tensor.to(device=device, dtype=torch.float32) + else: + # Create MLA KV cache: (num_blocks, block_size, head_size) + kv_cache = torch.empty(num_blocks, + block_size, + head_size, + dtype=dtype, + device=device) + kv_cache_flat = kv_cache.view(-1, head_size) # Populate the cache with the context tokens # Start from block_id=1 since block_id=0 is considered the null block start_block_idx = 1 for i in range(batch_size): kv_c_context, k_pe_context = kv_c_contexts[i], k_pe_contexts[i] - kv_context = torch.cat([kv_c_context, k_pe_context.squeeze(1)], dim=-1) + context_len = kv_c_context.shape[0] + if context_len == 0: + start_block_idx += cdiv(int(seq_lens[i]), block_size) + continue + start = start_block_idx * block_size - end = start + kv_context.shape[0] - kv_cache_flat[start:end, ...] = kv_context + + if use_fp8_ds_mla: + slots = torch.arange(context_len, + device=device, + dtype=torch.long) + start + ops.concat_and_cache_mla( + kv_c_context, + k_pe_context.squeeze(1), + kv_cache, + slots, + kv_cache_dtype="fp8_ds_mla", + scale=scale_tensor, + ) + else: + kv_context = torch.cat([kv_c_context, k_pe_context.squeeze(1)], + dim=-1) + end = start + kv_context.shape[0] + kv_cache_flat[start:end, ...] = kv_context # Stay block aligned and allocate enough blocks for the new tokens start_block_idx += cdiv(int(seq_lens[i]), block_size) diff --git a/tests/v1/attention/test_sparse_mla_backends.py b/tests/v1/attention/test_sparse_mla_backends.py new file mode 100644 index 000000000000..74eea6f716fe --- /dev/null +++ b/tests/v1/attention/test_sparse_mla_backends.py @@ -0,0 +1,426 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for the FlashMLA sparse backend utilities.""" + +import math +from types import MethodType, SimpleNamespace + +import numpy as np +import pytest +import torch + +from tests.v1.attention.test_mla_backends import ( + BATCH_SPECS, BatchSpec, MockAttentionLayer, + create_and_prepopulate_kv_cache) +from tests.v1.attention.utils import (create_common_attn_metadata, + create_standard_kv_cache_spec, + create_vllm_config) +from vllm import _custom_ops as ops +from vllm.attention.ops import flashmla +from vllm.model_executor.layers.linear import ColumnParallelLinear +from vllm.utils import cdiv +from vllm.v1.attention.backends.mla.flashmla_sparse import ( + FlashMLASparseBackend, FlashMLASparseDecodeAndContextMetadata, + FlashMLASparseImpl, FlashMLASparseMetadata) + +SPARSE_BACKEND_BATCH_SPECS = { + name: BATCH_SPECS[name] + for name in [ + "mixed_small", + "mixed_medium", + "small_prefill", + "medium_prefill", + "single_prefill", + ] +} + +SPARSE_BACKEND_BATCH_SPECS["large_q_prefill"] = BatchSpec(seq_lens=[1024] * 2, + query_lens=[256] * 2) +SPARSE_BACKEND_BATCH_SPECS["large_q_pure_prefill"] = BatchSpec( + seq_lens=[256] * 2, query_lens=[256] * 2) + + +def _dequantize_fp8_ds_mla_entry( + cache_slice: torch.Tensor, kv_lora_rank: int, rope_dim: int, + dtype: torch.dtype) -> tuple[torch.Tensor, torch.Tensor]: + """Dequantize a single fp8_ds_mla cache entry back to latent + rope.""" + + # The first kv_lora_rank bytes store FP8 latent values with one scale per + # 128 element tile written as float32 right after the latent payload. + scales = cache_slice.view(torch.float32)[kv_lora_rank // + 4:kv_lora_rank // 4 + 4] + latent = torch.empty(kv_lora_rank, + dtype=torch.float16, + device=cache_slice.device) + for tile_idx in range(4): + tile_start = tile_idx * 128 + tile_end = tile_start + 128 + ops.convert_fp8(latent[tile_start:tile_end], + cache_slice[tile_start:tile_end], + float(scales[tile_idx].item()), + kv_dtype="fp8") + latent = latent.to(dtype) + + rope_offset = kv_lora_rank // 2 + 8 + rope_vals = cache_slice.view(dtype)[rope_offset:rope_offset + rope_dim] + return latent, rope_vals.clone() + + +def _quantize_dequantize_fp8_ds_mla( + kv_c: torch.Tensor, k_pe: torch.Tensor, block_size: int, + scale: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Round-trip kv_c/k_pe though the fp8_ds_mla cache layout.""" + + if kv_c.numel() == 0: + return kv_c.clone(), k_pe.clone() + + kv_lora_rank = kv_c.shape[-1] + rope_dim = k_pe.shape[-1] + num_tokens = kv_c.shape[0] + num_blocks = max(1, math.ceil(num_tokens / block_size)) + entry_size = kv_lora_rank + 4 * 4 + 2 * rope_dim + + tmp_cache = torch.zeros(num_blocks, + block_size, + entry_size, + dtype=torch.uint8, + device=kv_c.device) + slot_mapping = torch.arange(num_tokens, + dtype=torch.long, + device=kv_c.device) + + ops.concat_and_cache_mla(kv_c, + k_pe, + tmp_cache, + slot_mapping, + kv_cache_dtype="fp8_ds_mla", + scale=scale) + + dequant_kv_c = torch.empty_like(kv_c) + dequant_k_pe = torch.empty_like(k_pe) + + for token_idx in range(num_tokens): + slot = slot_mapping[token_idx].item() + block_idx = slot // block_size + block_offset = slot % block_size + cache_slice = tmp_cache[block_idx, block_offset] + latent, rope_vals = _dequantize_fp8_ds_mla_entry( + cache_slice, kv_lora_rank, rope_dim, kv_c.dtype) + dequant_kv_c[token_idx] = latent + dequant_k_pe[token_idx] = rope_vals + + return dequant_kv_c, dequant_k_pe + + +def test_sparse_backend_metadata_registration(): + backend = FlashMLASparseBackend + + assert backend.get_name() == "FLASHMLA_SPARSE_VLLM_V1" + assert backend.get_metadata_cls() is FlashMLASparseMetadata + assert backend.get_impl_cls() is FlashMLASparseImpl + + dtype_list = backend.get_supported_dtypes() + assert torch.bfloat16 in dtype_list + + shape = backend.get_kv_cache_shape(num_blocks=2, + block_size=64, + num_kv_heads=1, + head_size=576) + assert shape == (2, 64, 576) + + +def test_sparse_decode_metadata_filters_prefill_indices(): + prefill_context_lengths = torch.tensor([4, 2], dtype=torch.int32) + metadata = FlashMLASparseDecodeAndContextMetadata( + scheduler_metadata=torch.tensor([[0]], dtype=torch.int32), + num_splits=torch.tensor([1, 1], dtype=torch.int32), + cache_lens=torch.tensor([10, 12], dtype=torch.int32), + prefill_context_lengths=prefill_context_lengths, + ) + + indices = torch.tensor([[0, 3, 5], [1, 2, 4]], dtype=torch.int32) + + context_indices, new_token_indices = metadata.filter_prefill_indices( + indices) + + expected_context = torch.tensor([[-1, -1, 5], [-1, -1, 4]], + dtype=torch.int32) + expected_new_tokens = torch.tensor([[-1, -1, 1], [-1, 0, 2]], + dtype=torch.int32) + + assert torch.equal(context_indices, expected_context) + assert torch.equal(new_token_indices, expected_new_tokens) + + +def test_sparse_impl_zero_fills_when_metadata_missing(): + impl = FlashMLASparseImpl.__new__(FlashMLASparseImpl) + dummy_layer = object() + q = torch.zeros((2, 1, 3)) + k_c = torch.zeros((2, 3)) + k_pe = torch.zeros((2, 1, 1)) + kv_cache = torch.zeros((1, 1, 1)) + output = torch.ones((2, 4)) + + result = FlashMLASparseImpl.forward(impl, + dummy_layer, + q, + k_c, + k_pe, + kv_cache, + attn_metadata=None, + output=output) + + assert result is output + assert torch.all(result == 0) + + +@pytest.mark.parametrize("batch_name", list(SPARSE_BACKEND_BATCH_SPECS.keys())) +@pytest.mark.parametrize("kv_cache_dtype", ["fp8_ds_mla", "auto"]) +def test_sparse_backend_decode_correctness(dist_init, batch_name, + kv_cache_dtype): + if not torch.cuda.is_available(): + pytest.skip("CUDA is required for sparse MLA decode test") + + device = torch.device("cuda") + dtype = torch.bfloat16 + + batch_spec = SPARSE_BACKEND_BATCH_SPECS[batch_name] + + # Model hyper-parameters (kept intentionally small for the unit test) + num_heads = 128 + kv_lora_rank = 512 + qk_nope_head_dim = 128 + qk_rope_head_dim = 64 + v_head_dim = 128 + head_size = kv_lora_rank + qk_rope_head_dim + topk_tokens = 2048 + + max_seqlen = max(batch_spec.seq_lens) + total_cache_tokens = sum(batch_spec.seq_lens) + block_size = 64 + + vllm_config = create_vllm_config( + model_name="deepseek-ai/DeepSeek-V2-Lite-Chat", + max_model_len=max_seqlen, + num_gpu_blocks=max(2048, + cdiv(total_cache_tokens, block_size) + 1), + block_size=block_size) + model_config = vllm_config.model_config + model_config.hf_config = SimpleNamespace( + attn_module_list_cfg=[{ + "topk_tokens": topk_tokens + }]) + model_config.hf_text_config = SimpleNamespace( + q_lora_rank=None, + kv_lora_rank=kv_lora_rank, + qk_nope_head_dim=qk_nope_head_dim, + qk_rope_head_dim=qk_rope_head_dim, + v_head_dim=v_head_dim, + model_type="deepseek_v2", + ) + model_config.dtype = dtype + model_config.get_num_attention_heads = MethodType( + lambda self, parallel_config: num_heads, model_config) + model_config.get_num_kv_heads = MethodType(lambda self, parallel_config: 1, + model_config) + model_config.get_head_size = MethodType(lambda self: head_size, + model_config) + model_config.get_sliding_window = MethodType(lambda self: None, + model_config) + + kv_cache_spec = create_standard_kv_cache_spec(vllm_config) + + torch.manual_seed(0) + + scale = 1.0 / math.sqrt(head_size) + + # Shared MLA projection weights to keep reference and backend in sync + W_UK = torch.randn(kv_lora_rank, + num_heads, + qk_nope_head_dim, + dtype=dtype, + device=device) + W_UV = torch.randn(kv_lora_rank, + num_heads, + v_head_dim, + dtype=dtype, + device=device) + + # Build synthetic decode-only workload + seq_lens = batch_spec.seq_lens + query_lens = batch_spec.query_lens + + all_q_vllm, all_kv_c_vllm, all_k_pe_vllm = [], [], [] + kv_c_contexts, k_pe_contexts = [], [] + reference_outputs = [] + + kv_cache_scale = torch.tensor(1.0, dtype=torch.float32, device=device) + + for i in range(batch_spec.batch_size): + s_len = seq_lens[i] + q_len = query_lens[i] + ctx_len = s_len - q_len + + q_c = torch.rand(q_len, + num_heads, + qk_nope_head_dim + qk_rope_head_dim, + dtype=dtype, + device=device) + kv_c_full = torch.rand(s_len, kv_lora_rank, dtype=dtype, device=device) + k_pe_full = torch.rand(s_len, + 1, + qk_rope_head_dim, + dtype=dtype, + device=device) + + kv_c_full, k_pe_full = _quantize_dequantize_fp8_ds_mla( + kv_c_full, + k_pe_full.squeeze(1), + block_size=vllm_config.cache_config.block_size, + scale=kv_cache_scale, + ) + + q_nope, q_pe = q_c.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1) + ql_nope = torch.einsum("qnh,lnh->qnl", q_nope, W_UK) + q_mqa = torch.cat([ql_nope, q_pe], dim=-1) + + k_mqa = torch.cat([kv_c_full, k_pe_full], dim=-1) + k_mqa = k_mqa.unsqueeze(1).expand(-1, num_heads, -1) + v_mqa = kv_c_full.unsqueeze(1).expand(-1, num_heads, -1) + + attn_mask = torch.ones(q_len, s_len, dtype=torch.bool, device=device) + causal_mask = torch.tril(torch.ones(q_len, q_len, device=device)) + attn_mask[:, ctx_len:] = causal_mask + + q_sdpa_in = q_mqa.unsqueeze(0).transpose(1, 2) + k_sdpa_in = k_mqa.unsqueeze(0).transpose(1, 2) + v_sdpa_in = v_mqa.unsqueeze(0).transpose(1, 2) + + sdpa_out = torch.nn.functional.scaled_dot_product_attention( + q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale) + sdpa_out = sdpa_out.transpose(1, 2).squeeze(0) + + sdpa_out = torch.einsum("qnl,lnv->qnv", sdpa_out, W_UV) + reference_outputs.append(sdpa_out.flatten(start_dim=-2)) + + all_q_vllm.append(q_c) + all_kv_c_vllm.append(kv_c_full[ctx_len:]) + all_k_pe_vllm.append(k_pe_full[ctx_len:]) + kv_c_contexts.append(kv_c_full[:ctx_len + 1]) + k_pe_contexts.append(k_pe_full[:ctx_len + 1]) + + query_vllm = torch.cat(all_q_vllm, dim=0) + kv_c_vllm = torch.cat(all_kv_c_vllm, dim=0) + k_pe_vllm = torch.cat(all_k_pe_vllm, dim=0) + sdpa_reference = torch.cat(reference_outputs, dim=0) + + vllm_config.cache_config.cache_dtype = kv_cache_dtype + + common_attn_metadata = create_common_attn_metadata( + batch_spec, + vllm_config.cache_config.block_size, + device, + arange_block_indices=True) + + kv_cache = create_and_prepopulate_kv_cache( + kv_c_contexts=kv_c_contexts, + k_pe_contexts=k_pe_contexts, + block_size=vllm_config.cache_config.block_size, + head_size=head_size, + dtype=dtype, + device=device, + num_blocks=vllm_config.cache_config.num_gpu_blocks, + common_attn_metadata=common_attn_metadata, + randomize_blocks=False, + kv_cache_dtype=vllm_config.cache_config.cache_dtype, + scale=kv_cache_scale, + ) + + builder_cls = FlashMLASparseBackend.get_builder_cls() + builder = builder_cls(kv_cache_spec, ["placeholder"], vllm_config, device) + metadata = builder.build(common_prefix_len=0, + common_attn_metadata=common_attn_metadata) + + starts = np.asarray(common_attn_metadata.query_start_loc_cpu, + dtype=np.int32) + seg_lengths = np.diff(starts) + positions = np.arange(starts[-1], dtype=np.int32) - np.repeat( + starts[:-1], seg_lengths) + seq_lengths = np.asarray(common_attn_metadata.seq_lens_cpu, dtype=np.int32) + prefix_lengths = seq_lengths - seg_lengths + positions += np.repeat(prefix_lengths, seg_lengths) + + pos_gpu = torch.as_tensor(positions, device=device, dtype=torch.int32) + topk = metadata.topk_tokens + debug_indices = torch.arange(topk, device=device, + dtype=torch.int32).unsqueeze(0) + token_positions = pos_gpu.unsqueeze(1) + causal_mask = (debug_indices <= token_positions) + debug_indices = torch.where(causal_mask, debug_indices, + torch.full_like(debug_indices, -1)) + + # FlashMLASparseImpl now reads top-k indices from the indexer-provided + # buffer, so emulate that contract with a simple namespace mock. + debug_indices = debug_indices.expand(metadata.num_actual_tokens, + -1).clone() + mock_indexer = SimpleNamespace(topk_indices_buffer=debug_indices) + + ok, reason = flashmla.is_flashmla_supported() + if not ok: + pytest.skip(reason) + + kv_b_proj_weight = torch.cat([W_UK, W_UV], dim=-1) + kv_b_proj_weight = kv_b_proj_weight.view( + kv_lora_rank, num_heads * (qk_nope_head_dim + v_head_dim)) + + mock_kv_b_proj = ColumnParallelLinear(input_size=kv_lora_rank, + output_size=num_heads * + (qk_nope_head_dim + v_head_dim), + bias=False).to(device=device, + dtype=dtype) + mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T.contiguous()) + + impl_cls = FlashMLASparseBackend.get_impl_cls() + impl = impl_cls(num_heads=num_heads, + head_size=head_size, + scale=scale, + num_kv_heads=1, + alibi_slopes=None, + sliding_window=None, + kv_cache_dtype=vllm_config.cache_config.cache_dtype, + logits_soft_cap=None, + attn_type="decoder", + kv_sharing_target_layer_name=None, + q_lora_rank=None, + kv_lora_rank=kv_lora_rank, + qk_nope_head_dim=qk_nope_head_dim, + qk_rope_head_dim=qk_rope_head_dim, + qk_head_dim=qk_nope_head_dim + qk_rope_head_dim, + v_head_dim=v_head_dim, + kv_b_proj=mock_kv_b_proj, + indexer=mock_indexer) + + impl.process_weights_after_loading(dtype) + + layer = MockAttentionLayer(device) + out_buffer = torch.empty(metadata.num_actual_tokens, + num_heads * v_head_dim, + dtype=dtype, + device=device) + + backend_output = impl.forward(layer, + query_vllm, + kv_c_vllm, + k_pe_vllm, + kv_cache, + metadata, + output=out_buffer) + + assert backend_output.shape == sdpa_reference.shape + assert backend_output.dtype == sdpa_reference.dtype + assert torch.isfinite(backend_output).all() + + torch.testing.assert_close(backend_output, + sdpa_reference, + rtol=0.5, + atol=0.5) diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py index f07c6eb0ea4d..41b71e33e0c4 100644 --- a/tests/v1/attention/utils.py +++ b/tests/v1/attention/utils.py @@ -168,7 +168,6 @@ def create_standard_kv_cache_spec( vllm_config.parallel_config), head_size=vllm_config.model_config.get_head_size(), dtype=vllm_config.model_config.dtype, - use_mla=vllm_config.model_config.use_mla, sliding_window=vllm_config.model_config.get_sliding_window(), ) diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 4cb7ed6ce382..452b16ef4a91 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -24,7 +24,8 @@ make_block_hash_with_group_id) from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, KVCacheSpec, - KVCacheTensor, SlidingWindowSpec, + KVCacheTensor, MLAAttentionSpec, + SlidingWindowSpec, UniformTypeKVCacheSpecs) from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request @@ -77,13 +78,11 @@ def new_kv_cache_spec(block_size=16, num_kv_heads=2, head_size=64, dtype=torch.float32, - use_mla=False, sliding_window=None): return FullAttentionSpec(block_size=block_size, num_kv_heads=num_kv_heads, head_size=head_size, dtype=dtype, - use_mla=use_mla, sliding_window=sliding_window) @@ -91,13 +90,11 @@ def new_sliding_window_spec(block_size=16, num_kv_heads=2, head_size=64, dtype=torch.float32, - use_mla=False, sliding_window=1): return SlidingWindowSpec(block_size=block_size, num_kv_heads=num_kv_heads, head_size=head_size, dtype=dtype, - use_mla=use_mla, sliding_window=sliding_window) @@ -894,7 +891,6 @@ def test_merge_kv_cache_spec(): num_kv_heads=full_spec.num_kv_heads, head_size=full_spec.head_size, dtype=full_spec.dtype, - use_mla=full_spec.use_mla, sliding_window=1, ), ] @@ -991,7 +987,6 @@ def test_estimate_max_model_len(model_id, max_model_len, num_kv_heads=32, head_size=128, dtype=torch.float16, - use_mla=False, ) # Estimate the maximum model length, 16384 model_len need 8GB estimated_max_len = estimate_max_model_len(vllm_config, kv_cache_spec, @@ -1022,7 +1017,6 @@ def test_get_max_concurrency_for_kv_cache_config(): num_kv_heads=32, head_size=128, dtype=torch.float16, - use_mla=False, ) sliding_window_spec = SlidingWindowSpec( @@ -1030,7 +1024,6 @@ def test_get_max_concurrency_for_kv_cache_config(): num_kv_heads=32, head_size=128, dtype=torch.float16, - use_mla=False, sliding_window=1024, ) @@ -1412,3 +1405,48 @@ def test_generate_scheduler_kv_cache_config(): KVCacheGroupSpec(['layer_1', 'layer_2'], new_kv_cache_spec()) ], ) + + +def new_mla_spec(cache_dtype_str=None): + return MLAAttentionSpec(block_size=16, + num_kv_heads=16, + head_size=64, + dtype=torch.float32, + cache_dtype_str=cache_dtype_str) + + +def test_merge_mla_spec(): + kv_cache_specs = [ + new_mla_spec(), + new_mla_spec(), + ] + mla_spec = kv_cache_specs[0].merge(kv_cache_specs) + assert mla_spec == new_mla_spec() + + kv_cache_specs = [ + new_mla_spec(cache_dtype_str="fp8_ds_mla"), + new_mla_spec(cache_dtype_str="fp8_ds_mla"), + ] + mla_spec = kv_cache_specs[0].merge(kv_cache_specs) + assert mla_spec == new_mla_spec(cache_dtype_str="fp8_ds_mla") + + kv_cache_specs = [ + new_mla_spec(cache_dtype_str="fp8_ds_mla"), + new_mla_spec(cache_dtype_str=None), + ] + with pytest.raises(AssertionError): + kv_cache_specs[0].merge(kv_cache_specs) + + kv_cache_specs = [ + new_kv_cache_spec(), + new_mla_spec(), + ] + with pytest.raises(AssertionError): + kv_cache_specs[0].merge(kv_cache_specs) + + kv_cache_specs = [ + new_mla_spec(cache_dtype_str="fp8_ds_mla"), + new_kv_cache_spec(), + ] + with pytest.raises(AssertionError): + kv_cache_specs[0].merge(kv_cache_specs) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 3cf9d9369676..3ddfaf71a1ca 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -1337,7 +1337,6 @@ def test_eagle_with_sliding_window(): head_size=1, dtype=torch.float32, sliding_window=block_size, - use_mla=False, ) manager = KVCacheManager( KVCacheConfig( diff --git a/tests/v1/core/test_single_type_kv_cache_manager.py b/tests/v1/core/test_single_type_kv_cache_manager.py index b70850a9bcff..e1a26cfd898f 100644 --- a/tests/v1/core/test_single_type_kv_cache_manager.py +++ b/tests/v1/core/test_single_type_kv_cache_manager.py @@ -35,7 +35,6 @@ def test_chunked_local_attention_possible_cached_prefix(): head_size=1, dtype=torch.float32, attention_chunk_size=4, - use_mla=False, ) block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) @@ -101,7 +100,6 @@ def test_sliding_window_possible_cached_prefix(): head_size=1, dtype=torch.float32, sliding_window=4, - use_mla=False, ) block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) @@ -167,7 +165,6 @@ def test_chunked_local_attention_remove_skipped_blocks(): head_size=1, dtype=torch.float32, attention_chunk_size=4, - use_mla=False, ) block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True) @@ -219,7 +216,6 @@ def test_sliding_window_remove_skipped_blocks(): head_size=1, dtype=torch.float32, sliding_window=4, - use_mla=False, ) block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True) @@ -287,7 +283,6 @@ def test_get_num_blocks_to_allocate(): head_size=1, dtype=torch.float32, sliding_window=4, # Placeholder value, not related to test result - use_mla=False, ) block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) @@ -310,7 +305,6 @@ def test_chunked_local_attention_get_num_blocks_to_allocate(): head_size=1, dtype=torch.float32, attention_chunk_size=4, # Placeholder value, not related to test result - use_mla=False, ) block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index 992c4e01386e..10adac9bab5f 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -836,8 +836,7 @@ def create_mock_executor(vllm_config): mock_spec = FullAttentionSpec(block_size=16, num_kv_heads=1, head_size=64, - dtype=torch.float16, - use_mla=False) + dtype=torch.float16) mock_executor.get_kv_cache_specs.return_value = [{ "default": mock_spec diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 8b571f95c5ec..49a7a61e1889 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -39,7 +39,6 @@ def initialize_kv_cache(runner: GPUModelRunner): runner.parallel_config), head_size=runner.model_config.get_head_size(), dtype=runner.kv_cache_dtype, - use_mla=False, ) tensor_size = attn_spec.page_size_bytes * NUM_BLOCKS kv_cache_config = KVCacheConfig( diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index dfde67e1713c..754545e6f2d6 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -76,6 +76,7 @@ def get_kv_cache_shape( block_size: int, num_kv_heads: int, head_size: int, + cache_dtype_str: str = "auto", ) -> Tuple[int, ...]: raise NotImplementedError diff --git a/vllm/attention/backends/differential_flash_attn.py b/vllm/attention/backends/differential_flash_attn.py index a7d0e3afb517..7dce44489a21 100644 --- a/vllm/attention/backends/differential_flash_attn.py +++ b/vllm/attention/backends/differential_flash_attn.py @@ -53,6 +53,7 @@ def get_kv_cache_shape( block_size: int, num_kv_heads: int, head_size: int, + cache_dtype_str: str = "auto", ) -> Tuple[int, ...]: if block_size % 16 != 0: raise ValueError("Block size must be a multiple of 16.") diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 78c768f92d3c..25f05dac28c2 100755 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -71,6 +71,7 @@ def get_kv_cache_shape( block_size: int, num_kv_heads: int, head_size: int, + cache_dtype_str: str = "auto", ) -> Tuple[int, ...]: if block_size % 16 != 0: raise ValueError("Block size must be a multiple of 16.") diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index 789393eb39a7..3feaee438523 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -263,6 +263,7 @@ def get_kv_cache_shape( block_size: int, num_kv_heads: int, # assumed to be 1 for MLA head_size: int, + cache_dtype_str: str = "auto", ) -> Tuple[int, ...]: return (num_blocks, block_size, head_size) diff --git a/vllm/attention/backends/placeholder_attn.py b/vllm/attention/backends/placeholder_attn.py index e630a6c6de8c..aaa12da3c67b 100644 --- a/vllm/attention/backends/placeholder_attn.py +++ b/vllm/attention/backends/placeholder_attn.py @@ -51,6 +51,7 @@ def get_kv_cache_shape( block_size: int, num_kv_heads: int, head_size: int, + cache_dtype_str: str = "auto", ) -> Tuple[int, ...]: return (1, 1, 1, 1, 1) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 9262144e37b5..5dc7790bacf9 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -82,6 +82,7 @@ def get_kv_cache_shape( block_size: int, num_kv_heads: int, head_size: int, + cache_dtype_str: str = "auto", ) -> Tuple[int, ...]: paged_attn = _get_paged_attn_module() return paged_attn.get_kv_cache_shape(num_blocks, block_size, diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 302d3d7ea903..495225127fe2 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -53,6 +53,7 @@ def get_kv_cache_shape( block_size: int, num_kv_heads: int, head_size: int, + cache_dtype_str: str = "auto", ) -> Tuple[int, ...]: return PagedAttention.get_kv_cache_shape(num_blocks, block_size, num_kv_heads, head_size) diff --git a/vllm/attention/ops/flashmla.py b/vllm/attention/ops/flashmla.py index 7c7e010e2af2..9c9eee24ebeb 100644 --- a/vllm/attention/ops/flashmla.py +++ b/vllm/attention/ops/flashmla.py @@ -13,13 +13,21 @@ if current_platform.is_cuda(): try: import vllm._flashmla_C # noqa: F401 - import vllm._flashmla_sparse_C # noqa: F401 _flashmla_C_AVAILABLE = True except ImportError: _flashmla_C_AVAILABLE = False else: _flashmla_C_AVAILABLE = False +if current_platform.is_cuda(): + try: + import vllm._flashmla_extension_C # noqa: F401 + _flashmla_extension_C_AVAILABLE = True + except ImportError: + _flashmla_extension_C_AVAILABLE = False +else: + _flashmla_extension_C_AVAILABLE = False + def is_flashmla_supported() -> Tuple[bool, Optional[str]]: """ @@ -38,24 +46,28 @@ def is_flashmla_supported() -> Tuple[bool, Optional[str]]: def get_mla_metadata( - cache_seqlens: torch.Tensor, - num_heads_per_head_k: int, - num_heads_k: int, -) -> Tuple[torch.Tensor, torch.Tensor]: + cache_seqlens: torch.Tensor, + num_q_tokens_per_head_k: int, + num_heads_k: int, + num_heads_q: Optional[int] = None, + is_fp8_kvcache: bool = False, + topk: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]: """ Arguments: cache_seqlens: (batch_size), dtype torch.int32. - num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k. - num_heads_k: num_heads_k. + num_q_tokens_per_head_k: Equals to num_q_tokens_per_q_seq * num_heads_q // num_heads_k. + num_heads_k: The number of k heads. + num_heads_q: The number of q heads. This argument is optional when sparse attention is not enabled + is_fp8_kvcache: Whether the k_cache and v_cache are in fp8 format. + topk: If not None, sparse attention will be enabled, and only tokens in the `indices` array passed to `flash_mla_with_kvcache_sm90` will be attended to. - Return: - tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), - dtype torch.int32. + Returns: + tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. num_splits: (batch_size + 1), dtype torch.int32. """ - return torch.ops._flashmla_C.get_mla_metadata(cache_seqlens, - num_heads_per_head_k, - num_heads_k) + return torch.ops._flashmla_C.get_mla_decoding_metadata( + cache_seqlens, num_q_tokens_per_head_k, num_heads_k, num_heads_q, + is_fp8_kvcache, topk) def flash_mla_with_kvcache( @@ -70,6 +82,8 @@ def flash_mla_with_kvcache( causal: bool = False, descale_q: Optional[torch.Tensor] = None, descale_k: Optional[torch.Tensor] = None, + is_fp8_kvcache: bool = False, + indices: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Arguments: @@ -77,121 +91,37 @@ def flash_mla_with_kvcache( k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). block_table: (batch_size, max_num_blocks_per_seq), torch.int32. cache_seqlens: (batch_size), torch.int32. - head_dim_v: Head_dim of v. - tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), - torch.int32, return by get_mla_metadata. - num_splits: (batch_size + 1), torch.int32, return by get_mla_metadata. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(head_dim). + head_dim_v: Head dimension of v. + tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata. + num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata. + softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim). causal: bool. Whether to apply causal attention mask. - descale_q: (batch_size), torch.float32. Descaling factors for Q. - descale_k: (batch_size), torch.float32. Descaling factors for K. + descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization. + descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization. + is_fp8_kvcache: bool. Whether the k_cache and v_cache are in fp8 format. For the format of FP8 KV cache, please refer to README.md + indices: (batch_size, seq_len_q, topk), torch.int32. If not None, sparse attention will be enabled, and only tokens in the `indices` array will be attended to. Invalid indices should be set to -1 or numbers >= total_seq_len_kv. For details about how to set up `indices`, please refer to README.md. - Return: - out: (batch_size, seq_len_q, num_heads_q, head_dim_v). - softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. - """ - if softmax_scale is None: - softmax_scale = q.shape[-1]**(-0.5) - out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla( - q, - k_cache, - head_dim_v, - cache_seqlens, - block_table, - softmax_scale, - causal, - tile_scheduler_metadata, - num_splits, - descale_q, - descale_k, - ) - - # Note(hc): need revisit when we support DCP with decode query_len > 1. - return out.squeeze(1), softmax_lse.squeeze(-1) - - -# ------------------------ Sparse FlashMLA bindings ------------------------- - - -def get_sparse_mla_metadata( - cache_seqlens: torch.Tensor, - q_seq_per_hk: int, - num_heads_k: int, - topk: int, - q_heads_per_hk: Optional[int] = None, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Arguments: - cache_seqlens: (batch_size), dtype torch.int32. - q_seq_per_hk: Equals to seq_len_q * num_heads_q // num_heads_k. - num_heads_k: num_heads_k. - topk: topk - q_heads_per_hk: equals to num_heads_q // num_heads_k. Only need to be - specified when topk is not None. - - Return: - tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), - dtype torch.int32. - num_splits: (batch_size + 1), dtype torch.int32. - """ - return torch.ops._flashmla_sparse_C.get_mla_metadata( - cache_seqlens, q_seq_per_hk, num_heads_k, topk, q_heads_per_hk) - - -def flash_mla_sparse_with_kvcache( - q: torch.Tensor, - k_cache: torch.Tensor, - cache_seqlens: torch.Tensor, - head_dim_v: int, - tile_scheduler_metadata: torch.Tensor, - num_splits: torch.Tensor, - indices_in_kvcache: torch.Tensor, - softmax_scale: Optional[float] = None, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Arguments: - q: (batch_size, seq_len_q, num_heads_q, head_dim). - k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). - cache_seqlens: (batch_size), torch.int32. - head_dim_v: Head_dim of v. - tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), - torch.int32, returned by get_sparse_mla_metadata. - num_splits: (batch_size + 1), torch.int32, returned by - get_sparse_mla_metadata. - indices_in_kvcache: (batch_size, seq_len_q, topk). KV indices when - sparse attention is enabled. Note that - indices_in_kvcache[i][j][k] = - (the index of the page block where token t resides) * - page_block_size + (the offset of token t within that page block), - where t is the k-th token of the j-th q-sequence in the i-th batch. - softmax_scale: float. Scaling of QK^T before softmax. - Defaults to 1 / sqrt(head_dim). - - Explanation of K/V cache layout: - We quantize the NoPE part of each token (in 1x128 granularity), - yielding 512 float8_e4m3 values and 4 float32 scale factors. For the - RoPE part, we keep it as 64 bfloat16. Each token occupies 656 bytes: - - First 512 bytes: quantized NoPE (512 x float8_e4m3) - - Next 16 bytes: scale factors (4 x float32) - - Last 128 bytes: RoPE (64 x bfloat16) - - Return: + Returns: out: (batch_size, seq_len_q, num_heads_q, head_dim_v). softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. """ if softmax_scale is None: softmax_scale = q.shape[-1]**(-0.5) - # Strict shape checks like the reference implementation - assert k_cache.shape[-1] == 656, ( - "The last dim of k_cache must be 656 (=512+2*16+4*4) when " - "is_fp8_kvcache is True") - assert k_cache.shape[-2] == 1, ( - "The number of K heads must be 1 when is_fp8_kvcache is True") - - out, softmax_lse = torch.ops._flashmla_sparse_C.fwd_kvcache_mla( - q, k_cache, head_dim_v, cache_seqlens, softmax_scale, - tile_scheduler_metadata, num_splits, indices_in_kvcache) + if indices is not None: + assert causal == False, "causal must be `false` if sparse attention is enabled." + assert (descale_q is None) == ( + descale_k is None + ), "descale_q and descale_k should be both None or both not None" + + if (descale_q is not None) and (descale_k is not None): + out, softmax_lse = torch.ops._flashmla_extension_C.fwd_kvcache_mla_fp8( + q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale, + causal, tile_scheduler_metadata, num_splits, descale_q, descale_k) + else: + out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla( + q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale, + causal, tile_scheduler_metadata, num_splits, is_fp8_kvcache, + indices) return out, softmax_lse @@ -203,24 +133,24 @@ def flash_mla_sparse_prefill( d_v: int = 512, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ - Sparse attention forward operator, for prefill + Sparse attention prefill kernel Args: q: [s_q, h_q, d_qk], bfloat16 kv: [s_kv, h_kv, d_qk], bfloat16 - indices: [s_q, h_kv, topk], int32. Invalid indices should be set to -1, - or to a number >= s_kv - sm_scale: float, scaling factor for the attention scores - d_v: dimension of the value, default (and only supported) is 512 + indices: [s_q, h_kv, topk], int32. Invalid indices should be set to -1 or numbers >= s_kv + sm_scale: float + d_v: The dimension of value vectors. Can only be 512 Returns: - Returns (output, max_logits, lse) - - output: [s_q, h_q, d_v], bfloat16, the result of attention - - max_logits: [s_q, h_q], float - - lse: [s_q, h_q], float, base-2 + (output, max_logits, lse) + About the definition of output, max_logits and lse, please refer to README.md + - output: [s_q, h_q, d_v], bfloat16 + - max_logits: [s_q, h_q], float + - lse: [s_q, h_q], float, 2-based log-sum-exp """ - results = torch.ops._flashmla_sparse_C.sparse_topk_attn_fwd( - q, kv, indices, sm_scale, d_v) + results = torch.ops._flashmla_C.sparse_prefill_fwd(q, kv, indices, + sm_scale, d_v) return results diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 4d870a45e580..539b57e41de7 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -50,6 +50,7 @@ def get_kv_cache_shape( block_size: int, num_kv_heads: int, head_size: int, + cache_dtype_str: str = "auto", ) -> Tuple[int, ...]: return (2, num_blocks, block_size * num_kv_heads * head_size) diff --git a/vllm/model_executor/layers/mla.py b/vllm/model_executor/layers/mla.py index 892bf5e09c8f..d23232ab09b5 100644 --- a/vllm/model_executor/layers/mla.py +++ b/vllm/model_executor/layers/mla.py @@ -81,11 +81,14 @@ def __init__( self.rotary_emb = mla_modules.rotary_emb self.o_proj = mla_modules.o_proj self.indexer = mla_modules.indexer - self.topk_tokens = mla_modules.indexer.topk_tokens \ - if self.indexer else None self.use_sparse = mla_modules.is_sparse and os.getenv( "VLLM_MLA_SPARSE_DISABLED") != "1" - self.topk_indices_buffer = mla_modules.topk_indices_buffer + + if self.indexer is not None: + assert hasattr(self.indexer, "topk_tokens") + self.topk_tokens = self.indexer.topk_tokens \ + if self.indexer else None + self.topk_indices_buffer = mla_modules.topk_indices_buffer # In the MLA backend, kv_cache includes both k_c and # pe (i.e. decoupled position embeddings). In particular, @@ -160,7 +163,7 @@ def forward_native( if self.indexer and self.use_sparse: _topk_indices = self.indexer(hidden_states, q_c, positions, - self.rotary_emb) + self.rotary_emb) attn_out = self.mla_attn( q, diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index ce3d23763ed6..86348f8615b9 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -414,6 +414,24 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: "exactly equal.", mamba_padding_pct) +class DeepseekV3ForCausalLM(VerifyAndUpdateConfig): + + @classmethod + def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: + """ + Updated fp8 cache to custom "fp8_ds_mla" format for DeepSeekV32 + """ + hf_config = vllm_config.model_config.hf_config + + is_v32 = hasattr(hf_config, "attn_module_list_cfg") \ + and "attn_index" in hf_config.attn_module_list_cfg[0] + + if is_v32: + cache_config = vllm_config.cache_config + if cache_config.cache_dtype.startswith("fp8"): + cache_config.cache_dtype = "fp8_ds_mla" + + MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = { "GteModel": SnowflakeGteNewModelConfig, "GteNewModel": GteNewModelConfig, @@ -431,4 +449,5 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: "MambaForCausalLM": MambaModelConfig, "Mamba2ForCausalLM": MambaModelConfig, "FalconMambaForCausalLM": MambaModelConfig, + "DeepseekV3ForCausalLM": DeepseekV3ForCausalLM, } diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 25d6340f198b..0ce54a594d44 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -74,7 +74,7 @@ is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) -from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheSpec +from vllm.v1.kv_cache_interface import MLAAttentionSpec, KVCacheSpec from vllm.utils.deep_gemm import ( fp8_mqa_logits, get_paged_mqa_logits_metadata, @@ -514,12 +514,11 @@ def __init__(self, head_dim: int, dtype: torch.dtype, prefix: str, compilation_config.static_forward_context[prefix] = self def get_kv_cache_spec(self) -> KVCacheSpec: - return FullAttentionSpec( + return MLAAttentionSpec( # Only has one vector instead of K + V block_size=self.cache_config.block_size, num_kv_heads=1, head_size=self.head_dim, dtype=self.dtype, - use_mla=True # Only has one vector instead of K + V ) def forward(self): diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index 968bba664f0a..a38d0d58e4ec 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -189,6 +189,7 @@ "fp8_e5m2": torch.uint8, "int8": torch.int8, "fp8_inc": torch.float8_e4m3fn, + "fp8_ds_mla": torch.uint8, } TORCH_DTYPE_TO_NUMPY_DTYPE = { diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index 6627164c9879..466e6320c591 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -79,6 +79,7 @@ def get_kv_cache_shape( block_size: int, num_kv_heads: int, head_size: int, + cache_dtype_str: str = "auto", ) -> tuple[int, ...]: return _get_paged_attn_impl().get_kv_cache_shape( num_blocks, block_size, num_kv_heads, head_size) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 20f1904b3be6..5d407fca1ad9 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -82,6 +82,7 @@ def get_kv_cache_shape( block_size: int, num_kv_heads: int, head_size: int, + cache_dtype_str: str = "auto", ) -> tuple[int, ...]: if block_size % 16 != 0: raise ValueError("Block size must be a multiple of 16.") diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index cb092aa74e7f..d5a00a3afebe 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -176,6 +176,7 @@ def get_kv_cache_shape( block_size: int, num_kv_heads: int, head_size: int, + cache_dtype_str: str = "auto", ) -> tuple[int, ...]: return (num_blocks, 2, block_size, num_kv_heads, head_size) diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index 662d3984554a..c0e5acdd245a 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -88,6 +88,7 @@ def get_kv_cache_shape( block_size: int, num_kv_heads: int, head_size: int, + cache_dtype_str: str = "auto", ) -> tuple[int, ...]: return (2, num_blocks, block_size, num_kv_heads, head_size) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index c3de294d947e..d8749aaab930 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -237,11 +237,6 @@ except ImportError: flashinfer_available = False -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from vllm.v1.attention.backends.mla.flashmla_sparse import MLASparsePrefillMetadata - def is_rocm_aiter_fp8bmm_enabled() -> bool: return current_platform.is_rocm() \ @@ -291,6 +286,7 @@ def get_kv_cache_shape( block_size: int, num_kv_heads: int, # assumed to be 1 for MLA head_size: int, + cache_dtype_str: str = "auto", ) -> tuple[int, ...]: return (num_blocks, block_size, head_size) @@ -403,8 +399,8 @@ class MLACommonMetadata(Generic[D]): decode: Optional[D] = None prefill: Optional[Union[MLACommonPrefillMetadata, - FlashInferPrefillMetadata, CudnnPrefillMetadata, - "MLASparsePrefillMetadata"]] = None + FlashInferPrefillMetadata, + CudnnPrefillMetadata]] = None def __post_init__(self): if self.head_dim is not None: @@ -412,6 +408,7 @@ def __post_init__(self): M = TypeVar("M", bound=MLACommonMetadata) +A = TypeVar("A") def use_flashinfer_prefill() -> bool: @@ -921,7 +918,9 @@ def reorg_kvcache( return reorganized_kv_c_normed, reorganized_k_pe -class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): +# TODO(Lucas): rename MLACommonBaseImpl -> MLACommonImpl, +# and MLACommonImpl -> MLACommonDenseImpl or somthing like that +class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]): """ NOTE: Please read the comment at the top of the file before trying to understand this class @@ -947,7 +946,7 @@ def __init__( qk_head_dim: int, v_head_dim: int, kv_b_proj: ColumnParallelLinear, - indexer = None, + indexer=None, ) -> None: if kv_sharing_target_layer_name is not None: raise NotImplementedError("KV sharing is not supported for MLA") @@ -967,6 +966,126 @@ def __init__( self.kv_b_proj = kv_b_proj self.indexer = indexer + def process_weights_after_loading(self, act_dtype: torch.dtype): + + def get_layer_weight(layer): + WEIGHT_NAMES = ("weight", "qweight", "weight_packed") + for attr in WEIGHT_NAMES: + if hasattr(layer, attr): + return getattr(layer, attr) + raise AttributeError( + f"Layer '{layer}' has no recognized weight attribute:" + f" {WEIGHT_NAMES}.") + + def get_and_maybe_dequant_weights(layer: LinearBase): + if not isinstance(layer.quant_method, UnquantizedLinearMethod): + # NOTE: This should only be used offline, since it's O(N^3) + eye = torch.eye(layer.input_size_per_partition, + dtype=act_dtype, + device=get_layer_weight(layer).device) + dequant_weights = layer.quant_method.apply(layer, + eye, + bias=None) + del eye + # standardize to (output, input) + return dequant_weights.T + return layer.weight + + # we currently do not have quantized bmm's which are needed for + # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform + # the bmm's in 16-bit, the extra memory overhead of this is fairly low + kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T + assert kv_b_proj_weight.shape == ( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), ( + f"{kv_b_proj_weight.shape=}, " + f"{self.kv_lora_rank=}, " + f"{self.num_heads=}, " + f"{self.qk_nope_head_dim=}, " + f"{self.v_head_dim=}") + kv_b_proj_weight = kv_b_proj_weight.view( + self.kv_lora_rank, + self.num_heads, + self.qk_nope_head_dim + self.v_head_dim, + ) + + W_UK, W_UV = kv_b_proj_weight.split( + [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + if is_rocm_aiter_fp8bmm_enabled(): + W_K = W_UK.transpose(0, 1) # 16 512 128 + W_V = W_UV.permute(1, 2, 0) # 16 128 512 + self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant( + W_K, dtype=current_platform.fp8_dtype()) + self.W_V, self.W_V_scale = dynamic_per_batched_tensor_quant( + W_V, dtype=current_platform.fp8_dtype()) + + # The kernel operates on non-padded inputs. Hence, pre-compiling + # triton kernel to avoid runtime compilation for unseen batch sizes + # Pre-compile for batch sizes 1 to 1024 to cover most use-cases. + # On DS-R1, this step adds roughly 50s to the model loading time. + max_batch_size = 1024 # [ToDo] Find the optimal upper limit + pre_compilation_list = list(range(1, max_batch_size + 1)) + if is_global_first_rank(): + pre_compilation_list = tqdm( + pre_compilation_list, + desc="[Aiter Triton] Pre-compiling fp8 BMM kernel", + total=max_batch_size, + ) + + for m in pre_compilation_list: + x = torch.empty((self.W_K.shape[0], m, self.W_K.shape[2]), + dtype=torch.bfloat16, + device=self.W_K.device) + aiter_triton_fp8_bmm(x, + self.W_K, + self.W_K_scale, + group_size=128, + transpose_bm=True) + + x = torch.empty((self.W_V.shape[0], m, self.W_V.shape[2]), + dtype=torch.bfloat16, + device=self.W_V.device) + aiter_triton_fp8_bmm(x, + self.W_V, + self.W_V_scale, + group_size=128, + transpose_bm=True) + else: + # Convert from (L, N, V) to (N, L, V) + self.W_UV = W_UV.transpose(0, 1) + # Convert from (L, N, P) to (N, P, L) + self.W_UK_T = W_UK.permute(1, 2, 0) + + def _v_up_proj(self, x): + # Convert from (B, N, L) to (N, B, L) + x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) + if is_rocm_aiter_fp8bmm_enabled(): + # Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V) + x = aiter_triton_fp8_bmm(x, + self.W_V, + self.W_V_scale, + group_size=128, + transpose_bm=True) + # Convert from (B, N, V) to (B, N * V) + x = x.reshape(-1, self.num_heads * self.v_head_dim) + else: + # Multiply (N, B, L) x (N, L, V) -> (N, B, V) + x = torch.bmm(x, self.W_UV) + # Convert from (N, B, V) to (B, N * V) + x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) + return x + + +class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): + """ + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + if use_flashinfer_prefill(): logger.debug_once("Using FlashInfer prefill for MLA") self._run_prefill_context_chunk = self._run_prefill_context_chunk_fi @@ -1141,116 +1260,6 @@ def _run_prefill_context_chunk_cudnn(self, True, #Indicates actual_seq_lens are on GPU or CPU. ) - def _v_up_proj(self, x): - # Convert from (B, N, L) to (N, B, L) - x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) - if is_rocm_aiter_fp8bmm_enabled(): - # Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V) - x = aiter_triton_fp8_bmm(x, - self.W_V, - self.W_V_scale, - group_size=128, - transpose_bm=True) - # Convert from (B, N, V) to (B, N * V) - x = x.reshape(-1, self.num_heads * self.v_head_dim) - else: - # Multiply (N, B, L) x (N, L, V) -> (N, B, V) - x = torch.bmm(x, self.W_UV) - # Convert from (N, B, V) to (B, N * V) - x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) - return x - - def process_weights_after_loading(self, act_dtype: torch.dtype): - - def get_layer_weight(layer): - WEIGHT_NAMES = ("weight", "qweight", "weight_packed") - for attr in WEIGHT_NAMES: - if hasattr(layer, attr): - return getattr(layer, attr) - raise AttributeError( - f"Layer '{layer}' has no recognized weight attribute:" - f" {WEIGHT_NAMES}.") - - def get_and_maybe_dequant_weights(layer: LinearBase): - if not isinstance(layer.quant_method, UnquantizedLinearMethod): - # NOTE: This should only be used offline, since it's O(N^3) - eye = torch.eye(layer.input_size_per_partition, - dtype=act_dtype, - device=get_layer_weight(layer).device) - dequant_weights = layer.quant_method.apply(layer, - eye, - bias=None) - del eye - # standardize to (output, input) - return dequant_weights.T - return layer.weight - - # we currently do not have quantized bmm's which are needed for - # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform - # the bmm's in 16-bit, the extra memory overhead of this is fairly low - kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T - assert kv_b_proj_weight.shape == ( - self.kv_lora_rank, - self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), ( - f"{kv_b_proj_weight.shape=}, " - f"{self.kv_lora_rank=}, " - f"{self.num_heads=}, " - f"{self.qk_nope_head_dim=}, " - f"{self.v_head_dim=}") - kv_b_proj_weight = kv_b_proj_weight.view( - self.kv_lora_rank, - self.num_heads, - self.qk_nope_head_dim + self.v_head_dim, - ) - - W_UK, W_UV = kv_b_proj_weight.split( - [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - - if is_rocm_aiter_fp8bmm_enabled(): - W_K = W_UK.transpose(0, 1) # 16 512 128 - W_V = W_UV.permute(1, 2, 0) # 16 128 512 - self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant( - W_K, dtype=current_platform.fp8_dtype()) - self.W_V, self.W_V_scale = dynamic_per_batched_tensor_quant( - W_V, dtype=current_platform.fp8_dtype()) - - # The kernel operates on non-padded inputs. Hence, pre-compiling - # triton kernel to avoid runtime compilation for unseen batch sizes - # Pre-compile for batch sizes 1 to 1024 to cover most use-cases. - # On DS-R1, this step adds roughly 50s to the model loading time. - max_batch_size = 1024 # [ToDo] Find the optimal upper limit - pre_compilation_list = list(range(1, max_batch_size + 1)) - if is_global_first_rank(): - pre_compilation_list = tqdm( - pre_compilation_list, - desc="[Aiter Triton] Pre-compiling fp8 BMM kernel", - total=max_batch_size, - ) - - for m in pre_compilation_list: - x = torch.empty((self.W_K.shape[0], m, self.W_K.shape[2]), - dtype=torch.bfloat16, - device=self.W_K.device) - aiter_triton_fp8_bmm(x, - self.W_K, - self.W_K_scale, - group_size=128, - transpose_bm=True) - - x = torch.empty((self.W_V.shape[0], m, self.W_V.shape[2]), - dtype=torch.bfloat16, - device=self.W_V.device) - aiter_triton_fp8_bmm(x, - self.W_V, - self.W_V_scale, - group_size=128, - transpose_bm=True) - else: - # Convert from (L, N, V) to (N, L, V) - self.W_UV = W_UV.transpose(0, 1) - # Convert from (L, N, P) to (N, P, L) - self.W_UK_T = W_UK.permute(1, 2, 0) - def _compute_prefill_context( self, q: torch.Tensor, diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 1a0b91ebc544..238817350ff1 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -1,35 +1,55 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math from dataclasses import dataclass -from typing import Optional, Union, ClassVar +from typing import TYPE_CHECKING, ClassVar, Optional import numpy as np import torch from vllm import _custom_ops as ops -from vllm.attention.backends.abstract import AttentionLayer, AttentionMetadata -from vllm.attention.ops.flashmla import flash_mla_sparse_prefill +from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, + AttentionMetadata) +from vllm.attention.backends.utils import get_mla_dims +from vllm.attention.ops.flashmla import (flash_mla_sparse_prefill, + flash_mla_with_kvcache, + get_mla_metadata) from vllm.config import VllmConfig -from vllm.distributed.parallel_state import get_tensor_model_parallel_rank from vllm.logger import init_logger -from vllm.v1.attention.backends.mla.common import (MLACommonBackend, - MLACommonDecodeMetadata, - MLACommonImpl, - MLACommonMetadata, - MLACommonMetadataBuilder) -from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, - split_decodes_and_prefills) +from vllm.triton_utils import tl, triton +from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl +from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, + CommonAttentionMetadata) from vllm.v1.kv_cache_interface import AttentionSpec -import triton -import triton.language as tl -from typing import TYPE_CHECKING + if TYPE_CHECKING: from vllm.model_executor.models.deepseek_v2 import Indexer logger = init_logger(__name__) +""" +NOTE: FlashMLA Sparse uses an fp8 cache with the following format + +In the "FP8 with scale" format, each token's KV cache is 656 Bytes, +structured as: +- **First 512 bytes:** The "quantized NoPE" part, containing 512 + `float8_e4m3` values. +- **Next 16 bytes:** Scale factors, containing 4 `float32` values. + The first `float32` is the scale for the first 128 `float8_e4m3` values, + the second for the next 128, and so on. +- **Last 128 bytes:** The "RoPE" part, containing 64 `bfloat16` values. This + part is not quantized for accuracy. +""" + + +def _lse2_to_lse(lse_base2: torch.Tensor) -> torch.Tensor: + # Convert base-2 LSE to natural-log LSE + # Keep FP32 for numerical stability during the merge. + return (lse_base2.to(torch.float32) * math.log(2.0)) + +class FlashMLASparseBackend(AttentionBackend): -class FlashMLASparseBackend(MLACommonBackend): + accept_output_buffer: bool = True @staticmethod def get_name() -> str: @@ -53,33 +73,51 @@ def get_kv_cache_shape( block_size: int, num_kv_heads: int, # assumed to be 1 for MLA head_size: int, + cache_dtype_str: str = "auto", ) -> tuple[int, ...]: - return (num_blocks, block_size, head_size) + if cache_dtype_str == "fp8_ds_mla": + # custom storage fromat is 656 bytes + # see FlashMLA readme.md for details + return (num_blocks, block_size, 656) + else: + return (num_blocks, block_size, head_size) @classmethod def get_supported_dtypes(cls) -> list[torch.dtype]: - print("try running get_supported_dtypes") - # TODO: verify this - return [torch.float16, torch.bfloat16] + return [torch.bfloat16] @classmethod def get_supported_head_sizes(cls) -> list[int]: - # TODO: verify this return [576] +@dataclass class MLASparsePrefillMetadata: # NOTE(Chen): not call it "FlashMLASparsePrefillMetadata" because # the kernel is not from flashmla - def __init__(self, block_table: torch.Tensor, - req_id_per_token: torch.Tensor): - pass - + block_table: torch.Tensor + has_context: bool = False + context_lens: Optional[torch.Tensor] = None -class FlashMLASparseDecodeMetadata(MLACommonDecodeMetadata): - def __init__(self): - pass +@dataclass +class FlashMLASparseDecodeAndContextMetadata: + scheduler_metadata: torch.Tensor = None + num_splits: torch.Tensor = None + cache_lens: torch.Tensor = None + prefill_context_lengths: Optional[torch.Tensor] = None + prefill_new_k_start_locs: Optional[torch.Tensor] = None + dummy_block_table: torch.Tensor = None + + def filter_prefill_indices( + self, indices: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + assert self.prefill_context_lengths is not None + prefill_context_lengths = self.prefill_context_lengths.unsqueeze(-1) + context_indices = torch.where(indices < prefill_context_lengths, + indices, -1) + new_token_indices = torch.where(indices >= prefill_context_lengths, + indices - prefill_context_lengths, -1) + return context_indices, new_token_indices @dataclass @@ -97,9 +135,14 @@ class FlashMLASparseMetadata: block_size: int = 64 topk_tokens: int = 2048 - # For now just create topk_indices that just attend to the first topk tokens - # always to enable development - debug_topk_indices: Optional[torch.Tensor] = None + @dataclass + class FP8KernelMetadata: + scheduler_metadata: Optional[torch.Tensor] + num_splits: torch.Tensor + dummy_block_table: torch.Tensor + cache_lens: torch.Tensor + + fp8_extra_metadata: Optional[FP8KernelMetadata] = None @triton.jit @@ -167,18 +210,21 @@ def triton_convert_req_index_to_global_index( ): """ out[token_id, indice_id] = - block_table[req_id[token_id], token_indices[token_id, indice_id] // BLOCK_SIZE] * BLOCK_SIZE + block_table[req_id[token_id], + token_indices[token_id, indice_id] // BLOCK_SIZE] * BLOCK_SIZE + token_indices[token_id, indice_id] % BLOCK_SIZE Only when token_indices[token_id, indice_id] == -1 do we output -1. - For safety, we also output -1 if the derived block_id would be out-of-bounds. + For safety, we also output -1 if the derived block_id would be + out-of-bounds. """ assert req_id.dtype == torch.int32 assert block_table.dtype == torch.int32 assert token_indices.dtype == torch.int32 assert token_indices.shape[1] == NUM_TOPK_TOKENS assert NUM_TOPK_TOKENS % BLOCK_N == 0, \ - f"NUM_TOPK_TOKENS ({NUM_TOPK_TOKENS}) must be divisible by BLOCK_N ({BLOCK_N})" + f"NUM_TOPK_TOKENS ({NUM_TOPK_TOKENS}) must be divisible by" \ + f"BLOCK_N ({BLOCK_N})" num_tokens = req_id.shape[0] num_requests, max_num_blocks_per_req = block_table.shape @@ -220,58 +266,78 @@ def triton_convert_req_index_to_global_index( @dataclass class FlashMLASparseMetadataBuilder( - MLACommonMetadataBuilder[FlashMLASparseMetadata]): + AttentionMetadataBuilder[FlashMLASparseMetadata]): + + reorder_batch_threshold: ClassVar[int] = 128 # TODO(lucas): tune this reorder_batch_threshold: ClassVar[int] = 1 def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): - super().__init__(kv_cache_spec, layer_names, vllm_config, device, - FlashMLASparseMetadata) + + cache_config = vllm_config.cache_config + self.kv_cache_spec = kv_cache_spec + self.model_config = vllm_config.model_config + parallel_config = vllm_config.parallel_config + self.device = device + + self.num_heads = self.model_config.get_num_attention_heads( + parallel_config) + self.mla_dims = get_mla_dims(self.model_config) self.topk_tokens = vllm_config.model_config.hf_config\ .attn_module_list_cfg[0]["topk_tokens"] + self.use_fp8_kv_cache = cache_config.cache_dtype == "fp8_ds_mla" + self.topk_tokens_tensor = torch.tensor([self.topk_tokens], + device=device, + dtype=torch.int32) + self.max_model_len_tensor = torch.tensor( + [self.model_config.max_model_len], + device=device, + dtype=torch.int32) + # this is ignored by `flash_mla_with_kvcache` if indices not None + self.dummy_block_table = torch.empty((1, 1), + dtype=torch.int32, + device=self.device) self.num_speculative_tokens = ( - vllm_config.speculative_config.num_speculative_tokens - if vllm_config.speculative_config else 0 - ) + vllm_config.speculative_config.num_speculative_tokens + if vllm_config.speculative_config else 0) self.reorder_batch_threshold += self.num_speculative_tokens - def _build_prefill( - self, common_attn_metadata: CommonAttentionMetadata - ) -> MLASparsePrefillMetadata: - return MLASparsePrefillMetadata() - - def _build_decode( - self, common_attn_metadata: CommonAttentionMetadata - ) -> FlashMLASparseDecodeMetadata: - return FlashMLASparseDecodeMetadata() - def build(self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, fast_build: bool = False) -> FlashMLASparseMetadata: - num_actual_tokens = common_attn_metadata.num_actual_tokens + num_tokens = common_attn_metadata.num_actual_tokens starts = np.asarray(common_attn_metadata.query_start_loc_cpu, dtype=np.int32) seg_lengths = np.diff(starts) req_id_per_token = np.repeat( np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths) + req_id_per_token = torch.from_numpy(req_id_per_token)\ + .to(device='cuda', non_blocking=True) + + fp8_extra_metadata = None + if self.use_fp8_kv_cache: + tile_scheduler_metadata, num_splits = get_mla_metadata( + cache_seqlens=self.topk_tokens_tensor, + num_q_tokens_per_head_k=num_tokens * self.num_heads, + topk=self.topk_tokens, + num_heads_q=self.num_heads, + num_heads_k=1, + is_fp8_kvcache=True, + ) + fp8_extra_metadata = FlashMLASparseMetadata.FP8KernelMetadata( + scheduler_metadata=tile_scheduler_metadata, + num_splits=num_splits, + # cache_lens and block_table are basically unused in sparse case + # but the decode kernel will treat -1 and indices >= cache_lens + # as invalid so we make sure cache_lens is large enough to not + # accidentally mark indices invalid, we will use -1 exclusively + # to mark invalid indices + cache_lens=self.max_model_len_tensor, + dummy_block_table=self.dummy_block_table) - # pos = np.arange(starts[-1]) - np.repeat(starts[:-1], np.diff(starts)) - # seq_lengths = np.asarray(common_attn_metadata.seq_lens_cpu, - # dtype=np.int32) - # prefix_length = seq_lengths - seg_lengths - # prefix_length_per_token = np.repeat(prefix_length, seg_lengths) - # pos = pos + prefix_length_per_token - # pos_gpu = torch.as_tensor(pos, device=self.device, dtype=torch.long) - # row = torch.arange(self.topk_tokens, - # device=self.device, - # dtype=torch.int32) - # debug_topk_indices = row.repeat(num_actual_tokens, 1) - # mask = debug_topk_indices <= pos_gpu.unsqueeze(1) - # debug_topk_indices = debug_topk_indices.masked_fill(~mask, -1) - debug_topk_indices = None metadata = FlashMLASparseMetadata( num_reqs=common_attn_metadata.num_reqs, max_query_len=common_attn_metadata.max_query_len, @@ -280,20 +346,15 @@ def build(self, query_start_loc=common_attn_metadata.query_start_loc, slot_mapping=common_attn_metadata.slot_mapping, block_table=common_attn_metadata.block_table_tensor, - req_id_per_token=torch.from_numpy(req_id_per_token).to( - device='cuda'), - # num_decodes=num_decodes, - # num_decode_tokens=num_decode_tokens, - # num_prefills=num_prefills, + req_id_per_token=req_id_per_token, block_size=self.kv_cache_spec.block_size, topk_tokens=self.topk_tokens, - debug_topk_indices=debug_topk_indices, + fp8_extra_metadata=fp8_extra_metadata, ) return metadata -@dataclass -class FlashMLASparseImpl(MLACommonImpl[FlashMLASparseMetadata]): +class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]): def __init__( self, @@ -311,19 +372,60 @@ def __init__( topk_indice_buffer: Optional[torch.Tensor] = None, indexer: Optional["Indexer"] = None, **mla_args) -> None: - super().__init__(num_heads, - head_size, - scale, - num_kv_heads, - alibi_slopes, - sliding_window, - kv_cache_dtype, - logits_soft_cap, - attn_type, - kv_sharing_target_layer_name, - indexer=indexer, - **mla_args) - self.topk_indice_buffer = indexer.topk_indices_buffer + super().__init__(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype, + logits_soft_cap, attn_type, + kv_sharing_target_layer_name, **mla_args) + self.softmax_scale = scale + assert indexer is not None + self.topk_indices_buffer = indexer.topk_indices_buffer + + def _forward_bf16_kv( + self, q: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, + topk_indices: torch.Tensor, + attn_metadata: FlashMLASparseMetadata) -> torch.Tensor: + num_tokens = q.shape[0] + kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view( + -1, 1, kv_c_and_k_pe_cache.shape[-1]) + + # NOTE(Chen): kernel requires num_local_head to be a multiple of 64. + if self.num_heads % 64 != 0: + assert 64 % self.num_heads == 0 + logger.warning_once( + "padding num_heads to 64 due to sparse attn kernel requirement" + ) + q_padded = q.new_empty((q.shape[0], 64, q.shape[2])) + q_padded[:, :self.num_heads, :] = q + q = q_padded + + topk_indices = topk_indices.view(num_tokens, 1, -1) + output = flash_mla_sparse_prefill(q, kv_c_and_k_pe_cache, topk_indices, + self.softmax_scale)[0] + output = output[:, :self.num_heads, :] + return output + + def _forward_fp8_kv(self, q: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + topk_indices: torch.Tensor, + attn_metadata: FlashMLASparseMetadata) -> torch.Tensor: + + assert attn_metadata.fp8_extra_metadata is not None + extra_metadata = attn_metadata.fp8_extra_metadata + + _attn_out, _ = flash_mla_with_kvcache( + q=q.unsqueeze(0), # unsqueeze to add batch_dim + k_cache=kv_c_and_k_pe_cache.view(torch.uint8).unsqueeze(-2), + block_table=extra_metadata.dummy_block_table, + head_dim_v=512, + cache_seqlens=extra_metadata.cache_lens, + tile_scheduler_metadata=extra_metadata.scheduler_metadata, + num_splits=extra_metadata.num_splits, + is_fp8_kvcache=True, + indices=topk_indices.unsqueeze(0), # unsqueeze to add batch_dim + softmax_scale=self.softmax_scale, + ) + + return _attn_out def forward( self, @@ -354,14 +456,13 @@ def forward( # same expert outputs. return output.fill_(0) - num_actual_tokens = attn_metadata.num_actual_tokens + num_actual_toks = attn_metadata.num_actual_tokens # Inputs and outputs may be padded for CUDA graphs - output_padded = output - output = output[:num_actual_tokens, ...] - q = q[:num_actual_tokens, ...] - k_c_normed = k_c_normed[:num_actual_tokens, ...] - k_pe = k_pe[:num_actual_tokens, ...] + + q = q[:num_actual_toks, ...] + k_c_normed = k_c_normed[:num_actual_toks, ...] + k_pe = k_pe[:num_actual_toks, ...] q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) @@ -372,6 +473,19 @@ def forward( # Convert from (N, B, L) to (B, N, L) ql_nope = ql_nope.transpose(0, 1) + topk_indices = self.topk_indices_buffer[:num_actual_toks] + + # TODO: handle index / kv_cache correctly + topk_indices_global = triton_convert_req_index_to_global_index( + attn_metadata.req_id_per_token, + attn_metadata.block_table, + topk_indices, + BLOCK_SIZE=attn_metadata.block_size, + NUM_TOPK_TOKENS=attn_metadata.topk_tokens, + ) + + q = torch.cat([ql_nope, q_pe], dim=-1) + # write the latent and rope to kv cache if kv_cache.numel() > 0: ops.concat_and_cache_mla( @@ -383,68 +497,12 @@ def forward( scale=layer._k_scale, ) - attn_out = self._forward_bf16_kv(ql_nope, q_pe, kv_cache, - attn_metadata, self.scale) - - output[:num_actual_tokens] = self._v_up_proj( - attn_out[:num_actual_tokens]) - return output_padded + if self.kv_cache_dtype != "fp8_ds_mla": + attn_out = self._forward_bf16_kv(q, kv_cache, topk_indices_global, + attn_metadata) + else: + attn_out = self._forward_fp8_kv(q, kv_cache, topk_indices_global, + attn_metadata) - def _forward_bf16_kv(self, ql_nope: torch.Tensor, q_pe: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, - attn_metadata: FlashMLASparseMetadata, - k_scale: torch.Tensor) -> torch.Tensor: - topk_indices = self.topk_indice_buffer - num_tokens = attn_metadata.num_actual_tokens - q = torch.cat([ql_nope, q_pe], dim=-1) - kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view( - -1, 1, kv_c_and_k_pe_cache.shape[-1]) - # NOTE(Chen): kernel requires num_local_head to be a multiple of 64. - if self.num_heads % 64 != 0: - assert 64 % self.num_heads == 0 - logger.warning_once( - f"padding num_heads to 64 due to sparse attn kernel requirement" - ) - q_padded = q.new_empty((q.shape[0], 64, q.shape[2])) - q_padded[:, :self.num_heads, :] = q - q = q_padded - # TODO: handle index / kv_cache correctly - topk_indices_global = triton_convert_req_index_to_global_index( - attn_metadata.req_id_per_token[:num_tokens], - attn_metadata.block_table, - topk_indices[:num_tokens], - BLOCK_SIZE=attn_metadata.block_size, - NUM_TOPK_TOKENS=attn_metadata.topk_tokens, - ) - topk_indices_global = topk_indices_global.view(num_tokens, 1, -1) - output = flash_mla_sparse_prefill(q[:num_tokens], kv_c_and_k_pe_cache, - topk_indices_global, k_scale)[0] - output = output[:, :self.num_heads, :] + output[:num_actual_toks] = self._v_up_proj(attn_out) return output - - def _forward_decode( - self, - q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], - kv_c_and_k_pe_cache: torch.Tensor, - attn_metadata: FlashMLASparseMetadata, - layer: AttentionLayer, - topk_indices: Optional[torch.Tensor] = None, # sparse attn - ) -> torch.Tensor: - - topk_indices = self.topk_indices[:attn_metadata.num_decodes] - - # # assume indice of shape [num_decode_tokens, topk] - # block_id_in_req = topk_indices // self.block_size - - logger.info("called _forward_decode with topk_indices shape %s", - topk_indices.shape) - - ql_nope, q_pe = q - - attn_out = torch.zeros((ql_nope.shape[0], ql_nope.shape[1], 512), - dtype=ql_nope.dtype, - device=ql_nope.device) - lse = None #TODO - - # NOTE(Chen): shape is unsure - return attn_out, lse diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index 2cefef6206a7..c2c2787c9a1a 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -33,6 +33,7 @@ def get_kv_cache_shape( block_size: int, num_kv_heads: int, head_size: int, + cache_dtype_str: str = "auto", ) -> tuple[int, ...]: assert num_kv_heads == 1 return (num_blocks, block_size, head_size) diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index 26f9abf13d0e..f05c3a7e93a9 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -107,6 +107,7 @@ def get_kv_cache_shape( block_size: int, num_kv_heads: int, head_size: int, + cache_dtype_str: str = "auto", ) -> tuple[int, ...]: padded_head_size = cdiv( head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index afb2283c44d3..1fffe4a6b191 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -360,6 +360,7 @@ def get_kv_cache_shape( block_size: int, num_kv_heads: int, head_size: int, + cache_dtype_str: str = "auto", ) -> tuple[int, ...]: if block_size % 16 != 0: raise ValueError("Block size must be a multiple of 16.") diff --git a/vllm/v1/attention/backends/tree_attn.py b/vllm/v1/attention/backends/tree_attn.py index 10238f36455d..14e3c57a8683 100644 --- a/vllm/v1/attention/backends/tree_attn.py +++ b/vllm/v1/attention/backends/tree_attn.py @@ -68,6 +68,7 @@ def get_kv_cache_shape( block_size: int, num_kv_heads: int, head_size: int, + cache_dtype_str: str = "auto", ) -> tuple[int, ...]: if block_size % 16 != 0: raise ValueError("Block size must be a multiple of 16.") diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 784912a122f6..2dbccb1d284d 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -179,6 +179,7 @@ def get_kv_cache_shape( block_size: int, num_kv_heads: int, head_size: int, + cache_dtype_str: str = "auto", ) -> tuple[int, ...]: if block_size % 16 != 0: raise ValueError("Block size must be a multiple of 16.") diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 63326d19194f..6336ba1d2629 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -694,7 +694,6 @@ def split_decodes_and_prefills( return num_reqs, 0, num_tokens, 0 first_prefill = is_prefill.int().argmax(dim=-1).item() - assert torch.all(query_lens[first_prefill:] > decode_threshold) assert torch.all(query_lens[:first_prefill] <= decode_threshold) num_decodes = first_prefill num_prefills = num_reqs - num_decodes diff --git a/vllm/v1/attention/backends/xformers.py b/vllm/v1/attention/backends/xformers.py index a6ca33491235..88ecdfcd00f5 100644 --- a/vllm/v1/attention/backends/xformers.py +++ b/vllm/v1/attention/backends/xformers.py @@ -106,6 +106,7 @@ def get_kv_cache_shape( block_size: int, num_kv_heads: int, head_size: int, + cache_dtype_str: str = "auto", ) -> tuple[int, ...]: if block_size % 16 != 0: raise ValueError("Block size must be a multiple of 16.") diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 55cc7ea5a265..2ff1bb681d80 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -1130,7 +1130,6 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]): num_kv_heads=spec.num_kv_heads, head_size=spec.head_size, dtype=spec.dtype, - use_mla=spec.use_mla, sliding_window=spec.sliding_window, ) elif isinstance(spec, ChunkedLocalAttentionSpec): @@ -1139,7 +1138,6 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]): num_kv_heads=spec.num_kv_heads, head_size=spec.head_size, dtype=spec.dtype, - use_mla=spec.use_mla, attention_chunk_size=spec.attention_chunk_size, ) diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index d27239164b0d..58fe12aef0a9 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -9,6 +9,7 @@ from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec, CrossAttentionSpec, FullAttentionSpec, + MLAAttentionSpec, KVCacheSpec, MambaSpec, SlidingWindowSpec) from vllm.v1.request import Request @@ -656,6 +657,7 @@ def remove_skipped_blocks(self, request_id: str, spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = { FullAttentionSpec: FullAttentionManager, + MLAAttentionSpec: FullAttentionManager, SlidingWindowSpec: SlidingWindowManager, ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager, MambaSpec: MambaManager, diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index f72cc8f93a6c..281816653540 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -59,13 +59,10 @@ class AttentionSpec(KVCacheSpec): num_kv_heads: int head_size: int dtype: torch.dtype - use_mla: bool @property def page_size_bytes(self) -> int: - # For MLA we only store a single latent vector - coef = 1 if self.use_mla else 2 - return coef * self.block_size * self.num_kv_heads * self.head_size \ + return 2 * self.block_size * self.num_kv_heads * self.head_size \ * get_dtype_size(self.dtype) @@ -118,12 +115,13 @@ def merge(cls, specs: list[Self]) -> Self: if spec.sliding_window is not None) attention_chunk_size = set(spec.attention_chunk_size for spec in specs if spec.attention_chunk_size is not None) + assert not any(isinstance(spec, MLAAttentionSpec) for spec in specs), ( + "MLAAttentionSpec should be merged in MLAAttentionSpec.merge") merged_spec = cls( block_size=specs[0].block_size, num_kv_heads=specs[0].num_kv_heads, head_size=specs[0].head_size, dtype=specs[0].dtype, - use_mla=specs[0].use_mla, sliding_window=cls.merge_window_sizes(sliding_window), attention_chunk_size=cls.merge_window_sizes(attention_chunk_size), ) @@ -140,6 +138,38 @@ def merge(cls, specs: list[Self]) -> Self: return merged_spec +@dataclass(frozen=True) +class MLAAttentionSpec(FullAttentionSpec): + # TODO(Lucas/Chen): less hacky way to do this + cache_dtype_str: Optional[str] = None + + @property + def page_size_bytes(self) -> int: + if self.cache_dtype_str == "fp8_ds_mla": + # See `vllm/v1/attention/backends/mla/flashmla_sparse.py` + # for details. + return self.block_size * 656 + return self.block_size * self.num_kv_heads * self.head_size \ + * get_dtype_size(self.dtype) + + @classmethod + def merge(cls, specs: list[Self]) -> Self: + assert all(isinstance(spec, MLAAttentionSpec) for spec in specs), ( + "All attention layers in the same KV cache group must be " + "MLAAttentionSpec.") + cache_dtype_str_set = set(spec.cache_dtype_str for spec in specs) + assert len(cache_dtype_str_set) == 1, ( + "All attention layers in the same KV cache group must use the same " + "quantization method.") + return cls( + block_size=specs[0].block_size, + num_kv_heads=specs[0].num_kv_heads, + head_size=specs[0].head_size, + dtype=specs[0].dtype, + cache_dtype_str=cache_dtype_str_set.pop(), + ) + + @dataclass(frozen=True) class ChunkedLocalAttentionSpec(AttentionSpec): attention_chunk_size: int @@ -163,9 +193,6 @@ def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: class SlidingWindowSpec(AttentionSpec): sliding_window: int - def __post_init__(self): - assert not self.use_mla, "MLA is not supported for sliding window" - def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: assert vllm_config.parallel_config.decode_context_parallel_size == 1, \ "DCP not support sliding window." @@ -266,9 +293,13 @@ def is_uniform_type(cls, kv_cache_specs: dict[str, KVCacheSpec]) -> bool: # Different block sizes, not uniform. return False one_spec = next(iter(kv_cache_specs.values())) - if isinstance(one_spec, (FullAttentionSpec, CrossAttentionSpec)): + if isinstance(one_spec, FullAttentionSpec): + return all( + isinstance(spec, FullAttentionSpec) + for spec in kv_cache_specs.values()) + elif isinstance(one_spec, CrossAttentionSpec): return all( - isinstance(spec, type(one_spec)) + isinstance(spec, CrossAttentionSpec) for spec in kv_cache_specs.values()) elif isinstance(one_spec, SlidingWindowSpec): return all( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e6a492266b53..3f4e46a5ce7e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -75,7 +75,8 @@ EncoderOnlyAttentionSpec, FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, KVCacheSpec, - MambaSpec, SlidingWindowSpec, + MambaSpec, MLAAttentionSpec, + SlidingWindowSpec, UniformTypeKVCacheSpecs) # yapf: enable from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, @@ -3726,8 +3727,11 @@ def _reshape_kv_cache_tensors( if isinstance(kv_cache_spec, AttentionSpec): has_attn = True kv_cache_shape = attn_backend.get_kv_cache_shape( - num_blocks, kv_cache_spec.block_size, - kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) + num_blocks, + kv_cache_spec.block_size, + kv_cache_spec.num_kv_heads, + kv_cache_spec.head_size, + cache_dtype_str=self.cache_config.cache_dtype) dtype = kv_cache_spec.dtype try: kv_cache_stride_order = \ @@ -3911,7 +3915,6 @@ def may_add_encoder_only_layers_to_kv_cache_config(self) -> None: Add encoder-only layers to the KV cache config. """ block_size = self.vllm_config.cache_config.block_size - use_mla = self.vllm_config.model_config.use_mla encoder_only_attn_specs: dict[AttentionSpec, list[str]] = defaultdict(list) attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) @@ -3921,8 +3924,7 @@ def may_add_encoder_only_layers_to_kv_cache_config(self) -> None: block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - use_mla=use_mla) + dtype=self.kv_cache_dtype) encoder_only_attn_specs[attn_spec].append(layer_name) self.runner_only_attn_layers.add(layer_name) if len(encoder_only_attn_specs) > 0: @@ -3944,6 +3946,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: block_size = self.vllm_config.cache_config.block_size use_mla = self.vllm_config.model_config.use_mla + cache_dtype_str = self.vllm_config.cache_config.cache_dtype kv_cache_spec: dict[str, KVCacheSpec] = {} attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) for layer_name, attn_module in attn_layers.items(): @@ -3963,13 +3966,21 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: # the attention backends if attn_module.attn_type == AttentionType.DECODER: if attn_module.sliding_window is not None: + assert not use_mla, "MLA is not supported for sliding" \ + "window" kv_cache_spec[layer_name] = SlidingWindowSpec( block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, dtype=self.kv_cache_dtype, - sliding_window=attn_module.sliding_window, - use_mla=use_mla) + sliding_window=attn_module.sliding_window) + elif use_mla: + kv_cache_spec[layer_name] = MLAAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=self.kv_cache_dtype, + cache_dtype_str=cache_dtype_str) elif self.attention_chunk_size is not None \ and isinstance(attn_module, ChunkedLocalAttention): kv_cache_spec[layer_name] = ChunkedLocalAttentionSpec( @@ -3977,22 +3988,19 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, dtype=self.kv_cache_dtype, - attention_chunk_size=self.attention_chunk_size, - use_mla=use_mla) + attention_chunk_size=self.attention_chunk_size) else: kv_cache_spec[layer_name] = FullAttentionSpec( block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - use_mla=use_mla) + dtype=self.kv_cache_dtype) elif attn_module.attn_type == AttentionType.ENCODER_DECODER: kv_cache_spec[layer_name] = CrossAttentionSpec( block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - use_mla=use_mla) + dtype=self.kv_cache_dtype) elif attn_module.attn_type in (AttentionType.ENCODER, AttentionType.ENCODER_ONLY): # encoder-only attention does not need KV cache. diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 48070c1e3e7c..c7d6dcd77b2c 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -528,7 +528,6 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: head_size=attn_module.head_size, dtype=self.kv_cache_dtype, sliding_window=attn_module.sliding_window, - use_mla=False, ) else: kv_cache_spec[layer_name] = FullAttentionSpec( @@ -536,7 +535,6 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, dtype=self.kv_cache_dtype, - use_mla=False, ) elif attn_module.attn_type in (AttentionType.ENCODER, AttentionType.ENCODER_ONLY): diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 530907012f70..6349f0e97592 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -73,7 +73,8 @@ def _allocate_kv_cache( ) -> List[torch.Tensor]: """Allocates KV cache on the specified device.""" kv_cache_generic_shape = self.attn_backend.get_kv_cache_shape( - num_blocks, self.block_size, self.num_kv_heads, self.head_size) + num_blocks, self.block_size, self.num_kv_heads, self.head_size, + self.cache_config.cache_dtype) pin_memory = is_pin_memory_available() if device == "cpu" else False kv_cache: List[torch.Tensor] = [] try: From 53df680c40659ba91c82f5baf1246db21fa6a3fc Mon Sep 17 00:00:00 2001 From: Xiaozhu Meng Date: Sun, 28 Sep 2025 19:17:21 -0700 Subject: [PATCH 42/82] Preliminary blackwell enablement (#54) * Pad flashmla_sparse to 128 on blackwell * adjust get_max_prefill_buffer_size * change comments --- vllm/v1/attention/backends/mla/flashmla_sparse.py | 13 ++++++++----- vllm/v1/attention/backends/mla/indexer.py | 3 ++- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 238817350ff1..969d94c5c0cf 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -21,6 +21,7 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, CommonAttentionMetadata) from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.platforms import current_platform if TYPE_CHECKING: from vllm.model_executor.models.deepseek_v2 import Indexer @@ -388,13 +389,15 @@ def _forward_bf16_kv( kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view( -1, 1, kv_c_and_k_pe_cache.shape[-1]) - # NOTE(Chen): kernel requires num_local_head to be a multiple of 64. - if self.num_heads % 64 != 0: - assert 64 % self.num_heads == 0 + # NOTE(Chen): kernel requires num_local_head to be a multiple of + # 64 on hopper and 128 on blackwell + padding = 128 if current_platform.is_device_capability(100) else 64 + if self.num_heads % padding != 0: + assert padding % self.num_heads == 0 logger.warning_once( - "padding num_heads to 64 due to sparse attn kernel requirement" + f"padding num_heads to {padding} due to sparse attn kernel requirement" ) - q_padded = q.new_empty((q.shape[0], 64, q.shape[2])) + q_padded = q.new_empty((q.shape[0], padding, q.shape[2])) q_padded[:, :self.num_heads, :] = q q = q_padded diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index c2c2787c9a1a..5eae6baa20ea 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -148,8 +148,9 @@ def kv_spans_from_batches(start_seq_loc: torch.Tensor, def get_max_prefill_buffer_size(vllm_config: VllmConfig): max_model_len = vllm_config.model_config.max_model_len max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens + max_num_seq = vllm_config.scheduler_config.max_num_seqs # NOTE(Chen): an estimated max size of flattened_kv. Need to double check. - return max_model_len + max_num_batched_tokens + return max_model_len * max_num_seq class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): From fd63ddc34541e4971b2aa134a8dfc66e41ef0495 Mon Sep 17 00:00:00 2001 From: Xiaozhu Meng Date: Mon, 29 Sep 2025 02:48:08 +0000 Subject: [PATCH 43/82] Move the logic of determining padding amount to class __init__ --- vllm/v1/attention/backends/mla/flashmla_sparse.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 969d94c5c0cf..7b31335ceb84 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -380,6 +380,7 @@ def __init__( self.softmax_scale = scale assert indexer is not None self.topk_indices_buffer = indexer.topk_indices_buffer + self.padding = 128 if current_platform.is_device_capability(100) else 64 def _forward_bf16_kv( self, q: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, @@ -391,13 +392,12 @@ def _forward_bf16_kv( # NOTE(Chen): kernel requires num_local_head to be a multiple of # 64 on hopper and 128 on blackwell - padding = 128 if current_platform.is_device_capability(100) else 64 - if self.num_heads % padding != 0: - assert padding % self.num_heads == 0 + if self.num_heads % self.padding != 0: + assert self.padding % self.num_heads == 0 logger.warning_once( - f"padding num_heads to {padding} due to sparse attn kernel requirement" + f"padding num_heads to {self.padding} due to sparse attn kernel requirement" ) - q_padded = q.new_empty((q.shape[0], padding, q.shape[2])) + q_padded = q.new_empty((q.shape[0], self.padding, q.shape[2])) q_padded[:, :self.num_heads, :] = q q = q_padded From 224f1ddd48e4a17820d446c07a91e462d512f69d Mon Sep 17 00:00:00 2001 From: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com> Date: Mon, 29 Sep 2025 11:02:14 +0800 Subject: [PATCH 44/82] Add indexer_k_quant_and_cache_kernel (#38) * Add indexer_k_quant_and_cache_kernel Signed-off-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com> * Accept 3D kv_cache buffer Signed-off-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com> * Address review comments Signed-off-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com> --------- Signed-off-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com> --- csrc/cache.h | 8 ++ csrc/cache_kernels.cu | 103 ++++++++++++++++++++++ csrc/torch_bindings.cpp | 5 ++ vllm/_custom_ops.py | 9 ++ vllm/model_executor/models/deepseek_v2.py | 2 +- 5 files changed, 126 insertions(+), 1 deletion(-) diff --git a/csrc/cache.h b/csrc/cache.h index fd230bec27fc..3a4fc92a6c25 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -56,3 +56,11 @@ void cp_gather_cache( torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES] torch::Tensor const& cu_seq_lens, // [BATCH+1] int64_t batch_size, std::optional seq_starts = std::nullopt); + +// Indexer K quantization and cache function +void indexer_k_quant_and_cache( + torch::Tensor& k, // [num_tokens, head_dim] + torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride] + torch::Tensor& slot_mapping, // [num_tokens] + int64_t quant_block_size, // quantization block size + const std::string& scale_fmt); diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 422f6907083f..992a1e8fc1f7 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -500,6 +500,64 @@ __global__ void concat_and_cache_ds_mla_kernel( src_val, scale_val); } +template +__global__ void indexer_k_quant_and_cache_kernel( + const scalar_t* __restrict__ k, // [num_tokens, head_dim] + cache_t* __restrict__ kv_cache, // [num_blocks, block_size, cache_stride] + const int64_t* __restrict__ slot_mapping, // [num_tokens] + const int token_stride, // stride for each token in k + const int head_dim, // dimension of each head + const int block_stride, // stride for each block in kv_cache + const int cache_token_stride, // stride for each token in kv_cache + const int cache_block_size, // num_tokens for each block in kv_cache + const int quant_block_size, // quantization block size + const bool use_ue8m0 // use ue8m0 scale format +) { + // NOTE: In each block of kv_cache, the quantized k and scales are stored separately. + // The first cache_block_size * head_dim elements are the quantized k values in FP8, + // and the last cache_block_size * head_dim * 4 / quant_block_size elements are the scales in FP32. + constexpr int VEC_SIZE = 4; + const int64_t token_idx = blockIdx.x; + const int64_t head_idx = (blockIdx.y * blockDim.y * blockDim.x + threadIdx.y * blockDim.x + threadIdx.x) * VEC_SIZE; + const int64_t slot_idx = slot_mapping[token_idx]; + const int64_t cache_block_idx = slot_idx / cache_block_size; + const int64_t cache_inblock_idx = slot_idx % cache_block_size; + + // NOTE: slot_idx can be -1 if the token is padded + if (slot_idx < 0 || (head_idx >= head_dim)) { + return; + } + + float2 k_val = (reinterpret_cast(k))[(token_idx * token_stride + head_idx) / VEC_SIZE]; + scalar_t* k_val_ptr = reinterpret_cast(&k_val); + float amax = 0.0f; + for (int i = 0; i < VEC_SIZE; i++) { + amax = fmaxf(amax, fabsf(float(k_val_ptr[i]))); + } + __syncwarp(); + + // Reduced amax + for (int mask = 16; mask > 0; mask /= 2) { + amax = fmaxf(amax, __shfl_xor_sync(unsigned(-1), amax, mask)); + } + __syncwarp(); + float scale = fmaxf(amax, 1e-4) / 448.0f; + if (use_ue8m0) { + scale = exp2f(ceilf(log2f(scale))); + } + + const int64_t cache_block_start_offset = cache_block_idx * block_stride; + const int64_t cache_inblock_offset = cache_inblock_idx * head_dim + head_idx; + const int64_t dst_k_vals_offset = cache_block_start_offset + cache_inblock_offset; + for (int i = 0; i < VEC_SIZE; i++) { + kv_cache[dst_k_vals_offset + i] = fp8::scaled_convert(k_val_ptr[i], scale); + } + if (threadIdx.x == 0) { + const int64_t dst_scale_offset = cache_block_start_offset + cache_block_size * head_dim + cache_inblock_offset * 4 / quant_block_size; + reinterpret_cast(kv_cache)[dst_scale_offset / 4] = scale; + } +} + } // namespace vllm // KV_T is the data type of key and value tensors. @@ -1062,3 +1120,48 @@ void cp_gather_cache( TORCH_CHECK(false, "Unsupported data type width: ", dtype_bits); } } + +// Macro to dispatch the kernel based on the data type. +#define CALL_INDEXER_K_QUANT_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \ + vllm::indexer_k_quant_and_cache_kernel \ + <<>>( \ + reinterpret_cast(k.data_ptr()), \ + reinterpret_cast(kv_cache.data_ptr()), \ + slot_mapping.data_ptr(), \ + k.stride(0), \ + k.size(1), \ + kv_cache.stride(0), \ + kv_cache.stride(1), \ + kv_cache.size(1), \ + quant_block_size, \ + use_ue8m0); + +void indexer_k_quant_and_cache( + torch::Tensor& k, // [num_tokens, head_dim] + torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride] + torch::Tensor& slot_mapping, // [num_tokens] + int64_t quant_block_size, // quantization block size + const std::string& scale_fmt) { + + int num_tokens = k.size(0); + int head_dim = k.size(1); + int cache_block_size = kv_cache.size(1); + int cache_stride = kv_cache.size(2); + bool use_ue8m0 = scale_fmt == "ue8m0"; + + TORCH_CHECK(k.device() == kv_cache.device(), + "k and kv_cache must be on the same device"); + TORCH_CHECK(k.device() == slot_mapping.device(), + "k and slot_mapping must be on the same device"); + TORCH_CHECK(head_dim % quant_block_size == 0, + "head_dim must be divisible by quant_block_size"); + + constexpr int vec_size = 4; + dim3 grid(num_tokens, (head_dim + quant_block_size * vec_size - 1) / (quant_block_size * vec_size)); + dim3 block(32, vec_size); + const at::cuda::OptionalCUDAGuard device_guard(device_of(k)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + DISPATCH_BY_KV_CACHE_DTYPE(k.dtype(), "fp8_e4m3", + CALL_INDEXER_K_QUANT_AND_CACHE); +} diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index bc096406c51a..9e7fbeb80bb3 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -713,6 +713,11 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { "cp_gather_cache(Tensor src_cache, Tensor! dst, Tensor block_table, " "Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()"); cache_ops.impl("cp_gather_cache", torch::kCUDA, &cp_gather_cache); + + cache_ops.def( + "indexer_k_quant_and_cache(Tensor k, Tensor! kv_cache, Tensor slot_mapping, " + "int quant_block_size, str kv_cache_dtype) -> ()"); + cache_ops.impl("indexer_k_quant_and_cache", torch::kCUDA, &indexer_k_quant_and_cache); } TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) { diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 712295aa9288..3c06cce130f7 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1671,6 +1671,15 @@ def cp_gather_cache(src_cache: torch.Tensor, cu_seq_lens, batch_size, seq_starts) +def indexer_k_quant_and_cache(k: torch.Tensor, + kv_cache: torch.Tensor, + slot_mapping: torch.Tensor, + quant_block_size: int, + kv_cache_dtype: str) -> None: + torch.ops._C_cache_ops.indexer_k_quant_and_cache(k, kv_cache, slot_mapping, + quant_block_size, kv_cache_dtype) + + def get_device_attribute(attribute: int, device: int) -> int: return torch.ops._C_cuda_utils.get_device_attribute(attribute, device) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 0ce54a594d44..39ecb72ac7e1 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -653,7 +653,7 @@ def sparse_attn_indexer( has_prefill = attn_metadata.num_prefills > 0 num_decode_tokens = attn_metadata.num_decode_tokens - indexer_k_quant_and_cache( + ops.indexer_k_quant_and_cache( k, kv_cache, slot_mapping, From b5ef289bf9bda92dc1a5470d2379ad4819316017 Mon Sep 17 00:00:00 2001 From: Yongye Zhu Date: Mon, 29 Sep 2025 01:07:36 -0400 Subject: [PATCH 45/82] remove tilelang dep (#57) Signed-off-by: Yongye Zhu --- vllm/model_executor/models/deepseek_v2.py | 38 +-- vllm/utils/tile_lang_kernels.py | 282 ---------------------- 2 files changed, 7 insertions(+), 313 deletions(-) delete mode 100644 vllm/utils/tile_lang_kernels.py diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 39ecb72ac7e1..1cfc645ff10a 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -556,35 +556,6 @@ def rotate_activation_fake(x: torch.Tensor, ) -> torch.Tensor: dispatch_key=current_platform.dispatch_key, ) - -def tilelang_act_quant( - x: torch.Tensor, - block_size: int, - scale_fmt: Optional[str], -) -> tuple[torch.Tensor, torch.Tensor]: - from vllm.utils.tile_lang_kernels import act_quant - return act_quant(x, block_size, scale_fmt) - - -def tilelang_act_quant_fake( - x: torch.Tensor, - block_size: int, - scale_fmt: Optional[str], -) -> tuple[torch.Tensor, torch.Tensor]: - return per_token_group_quant_fp8(x, - block_size, - column_major_scales=False, - use_ue8m0=scale_fmt is not None) - - -direct_register_custom_op( - op_name="tilelang_act_quant", - op_func=tilelang_act_quant, - mutates_args=[], - fake_impl=tilelang_act_quant_fake, - dispatch_key=current_platform.dispatch_key, -) - @torch.inference_mode() def indexer_k_quant_and_cache( k, @@ -871,8 +842,13 @@ def forward(self, hidden_states: torch.Tensor, qr: torch.Tensor, positions, ) #FIXME (siyuanf) hadacore_transform causes illegal memory access when applying to k # we only quant q here since k quant is fused with cache insertion - q_fp8, q_scale = torch.ops.vllm.tilelang_act_quant( - q, self.quant_block_size, self.scale_fmt) + q = q.view(-1, self.head_dim) + q_fp8, q_scale = per_token_group_quant_fp8(q, + self.quant_block_size, + column_major_scales=False, + use_ue8m0=self.scale_fmt is not None) + q_fp8 = q_fp8.view(-1, self.n_head, self.head_dim) + q_scale = q_scale.view(-1, self.n_head, 1) weights, _ = self.weights_proj(hidden_states) weights = weights.unsqueeze( diff --git a/vllm/utils/tile_lang_kernels.py b/vllm/utils/tile_lang_kernels.py deleted file mode 100644 index 5e4576fea45e..000000000000 --- a/vllm/utils/tile_lang_kernels.py +++ /dev/null @@ -1,282 +0,0 @@ -from typing import Optional, Tuple - -import tilelang -import tilelang.language as T -import torch - -tilelang.set_log_level("WARNING") - -pass_configs = { - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_FAST_MATH: True, -} - -FP8 = "float8_e4m3" -BF16 = "bfloat16" -FP32 = "float32" - - -def fast_log2_ceil(x): - bits_x = T.reinterpret("uint32", x) - exp_x = (bits_x >> 23) & 0xFF - man_bits = bits_x & ((1 << 23) - 1) - return T.Cast("int32", exp_x - 127 + T.if_then_else(man_bits != 0, 1, 0)) - - -def fast_pow2(x): - bits_x = (x + 127) << 23 - return T.reinterpret("float32", bits_x) - - -def fast_round_scale(amax, fp8_max_inv): - return fast_pow2(fast_log2_ceil(amax * fp8_max_inv)) - - -@tilelang.jit(pass_configs=pass_configs) -def act_quant_kernel(N, - in_dtype=BF16, - out_dtype=FP8, - scale_dtype=FP32, - round_scale=False): - M = T.symbolic("M") - fp8_min = -448.0 - fp8_max = 448.0 - fp8_max_inv = 1 / fp8_max - num_stages = 0 if round_scale else 2 - blk_m = 32 - group_size = 128 - - @T.prim_func - def act_quant_kernel_( - X: T.Tensor[(M, N), in_dtype], - Y: T.Tensor[(M, N), out_dtype], - S: T.Tensor[(M, T.ceildiv(N, group_size)), scale_dtype], - ): - with T.Kernel(T.ceildiv(M, blk_m), - T.ceildiv(N, group_size), - threads=128) as ( - pid_m, - pid_n, - ): - x_shared = T.alloc_shared((blk_m, group_size), in_dtype) - x_local = T.alloc_fragment((blk_m, group_size), in_dtype) - amax_local = T.alloc_fragment((blk_m, ), scale_dtype) - s_local = T.alloc_fragment((blk_m, ), scale_dtype) - y_local = T.alloc_fragment((blk_m, group_size), out_dtype) - y_shared = T.alloc_shared((blk_m, group_size), out_dtype) - - for _ in T.Pipelined(1, num_stages=num_stages): - T.copy(X[pid_m * blk_m, pid_n * group_size], x_shared) - T.copy(x_shared, x_local) - T.reduce_absmax(x_local, amax_local, dim=1) - for i in T.Parallel(blk_m): - amax_local[i] = T.max(amax_local[i], 1e-4) - if round_scale: - s_local[i] = fast_round_scale(amax_local[i], - fp8_max_inv) - else: - s_local[i] = amax_local[i] * fp8_max_inv - for i, j in T.Parallel(blk_m, group_size): - y_local[i, j] = T.clamp(x_local[i, j] / s_local[i], - fp8_min, fp8_max) - for i in T.Parallel(blk_m): - S[pid_m * blk_m + i, pid_n] = s_local[i] - T.copy(y_local, y_shared) - T.copy(y_shared, Y[pid_m * blk_m, pid_n * group_size]) - - return act_quant_kernel_ - - -def act_quant( - x: torch.Tensor, - block_size: int = 128, - scale_fmt: Optional[str] = None) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Quantizes the input tensor `x` using block-wise quantization. - - Args: - x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`. - block_size (int, optional): The size of the blocks to be used for quantization. Default is 128. - scale_fmt (Optional[str], optional): The format of the scale. Default is None. - Returns: - Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - The quantized tensor with dtype `torch.float8_e4m3fn`. - - A tensor of scaling factors with dtype `torch.float32`. - """ - assert x.is_contiguous(), "Input tensor must be contiguous" - assert x.size(-1) % block_size == 0, ( - f"Last dimension size must be divisible by block_size (block_size={block_size})" - ) - N = x.size(-1) - y = torch.empty_like(x, dtype=torch.float8_e4m3fn) - s = x.new_empty(*x.size()[:-1], N // block_size, dtype=torch.float32) - kernel = act_quant_kernel(N, round_scale=scale_fmt is not None) - kernel(x.view(-1, N), y.view(-1, N), s.view(-1, N // block_size)) - return y, s - - -@tilelang.jit(pass_configs=pass_configs) -def fp8_gemm_kernel(N, K, out_dtype=BF16, accum_dtype="float32"): - assert out_dtype in [BF16, "float32"] - - M = T.symbolic("M") - group_size = 128 - block_M = 32 - block_N = 128 - block_K = 128 - - @T.prim_func - def fp8_gemm_kernel_( - A: T.Tensor[(M, K), FP8], - B: T.Tensor[(N, K), FP8], - C: T.Tensor[(M, N), out_dtype], - scales_a: T.Tensor[(M, T.ceildiv(K, group_size)), FP32], - scales_b: T.Tensor[(T.ceildiv(N, group_size), - T.ceildiv(K, group_size)), FP32], - ): - with T.Kernel(T.ceildiv(N, block_N), - T.ceildiv(M, block_M), - threads=128) as ( - bx, - by, - ): - A_shared = T.alloc_shared((block_M, block_K), FP8) - B_shared = T.alloc_shared((block_N, block_K), FP8) - C_shared = T.alloc_shared((block_M, block_N), out_dtype) - Scale_C_shared = T.alloc_shared((block_M), FP32) - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - C_local_accum = T.alloc_fragment((block_M, block_N), accum_dtype) - - # Improve L2 Cache - T.use_swizzle(panel_size=10) - - T.clear(C_local) - T.clear(C_local_accum) - K_iters = T.ceildiv(K, block_K) - for k in T.Pipelined(K_iters, num_stages=4): - # Load A into shared memory - T.copy(A[by * block_M, k * block_K], A_shared) - # Load B into shared memory - T.copy(B[bx * block_N, k * block_K], B_shared) - # Load scale into shared memory - Scale_B = scales_b[bx * block_N // group_size, k] - for i in T.Parallel(block_M): - Scale_C_shared[i] = scales_a[by * block_M + i, k] * Scale_B - - T.gemm(A_shared, B_shared, C_local, transpose_B=True) - # Promote to enable 2xAcc - for i, j in T.Parallel(block_M, block_N): - C_local_accum[i, j] += C_local[i, j] * Scale_C_shared[i] - T.clear(C_local) - # TMA store - T.copy(C_local_accum, C_shared) - T.copy(C_shared, C[by * block_M, bx * block_N]) - - return fp8_gemm_kernel_ - - -def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, - b_s: torch.Tensor) -> torch.Tensor: - """ - Perform a matrix multiplication using FP8 precision. - - Args: - a (torch.Tensor): The first input matrix, must be contiguous. - a_s (torch.Tensor): The scaling factor for the first input matrix, must be contiguous. - b (torch.Tensor): The second input matrix, must be contiguous. - b_s (torch.Tensor): The scaling factor for the second input matrix, must be contiguous. - - Returns: - torch.Tensor: The result of the matrix multiplication. - """ - assert a.is_contiguous() and b.is_contiguous( - ), "Input tensors must be contiguous" - assert a_s.is_contiguous() and b_s.is_contiguous(), ( - "Scaling factor tensors must be contiguous") - K = a.size(-1) - M = a.numel() // K - N = b.size(0) - c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype()) - kernel = fp8_gemm_kernel(N, K) - kernel(a.view(M, K), b, c.view(M, N), a_s.view(M, -1), b_s) - return c - - -@tilelang.jit(out_idx=[4], pass_configs=pass_configs) -def fp8_index_kernel(h: int, d: int): - b = T.symbolic("b") - m = T.symbolic("m") - n = T.symbolic("n") - - blk_n1 = 512 - blk_n2 = 128 - - @T.prim_func - def fp8_index_kernel_( - q: T.Tensor[(b, m, h, d), FP8], - q_s: T.Tensor[(b, m, h), FP32], - k: T.Tensor[(b, n, d), FP8], - k_s: T.Tensor[(b, n), FP32], - o: T.Tensor[(b, m, n), FP32], - ) -> None: - with T.Kernel(b, m, T.ceildiv(n, blk_n1)) as (i_b, i_m, i1_n): - q_smem = T.alloc_shared((h, d), FP8) - T.copy(q[i_b, i_m, 0, 0], q_smem) - - q_s_frag = T.alloc_fragment(h, FP32) - T.copy(q_s[i_b, i_m, 0], q_s_frag) - - for i2_n in T.Pipelined(blk_n1 // blk_n2, num_stages=2): - k_smem = T.alloc_shared((blk_n2, d), FP8) - T.copy(k[i_b, i1_n * blk_n1 + i2_n * blk_n2, 0], k_smem) - - k_s_frag = T.alloc_fragment(blk_n2, FP32) - T.copy(k_s[i_b, i1_n * blk_n1 + i2_n * blk_n2], k_s_frag) - - logits = T.alloc_fragment((blk_n2, h), FP32) - T.gemm( - k_smem, - q_smem, - logits, - transpose_A=False, - transpose_B=True, - clear_accum=True, - ) - - for i_h, i3_n in T.Parallel(h, blk_n2): - logits[i3_n, - i_h] = T.max(logits[i3_n, i_h], 0) * q_s_frag[i_h] - - logits_sum = T.alloc_fragment(blk_n2, FP32) - T.reduce_sum(logits, logits_sum, dim=1) - - for i3_n in T.Parallel(blk_n2): - logits_sum[i3_n] *= k_s_frag[i3_n] - - T.copy(logits_sum, o[i_b, i_m, i1_n * blk_n1 + i2_n * blk_n2]) - - return fp8_index_kernel_ - - -def fp8_index( - q: torch.Tensor, - q_s: torch.Tensor, - k: torch.Tensor, - k_s: torch.Tensor, -) -> torch.Tensor: - """ - Perform index score using FP8 precision. - - Args: - q (torch.Tensor): The Q tensor, must be contiguous. - q_s (torch.Tensor): The scaling factor for Q (float), must be contiguous. - k (torch.Tensor): The K tensor, must be contiguous. - k_s (torch.Tensor): The scaling factor for K (e8m0 here), must be contiguous. - - fp8 q @ fp8 k -> fp32 logits - relu(fp32 logits) * q_s (weights) -> fp32 logits - fp32 logits -> fp32 logits_sum - fp32 logits_sum * k_s (e8m0) -> fp32 index_score - """ - return fp8_index_kernel(q.shape[2], q.shape[3])(q, q_s, k, k_s) From 0e12bdb8881a501e3c7d7e1a21223c3aec804dff Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Mon, 29 Sep 2025 05:22:02 +0000 Subject: [PATCH 46/82] default to fp8 Signed-off-by: Lucas Wilkinson --- vllm/config/cache.py | 20 +++++++++++++------- vllm/model_executor/models/config.py | 11 ++++++++++- 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/vllm/config/cache.py b/vllm/config/cache.py index 4c4e39c37ee5..bf13a18e0e0c 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -22,7 +22,8 @@ logger = init_logger(__name__) BlockSize = Literal[1, 8, 16, 32, 64, 128] -CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"] +CacheDType = Literal[ + "auto", "bfloat16", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"] MambaDType = Literal["auto", "float32"] PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor"] @@ -52,7 +53,11 @@ class CacheConfig: cache_dtype: CacheDType = "auto" """Data type for kv cache storage. If "auto", will use model data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ROCm (AMD GPU) supports - fp8 (=fp8_e4m3). Intel Gaudi (HPU) supports fp8 (using fp8_inc).""" + fp8 (=fp8_e4m3). Intel Gaudi (HPU) supports fp8 (using fp8_inc). + Some models (namely DeepSeekV3.2) default to fp8, set to bfloat16 to use + bfloat16 instead, this is an invalid option for models that do not default + to fp8. + """ is_attention_free: bool = False """Whether the model is attention-free. This is primarily set in `ModelConfig` and that value should be manually duplicated here.""" @@ -171,11 +176,12 @@ def _verify_cache_dtype(self) -> None: if self.cache_dtype == "auto": pass elif self.cache_dtype in get_args(CacheDType): - logger.info( - "Using fp8 data type to store kv cache. It reduces the GPU " - "memory footprint and boosts the performance. " - "Meanwhile, it may cause accuracy drop without a proper " - "scaling factor.") + if self.cache_dtype.startswith("fp8"): + logger.info( + "Using fp8 data type to store kv cache. It reduces the GPU " + "memory footprint and boosts the performance. " + "Meanwhile, it may cause accuracy drop without a proper " + "scaling factor.") else: raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}") diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 86348f8615b9..2c5f0bb64dec 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -427,9 +427,18 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: and "attn_index" in hf_config.attn_module_list_cfg[0] if is_v32: + # For DeepSeekV3.2, we use a custom fp8 format as default (i.e. + # "auto") cache_config = vllm_config.cache_config - if cache_config.cache_dtype.startswith("fp8"): + if cache_config.cache_dtype == "auto" or \ + cache_config.cache_dtype.startswith("fp8"): cache_config.cache_dtype = "fp8_ds_mla" + logger.info("Using custom fp8 format for DeepSeekV3.2") + if cache_config.cache_dtype == "bfloat16": + cache_config.cache_dtype = "auto" + logger.info( + "DeepSeekV3.2 cache kernels reject bfloat16; falling back " + "to activation dtype via auto.") MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = { From 82f0fa5f918cf0040e35b95815c84f0f2c9893dd Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Mon, 29 Sep 2025 05:26:43 +0000 Subject: [PATCH 47/82] fix up prints Signed-off-by: Lucas Wilkinson --- vllm/model_executor/models/config.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 2c5f0bb64dec..15d997a9303c 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -433,12 +433,10 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: if cache_config.cache_dtype == "auto" or \ cache_config.cache_dtype.startswith("fp8"): cache_config.cache_dtype = "fp8_ds_mla" - logger.info("Using custom fp8 format for DeepSeekV3.2") + logger.info("Using custom fp8 kv-cache format for DeepSeekV3.2") if cache_config.cache_dtype == "bfloat16": cache_config.cache_dtype = "auto" - logger.info( - "DeepSeekV3.2 cache kernels reject bfloat16; falling back " - "to activation dtype via auto.") + logger.info("Using bfloat16 kv-cache for DeepSeekV3.2") MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = { From ee3edfa41b5bf6710456e55a0566d74b547cbcf9 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Mon, 29 Sep 2025 01:36:21 -0400 Subject: [PATCH 48/82] Full-CG Support (#46) * full-cg support Signed-off-by: Lucas Wilkinson * fix non-spec error Signed-off-by: Lucas Wilkinson * fix invalid op in capture Signed-off-by: Lucas Wilkinson * cleanup Signed-off-by: Lucas Wilkinson --------- Signed-off-by: Lucas Wilkinson --- vllm/model_executor/models/deepseek_v2.py | 10 ++-- .../attention/backends/mla/flashmla_sparse.py | 56 ++++++++++++++++--- vllm/v1/attention/backends/mla/indexer.py | 54 ++++++++++++++---- vllm/v1/worker/gpu_model_runner.py | 4 +- 4 files changed, 99 insertions(+), 25 deletions(-) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 1cfc645ff10a..e8dce07bc741 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -678,12 +678,10 @@ def sparse_attn_indexer( if has_decode: decode_metadata = attn_metadata.decode # kv_cache size requirement [num_block, block_size, n_head, head_dim], - # we only have [num_block, block_size, head_dim], - query_start_loc = attn_metadata.query_start_loc - decode_lens = query_start_loc[1:attn_metadata.num_decodes+1] - query_start_loc[:attn_metadata.num_decodes] + # we only have [num_block, block_size, head_dim], kv_cache = kv_cache.unsqueeze(-2) - require_padding = (decode_lens.max() > decode_lens.min()).item() - if require_padding: + decode_lens = decode_metadata.decode_lens + if decode_metadata.requires_padding: # pad in edge case where we have short chunked prefill length < # decode_threshold since we unstrictly split # prefill and decode by decode_threshold (currently set to 1 + speculative tokens) @@ -721,7 +719,7 @@ def sparse_attn_indexer( # ensure we don't set indices for the top k that out of range(masked already) # this will happen if context length is shorter than K topk_indices[topk_indices > index_end_pos] = -1 - if require_padding: + if decode_metadata.requires_padding: # if padded, we need to unpack the topk indices removing padded tokens topk_indices = unpack_seq_triton(topk_indices.reshape(batch_size, -1, logits.shape[-1]), decode_lens) topk_indices_buffer[:num_decode_tokens, :topk_indices. diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 7b31335ceb84..7d70f3081996 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -17,8 +17,10 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.triton_utils import tl, triton +from vllm.utils import cdiv from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl -from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, +from vllm.v1.attention.backends.utils import (AttentionCGSupport, + AttentionMetadataBuilder, CommonAttentionMetadata) from vllm.v1.kv_cache_interface import AttentionSpec from vllm.platforms import current_platform @@ -268,6 +270,8 @@ def triton_convert_req_index_to_global_index( @dataclass class FlashMLASparseMetadataBuilder( AttentionMetadataBuilder[FlashMLASparseMetadata]): + cudagraph_support: ClassVar[AttentionCGSupport] = \ + AttentionCGSupport.UNIFORM_BATCH reorder_batch_threshold: ClassVar[int] = 128 # TODO(lucas): tune this @@ -282,6 +286,9 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], parallel_config = vllm_config.parallel_config self.device = device + props = torch.cuda.get_device_properties(device) + sm_count = props.multi_processor_count + self.num_heads = self.model_config.get_num_attention_heads( parallel_config) self.mla_dims = get_mla_dims(self.model_config) @@ -300,10 +307,34 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], dtype=torch.int32, device=self.device) self.num_speculative_tokens = ( - vllm_config.speculative_config.num_speculative_tokens - if vllm_config.speculative_config else 0) + vllm_config.speculative_config.num_speculative_tokens + if vllm_config.speculative_config else 0 + ) self.reorder_batch_threshold += self.num_speculative_tokens + # Equation taken from FlashMLA/csrc/pybind.cpp + h_q, h_k = self.num_heads, 1 + s_q = 1 # inversely proportional to s_q, so s_q = 1 is the largest + max_num_sm_parts = int( + max((sm_count // 2) / h_k // (cdiv(h_q // h_k, 2 * 64) * s_q), 1)) + + self.tile_scheduler_metadata_buffer = torch.empty( + # TileSchedulerMetaDataSize = 8 + # see: FlashMLA/csrc/params.h + (max_num_sm_parts, 8), + dtype=torch.int32, + device=device) + self.num_splits_buffer = torch.empty( + # We pack all the tokens into one batch for sparse attention. + # Otherwise, we can exceed the sm of `get_mla_metadata`. + (2, ), + dtype=torch.int32, + device=device) + self.req_id_per_token_buffer = torch.empty( + (vllm_config.scheduler_config.max_num_batched_tokens, ), + dtype=torch.int32, + device=device) + def build(self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, @@ -315,8 +346,11 @@ def build(self, seg_lengths = np.diff(starts) req_id_per_token = np.repeat( np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths) - req_id_per_token = torch.from_numpy(req_id_per_token)\ - .to(device='cuda', non_blocking=True) + # Zero-fill for cudagraphs + self.req_id_per_token_buffer.fill_(0) + self.req_id_per_token_buffer[:req_id_per_token.shape[0]]\ + .copy_(torch.from_numpy(req_id_per_token), non_blocking=True) + req_id_per_token = self.req_id_per_token_buffer[:num_tokens] fp8_extra_metadata = None if self.use_fp8_kv_cache: @@ -328,9 +362,17 @@ def build(self, num_heads_k=1, is_fp8_kvcache=True, ) + + num_sm_parts = tile_scheduler_metadata.size(0) + # Copy to persistent buffer for full-CG support + tile_scheduler_metadata_buffer = \ + self.tile_scheduler_metadata_buffer[:num_sm_parts] + tile_scheduler_metadata_buffer.copy_(tile_scheduler_metadata) + self.num_splits_buffer.copy_(num_splits) + fp8_extra_metadata = FlashMLASparseMetadata.FP8KernelMetadata( - scheduler_metadata=tile_scheduler_metadata, - num_splits=num_splits, + scheduler_metadata=tile_scheduler_metadata_buffer, + num_splits=self.num_splits_buffer, # cache_lens and block_table are basically unused in sparse case # but the decode kernel will treat -1 and indices >= cache_lens # as invalid so we make sure cache_lens is large enough to not diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index 5eae6baa20ea..6abd02622ed2 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -1,14 +1,19 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass from typing import ClassVar, Optional -from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata +import torch + +from vllm.attention.backends.abstract import (AttentionBackend, + AttentionMetadata) from vllm.config import VllmConfig -from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata, get_num_sms -from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, +from vllm.logger import init_logger +from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata +from vllm.v1.attention.backends.utils import (AttentionCGSupport, + AttentionMetadataBuilder, CommonAttentionMetadata, split_decodes_and_prefills) -import torch -from vllm.logger import init_logger logger = init_logger(__name__) @@ -58,6 +63,8 @@ class DeepseekV32IndexerPrefillMetadata: class DeepSeekV32IndexerDecodeMetadata: block_table: torch.Tensor seq_lens: torch.Tensor + decode_lens: torch.Tensor + requires_padding: bool schedule_metadata: torch.Tensor @@ -154,12 +161,13 @@ def get_max_prefill_buffer_size(vllm_config: VllmConfig): class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): + cudagraph_support: ClassVar[AttentionCGSupport] = \ + AttentionCGSupport.UNIFORM_BATCH reorder_batch_threshold: ClassVar[int] = 1 def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - max_model_len = self.vllm_config.model_config.max_model_len - max_num_batched_tokens = self.vllm_config.scheduler_config.max_num_batched_tokens + scheduler_config = self.vllm_config.scheduler_config # NOTE(Chen): an estimated max size of flattened_kv. Need to double check. self.max_prefill_buffer_size = get_max_prefill_buffer_size( self.vllm_config) @@ -169,6 +177,20 @@ def __init__(self, *args, **kwargs): ) self.reorder_batch_threshold += self.num_speculative_tokens + props = torch.cuda.get_device_properties(self.device) + sm_count = props.multi_processor_count + self.num_sms = sm_count + + self.decode_lens_buffer = torch.empty( + (scheduler_config.max_num_seqs, ), + dtype=torch.int32, + device=self.device) + + # See: DeepGMM/csrc/apis/attention.hpp + self.scheduler_metadata_buffer = torch.empty( + (self.num_sms + 1, 2), dtype=torch.int32, device=self.device + ) + def build(self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, @@ -216,14 +238,26 @@ def build(self, decode_metadata = None if num_decodes > 0: + torch.diff(common_attn_metadata.query_start_loc[:num_decodes+1], + out=self.decode_lens_buffer[:num_decodes]) + decode_lens = self.decode_lens_buffer[:num_decodes] + decode_lens_cpu = torch.diff( + common_attn_metadata.query_start_loc_cpu[:num_decodes+1]) + + # Use CPU to avoid GPU sync; breaking async scheduling + requires_padding = (decode_lens_cpu.max() > decode_lens_cpu.min()).item() + seq_lens = common_attn_metadata.seq_lens[:num_decodes] - schedule_metadata = get_paged_mqa_logits_metadata( - seq_lens, self.kv_cache_spec.block_size, get_num_sms()) + + self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata( + seq_lens, self.kv_cache_spec.block_size, self.num_sms) decode_metadata = DeepSeekV32IndexerDecodeMetadata( block_table=common_attn_metadata. block_table_tensor[:num_decodes, ...], seq_lens=common_attn_metadata.seq_lens[:num_decodes], - schedule_metadata=schedule_metadata, + decode_lens=decode_lens, + requires_padding=requires_padding, + schedule_metadata=self.scheduler_metadata_buffer, ) attn_metadata = DeepseekV32IndexerMetadata( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 3f4e46a5ce7e..ef60866c074b 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2997,7 +2997,7 @@ def _dummy_run( attn_metadata_i = (attn_group\ .get_metadata_builder(ubatch_id=ubid)\ .build_for_cudagraph_capture(common_attn_metadata)) - for layer_name in kv_cache_group_spec.layer_names: + for layer_name in attn_group.layer_names: assert type(attn_metadata) is list attn_metadata[ubid][ layer_name] = attn_metadata_i @@ -3005,7 +3005,7 @@ def _dummy_run( assert type(attn_metadata) is dict attn_metadata_i = attn_group.get_metadata_builder()\ .build_for_cudagraph_capture(common_attn_metadata) - for layer_name in kv_cache_group_spec.layer_names: + for layer_name in attn_group.layer_names: attn_metadata[layer_name] = attn_metadata_i with self.maybe_dummy_run_with_lora(self.lora_config, From d710dc8aa665b061da35fa5ac86560249875c035 Mon Sep 17 00:00:00 2001 From: Yongye Zhu Date: Mon, 29 Sep 2025 01:37:38 -0400 Subject: [PATCH 49/82] reverse last commit of insert kernel (#60) Signed-off-by: Yongye Zhu --- csrc/cache_kernels.cu | 41 ++++++++++++++--------------------------- 1 file changed, 14 insertions(+), 27 deletions(-) diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 992a1e8fc1f7..b9fb1b680806 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -505,30 +505,25 @@ __global__ void indexer_k_quant_and_cache_kernel( const scalar_t* __restrict__ k, // [num_tokens, head_dim] cache_t* __restrict__ kv_cache, // [num_blocks, block_size, cache_stride] const int64_t* __restrict__ slot_mapping, // [num_tokens] - const int token_stride, // stride for each token in k const int head_dim, // dimension of each head - const int block_stride, // stride for each block in kv_cache - const int cache_token_stride, // stride for each token in kv_cache - const int cache_block_size, // num_tokens for each block in kv_cache const int quant_block_size, // quantization block size + const int cache_block_size, // cache block size + const int cache_stride, // stride for each token in kv_cache const bool use_ue8m0 // use ue8m0 scale format ) { - // NOTE: In each block of kv_cache, the quantized k and scales are stored separately. - // The first cache_block_size * head_dim elements are the quantized k values in FP8, - // and the last cache_block_size * head_dim * 4 / quant_block_size elements are the scales in FP32. constexpr int VEC_SIZE = 4; const int64_t token_idx = blockIdx.x; - const int64_t head_idx = (blockIdx.y * blockDim.y * blockDim.x + threadIdx.y * blockDim.x + threadIdx.x) * VEC_SIZE; + const int64_t head_dim_idx = (blockIdx.y * blockDim.y * blockDim.x + threadIdx.y * blockDim.x + threadIdx.x) * VEC_SIZE; const int64_t slot_idx = slot_mapping[token_idx]; - const int64_t cache_block_idx = slot_idx / cache_block_size; - const int64_t cache_inblock_idx = slot_idx % cache_block_size; + const int64_t block_idx = slot_idx / cache_block_size; + const int64_t block_offset = slot_idx % cache_block_size; // NOTE: slot_idx can be -1 if the token is padded - if (slot_idx < 0 || (head_idx >= head_dim)) { + if (slot_idx < 0 || (head_dim_idx >= head_dim)) { return; } - float2 k_val = (reinterpret_cast(k))[(token_idx * token_stride + head_idx) / VEC_SIZE]; + float2 k_val = (reinterpret_cast(k))[(token_idx * head_dim + head_dim_idx) / VEC_SIZE]; scalar_t* k_val_ptr = reinterpret_cast(&k_val); float amax = 0.0f; for (int i = 0; i < VEC_SIZE; i++) { @@ -546,15 +541,13 @@ __global__ void indexer_k_quant_and_cache_kernel( scale = exp2f(ceilf(log2f(scale))); } - const int64_t cache_block_start_offset = cache_block_idx * block_stride; - const int64_t cache_inblock_offset = cache_inblock_idx * head_dim + head_idx; - const int64_t dst_k_vals_offset = cache_block_start_offset + cache_inblock_offset; + const int64_t dst_offset = block_idx * cache_block_size * cache_stride + block_offset * head_dim + head_dim_idx; for (int i = 0; i < VEC_SIZE; i++) { - kv_cache[dst_k_vals_offset + i] = fp8::scaled_convert(k_val_ptr[i], scale); + kv_cache[dst_offset + i] = fp8::scaled_convert(k_val_ptr[i], scale); } if (threadIdx.x == 0) { - const int64_t dst_scale_offset = cache_block_start_offset + cache_block_size * head_dim + cache_inblock_offset * 4 / quant_block_size; - reinterpret_cast(kv_cache)[dst_scale_offset / 4] = scale; + const int64_t dst_scale_idx = block_idx * cache_block_size * cache_stride + cache_block_size * head_dim + (block_offset * head_dim + head_dim_idx) * 4 / quant_block_size; + reinterpret_cast(kv_cache)[dst_scale_idx / 4] = scale; } } @@ -1127,14 +1120,8 @@ void cp_gather_cache( <<>>( \ reinterpret_cast(k.data_ptr()), \ reinterpret_cast(kv_cache.data_ptr()), \ - slot_mapping.data_ptr(), \ - k.stride(0), \ - k.size(1), \ - kv_cache.stride(0), \ - kv_cache.stride(1), \ - kv_cache.size(1), \ - quant_block_size, \ - use_ue8m0); + slot_mapping.data_ptr(), head_dim, quant_block_size, \ + cache_block_size, cache_stride, use_ue8m0); void indexer_k_quant_and_cache( torch::Tensor& k, // [num_tokens, head_dim] @@ -1164,4 +1151,4 @@ void indexer_k_quant_and_cache( DISPATCH_BY_KV_CACHE_DTYPE(k.dtype(), "fp8_e4m3", CALL_INDEXER_K_QUANT_AND_CACHE); -} +} \ No newline at end of file From b64779aba785a4ab6f3de24bc42d20bfa4367ba4 Mon Sep 17 00:00:00 2001 From: Yongye Zhu Date: Mon, 29 Sep 2025 04:55:36 -0400 Subject: [PATCH 50/82] Gather cache. (#61) * . Signed-off-by: Yongye Zhu * remove hadamard transform * cleanup Signed-off-by: Yongye Zhu --------- Signed-off-by: Yongye Zhu --- vllm/model_executor/models/deepseek_v2.py | 119 +++++++++++----------- 1 file changed, 57 insertions(+), 62 deletions(-) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index e8dce07bc741..42882e2b85b3 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -530,56 +530,55 @@ def forward(self): def get_attn_backend(self) -> AttentionBackend: return DeepseekV32IndexerBackend - -def rotate_activation(x: torch.Tensor) -> torch.Tensor: - assert x.dtype == torch.bfloat16 - from fast_hadamard_transform import hadamard_transform - hidden_size = x.size(-1) - # make sure the hidden_size is expontial of 2 - return hadamard_transform(x, scale=hidden_size**-0.5) - - -def hadacore_transform(x: torch.Tensor, inplace: bool = False) -> torch.Tensor: - assert x.dtype == torch.bfloat16 - return ops.hadacore_transform(x, inplace=inplace) - - -def rotate_activation_fake(x: torch.Tensor, ) -> torch.Tensor: - return torch.empty_like(x) - - -direct_register_custom_op( - op_name="rotate_activation", - op_func=rotate_activation, - mutates_args=["x"], - fake_impl=rotate_activation_fake, - dispatch_key=current_platform.dispatch_key, -) - @torch.inference_mode() -def indexer_k_quant_and_cache( - k, - kv_cache, - slot_mapping, - quant_block_size, - scale_fmt, +def cp_gather_indexer_k_quant_cache( + kv_cache, # [num_blocks, block_size, head_dim + 1] + dst_value, # [cu_seq_lens[-1], head_dim] + dst_scale, # [cu_seq_lens[-1], 4] + block_table, # [batch_size, num_blocks] + cu_seq_lens, # [batch_size + 1, ] + batch_size, ): - _, block_size, head_dim = kv_cache.shape - k_fp8, k_scale = torch.ops.vllm.tilelang_act_quant(k, quant_block_size, - scale_fmt) - k_bytes = k_fp8.view(torch.uint8) - s_bytes = k_scale.view(torch.uint8) - - packed = torch.cat([k_bytes, s_bytes], dim=-1) - - block_idx = torch.div(slot_mapping, block_size, rounding_mode='floor') - inblock = slot_mapping - block_idx * block_size - linear = block_idx * block_size + inblock - - kv_cache_flat = kv_cache.view(-1, head_dim) - # kv_cache_flat.shape: torch.Size([22326528, 132]), packed.shape: torch.Size([96, 132]), kv_cache.shape: torch.Size([348852, 64, 132]), linear.shape: torch.Size([91]) - - kv_cache_flat.index_copy_(0, linear, packed[:len(linear)]) + num_blocks, block_size, _ = kv_cache.shape + head_dim = dst_value.shape[-1] + kv_cache = kv_cache.view(num_blocks, -1) + + expected_value = [] + expected_scale = [] + for b in range(batch_size): + s = cu_seq_lens[b + 1] - cu_seq_lens[b] + if s == 0: + continue + tot = cdiv(s, block_size) + blocks = block_table[b, :tot] + + value = [] + scale = [] + full_block = torch.arange(tot - 1, device=kv_cache.device, dtype=torch.int32) + # print(f"full_blocks: {blocks[full_block]}") + non_remaining_value = kv_cache[blocks[full_block], : block_size * head_dim].view(-1, head_dim) + non_remaining_scale = kv_cache[blocks[full_block], block_size * head_dim:].view(-1, 4) + + # for i in range(tot - 1): + # value.append(kv_cache[blocks[i], :block_size * head_dim]) + # scale.append(kv_cache[blocks[i], block_size * head_dim:]) + + remaining = s - (tot - 1) * block_size + # value.append(kv_cache[blocks[-1], :remaining * head_dim]) + # scale.append(kv_cache[blocks[-1], block_size * head_dim: block_size * head_dim + remaining * 4]) + + value = torch.cat([non_remaining_value, kv_cache[blocks[-1], :remaining * head_dim].view(-1, head_dim)], dim=0) + scale = torch.cat([non_remaining_scale, kv_cache[blocks[-1], block_size * head_dim: block_size * head_dim + remaining * 4].view(-1, 4)], dim=0) + + expected_value.append(value) + expected_scale.append(scale) + + gather_value = torch.cat(expected_value, dim=0).view(-1, head_dim) + gather_scale = torch.cat(expected_scale, dim=0).view(-1, 4) + gather_value = gather_value.view(torch.float8_e4m3fn) + gather_scale = gather_scale.view(torch.float32) + dst_value.copy_(gather_value) + dst_scale.copy_(gather_scale) def sparse_attn_indexer( @@ -636,21 +635,22 @@ def sparse_attn_indexer( if has_prefill: prefill_metadata = attn_metadata.prefill num_prefills = attn_metadata.num_prefills - flattened_kv = torch.empty( - [prefill_metadata.total_seq_lens, head_dim + 4], + k_fp8 = torch.empty( + [prefill_metadata.total_seq_lens, head_dim], + device=k.device, + dtype=torch.float8_e4m3fn) + k_scale = torch.empty( + [prefill_metadata.total_seq_lens, 1], device=k.device, - dtype=torch.uint8) - ops.cp_gather_cache( + dtype=torch.float32) + cp_gather_indexer_k_quant_cache( kv_cache, - flattened_kv, + k_fp8, + k_scale, prefill_metadata.block_table, prefill_metadata.cu_seq_lens, num_prefills, ) - # TODO: the memory footprint here can be optimized - k_fp8 = flattened_kv[..., :head_dim].view( - torch.float8_e4m3fn).contiguous() - k_scale = flattened_kv[..., head_dim:].view(torch.float32).contiguous() cu_seqlen_ks = prefill_metadata.cu_seqlen_ks cu_seqlen_ke = prefill_metadata.cu_seqlen_ke num_tokens = attn_metadata.num_actual_tokens @@ -833,11 +833,6 @@ def forward(self, hidden_states: torch.Tensor, qr: torch.Tensor, positions, q_pe, k_pe = rotary_emb(positions, q_pe, k_pe.unsqueeze(1)) q = torch.cat([q_pe, q_nope], dim=-1) k = torch.cat([k_pe.squeeze(1), k_nope], dim=-1) - # logger.info_once(f'q.shape: {q.shape}, k.shape: {k.shape}') - q = torch.ops.vllm.rotate_activation(q) - k = torch.ops.vllm.rotate_activation( - k - ) #FIXME (siyuanf) hadacore_transform causes illegal memory access when applying to k # we only quant q here since k quant is fused with cache insertion q = q.view(-1, self.head_dim) From 80d834c264e53c47b396291b9c724f75d321b3be Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Mon, 29 Sep 2025 02:19:13 -0700 Subject: [PATCH 51/82] fix basic.py (#63) Signed-off-by: Chen Zhang --- examples/offline_inference/basic/basic.py | 41 +++++------------------ 1 file changed, 8 insertions(+), 33 deletions(-) diff --git a/examples/offline_inference/basic/basic.py b/examples/offline_inference/basic/basic.py index cd73022ba6ee..909fc9e4df66 100644 --- a/examples/offline_inference/basic/basic.py +++ b/examples/offline_inference/basic/basic.py @@ -2,50 +2,25 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from vllm import LLM, SamplingParams -from vllm.inputs.data import TokensPrompt # Sample prompts. prompts = [ - "hello, can you tell me the answer of 1 + 1?", - + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", ] # Create a sampling params object. -sampling_params = SamplingParams(temperature=0, top_p=1, max_tokens=50) +sampling_params = SamplingParams(temperature=0.8, top_p=0.95) -prompt_token_ids = [ - TokensPrompt( - prompt_token_ids=[0, 128803, 33310, 14, 588, 440, 4575, 678, 270, 3287, 294, 223, 19, 940, 223, 19, 33, 128804, 128799], - ), # hello, can you tell me the answer of 1 + 1? - TokensPrompt( - prompt_token_ids=[0, 128803, 33310, 14, 1205, 344, 223, 20, 940, 223, 20, 33, 128804, 128799], - ), # hello, what is 2 + 2? - TokensPrompt( - prompt_token_ids=[0, 128803, 9602, 344, 223, 21, 940, 223, 21, 33, 8033, 1801, 678, 16, 128804, 128799], - ), # what is 3 + 3? please show me. -] - -""" -Prompt: hello, can you tell me the answer of 1 + 1? -Output: Hello! The answer to 1 + 1 is **2**. \n\nIf you have any more questions, feel free to ask! 😊 -""" - -""" -Prompt: hello, what is 2 + 2? -Output: Hello! 2 + 2 equals 4. 😊 -""" - -""" -Prompt: what is 3 + 3? please show me. -Output: Let's add 3 and 3 together:\n\n3 + 3 = 6\n\nSo, 3 plus 3 equals 6." -""" def main(): # Create an LLM. - llm = LLM(model="/home/vllm-dsv32/DeepSeek-V3.2-Preview-Fix", tensor_parallel_size=8) + llm = LLM(model="facebook/opt-125m") # Generate texts from the prompts. # The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. - outputs = llm.generate(prompt_token_ids, sampling_params) + outputs = llm.generate(prompts, sampling_params) # Print the outputs. print("\nGenerated Outputs:\n" + "-" * 60) for output in outputs: @@ -57,4 +32,4 @@ def main(): if __name__ == "__main__": - main() + main() \ No newline at end of file From aeee9295b6636dcdc025f3c3e2912bb71e1c9809 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Mon, 29 Sep 2025 02:28:43 -0700 Subject: [PATCH 52/82] fix flashmla Signed-off-by: Chen Zhang --- cmake/external_projects/flashmla.cmake | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cmake/external_projects/flashmla.cmake b/cmake/external_projects/flashmla.cmake index 1e15cd168489..946e1d86fbb4 100644 --- a/cmake/external_projects/flashmla.cmake +++ b/cmake/external_projects/flashmla.cmake @@ -18,8 +18,8 @@ if(FLASH_MLA_SRC_DIR) else() FetchContent_Declare( flashmla - GIT_REPOSITORY https://github.com/vllm-model-0920/FlashMLA - GIT_TAG c2726ac45add214249698c7d7053851b9f3e54a4 + GIT_REPOSITORY https://github.com/vllm-project/FlashMLA + GIT_TAG 9140b54f8ca80a32b69972b46a68bfd0de4501b8 GIT_PROGRESS TRUE CONFIGURE_COMMAND "" BUILD_COMMAND "" From f14265482928356ca6d57ede3c26949105862a11 Mon Sep 17 00:00:00 2001 From: Lucia Fang <116399278+luccafong@users.noreply.github.com> Date: Mon, 29 Sep 2025 02:35:44 -0700 Subject: [PATCH 53/82] fix unpack kernel (#64) * fix unpack kernel * increase atol/rtol in test case --------- Co-authored-by: Lucia Fang --- .../attention/test_pack_unpack_triton.py | 190 ++---------------- vllm/attention/ops/common.py | 78 +++---- 2 files changed, 53 insertions(+), 215 deletions(-) diff --git a/tests/kernels/attention/test_pack_unpack_triton.py b/tests/kernels/attention/test_pack_unpack_triton.py index a44c49829612..59a9b64eebff 100644 --- a/tests/kernels/attention/test_pack_unpack_triton.py +++ b/tests/kernels/attention/test_pack_unpack_triton.py @@ -8,7 +8,7 @@ from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton -def test_pack_decode_query_basic_fp8(): +def test_pack_seq_basic_fp8(): """Test basic functionality of pack_seq_triton with fp8 and 3D tensors.""" device = "cuda" dtype = torch.float8_e4m3fn @@ -46,7 +46,7 @@ def test_pack_decode_query_basic_fp8(): assert_close(actual_data, expected_data, rtol=1e-1, atol=1e-2) -def test_pack_decode_query_custom_padding_fp8(): +def test_pack_seq_custom_padding_fp8(): """Test pack_seq_triton with custom padding values for fp8.""" device = "cuda" dtype = torch.float8_e4m3fn @@ -77,7 +77,7 @@ def test_pack_decode_query_custom_padding_fp8(): assert torch.allclose(padded_data, torch.zeros_like(padded_data), atol=1e-2) -def test_pack_decode_query_default_negative_inf_padding_fp8(): +def test_pack_seq_default_negative_inf_padding_fp8(): """Test that pack_seq_triton uses -inf padding by default for fp8.""" device = "cuda" dtype = torch.float8_e4m3fn @@ -93,7 +93,7 @@ def test_pack_decode_query_default_negative_inf_padding_fp8(): assert torch.all(padded_data < -100) # fp8 -inf is represented as large negative number -def test_pack_decode_query_edge_cases_fp8(): +def test_pack_seq_edge_cases_fp8(): """Test pack_seq_triton with edge cases for fp8.""" device = "cuda" dtype = torch.float8_e4m3fn @@ -120,7 +120,7 @@ def test_pack_decode_query_edge_cases_fp8(): assert result.shape == (3, 7, 8, 16) -def test_pack_decode_query_different_block_sizes_fp8(): +def test_pack_seq_different_block_sizes_fp8(): """Test pack_seq_triton with different block sizes for fp8.""" device = "cuda" dtype = torch.float8_e4m3fn @@ -144,7 +144,7 @@ def test_pack_decode_query_different_block_sizes_fp8(): assert_close(actual_data, expected_data, rtol=1e-1, atol=1e-2) -def test_pack_decode_query_shape_consistency(): +def test_pack_seq_shape_consistency(): """Test that pack_seq_triton maintains shape consistency.""" device = "cuda" dtype = torch.float8_e4m3fn @@ -191,13 +191,11 @@ def test_pack_unpack_roundtrip_fp8(): assert unpacked.shape == x.shape x_f32 = x.to(torch.float32) unpacked_f32 = unpacked.to(torch.float32) - assert_close(x_f32, unpacked_f32, rtol=1e-1, atol=1e-2) + assert_close(x_f32, unpacked_f32, rtol=1e-3, atol=1e-3) - # Test with query_start_loc - query_start_loc = torch.cat([torch.zeros(1, device=device, dtype=lengths.dtype), - lengths.cumsum(0)[:-1]]) - unpacked_with_loc = unpack_seq_triton(packed, lengths, query_start_loc) - assert_close(x_f32, unpacked_with_loc.to(torch.float32), rtol=1e-1, atol=1e-2) + # Unpack without explicit start locations (computed in kernel) + unpacked_with_loc = unpack_seq_triton(packed, lengths) + assert_close(x_f32, unpacked_with_loc.to(torch.float32), rtol=1e-3, atol=1e-2) def test_unpack_seq_triton_edge_cases_fp8(): @@ -223,174 +221,10 @@ def test_unpack_seq_triton_edge_cases_fp8(): # Only compare the first 3 elements that were actually packed assert_close(x[:3].to(torch.float32), unpacked.to(torch.float32), rtol=1e-1, atol=1e-2) - # Test with query_start_loc x = torch.randn(15, 8, 16, dtype=torch.float32, device=device) * 0.1 x = x.to(dtype=dtype) lengths = torch.tensor([5, 7, 3], device=device) - query_start_loc = torch.tensor([0, 5, 12], device=device) packed = pack_seq_triton(x, lengths) - unpacked = unpack_seq_triton(packed, lengths, query_start_loc) + unpacked = unpack_seq_triton(packed, lengths) assert unpacked.shape == x.shape - assert_close(x.to(torch.float32), unpacked.to(torch.float32), rtol=1e-1, atol=1e-2) - - -def test_masked_topk_basic(): - """Test basic functionality of masked_topk function.""" - device = "cuda" - - # Test case 1: Simple example - seq_lens = torch.tensor([2, 1], device=device) # 2 batches: lengths 2,1 - starting_pos = torch.tensor([3, 7], device=device) # starting positions - N = seq_lens.sum().item() # 3 total positions - vocab_size, k = 20, 2 - - scores = torch.randn(N, vocab_size, device=device) - - indices, top_scores = masked_topk(scores, seq_lens, starting_pos, k) - - # Check output shapes - assert indices.shape == (N, k) - assert top_scores.shape == (N, k) - - # Verify masking constraints - # Positions 0,1 (batch 0): should only use indices < 3 - assert torch.all(indices[0] < 3) - assert torch.all(indices[1] < 3) - # Position 2 (batch 1): should only use indices < 7 - assert torch.all(indices[2] < 7) - - -def test_masked_topk_complex(): - """Test masked_topk with more complex sequences.""" - device = "cuda" - - # Test case: 4 batches with different lengths - seq_lens = torch.tensor([3, 1, 1, 1], device=device) # lengths: 3,1,1,1 - starting_pos = torch.tensor([4, 12, 33, 50], device=device) # starting positions - N = seq_lens.sum().item() # 6 total positions - vocab_size, k = 100, 3 - - scores = torch.randn(N, vocab_size, device=device) - - indices, top_scores = masked_topk(scores, seq_lens, starting_pos, k) - - # Check output shapes - assert indices.shape == (N, k) - assert top_scores.shape == (N, k) - - # Verify masking constraints for each batch - pos_idx = 0 - for b in range(len(seq_lens)): - seq_len = seq_lens[b].item() - start_pos = starting_pos[b].item() - - # Check all positions in this batch - for i in range(seq_len): - assert torch.all(indices[pos_idx] < start_pos), f"Position {pos_idx} should only use indices < {start_pos}" - pos_idx += 1 - - -def test_masked_topk_edge_cases(): - """Test masked_topk with edge cases.""" - device = "cuda" - - # Test case 1: Single batch - seq_lens = torch.tensor([5], device=device) - starting_pos = torch.tensor([10], device=device) - scores = torch.randn(5, 50, device=device) - - indices, top_scores = masked_topk(scores, seq_lens, starting_pos, k=3) - assert indices.shape == (5, 3) - assert torch.all(indices < 10) # All positions should use indices < 10 - - # Test case 2: Very small starting positions - seq_lens = torch.tensor([2, 1], device=device) - starting_pos = torch.tensor([1, 2], device=device) - scores = torch.randn(3, 20, device=device) - - indices, top_scores = masked_topk(scores, seq_lens, starting_pos, k=1) - assert indices.shape == (3, 1) - assert torch.all(indices[0] < 1) # First position can only use index 0 - assert torch.all(indices[1] < 1) # Second position can only use index 0 - assert torch.all(indices[2] < 2) # Third position can use indices 0,1 - - # Test case 3: Large starting positions - seq_lens = torch.tensor([2], device=device) - starting_pos = torch.tensor([95], device=device) - scores = torch.randn(2, 100, device=device) - - indices, top_scores = masked_topk(scores, seq_lens, starting_pos, k=5) - assert indices.shape == (2, 5) - assert torch.all(indices < 95) - - -def test_masked_topk_different_k_values(): - """Test masked_topk with different k values.""" - device = "cuda" - - seq_lens = torch.tensor([2, 1], device=device) - starting_pos = torch.tensor([5, 10], device=device) - scores = torch.randn(3, 20, device=device) - - # Test different k values - for k in [1, 3, 5, 10]: - indices, top_scores = masked_topk(scores, seq_lens, starting_pos, k) - - assert indices.shape == (3, k) - assert top_scores.shape == (3, k) - - # Verify masking constraints - assert torch.all(indices[0] < 5) - assert torch.all(indices[1] < 5) - assert torch.all(indices[2] < 10) - - -def test_masked_topk_fp8(): - """Test masked_topk with fp8 dtype.""" - device = "cuda" - dtype = torch.float8_e4m3fn - - seq_lens = torch.tensor([2, 1], device=device) - starting_pos = torch.tensor([5, 10], device=device) - - # Create fp8 scores - scores_f32 = torch.randn(3, 20, device=device) * 0.1 - scores = scores_f32.to(dtype) - - indices, top_scores = masked_topk(scores, seq_lens, starting_pos, k=3) - - # Check output shapes - assert indices.shape == (3, 3) - assert top_scores.shape == (3, 3) - assert top_scores.dtype == dtype - - # Verify masking constraints - assert torch.all(indices[0] < 5) - assert torch.all(indices[1] < 5) - assert torch.all(indices[2] < 10) - - # Check that top scores are reasonable (not all -inf) - assert not torch.all(torch.isinf(top_scores.to(torch.float32))) - - -def test_masked_topk_consistency(): - """Test that masked_topk produces consistent results.""" - device = "cuda" - - seq_lens = torch.tensor([2, 1], device=device) - starting_pos = torch.tensor([5, 10], device=device) - - # Use deterministic scores for consistency testing - torch.manual_seed(42) - scores = torch.randn(3, 20, device=device) - - # Run multiple times - results = [] - for _ in range(3): - indices, top_scores = masked_topk(scores, seq_lens, starting_pos, k=3) - results.append((indices.clone(), top_scores.clone())) - - # Check that all runs produce identical results - for i in range(1, len(results)): - assert torch.equal(results[0][0], results[i][0]), "Indices should be consistent" - assert_close(results[0][1], results[i][1], rtol=1e-5, atol=1e-5), "Scores should be consistent" + assert_close(x.to(torch.float32), unpacked.to(torch.float32), rtol=1e-1, atol=1e-2) \ No newline at end of file diff --git a/vllm/attention/ops/common.py b/vllm/attention/ops/common.py index cfd5c54c711e..6f4a695637d3 100644 --- a/vllm/attention/ops/common.py +++ b/vllm/attention/ops/common.py @@ -140,9 +140,9 @@ def cp_lse_ag_out_rs(cp_attn_out: torch.Tensor, @triton.jit def _pack_seq_kernel( - x_ptr, # *fp8, [N, D] - out_ptr, # *fp8, [B, Lmax, D] - starts_ptr, # *i32, [B] + x_ptr, # [N, D] + out_ptr, # [B, Lmax, D] + lengths_ptr, # *i32, [B] N: tl.constexpr, D: tl.constexpr, Lmax: tl.constexpr, PAD_VALUE: tl.constexpr, BLOCK_T: tl.constexpr, # timesteps per program @@ -154,15 +154,11 @@ def _pack_seq_kernel( off_t = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) # [BLOCK_T] off_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) # [BLOCK_D] - # bounds - in_start = tl.load(starts_ptr + pid_b) - - # Calculate sequence length from starts - if pid_b < tl.num_programs(0) - 1: - next_start = tl.load(starts_ptr + pid_b + 1) - seq_len = next_start - in_start - else: - seq_len = N - in_start + # Compute start index and sequence length from cumulative lengths + in_start = 0 + for i in range(pid_b): + in_start += tl.load(lengths_ptr + i) + seq_len = tl.load(lengths_ptr + pid_b) # valid time positions for this block t_mask = off_t < Lmax @@ -178,42 +174,50 @@ def _pack_seq_kernel( # out_ptr: row-major [B, Lmax, D] out_row_ptr = out_ptr + (pid_b * Lmax + off_t)[:, None] * D + off_d[None, :] - # Initialize with PAD - # (write pad for all t in this block) + # Initialize with PAD (cast will occur as needed based on out_ptr dtype) d_mask = off_d[None, :] < D - tl.store(out_row_ptr, tl.full([BLOCK_T, BLOCK_D], PAD_VALUE, tl.float32), mask=t_mask[:, None] & d_mask) + pad_vals = tl.full([BLOCK_T, BLOCK_D], PAD_VALUE, tl.float32) + tl.store(out_row_ptr, pad_vals, mask=t_mask[:, None] & d_mask) # Load & write only where within seq_len x_vals = tl.load(x_row_ptr, mask=valid_row[:, None] & d_mask) tl.store(out_row_ptr, x_vals, mask=valid_row[:, None] & d_mask) -def pack_seq_triton(x, starts, pad_value=-float('inf'), block_t=64, block_d=64): +def pack_seq_triton(x, lengths, pad_value=-float('inf'), block_t=64, block_d=64): + """ + Pack sequences of different lengths into a batched tensor. + + Args: + x: [N, ...] - input tensor where N is total number of tokens + lengths: [B] - sequence lengths for each batch + pad_value: value to use for padding + block_t: block size for time dimension + block_d: block size for feature dimension + + Returns: + packed: [B, Lmax, ...] - packed tensor + """ # Handle multi-dimensional input by reshaping to (N, -1) original_shape = x.shape if len(original_shape) > 2: N = original_shape[0] x_reshaped = x.reshape(N, -1) - D = x_reshaped.shape[1] # Get the actual feature dimension + D = x_reshaped.shape[1] else: N, D = x.shape x_reshaped = x - B = starts.numel() - # Calculate Lmax from starts without creating lengths tensor - if B == 1: - Lmax = N - starts[0].item() - else: - # Calculate max length from consecutive starts - lengths = starts[1:] - starts[:-1] - last_length = N - starts[-1].item() - Lmax = max(int(lengths.max().item()), int(last_length)) + B = lengths.numel() + Lmax = int(lengths.max().item()) + + # Starts are computed inside the kernel from lengths out = torch.empty((B, Lmax, D), device=x.device, dtype=x.dtype) grid = (B, triton.cdiv(Lmax, block_t), triton.cdiv(D, block_d)) _pack_seq_kernel[grid]( - x_reshaped, out, starts.int(), + x_reshaped, out, lengths.int(), N, D, Lmax, PAD_VALUE=float(pad_value), BLOCK_T=block_t, BLOCK_D=block_d, @@ -230,9 +234,8 @@ def pack_seq_triton(x, starts, pad_value=-float('inf'), block_t=64, block_d=64): @triton.jit def _unpack_seq_triton_kernel( - packed_ptr, # *fp8, [B, Lmax, D] - out_ptr, # *fp8, [N, D] - starts_ptr, # *i32, [B] + packed_ptr, # [B, Lmax, D] + out_ptr, # [N, D] lengths_ptr, # *i32, [B] B: tl.constexpr, Lmax: tl.constexpr, D: tl.constexpr, BLOCK_T: tl.constexpr, # timesteps per program @@ -244,8 +247,10 @@ def _unpack_seq_triton_kernel( off_t = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) # [BLOCK_T] off_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) # [BLOCK_D] - # bounds - in_start = tl.load(starts_ptr + pid_b) + # bounds: compute start from cumulative lengths + in_start = 0 + for i in range(pid_b): + in_start += tl.load(lengths_ptr + i) seq_len = tl.load(lengths_ptr + pid_b) # valid time positions for this block @@ -268,15 +273,14 @@ def _unpack_seq_triton_kernel( tl.store(out_row_ptr, packed_vals, mask=valid_row[:, None] & d_mask) -def unpack_seq_triton(packed_tensor, starts, lengths, block_t=64, block_d=64): +def unpack_seq_triton(packed_tensor, lengths, block_t=64, block_d=64): """ Unpack a packed decode query tensor back to the original format. Efficient Triton implementation. Args: packed_tensor: [B, Lmax, ...] - packed tensor from pack_seq_triton - starts: [B] - start locations for each batch - lengths: [B] - sequence lengths for each batch (needed to calculate total N) + lengths: [B] - sequence lengths for each batch block_t: block size for time dimension block_d: block size for feature dimension @@ -289,7 +293,7 @@ def unpack_seq_triton(packed_tensor, starts, lengths, block_t=64, block_d=64): if len(original_shape) > 3: B, Lmax = original_shape[:2] packed_reshaped = packed_tensor.reshape(B, Lmax, -1) - D = packed_reshaped.shape[2] # Get the actual feature dimension + D = packed_reshaped.shape[2] else: B, Lmax, D = packed_tensor.shape packed_reshaped = packed_tensor @@ -301,7 +305,7 @@ def unpack_seq_triton(packed_tensor, starts, lengths, block_t=64, block_d=64): grid = (B, triton.cdiv(Lmax, block_t), triton.cdiv(D, block_d)) _unpack_seq_triton_kernel[grid]( - packed_reshaped, out, starts.int(), lengths.int(), + packed_reshaped, out, lengths.int(), B, Lmax, D, BLOCK_T=block_t, BLOCK_D=block_d, num_warps=4, num_stages=2 From 093b0c0435769e0160d7fcc80c2dc3e3d599cfd1 Mon Sep 17 00:00:00 2001 From: Yongye Zhu Date: Mon, 29 Sep 2025 11:35:55 +0000 Subject: [PATCH 54/82] partial configs --- vllm/config/model.py | 4 +- vllm/model_executor/models/config.py | 5 +- vllm/model_executor/models/deepseek_mtp.py | 4 +- vllm/model_executor/models/deepseek_v2.py | 20 +- vllm/model_executor/models/registry.py | 1 + vllm/transformers_utils/config.py | 1 + vllm/transformers_utils/configs/__init__.py | 2 + .../transformers_utils/configs/deepseek_v3.py | 193 ++++++++++++++++++ 8 files changed, 214 insertions(+), 16 deletions(-) create mode 100644 vllm/transformers_utils/configs/deepseek_v3.py diff --git a/vllm/config/model.py b/vllm/config/model.py index 921322bb475c..baf577b49ef5 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -1198,13 +1198,13 @@ def is_deepseek_mla(self) -> bool: if not hasattr(self.hf_text_config, "model_type"): return False elif self.hf_text_config.model_type in \ - ('deepseek_v2', 'deepseek_v3', 'deepseek_mtp', 'kimi_k2'): + ('deepseek_v2', 'deepseek_v3', 'deepseek_mtp', 'kimi_k2', 'deepseek_v32'): return self.hf_text_config.kv_lora_rank is not None elif self.hf_text_config.model_type == 'eagle': # if the model is an EAGLE module, check for the # underlying architecture return self.hf_text_config.model.model_type in \ - ('deepseek_v2', 'deepseek_v3') \ + ('deepseek_v2', 'deepseek_v3', 'deepseek_v32') \ and self.hf_text_config.kv_lora_rank is not None return False diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 15d997a9303c..494403492c2a 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -423,8 +423,9 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: """ hf_config = vllm_config.model_config.hf_config - is_v32 = hasattr(hf_config, "attn_module_list_cfg") \ - and "attn_index" in hf_config.attn_module_list_cfg[0] + is_v32 = hasattr( + hf_config, "index_topk" + ) if is_v32: # For DeepSeekV3.2, we use a custom fp8 format as default (i.e. diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index 42e7fe5b6f59..9d299eac4214 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -56,8 +56,8 @@ def __init__(self, vllm_config: VllmConfig, prefix: str) -> None: bias=False) self.is_v32 = hasattr( - config, "attn_module_list_cfg" - ) and "attn_index" in config.attn_module_list_cfg[0] + config, "index_topk" + ) if self.is_v32: topk_tokens = config.attn_module_list_cfg[0]["topk_tokens"] topk_indices_buffer = torch.empty(vllm_config.scheduler_config.max_num_batched_tokens, diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 42882e2b85b3..10a77375f308 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -779,11 +779,11 @@ def __init__(self, super().__init__() self.vllm_config = vllm_config self.config = config - self.indexer_cfg = config.attn_module_list_cfg[0]["attn_index"] - self.topk_tokens = config.attn_module_list_cfg[0]["topk_tokens"] - self.n_head = self.indexer_cfg["n_head"] # 64 - self.head_dim = self.indexer_cfg["head_dim"] # 128 - self.rope_dim = self.indexer_cfg["rope_dim"] # 64 + # self.indexer_cfg = config.attn_module_list_cfg[0]["attn_index"] + self.topk_tokens = config.index_topk + self.n_head = config.index_n_heads # 64 + self.head_dim = config.index_head_dim # 128 + self.rope_dim = config.qk_rope_head_dim # 64 self.q_lora_rank = q_lora_rank # 1536 # no tensor parallel, just replicated self.wq_b = ReplicatedLinear(self.q_lora_rank, @@ -962,8 +962,8 @@ def __init__( self.scaling = self.scaling * mscale * mscale self.is_v32 = hasattr( - config, "attn_module_list_cfg" - ) and "attn_index" in config.attn_module_list_cfg[0] + config, "index_topk" + ) if self.is_v32: self.indexer = Indexer(vllm_config, config, hidden_size, @@ -1141,10 +1141,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.vocab_size = config.vocab_size self.is_v32 = hasattr( - config, "attn_module_list_cfg" - ) and "attn_index" in config.attn_module_list_cfg[0] + config, "index_topk" + ) if self.is_v32: - topk_tokens = config.attn_module_list_cfg[0]["topk_tokens"] + topk_tokens = config.index_topk topk_indices_buffer = torch.empty(vllm_config.scheduler_config.max_num_batched_tokens, topk_tokens, dtype=torch.int32, diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 5dc5d545bb9c..371222af9e62 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -69,6 +69,7 @@ "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"), "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"), "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"), + "DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"), "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"), "Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"), "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"), diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index cafc43f6b767..6bbfaad5fa1a 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -64,6 +64,7 @@ def __getitem__(self, key): _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict( chatglm="ChatGLMConfig", deepseek_vl_v2="DeepseekVLV2Config", + deepseek_v32="DeepseekV3Config", kimi_vl="KimiVLConfig", Llama_Nemotron_Nano_VL="Nemotron_Nano_VL_Config", RefinedWeb="RWConfig", # For tiiuae/falcon-40b(-instruct) diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 91bfeb8c55ee..efb249da2e87 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -32,10 +32,12 @@ Step3VisionEncoderConfig, Step3VLConfig) from vllm.transformers_utils.configs.ultravox import UltravoxConfig +from vllm.transformers_utils.configs.deepseek_v3 import DeepseekV3Config __all__ = [ "ChatGLMConfig", "DeepseekVLV2Config", + "DeepseekV3Config", "EAGLEConfig", "RWConfig", "JAISConfig", diff --git a/vllm/transformers_utils/configs/deepseek_v3.py b/vllm/transformers_utils/configs/deepseek_v3.py new file mode 100644 index 000000000000..235b7b0fd33c --- /dev/null +++ b/vllm/transformers_utils/configs/deepseek_v3.py @@ -0,0 +1,193 @@ +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {} +class DeepseekV3Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the DeepSeek-V3. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 129280): + Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`DeepseekV3Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + moe_intermediate_size (`int`, *optional*, defaults to 1407): + Dimension of the MoE representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_nextn_predict_layers (`int`, *optional*, defaults to 1): + Number of nextn predict layers in the DeepSeekV3 Model. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + n_shared_experts (`int`, *optional*, defaults to None): + Number of shared experts, None means dense model. + n_routed_experts (`int`, *optional*, defaults to None): + Number of routed experts, None means dense model. + routed_scaling_factor (`float`, *optional*, defaults to 1.0): + Scaling factor or routed experts. + topk_method (`str`, *optional*, defaults to `gready`): + Topk method used in routed gate. + n_group (`int`, *optional*, defaults to None): + Number of groups for routed experts. + topk_group (`int`, *optional*, defaults to None): + Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups). + num_experts_per_tok (`int`, *optional*, defaults to None): + Number of selected experts, None means dense model. + moe_layer_freq (`int`, *optional*, defaults to 1): + The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers. + first_k_dense_replace (`int`, *optional*, defaults to 0): + Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head). + \--k dense layers--/ + norm_topk_prob (`bool`, *optional*, defaults to False): + Whether to normalize the weights of the routed experts. + scoring_func (`str`, *optional*, defaults to 'softmax'): + Method of computing expert weights. + aux_loss_alpha (`float`, *optional*, defaults to 0.001): + Auxiliary loss weight coefficient. + seq_aux = (`bool`, *optional*, defaults to True): + Whether to compute the auxiliary loss for each individual sample. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + ```python + >>> from transformers import DeepseekV3Model, DeepseekV3Config + >>> # Initializing a Deepseek-V3 style configuration + >>> configuration = DeepseekV3Config() + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "deepseek_v3" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=129280, + hidden_size=7168, + intermediate_size=18432, + moe_intermediate_size = 2048, + num_hidden_layers=61, + num_nextn_predict_layers=1, + num_attention_heads=128, + num_key_value_heads=128, + n_shared_experts = 1, + n_routed_experts = 256, + ep_size = 1, + routed_scaling_factor = 2.5, + kv_lora_rank = 512, + q_lora_rank = 1536, + qk_rope_head_dim = 64, + v_head_dim = 128, + qk_nope_head_dim = 128, + topk_method = 'noaux_tc', + n_group = 8, + topk_group = 4, + num_experts_per_tok = 8, + moe_layer_freq = 1, + first_k_dense_replace = 3, + norm_topk_prob = True, + scoring_func = 'sigmoid', + hidden_act="silu", + max_position_embeddings=4096, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=0, + eos_token_id=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.moe_intermediate_size = moe_intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_nextn_predict_layers = num_nextn_predict_layers + self.num_attention_heads = num_attention_heads + self.n_shared_experts = n_shared_experts + self.n_routed_experts = n_routed_experts + self.ep_size = ep_size + self.routed_scaling_factor = routed_scaling_factor + self.kv_lora_rank = kv_lora_rank + self.q_lora_rank = q_lora_rank + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.qk_nope_head_dim = qk_nope_head_dim + self.topk_method = topk_method + self.n_group = n_group + self.topk_group = topk_group + self.num_experts_per_tok = num_experts_per_tok + self.moe_layer_freq = moe_layer_freq + self.first_k_dense_replace = first_k_dense_replace + self.norm_topk_prob = norm_topk_prob + self.scoring_func = scoring_func + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) \ No newline at end of file From 3e530a59cdd6adf4be66f9e9bdf6dd5deff5e7b3 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Mon, 29 Sep 2025 04:42:28 -0700 Subject: [PATCH 55/82] fix blackwell Signed-off-by: Chen Zhang --- vllm/v1/attention/backends/mla/flashmla_sparse.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 7d70f3081996..745807e5a0b2 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -317,7 +317,8 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], s_q = 1 # inversely proportional to s_q, so s_q = 1 is the largest max_num_sm_parts = int( max((sm_count // 2) / h_k // (cdiv(h_q // h_k, 2 * 64) * s_q), 1)) - + if current_platform.is_device_capability(100): + max_num_sm_parts *= 2 self.tile_scheduler_metadata_buffer = torch.empty( # TileSchedulerMetaDataSize = 8 # see: FlashMLA/csrc/params.h From 88ef733ff5125b1b29e7aaae13063b2a109d5e8f Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 29 Sep 2025 21:21:06 +0800 Subject: [PATCH 56/82] update config Signed-off-by: youkaichao --- vllm/v1/attention/backends/mla/flashmla_sparse.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 745807e5a0b2..a118b29e9895 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -292,8 +292,7 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], self.num_heads = self.model_config.get_num_attention_heads( parallel_config) self.mla_dims = get_mla_dims(self.model_config) - self.topk_tokens = vllm_config.model_config.hf_config\ - .attn_module_list_cfg[0]["topk_tokens"] + self.topk_tokens = vllm_config.model_config.hf_config.index_topk self.use_fp8_kv_cache = cache_config.cache_dtype == "fp8_ds_mla" self.topk_tokens_tensor = torch.tensor([self.topk_tokens], device=device, From 98e0a0ffcf0a1e279556d1cdf80062d8efe2e94e Mon Sep 17 00:00:00 2001 From: Yongye Zhu Date: Mon, 29 Sep 2025 17:51:03 +0000 Subject: [PATCH 57/82] small fix Signed-off-by: Yongye Zhu --- vllm/model_executor/models/deepseek_v2.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index c5557a89ead0..c2fe06f3977d 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -80,6 +80,7 @@ fp8_paged_mqa_logits, ) from vllm.model_executor.layers.quantization.utils.fp8_utils import per_token_group_quant_fp8 +from vllm.platforms import current_platform if current_platform.is_cuda_alike(): from vllm import _custom_ops as ops From 3683a69b0b77b3a0d206a8a8e26dbd279551dc01 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Mon, 29 Sep 2025 13:11:09 -0700 Subject: [PATCH 58/82] update to support 12.8 Signed-off-by: Lucas Wilkinson --- cmake/external_projects/flashmla.cmake | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cmake/external_projects/flashmla.cmake b/cmake/external_projects/flashmla.cmake index 946e1d86fbb4..c9e7aec880b9 100644 --- a/cmake/external_projects/flashmla.cmake +++ b/cmake/external_projects/flashmla.cmake @@ -19,7 +19,7 @@ else() FetchContent_Declare( flashmla GIT_REPOSITORY https://github.com/vllm-project/FlashMLA - GIT_TAG 9140b54f8ca80a32b69972b46a68bfd0de4501b8 + GIT_TAG 5f65b85703c7ed75fda01e06495077caad207c3f GIT_PROGRESS TRUE CONFIGURE_COMMAND "" BUILD_COMMAND "" @@ -38,7 +38,7 @@ set(SUPPORT_ARCHS) if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3) list(APPEND SUPPORT_ARCHS 9.0a) endif() -if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.9) +if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8) list(APPEND SUPPORT_ARCHS 10.0a) endif() From b7de53e19193faa8ac111754a23b5197ac862261 Mon Sep 17 00:00:00 2001 From: Yongye Zhu Date: Mon, 29 Sep 2025 20:46:46 +0000 Subject: [PATCH 59/82] format --- csrc/cache.h | 8 +- csrc/cache_kernels.cu | 68 ++++--- csrc/torch_bindings.cpp | 8 +- .../attention/test_deepgemm_attention.py | 142 ++++++------- .../kernels/attention/test_flashmla_sparse.py | 35 ++-- tests/kernels/attention/test_indexer.py | 170 ++++++++-------- .../attention/test_pack_unpack_triton.py | 99 +++++---- tests/v1/attention/test_mla_backends.py | 16 +- vllm/_custom_ops.py | 6 +- vllm/attention/ops/common.py | 128 +++++++----- vllm/config/cache.py | 4 +- vllm/model_executor/models/config.py | 7 +- vllm/model_executor/models/deepseek_mtp.py | 18 +- vllm/model_executor/models/deepseek_v2.py | 188 ++++++++++-------- vllm/transformers_utils/configs/__init__.py | 2 +- .../transformers_utils/configs/deepseek_v3.py | 136 +++---------- vllm/utils/deep_gemm.py | 48 ++--- vllm/v1/attention/backends/mla/common.py | 2 +- .../attention/backends/mla/flashmla_sparse.py | 20 +- vllm/v1/attention/backends/mla/indexer.py | 55 ++--- vllm/v1/core/single_type_kv_cache_manager.py | 3 +- vllm/v1/spec_decode/eagle.py | 36 ++-- 22 files changed, 580 insertions(+), 619 deletions(-) diff --git a/csrc/cache.h b/csrc/cache.h index 3a4fc92a6c25..427bd0d54fac 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -59,8 +59,8 @@ void cp_gather_cache( // Indexer K quantization and cache function void indexer_k_quant_and_cache( - torch::Tensor& k, // [num_tokens, head_dim] - torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride] - torch::Tensor& slot_mapping, // [num_tokens] - int64_t quant_block_size, // quantization block size + torch::Tensor& k, // [num_tokens, head_dim] + torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride] + torch::Tensor& slot_mapping, // [num_tokens] + int64_t quant_block_size, // quantization block size const std::string& scale_fmt); diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index b9fb1b680806..b014fd27a8d6 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -502,28 +502,31 @@ __global__ void concat_and_cache_ds_mla_kernel( template __global__ void indexer_k_quant_and_cache_kernel( - const scalar_t* __restrict__ k, // [num_tokens, head_dim] - cache_t* __restrict__ kv_cache, // [num_blocks, block_size, cache_stride] - const int64_t* __restrict__ slot_mapping, // [num_tokens] - const int head_dim, // dimension of each head - const int quant_block_size, // quantization block size - const int cache_block_size, // cache block size - const int cache_stride, // stride for each token in kv_cache - const bool use_ue8m0 // use ue8m0 scale format + const scalar_t* __restrict__ k, // [num_tokens, head_dim] + cache_t* __restrict__ kv_cache, // [num_blocks, block_size, cache_stride] + const int64_t* __restrict__ slot_mapping, // [num_tokens] + const int head_dim, // dimension of each head + const int quant_block_size, // quantization block size + const int cache_block_size, // cache block size + const int cache_stride, // stride for each token in kv_cache + const bool use_ue8m0 // use ue8m0 scale format ) { constexpr int VEC_SIZE = 4; const int64_t token_idx = blockIdx.x; - const int64_t head_dim_idx = (blockIdx.y * blockDim.y * blockDim.x + threadIdx.y * blockDim.x + threadIdx.x) * VEC_SIZE; + const int64_t head_dim_idx = (blockIdx.y * blockDim.y * blockDim.x + + threadIdx.y * blockDim.x + threadIdx.x) * + VEC_SIZE; const int64_t slot_idx = slot_mapping[token_idx]; const int64_t block_idx = slot_idx / cache_block_size; const int64_t block_offset = slot_idx % cache_block_size; - + // NOTE: slot_idx can be -1 if the token is padded if (slot_idx < 0 || (head_dim_idx >= head_dim)) { return; } - - float2 k_val = (reinterpret_cast(k))[(token_idx * head_dim + head_dim_idx) / VEC_SIZE]; + + float2 k_val = (reinterpret_cast( + k))[(token_idx * head_dim + head_dim_idx) / VEC_SIZE]; scalar_t* k_val_ptr = reinterpret_cast(&k_val); float amax = 0.0f; for (int i = 0; i < VEC_SIZE; i++) { @@ -541,12 +544,17 @@ __global__ void indexer_k_quant_and_cache_kernel( scale = exp2f(ceilf(log2f(scale))); } - const int64_t dst_offset = block_idx * cache_block_size * cache_stride + block_offset * head_dim + head_dim_idx; + const int64_t dst_offset = block_idx * cache_block_size * cache_stride + + block_offset * head_dim + head_dim_idx; for (int i = 0; i < VEC_SIZE; i++) { - kv_cache[dst_offset + i] = fp8::scaled_convert(k_val_ptr[i], scale); + kv_cache[dst_offset + i] = + fp8::scaled_convert(k_val_ptr[i], scale); } if (threadIdx.x == 0) { - const int64_t dst_scale_idx = block_idx * cache_block_size * cache_stride + cache_block_size * head_dim + (block_offset * head_dim + head_dim_idx) * 4 / quant_block_size; + const int64_t dst_scale_idx = + block_idx * cache_block_size * cache_stride + + cache_block_size * head_dim + + (block_offset * head_dim + head_dim_idx) * 4 / quant_block_size; reinterpret_cast(kv_cache)[dst_scale_idx / 4] = scale; } } @@ -1115,40 +1123,40 @@ void cp_gather_cache( } // Macro to dispatch the kernel based on the data type. -#define CALL_INDEXER_K_QUANT_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \ - vllm::indexer_k_quant_and_cache_kernel \ - <<>>( \ - reinterpret_cast(k.data_ptr()), \ - reinterpret_cast(kv_cache.data_ptr()), \ - slot_mapping.data_ptr(), head_dim, quant_block_size, \ +#define CALL_INDEXER_K_QUANT_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \ + vllm::indexer_k_quant_and_cache_kernel \ + <<>>( \ + reinterpret_cast(k.data_ptr()), \ + reinterpret_cast(kv_cache.data_ptr()), \ + slot_mapping.data_ptr(), head_dim, quant_block_size, \ cache_block_size, cache_stride, use_ue8m0); void indexer_k_quant_and_cache( - torch::Tensor& k, // [num_tokens, head_dim] - torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride] - torch::Tensor& slot_mapping, // [num_tokens] - int64_t quant_block_size, // quantization block size + torch::Tensor& k, // [num_tokens, head_dim] + torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride] + torch::Tensor& slot_mapping, // [num_tokens] + int64_t quant_block_size, // quantization block size const std::string& scale_fmt) { - int num_tokens = k.size(0); int head_dim = k.size(1); int cache_block_size = kv_cache.size(1); int cache_stride = kv_cache.size(2); bool use_ue8m0 = scale_fmt == "ue8m0"; - + TORCH_CHECK(k.device() == kv_cache.device(), "k and kv_cache must be on the same device"); TORCH_CHECK(k.device() == slot_mapping.device(), "k and slot_mapping must be on the same device"); TORCH_CHECK(head_dim % quant_block_size == 0, "head_dim must be divisible by quant_block_size"); - + constexpr int vec_size = 4; - dim3 grid(num_tokens, (head_dim + quant_block_size * vec_size - 1) / (quant_block_size * vec_size)); + dim3 grid(num_tokens, (head_dim + quant_block_size * vec_size - 1) / + (quant_block_size * vec_size)); dim3 block(32, vec_size); const at::cuda::OptionalCUDAGuard device_guard(device_of(k)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - + DISPATCH_BY_KV_CACHE_DTYPE(k.dtype(), "fp8_e4m3", CALL_INDEXER_K_QUANT_AND_CACHE); } \ No newline at end of file diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 9e7fbeb80bb3..ebd28e735088 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -715,9 +715,11 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { cache_ops.impl("cp_gather_cache", torch::kCUDA, &cp_gather_cache); cache_ops.def( - "indexer_k_quant_and_cache(Tensor k, Tensor! kv_cache, Tensor slot_mapping, " - "int quant_block_size, str kv_cache_dtype) -> ()"); - cache_ops.impl("indexer_k_quant_and_cache", torch::kCUDA, &indexer_k_quant_and_cache); + "indexer_k_quant_and_cache(Tensor k, Tensor! kv_cache, Tensor " + "slot_mapping, " + "int quant_block_size, str kv_cache_dtype) -> ()"); + cache_ops.impl("indexer_k_quant_and_cache", torch::kCUDA, + &indexer_k_quant_and_cache); } TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) { diff --git a/tests/kernels/attention/test_deepgemm_attention.py b/tests/kernels/attention/test_deepgemm_attention.py index 50c547b84be6..03cc6b930c94 100644 --- a/tests/kernels/attention/test_deepgemm_attention.py +++ b/tests/kernels/attention/test_deepgemm_attention.py @@ -1,17 +1,15 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import random + import pytest import torch from vllm.platforms import current_platform -from vllm.utils import has_deep_gemm, cdiv -from vllm.utils.deep_gemm import ( - _ceil_to_ue8m0, - fp8_mqa_logits, - calc_diff, - get_paged_mqa_logits_metadata, - fp8_paged_mqa_logits, - get_num_sms, -) +from vllm.utils import cdiv, has_deep_gemm +from vllm.utils.deep_gemm import (_ceil_to_ue8m0, calc_diff, fp8_mqa_logits, + fp8_paged_mqa_logits, get_num_sms, + get_paged_mqa_logits_metadata) def kv_cache_cast_to_fp8(x: torch.Tensor) -> torch.Tensor: @@ -26,18 +24,17 @@ def kv_cache_cast_to_fp8(x: torch.Tensor) -> torch.Tensor: device=x.device, dtype=torch.uint8, ) - x_fp8[:, : block_size * head_dim] = x_scaled.view( - num_blocks, block_size * head_dim - ).view(dtype=torch.uint8) - x_fp8[:, block_size * head_dim :] = sf.view(num_blocks, block_size).view( - dtype=torch.uint8 - ) + x_fp8[:, :block_size * head_dim] = x_scaled.view( + num_blocks, block_size * head_dim).view(dtype=torch.uint8) + x_fp8[:, + block_size * head_dim:] = sf.view(num_blocks, + block_size).view(dtype=torch.uint8) return x_fp8.view(num_blocks, block_size, num_heads, head_dim + 4) def per_custom_dims_cast_to_fp8( - x: torch.Tensor, dims: tuple, use_ue8m0: bool -) -> tuple[torch.Tensor, torch.Tensor]: + x: torch.Tensor, dims: tuple, + use_ue8m0: bool) -> tuple[torch.Tensor, torch.Tensor]: excluded_dims = tuple([i for i in range(x.dim()) if i not in set(dims)]) x_amax = x.abs().float().amax(dim=excluded_dims, keepdim=True).clamp(1e-4) sf = x_amax / 448.0 @@ -72,17 +69,13 @@ def _ref_fp8_mqa_logits( q = q.float() k = k.float() - mask_lo = ( - torch.arange(0, seq_len_kv, device="cuda")[None, :] - >= cu_seqlen_ks[:, None] - ) - mask_hi = ( - torch.arange(0, seq_len_kv, device="cuda")[None, :] - < cu_seqlen_ke[:, None] - ) + mask_lo = (torch.arange(0, seq_len_kv, device="cuda")[None, :] + >= cu_seqlen_ks[:, None]) + mask_hi = (torch.arange(0, seq_len_kv, device="cuda")[None, :] + < cu_seqlen_ke[:, None]) mask = mask_lo & mask_hi - score = torch.einsum("mhd,nd->hmn", q, k) + score = torch.einsum("mhd,and->hmn", q, k) logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0) logits = logits.masked_fill(~mask, float("-inf")) @@ -95,8 +88,8 @@ def test_deepgemm_fp8_mqa_logits(): torch.manual_seed(0) random.seed(0) num_heads, head_dim = 32, 128 - for seq_len in (512,): - for seq_len_kv in (1024,): + for seq_len in (512, ): + for seq_len_kv in (1024, ): for disable_cp in (False, True): q = torch.randn( seq_len, @@ -105,23 +98,24 @@ def test_deepgemm_fp8_mqa_logits(): device="cuda", dtype=torch.bfloat16, ) - kv = torch.randn( - seq_len_kv, head_dim, device="cuda", dtype=torch.bfloat16 - ) - weights = torch.randn( - seq_len, num_heads, device="cuda", dtype=torch.float32 - ) + kv = torch.randn(seq_len_kv, + head_dim, + device="cuda", + dtype=torch.bfloat16) + weights = torch.randn(seq_len, + num_heads, + device="cuda", + dtype=torch.float32) if disable_cp: ks = torch.zeros(seq_len, dtype=torch.int, device="cuda") - ke = torch.arange( - seq_len, dtype=torch.int, device="cuda" - ) + (seq_len_kv - seq_len) + ke = torch.arange(seq_len, dtype=torch.int, + device="cuda") + (seq_len_kv - seq_len) else: ks, ke = _generate_cp_test_data(seq_len, seq_len_kv) q_fp8 = q.to(torch.float8_e4m3fn) - kv_fp8 = per_custom_dims_cast_to_fp8(kv, (0,), False) + kv_fp8 = per_custom_dims_cast_to_fp8(kv, (0, ), False) logits = fp8_mqa_logits(q_fp8, kv_fp8, weights, ks, ke) ref_logits = _ref_fp8_mqa_logits( @@ -161,14 +155,11 @@ def _ref_fp8_paged_mqa_logits( context_lens_list = context_lens.tolist() for i in range(batch_size): context_len = context_lens_list[i] - q_offsets = torch.arange( - context_len - next_n, context_len, device="cuda" - ) - weight_slice = ( - weights[i * next_n : (i + 1) * next_n, :] - .transpose(0, 1) - .contiguous() - ) + q_offsets = torch.arange(context_len - next_n, + context_len, + device="cuda") + weight_slice = (weights[i * next_n:(i + 1) * next_n, :].transpose( + 0, 1).contiguous()) for block_rk in range(cdiv(context_len, block_size)): block_idx = block_tables[i][block_rk] qx, kx = q[i], kv_cache[block_idx] @@ -177,24 +168,21 @@ def _ref_fp8_paged_mqa_logits( (block_rk + 1) * block_size, device="cuda", ) - mask = (k_offsets[None, :] < context_len) & ( - k_offsets[None, :] <= q_offsets[:, None] - ) + mask = (k_offsets[None, :] < context_len) & (k_offsets[None, :] + <= q_offsets[:, None]) s = torch.where( mask[None, :, :], (qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to( - logits.dtype - ), + logits.dtype), float("-inf"), ) s = torch.relu(s) * weight_slice[..., None] s = s.sum(dim=0) logits[ - i * next_n : (i + 1) * next_n, - block_rk * block_size : (block_rk + 1) * block_size, - ] = torch.where( - k_offsets[None, :] <= q_offsets[:, None], s, float("-inf") - ) + i * next_n:(i + 1) * next_n, + block_rk * block_size:(block_rk + 1) * block_size, + ] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, + float("-inf")) return logits @@ -207,7 +195,7 @@ def test_deepgemm_fp8_paged_mqa_logits(): max_model_len = 4096 for batch_size, next_n in [(4, 1), (2, 2)]: for heads, index_dim in [(32, 128)]: - for avg_kv in (2048,): + for avg_kv in (2048, ): num_blocks, blocksize = max_model_len * 2, 64 q = torch.randn( @@ -226,18 +214,12 @@ def test_deepgemm_fp8_paged_mqa_logits(): dtype=torch.float32, ) - context_lens = ( - torch.randint( - int(0.8 * avg_kv), int(1.2 * avg_kv), (batch_size,) - ) - .cuda() - .to(torch.int32) - ) - max_block_len = ( - (context_lens.max().item() + blocksize - 1) - // blocksize - * blocksize - ) + context_lens = (torch.randint(int(0.8 * avg_kv), + int(1.2 * avg_kv), + (batch_size, )).cuda().to( + torch.int32)) + max_block_len = ((context_lens.max().item() + blocksize - 1) // + blocksize * blocksize) block_tables = torch.zeros( (batch_size, max_block_len), device="cuda", @@ -257,8 +239,7 @@ def test_deepgemm_fp8_paged_mqa_logits(): kv_cache_fp8 = kv_cache_cast_to_fp8(kv_cache) schedule_metadata = get_paged_mqa_logits_metadata( - context_lens, blocksize, get_num_sms() - ) + context_lens, blocksize, get_num_sms()) logits = fp8_paged_mqa_logits( q_fp8, kv_cache_fp8, @@ -278,20 +259,15 @@ def test_deepgemm_fp8_paged_mqa_logits(): max_model_len, ) - positions = ( - torch.arange(max_model_len, device="cuda") - .unsqueeze(0) - .expand(batch_size * next_n, -1) - ) + positions = (torch.arange(max_model_len, + device="cuda").unsqueeze(0).expand( + batch_size * next_n, -1)) row_indices = ( - torch.arange(batch_size * next_n, device="cuda") // next_n - ) + torch.arange(batch_size * next_n, device="cuda") // next_n) next_n_offset = ( - torch.arange(batch_size * next_n, device="cuda") % next_n - ) - mask = positions <= ( - context_lens[row_indices] - next_n + next_n_offset - ).unsqueeze(1) + torch.arange(batch_size * next_n, device="cuda") % next_n) + mask = positions <= (context_lens[row_indices] - next_n + + next_n_offset).unsqueeze(1) logits = logits.masked_fill(~mask, 0) ref_logits = ref_logits.masked_fill(~mask, 0) diff --git a/tests/kernels/attention/test_flashmla_sparse.py b/tests/kernels/attention/test_flashmla_sparse.py index 62ff7f65a0a2..9036e4e7800b 100644 --- a/tests/kernels/attention/test_flashmla_sparse.py +++ b/tests/kernels/attention/test_flashmla_sparse.py @@ -1,3 +1,5 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest import torch @@ -54,15 +56,14 @@ def test_sparse_flashmla_decode_smoke(): # Metadata q_seq_per_hk = seqlen_q * num_heads_q // num_heads_k - q_heads_per_hk = num_heads_q // num_heads_k + # q_heads_per_hk = num_heads_q // num_heads_k cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device) tile_md, num_splits = fm.get_mla_metadata(cache_seqlens, - - q_seq_per_hk, - num_heads_k, - num_heads_q=num_heads_q, - topk=topk, - is_fp8_kvcache=True) + q_seq_per_hk, + num_heads_k, + num_heads_q=num_heads_q, + topk=topk, + is_fp8_kvcache=True) # Inputs q = torch.zeros((batch_size, seqlen_q, num_heads_q, head_dim_k), @@ -75,10 +76,18 @@ def test_sparse_flashmla_decode_smoke(): dtype=torch.int32, device=device) - block_table = torch.zeros((batch_size, 128), dtype=torch.int32, device=device) - out, lse = fm.flash_mla_with_kvcache(q, k_cache, block_table, cache_seqlens, - head_dim_v, tile_md, - num_splits, indices=indices, is_fp8_kvcache=True) + block_table = torch.zeros((batch_size, 128), + dtype=torch.int32, + device=device) + out, lse = fm.flash_mla_with_kvcache(q, + k_cache, + block_table, + cache_seqlens, + head_dim_v, + tile_md, + num_splits, + indices=indices, + is_fp8_kvcache=True) assert out.shape[0] == batch_size assert out.shape[-1] == head_dim_v assert lse.shape[0] == batch_size @@ -103,8 +112,8 @@ def test_sparse_flashmla_prefill_smoke(): kv = torch.zeros((s_kv, h_kv, d_qk), dtype=torch.bfloat16, device=device) indices = torch.zeros((s_q, h_kv, topk), dtype=torch.int32, device=device) - out, max_logits, lse = fm.flash_mla_sparse_prefill(q, kv, indices, 1.0, d_v) + out, max_logits, lse = fm.flash_mla_sparse_prefill(q, kv, indices, 1.0, + d_v) assert out.shape == (s_q, h_q, d_v) assert max_logits.shape == (s_q, h_q) assert lse.shape == (s_q, h_q) - diff --git a/tests/kernels/attention/test_indexer.py b/tests/kernels/attention/test_indexer.py index 5ed6c212e528..6abb02d92aa4 100644 --- a/tests/kernels/attention/test_indexer.py +++ b/tests/kernels/attention/test_indexer.py @@ -1,19 +1,17 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import random import torch +from vllm import _custom_ops as ops from vllm.utils import cdiv -from vllm.v1.attention.backends.mla.indexer import kv_spans_from_batches +from vllm.utils.deep_gemm import (calc_diff, fp8_mqa_logits, + fp8_paged_mqa_logits, get_num_sms, + get_paged_mqa_logits_metadata) from vllm.utils.tile_lang_kernels import act_quant, fp8_index -from vllm import _custom_ops as ops -from vllm.model_executor.models.deepseek_v2 import indexer_k_quant_and_cache -from vllm.utils.deep_gemm import ( - fp8_mqa_logits, - calc_diff, - get_paged_mqa_logits_metadata, - fp8_paged_mqa_logits, - get_num_sms, -) +from vllm.v1.attention.backends.mla.indexer import kv_spans_from_batches + def kv_cache_cast_to_fp8(x: torch.Tensor) -> torch.Tensor: num_blocks, block_size, num_heads, head_dim = x.shape @@ -31,47 +29,58 @@ def kv_cache_cast_to_fp8(x: torch.Tensor) -> torch.Tensor: block_size).view(dtype=torch.uint8) return x_fp8.view(num_blocks, block_size, num_heads, head_dim + 4) + def ref_compute_logits_fp8(q, kv, weights, mask, block_size): q_fp8, q_scale = act_quant(q, block_size, "ue8m0") k_fp8, k_scale = act_quant(kv, block_size, "ue8m0") - + weights = weights.unsqueeze(-1) * q_scale weights = weights * (128**(-0.5)) - index_score = fp8_index( - q_fp8.contiguous(), weights, - k_fp8.contiguous(), - k_scale.contiguous()) + index_score = fp8_index(q_fp8.contiguous(), weights, k_fp8.contiguous(), + k_scale.contiguous()) if mask is not None: index_score += mask return index_score + def ref_indexer(seq_len, q, kv, weights, block_size, topk): B = seq_len.shape[0] total_seqlen = torch.sum(seq_len) - varlen_logits = torch.full((total_seqlen, total_seqlen), float("-inf"), device="cuda") - + varlen_logits = torch.full((total_seqlen, total_seqlen), + float("-inf"), + device="cuda") + current_context_ptr = 0 for i in range(B): S = seq_len[i] q_s = q[i][:S].contiguous().unsqueeze(0) kv_s = kv[i][:S].contiguous().unsqueeze(0) weights_s = weights[i][:S].contiguous().unsqueeze(0) - mask = torch.full( - (S, S), float("-inf"), - device="cuda").triu_(1) + mask = torch.full((S, S), float("-inf"), device="cuda").triu_(1) logits = ref_compute_logits_fp8(q_s, kv_s, weights_s, mask, block_size) logits = logits.squeeze(0) - - varlen_logits[current_context_ptr:current_context_ptr + S, current_context_ptr: current_context_ptr + S] = logits + + varlen_logits[current_context_ptr:current_context_ptr + S, + current_context_ptr:current_context_ptr + S] = logits current_context_ptr += S return varlen_logits -def deepgemm_mqa_indexer(seq_len, query_seq_len, q, kv, weights, block_size, topk, is_kv_batched=True): + +def deepgemm_mqa_indexer( + seq_len, + query_seq_len, + q, + kv, + weights, + block_size, + topk, + is_kv_batched=True, +): B = seq_len.shape[0] concat_q = [] concat_kv = [] concat_weights = [] - + for i in range(B): S = seq_len[i] q_s = q[i][:S].contiguous() @@ -82,29 +91,25 @@ def deepgemm_mqa_indexer(seq_len, query_seq_len, q, kv, weights, block_size, top if is_kv_batched: concat_kv.append(kv_s) concat_weights.append(weight_s) - + q = torch.cat(concat_q, dim=0) if is_kv_batched: kv = torch.cat(concat_kv, dim=0) weights = torch.cat(concat_weights, dim=0) q_fp8, q_scale = act_quant(q, block_size, "ue8m0") kv_fp8, kv_scale = act_quant(kv, block_size, "ue8m0") - + weights = weights.unsqueeze(-1) * (128**(-0.5)) * q_scale weights = weights.squeeze(-1) query_start_loc = torch.empty((B + 1), device="cuda") query_start_loc[0] = 0 query_start_loc[1:] = query_seq_len.cumsum(dim=0).to(dtype=torch.int32) - cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(query_start_loc, seq_len) + cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(query_start_loc, + seq_len) - logits = fp8_mqa_logits( - q_fp8, - (kv_fp8, kv_scale), - weights, - cu_seqlen_ks, - cu_seqlen_ke - ) + logits = fp8_mqa_logits(q_fp8, (kv_fp8, kv_scale), weights, cu_seqlen_ks, + cu_seqlen_ke) topk_indices = logits.topk(topk, dim=-1)[1] mask_lo = topk_indices >= cu_seqlen_ks[:, None] mask_hi = topk_indices < cu_seqlen_ke[:, None] @@ -112,26 +117,27 @@ def deepgemm_mqa_indexer(seq_len, query_seq_len, q, kv, weights, block_size, top topk_indices = topk_indices.masked_fill(~mask, -1) return logits + def test_prefill_indexer(): B = 3 S = 128 SKV = S H = 64 - HKV = 1 + # HKV = 1 D = 128 block_size = 128 topk = 64 device = "cuda" - seq_len = torch.randint(low=64, high=S, size=(B,)) - - q = torch.randn(B, S, H, D, device="cuda", - dtype=torch.bfloat16) - kv = torch.randn(B, SKV, D, device="cuda", - dtype=torch.bfloat16) - weights = torch.randn(B, S, H, device=device, dtype=torch.float32) * H**-0.5 + seq_len = torch.randint(low=64, high=S, size=(B, )) + + q = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16) + kv = torch.randn(B, SKV, D, device="cuda", dtype=torch.bfloat16) + weights = torch.randn(B, S, H, device=device, + dtype=torch.float32) * H**-0.5 ref_logits = ref_indexer(seq_len, q, kv, weights, block_size, topk) - deepgemm_logits = deepgemm_mqa_indexer(seq_len, seq_len, q, kv, weights, block_size, topk) + deepgemm_logits = deepgemm_mqa_indexer(seq_len, seq_len, q, kv, weights, + block_size, topk) torch.testing.assert_close(ref_logits, deepgemm_logits) @@ -139,33 +145,30 @@ def test_decode_paged_indexer(): num_blocks, blocksize = 111 * 3000, 64 B = 3 S = 128 - SKV = S + # SKV = S H = 64 - HKV = 1 + # HKV = 1 D = 128 block_size = 128 topk = 64 device = "cuda" - seq_len = torch.randint(low=64, high=S, size=(B,), device="cuda") + seq_len = torch.randint(low=64, high=S, size=(B, ), device="cuda") query_seq_len = torch.ones(B, device="cuda") - q = torch.randn((B, 1, H, D), - device='cuda', - dtype=torch.bfloat16) + q = torch.randn((B, 1, H, D), device='cuda', dtype=torch.bfloat16) kv_cache = torch.randn((num_blocks, blocksize, 1, D), - device='cuda', - dtype=torch.bfloat16) - weights = torch.randn((B * 1, H), - device='cuda', - dtype=torch.float32) * H**-0.5 + device='cuda', + dtype=torch.bfloat16) + weights = torch.randn( + (B * 1, H), device='cuda', dtype=torch.float32) * H**-0.5 max_block_len = (seq_len.max().item() + blocksize - - 1) // blocksize * blocksize - + 1) // blocksize * blocksize + block_tables = torch.zeros((B, max_block_len), - device='cuda', - dtype=torch.int32) - + device='cuda', + dtype=torch.int32) + counter = 0 block_idx_pool = list(range(num_blocks)) random.shuffle(block_idx_pool) @@ -174,51 +177,58 @@ def test_decode_paged_indexer(): for j in range(cdiv(ctx_len, blocksize)): block_tables[i][j] = block_idx_pool[counter] counter += 1 - - flatten_kv = torch.empty( - [seq_len.sum(), D], device="cuda", dtype=torch.bfloat16 - ) + + flatten_kv = torch.empty([seq_len.sum(), D], + device="cuda", + dtype=torch.bfloat16) cu_seq_lens = torch.cat([ - torch.zeros(1, dtype=torch.int32, device=device), - seq_len.cumsum(dim=0) - ]).to(torch.int32).cuda() + torch.zeros(1, dtype=torch.int32, device=device), + seq_len.cumsum(dim=0) + ]).to(torch.int32).cuda() ops.cp_gather_cache( - kv_cache, + kv_cache, flatten_kv, block_tables, cu_seq_lens, B, ) - - ref_logits = deepgemm_mqa_indexer(seq_len, query_seq_len, q, flatten_kv, weights, block_size, topk, is_kv_batched=False) + + ref_logits = deepgemm_mqa_indexer(seq_len, + query_seq_len, + q, + flatten_kv, + weights, + block_size, + topk, + is_kv_batched=False) q_fp8, q_scale = act_quant(q, block_size, "ue8m0") kv_cache_fp8 = kv_cache_cast_to_fp8(kv_cache) - schedule_metadata = get_paged_mqa_logits_metadata( - seq_len.int(), blocksize, get_num_sms()) - + schedule_metadata = get_paged_mqa_logits_metadata(seq_len.int(), blocksize, + get_num_sms()) + weights = weights.unsqueeze(-1) * (128**(-0.5)) * q_scale.squeeze(1) weights = weights.squeeze(-1) - - logits = fp8_paged_mqa_logits( - q_fp8, kv_cache_fp8, weights, seq_len.int(), block_tables, - schedule_metadata, 4096) - + + logits = fp8_paged_mqa_logits(q_fp8, kv_cache_fp8, weights, seq_len.int(), + block_tables, schedule_metadata, 4096) + concat_logit = [] context = 0 for i in range(B): per_seq_logits = torch.zeros(4096, device="cuda") S = seq_len[i] - per_seq_logits[:S] = ref_logits[i][context: context + S] + per_seq_logits[:S] = ref_logits[i][context:context + S] concat_logit.append(per_seq_logits) context += S ref_logits = torch.stack(concat_logit, dim=0) logits[logits == float("-inf")] = 0 diff = calc_diff(logits, ref_logits) assert diff < 1e-3, f"{diff=}" - + + if __name__ == "__main__": test_prefill_indexer() - test_decode_paged_indexer() \ No newline at end of file + test_decode_paged_indexer() diff --git a/tests/kernels/attention/test_pack_unpack_triton.py b/tests/kernels/attention/test_pack_unpack_triton.py index 59a9b64eebff..20c0b262b479 100644 --- a/tests/kernels/attention/test_pack_unpack_triton.py +++ b/tests/kernels/attention/test_pack_unpack_triton.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import pytest import torch from torch.testing import assert_close @@ -12,37 +11,37 @@ def test_pack_seq_basic_fp8(): """Test basic functionality of pack_seq_triton with fp8 and 3D tensors.""" device = "cuda" dtype = torch.float8_e4m3fn - + # Test cases with 3D tensors (N, H, D) test_cases = [ - (6, 8, 4, 2, [3, 3]), # (6, 8, 4) -> (2, 3, 8, 4) + (6, 8, 4, 2, [3, 3]), # (6, 8, 4) -> (2, 3, 8, 4) (10, 4, 8, 3, [2, 4, 4]), # (10, 4, 8) -> (3, 4, 4, 8) (20, 16, 32, 4, [5, 5, 5, 5]), # (20, 16, 32) -> (4, 5, 16, 32) ] - + for N, H, D, B, lengths_list in test_cases: # Create input tensor with small values for fp8 x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1 x = x.to(dtype=dtype) lengths = torch.tensor(lengths_list, device=device) - + # Pack the data packed = pack_seq_triton(x, lengths) - + # Check output shape and properties expected_shape = (B, max(lengths_list), H, D) assert packed.shape == expected_shape assert packed.dtype == dtype assert packed.device == x.device - + # Check that valid data is preserved (within fp8 precision) for b in range(B): start_idx = sum(lengths_list[:b]) seq_len = lengths_list[b] - + expected_data = x[start_idx:start_idx + seq_len].to(torch.float32) actual_data = packed[b, :seq_len].to(torch.float32) - + assert_close(actual_data, expected_data, rtol=1e-1, atol=1e-2) @@ -52,66 +51,70 @@ def test_pack_seq_custom_padding_fp8(): dtype = torch.float8_e4m3fn N, H, D, B = 20, 8, 16, 2 lengths = torch.tensor([10, 10], device=device) - + x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1 x = x.to(dtype=dtype) - + # Test with different padding values for pad_value in [-100.0, -10.0, 0.0, 10.0, 100.0]: result = pack_seq_triton(x, lengths, pad_value=pad_value) - + # Check valid data for b in range(B): start_idx = b * 10 expected_data = x[start_idx:start_idx + 10].to(torch.float32) actual_data = result[b, :10].to(torch.float32) assert_close(actual_data, expected_data, rtol=1e-1, atol=1e-2) - + # Check padding (fp8 has limited range, so check for large values) padded_data = result[:, 10:].to(torch.float32) if pad_value < 0: assert torch.all(padded_data < -50) # Large negative values elif pad_value > 0: - assert torch.all(padded_data > 50) # Large positive values + assert torch.all(padded_data > 50) # Large positive values else: - assert torch.allclose(padded_data, torch.zeros_like(padded_data), atol=1e-2) + assert torch.allclose(padded_data, + torch.zeros_like(padded_data), + atol=1e-2) def test_pack_seq_default_negative_inf_padding_fp8(): """Test that pack_seq_triton uses -inf padding by default for fp8.""" device = "cuda" dtype = torch.float8_e4m3fn - N, H, D, B = 20, 8, 16, 2 + # B = 2 + N, H, D = 20, 8, 16 lengths = torch.tensor([10, 10], device=device) - + x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1 x = x.to(dtype=dtype) result = pack_seq_triton(x, lengths) - + # Check that padding is large negative values (fp8 representation of -inf) padded_data = result[:, 10:].to(torch.float32) - assert torch.all(padded_data < -100) # fp8 -inf is represented as large negative number + assert torch.all( + padded_data < -100) # fp8 -inf is represented as large negative number def test_pack_seq_edge_cases_fp8(): """Test pack_seq_triton with edge cases for fp8.""" device = "cuda" dtype = torch.float8_e4m3fn - + # Test with single batch element x = torch.randn(10, 8, 16, dtype=torch.float32, device=device) * 0.1 x = x.to(dtype=dtype) lengths = torch.tensor([10], device=device) result = pack_seq_triton(x, lengths) assert result.shape == (1, 10, 8, 16) - + # Test with very short sequences x = torch.randn(20, 4, 8, dtype=torch.float32, device=device) * 0.1 x = x.to(dtype=dtype) lengths = torch.tensor([1, 1, 1], device=device) result = pack_seq_triton(x, lengths) assert result.shape == (3, 1, 4, 8) - + # Test with different sequence lengths x = torch.randn(15, 8, 16, dtype=torch.float32, device=device) * 0.1 x = x.to(dtype=dtype) @@ -126,16 +129,16 @@ def test_pack_seq_different_block_sizes_fp8(): dtype = torch.float8_e4m3fn N, H, D, B = 100, 16, 32, 4 lengths = torch.tensor([25, 25, 25, 25], device=device) - + x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1 x = x.to(dtype=dtype) - + # Test different block sizes for block_t, block_d in [(32, 32), (64, 64), (128, 128)]: result = pack_seq_triton(x, lengths, block_t=block_t, block_d=block_d) - + assert result.shape == (B, 25, H, D) - + # Check that valid data is preserved (within fp8 precision) for b in range(B): start_idx = b * 25 @@ -150,12 +153,12 @@ def test_pack_seq_shape_consistency(): dtype = torch.float8_e4m3fn N, H, D, B = 20, 8, 16, 2 lengths = torch.tensor([10, 10], device=device) - + x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1 x = x.to(dtype=dtype) - + result = pack_seq_triton(x, lengths) - + # Check shape consistency assert result.shape[0] == B # Batch dimension assert result.shape[1] == lengths.max().item() # Max sequence length @@ -166,7 +169,7 @@ def test_pack_unpack_roundtrip_fp8(): """Test that pack -> unpack gives us back the original data for fp8.""" device = "cuda" dtype = torch.float8_e4m3fn - + # Test cases with 3D tensors test_cases = [ (6, 8, 4, 2, [3, 3]), @@ -174,35 +177,38 @@ def test_pack_unpack_roundtrip_fp8(): (20, 16, 32, 4, [5, 5, 5, 5]), (15, 8, 16, 3, [7, 5, 3]), ] - + for N, H, D, B, lengths_list in test_cases: # Create input tensor with small values for fp8 x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1 x = x.to(dtype=dtype) lengths = torch.tensor(lengths_list, device=device) - + # Pack the data packed = pack_seq_triton(x, lengths) - + # Unpack the data unpacked = unpack_seq_triton(packed, lengths) - + # Check that we get back the original data (within fp8 precision) assert unpacked.shape == x.shape x_f32 = x.to(torch.float32) unpacked_f32 = unpacked.to(torch.float32) assert_close(x_f32, unpacked_f32, rtol=1e-3, atol=1e-3) - + # Unpack without explicit start locations (computed in kernel) unpacked_with_loc = unpack_seq_triton(packed, lengths) - assert_close(x_f32, unpacked_with_loc.to(torch.float32), rtol=1e-3, atol=1e-2) + assert_close(x_f32, + unpacked_with_loc.to(torch.float32), + rtol=1e-3, + atol=1e-2) def test_unpack_seq_triton_edge_cases_fp8(): """Test unpack function with edge cases for fp8.""" device = "cuda" dtype = torch.float8_e4m3fn - + # Test with single batch element x = torch.randn(10, 8, 16, dtype=torch.float32, device=device) * 0.1 x = x.to(dtype=dtype) @@ -210,8 +216,11 @@ def test_unpack_seq_triton_edge_cases_fp8(): packed = pack_seq_triton(x, lengths) unpacked = unpack_seq_triton(packed, lengths) assert unpacked.shape == x.shape - assert_close(x.to(torch.float32), unpacked.to(torch.float32), rtol=1e-1, atol=1e-2) - + assert_close(x.to(torch.float32), + unpacked.to(torch.float32), + rtol=1e-1, + atol=1e-2) + # Test with very short sequences x = torch.randn(20, 4, 8, dtype=torch.float32, device=device) * 0.1 x = x.to(dtype=dtype) @@ -219,12 +228,18 @@ def test_unpack_seq_triton_edge_cases_fp8(): packed = pack_seq_triton(x, lengths) unpacked = unpack_seq_triton(packed, lengths) # Only compare the first 3 elements that were actually packed - assert_close(x[:3].to(torch.float32), unpacked.to(torch.float32), rtol=1e-1, atol=1e-2) - + assert_close(x[:3].to(torch.float32), + unpacked.to(torch.float32), + rtol=1e-1, + atol=1e-2) + x = torch.randn(15, 8, 16, dtype=torch.float32, device=device) * 0.1 x = x.to(dtype=dtype) lengths = torch.tensor([5, 7, 3], device=device) packed = pack_seq_triton(x, lengths) unpacked = unpack_seq_triton(packed, lengths) assert unpacked.shape == x.shape - assert_close(x.to(torch.float32), unpacked.to(torch.float32), rtol=1e-1, atol=1e-2) \ No newline at end of file + assert_close(x.to(torch.float32), + unpacked.to(torch.float32), + rtol=1e-1, + atol=1e-2) diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py index 392c0ab3eeca..0154237bc04a 100644 --- a/tests/v1/attention/test_mla_backends.py +++ b/tests/v1/attention/test_mla_backends.py @@ -5,13 +5,12 @@ import pytest import torch -from vllm import _custom_ops as ops - from tests.v1.attention.utils import (BatchSpec, _Backend, create_common_attn_metadata, create_standard_kv_cache_spec, create_vllm_config, get_attention_backend) +from vllm import _custom_ops as ops from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.kv_cache_interface import FullAttentionSpec @@ -128,9 +127,9 @@ def create_and_prepopulate_kv_cache( entry_size, dtype=torch.uint8, device=device) - scale_tensor = (scale if isinstance(scale, torch.Tensor) else - torch.tensor(scale, dtype=torch.float32, - device=device)) + scale_tensor = (scale + if isinstance(scale, torch.Tensor) else torch.tensor( + scale, dtype=torch.float32, device=device)) scale_tensor = scale_tensor.to(device=device, dtype=torch.float32) else: # Create MLA KV cache: (num_blocks, block_size, head_size) @@ -154,8 +153,7 @@ def create_and_prepopulate_kv_cache( start = start_block_idx * block_size if use_fp8_ds_mla: - slots = torch.arange(context_len, - device=device, + slots = torch.arange(context_len, device=device, dtype=torch.long) + start ops.concat_and_cache_mla( kv_c_context, @@ -166,8 +164,8 @@ def create_and_prepopulate_kv_cache( scale=scale_tensor, ) else: - kv_context = torch.cat([kv_c_context, k_pe_context.squeeze(1)], - dim=-1) + kv_context = torch.cat( + [kv_c_context, k_pe_context.squeeze(1)], dim=-1) end = start + kv_context.shape[0] kv_cache_flat[start:end, ...] = kv_context diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index e5b37bf9acf9..f07fa1e4e7be 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1678,13 +1678,13 @@ def cp_gather_cache(src_cache: torch.Tensor, cu_seq_lens, batch_size, seq_starts) -def indexer_k_quant_and_cache(k: torch.Tensor, - kv_cache: torch.Tensor, +def indexer_k_quant_and_cache(k: torch.Tensor, kv_cache: torch.Tensor, slot_mapping: torch.Tensor, quant_block_size: int, kv_cache_dtype: str) -> None: torch.ops._C_cache_ops.indexer_k_quant_and_cache(k, kv_cache, slot_mapping, - quant_block_size, kv_cache_dtype) + quant_block_size, + kv_cache_dtype) def get_device_attribute(attribute: int, device: int) -> int: diff --git a/vllm/attention/ops/common.py b/vllm/attention/ops/common.py index 2136ed9d2593..eb6d11c141c7 100644 --- a/vllm/attention/ops/common.py +++ b/vllm/attention/ops/common.py @@ -139,21 +139,24 @@ def cp_lse_ag_out_rs(cp_attn_out: torch.Tensor, out = cp_group.reduce_scatter(out, dim=1) return out + @triton.jit def _pack_seq_kernel( - x_ptr, # [N, D] - out_ptr, # [B, Lmax, D] - lengths_ptr, # *i32, [B] - N: tl.constexpr, D: tl.constexpr, Lmax: tl.constexpr, - PAD_VALUE: tl.constexpr, - BLOCK_T: tl.constexpr, # timesteps per program - BLOCK_D: tl.constexpr # features per program + x_ptr, # [N, D] + out_ptr, # [B, Lmax, D] + lengths_ptr, # *i32, [B] + N: tl.constexpr, + D: tl.constexpr, + Lmax: tl.constexpr, + PAD_VALUE: tl.constexpr, + BLOCK_T: tl.constexpr, # timesteps per program + BLOCK_D: tl.constexpr # features per program ): - pid_b = tl.program_id(0) # batch id - pid_t = tl.program_id(1) # block over time dimension - pid_d = tl.program_id(2) # block over feature dimension - off_t = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) # [BLOCK_T] - off_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) # [BLOCK_D] + pid_b = tl.program_id(0) # batch id + pid_t = tl.program_id(1) # block over time dimension + pid_d = tl.program_id(2) # block over feature dimension + off_t = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) # [BLOCK_T] + off_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) # [BLOCK_D] # Compute start index and sequence length from cumulative lengths in_start = 0 @@ -173,7 +176,8 @@ def _pack_seq_kernel( x_row_ptr = x_ptr + in_row[:, None] * D + off_d[None, :] # out_ptr: row-major [B, Lmax, D] - out_row_ptr = out_ptr + (pid_b * Lmax + off_t)[:, None] * D + off_d[None, :] + out_row_ptr = out_ptr + (pid_b * Lmax + off_t)[:, + None] * D + off_d[None, :] # Initialize with PAD (cast will occur as needed based on out_ptr dtype) d_mask = off_d[None, :] < D @@ -184,7 +188,12 @@ def _pack_seq_kernel( x_vals = tl.load(x_row_ptr, mask=valid_row[:, None] & d_mask) tl.store(out_row_ptr, x_vals, mask=valid_row[:, None] & d_mask) -def pack_seq_triton(x, lengths, pad_value=-float('inf'), block_t=64, block_d=64): + +def pack_seq_triton(x, + lengths, + pad_value=-float('inf'), + block_t=64, + block_d=64): """ Pack sequences of different lengths into a batched tensor. @@ -198,7 +207,7 @@ def pack_seq_triton(x, lengths, pad_value=-float('inf'), block_t=64, block_d=64) Returns: packed: [B, Lmax, ...] - packed tensor """ - + # Handle multi-dimensional input by reshaping to (N, -1) original_shape = x.shape if len(original_shape) > 2: @@ -208,45 +217,51 @@ def pack_seq_triton(x, lengths, pad_value=-float('inf'), block_t=64, block_d=64) else: N, D = x.shape x_reshaped = x - + B = lengths.numel() Lmax = int(lengths.max().item()) - + # Starts are computed inside the kernel from lengths out = torch.empty((B, Lmax, D), device=x.device, dtype=x.dtype) grid = (B, triton.cdiv(Lmax, block_t), triton.cdiv(D, block_d)) - _pack_seq_kernel[grid]( - x_reshaped, out, lengths.int(), - N, D, Lmax, - PAD_VALUE=float(pad_value), - BLOCK_T=block_t, BLOCK_D=block_d, - num_warps=4, num_stages=2 - ) - + _pack_seq_kernel[grid](x_reshaped, + out, + lengths.int(), + N, + D, + Lmax, + PAD_VALUE=float(pad_value), + BLOCK_T=block_t, + BLOCK_D=block_d, + num_warps=4, + num_stages=2) + # Reshape output back to original dimensions (except first dimension) if len(original_shape) > 2: output_shape = (B, Lmax) + original_shape[1:] out = out.reshape(output_shape) - + return out @triton.jit def _unpack_seq_triton_kernel( - packed_ptr, # [B, Lmax, D] - out_ptr, # [N, D] - lengths_ptr, # *i32, [B] - B: tl.constexpr, Lmax: tl.constexpr, D: tl.constexpr, - BLOCK_T: tl.constexpr, # timesteps per program - BLOCK_D: tl.constexpr # features per program + packed_ptr, # [B, Lmax, D] + out_ptr, # [N, D] + lengths_ptr, # *i32, [B] + B: tl.constexpr, + Lmax: tl.constexpr, + D: tl.constexpr, + BLOCK_T: tl.constexpr, # timesteps per program + BLOCK_D: tl.constexpr # features per program ): - pid_b = tl.program_id(0) # batch id - pid_t = tl.program_id(1) # block over time dimension - pid_d = tl.program_id(2) # block over feature dimension - off_t = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) # [BLOCK_T] - off_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) # [BLOCK_D] + pid_b = tl.program_id(0) # batch id + pid_t = tl.program_id(1) # block over time dimension + pid_d = tl.program_id(2) # block over feature dimension + off_t = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) # [BLOCK_T] + off_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) # [BLOCK_D] # bounds: compute start from cumulative lengths in_start = 0 @@ -263,7 +278,8 @@ def _unpack_seq_triton_kernel( # Pointers # packed_ptr: row-major [B, Lmax, D] - packed_row_ptr = packed_ptr + (pid_b * Lmax + off_t)[:, None] * D + off_d[None, :] + packed_row_ptr = packed_ptr + (pid_b * Lmax + + off_t)[:, None] * D + off_d[None, :] # out_ptr: row-major [N, D] out_row_ptr = out_ptr + out_row[:, None] * D + off_d[None, :] @@ -288,7 +304,7 @@ def unpack_seq_triton(packed_tensor, lengths, block_t=64, block_d=64): Returns: unpacked_tensor: [N, ...] where N = sum(lengths) """ - + # Handle multi-dimensional input by reshaping to (B, Lmax, -1) original_shape = packed_tensor.shape if len(original_shape) > 3: @@ -298,23 +314,29 @@ def unpack_seq_triton(packed_tensor, lengths, block_t=64, block_d=64): else: B, Lmax, D = packed_tensor.shape packed_reshaped = packed_tensor - + # Calculate total number of elements N = int(lengths.sum().item()) - - out = torch.empty((N, D), device=packed_tensor.device, dtype=packed_tensor.dtype) - + + out = torch.empty((N, D), + device=packed_tensor.device, + dtype=packed_tensor.dtype) + grid = (B, triton.cdiv(Lmax, block_t), triton.cdiv(D, block_d)) - _unpack_seq_triton_kernel[grid]( - packed_reshaped, out, lengths.int(), - B, Lmax, D, - BLOCK_T=block_t, BLOCK_D=block_d, - num_warps=4, num_stages=2 - ) - + _unpack_seq_triton_kernel[grid](packed_reshaped, + out, + lengths.int(), + B, + Lmax, + D, + BLOCK_T=block_t, + BLOCK_D=block_d, + num_warps=4, + num_stages=2) + # Reshape output back to original dimensions (except first dimension) if len(original_shape) > 3: - output_shape = (N,) + original_shape[2:] + output_shape = (N, ) + original_shape[2:] out = out.reshape(output_shape) - - return out \ No newline at end of file + + return out diff --git a/vllm/config/cache.py b/vllm/config/cache.py index bf13a18e0e0c..58770649a8af 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -22,8 +22,8 @@ logger = init_logger(__name__) BlockSize = Literal[1, 8, 16, 32, 64, 128] -CacheDType = Literal[ - "auto", "bfloat16", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"] +CacheDType = Literal["auto", "bfloat16", "fp8", "fp8_e4m3", "fp8_e5m2", + "fp8_inc"] MambaDType = Literal["auto", "float32"] PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor"] diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 7bbb9cf6c34c..d381268f78c4 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -410,9 +410,7 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: """ hf_config = vllm_config.model_config.hf_config - is_v32 = hasattr( - hf_config, "index_topk" - ) + is_v32 = hasattr(hf_config, "index_topk") if is_v32: # For DeepSeekV3.2, we use a custom fp8 format as default (i.e. @@ -421,7 +419,8 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: if cache_config.cache_dtype == "auto" or \ cache_config.cache_dtype.startswith("fp8"): cache_config.cache_dtype = "fp8_ds_mla" - logger.info("Using custom fp8 kv-cache format for DeepSeekV3.2") + logger.info( + "Using custom fp8 kv-cache format for DeepSeekV3.2") if cache_config.cache_dtype == "bfloat16": cache_config.cache_dtype = "auto" logger.info("Using bfloat16 kv-cache for DeepSeekV3.2") diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index 1a3d8ea6efb7..7187915b2db9 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -53,20 +53,20 @@ def __init__(self, vllm_config: VllmConfig, prefix: str) -> None: self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False) - - self.is_v32 = hasattr( - config, "index_topk" - ) + + self.is_v32 = hasattr(config, "index_topk") if self.is_v32: topk_tokens = config.attn_module_list_cfg[0]["topk_tokens"] - topk_indices_buffer = torch.empty(vllm_config.scheduler_config.max_num_batched_tokens, - topk_tokens, - dtype=torch.int32, - device="cuda") + topk_indices_buffer = torch.empty( + vllm_config.scheduler_config.max_num_batched_tokens, + topk_tokens, + dtype=torch.int32, + device="cuda") else: topk_indices_buffer = None self.shared_head = SharedHead(config=config, quant_config=quant_config) - self.mtp_block = DeepseekV2DecoderLayer(vllm_config, prefix, topk_indices_buffer) + self.mtp_block = DeepseekV2DecoderLayer(vllm_config, prefix, + topk_indices_buffer) def forward( self, diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index c2fe06f3977d..55ed8462050c 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -31,20 +31,19 @@ import torch from torch import nn from transformers import DeepseekV2Config, DeepseekV3Config -import torch.distributed as dist -from vllm.attention.backends.abstract import AttentionBackend -from vllm.logger import init_logger -from vllm.config.compilation import CompilationConfig -import vllm.envs as envs from vllm.attention import Attention +from vllm.attention.backends.abstract import AttentionBackend +from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, ParallelConfig, VllmConfig, get_current_vllm_config +from vllm.config import (CacheConfig, ParallelConfig, VllmConfig, + get_current_vllm_config) from vllm.distributed import (get_ep_group, get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_gather) from vllm.forward_context import get_forward_context +from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.fused_moe import FusedMoE @@ -56,6 +55,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mla import MLAModules, MultiHeadLatentAttention from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + per_token_group_quant_fp8) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -63,24 +64,18 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.models.utils import sequence_parallel_chunk +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.utils import cdiv, direct_register_custom_op -from vllm.v1.attention.backends.mla.indexer import DeepseekV32IndexerBackend, DeepseekV32IndexerMetadata -from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton +from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits +from vllm.v1.attention.backends.mla.indexer import (DeepseekV32IndexerBackend, + DeepseekV32IndexerMetadata) +from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP -from .utils import (PPMissingLayer, extract_layer_index, - is_pp_missing_parameter, +from .utils import (PPMissingLayer, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) -from vllm.v1.kv_cache_interface import MLAAttentionSpec, KVCacheSpec -from vllm.utils.deep_gemm import ( - fp8_mqa_logits, - get_paged_mqa_logits_metadata, - fp8_paged_mqa_logits, -) -from vllm.model_executor.layers.quantization.utils.fp8_utils import per_token_group_quant_fp8 -from vllm.platforms import current_platform if current_platform.is_cuda_alike(): from vllm import _custom_ops as ops @@ -461,7 +456,7 @@ def __init__(self, head_dim: int, dtype: torch.dtype, prefix: str, compilation_config.static_forward_context[prefix] = self def get_kv_cache_spec(self) -> KVCacheSpec: - return MLAAttentionSpec( # Only has one vector instead of K + V + return MLAAttentionSpec( # Only has one vector instead of K + V block_size=self.cache_config.block_size, num_kv_heads=1, head_size=self.head_dim, @@ -469,21 +464,19 @@ def get_kv_cache_spec(self) -> KVCacheSpec: ) def forward(self): - attn_metadata = get_forward_context().attn_metadata - if isinstance(attn_metadata, dict): - attn_metadata = attn_metadata[self.prefix] - logger.info(f"attn_metadata {attn_metadata}") + ... def get_attn_backend(self) -> AttentionBackend: return DeepseekV32IndexerBackend + @torch.inference_mode() def cp_gather_indexer_k_quant_cache( - kv_cache, # [num_blocks, block_size, head_dim + 1] - dst_value, # [cu_seq_lens[-1], head_dim] - dst_scale, # [cu_seq_lens[-1], 4] - block_table, # [batch_size, num_blocks] - cu_seq_lens, # [batch_size + 1, ] + kv_cache, # [num_blocks, block_size, head_dim + 1] + dst_value, # [cu_seq_lens[-1], head_dim] + dst_scale, # [cu_seq_lens[-1], 4] + block_table, # [batch_size, num_blocks] + cu_seq_lens, # [batch_size + 1, ] batch_size, ): num_blocks, block_size, _ = kv_cache.shape @@ -501,21 +494,27 @@ def cp_gather_indexer_k_quant_cache( value = [] scale = [] - full_block = torch.arange(tot - 1, device=kv_cache.device, dtype=torch.int32) - # print(f"full_blocks: {blocks[full_block]}") - non_remaining_value = kv_cache[blocks[full_block], : block_size * head_dim].view(-1, head_dim) - non_remaining_scale = kv_cache[blocks[full_block], block_size * head_dim:].view(-1, 4) - - # for i in range(tot - 1): - # value.append(kv_cache[blocks[i], :block_size * head_dim]) - # scale.append(kv_cache[blocks[i], block_size * head_dim:]) + full_block = torch.arange(tot - 1, + device=kv_cache.device, + dtype=torch.int32) + non_remaining_value = kv_cache[blocks[full_block], :block_size * + head_dim].view(-1, head_dim) + non_remaining_scale = kv_cache[blocks[full_block], + block_size * head_dim:].view(-1, 4) remaining = s - (tot - 1) * block_size - # value.append(kv_cache[blocks[-1], :remaining * head_dim]) - # scale.append(kv_cache[blocks[-1], block_size * head_dim: block_size * head_dim + remaining * 4]) - value = torch.cat([non_remaining_value, kv_cache[blocks[-1], :remaining * head_dim].view(-1, head_dim)], dim=0) - scale = torch.cat([non_remaining_scale, kv_cache[blocks[-1], block_size * head_dim: block_size * head_dim + remaining * 4].view(-1, 4)], dim=0) + value = torch.cat([ + non_remaining_value, + kv_cache[blocks[-1], :remaining * head_dim].view(-1, head_dim) + ], + dim=0) + scale = torch.cat([ + non_remaining_scale, + kv_cache[blocks[-1], block_size * head_dim:block_size * head_dim + + remaining * 4].view(-1, 4) + ], + dim=0) expected_value.append(value) expected_scale.append(scale) @@ -582,14 +581,12 @@ def sparse_attn_indexer( if has_prefill: prefill_metadata = attn_metadata.prefill num_prefills = attn_metadata.num_prefills - k_fp8 = torch.empty( - [prefill_metadata.total_seq_lens, head_dim], - device=k.device, - dtype=torch.float8_e4m3fn) - k_scale = torch.empty( - [prefill_metadata.total_seq_lens, 1], - device=k.device, - dtype=torch.float32) + k_fp8 = torch.empty([prefill_metadata.total_seq_lens, head_dim], + device=k.device, + dtype=torch.float8_e4m3fn) + k_scale = torch.empty([prefill_metadata.total_seq_lens, 1], + device=k.device, + dtype=torch.float32) cp_gather_indexer_k_quant_cache( kv_cache, k_fp8, @@ -625,16 +622,19 @@ def sparse_attn_indexer( if has_decode: decode_metadata = attn_metadata.decode # kv_cache size requirement [num_block, block_size, n_head, head_dim], - # we only have [num_block, block_size, head_dim], + # we only have [num_block, block_size, head_dim], kv_cache = kv_cache.unsqueeze(-2) decode_lens = decode_metadata.decode_lens if decode_metadata.requires_padding: - # pad in edge case where we have short chunked prefill length < - # decode_threshold since we unstrictly split - # prefill and decode by decode_threshold (currently set to 1 + speculative tokens) - padded_q_fp8_decode_tokens = pack_seq_triton(q_fp8[:num_decode_tokens], decode_lens) + # pad in edge case where we have short chunked prefill length < + # decode_threshold since we unstrictly split + # prefill and decode by decode_threshold + # (currently set to 1 + speculative tokens) + padded_q_fp8_decode_tokens = pack_seq_triton( + q_fp8[:num_decode_tokens], decode_lens) else: - padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape(decode_lens.shape[0], -1, *q_fp8.shape[1:]) + padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape( + decode_lens.shape[0], -1, *q_fp8.shape[1:]) # TODO: move and optimize below logic with triton kernels batch_size = padded_q_fp8_decode_tokens.shape[0] assert batch_size == decode_metadata.seq_lens.shape[0] @@ -652,27 +652,35 @@ def sparse_attn_indexer( # padded query len current_device = padded_q_fp8_decode_tokens.device padded_num_tokens = batch_size * next_n - positions = torch.arange(max_model_len, device=current_device).unsqueeze(0).expand( - batch_size * next_n, -1) - row_indices = torch.arange(padded_num_tokens, device=current_device) // next_n - next_n_offset = torch.arange(padded_num_tokens, device=padded_q_fp8_decode_tokens.device) % next_n - index_end_pos = (decode_metadata.seq_lens[row_indices] - next_n + next_n_offset).unsqueeze(1) + positions = torch.arange(max_model_len, + device=current_device).unsqueeze(0).expand( + batch_size * next_n, -1) + row_indices = torch.arange(padded_num_tokens, + device=current_device) // next_n + next_n_offset = torch.arange( + padded_num_tokens, + device=padded_q_fp8_decode_tokens.device) % next_n + index_end_pos = (decode_metadata.seq_lens[row_indices] - next_n + + next_n_offset).unsqueeze(1) # index_end_pos: [B * N, 1] mask = positions <= index_end_pos # mask: [B * N, L] logits = logits.masked_fill(~mask, float('-inf')) - topk_indices = logits.topk(topk_tokens, dim=-1)[1].to( - torch.int32) # [B * N, K] - # ensure we don't set indices for the top k that out of range(masked already) + topk_indices = logits.topk(topk_tokens, + dim=-1)[1].to(torch.int32) # [B * N, K] + # ensure we don't set indices for the top k + # that is out of range(masked already) # this will happen if context length is shorter than K topk_indices[topk_indices > index_end_pos] = -1 if decode_metadata.requires_padding: - # if padded, we need to unpack the topk indices removing padded tokens - topk_indices = unpack_seq_triton(topk_indices.reshape(batch_size, -1, logits.shape[-1]), decode_lens) + # if padded, we need to unpack + # the topk indices removing padded tokens + topk_indices = unpack_seq_triton( + topk_indices.reshape(batch_size, -1, logits.shape[-1]), + decode_lens) topk_indices_buffer[:num_decode_tokens, :topk_indices. - shape[-1]] = topk_indices.to( - dtype=torch.int32) - + shape[-1]] = topk_indices.to(dtype=torch.int32) + return topk_indices_buffer @@ -761,7 +769,8 @@ def __init__(self, cache_config=cache_config) self.max_model_len = vllm_config.model_config.max_model_len self.prefix = prefix - from vllm.v1.attention.backends.mla.indexer import get_max_prefill_buffer_size + from vllm.v1.attention.backends.mla.indexer import ( + get_max_prefill_buffer_size) self.max_total_seq_len = get_max_prefill_buffer_size(vllm_config) def forward(self, hidden_states: torch.Tensor, qr: torch.Tensor, positions, @@ -776,7 +785,6 @@ def forward(self, hidden_states: torch.Tensor, qr: torch.Tensor, positions, k_pe, k_nope = torch.split( k, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1) - #FIXME (zyongye) this will cause OOM when using full sequence forward on 8xH200 q_pe, k_pe = rotary_emb(positions, q_pe, k_pe.unsqueeze(1)) q = torch.cat([q_pe, q_nope], dim=-1) k = torch.cat([k_pe.squeeze(1), k_nope], dim=-1) @@ -784,9 +792,10 @@ def forward(self, hidden_states: torch.Tensor, qr: torch.Tensor, positions, # we only quant q here since k quant is fused with cache insertion q = q.view(-1, self.head_dim) q_fp8, q_scale = per_token_group_quant_fp8(q, - self.quant_block_size, - column_major_scales=False, - use_ue8m0=self.scale_fmt is not None) + self.quant_block_size, + column_major_scales=False, + use_ue8m0=self.scale_fmt + is not None) q_fp8 = q_fp8.view(-1, self.n_head, self.head_dim) q_scale = q_scale.view(-1, self.n_head, 1) @@ -796,10 +805,20 @@ def forward(self, hidden_states: torch.Tensor, qr: torch.Tensor, positions, weights = weights.squeeze(-1) return torch.ops.vllm.sparse_attn_indexer( - hidden_states, self.k_cache.prefix, self.k_cache.kv_cache[0], - q_fp8, k, weights, self.quant_block_size, self.scale_fmt, - self.topk_tokens, self.head_dim, self.max_model_len, - self.max_total_seq_len, self.topk_indices_buffer) + hidden_states, + self.k_cache.prefix, + self.k_cache.kv_cache[0], + q_fp8, + k, + weights, + self.quant_block_size, + self.scale_fmt, + self.topk_tokens, + self.head_dim, + self.max_model_len, + self.max_total_seq_len, + self.topk_indices_buffer, + ) class DeepseekV2MLAAttention(nn.Module): @@ -909,9 +928,7 @@ def __init__( mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) self.scaling = self.scaling * mscale * mscale - self.is_v32 = hasattr( - config, "index_topk" - ) + self.is_v32 = hasattr(config, "index_topk") if self.is_v32: self.indexer = Indexer(vllm_config, config, hidden_size, @@ -1088,15 +1105,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config self.vocab_size = config.vocab_size - self.is_v32 = hasattr( - config, "index_topk" - ) + self.is_v32 = hasattr(config, "index_topk") if self.is_v32: topk_tokens = config.index_topk - topk_indices_buffer = torch.empty(vllm_config.scheduler_config.max_num_batched_tokens, - topk_tokens, - dtype=torch.int32, - device="cuda") + topk_indices_buffer = torch.empty( + vllm_config.scheduler_config.max_num_batched_tokens, + topk_tokens, + dtype=torch.int32, + device="cuda") else: topk_indices_buffer = None diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index ddec5dd64cc4..1b33b5e70e0b 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -8,6 +8,7 @@ """ from vllm.transformers_utils.configs.chatglm import ChatGLMConfig +from vllm.transformers_utils.configs.deepseek_v3 import DeepseekV3Config from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekVLV2Config from vllm.transformers_utils.configs.dotsocr import DotsOCRConfig from vllm.transformers_utils.configs.eagle import EAGLEConfig @@ -33,7 +34,6 @@ Step3VisionEncoderConfig, Step3VLConfig) from vllm.transformers_utils.configs.ultravox import UltravoxConfig -from vllm.transformers_utils.configs.deepseek_v3 import DeepseekV3Config __all__ = [ "ChatGLMConfig", diff --git a/vllm/transformers_utils/configs/deepseek_v3.py b/vllm/transformers_utils/configs/deepseek_v3.py index 235b7b0fd33c..209ba08feb13 100644 --- a/vllm/transformers_utils/configs/deepseek_v3.py +++ b/vllm/transformers_utils/configs/deepseek_v3.py @@ -1,104 +1,14 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging logger = logging.get_logger(__name__) DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {} + + class DeepseekV3Config(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek - model according to the specified arguments, defining the model architecture. Instantiating a configuration with the - defaults will yield a similar configuration to that of the DeepSeek-V3. - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - Args: - vocab_size (`int`, *optional*, defaults to 129280): - Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`DeepseekV3Model`] - hidden_size (`int`, *optional*, defaults to 4096): - Dimension of the hidden representations. - intermediate_size (`int`, *optional*, defaults to 11008): - Dimension of the MLP representations. - moe_intermediate_size (`int`, *optional*, defaults to 1407): - Dimension of the MoE representations. - num_hidden_layers (`int`, *optional*, defaults to 32): - Number of hidden layers in the Transformer decoder. - num_nextn_predict_layers (`int`, *optional*, defaults to 1): - Number of nextn predict layers in the DeepSeekV3 Model. - num_attention_heads (`int`, *optional*, defaults to 32): - Number of attention heads for each attention layer in the Transformer decoder. - n_shared_experts (`int`, *optional*, defaults to None): - Number of shared experts, None means dense model. - n_routed_experts (`int`, *optional*, defaults to None): - Number of routed experts, None means dense model. - routed_scaling_factor (`float`, *optional*, defaults to 1.0): - Scaling factor or routed experts. - topk_method (`str`, *optional*, defaults to `gready`): - Topk method used in routed gate. - n_group (`int`, *optional*, defaults to None): - Number of groups for routed experts. - topk_group (`int`, *optional*, defaults to None): - Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups). - num_experts_per_tok (`int`, *optional*, defaults to None): - Number of selected experts, None means dense model. - moe_layer_freq (`int`, *optional*, defaults to 1): - The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers. - first_k_dense_replace (`int`, *optional*, defaults to 0): - Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head). - \--k dense layers--/ - norm_topk_prob (`bool`, *optional*, defaults to False): - Whether to normalize the weights of the routed experts. - scoring_func (`str`, *optional*, defaults to 'softmax'): - Method of computing expert weights. - aux_loss_alpha (`float`, *optional*, defaults to 0.001): - Auxiliary loss weight coefficient. - seq_aux = (`bool`, *optional*, defaults to True): - Whether to compute the auxiliary loss for each individual sample. - num_key_value_heads (`int`, *optional*): - This is the number of key_value heads that should be used to implement Grouped Query Attention. If - `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if - `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When - converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed - by meanpooling all the original heads within that group. For more details checkout [this - paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to - `num_attention_heads`. - hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): - The non-linear activation function (function or string) in the decoder. - max_position_embeddings (`int`, *optional*, defaults to 2048): - The maximum sequence length that this model might ever be used with. - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - rms_norm_eps (`float`, *optional*, defaults to 1e-06): - The epsilon used by the rms normalization layers. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values attentions (not used by all models). Only - relevant if `config.is_decoder=True`. - pad_token_id (`int`, *optional*): - Padding token id. - bos_token_id (`int`, *optional*, defaults to 1): - Beginning of stream token id. - eos_token_id (`int`, *optional*, defaults to 2): - End of stream token id. - tie_word_embeddings (`bool`, *optional*, defaults to `False`): - Whether to tie weight embeddings - rope_theta (`float`, *optional*, defaults to 10000.0): - The base period of the RoPE embeddings. - rope_scaling (`Dict`, *optional*): - Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling - strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is - `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update - `max_position_embeddings` to the expected new maximum. - attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): - Whether to use a bias in the query, key, value and output projection layers during self-attention. - attention_dropout (`float`, *optional*, defaults to 0.0): - The dropout ratio for the attention probabilities. - ```python - >>> from transformers import DeepseekV3Model, DeepseekV3Config - >>> # Initializing a Deepseek-V3 style configuration - >>> configuration = DeepseekV3Config() - >>> # Accessing the model configuration - >>> configuration = model.config - ```""" model_type = "deepseek_v3" keys_to_ignore_at_inference = ["past_key_values"] @@ -108,28 +18,28 @@ def __init__( vocab_size=129280, hidden_size=7168, intermediate_size=18432, - moe_intermediate_size = 2048, + moe_intermediate_size=2048, num_hidden_layers=61, num_nextn_predict_layers=1, num_attention_heads=128, num_key_value_heads=128, - n_shared_experts = 1, - n_routed_experts = 256, - ep_size = 1, - routed_scaling_factor = 2.5, - kv_lora_rank = 512, - q_lora_rank = 1536, - qk_rope_head_dim = 64, - v_head_dim = 128, - qk_nope_head_dim = 128, - topk_method = 'noaux_tc', - n_group = 8, - topk_group = 4, - num_experts_per_tok = 8, - moe_layer_freq = 1, - first_k_dense_replace = 3, - norm_topk_prob = True, - scoring_func = 'sigmoid', + n_shared_experts=1, + n_routed_experts=256, + ep_size=1, + routed_scaling_factor=2.5, + kv_lora_rank=512, + q_lora_rank=1536, + qk_rope_head_dim=64, + v_head_dim=128, + qk_nope_head_dim=128, + topk_method='noaux_tc', + n_group=8, + topk_group=4, + num_experts_per_tok=8, + moe_layer_freq=1, + first_k_dense_replace=3, + norm_topk_prob=True, + scoring_func='sigmoid', hidden_act="silu", max_position_embeddings=4096, initializer_range=0.02, @@ -190,4 +100,4 @@ def __init__( eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs, - ) \ No newline at end of file + ) diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index cdd9cb3c3b1a..2b5402bc5f6b 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -84,14 +84,11 @@ def _lazy_init() -> None: global _get_mn_major_tma_aligned_tensor_impl # fast path - if ( - _fp8_gemm_nt_impl is not None - or _grouped_impl is not None - or _grouped_masked_impl is not None - or _fp8_mqa_logits_impl is not None - or _fp8_paged_mqa_logits_impl is not None - or _get_paged_mqa_logits_metadata_impl is not None - ): + if (_fp8_gemm_nt_impl is not None or _grouped_impl is not None + or _grouped_masked_impl is not None + or _fp8_mqa_logits_impl is not None + or _fp8_paged_mqa_logits_impl is not None + or _get_paged_mqa_logits_metadata_impl is not None): return if not has_deep_gemm(): @@ -111,8 +108,7 @@ def _lazy_init() -> None: _fp8_mqa_logits_impl = getattr(_dg, "fp8_mqa_logits", None) _fp8_paged_mqa_logits_impl = getattr(_dg, "fp8_paged_mqa_logits", None) _get_paged_mqa_logits_metadata_impl = getattr( - _dg, "get_paged_mqa_logits_metadata", None - ) + _dg, "get_paged_mqa_logits_metadata", None) _get_mn_major_tma_aligned_tensor_impl = getattr( _dg, "get_mn_major_tma_aligned_tensor", None) @@ -187,10 +183,8 @@ def fp8_mqa_logits( return _fp8_mqa_logits_impl(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke) - -def get_paged_mqa_logits_metadata( - context_lens: torch.Tensor, block_size: int, num_sms: int -) -> torch.Tensor: +def get_paged_mqa_logits_metadata(context_lens: torch.Tensor, block_size: int, + num_sms: int) -> torch.Tensor: """Build scheduling metadata for paged MQA logits. Args: @@ -206,9 +200,8 @@ def get_paged_mqa_logits_metadata( _lazy_init() if _get_paged_mqa_logits_metadata_impl is None: return _missing() - return _get_paged_mqa_logits_metadata_impl( - context_lens, block_size, num_sms - ) + return _get_paged_mqa_logits_metadata_impl(context_lens, block_size, + num_sms) def fp8_paged_mqa_logits( @@ -244,17 +237,14 @@ def fp8_paged_mqa_logits( _lazy_init() if _fp8_paged_mqa_logits_impl is None: return _missing() - return _fp8_paged_mqa_logits_impl( - q_fp8, - kv_cache_fp8, - weights, - context_lens, - block_tables, - schedule_metadata, - max_model_len, - clean_logits=True - ) - + return _fp8_paged_mqa_logits_impl(q_fp8, + kv_cache_fp8, + weights, + context_lens, + block_tables, + schedule_metadata, + max_model_len, + clean_logits=True) def _ceil_to_ue8m0(x: torch.Tensor): @@ -326,4 +316,4 @@ def should_use_deepgemm_for_fp8_linear(output_dtype: torch.dtype, "get_num_sms", "should_use_deepgemm_for_fp8_linear", "get_col_major_tma_aligned_tensor", -] \ No newline at end of file +] diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index fd96a0fdf8b0..3511f0c2aac4 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -1101,7 +1101,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) - + if use_flashinfer_prefill(): logger.debug_once("Using FlashInfer prefill for MLA") self._run_prefill_context_chunk = self._run_prefill_context_chunk_fi diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index a118b29e9895..5ba2d05fdd7b 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -16,6 +16,7 @@ get_mla_metadata) from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils import cdiv from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl @@ -23,7 +24,6 @@ AttentionMetadataBuilder, CommonAttentionMetadata) from vllm.v1.kv_cache_interface import AttentionSpec -from vllm.platforms import current_platform if TYPE_CHECKING: from vllm.model_executor.models.deepseek_v2 import Indexer @@ -306,9 +306,8 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], dtype=torch.int32, device=self.device) self.num_speculative_tokens = ( - vllm_config.speculative_config.num_speculative_tokens - if vllm_config.speculative_config else 0 - ) + vllm_config.speculative_config.num_speculative_tokens + if vllm_config.speculative_config else 0) self.reorder_batch_threshold += self.num_speculative_tokens # Equation taken from FlashMLA/csrc/pybind.cpp @@ -327,7 +326,8 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], self.num_splits_buffer = torch.empty( # We pack all the tokens into one batch for sparse attention. # Otherwise, we can exceed the sm of `get_mla_metadata`. - (2, ), + ( + 2, ), dtype=torch.int32, device=device) self.req_id_per_token_buffer = torch.empty( @@ -422,7 +422,8 @@ def __init__( self.softmax_scale = scale assert indexer is not None self.topk_indices_buffer = indexer.topk_indices_buffer - self.padding = 128 if current_platform.is_device_capability(100) else 64 + self.padding = 128 if current_platform.is_device_capability( + 100) else 64 def _forward_bf16_kv( self, q: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, @@ -432,13 +433,12 @@ def _forward_bf16_kv( kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view( -1, 1, kv_c_and_k_pe_cache.shape[-1]) - # NOTE(Chen): kernel requires num_local_head to be a multiple of + # NOTE(Chen): kernel requires num_local_head to be a multiple of # 64 on hopper and 128 on blackwell if self.num_heads % self.padding != 0: assert self.padding % self.num_heads == 0 - logger.warning_once( - f"padding num_heads to {self.padding} due to sparse attn kernel requirement" - ) + logger.warning_once(f"padding num_heads to {self.padding} \ + due to sparse attn kernel requirement") q_padded = q.new_empty((q.shape[0], self.padding, q.shape[2])) q_padded[:, :self.num_heads, :] = q q = q_padded diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index 6abd02622ed2..d105e1c0222f 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -71,7 +71,8 @@ class DeepSeekV32IndexerDecodeMetadata: @dataclass class DeepseekV32IndexerMetadata: - #FIXME (zyongye) hacky way to access the data now, need to be in chunked meta + # FIXME (zyongye) + # hacky way to access the data now, need to be in chunked meta seq_lens: torch.Tensor num_reqs: int @@ -100,18 +101,24 @@ def kv_spans_from_batches(start_seq_loc: torch.Tensor, seq_len_per_batch: torch.Tensor): """ Args: - start_seq_loc: 1D long tensor [B+1], cumulative counts of selected tokens per batch. - Example: [0, 2, 4, 7] -> batch sizes (selected) [2, 2, 3], N=7 tokens total. - seq_len_per_batch: 1D long tensor [B], full sequence length (KV length) of each batch. + start_seq_loc: 1D long tensor [B+1], cumulative counts of + selected tokens per batch. + Example: [0, 2, 4, 7] -> + batch sizes (selected) [2, 2, 3], N=7 tokens total. + seq_len_per_batch: 1D long tensor [B], + full sequence length (KV length) of each batch. Example: [5, 9, 4]. Returns: - start_tensor: 1D long tensor [N], start offset in the concatenated KV cache for each token's batch. - end_location: 1D long tensor [N], **exclusive** end = start + token's local position. + start_tensor: 1D long tensor [N], start offset in the + concatenated KV cache for each token's batch. + end_location: 1D long tensor [N], + **exclusive** end = start + token's local position. (So the attended KV slice is kv[start:end].) - Assumes each batch contributes its full `seq_len_per_batch[i]` keys to the KV cache, and - the selected tokens within a batch are the **last** `counts[i]` positions of that sequence. + Assumes each batch contributes its full `seq_len_per_batch[i]` + keys to the KV cache, andthe selected tokens within a batch + are the **last** `counts[i]` positions of that sequence. """ q = start_seq_loc.to(dtype=torch.long) L = seq_len_per_batch.to(dtype=torch.long, device=q.device) @@ -154,7 +161,8 @@ def kv_spans_from_batches(start_seq_loc: torch.Tensor, def get_max_prefill_buffer_size(vllm_config: VllmConfig): max_model_len = vllm_config.model_config.max_model_len - max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens + # max_num_batched_tokens = \ + # vllm_config.scheduler_config.max_num_batched_tokens max_num_seq = vllm_config.scheduler_config.max_num_seqs # NOTE(Chen): an estimated max size of flattened_kv. Need to double check. return max_model_len * max_num_seq @@ -165,31 +173,31 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): AttentionCGSupport.UNIFORM_BATCH reorder_batch_threshold: ClassVar[int] = 1 + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) scheduler_config = self.vllm_config.scheduler_config - # NOTE(Chen): an estimated max size of flattened_kv. Need to double check. + #NOTE(Chen):an estimated max size of flattened_kv. Need to double check. self.max_prefill_buffer_size = get_max_prefill_buffer_size( self.vllm_config) self.num_speculative_tokens = ( - self.vllm_config.speculative_config.num_speculative_tokens - if self.vllm_config.speculative_config else 0 - ) + self.vllm_config.speculative_config.num_speculative_tokens + if self.vllm_config.speculative_config else 0) self.reorder_batch_threshold += self.num_speculative_tokens props = torch.cuda.get_device_properties(self.device) sm_count = props.multi_processor_count - self.num_sms = sm_count + self.num_sms = sm_count self.decode_lens_buffer = torch.empty( (scheduler_config.max_num_seqs, ), dtype=torch.int32, device=self.device) - # See: DeepGMM/csrc/apis/attention.hpp - self.scheduler_metadata_buffer = torch.empty( - (self.num_sms + 1, 2), dtype=torch.int32, device=self.device - ) + # See: DeepGMM/csrc/apis/attention.hpp + self.scheduler_metadata_buffer = torch.empty((self.num_sms + 1, 2), + dtype=torch.int32, + device=self.device) def build(self, common_prefix_len: int, @@ -238,14 +246,15 @@ def build(self, decode_metadata = None if num_decodes > 0: - torch.diff(common_attn_metadata.query_start_loc[:num_decodes+1], + torch.diff(common_attn_metadata.query_start_loc[:num_decodes + 1], out=self.decode_lens_buffer[:num_decodes]) decode_lens = self.decode_lens_buffer[:num_decodes] decode_lens_cpu = torch.diff( - common_attn_metadata.query_start_loc_cpu[:num_decodes+1]) - - # Use CPU to avoid GPU sync; breaking async scheduling - requires_padding = (decode_lens_cpu.max() > decode_lens_cpu.min()).item() + common_attn_metadata.query_start_loc_cpu[:num_decodes + 1]) + + # Use CPU to avoid GPU sync; breaking async scheduling + requires_padding = (decode_lens_cpu.max() + > decode_lens_cpu.min()).item() seq_lens = common_attn_metadata.seq_lens[:num_decodes] diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 58fe12aef0a9..e889f7804e84 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -9,9 +9,8 @@ from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec, CrossAttentionSpec, FullAttentionSpec, - MLAAttentionSpec, KVCacheSpec, MambaSpec, - SlidingWindowSpec) + MLAAttentionSpec, SlidingWindowSpec) from vllm.v1.request import Request diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 1928ab008e8b..9d3216012e01 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -16,8 +16,8 @@ from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model -from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache from vllm.model_executor.models import supports_multimodal +from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.platforms import current_platform @@ -210,15 +210,13 @@ def propose( self.runner.attn_groups[0][0].metadata_builders[ubatch_id] attn_metadata = attn_metadata_builder.build_for_drafting( common_attn_metadata=common_attn_metadata, draft_index=0) - # FIXME: support hybrid kv for draft model (remove separate indexer) + # FIXME: support hybrid kv for draft model (remove separate indexer) if self.draft_indexer_metadata_builder: draft_indexer_metadata = ( - self.draft_indexer_metadata_builder - .build_for_drafting( + self.draft_indexer_metadata_builder.build_for_drafting( common_attn_metadata=common_attn_metadata, draft_index=0, - ) - ) + )) else: draft_indexer_metadata = None # At this moment, we assume all eagle layers belong to the same KV @@ -862,8 +860,8 @@ def load_model(self, target_model: nn.Module) -> None: get_layers_from_vllm_config(self.vllm_config, Attention).keys()) # FIXME: support hybrid kv for draft model target_indexer_layer_names = set( - get_layers_from_vllm_config( - self.vllm_config, DeepseekV32IndexerCache).keys()) + get_layers_from_vllm_config(self.vllm_config, + DeepseekV32IndexerCache).keys()) from vllm.compilation.backends import set_model_tag with set_model_tag("eagle_head"): @@ -873,23 +871,23 @@ def load_model(self, target_model: nn.Module) -> None: draft_attn_layer_names = ( get_layers_from_vllm_config(self.vllm_config, Attention).keys() - target_attn_layer_names) - indexer_layers = get_layers_from_vllm_config(self.vllm_config, DeepseekV32IndexerCache) - draft_indexer_layer_names = (indexer_layers.keys() - target_indexer_layer_names) + indexer_layers = get_layers_from_vllm_config(self.vllm_config, + DeepseekV32IndexerCache) + draft_indexer_layer_names = (indexer_layers.keys() - + target_indexer_layer_names) self.attn_layer_names = list(draft_attn_layer_names) self.indexer_layer_names = list(draft_indexer_layer_names) if self.indexer_layer_names: first_layer = self.indexer_layer_names[0] self.draft_indexer_metadata_builder = ( - indexer_layers[first_layer] - .get_attn_backend() - .get_builder_cls()( - indexer_layers[first_layer].get_kv_cache_spec(), - self.indexer_layer_names, - self.vllm_config, - self.device, - ) - ) + indexer_layers[first_layer].get_attn_backend().get_builder_cls( + )( + indexer_layers[first_layer].get_kv_cache_spec(), + self.indexer_layer_names, + self.vllm_config, + self.device, + )) else: self.draft_indexer_metadata_builder = None From 53a3b94bd4f1c6385e3836303a40b82f597ad67e Mon Sep 17 00:00:00 2001 From: Yongye Zhu Date: Mon, 29 Sep 2025 20:54:53 +0000 Subject: [PATCH 60/82] fixing pre-commit Signed-off-by: Yongye Zhu --- vllm/v1/spec_decode/eagle.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 9d3216012e01..bb11a543fd8b 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -33,6 +33,7 @@ from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch +from vllm.v1.worker.ubatching import dbo_current_ubatch_id logger = init_logger(__name__) @@ -382,7 +383,7 @@ def propose( exceeds_max_model_len, PADDING_SLOT_ID) # Rebuild attention metadata - attn_metadata = builder.build_for_drafting( # type: ignore + attn_metadata = attn_metadata_builder.build_for_drafting( # type: ignore common_attn_metadata=common_attn_metadata, draft_index=token_index + 1) for layer_name in self.attn_layer_names: From 69fcaa2ab2e3a9f7f965e4322567d49e24f17d97 Mon Sep 17 00:00:00 2001 From: Yongye Zhu Date: Mon, 29 Sep 2025 21:04:41 +0000 Subject: [PATCH 61/82] fixing pre-commit Signed-off-by: Yongye Zhu --- vllm/model_executor/models/deepseek_v2.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 55ed8462050c..7f365c5d8697 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -762,11 +762,15 @@ def __init__(self, self.quant_block_size = 128 # TODO: get from config self.topk_indices_buffer = topk_indices_buffer - #TODO (zyongye) change dim to fp8 later to (self.head_dim + 4) - self.k_cache = DeepseekV32IndexerCache(head_dim=self.head_dim + 4, - dtype=torch.uint8, - prefix=f"{prefix}.k_cache", - cache_config=cache_config) + # NOTE: (zyongye) we use fp8 naive cache, + # where we store value in fp8 and scale in fp32 + # per self.quant_block_size element + self.k_cache = DeepseekV32IndexerCache( + head_dim=self.head_dim + + self.head_dim // self.quant_block_size * 4, + dtype=torch.uint8, + prefix=f"{prefix}.k_cache", + cache_config=cache_config) self.max_model_len = vllm_config.model_config.max_model_len self.prefix = prefix from vllm.v1.attention.backends.mla.indexer import ( @@ -975,7 +979,6 @@ def forward( positions: torch.Tensor, hidden_states: torch.Tensor, ) -> torch.Tensor: - # self.indexer(torch.tensor([]), torch.tensor([])) return self.mla_attn(positions, hidden_states) From 1dfc501d6b580eaf6cd6eb1169573873e1315583 Mon Sep 17 00:00:00 2001 From: Yongye Zhu Date: Mon, 29 Sep 2025 21:16:48 +0000 Subject: [PATCH 62/82] delete envs Signed-off-by: Yongye Zhu --- vllm/model_executor/layers/mla.py | 4 +--- vllm/platforms/cuda.py | 3 ++- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/mla.py b/vllm/model_executor/layers/mla.py index c83c0e12d26c..5298354c6027 100644 --- a/vllm/model_executor/layers/mla.py +++ b/vllm/model_executor/layers/mla.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import os from dataclasses import dataclass from typing import Optional @@ -81,8 +80,7 @@ def __init__( self.rotary_emb = mla_modules.rotary_emb self.o_proj = mla_modules.o_proj self.indexer = mla_modules.indexer - self.use_sparse = mla_modules.is_sparse and os.getenv( - "VLLM_MLA_SPARSE_DISABLED") != "1" + self.use_sparse = mla_modules.is_sparse if self.indexer is not None: assert hasattr(self.indexer, "topk_tokens") diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index e7033fe03bbe..31849e59968d 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -129,7 +129,8 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: # TODO(lucas): handle this more gracefully # Note: model_config may be None during testing if model_config is not None and model_config.use_mla: - use_sparse = os.getenv("VLLM_MLA_SPARSE_DISABLED") != "1" + use_sparse = hasattr(vllm_config.model_config.hf_config, + "index_topk") # If `VLLM_ATTENTION_BACKEND` is not set and we are using MLA, # then we default to FlashMLA backend for non-blackwell GPUs, # else we default to CutlassMLA. For each case, we force the From 66ebc852f90756cbca0d4b4395e528bf2d9fcda9 Mon Sep 17 00:00:00 2001 From: Yongye Zhu Date: Mon, 29 Sep 2025 21:36:12 +0000 Subject: [PATCH 63/82] fix basic.py Signed-off-by: Yongye Zhu --- examples/offline_inference/basic/basic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/offline_inference/basic/basic.py b/examples/offline_inference/basic/basic.py index 909fc9e4df66..78bfda9bcf4e 100644 --- a/examples/offline_inference/basic/basic.py +++ b/examples/offline_inference/basic/basic.py @@ -32,4 +32,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() From 684658d4a75fc54478d4cb399f0703c222db9a15 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Mon, 29 Sep 2025 14:34:10 -0700 Subject: [PATCH 64/82] fix pre-commit Signed-off-by: Lucas Wilkinson --- vllm/v1/attention/backends/mla/common.py | 49 ++++++------------- .../attention/backends/mla/flashmla_sparse.py | 10 +--- 2 files changed, 16 insertions(+), 43 deletions(-) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 3511f0c2aac4..561793b6a377 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -1073,7 +1073,7 @@ def get_and_maybe_dequant_weights(layer: LinearBase): # Convert from (L, N, P) to (N, P, L) self.W_UK_T = W_UK.permute(1, 2, 0) - def _v_up_proj(self, x): + def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor): # Convert from (B, N, L) to (N, B, L) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) if is_rocm_aiter_fp8bmm_enabled(): @@ -1085,12 +1085,23 @@ def _v_up_proj(self, x): transpose_bm=True) # Convert from (B, N, V) to (B, N * V) x = x.reshape(-1, self.num_heads * self.v_head_dim) + # Copy result + out.copy_(x) else: + # Convert from (B, N * V) to (N, B, V) + out = out.view(-1, self.num_heads, self.v_head_dim).transpose(0, 1) + # Multiply (N, B, L) x (N, L, V) -> (N, B, V) - x = torch.bmm(x, self.W_UV) + torch.bmm(x, self.W_UV, out=out) # Reuse "out" to make it "hot" + # Convert from (N, B, V) to (B, N * V) - x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) - return x + out_new = out.transpose(0, 1).reshape( + -1, self.num_heads * self.v_head_dim) + + # Adjust output buffer shape back to the original (B, N * V) + N, B, V = out.shape + out.resize_((B, N * V)) + out.copy_(out_new) # Copy result class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): @@ -1280,36 +1291,6 @@ def _run_prefill_context_chunk_cudnn(self, True, #Indicates actual_seq_lens are on GPU or CPU. ) - def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor): - # Convert from (B, N, L) to (N, B, L) - x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) - if is_rocm_aiter_fp8bmm_enabled(): - # Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V) - x = aiter_triton_fp8_bmm(x, - self.W_V, - self.W_V_scale, - group_size=128, - transpose_bm=True) - # Convert from (B, N, V) to (B, N * V) - x = x.reshape(-1, self.num_heads * self.v_head_dim) - # Copy result - out.copy_(x) - else: - # Convert from (B, N * V) to (N, B, V) - out = out.view(-1, self.num_heads, self.v_head_dim).transpose(0, 1) - - # Multiply (N, B, L) x (N, L, V) -> (N, B, V) - torch.bmm(x, self.W_UV, out=out) # Reuse "out" to make it "hot" - - # Convert from (N, B, V) to (B, N * V) - out_new = out.transpose(0, 1).reshape( - -1, self.num_heads * self.v_head_dim) - - # Adjust output buffer shape back to the original (B, N * V) - N, B, V = out.shape - out.resize_((B, N * V)) - out.copy_(out_new) # Copy result - def process_weights_after_loading(self, act_dtype: torch.dtype): def get_layer_weight(layer): diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 5ba2d05fdd7b..2ebc1ce38b6a 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -273,10 +273,6 @@ class FlashMLASparseMetadataBuilder( cudagraph_support: ClassVar[AttentionCGSupport] = \ AttentionCGSupport.UNIFORM_BATCH - reorder_batch_threshold: ClassVar[int] = 128 # TODO(lucas): tune this - - reorder_batch_threshold: ClassVar[int] = 1 - def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): @@ -305,10 +301,6 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], self.dummy_block_table = torch.empty((1, 1), dtype=torch.int32, device=self.device) - self.num_speculative_tokens = ( - vllm_config.speculative_config.num_speculative_tokens - if vllm_config.speculative_config else 0) - self.reorder_batch_threshold += self.num_speculative_tokens # Equation taken from FlashMLA/csrc/pybind.cpp h_q, h_k = self.num_heads, 1 @@ -549,5 +541,5 @@ def forward( attn_out = self._forward_fp8_kv(q, kv_cache, topk_indices_global, attn_metadata) - output[:num_actual_toks] = self._v_up_proj(attn_out) + self._v_up_proj(attn_out, out=output[:num_actual_toks]) return output From ae30e228553a11b548c6c52070739c73234d0088 Mon Sep 17 00:00:00 2001 From: Yongye Zhu Date: Mon, 29 Sep 2025 21:45:21 +0000 Subject: [PATCH 65/82] pre-commit Signed-off-by: Yongye Zhu --- vllm/attention/ops/flashmla.py | 90 +++++++++++++++++++++------------- 1 file changed, 56 insertions(+), 34 deletions(-) diff --git a/vllm/attention/ops/flashmla.py b/vllm/attention/ops/flashmla.py index 9c9eee24ebeb..1e49ed2db7dc 100644 --- a/vllm/attention/ops/flashmla.py +++ b/vllm/attention/ops/flashmla.py @@ -54,16 +54,22 @@ def get_mla_metadata( topk: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]: """ Arguments: - cache_seqlens: (batch_size), dtype torch.int32. - num_q_tokens_per_head_k: Equals to num_q_tokens_per_q_seq * num_heads_q // num_heads_k. - num_heads_k: The number of k heads. - num_heads_q: The number of q heads. This argument is optional when sparse attention is not enabled - is_fp8_kvcache: Whether the k_cache and v_cache are in fp8 format. - topk: If not None, sparse attention will be enabled, and only tokens in the `indices` array passed to `flash_mla_with_kvcache_sm90` will be attended to. + - cache_seqlens: (batch_size), dtype torch.int32. + - num_q_tokens_per_head_k: + Equals to num_q_tokens_per_q_seq * num_heads_q // num_heads_k. + - num_heads_k: The number of k heads. + - num_heads_q: + The number of q heads. + This argument is optional when sparse attention is not enabled + - is_fp8_kvcache: Whether the k_cache and v_cache are in fp8 format. + - topk: If not None, sparse attention will be enabled, + and only tokens in the `indices` array + passed to `flash_mla_with_kvcache_sm90` will be attended to. Returns: - tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. - num_splits: (batch_size + 1), dtype torch.int32. + - tile_scheduler_metadata: + (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. + - num_splits: (batch_size + 1), dtype torch.int32. """ return torch.ops._flashmla_C.get_mla_decoding_metadata( cache_seqlens, num_q_tokens_per_head_k, num_heads_k, num_heads_q, @@ -87,28 +93,42 @@ def flash_mla_with_kvcache( ) -> Tuple[torch.Tensor, torch.Tensor]: """ Arguments: - q: (batch_size, seq_len_q, num_heads_q, head_dim). - k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). - block_table: (batch_size, max_num_blocks_per_seq), torch.int32. - cache_seqlens: (batch_size), torch.int32. - head_dim_v: Head dimension of v. - tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata. - num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata. - softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim). - causal: bool. Whether to apply causal attention mask. - descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization. - descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization. - is_fp8_kvcache: bool. Whether the k_cache and v_cache are in fp8 format. For the format of FP8 KV cache, please refer to README.md - indices: (batch_size, seq_len_q, topk), torch.int32. If not None, sparse attention will be enabled, and only tokens in the `indices` array will be attended to. Invalid indices should be set to -1 or numbers >= total_seq_len_kv. For details about how to set up `indices`, please refer to README.md. + - q: (batch_size, seq_len_q, num_heads_q, head_dim). + - k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). + - block_table: (batch_size, max_num_blocks_per_seq), torch.int32. + - cache_seqlens: (batch_size), torch.int32. + - head_dim_v: Head dimension of v. + - tile_scheduler_metadata: + (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, + returned by get_mla_metadata. + - num_splits: + (batch_size + 1), torch.int32, returned by get_mla_metadata. + - softmax_scale: float. + The scale of QK^T before applying softmax. + Default to 1 / sqrt(head_dim). + - causal: bool. Whether to apply causal attention mask. + - descale_q: (batch_size), + torch.float32. Descaling factors for Q, used for fp8 quantization. + - descale_k: (batch_size), + torch.float32. Descaling factors for K, used for fp8 quantization. + - is_fp8_kvcache: bool. + Whether the k_cache and v_cache are in fp8 format. + For the format of FP8 KV cache, please refer to README.md + - indices: (batch_size, seq_len_q, topk), torch.int32. + If not None, sparse attention will be enabled, + and only tokens in the `indices` array will be attended to. + Invalid indices should be set to -1 or numbers >= total_seq_len_kv. + For details about how to set up `indices`, please refer to README.md. Returns: - out: (batch_size, seq_len_q, num_heads_q, head_dim_v). - softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. + - out: (batch_size, seq_len_q, num_heads_q, head_dim_v). + - softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. """ if softmax_scale is None: softmax_scale = q.shape[-1]**(-0.5) if indices is not None: - assert causal == False, "causal must be `false` if sparse attention is enabled." + assert not causal, \ + "causal must be `false` if sparse attention is enabled." assert (descale_q is None) == ( descale_k is None ), "descale_q and descale_k should be both None or both not None" @@ -136,18 +156,20 @@ def flash_mla_sparse_prefill( Sparse attention prefill kernel Args: - q: [s_q, h_q, d_qk], bfloat16 - kv: [s_kv, h_kv, d_qk], bfloat16 - indices: [s_q, h_kv, topk], int32. Invalid indices should be set to -1 or numbers >= s_kv - sm_scale: float - d_v: The dimension of value vectors. Can only be 512 + - q: [s_q, h_q, d_qk], bfloat16 + - kv: [s_kv, h_kv, d_qk], bfloat16 + - indices: [s_q, h_kv, topk], int32. + Invalid indices should be set to -1 or numbers >= s_kv + - sm_scale: float + - d_v: The dimension of value vectors. Can only be 512 Returns: - (output, max_logits, lse) - About the definition of output, max_logits and lse, please refer to README.md - - output: [s_q, h_q, d_v], bfloat16 - - max_logits: [s_q, h_q], float - - lse: [s_q, h_q], float, 2-based log-sum-exp + - (output, max_logits, lse) + About the definition of output, + max_logits and lse, please refer to README.md + - output: [s_q, h_q, d_v], bfloat16 + - max_logits: [s_q, h_q], float + - lse: [s_q, h_q], float, 2-based log-sum-exp """ results = torch.ops._flashmla_C.sparse_prefill_fwd(q, kv, indices, sm_scale, d_v) From 148f43a13fa3501c17029128552f301d8f8349d9 Mon Sep 17 00:00:00 2001 From: Yongye Zhu Date: Mon, 29 Sep 2025 21:59:42 +0000 Subject: [PATCH 66/82] fix more pre-commit Signed-off-by: Yongye Zhu --- tests/v1/attention/test_mla_backends.py | 5 +++-- vllm/platforms/tpu.py | 3 +-- vllm/transformers_utils/configs/deepseek_v3.py | 2 -- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py index 0154237bc04a..228551573ba8 100644 --- a/tests/v1/attention/test_mla_backends.py +++ b/tests/v1/attention/test_mla_backends.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Tests for v1 MLA backends without GPUModelRunner dependency.""" +from typing import Optional, Union import pytest import torch @@ -80,8 +81,8 @@ def create_and_prepopulate_kv_cache( num_blocks: int, common_attn_metadata: CommonAttentionMetadata, randomize_blocks: bool = True, - kv_cache_dtype: str | None = None, - scale: float | torch.Tensor = 1.0) -> torch.Tensor: + kv_cache_dtype: Optional[str] = None, + scale: Union[float, torch.Tensor] = 1.0) -> torch.Tensor: """Create and prepopulate an MLA KV cache with context data. Args: diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 85066335dc6d..4a4931f7f009 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -53,8 +53,7 @@ def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, if use_sparse: raise NotImplementedError( "Sparse Attention is not supported on TPU.") - if (selected_backend != _Backend.PALLAS - and selected_backend != _Backend.PALLAS_VLLM_V1): + if selected_backend != _Backend.PALLAS: logger.info("Cannot use %s backend on TPU.", selected_backend) if not use_v1: diff --git a/vllm/transformers_utils/configs/deepseek_v3.py b/vllm/transformers_utils/configs/deepseek_v3.py index 209ba08feb13..4b26cdfd94b5 100644 --- a/vllm/transformers_utils/configs/deepseek_v3.py +++ b/vllm/transformers_utils/configs/deepseek_v3.py @@ -5,8 +5,6 @@ logger = logging.get_logger(__name__) -DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {} - class DeepseekV3Config(PretrainedConfig): From 9033b4e165efbbd6b5e9e65d7b4c9c060708621e Mon Sep 17 00:00:00 2001 From: Lucia Fang <116399278+luccafong@users.noreply.github.com> Date: Mon, 29 Sep 2025 15:14:20 -0700 Subject: [PATCH 67/82] fix mtp config (#1) fix the num tokens Signed-off-by: Lucia Fang Co-authored-by: Lucia Fang --- vllm/config/speculative.py | 2 +- vllm/model_executor/models/deepseek_mtp.py | 2 +- vllm/model_executor/models/deepseek_v2.py | 8 ++++---- vllm/v1/attention/backends/mla/indexer.py | 5 +++-- 4 files changed, 9 insertions(+), 8 deletions(-) diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index cb4f0ae2cee0..f684e4e4ccd4 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -145,7 +145,7 @@ def compute_hash(self) -> str: @staticmethod def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig: - if hf_config.model_type == "deepseek_v3": + if hf_config.model_type in ("deepseek_v3", "deepseek_v32"): hf_config.model_type = "deepseek_mtp" if hf_config.model_type == "deepseek_mtp": n_predict = getattr(hf_config, "num_nextn_predict_layers", None) diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index 7187915b2db9..788e561ac394 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -56,7 +56,7 @@ def __init__(self, vllm_config: VllmConfig, prefix: str) -> None: self.is_v32 = hasattr(config, "index_topk") if self.is_v32: - topk_tokens = config.attn_module_list_cfg[0]["topk_tokens"] + topk_tokens = config.index_topk topk_indices_buffer = torch.empty( vllm_config.scheduler_config.max_num_batched_tokens, topk_tokens, diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 7f365c5d8697..b093748af8ca 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -637,18 +637,18 @@ def sparse_attn_indexer( decode_lens.shape[0], -1, *q_fp8.shape[1:]) # TODO: move and optimize below logic with triton kernels batch_size = padded_q_fp8_decode_tokens.shape[0] + next_n = padded_q_fp8_decode_tokens.shape[1] assert batch_size == decode_metadata.seq_lens.shape[0] + num_padded_tokens = batch_size * next_n logits = fp8_paged_mqa_logits( padded_q_fp8_decode_tokens, kv_cache, - weights[:num_decode_tokens], + weights[:num_padded_tokens], decode_metadata.seq_lens, decode_metadata.block_table, decode_metadata.schedule_metadata, max_model_len=max_model_len, ) - # [B, N, L] - next_n = padded_q_fp8_decode_tokens.shape[1] # padded query len current_device = padded_q_fp8_decode_tokens.device padded_num_tokens = batch_size * next_n @@ -676,7 +676,7 @@ def sparse_attn_indexer( # if padded, we need to unpack # the topk indices removing padded tokens topk_indices = unpack_seq_triton( - topk_indices.reshape(batch_size, -1, logits.shape[-1]), + topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]), decode_lens) topk_indices_buffer[:num_decode_tokens, :topk_indices. shape[-1]] = topk_indices.to(dtype=torch.int32) diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index d105e1c0222f..59e2958880ee 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -172,7 +172,7 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): cudagraph_support: ClassVar[AttentionCGSupport] = \ AttentionCGSupport.UNIFORM_BATCH - reorder_batch_threshold: ClassVar[int] = 1 + reorder_batch_threshold: int = 1 def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -183,7 +183,8 @@ def __init__(self, *args, **kwargs): self.num_speculative_tokens = ( self.vllm_config.speculative_config.num_speculative_tokens if self.vllm_config.speculative_config else 0) - self.reorder_batch_threshold += self.num_speculative_tokens + # Now deepgemm fp8_paged_mqa_logits does not support next_n > 2 + self.reorder_batch_threshold += min(self.num_speculative_tokens, 1) props = torch.cuda.get_device_properties(self.device) sm_count = props.multi_processor_count From 1fd2ceff796e907a673256cf5d40a73b8fb11675 Mon Sep 17 00:00:00 2001 From: Yongye Zhu Date: Mon, 29 Sep 2025 23:45:14 +0000 Subject: [PATCH 68/82] add tilelang kernel and skip if not installed Signed-off-by: Yongye Zhu --- tests/kernels/attention/test_indexer.py | 12 +- vllm/utils/__init__.py | 6 + vllm/utils/tilelang_kernels.py | 497 ++++++++++++++++++++++++ 3 files changed, 513 insertions(+), 2 deletions(-) create mode 100644 vllm/utils/tilelang_kernels.py diff --git a/tests/kernels/attention/test_indexer.py b/tests/kernels/attention/test_indexer.py index 6abb02d92aa4..7fb8ff8a69e6 100644 --- a/tests/kernels/attention/test_indexer.py +++ b/tests/kernels/attention/test_indexer.py @@ -2,14 +2,22 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import random +import pytest import torch from vllm import _custom_ops as ops -from vllm.utils import cdiv +from vllm.utils import cdiv, has_tilelang from vllm.utils.deep_gemm import (calc_diff, fp8_mqa_logits, fp8_paged_mqa_logits, get_num_sms, get_paged_mqa_logits_metadata) -from vllm.utils.tile_lang_kernels import act_quant, fp8_index + +if not has_tilelang(): + pytest.skip( + "tilelang not found, skipping all related tests", + allow_module_level=True, + ) + +from vllm.utils.tilelang_kernels import act_quant, fp8_index from vllm.v1.attention.backends.mla.indexer import kv_spans_from_batches diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index d48fb78b1215..11d6686009b2 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -3434,6 +3434,12 @@ def has_triton_kernels() -> bool: return _has_module("triton_kernels") +def has_tilelang() -> bool: + """Whether the optional `tilelang` package is available.""" + + return _has_module("tilelang") + + def set_process_title(name: str, suffix: str = "", prefix: str = envs.VLLM_PROCESS_NAME_PREFIX) -> None: diff --git a/vllm/utils/tilelang_kernels.py b/vllm/utils/tilelang_kernels.py new file mode 100644 index 000000000000..257f95891461 --- /dev/null +++ b/vllm/utils/tilelang_kernels.py @@ -0,0 +1,497 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import tilelang +import tilelang.language as T +import torch + +tilelang.set_log_level("WARNING") + +pass_configs = { + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_FAST_MATH: True, + tilelang.PassConfigKey.TL_DISABLE_THREAD_STORAGE_SYNC: True, +} + +FP8 = "float8_e4m3" +BF16 = "bfloat16" +FP32 = "float32" + + +def fast_log2_ceil(x): + bits_x = T.reinterpret("uint32", x) + exp_x = (bits_x >> 23) & 0xFF + man_bits = bits_x & ((1 << 23) - 1) + return T.Cast("int32", exp_x - 127 + T.if_then_else(man_bits != 0, 1, 0)) + + +def fast_pow2(x): + bits_x = (x + 127) << 23 + return T.reinterpret("float32", bits_x) + + +def fast_round_scale(amax, fp8_max_inv): + return fast_pow2(fast_log2_ceil(amax * fp8_max_inv)) + + +@tilelang.jit(pass_configs=pass_configs) +def act_quant_kernel(N, + in_dtype=BF16, + out_dtype=FP8, + scale_dtype=FP32, + round_scale=False): + M = T.symbolic("M") + fp8_min = -448.0 + fp8_max = 448.0 + fp8_max_inv = 1 / fp8_max + num_stages = 0 if round_scale else 2 + blk_m = 32 + group_size = 128 + + @T.prim_func + def act_quant_kernel_( + X: T.Tensor[(M, N), in_dtype], + Y: T.Tensor[(M, N), out_dtype], + S: T.Tensor[(M, T.ceildiv(N, group_size)), scale_dtype], + ): + with T.Kernel(T.ceildiv(M, blk_m), + T.ceildiv(N, group_size), + threads=128) as ( + pid_m, + pid_n, + ): + x_shared = T.alloc_shared((blk_m, group_size), in_dtype) + x_local = T.alloc_fragment((blk_m, group_size), in_dtype) + amax_local = T.alloc_fragment((blk_m, ), scale_dtype) + s_local = T.alloc_fragment((blk_m, ), scale_dtype) + y_local = T.alloc_fragment((blk_m, group_size), out_dtype) + y_shared = T.alloc_shared((blk_m, group_size), out_dtype) + + for _ in T.Pipelined(1, num_stages=num_stages): + T.copy(X[pid_m * blk_m, pid_n * group_size], x_shared) + T.copy(x_shared, x_local) + T.reduce_absmax(x_local, amax_local, dim=1) + for i in T.Parallel(blk_m): + amax_local[i] = T.max(amax_local[i], 1e-4) + if round_scale: + s_local[i] = fast_round_scale(amax_local[i], + fp8_max_inv) + else: + s_local[i] = amax_local[i] * fp8_max_inv + for i, j in T.Parallel(blk_m, group_size): + y_local[i, j] = T.clamp(x_local[i, j] / s_local[i], + fp8_min, fp8_max) + for i in T.Parallel(blk_m): + S[pid_m * blk_m + i, pid_n] = s_local[i] + T.copy(y_local, y_shared) + T.copy(y_shared, Y[pid_m * blk_m, pid_n * group_size]) + + return act_quant_kernel_ + + +def act_quant( + x: torch.Tensor, + block_size: int = 128, + scale_fmt: Optional[str] = None) -> tuple[torch.Tensor, torch.Tensor]: + """ + Quantizes the input tensor `x` using block-wise quantization. + + Args: + - x (torch.Tensor): + The input tensor to be quantized. + Must be contiguous and its last dimension size + must be divisible by `block_size`. + - block_size (int, optional): + The size of the blocks to be used for quantization. Default is 128. + - scale_fmt (Optional[str], optional): + The format of the scale. Default is None. + Returns: + - tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - The quantized tensor with dtype `torch.float8_e4m3fn`. + - A tensor of scaling factors with dtype `torch.float32`. + """ + assert x.is_contiguous(), "Input tensor must be contiguous" + assert x.size(-1) % block_size == 0, ( + f"Last dimension must be divisible by block_size={block_size}") + N = x.size(-1) + y = torch.empty_like(x, dtype=torch.float8_e4m3fn) + s = x.new_empty(*x.size()[:-1], N // block_size, dtype=torch.float32) + kernel = act_quant_kernel(N, round_scale=scale_fmt is not None) + kernel(x.view(-1, N), y.view(-1, N), s.view(-1, N // block_size)) + return y, s + + +@tilelang.jit(pass_configs=pass_configs) +def fp8_gemm_kernel(N, K, out_dtype=BF16, accum_dtype="float32"): + assert out_dtype in [BF16, "float32"] + + M = T.symbolic("M") + group_size = 128 + block_M = 32 + block_N = 128 + block_K = 128 + + @T.prim_func + def fp8_gemm_kernel_( + A: T.Tensor[(M, K), FP8], + B: T.Tensor[(N, K), FP8], + C: T.Tensor[(M, N), out_dtype], + scales_a: T.Tensor[(M, T.ceildiv(K, group_size)), FP32], + scales_b: T.Tensor[(T.ceildiv(N, group_size), + T.ceildiv(K, group_size)), FP32], + ): + with T.Kernel(T.ceildiv(N, block_N), + T.ceildiv(M, block_M), + threads=128) as ( + bx, + by, + ): + A_shared = T.alloc_shared((block_M, block_K), FP8) + B_shared = T.alloc_shared((block_N, block_K), FP8) + C_shared = T.alloc_shared((block_M, block_N), out_dtype) + Scale_C_shared = T.alloc_shared((block_M), FP32) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_local_accum = T.alloc_fragment((block_M, block_N), accum_dtype) + + # Improve L2 Cache + T.use_swizzle(panel_size=10) + + T.clear(C_local) + T.clear(C_local_accum) + K_iters = T.ceildiv(K, block_K) + for k in T.Pipelined(K_iters, num_stages=4): + # Load A into shared memory + T.copy(A[by * block_M, k * block_K], A_shared) + # Load B into shared memory + T.copy(B[bx * block_N, k * block_K], B_shared) + # Load scale into shared memory + Scale_B = scales_b[bx * block_N // group_size, k] + for i in T.Parallel(block_M): + Scale_C_shared[i] = scales_a[by * block_M + i, k] * Scale_B + + T.gemm(A_shared, B_shared, C_local, transpose_B=True) + # Promote to enable 2xAcc + for i, j in T.Parallel(block_M, block_N): + C_local_accum[i, j] += C_local[i, j] * Scale_C_shared[i] + T.clear(C_local) + # TMA store + T.copy(C_local_accum, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return fp8_gemm_kernel_ + + +def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, + b_s: torch.Tensor) -> torch.Tensor: + """ + Perform a matrix multiplication using FP8 precision. + + Args: + - a (torch.Tensor): + The first input matrix, must be contiguous. + - a_s (torch.Tensor): + The scaling factor for the first input matrix, must be contiguous. + - b (torch.Tensor): + The second input matrix, must be contiguous. + - b_s (torch.Tensor): + The scaling factor for the second input matrix, must be contiguous. + + Returns: + torch.Tensor: The result of the matrix multiplication. + """ + assert a.is_contiguous() and b.is_contiguous( + ), "Input tensors must be contiguous" + assert a_s.is_contiguous() and b_s.is_contiguous(), ( + "Scaling factor tensors must be contiguous") + K = a.size(-1) + M = a.numel() // K + N = b.size(0) + c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype()) + kernel = fp8_gemm_kernel(N, K) + kernel(a.view(M, K), b, c.view(M, N), a_s.view(M, -1), b_s) + return c + + +@tilelang.jit(out_idx=[4], pass_configs=pass_configs) +def fp8_index_kernel(h: int, d: int): + b = T.symbolic("b") + m = T.symbolic("m") + n = T.symbolic("n") + + blk_n1 = 512 + blk_n2 = 128 + + @T.prim_func + def fp8_index_kernel_( + q: T.Tensor[(b, m, h, d), FP8], + q_s: T.Tensor[(b, m, h), FP32], + k: T.Tensor[(b, n, d), FP8], + k_s: T.Tensor[(b, n), FP32], + o: T.Tensor[(b, m, n), FP32], + ) -> None: + with T.Kernel(b, m, T.ceildiv(n, blk_n1)) as (i_b, i_m, i1_n): + q_smem = T.alloc_shared((h, d), FP8) + T.copy(q[i_b, i_m, 0, 0], q_smem) + + q_s_frag = T.alloc_fragment(h, FP32) + T.copy(q_s[i_b, i_m, 0], q_s_frag) + + for i2_n in T.Pipelined(blk_n1 // blk_n2, num_stages=2): + k_smem = T.alloc_shared((blk_n2, d), FP8) + T.copy(k[i_b, i1_n * blk_n1 + i2_n * blk_n2, 0], k_smem) + + k_s_frag = T.alloc_fragment(blk_n2, FP32) + T.copy(k_s[i_b, i1_n * blk_n1 + i2_n * blk_n2], k_s_frag) + + logits = T.alloc_fragment((blk_n2, h), FP32) + T.gemm( + k_smem, + q_smem, + logits, + transpose_A=False, + transpose_B=True, + clear_accum=True, + ) + + for i_h, i3_n in T.Parallel(h, blk_n2): + logits[i3_n, + i_h] = T.max(logits[i3_n, i_h], 0) * q_s_frag[i_h] + + logits_sum = T.alloc_fragment(blk_n2, FP32) + T.reduce_sum(logits, logits_sum, dim=1) + + for i3_n in T.Parallel(blk_n2): + logits_sum[i3_n] *= k_s_frag[i3_n] + + T.copy(logits_sum, o[i_b, i_m, i1_n * blk_n1 + i2_n * blk_n2]) + + return fp8_index_kernel_ + + +def fp8_index( + q: torch.Tensor, + q_s: torch.Tensor, + k: torch.Tensor, + k_s: torch.Tensor, +) -> torch.Tensor: + """ + Perform index score using FP8 precision. + + Args: + - q (torch.Tensor): + The Q tensor, must be contiguous. + - q_s (torch.Tensor): + The scaling factor for Q (float), must be contiguous. + - k (torch.Tensor): + The K tensor, must be contiguous. + - k_s (torch.Tensor): + The scaling factor for K (e8m0 here), must be contiguous. + + fp8 q @ fp8 k -> fp32 logits + relu(fp32 logits) * q_s (weights) -> fp32 logits + fp32 logits -> fp32 logits_sum + fp32 logits_sum * k_s (e8m0) -> fp32 index_score + """ + return fp8_index_kernel(q.shape[2], q.shape[3])(q, q_s, k, k_s) + + +def convert_to_uint16(x): + hval = T.Cast("float16", x) + bits_uint = T.reinterpret("uint16", hval) + bits_uint = T.if_then_else(x < 0, ~bits_uint & (0xFFFF), + bits_uint | (0x8000)) + return bits_uint >> 8 + + +def convert_to_uint32(x): + bits_uint = T.reinterpret("uint32", x) + bits_uint = T.if_then_else( + x < 0, + ~bits_uint & T.Cast("uint32", (0xFFFFFFFF)), + bits_uint | T.Cast("uint32", (0x80000000)), + ) + return bits_uint + + +@tilelang.jit(pass_configs=pass_configs) +def tl_topk_impl(topk, in_dtype="float32", out_dtype="int32"): + batch = T.symbolic("batch") + seq_len = T.symbolic("seq_len") + RADIX = 1 << 8 + BLOCK_SIZE = 1024 + # assume the threshold bucket size after first pass is less than 4K + SMEM_INPUT_SIZE = 4096 + + @T.prim_func + def tl_topk_kernel( + input: T.Tensor[(batch, seq_len), in_dtype], + index: T.Tensor[(batch, topk), out_dtype], + starts: T.Tensor[(batch), "int32"], # noqa: F821 + ends: T.Tensor[(batch), "int32"], # noqa: F821 + ): + with T.Kernel(batch, threads=BLOCK_SIZE) as (bx): + tx = T.get_thread_binding() + + s_threshold_bin_id = T.alloc_shared([1], "int32") + s_histogram = T.alloc_shared([RADIX + 1], "int32") + s_num_input = T.alloc_shared([2], "int32") + s_input_idx = T.alloc_shared([2, SMEM_INPUT_SIZE], "int32") + + l_threshold_bin_id = T.alloc_var("int32") + l_new_topk = T.alloc_var("int32") + l_num_input = T.alloc_var("int32") + l_bin_id32 = T.alloc_var("int32") + l_val = T.alloc_var("int32") + l_start_pos = T.alloc_var("int32") + l_start_idx = T.alloc_var("int32") + l_end_idx = T.alloc_var("int32") + l_out_pos = T.alloc_var("int32") + + l_new_topk = topk + l_start_idx = starts[bx] + l_end_idx = ends[bx] + + # stage 1: use 8bit to do quick topk + T.fill(s_histogram, 0) + T.fill(s_num_input[0], 0) + + T.sync_threads() + for s in T.serial(T.ceildiv(seq_len, BLOCK_SIZE)): + input_idx = s * BLOCK_SIZE + tx + if (input_idx < l_end_idx and input_idx >= l_start_idx + and input_idx < seq_len): + inval_int16 = convert_to_uint16(input[bx, input_idx]) + T.atomic_add(s_histogram[inval_int16], 1) + T.sync_threads() + + # cumsum + if tx < RADIX: + for i in T.serial(8): + offset = 1 << i + T.sync_threads(3, RADIX) + if tx < RADIX - offset: + l_val = s_histogram[tx] + s_histogram[tx + offset] + T.sync_threads(3, RADIX) + if tx < RADIX - offset: + s_histogram[tx] = l_val + + # find threshold bin id + T.sync_threads(3, RADIX) + if s_histogram[tx] > l_new_topk and s_histogram[ + tx + 1] <= l_new_topk: + s_threshold_bin_id[0] = tx + T.sync_threads() + l_threshold_bin_id = s_threshold_bin_id[0] + l_new_topk = l_new_topk - s_histogram[l_threshold_bin_id + 1] + T.sync_threads() + + # collect all elements with exponent ≥ threshold + for s in T.serial(T.ceildiv(seq_len, BLOCK_SIZE)): + T.sync_threads() + input_idx = s * BLOCK_SIZE + tx + if (input_idx < l_end_idx and input_idx >= l_start_idx + and input_idx < seq_len): + bin_id = convert_to_uint16(input[bx, input_idx]) + l_bin_id32 = T.Cast("int32", bin_id) + if l_bin_id32 > l_threshold_bin_id: + # need a pos = T.atomic_add(s_histogram[bin_id32+1], 1) + pos = T.atomic_add(s_histogram[l_bin_id32 + 1], + 1, + return_prev=True) + index[bx, pos] = input_idx + + elif l_bin_id32 == l_threshold_bin_id and l_new_topk > 0: + # pos = s_num_input[0] + pos = T.atomic_add(s_num_input[0], 1, return_prev=True) + s_input_idx[0, pos] = input_idx + + # stage 2: tail pass + for round in T.serial(4): + if l_new_topk <= 0: + T.loop_break() + + r_idx = round % 2 + l_start_pos = topk - l_new_topk + + T.sync_threads() + T.fill(s_histogram, 0) + if tx == 0: + s_num_input[r_idx ^ 1] = 0 + T.sync_threads() + + l_num_input = s_num_input[r_idx] + for s in T.serial(T.ceildiv(l_num_input, BLOCK_SIZE)): + if s * BLOCK_SIZE + tx < l_num_input: + l_bin_id32 = T.Cast( + "int32", ((convert_to_uint32(input[bx, s_input_idx[ + r_idx, s * BLOCK_SIZE + tx]]) >> + (24 - round * 8)) & 0xFF)) + T.atomic_add(s_histogram[l_bin_id32], 1) + T.sync_threads() + # cumsum + if tx < RADIX: + for i in T.serial(8): + offset = 1 << i + T.sync_threads(3, RADIX) + if tx < RADIX - offset: + l_val = s_histogram[tx] + s_histogram[tx + offset] + T.sync_threads(3, RADIX) + if tx < RADIX - offset: + s_histogram[tx] = l_val + + # find threshold bin id + T.sync_threads(3, RADIX) + if s_histogram[tx] > l_new_topk and s_histogram[ + tx + 1] <= l_new_topk: + s_threshold_bin_id[0] = tx + T.sync_threads() + + l_threshold_bin_id = s_threshold_bin_id[0] + l_new_topk = l_new_topk - s_histogram[l_threshold_bin_id + 1] + T.sync_threads() + + for s in T.serial(T.ceildiv(l_num_input, BLOCK_SIZE)): + T.sync_threads() + if s * BLOCK_SIZE + tx < l_num_input: + l_bin_id32 = T.Cast( + "int32", ((convert_to_uint32(input[bx, s_input_idx[ + r_idx, s * BLOCK_SIZE + tx]]) >> + (24 - round * 8)) & 0xFF)) + if l_bin_id32 > l_threshold_bin_id: + pos = T.atomic_add(s_histogram[l_bin_id32 + 1], + 1, + return_prev=True) + l_start_pos + index[bx, pos] = s_input_idx[r_idx, + s * BLOCK_SIZE + tx] + elif (l_bin_id32 == l_threshold_bin_id + and l_new_topk > 0): + if round == 3: + l_out_pos = T.atomic_add( + s_histogram[l_bin_id32 + 1], + 1, + return_prev=True) + l_start_pos + if l_out_pos < topk: + index[bx, + l_out_pos] = s_input_idx[r_idx, s * + BLOCK_SIZE + + tx] + else: + pos = T.atomic_add(s_num_input[r_idx ^ 1], + 1, + return_prev=True) + s_input_idx[r_idx ^ 1, + pos] = s_input_idx[r_idx, + s * BLOCK_SIZE + + tx] + + return tl_topk_kernel + + +def tl_topk(input, starts, ends, topk): + batch, seq_len = input.shape + indexes = torch.zeros(batch, topk, dtype=torch.int32, device=input.device) + kernel = tl_topk_impl(topk) + kernel(input, indexes, starts, ends) + return indexes From cd77644f1feeae7101671f56e001b6e8eafc20ee Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Mon, 29 Sep 2025 17:01:56 -0700 Subject: [PATCH 69/82] [ci fix] DeepseekV2DecoderLayer.topk_indices_buffer Signed-off-by: Chen Zhang --- vllm/model_executor/models/deepseek_v2.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index b093748af8ca..03c43654d68f 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -311,6 +311,7 @@ def __init__( max_position_embeddings: int = 8192, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + topk_indices_buffer: Optional[torch.Tensor] = None, prefix: str = "", ) -> None: super().__init__() @@ -328,6 +329,8 @@ def __init__( self.scaling = self.qk_head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings + assert topk_indices_buffer is None, "topk_indices_buffer is not \ + supported for DeepseekV2Attention" if self.q_lora_rank is not None: self.q_a_proj = ReplicatedLinear(self.hidden_size, @@ -984,8 +987,10 @@ def forward( class DeepseekV2DecoderLayer(nn.Module): - def __init__(self, vllm_config: VllmConfig, prefix: str, - topk_indices_buffer: Optional[torch.Tensor]) -> None: + def __init__(self, + vllm_config: VllmConfig, + prefix: str, + topk_indices_buffer: Optional[torch.Tensor] = None) -> None: super().__init__() config = vllm_config.model_config.hf_config From 01e46c3c065af466ba66cf3313a69e6e1d1b72ae Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Mon, 29 Sep 2025 17:12:01 -0700 Subject: [PATCH 70/82] [ci fix] models/test_registry.py::test_registry_imports[DeepseekV32ForCausalLM] Signed-off-by: Chen Zhang --- tests/models/registry.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/models/registry.py b/tests/models/registry.py index 37ee474d3ecb..b7a2514d8bc0 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -207,6 +207,7 @@ def check_available_online( trust_remote_code=True), "DeepseekV3ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V3", # noqa: E501 trust_remote_code=True), + "DeepseekV32ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V3.2-Exp"), "Ernie4_5ForCausalLM": _HfExamplesInfo("baidu/ERNIE-4.5-0.3B-PT", min_transformers_version="4.54"), "Ernie4_5_MoeForCausalLM": _HfExamplesInfo("baidu/ERNIE-4.5-21B-A3B-PT", From 1419ff17256506027254af0b56faecad0b6232d9 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Mon, 29 Sep 2025 17:19:08 -0700 Subject: [PATCH 71/82] [ci fix] AttentionSpec.use_mla related Signed-off-by: Chen Zhang --- tests/v1/core/test_prefix_caching.py | 6 ++---- vllm/model_executor/models/config.py | 3 +-- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index b2d75aa955ff..5769099e0af1 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -76,7 +76,7 @@ def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig: kv_cache_groups=[ KVCacheGroupSpec( ["layer"], - FullAttentionSpec(block_size, 1, 1, torch.float32, False), + FullAttentionSpec(block_size, 1, 1, torch.float32), ) ], ) @@ -90,7 +90,7 @@ def make_kv_cache_config_hybrid_model(block_size: int, kv_cache_groups=[ KVCacheGroupSpec( ["layer1"], - FullAttentionSpec(block_size, 1, 1, torch.float32, False), + FullAttentionSpec(block_size, 1, 1, torch.float32), ), KVCacheGroupSpec( ["layer2"], @@ -98,7 +98,6 @@ def make_kv_cache_config_hybrid_model(block_size: int, 1, 1, torch.float32, - False, sliding_window=2 * block_size), ), KVCacheGroupSpec( @@ -107,7 +106,6 @@ def make_kv_cache_config_hybrid_model(block_size: int, 1, 1, torch.float32, - False, sliding_window=2 * block_size), ), ], diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index d381268f78c4..589ca0069034 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -346,8 +346,7 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: block_size=1, num_kv_heads=model_config.get_num_kv_heads(parallel_config), head_size=model_config.get_head_size(), - dtype=kv_cache_dtype, - use_mla=model_config.use_mla).page_size_bytes + dtype=kv_cache_dtype).page_size_bytes model_cls, _ = ModelRegistry.resolve_model_cls( model_config.architecture, From 8fbefb424e53068815fa99dc71bdcb2a6650d6e7 Mon Sep 17 00:00:00 2001 From: Yongye Zhu Date: Tue, 30 Sep 2025 00:43:51 +0000 Subject: [PATCH 72/82] address review comment Signed-off-by: Yongye Zhu --- vllm/attention/ops/flashmla.py | 3 +++ vllm/model_executor/layers/mla.py | 7 +++---- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/vllm/attention/ops/flashmla.py b/vllm/attention/ops/flashmla.py index 1e49ed2db7dc..3cc0e4adfa0a 100644 --- a/vllm/attention/ops/flashmla.py +++ b/vllm/attention/ops/flashmla.py @@ -127,6 +127,9 @@ def flash_mla_with_kvcache( if softmax_scale is None: softmax_scale = q.shape[-1]**(-0.5) if indices is not None: + # NOTE (zyongye): sparse attention is also causal + # since it only attend to the tokens before + # but here `causal` should not be specified assert not causal, \ "causal must be `false` if sparse attention is enabled." assert (descale_q is None) == ( diff --git a/vllm/model_executor/layers/mla.py b/vllm/model_executor/layers/mla.py index 5298354c6027..66bf3823e191 100644 --- a/vllm/model_executor/layers/mla.py +++ b/vllm/model_executor/layers/mla.py @@ -80,12 +80,11 @@ def __init__( self.rotary_emb = mla_modules.rotary_emb self.o_proj = mla_modules.o_proj self.indexer = mla_modules.indexer - self.use_sparse = mla_modules.is_sparse + self.is_sparse = mla_modules.is_sparse if self.indexer is not None: assert hasattr(self.indexer, "topk_tokens") - self.topk_tokens = self.indexer.topk_tokens \ - if self.indexer else None + self.topk_tokens = self.indexer.topk_tokens self.topk_indices_buffer = mla_modules.topk_indices_buffer # In the MLA backend, kv_cache includes both k_c and @@ -158,7 +157,7 @@ def forward_native( q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb( positions, q[..., self.qk_nope_head_dim:], k_pe) - if self.indexer and self.use_sparse: + if self.indexer and self.is_sparse: _topk_indices = self.indexer(hidden_states, q_c, positions, self.rotary_emb) From abaa8ccebdb94b390c6ee688bc9056cf4bd2b59a Mon Sep 17 00:00:00 2001 From: Yongye Zhu Date: Tue, 30 Sep 2025 02:05:57 +0000 Subject: [PATCH 73/82] [ci] skip if not on sm90+, add vllm_config on longcat model Signed-off-by: Yongye Zhu --- tests/kernels/attention/test_deepgemm_attention.py | 4 ++++ vllm/model_executor/models/longcat_flash.py | 3 +++ 2 files changed, 7 insertions(+) diff --git a/tests/kernels/attention/test_deepgemm_attention.py b/tests/kernels/attention/test_deepgemm_attention.py index 03cc6b930c94..2d901e408b27 100644 --- a/tests/kernels/attention/test_deepgemm_attention.py +++ b/tests/kernels/attention/test_deepgemm_attention.py @@ -84,6 +84,8 @@ def _ref_fp8_mqa_logits( @pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only") @pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available") +@pytest.mark.skipif(not current_platform.has_device_capability(90), + reason="SM90 and SM100 only") def test_deepgemm_fp8_mqa_logits(): torch.manual_seed(0) random.seed(0) @@ -188,6 +190,8 @@ def _ref_fp8_paged_mqa_logits( @pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only") @pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available") +@pytest.mark.skipif(not current_platform.has_device_capability(90), + reason="SM90 and SM100 only") def test_deepgemm_fp8_paged_mqa_logits(): torch.manual_seed(0) random.seed(0) diff --git a/vllm/model_executor/models/longcat_flash.py b/vllm/model_executor/models/longcat_flash.py index 1a7a64bfd1a4..78e6e3d4b535 100644 --- a/vllm/model_executor/models/longcat_flash.py +++ b/vllm/model_executor/models/longcat_flash.py @@ -308,6 +308,7 @@ class FlashDecoderLayer(nn.Module): def __init__( self, + vllm_config: VllmConfig, config: FlashConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, @@ -329,6 +330,7 @@ def __init__( # Dual attention structure self.self_attn = nn.ModuleList([ DeepseekV2MLAAttention( + vllm_config=vllm_config, config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, @@ -454,6 +456,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: FlashDecoderLayer( + vllm_config, config, cache_config=cache_config, quant_config=quant_config, From de7f7cb7ca05bb81b581dfaf02f3461ad55bc4a3 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Mon, 29 Sep 2025 20:45:09 -0700 Subject: [PATCH 74/82] [ci fix] test_can_initialize_large_subset[DeepseekV32ForCausalLM] Signed-off-by: Chen Zhang --- tests/models/test_initialization.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index e818b908e8a8..1db0dc3da922 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -8,7 +8,8 @@ from vllm import LLM from vllm.utils import GiB_bytes -from vllm.v1.core.kv_cache_utils import get_kv_cache_configs +from vllm.v1.core.kv_cache_utils import (generate_scheduler_kv_cache_config, + get_kv_cache_configs) from vllm.v1.engine.core import EngineCore as V1EngineCore from ..utils import create_new_process_for_each_test @@ -62,11 +63,13 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch, # Avoid calling model.forward() def _initialize_kv_caches_v1(self, vllm_config): kv_cache_specs = self.model_executor.get_kv_cache_specs() - scheduler_kv_cache_config = get_kv_cache_configs( + kv_cache_configs = get_kv_cache_configs( vllm_config, kv_cache_specs, [10 * GiB_bytes], - )[0] + ) + scheduler_kv_cache_config = generate_scheduler_kv_cache_config( + kv_cache_configs) # gpu_blocks (> 0), cpu_blocks, scheduler_kv_cache_config return 1, 0, scheduler_kv_cache_config From 24fc3e7d34947781c7d398464ae28649ce417cf9 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 30 Sep 2025 14:48:51 +0800 Subject: [PATCH 75/82] Update vllm/transformers_utils/config.py Signed-off-by: youkaichao --- vllm/transformers_utils/config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 4eee53989279..c3384aec82b1 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -66,6 +66,7 @@ def __getitem__(self, key): _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict( chatglm="ChatGLMConfig", deepseek_vl_v2="DeepseekVLV2Config", + deepseek_v3="DeepseekV3Config", deepseek_v32="DeepseekV3Config", kimi_vl="KimiVLConfig", Llama_Nemotron_Nano_VL="Nemotron_Nano_VL_Config", From 9b1b762a713f4e6cd6d154729bac3aed4b5b9cee Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 30 Sep 2025 14:49:51 +0800 Subject: [PATCH 76/82] Update vllm/v1/attention/backends/mla/flashmla_sparse.py Signed-off-by: youkaichao --- vllm/v1/attention/backends/mla/flashmla_sparse.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 2ebc1ce38b6a..36c3c188042c 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -477,8 +477,7 @@ def forward( output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: # NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use - # MQA 576/512 approach for both prefill and decode (see: - # https://vllm-dev.slack.com/archives/C09GKA1D4LR/p1758506094148479) + # MQA 576/512 approach for both prefill and decode assert output is not None, "Output tensor must be provided." From eb5d3317d2d6c5bb8abd406c2d729fe44c265a0b Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 30 Sep 2025 14:52:55 +0800 Subject: [PATCH 77/82] rm files Signed-off-by: youkaichao --- tests/kernels/attention/test_indexer.py | 242 ------------------------ 1 file changed, 242 deletions(-) delete mode 100644 tests/kernels/attention/test_indexer.py diff --git a/tests/kernels/attention/test_indexer.py b/tests/kernels/attention/test_indexer.py deleted file mode 100644 index 7fb8ff8a69e6..000000000000 --- a/tests/kernels/attention/test_indexer.py +++ /dev/null @@ -1,242 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import random - -import pytest -import torch - -from vllm import _custom_ops as ops -from vllm.utils import cdiv, has_tilelang -from vllm.utils.deep_gemm import (calc_diff, fp8_mqa_logits, - fp8_paged_mqa_logits, get_num_sms, - get_paged_mqa_logits_metadata) - -if not has_tilelang(): - pytest.skip( - "tilelang not found, skipping all related tests", - allow_module_level=True, - ) - -from vllm.utils.tilelang_kernels import act_quant, fp8_index -from vllm.v1.attention.backends.mla.indexer import kv_spans_from_batches - - -def kv_cache_cast_to_fp8(x: torch.Tensor) -> torch.Tensor: - num_blocks, block_size, num_heads, head_dim = x.shape - assert num_heads == 1 - x_amax = x.abs().float().amax(dim=3, keepdim=True).clamp(1e-4) - sf = x_amax / 448.0 - x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn) - x_fp8 = torch.empty((num_blocks, block_size * (head_dim + 4)), - device=x.device, - dtype=torch.uint8) - x_fp8[:, :block_size * head_dim] = x_scaled.view( - num_blocks, block_size * head_dim).view(dtype=torch.uint8) - x_fp8[:, - block_size * head_dim:] = sf.view(num_blocks, - block_size).view(dtype=torch.uint8) - return x_fp8.view(num_blocks, block_size, num_heads, head_dim + 4) - - -def ref_compute_logits_fp8(q, kv, weights, mask, block_size): - q_fp8, q_scale = act_quant(q, block_size, "ue8m0") - k_fp8, k_scale = act_quant(kv, block_size, "ue8m0") - - weights = weights.unsqueeze(-1) * q_scale - weights = weights * (128**(-0.5)) - index_score = fp8_index(q_fp8.contiguous(), weights, k_fp8.contiguous(), - k_scale.contiguous()) - if mask is not None: - index_score += mask - return index_score - - -def ref_indexer(seq_len, q, kv, weights, block_size, topk): - B = seq_len.shape[0] - total_seqlen = torch.sum(seq_len) - varlen_logits = torch.full((total_seqlen, total_seqlen), - float("-inf"), - device="cuda") - - current_context_ptr = 0 - for i in range(B): - S = seq_len[i] - q_s = q[i][:S].contiguous().unsqueeze(0) - kv_s = kv[i][:S].contiguous().unsqueeze(0) - weights_s = weights[i][:S].contiguous().unsqueeze(0) - mask = torch.full((S, S), float("-inf"), device="cuda").triu_(1) - logits = ref_compute_logits_fp8(q_s, kv_s, weights_s, mask, block_size) - logits = logits.squeeze(0) - - varlen_logits[current_context_ptr:current_context_ptr + S, - current_context_ptr:current_context_ptr + S] = logits - current_context_ptr += S - return varlen_logits - - -def deepgemm_mqa_indexer( - seq_len, - query_seq_len, - q, - kv, - weights, - block_size, - topk, - is_kv_batched=True, -): - B = seq_len.shape[0] - concat_q = [] - concat_kv = [] - concat_weights = [] - - for i in range(B): - S = seq_len[i] - q_s = q[i][:S].contiguous() - if is_kv_batched: - kv_s = kv[i][:S].contiguous() - weight_s = weights[i][:S].contiguous() - concat_q.append(q_s) - if is_kv_batched: - concat_kv.append(kv_s) - concat_weights.append(weight_s) - - q = torch.cat(concat_q, dim=0) - if is_kv_batched: - kv = torch.cat(concat_kv, dim=0) - weights = torch.cat(concat_weights, dim=0) - q_fp8, q_scale = act_quant(q, block_size, "ue8m0") - kv_fp8, kv_scale = act_quant(kv, block_size, "ue8m0") - - weights = weights.unsqueeze(-1) * (128**(-0.5)) * q_scale - weights = weights.squeeze(-1) - query_start_loc = torch.empty((B + 1), device="cuda") - query_start_loc[0] = 0 - query_start_loc[1:] = query_seq_len.cumsum(dim=0).to(dtype=torch.int32) - - cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(query_start_loc, - seq_len) - - logits = fp8_mqa_logits(q_fp8, (kv_fp8, kv_scale), weights, cu_seqlen_ks, - cu_seqlen_ke) - topk_indices = logits.topk(topk, dim=-1)[1] - mask_lo = topk_indices >= cu_seqlen_ks[:, None] - mask_hi = topk_indices < cu_seqlen_ke[:, None] - mask = mask_lo & mask_hi - topk_indices = topk_indices.masked_fill(~mask, -1) - return logits - - -def test_prefill_indexer(): - B = 3 - S = 128 - SKV = S - H = 64 - # HKV = 1 - D = 128 - block_size = 128 - topk = 64 - device = "cuda" - seq_len = torch.randint(low=64, high=S, size=(B, )) - - q = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16) - kv = torch.randn(B, SKV, D, device="cuda", dtype=torch.bfloat16) - weights = torch.randn(B, S, H, device=device, - dtype=torch.float32) * H**-0.5 - - ref_logits = ref_indexer(seq_len, q, kv, weights, block_size, topk) - deepgemm_logits = deepgemm_mqa_indexer(seq_len, seq_len, q, kv, weights, - block_size, topk) - torch.testing.assert_close(ref_logits, deepgemm_logits) - - -def test_decode_paged_indexer(): - num_blocks, blocksize = 111 * 3000, 64 - B = 3 - S = 128 - # SKV = S - H = 64 - # HKV = 1 - D = 128 - block_size = 128 - topk = 64 - device = "cuda" - seq_len = torch.randint(low=64, high=S, size=(B, ), device="cuda") - - query_seq_len = torch.ones(B, device="cuda") - - q = torch.randn((B, 1, H, D), device='cuda', dtype=torch.bfloat16) - kv_cache = torch.randn((num_blocks, blocksize, 1, D), - device='cuda', - dtype=torch.bfloat16) - weights = torch.randn( - (B * 1, H), device='cuda', dtype=torch.float32) * H**-0.5 - max_block_len = (seq_len.max().item() + blocksize - - 1) // blocksize * blocksize - - block_tables = torch.zeros((B, max_block_len), - device='cuda', - dtype=torch.int32) - - counter = 0 - block_idx_pool = list(range(num_blocks)) - random.shuffle(block_idx_pool) - for i in range(B): - ctx_len = seq_len[i].item() - for j in range(cdiv(ctx_len, blocksize)): - block_tables[i][j] = block_idx_pool[counter] - counter += 1 - - flatten_kv = torch.empty([seq_len.sum(), D], - device="cuda", - dtype=torch.bfloat16) - cu_seq_lens = torch.cat([ - torch.zeros(1, dtype=torch.int32, device=device), - seq_len.cumsum(dim=0) - ]).to(torch.int32).cuda() - - ops.cp_gather_cache( - kv_cache, - flatten_kv, - block_tables, - cu_seq_lens, - B, - ) - - ref_logits = deepgemm_mqa_indexer(seq_len, - query_seq_len, - q, - flatten_kv, - weights, - block_size, - topk, - is_kv_batched=False) - - q_fp8, q_scale = act_quant(q, block_size, "ue8m0") - kv_cache_fp8 = kv_cache_cast_to_fp8(kv_cache) - - schedule_metadata = get_paged_mqa_logits_metadata(seq_len.int(), blocksize, - get_num_sms()) - - weights = weights.unsqueeze(-1) * (128**(-0.5)) * q_scale.squeeze(1) - weights = weights.squeeze(-1) - - logits = fp8_paged_mqa_logits(q_fp8, kv_cache_fp8, weights, seq_len.int(), - block_tables, schedule_metadata, 4096) - - concat_logit = [] - context = 0 - for i in range(B): - per_seq_logits = torch.zeros(4096, device="cuda") - S = seq_len[i] - per_seq_logits[:S] = ref_logits[i][context:context + S] - concat_logit.append(per_seq_logits) - context += S - ref_logits = torch.stack(concat_logit, dim=0) - logits[logits == float("-inf")] = 0 - diff = calc_diff(logits, ref_logits) - assert diff < 1e-3, f"{diff=}" - - -if __name__ == "__main__": - test_prefill_indexer() - test_decode_paged_indexer() From d9693e8c4bddf31b2020924ada8e10d91fd48a05 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 30 Sep 2025 14:53:11 +0800 Subject: [PATCH 78/82] rm files Signed-off-by: youkaichao --- vllm/utils/tilelang_kernels.py | 497 --------------------------------- 1 file changed, 497 deletions(-) delete mode 100644 vllm/utils/tilelang_kernels.py diff --git a/vllm/utils/tilelang_kernels.py b/vllm/utils/tilelang_kernels.py deleted file mode 100644 index 257f95891461..000000000000 --- a/vllm/utils/tilelang_kernels.py +++ /dev/null @@ -1,497 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional - -import tilelang -import tilelang.language as T -import torch - -tilelang.set_log_level("WARNING") - -pass_configs = { - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_FAST_MATH: True, - tilelang.PassConfigKey.TL_DISABLE_THREAD_STORAGE_SYNC: True, -} - -FP8 = "float8_e4m3" -BF16 = "bfloat16" -FP32 = "float32" - - -def fast_log2_ceil(x): - bits_x = T.reinterpret("uint32", x) - exp_x = (bits_x >> 23) & 0xFF - man_bits = bits_x & ((1 << 23) - 1) - return T.Cast("int32", exp_x - 127 + T.if_then_else(man_bits != 0, 1, 0)) - - -def fast_pow2(x): - bits_x = (x + 127) << 23 - return T.reinterpret("float32", bits_x) - - -def fast_round_scale(amax, fp8_max_inv): - return fast_pow2(fast_log2_ceil(amax * fp8_max_inv)) - - -@tilelang.jit(pass_configs=pass_configs) -def act_quant_kernel(N, - in_dtype=BF16, - out_dtype=FP8, - scale_dtype=FP32, - round_scale=False): - M = T.symbolic("M") - fp8_min = -448.0 - fp8_max = 448.0 - fp8_max_inv = 1 / fp8_max - num_stages = 0 if round_scale else 2 - blk_m = 32 - group_size = 128 - - @T.prim_func - def act_quant_kernel_( - X: T.Tensor[(M, N), in_dtype], - Y: T.Tensor[(M, N), out_dtype], - S: T.Tensor[(M, T.ceildiv(N, group_size)), scale_dtype], - ): - with T.Kernel(T.ceildiv(M, blk_m), - T.ceildiv(N, group_size), - threads=128) as ( - pid_m, - pid_n, - ): - x_shared = T.alloc_shared((blk_m, group_size), in_dtype) - x_local = T.alloc_fragment((blk_m, group_size), in_dtype) - amax_local = T.alloc_fragment((blk_m, ), scale_dtype) - s_local = T.alloc_fragment((blk_m, ), scale_dtype) - y_local = T.alloc_fragment((blk_m, group_size), out_dtype) - y_shared = T.alloc_shared((blk_m, group_size), out_dtype) - - for _ in T.Pipelined(1, num_stages=num_stages): - T.copy(X[pid_m * blk_m, pid_n * group_size], x_shared) - T.copy(x_shared, x_local) - T.reduce_absmax(x_local, amax_local, dim=1) - for i in T.Parallel(blk_m): - amax_local[i] = T.max(amax_local[i], 1e-4) - if round_scale: - s_local[i] = fast_round_scale(amax_local[i], - fp8_max_inv) - else: - s_local[i] = amax_local[i] * fp8_max_inv - for i, j in T.Parallel(blk_m, group_size): - y_local[i, j] = T.clamp(x_local[i, j] / s_local[i], - fp8_min, fp8_max) - for i in T.Parallel(blk_m): - S[pid_m * blk_m + i, pid_n] = s_local[i] - T.copy(y_local, y_shared) - T.copy(y_shared, Y[pid_m * blk_m, pid_n * group_size]) - - return act_quant_kernel_ - - -def act_quant( - x: torch.Tensor, - block_size: int = 128, - scale_fmt: Optional[str] = None) -> tuple[torch.Tensor, torch.Tensor]: - """ - Quantizes the input tensor `x` using block-wise quantization. - - Args: - - x (torch.Tensor): - The input tensor to be quantized. - Must be contiguous and its last dimension size - must be divisible by `block_size`. - - block_size (int, optional): - The size of the blocks to be used for quantization. Default is 128. - - scale_fmt (Optional[str], optional): - The format of the scale. Default is None. - Returns: - - tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - The quantized tensor with dtype `torch.float8_e4m3fn`. - - A tensor of scaling factors with dtype `torch.float32`. - """ - assert x.is_contiguous(), "Input tensor must be contiguous" - assert x.size(-1) % block_size == 0, ( - f"Last dimension must be divisible by block_size={block_size}") - N = x.size(-1) - y = torch.empty_like(x, dtype=torch.float8_e4m3fn) - s = x.new_empty(*x.size()[:-1], N // block_size, dtype=torch.float32) - kernel = act_quant_kernel(N, round_scale=scale_fmt is not None) - kernel(x.view(-1, N), y.view(-1, N), s.view(-1, N // block_size)) - return y, s - - -@tilelang.jit(pass_configs=pass_configs) -def fp8_gemm_kernel(N, K, out_dtype=BF16, accum_dtype="float32"): - assert out_dtype in [BF16, "float32"] - - M = T.symbolic("M") - group_size = 128 - block_M = 32 - block_N = 128 - block_K = 128 - - @T.prim_func - def fp8_gemm_kernel_( - A: T.Tensor[(M, K), FP8], - B: T.Tensor[(N, K), FP8], - C: T.Tensor[(M, N), out_dtype], - scales_a: T.Tensor[(M, T.ceildiv(K, group_size)), FP32], - scales_b: T.Tensor[(T.ceildiv(N, group_size), - T.ceildiv(K, group_size)), FP32], - ): - with T.Kernel(T.ceildiv(N, block_N), - T.ceildiv(M, block_M), - threads=128) as ( - bx, - by, - ): - A_shared = T.alloc_shared((block_M, block_K), FP8) - B_shared = T.alloc_shared((block_N, block_K), FP8) - C_shared = T.alloc_shared((block_M, block_N), out_dtype) - Scale_C_shared = T.alloc_shared((block_M), FP32) - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - C_local_accum = T.alloc_fragment((block_M, block_N), accum_dtype) - - # Improve L2 Cache - T.use_swizzle(panel_size=10) - - T.clear(C_local) - T.clear(C_local_accum) - K_iters = T.ceildiv(K, block_K) - for k in T.Pipelined(K_iters, num_stages=4): - # Load A into shared memory - T.copy(A[by * block_M, k * block_K], A_shared) - # Load B into shared memory - T.copy(B[bx * block_N, k * block_K], B_shared) - # Load scale into shared memory - Scale_B = scales_b[bx * block_N // group_size, k] - for i in T.Parallel(block_M): - Scale_C_shared[i] = scales_a[by * block_M + i, k] * Scale_B - - T.gemm(A_shared, B_shared, C_local, transpose_B=True) - # Promote to enable 2xAcc - for i, j in T.Parallel(block_M, block_N): - C_local_accum[i, j] += C_local[i, j] * Scale_C_shared[i] - T.clear(C_local) - # TMA store - T.copy(C_local_accum, C_shared) - T.copy(C_shared, C[by * block_M, bx * block_N]) - - return fp8_gemm_kernel_ - - -def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, - b_s: torch.Tensor) -> torch.Tensor: - """ - Perform a matrix multiplication using FP8 precision. - - Args: - - a (torch.Tensor): - The first input matrix, must be contiguous. - - a_s (torch.Tensor): - The scaling factor for the first input matrix, must be contiguous. - - b (torch.Tensor): - The second input matrix, must be contiguous. - - b_s (torch.Tensor): - The scaling factor for the second input matrix, must be contiguous. - - Returns: - torch.Tensor: The result of the matrix multiplication. - """ - assert a.is_contiguous() and b.is_contiguous( - ), "Input tensors must be contiguous" - assert a_s.is_contiguous() and b_s.is_contiguous(), ( - "Scaling factor tensors must be contiguous") - K = a.size(-1) - M = a.numel() // K - N = b.size(0) - c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype()) - kernel = fp8_gemm_kernel(N, K) - kernel(a.view(M, K), b, c.view(M, N), a_s.view(M, -1), b_s) - return c - - -@tilelang.jit(out_idx=[4], pass_configs=pass_configs) -def fp8_index_kernel(h: int, d: int): - b = T.symbolic("b") - m = T.symbolic("m") - n = T.symbolic("n") - - blk_n1 = 512 - blk_n2 = 128 - - @T.prim_func - def fp8_index_kernel_( - q: T.Tensor[(b, m, h, d), FP8], - q_s: T.Tensor[(b, m, h), FP32], - k: T.Tensor[(b, n, d), FP8], - k_s: T.Tensor[(b, n), FP32], - o: T.Tensor[(b, m, n), FP32], - ) -> None: - with T.Kernel(b, m, T.ceildiv(n, blk_n1)) as (i_b, i_m, i1_n): - q_smem = T.alloc_shared((h, d), FP8) - T.copy(q[i_b, i_m, 0, 0], q_smem) - - q_s_frag = T.alloc_fragment(h, FP32) - T.copy(q_s[i_b, i_m, 0], q_s_frag) - - for i2_n in T.Pipelined(blk_n1 // blk_n2, num_stages=2): - k_smem = T.alloc_shared((blk_n2, d), FP8) - T.copy(k[i_b, i1_n * blk_n1 + i2_n * blk_n2, 0], k_smem) - - k_s_frag = T.alloc_fragment(blk_n2, FP32) - T.copy(k_s[i_b, i1_n * blk_n1 + i2_n * blk_n2], k_s_frag) - - logits = T.alloc_fragment((blk_n2, h), FP32) - T.gemm( - k_smem, - q_smem, - logits, - transpose_A=False, - transpose_B=True, - clear_accum=True, - ) - - for i_h, i3_n in T.Parallel(h, blk_n2): - logits[i3_n, - i_h] = T.max(logits[i3_n, i_h], 0) * q_s_frag[i_h] - - logits_sum = T.alloc_fragment(blk_n2, FP32) - T.reduce_sum(logits, logits_sum, dim=1) - - for i3_n in T.Parallel(blk_n2): - logits_sum[i3_n] *= k_s_frag[i3_n] - - T.copy(logits_sum, o[i_b, i_m, i1_n * blk_n1 + i2_n * blk_n2]) - - return fp8_index_kernel_ - - -def fp8_index( - q: torch.Tensor, - q_s: torch.Tensor, - k: torch.Tensor, - k_s: torch.Tensor, -) -> torch.Tensor: - """ - Perform index score using FP8 precision. - - Args: - - q (torch.Tensor): - The Q tensor, must be contiguous. - - q_s (torch.Tensor): - The scaling factor for Q (float), must be contiguous. - - k (torch.Tensor): - The K tensor, must be contiguous. - - k_s (torch.Tensor): - The scaling factor for K (e8m0 here), must be contiguous. - - fp8 q @ fp8 k -> fp32 logits - relu(fp32 logits) * q_s (weights) -> fp32 logits - fp32 logits -> fp32 logits_sum - fp32 logits_sum * k_s (e8m0) -> fp32 index_score - """ - return fp8_index_kernel(q.shape[2], q.shape[3])(q, q_s, k, k_s) - - -def convert_to_uint16(x): - hval = T.Cast("float16", x) - bits_uint = T.reinterpret("uint16", hval) - bits_uint = T.if_then_else(x < 0, ~bits_uint & (0xFFFF), - bits_uint | (0x8000)) - return bits_uint >> 8 - - -def convert_to_uint32(x): - bits_uint = T.reinterpret("uint32", x) - bits_uint = T.if_then_else( - x < 0, - ~bits_uint & T.Cast("uint32", (0xFFFFFFFF)), - bits_uint | T.Cast("uint32", (0x80000000)), - ) - return bits_uint - - -@tilelang.jit(pass_configs=pass_configs) -def tl_topk_impl(topk, in_dtype="float32", out_dtype="int32"): - batch = T.symbolic("batch") - seq_len = T.symbolic("seq_len") - RADIX = 1 << 8 - BLOCK_SIZE = 1024 - # assume the threshold bucket size after first pass is less than 4K - SMEM_INPUT_SIZE = 4096 - - @T.prim_func - def tl_topk_kernel( - input: T.Tensor[(batch, seq_len), in_dtype], - index: T.Tensor[(batch, topk), out_dtype], - starts: T.Tensor[(batch), "int32"], # noqa: F821 - ends: T.Tensor[(batch), "int32"], # noqa: F821 - ): - with T.Kernel(batch, threads=BLOCK_SIZE) as (bx): - tx = T.get_thread_binding() - - s_threshold_bin_id = T.alloc_shared([1], "int32") - s_histogram = T.alloc_shared([RADIX + 1], "int32") - s_num_input = T.alloc_shared([2], "int32") - s_input_idx = T.alloc_shared([2, SMEM_INPUT_SIZE], "int32") - - l_threshold_bin_id = T.alloc_var("int32") - l_new_topk = T.alloc_var("int32") - l_num_input = T.alloc_var("int32") - l_bin_id32 = T.alloc_var("int32") - l_val = T.alloc_var("int32") - l_start_pos = T.alloc_var("int32") - l_start_idx = T.alloc_var("int32") - l_end_idx = T.alloc_var("int32") - l_out_pos = T.alloc_var("int32") - - l_new_topk = topk - l_start_idx = starts[bx] - l_end_idx = ends[bx] - - # stage 1: use 8bit to do quick topk - T.fill(s_histogram, 0) - T.fill(s_num_input[0], 0) - - T.sync_threads() - for s in T.serial(T.ceildiv(seq_len, BLOCK_SIZE)): - input_idx = s * BLOCK_SIZE + tx - if (input_idx < l_end_idx and input_idx >= l_start_idx - and input_idx < seq_len): - inval_int16 = convert_to_uint16(input[bx, input_idx]) - T.atomic_add(s_histogram[inval_int16], 1) - T.sync_threads() - - # cumsum - if tx < RADIX: - for i in T.serial(8): - offset = 1 << i - T.sync_threads(3, RADIX) - if tx < RADIX - offset: - l_val = s_histogram[tx] + s_histogram[tx + offset] - T.sync_threads(3, RADIX) - if tx < RADIX - offset: - s_histogram[tx] = l_val - - # find threshold bin id - T.sync_threads(3, RADIX) - if s_histogram[tx] > l_new_topk and s_histogram[ - tx + 1] <= l_new_topk: - s_threshold_bin_id[0] = tx - T.sync_threads() - l_threshold_bin_id = s_threshold_bin_id[0] - l_new_topk = l_new_topk - s_histogram[l_threshold_bin_id + 1] - T.sync_threads() - - # collect all elements with exponent ≥ threshold - for s in T.serial(T.ceildiv(seq_len, BLOCK_SIZE)): - T.sync_threads() - input_idx = s * BLOCK_SIZE + tx - if (input_idx < l_end_idx and input_idx >= l_start_idx - and input_idx < seq_len): - bin_id = convert_to_uint16(input[bx, input_idx]) - l_bin_id32 = T.Cast("int32", bin_id) - if l_bin_id32 > l_threshold_bin_id: - # need a pos = T.atomic_add(s_histogram[bin_id32+1], 1) - pos = T.atomic_add(s_histogram[l_bin_id32 + 1], - 1, - return_prev=True) - index[bx, pos] = input_idx - - elif l_bin_id32 == l_threshold_bin_id and l_new_topk > 0: - # pos = s_num_input[0] - pos = T.atomic_add(s_num_input[0], 1, return_prev=True) - s_input_idx[0, pos] = input_idx - - # stage 2: tail pass - for round in T.serial(4): - if l_new_topk <= 0: - T.loop_break() - - r_idx = round % 2 - l_start_pos = topk - l_new_topk - - T.sync_threads() - T.fill(s_histogram, 0) - if tx == 0: - s_num_input[r_idx ^ 1] = 0 - T.sync_threads() - - l_num_input = s_num_input[r_idx] - for s in T.serial(T.ceildiv(l_num_input, BLOCK_SIZE)): - if s * BLOCK_SIZE + tx < l_num_input: - l_bin_id32 = T.Cast( - "int32", ((convert_to_uint32(input[bx, s_input_idx[ - r_idx, s * BLOCK_SIZE + tx]]) >> - (24 - round * 8)) & 0xFF)) - T.atomic_add(s_histogram[l_bin_id32], 1) - T.sync_threads() - # cumsum - if tx < RADIX: - for i in T.serial(8): - offset = 1 << i - T.sync_threads(3, RADIX) - if tx < RADIX - offset: - l_val = s_histogram[tx] + s_histogram[tx + offset] - T.sync_threads(3, RADIX) - if tx < RADIX - offset: - s_histogram[tx] = l_val - - # find threshold bin id - T.sync_threads(3, RADIX) - if s_histogram[tx] > l_new_topk and s_histogram[ - tx + 1] <= l_new_topk: - s_threshold_bin_id[0] = tx - T.sync_threads() - - l_threshold_bin_id = s_threshold_bin_id[0] - l_new_topk = l_new_topk - s_histogram[l_threshold_bin_id + 1] - T.sync_threads() - - for s in T.serial(T.ceildiv(l_num_input, BLOCK_SIZE)): - T.sync_threads() - if s * BLOCK_SIZE + tx < l_num_input: - l_bin_id32 = T.Cast( - "int32", ((convert_to_uint32(input[bx, s_input_idx[ - r_idx, s * BLOCK_SIZE + tx]]) >> - (24 - round * 8)) & 0xFF)) - if l_bin_id32 > l_threshold_bin_id: - pos = T.atomic_add(s_histogram[l_bin_id32 + 1], - 1, - return_prev=True) + l_start_pos - index[bx, pos] = s_input_idx[r_idx, - s * BLOCK_SIZE + tx] - elif (l_bin_id32 == l_threshold_bin_id - and l_new_topk > 0): - if round == 3: - l_out_pos = T.atomic_add( - s_histogram[l_bin_id32 + 1], - 1, - return_prev=True) + l_start_pos - if l_out_pos < topk: - index[bx, - l_out_pos] = s_input_idx[r_idx, s * - BLOCK_SIZE + - tx] - else: - pos = T.atomic_add(s_num_input[r_idx ^ 1], - 1, - return_prev=True) - s_input_idx[r_idx ^ 1, - pos] = s_input_idx[r_idx, - s * BLOCK_SIZE + - tx] - - return tl_topk_kernel - - -def tl_topk(input, starts, ends, topk): - batch, seq_len = input.shape - indexes = torch.zeros(batch, topk, dtype=torch.int32, device=input.device) - kernel = tl_topk_impl(topk) - kernel(input, indexes, starts, ends) - return indexes From 39d9d0e2313ded59e9750e8815cfef69c81cd9a1 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 30 Sep 2025 15:25:54 +0800 Subject: [PATCH 79/82] fix spacing Signed-off-by: youkaichao --- vllm/utils/deep_gemm.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index 2b5402bc5f6b..0e3bdaec829e 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -164,7 +164,7 @@ def fp8_mqa_logits( Args: q: Query tensor of shape [M, H, D]. Casted to - `torch.float8_e4m3fn` by caller. + `torch.float8_e4m3fn` by caller. kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or [N, 1]) with dtype `torch.float32`. @@ -219,13 +219,13 @@ def fp8_paged_mqa_logits( q_fp8: Query tensor of shape [B, next_n, H, D]. Casted to `torch.float8_e4m3fn` by caller. kv_cache_fp8: Paged KV-cache in packed FP8+scale layout with shape - [num_blocks, block_size, 1, D+4], dtype `torch.uint8`. The last - 4 bytes per (block,pos) store the `float` dequant scale. + [num_blocks, block_size, 1, D+4], dtype `torch.uint8`. The last + 4 bytes per (block,pos) store the `float` dequant scale. weights: Tensor of shape [B * next_n, H], dtype `torch.float32`. context_lens: Tensor of shape [B], dtype int32; effective context length - for each batch element. + for each batch element. block_tables: Tensor of shape [B, max_blocks], dtype int32; maps logical - block indices to physical blocks in the paged cache. + block indices to physical blocks in the paged cache. schedule_metadata: Returned by `get_paged_mqa_logits_metadata`; used to distribute work across SMs. max_model_len: Maximum sequence length used to size the logits output. From c80dfd504a3c6a4db3ffd913855029c832523e92 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 30 Sep 2025 15:30:23 +0800 Subject: [PATCH 80/82] add type for return value Signed-off-by: youkaichao --- vllm/v1/attention/backends/mla/indexer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index 59e2958880ee..4e6b974ad74d 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -97,8 +97,9 @@ class DeepseekV32IndexerMetadata: # TODO (zyongye) optimize this, this is now vibe coded -def kv_spans_from_batches(start_seq_loc: torch.Tensor, - seq_len_per_batch: torch.Tensor): +def kv_spans_from_batches( + start_seq_loc: torch.Tensor, + seq_len_per_batch: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """ Args: start_seq_loc: 1D long tensor [B+1], cumulative counts of From 07be34b65336a7d07beb6e8705a16d117e28935b Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 30 Sep 2025 15:31:45 +0800 Subject: [PATCH 81/82] add type for return value Signed-off-by: youkaichao --- vllm/attention/ops/common.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/vllm/attention/ops/common.py b/vllm/attention/ops/common.py index eb6d11c141c7..e659f1f3eae9 100644 --- a/vllm/attention/ops/common.py +++ b/vllm/attention/ops/common.py @@ -189,11 +189,11 @@ def _pack_seq_kernel( tl.store(out_row_ptr, x_vals, mask=valid_row[:, None] & d_mask) -def pack_seq_triton(x, - lengths, - pad_value=-float('inf'), - block_t=64, - block_d=64): +def pack_seq_triton(x: torch.Tensor, + lengths: torch.Tensor, + pad_value: float = -float('inf'), + block_t: int = 64, + block_d: int = 64) -> torch.Tensor: """ Pack sequences of different lengths into a batched tensor. @@ -290,7 +290,10 @@ def _unpack_seq_triton_kernel( tl.store(out_row_ptr, packed_vals, mask=valid_row[:, None] & d_mask) -def unpack_seq_triton(packed_tensor, lengths, block_t=64, block_d=64): +def unpack_seq_triton(packed_tensor: torch.Tensor, + lengths: torch.Tensor, + block_t: int = 64, + block_d: int = 64) -> torch.Tensor: """ Unpack a packed decode query tensor back to the original format. Efficient Triton implementation. From a0264c7567bfb553a0d0217d7fa63ef60c81cf71 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 30 Sep 2025 15:36:29 +0800 Subject: [PATCH 82/82] fix for amd Signed-off-by: youkaichao --- csrc/cache_kernels.cu | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index b014fd27a8d6..b1c43163c6a5 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -464,7 +464,11 @@ __global__ void concat_and_cache_ds_mla_kernel( float max_abs = fabsf(src_val); #pragma unroll for (int offset = 16; offset > 0; offset /= 2) { +#ifdef USE_ROCM + max_abs = fmaxf(max_abs, __shfl_down_sync(UINT64_MAX, max_abs, offset)); +#else max_abs = fmaxf(max_abs, __shfl_down_sync(0xFFFFFFFF, max_abs, offset)); +#endif } // The first lane of each warp in each tile writes the max_abs of this part @@ -536,7 +540,11 @@ __global__ void indexer_k_quant_and_cache_kernel( // Reduced amax for (int mask = 16; mask > 0; mask /= 2) { +#ifdef USE_ROCM + amax = fmaxf(amax, __shfl_xor_sync(uint64_t(-1), amax, mask)); +#else amax = fmaxf(amax, __shfl_xor_sync(unsigned(-1), amax, mask)); +#endif } __syncwarp(); float scale = fmaxf(amax, 1e-4) / 448.0f;