diff --git a/vllm/envs.py b/vllm/envs.py index 8f0fd84f32f8..24268f8caefe 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -110,7 +110,10 @@ VLLM_USE_DEEP_GEMM: bool = False VLLM_XGRAMMAR_CACHE_MB: int = 0 VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256 - + VLLM_AITER_TRITON_FP8_BMM: bool = False + VLLM_AITER_TRITON_FUSED_CONCAT_ZEROS: bool = False + VLLM_AITER_TRITON_FUSED_ROPE_CACHE_CONCAT: bool = False + VLLM_AITER_TRITON_FUSED_ROPE_CACHE_CONCAT_QUANT: bool = False def get_default_cache_root(): return os.getenv( @@ -728,6 +731,19 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: # limit will actually be zero-copy decoded. "VLLM_MSGPACK_ZERO_COPY_THRESHOLD": lambda: int(os.getenv("VLLM_MSGPACK_ZERO_COPY_THRESHOLD", "256")), + + "VLLM_AITER_TRITON_FP8_BMM": + lambda: bool(int(os.getenv("VLLM_AITER_TRITON_FP8_BMM", "0"))), + + "VLLM_AITER_TRITON_FUSED_CONCAT_ZEROS": + lambda: bool(int(os.getenv("VLLM_AITER_TRITON_FUSED_CONCAT_ZEROS", "0"))), + + "VLLM_AITER_TRITON_FUSED_ROPE_CACHE_CONCAT": + lambda: bool(int(os.getenv("VLLM_AITER_TRITON_FUSED_ROPE_CACHE_CONCAT", "0"))), + + "VLLM_AITER_TRITON_FUSED_ROPE_CACHE_CONCAT_QUANT": + lambda: bool(int(os.getenv("VLLM_AITER_TRITON_FUSED_ROPE_CACHE_CONCAT_QUANT", "0"))), + } # end-env-vars-definition diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index d3f02ee70b9a..0ff76a583a2a 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -220,6 +220,37 @@ from vllm.v1.worker.gpu_input_batch import InputBatch from vllm.v1.worker.gpu_model_runner import GPUModelRunner +if envs.VLLM_AITER_TRITON_FUSED_ROPE_CACHE_CONCAT: + from aiter.ops.triton.fused_qk_concat import fused_qk_rope_cat_and_cache_mla + +if envs.VLLM_AITER_TRITON_FUSED_ROPE_CACHE_CONCAT_QUANT: + from aiter.ops.triton.fused_qk_concat import fused_qk_rope_cat_and_cache_mla_q_per_token_fp8_quant + import aiter as rocm_aiter + rocm_aiter_fp8 = rocm_aiter.dtypes.fp8 + +if envs.VLLM_AITER_TRITON_FP8_BMM: + def dynamic_per_batched_tensor_quant( + x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn + ): + DTYPE_MAX = torch.finfo(dtype).max + min_val, max_val = x.aminmax() + amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-10) + scale = DTYPE_MAX / amax + x_scl_sat = (x * scale).clamp(min=-DTYPE_MAX, max=DTYPE_MAX) + return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal() + + from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant + # @torch.compiler.disable + def aiter_triton_fp8_bmm_wrapper(x, w, w_s, group_size = 128, y = None, transpose_bm = False): + if y is not None: + batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant(x, w, w_s, group_size = group_size, YQ=y, transpose_bm=transpose_bm) + else: + y = batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant(x, w, w_s, group_size = group_size, transpose_bm = transpose_bm) + return y + +if envs.VLLM_AITER_TRITON_FUSED_CONCAT_ZEROS: + from aiter.ops.triton.fused_concat_zeros import fused_concat_zeros + logger = init_logger(__name__) @@ -636,10 +667,14 @@ def __init__( if self.use_rocm_aiter: self.rotary_emb = rotary_emb.forward_hip + self.cos_cache, self.sin_cache = rotary_emb.cos_cache, rotary_emb.sin_cache + self.rotary_emb_is_neox_style = rotary_emb.is_neox_style else: self.rotary_emb = rotary_emb.forward_native if current_platform.is_cuda(): self.rotary_emb = rotary_emb.forward_cuda + self.cos_cache, self.sin_cache = rotary_emb.cos_sin_cache.chunk(2, dim = -1) + self.rotary_emb_is_neox_style = rotary_emb.is_neox_style self.q_proj = q_proj self.kv_b_proj = kv_b_proj @@ -703,10 +738,17 @@ def _flash_attn_varlen_diff_headdims(self, def _v_up_proj_and_o_proj(self, x): # Convert from (B, N, L) to (N, B, L) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) - # Multiply (N, B, L) x (N, L, V) -> (N, B, V) - x = torch.bmm(x, self.W_UV) - # Convert from (N, B, V) to (B, N * V) - x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) + if envs.VLLM_AITER_TRITON_FP8_BMM: + # Multiply + Transpose (N, B, L) x (N, L, V) -> (N, B, V) -> (B, N, V) + # print(f"{x.dtype=}") + x = aiter_triton_fp8_bmm_wrapper(x, self.W_V, self.W_V_scale, group_size = 128, transpose_bm = True) + # Convert from (B, N, V) to (B, N * V) + x = x.reshape(-1, self.num_heads * self.v_head_dim) + else: + # Multiply (N, B, L) x (N, L, V) -> (N, B, V) + x = torch.bmm(x, self.W_UV) + # Convert from (N, B, V) to (B, N * V) + x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) return self.o_proj(x)[0] # Return `ql_nope`, `q_pe` @@ -717,10 +759,15 @@ def _q_proj_and_k_up_proj(self, x): # Convert from (B, N, P) to (N, B, P) q_nope = q_nope.transpose(0, 1) - # Multiply (N, B, P) x (N, P, L) -> (N, B, L) - ql_nope = torch.bmm(q_nope, self.W_UK_T) - # Convert from (N, B, L) to (B, N, L) - return ql_nope.transpose(0, 1), q_pe + if envs.VLLM_AITER_TRITON_FP8_BMM: + # Multiply + Transpose (N, B, P) x (N, P, L) -> (N, B, L) -> (B, N, L) + ql_nope = aiter_triton_fp8_bmm_wrapper(q_nope, self.W_K, self.W_K_scale, group_size = 128, transpose_bm = True) + else: + # Multiply (N, B, P) x (N, P, L) -> (N, B, L) + ql_nope = torch.bmm(q_nope, self.W_UK_T) + # Convert from (N, B, L) to (B, N, L) + ql_nope = ql_nope.transpose(0, 1) + return ql_nope, q_pe def process_weights_after_loading(self, act_dtype: torch.dtype): @@ -751,6 +798,7 @@ def get_and_maybe_dequant_weights(layer: LinearBase): # `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform # the bmm's in 16-bit, the extra memory overhead of this is fairly low kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T + assert kv_b_proj_weight.shape == ( self.kv_lora_rank, self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), ( @@ -767,11 +815,89 @@ def get_and_maybe_dequant_weights(layer: LinearBase): W_UK, W_UV = kv_b_proj_weight.split( [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - - # Convert from (L, N, V) to (N, L, V) - self.W_UV = W_UV.transpose(0, 1) - # Convert from (L, N, P) to (N, P, L) - self.W_UK_T = W_UK.permute(1, 2, 0) + + if (envs.VLLM_AITER_TRITON_FUSED_ROPE_CACHE_CONCAT or envs.VLLM_AITER_TRITON_FUSED_ROPE_CACHE_CONCAT_QUANT): + kv_cache_size = 8192 + max_position_embedding = self.cos_cache.shape[0] + for prefill_decode_size in [1, 256, 2048]: + for decode_batch_size in [0, 1, 256]: + if decode_batch_size > prefill_decode_size: + continue + + k_scale = torch.ones([1,], dtype=torch.float32, device=W_UK.device)[0] + + q = torch.empty((decode_batch_size, self.num_heads, self.kv_lora_rank + self.qk_rope_head_dim), dtype=torch.bfloat16, device=W_UK.device) + decode_ql_nope = q[..., :self.kv_lora_rank] + decode_q_pe = q[..., self.kv_lora_rank:] + + k = torch.empty((prefill_decode_size, 1, self.kv_lora_rank + self.qk_rope_head_dim), dtype=torch.bfloat16, device=W_UK.device) + k_c_normed = k[..., :self.kv_lora_rank].squeeze(1) + k_pe = k[..., self.kv_lora_rank:] + + input_positions = torch.randint(0, max_position_embedding, (decode_batch_size, ), device=W_UK.device) + slot_mapping = torch.randperm(kv_cache_size, device=W_UK.device)[:prefill_decode_size] + kv_cache = torch.empty((kv_cache_size, 1, self.kv_lora_rank + self.qk_rope_head_dim), dtype=torch.bfloat16, device=W_UK.device) + + if envs.VLLM_AITER_TRITON_FUSED_ROPE_CACHE_CONCAT_QUANT: + logger.info(f"[Triton] compiling fused_qk_rope_cat_and_cache_mla_q_per_token_fp8_quant with (decode tokens, total tokens) = ({decode_batch_size}, {prefill_decode_size})") + fused_qk_rope_cat_and_cache_mla_q_per_token_fp8_quant( + decode_ql_nope, + decode_q_pe, + k_c_normed.unsqueeze(1), + k_pe, + kv_cache, + slot_mapping, + input_positions, + self.cos_cache, + self.sin_cache, + k_scale, + self.rotary_emb_is_neox_style, + dtype_quant=rocm_aiter_fp8 + ) + if envs.VLLM_AITER_TRITON_FUSED_ROPE_CACHE_CONCAT: + logger.info(f"[Triton] compiling fused_qk_rope_cat_and_cache_mla with (decode tokens, total tokens) = ({decode_batch_size}, {prefill_decode_size})") + fused_qk_rope_cat_and_cache_mla( + decode_ql_nope, + decode_q_pe, + k_c_normed.unsqueeze(1), + k_pe, + kv_cache, + slot_mapping, + input_positions, + self.cos_cache, + self.sin_cache, + k_scale, + self.rotary_emb_is_neox_style, + ) + + if envs.VLLM_AITER_TRITON_FUSED_CONCAT_ZEROS: + max_batch_size = 256 + logger.info(f"[Triton] compiling fused_concat_zeros with shape = [1~{max_batch_size}] {self.num_heads} [{self.kv_lora_rank} : {self.qk_rope_head_dim}]") + for m in range(1, max_batch_size+1): + x1 = torch.empty((m, self.num_heads, self.kv_lora_rank), dtype=torch.bfloat16, device=W_UK.device) + x2 = torch.empty((m, self.num_heads, self.qk_rope_head_dim), dtype=torch.bfloat16, device=W_UK.device) + fused_concat_zeros(x1, x2) + + if envs.VLLM_AITER_TRITON_FP8_BMM: + max_batch_size = 256 + W_K = W_UK.transpose(0, 1) # 16 512 128 + W_V = W_UV.permute(1, 2, 0) # 16 128 512 + self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant(W_K, dtype=torch.float8_e4m3fnuz) + self.W_V, self.W_V_scale = dynamic_per_batched_tensor_quant(W_V, dtype=torch.float8_e4m3fnuz) + logger.info(f"[Triton] compiling fp8 BMM with shape = {self.W_K.shape[0]} [1~{max_batch_size}] {self.W_K.shape[1]} {self.W_K.shape[2]}") + logger.info(f"[Triton] compiling fp8 BMM with shape = {self.W_V.shape[0]} [1~{max_batch_size}] {self.W_V.shape[1]} {self.W_V.shape[2]}") + for m in range(1, max_batch_size+1): + x = torch.empty((self.W_K.shape[0], m, self.W_K.shape[2]), dtype=torch.bfloat16, device=self.W_K.device) + aiter_triton_fp8_bmm_wrapper(x, self.W_K, self.W_K_scale, group_size = 128, transpose_bm = True) + + x = torch.empty((self.W_V.shape[0], m, self.W_V.shape[2]), dtype=torch.bfloat16, device=self.W_V.device) + aiter_triton_fp8_bmm_wrapper(x, self.W_V, self.W_V_scale, group_size = 128, transpose_bm = True) + + else: + # Convert from (L, N, V) to (N, L, V) + self.W_UV = W_UV.transpose(0, 1) + # Convert from (L, N, P) to (N, P, L) + self.W_UK_T = W_UK.permute(1, 2, 0) def _compute_prefill_context( self, @@ -951,7 +1077,10 @@ def forward( decode_ql_nope, decode_q_pe = \ self._q_proj_and_k_up_proj(decode_hs_or_q_c) - if self.use_rocm_aiter: + if envs.VLLM_AITER_TRITON_FUSED_ROPE_CACHE_CONCAT or envs.VLLM_AITER_TRITON_FUSED_ROPE_CACHE_CONCAT_QUANT: + pass + # the rope operator for decode is now fused with concat_and_cache_mla operator using fused_qk_rope_cat_and_cache_mla + elif self.use_rocm_aiter: self.rotary_emb(attn_metadata.decode.input_positions, decode_q_pe, decode_k_pe) else: @@ -974,7 +1103,38 @@ def forward( prefill_q_pe.contiguous(), prefill_k_pe) # write the latent and rope to kv cache - if kv_cache.numel() > 0: + q_nope_pe, q_scale = None, None + if (envs.VLLM_AITER_TRITON_FUSED_ROPE_CACHE_CONCAT or envs.VLLM_AITER_TRITON_FUSED_ROPE_CACHE_CONCAT_QUANT) and has_decode and kv_cache.numel() > 0: + if envs.VLLM_AITER_TRITON_FUSED_ROPE_CACHE_CONCAT_QUANT: + q_nope_pe, q_scale = fused_qk_rope_cat_and_cache_mla_q_per_token_fp8_quant( + decode_ql_nope, + decode_q_pe, + k_c_normed.unsqueeze(1), + k_pe, + kv_cache, + attn_metadata.slot_mapping.flatten(), + attn_metadata.decode.input_positions, + self.cos_cache, + self.sin_cache, + layer._k_scale, + self.rotary_emb_is_neox_style, + dtype_quant=rocm_aiter_fp8 + ) + else: + q_nope_pe = fused_qk_rope_cat_and_cache_mla( + decode_ql_nope, + decode_q_pe, + k_c_normed.unsqueeze(1), + k_pe, + kv_cache, + attn_metadata.slot_mapping.flatten(), + attn_metadata.decode.input_positions, + self.cos_cache, + self.sin_cache, + layer._k_scale, + self.rotary_emb_is_neox_style, + ) + elif kv_cache.numel() > 0: ops.concat_and_cache_mla( k_c_normed, k_pe.squeeze(1), @@ -991,6 +1151,6 @@ def forward( if has_decode: output[:num_decode_tokens] = self._forward_decode( - decode_ql_nope, decode_q_pe, kv_cache, attn_metadata) + decode_ql_nope, decode_q_pe, kv_cache, attn_metadata, q_nope_pe=q_nope_pe, q_scale=q_scale) return output_padded diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 153442b91bf8..780427633215 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -15,6 +15,8 @@ MLACommonMetadata, MLACommonMetadataBuilder) +if envs.VLLM_AITER_TRITON_FUSED_CONCAT_ZEROS: + from aiter.ops.triton.fused_concat_zeros import fused_concat_zeros # yapf: enable @@ -180,18 +182,43 @@ def _forward_decode( q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: AiterMLAMetadata, + q_nope_pe: torch.Tensor = None, + q_scale: torch.Tensor = None, ) -> torch.Tensor: assert kv_c_and_k_pe_cache.numel() > 0 assert attn_metadata.decode is not None - B = q_nope.shape[0] - - q = torch.cat([q_nope, q_pe], dim=-1) - o = torch.zeros(B, - self.num_heads, - self.kv_lora_rank, - dtype=q.dtype, - device=q.device) + if envs.VLLM_AITER_TRITON_FUSED_ROPE_CACHE_CONCAT_QUANT and q_nope_pe is not None and q_scale is not None: + # q_nope_pe.dtype == torch.float8_e4m3fnuz + # q_scale.dtype == torch.float32 + # upcast back to bf16 for current implementation, this section can be commented out once aiter_mla_decode_fwd support fp8 and without using zero-tensor output + q = (q_nope_pe.to(torch.float32) * q_scale).to(q_nope.dtype) + B = q_nope.shape[0] + o = torch.zeros(B, + self.num_heads, + self.kv_lora_rank, + dtype=q.dtype, + device=q.device) + elif envs.VLLM_AITER_TRITON_FUSED_ROPE_CACHE_CONCAT and q_nope_pe is not None: + # q_nope_pe.dtype == torch.bfloat16 + q = q_nope_pe + B = q_nope.shape[0] + o = torch.zeros(B, + self.num_heads, + self.kv_lora_rank, + dtype=q.dtype, + device=q.device) + elif envs.VLLM_AITER_TRITON_FUSED_CONCAT_ZEROS: + q, o = fused_concat_zeros(q_nope, q_pe) + else: + B = q_nope.shape[0] + + q = torch.cat([q_nope, q_pe], dim=-1) + o = torch.zeros(B, + self.num_heads, + self.kv_lora_rank, + dtype=q.dtype, + device=q.device) kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2)