diff --git a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py index 02f8c593392c..a5a5b52f6039 100644 --- a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py +++ b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py @@ -17,7 +17,7 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - w8a8_triton_block_scaled_mm, + w8a8_block_fp8_matmul, ) from vllm.utils import FlexibleArgumentParser, cdiv @@ -158,7 +158,7 @@ def bench_fp8( "cutlass_fp8_fp8_fp16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm( a, b, scale_a, scale_b, torch.float16, bias.to(dtype=torch.float16) ), - "triton_fp8_fp8_fp16_scaled_mm_blockwise": lambda: w8a8_triton_block_scaled_mm( + "triton_fp8_fp8_fp16_scaled_mm_blockwise": lambda: w8a8_block_fp8_matmul( a_cont, b.t(), block_scale_a, block_scale_b.t(), (128, 128) ), "cutlass_fp8_fp8_fp16_scaled_mm_blockwise": lambda: ops.cutlass_scaled_mm( diff --git a/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py b/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py index 2010b8038563..db2398fc40a4 100644 --- a/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py +++ b/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py @@ -9,7 +9,7 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8, - w8a8_triton_block_scaled_mm, + w8a8_block_fp8_matmul, ) from vllm.triton_utils import triton from vllm.utils.deep_gemm import ( @@ -63,7 +63,7 @@ def deepgemm_gemm(): # === vLLM Triton Implementation === def vllm_triton_gemm(): - return w8a8_triton_block_scaled_mm(A_vllm, + return w8a8_block_fp8_matmul(A_vllm, B_vllm, A_scale_vllm, B_scale_vllm, diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index e02df540ce9d..211d1ecfe6e4 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -11,7 +11,7 @@ native_w8a8_block_matmul) from vllm.config import VllmConfig from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - cutlass_scaled_mm, per_token_group_quant_fp8, w8a8_triton_block_scaled_mm) + cutlass_scaled_mm, per_token_group_quant_fp8, w8a8_block_fp8_matmul) from vllm.platforms import current_platform from vllm.utils import has_deep_gemm from vllm.utils.deep_gemm import (fp8_gemm_nt, @@ -91,8 +91,7 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype) - out = w8a8_triton_block_scaled_mm(A_fp8, B_fp8, As, Bs, block_size, - out_dtype) + out = w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype) rel_diff = (torch.mean( torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / diff --git a/tests/kernels/quantization/test_fp8_quant_group.py b/tests/kernels/quantization/test_fp8_quant_group.py index 3d4c851a9b88..720eee62760d 100644 --- a/tests/kernels/quantization/test_fp8_quant_group.py +++ b/tests/kernels/quantization/test_fp8_quant_group.py @@ -20,11 +20,9 @@ (8, 513, 64), # Non-divisible (native only) ]) @pytest.mark.parametrize("seed", [42]) -@pytest.mark.parametrize("use_ue8m0", [True, False]) @torch.inference_mode() def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int, - group_size: int, seed: int, - use_ue8m0: bool) -> None: + group_size: int, seed: int) -> None: """Test QuantFP8 group quantization with various configurations. Tests both CUDA and native implementations, column-major scales, @@ -40,8 +38,7 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int, group_shape = GroupShape(1, group_size) quant_op = QuantFP8(static=False, group_shape=group_shape, - column_major_scales=False, - use_ue8m0=use_ue8m0) + column_major_scales=False) # 1. Test native implementation (always available) x_quant_native, scales_native = quant_op.forward_native(x.clone()) @@ -51,15 +48,9 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int, # 2. Test column-major scales configuration quant_op_col = QuantFP8(static=False, group_shape=group_shape, - column_major_scales=True, - use_ue8m0=use_ue8m0) + column_major_scales=True) _, scales_col = quant_op_col.forward_native(x.clone()) - assert scales_col.shape == (batch_size, expected_num_groups) - assert scales_col.stride(0) == 1 - assert scales_col.stride(1) == batch_size - - # Test column-major scales consistency - assert torch.allclose(scales_col, scales_native, rtol=1e-9, atol=1e-8) + assert scales_col.shape == (expected_num_groups, batch_size) # 3. Test CUDA implementation (only for divisible dimensions) if is_divisible: @@ -77,9 +68,8 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int, @pytest.mark.parametrize("seed", [42]) -@pytest.mark.parametrize("use_ue8m0", [True, False]) @torch.inference_mode() -def test_quantfp8_group_multidimensional(seed: int, use_ue8m0: bool) -> None: +def test_quantfp8_group_multidimensional(seed: int) -> None: current_platform.seed_everything(seed) group_size = 64 @@ -92,8 +82,7 @@ def test_quantfp8_group_multidimensional(seed: int, use_ue8m0: bool) -> None: group_shape = GroupShape(1, group_size) quant_op = QuantFP8(static=False, group_shape=group_shape, - column_major_scales=False, - use_ue8m0=use_ue8m0) + column_major_scales=False) x_quant, scales = quant_op.forward_native(x_3d.clone()) assert x_quant.shape == x_3d.shape @@ -102,8 +91,7 @@ def test_quantfp8_group_multidimensional(seed: int, use_ue8m0: bool) -> None: # Test column_major_scales with multi-dim quant_op_col = QuantFP8(static=False, group_shape=group_shape, - column_major_scales=True, - use_ue8m0=use_ue8m0) + column_major_scales=True) _, scales_col = quant_op_col.forward_native(x_3d.clone()) assert scales_col.shape == (batch1, hidden_dim // group_size, batch2) diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py index 200b6ecd5852..92ce10a9efc0 100644 --- a/tests/model_executor/test_enabled_custom_ops.py +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -17,6 +17,8 @@ from vllm.model_executor.layers.layernorm import (RMSNorm, dispatch_rocm_rmsnorm_func, fused_add_rms_norm, rms_norm) +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + cutlass_scaled_mm, dispatch_w8a8_blockscale_func, w8a8_block_fp8_matmul) from vllm.platforms import current_platform RMS_NORM_SUPPORTED_DTYPES = [torch.float16, torch.bfloat16] @@ -109,6 +111,34 @@ def test_enabled_ops_invalid(env: str): RMSNorm(1024).enabled() +@pytest.mark.skipif( + not current_platform.is_rocm() or not current_platform.is_fp8_fnuz(), + reason="AITER is a feature exclusive for ROCm and FP8_FNUZ") +@pytest.mark.parametrize("use_cutlass", [True, False]) +@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"]) +@pytest.mark.parametrize("use_rocm_aiter_gemm_w8a8_blockscale", ["0", "1"]) +def test_w8a8_blockscale_dispatch(use_cutlass: bool, use_rocm_aiter: str, + use_rocm_aiter_gemm_w8a8_blockscale: str, + monkeypatch): + + monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter) + monkeypatch.setenv("VLLM_ROCM_USE_AITER_LINEAR", + use_rocm_aiter_gemm_w8a8_blockscale) + + use_aiter_and_is_supported = (bool(int(use_rocm_aiter)) and bool( + int(use_rocm_aiter_gemm_w8a8_blockscale))) + block_scale_func = dispatch_w8a8_blockscale_func( + use_cutlass, use_aiter_and_is_supported=use_aiter_and_is_supported) + if use_cutlass: + assert block_scale_func == cutlass_scaled_mm + elif current_platform.is_rocm() and int(use_rocm_aiter) and int( + use_rocm_aiter_gemm_w8a8_blockscale): + assert block_scale_func == ( + torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale) + else: + assert block_scale_func == w8a8_block_fp8_matmul + + @pytest.mark.parametrize("use_rocm_aiter", ["0", "1"]) def test_topk_dispatch(use_rocm_aiter: str, monkeypatch): monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter) diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index af8c7ec3b482..c0ab3fbb1062 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -18,9 +18,6 @@ CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8, CompressedTensorsWNA16) -from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 -from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - W8A8BlockFp8LinearOp) from vllm.model_executor.layers.quantization.utils.quant_utils import ( cutlass_fp4_supported) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( @@ -745,35 +742,3 @@ def test_compressed_tensors_transforms_perplexity(vllm_runner, model, prompt, perplexity = llm.generate_prompt_perplexity([prompt])[0] print(perplexity) assert perplexity <= exp_perplexity - - -def test_compressed_tensors_fp8_block_enabled(vllm_runner): - model_path = "RedHatAI/Qwen3-0.6B-FP8-BLOCK" - with vllm_runner(model_path) as llm: - - fp8_dtype = current_platform.fp8_dtype() - - def check_model(model): - layer = model.model.layers[0] - - qkv_proj = layer.self_attn.qkv_proj - assert isinstance(qkv_proj.quant_method, - CompressedTensorsLinearMethod) - assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8) - assert isinstance(qkv_proj.scheme.w8a8_block_fp8_linear, - W8A8BlockFp8LinearOp) - - assert qkv_proj.weight.dtype is fp8_dtype - assert qkv_proj.weight_scale.dtype is torch.float32 - assert len(qkv_proj.weight.shape) == 2 - assert len(qkv_proj.weight_scale.shape) == 2 - - input_quant_op = \ - qkv_proj.scheme.w8a8_block_fp8_linear.input_quant_op - assert isinstance(input_quant_op, QuantFP8) - assert input_quant_op._forward_method == input_quant_op.forward_cuda - - llm.apply_model(check_model) - - output = llm.generate_greedy("Hello my name is", max_tokens=20) - assert output diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index df6564077e8a..bf2cb325a23d 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -545,23 +545,6 @@ def __post_init__(self): # local attention. self.scheduler_config.disable_hybrid_kv_cache_manager = True - def has_blocked_weights(): - if self.quant_config is not None: - if hasattr(self.quant_config, "weight_block_size"): - return self.quant_config.weight_block_size is not None - elif hasattr(self.quant_config, "has_blocked_weights"): - return self.quant_config.has_blocked_weights() - return False - - # Enable quant_fp8 CUDA ops (TODO disable in follow up) - # On H100 the CUDA kernel is faster than - # native implementation - # https://github.com/vllm-project/vllm/issues/25094 - if has_blocked_weights(): - custom_ops = self.compilation_config.custom_ops - if "none" not in custom_ops and "-quant_fp8" not in custom_ops: - custom_ops.append("+quant_fp8") - def update_sizes_for_sequence_parallelism(self, possible_sizes: list) -> list: # remove the sizes that not multiple of tp_size when diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 3f771ea2abd1..d6550dd16892 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -644,14 +644,6 @@ def get_cache_scale(self, name: str) -> Optional[str]: # If no matches, return None return None - def has_blocked_weights(self) -> bool: - for scheme in self.target_scheme_map.values(): - weight_quant = scheme.get("weights") - if (weight_quant is not None - and weight_quant.strategy == QuantizationStrategy.BLOCK): - return True - return False - @staticmethod def supports_cutlass_24( weight_quant: Optional[QuantizationArgs], diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index fa0816959fcd..d42ae22c5139 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -11,7 +11,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - W8A8BlockFp8LinearOp, check_aiter_fp8_linear_support, + apply_fp8_block_linear, check_aiter_fp8_linear_support, create_fp8_input_scale, create_fp8_scale_parameter, create_fp8_weight_parameter, maybe_post_process_fp8_weight_block, process_fp8_weight_block_strategy, process_fp8_weight_channel_strategy, @@ -41,30 +41,16 @@ def __init__(self, weight_quant: QuantizationArgs, self.strategy = weight_quant.strategy self.out_dtype = torch.get_default_dtype() self.is_static_input_scheme = is_static_input_scheme + self.act_q_group_shape = GroupShape.PER_TENSOR \ + if is_static_input_scheme else GroupShape.PER_TOKEN + self.fp8_linear = Fp8LinearOp( + act_quant_static=self.is_static_input_scheme, + act_quant_group_shape=self.act_q_group_shape) self.weight_block_size = self.weight_quant.block_structure - if self.weight_block_size is not None: - self.act_q_group_shape = GroupShape(1, self.weight_block_size[0]) - else: - self.act_q_group_shape = GroupShape.PER_TENSOR \ - if is_static_input_scheme else GroupShape.PER_TOKEN - self.cutlass_block_fp8_supported = cutlass_block_fp8_supported() self.use_aiter_and_is_supported = check_aiter_fp8_linear_support() - if self.weight_block_size is not None: - assert not self.is_static_input_scheme - self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp( - weight_group_shape=GroupShape(*self.weight_block_size), - act_quant_group_shape=self.act_q_group_shape, - cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, - use_aiter_and_is_supported=self.use_aiter_and_is_supported, - ) - else: - self.fp8_linear = Fp8LinearOp( - act_quant_static=self.is_static_input_scheme, - act_quant_group_shape=self.act_q_group_shape) - @classmethod def get_min_capability(cls) -> int: # lovelace and up @@ -155,14 +141,13 @@ def apply_weights(self, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - if self.weight_block_size is not None: - return self.w8a8_block_fp8_linear.apply( + if layer.weight_block_size is not None: + return apply_fp8_block_linear( + layer, input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - input_scale=layer.input_scale, bias=bias, - ) + cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, + use_aiter_and_is_supported=self.use_aiter_and_is_supported) return self.fp8_linear.apply(input=x, weight=layer.weight, diff --git a/vllm/model_executor/layers/quantization/deepgemm.py b/vllm/model_executor/layers/quantization/deepgemm.py index 8452f686b3ac..2236824ce910 100644 --- a/vllm/model_executor/layers/quantization/deepgemm.py +++ b/vllm/model_executor/layers/quantization/deepgemm.py @@ -42,7 +42,7 @@ def prepare_block_fp8_matmul_inputs( return M, N, K, C -def w8a8_deepgemm_block_scaled_mm( +def w8a8_block_fp8_matmul_deepgemm( A: torch.Tensor, B: torch.Tensor, As: torch.Tensor, @@ -58,7 +58,7 @@ def w8a8_deepgemm_block_scaled_mm( return C -def w8a8_deepgemm_block_scaled_mm_fake( +def w8a8_block_fp8_matmul_deepgemm_fake( A: torch.Tensor, B: torch.Tensor, As: torch.Tensor, @@ -72,7 +72,7 @@ def w8a8_deepgemm_block_scaled_mm_fake( direct_register_custom_op( - op_name="w8a8_deepgemm_block_scaled_mm", - op_func=w8a8_deepgemm_block_scaled_mm, - fake_impl=w8a8_deepgemm_block_scaled_mm_fake, + op_name="w8a8_block_fp8_matmul_deepgemm", + op_func=w8a8_block_fp8_matmul_deepgemm, + fake_impl=w8a8_block_fp8_matmul_deepgemm_fake, ) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index f77e5880209d..5fbc1545ea79 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -31,7 +31,7 @@ register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights, select_cutlass_fp8_gemm_impl, swap_w13_to_w31) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - W8A8BlockFp8LinearOp, check_aiter_fp8_linear_support, + apply_fp8_block_linear, check_aiter_fp8_linear_support, create_fp8_input_scale, create_fp8_scale_parameter, create_fp8_weight_parameter, expert_weight_is_col_major, maybe_post_process_fp8_weight_block, process_fp8_weight_block_strategy, @@ -236,28 +236,15 @@ def __init__(self, quant_config: Fp8Config): self.weight_block_size = self.quant_config.weight_block_size self.block_quant = self.weight_block_size is not None self.act_q_static = self.quant_config.activation_scheme == "static" - if self.weight_block_size: - self.act_q_group_shape = GroupShape(1, self.weight_block_size[0]) + # Use per-token quantization for better perf if dynamic and cutlass + if not self.act_q_static and cutlass_fp8_supported(): + self.act_q_group_shape = GroupShape.PER_TOKEN else: - # Use per-token quantization for better perf if dynamic and cutlass - if not self.act_q_static and cutlass_fp8_supported(): - self.act_q_group_shape = GroupShape.PER_TOKEN - else: - self.act_q_group_shape = GroupShape.PER_TENSOR + self.act_q_group_shape = GroupShape.PER_TENSOR - if self.block_quant: - assert not self.act_q_static - assert self.weight_block_size is not None - self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp( - weight_group_shape=GroupShape(*self.weight_block_size), - act_quant_group_shape=self.act_q_group_shape, - cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, - use_aiter_and_is_supported=self.use_aiter_and_is_supported, - ) - else: - self.fp8_linear = Fp8LinearOp( - act_quant_static=self.act_q_static, - act_quant_group_shape=self.act_q_group_shape) + self.fp8_linear = Fp8LinearOp( + act_quant_static=self.act_q_static, + act_quant_group_shape=self.act_q_group_shape) def create_weights( self, @@ -406,15 +393,12 @@ def apply(self, bias=bias) if self.block_quant: - assert self.weight_block_size is not None - - return self.w8a8_block_fp8_linear.apply( + return apply_fp8_block_linear( + layer, input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - input_scale=layer.input_scale, bias=bias, - ) + cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, + use_aiter_and_is_supported=self.use_aiter_and_is_supported) return self.fp8_linear.apply(input=x, weight=layer.weight, diff --git a/vllm/model_executor/layers/quantization/input_quant_fp8.py b/vllm/model_executor/layers/quantization/input_quant_fp8.py index ece3e5817116..31182f40b48f 100644 --- a/vllm/model_executor/layers/quantization/input_quant_fp8.py +++ b/vllm/model_executor/layers/quantization/input_quant_fp8.py @@ -27,14 +27,11 @@ class QuantFP8(CustomOp): This CustomOp supports both static and dynamic quantization. """ - def __init__( - self, - static: bool, - group_shape: GroupShape, - num_token_padding: Optional[int] = None, - column_major_scales: bool = False, - use_ue8m0: Optional[bool] = None, # for Torch compile - ): + def __init__(self, + static: bool, + group_shape: GroupShape, + num_token_padding: Optional[int] = None, + column_major_scales: bool = False): """ :param static: static or dynamic quantization :param group_shape: quantization group shape (PER_TOKEN, PER_TENSOR, @@ -49,7 +46,6 @@ def __init__( self.group_shape = group_shape self.num_token_padding = num_token_padding self.column_major_scales = column_major_scales - self.use_ue8m0 = use_ue8m0 self.is_group_quant = group_shape.is_per_group() if self.is_group_quant: @@ -74,8 +70,7 @@ def forward_cuda( x, group_size=self.group_size, column_major_scales=self.column_major_scales, - dtype=_FP8_DTYPE, - use_ue8m0=self.use_ue8m0) + dtype=_FP8_DTYPE) assert (scale is not None) == self.static assert scale_ub is None or (not self.static and self.group_shape @@ -142,10 +137,7 @@ def _quantize_group_native( x_grouped = x.view(-1, num_groups, self.group_size) absmax = x_grouped.abs().max(dim=-1, keepdim=True)[0].float() - scales_raw = absmax / _FP8_MAX - if self.use_ue8m0: - scales_raw = torch.exp2(torch.ceil(torch.log2(scales_raw))) - scales = (scales_raw).clamp(min=_FP8_MIN_SCALING_FACTOR) + scales = (absmax / _FP8_MAX).clamp(min=_FP8_MIN_SCALING_FACTOR) x_scaled = x_grouped / scales x_quant = x_scaled.clamp(_FP8_MIN, _FP8_MAX).to(_FP8_DTYPE) @@ -159,6 +151,6 @@ def _quantize_group_native( scales = scales.reshape(orig_shape[:-1] + (num_groups, )) if self.column_major_scales: - scales = scales.transpose(-2, -1).contiguous().transpose(-1, -2) + scales = scales.transpose(-2, -1).contiguous() return x_quant, scales diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 441bba6baacc..b32c67dec7ff 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -13,9 +13,8 @@ import vllm.envs as envs from vllm import _custom_ops as ops from vllm.logger import init_logger -from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape, group_broadcast) + group_broadcast) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( CUTLASS_BLOCK_FP8_SUPPORTED) from vllm.model_executor.parameter import (BlockQuantScaleParameter, @@ -25,7 +24,6 @@ from vllm.triton_utils import tl, triton from vllm.utils import direct_register_custom_op from vllm.utils.deep_gemm import (is_deep_gemm_e8m0_used, - is_deep_gemm_supported, should_use_deepgemm_for_fp8_linear) logger = init_logger(__name__) @@ -37,8 +35,6 @@ def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool: return x == torch.float8_e4m3fn or x == torch.float8_e4m3fnuz -# We need to pass in the is_hopper flag as argument because the function -# current_platform.is_device_capability() is not supported by Torch compiler. def cutlass_scaled_mm( A: torch.Tensor, B: torch.Tensor, @@ -46,17 +42,15 @@ def cutlass_scaled_mm( Bs: torch.Tensor, block_size: list[int], output_dtype: torch.dtype = torch.float16, - is_hopper: Optional[bool] = None, ) -> torch.Tensor: - if is_hopper is None: - is_hopper = current_platform.is_device_capability(90) return ops.cutlass_scaled_mm( A, B.T, out_dtype=output_dtype, scale_a=As, # SM90 block FP8 requires row-major scale_b, which we do ahead of time - scale_b=Bs if block_size is not None and is_hopper else Bs.T) + scale_b=Bs if block_size is not None + and current_platform.is_device_capability(90) else Bs.T) def rocm_aiter_gemm_w8a8_blockscale_impl( @@ -102,190 +96,122 @@ def rocm_aiter_gemm_w8a8_blockscale_fake( aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128) -# TODO we should be able to change the type of block_size to GroupShape -# after we resolve GroupShape compilation issue -# https://github.com/vllm-project/vllm/issues/25270 -def _w8a8_triton_block_scaled_mm_func( - qx: torch.Tensor, - weight: torch.Tensor, - x_scale: torch.Tensor, - weight_scale: torch.Tensor, - block_size: list[int], - output_dtype: torch.dtype, -) -> torch.Tensor: - return w8a8_triton_block_scaled_mm(qx, weight, x_scale, weight_scale, - block_size, output_dtype) +def dispatch_w8a8_blockscale_func( + use_cutlass: bool, use_aiter_and_is_supported: bool +) -> Callable[[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + list[int], + torch.dtype, +], torch.Tensor]: + if use_cutlass: + return cutlass_scaled_mm + if (use_aiter_and_is_supported): + return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale + return w8a8_block_fp8_matmul -def _w8a8_triton_block_scaled_mm_fake( - qx: torch.Tensor, +# TODO fix ROCm->Triton custom path: +# https://github.com/vllm-project/vllm/issues/14397 +def apply_w8a8_block_fp8_linear( + input: torch.Tensor, weight: torch.Tensor, - x_scale: torch.Tensor, - weight_scale: torch.Tensor, block_size: list[int], - output_dtype: torch.dtype, + weight_scale: torch.Tensor, + input_scale: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED, + use_aiter_and_is_supported: bool = False, ) -> torch.Tensor: - return torch.empty((qx.size(0), weight.size(0)), - dtype=output_dtype, - device=qx.device) + assert input_scale is None + # View input as 2D matrix for fp8 methods + input_2d = input.view(-1, input.shape[-1]) + output_shape = [*input.shape[:-1], weight.shape[0]] + output_dtype = input.dtype + if should_use_deepgemm_for_fp8_linear(output_dtype, weight): -# Note: the check can be removed when CPU torch > 2.7 -if not current_platform.is_cpu(): - direct_register_custom_op( - "w8a8_triton_block_scaled_mm_func", - _w8a8_triton_block_scaled_mm_func, - fake_impl=_w8a8_triton_block_scaled_mm_fake, - dispatch_key="CUDA", - ) - - -# TODO fix ROCm->Triton custom path: -# https://github.com/vllm-project/vllm/issues/14397 -class W8A8BlockFp8LinearOp: - """ - This class executes a Blocked FP8 linear layer using cutlass if supported - and torch.scaled_mm otherwise. - """ - - def __init__( - self, - weight_group_shape: GroupShape, - act_quant_group_shape: GroupShape, - cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED, - use_aiter_and_is_supported: bool = False, - ): - self.weight_group_shape = weight_group_shape - self.act_quant_group_shape = act_quant_group_shape - self.is_deep_gemm_supported = is_deep_gemm_supported() - self.is_hopper = current_platform.is_device_capability(90) - - # Get the correct blockscale mul and input quant operations. - # We can't use _dispatch_w8a8_blockscale_op to figure out if we want - # to use deepgemm because we don't know the shape of weights (and - # whether deepgemm supports it) at the init time. - self.w8a8_blockscale_op, self.input_quant_op = \ - self._dispatch_w8a8_blockscale_op( - cutlass_block_fp8_supported, use_aiter_and_is_supported) - self.deepgemm_input_quant_op = (QuantFP8( - False, - self.act_quant_group_shape, - column_major_scales=True, - use_ue8m0=is_deep_gemm_e8m0_used()) if self.is_deep_gemm_supported - else None) - - def apply( - self, - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - input_scale: Optional[torch.Tensor] = None, - bias: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - assert input_scale is None - # View input as 2D matrix for fp8 methods input_2d = input.view(-1, input.shape[-1]) output_shape = [*input.shape[:-1], weight.shape[0]] - output_dtype = input.dtype - if should_use_deepgemm_for_fp8_linear(output_dtype, weight, - self.is_deep_gemm_supported): - output = self._run_deepgemm(input, weight, weight_scale) - if bias is not None: - output = output + bias - return output.to(dtype=input.dtype).view(*output_shape) - - output = self.w8a8_blockscale_op(input_2d, weight, weight_scale) - if bias is not None: - output = output + bias - return output.to(dtype=input.dtype).view(*output_shape) + q_input, x_scale = per_token_group_quant_fp8( + input_2d, + block_size[1], + column_major_scales=True, + ) - def _run_deepgemm( - self, - input_2d: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - ) -> torch.Tensor: # ensure DeepGEMM-backed custom op is registered before use import vllm.model_executor.layers.quantization.deepgemm # noqa: F401 - assert self.deepgemm_input_quant_op is not None - q_input, x_scale = self.deepgemm_input_quant_op(input_2d) - return torch.ops.vllm.w8a8_deepgemm_block_scaled_mm( + output = torch.ops.vllm.w8a8_block_fp8_matmul_deepgemm( q_input, weight, x_scale, weight_scale, - self.weight_group_shape, - output_dtype=input_2d.dtype) - - def _run_cutlass( - self, - input_2d: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - ) -> torch.Tensor: - assert self.input_quant_op is not None - if self.is_hopper: - # We pad unconditionally (even if shape is already divisible by 4) - # to support dynamic shape for input_2d.shape[0] in torch.compile - x = torch.nn.functional.pad(input_2d, - (0, 0, 0, -input_2d.shape[0] % 4)) - else: - x = input_2d - - q_input, x_scale = self.input_quant_op(x) - output = cutlass_scaled_mm(q_input, weight, x_scale, weight_scale, - list(self.weight_group_shape), - input_2d.dtype, self.is_hopper) - output = output[0:input_2d.shape[0], ...] - return output - - def _run_aiter( - self, - input_2d: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - ) -> torch.Tensor: - assert self.act_quant_group_shape == GroupShape(1, 128) - q_input, x_scale = aiter_per1x128_quant( - input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8) - return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale( - q_input, weight, x_scale, weight_scale, self.weight_group_shape, - input_2d.dtype) - - def _run_triton( - self, - input_2d: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - ) -> torch.Tensor: - assert self.input_quant_op is not None - q_input, x_scale = self.input_quant_op(input_2d) - return torch.ops.vllm.w8a8_triton_block_scaled_mm_func( - q_input, weight, x_scale, weight_scale, self.weight_group_shape, - input_2d.dtype) - - def _dispatch_w8a8_blockscale_op( - self, - use_cutlass: bool, - use_aiter_and_is_supported: bool, - ) -> tuple[Callable[[ - torch.Tensor, - torch.Tensor, - torch.Tensor, - ], torch.Tensor], Optional[QuantFP8]]: - if use_cutlass: - return self._run_cutlass, (QuantFP8(False, - self.act_quant_group_shape, - column_major_scales=True, - use_ue8m0=False)) + block_size, + output_dtype=output_dtype) + if bias is not None: + output += bias + return output.to(dtype=output_dtype).view(*output_shape) + + w8a8_blockscale_func = dispatch_w8a8_blockscale_func( + cutlass_block_fp8_supported, use_aiter_and_is_supported) + if cutlass_block_fp8_supported: + num_pad = 0 + if current_platform.is_device_capability(90): + # pad first dimension to be divisible by 4 due to + # cutlass blockwise gemm limitation for hopper + num_pad = 4 - (input_2d.shape[0] % 4) + if num_pad > 0: + input_2d = torch.nn.functional.pad(input_2d, + (0, 0, 0, num_pad), + "constant", 0) + q_input, x_scale = per_token_group_quant_fp8(input_2d, + block_size[1], + column_major_scales=True) + output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale, + block_size, input.dtype) + if num_pad > 0: + output = output[:-num_pad] + else: if use_aiter_and_is_supported: - return self._run_aiter, None - return self._run_triton, (QuantFP8(False, - self.act_quant_group_shape, - column_major_scales=False, - use_ue8m0=False)) + q_input, x_scale = aiter_per1x128_quant( + input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8) + else: + q_input, x_scale = per_token_group_quant_fp8( + input_2d, block_size[1], column_major_scales=False) + + output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale, + block_size, input.dtype) + + if bias is not None: + output = output + bias + return output.to(dtype=input.dtype).view(*output_shape) + + +def apply_w8a8_block_fp8_linear_fake( + input: torch.Tensor, + weight: torch.Tensor, + block_size: list[int], + weight_scale: torch.Tensor, + input_scale: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED, + use_aiter_and_is_supported: bool = False, +) -> torch.Tensor: + output_shape = [*input.shape[:-1], weight.shape[0]] + return torch.empty(output_shape, dtype=input.dtype, device=input.device) + + +if not current_platform.is_cpu(): + direct_register_custom_op( + op_name="apply_w8a8_block_fp8_linear", + op_func=apply_w8a8_block_fp8_linear, + mutates_args=[], + fake_impl=apply_w8a8_block_fp8_linear_fake, + ) def input_to_float8( @@ -537,7 +463,7 @@ def per_token_group_quant_fp8( @triton.jit -def _w8a8_triton_block_scaled_mm( +def _w8a8_block_fp8_matmul( # Pointers to inputs and output A, B, @@ -662,7 +588,7 @@ def get_w8a8_block_fp8_configs(N: int, K: int, block_n: int, return None -def w8a8_triton_block_scaled_mm( +def w8a8_block_fp8_matmul( A: torch.Tensor, B: torch.Tensor, As: torch.Tensor, @@ -722,7 +648,7 @@ def grid(META): return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), ) - _w8a8_triton_block_scaled_mm[grid]( + _w8a8_block_fp8_matmul[grid]( A, B, C, @@ -1005,6 +931,25 @@ def maybe_post_process_fp8_weight_block(layer: torch.nn.Module, layer.weight_scale.data.T.contiguous(), requires_grad=False) +def apply_fp8_block_linear(layer: torch.nn.Module, input: torch.Tensor, + bias: Optional[torch.Tensor], + cutlass_block_fp8_supported: bool, + use_aiter_and_is_supported: bool) -> torch.Tensor: + """Apply block-wise FP8 linear operation.""" + assert layer.weight_block_size is not None + + return torch.ops.vllm.apply_w8a8_block_fp8_linear( + input=input, + weight=layer.weight, + block_size=layer.weight_block_size, + weight_scale=layer.weight_scale, + input_scale=layer.input_scale, + bias=bias, + cutlass_block_fp8_supported=cutlass_block_fp8_supported, + use_aiter_and_is_supported=use_aiter_and_is_supported, + ) + + def expert_weight_is_col_major(x: torch.Tensor) -> bool: assert x.dim() == 3 b, m, n = x.shape diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index 979c10f2c3e9..f955beb92b36 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -9,7 +9,7 @@ import functools import importlib import os -from typing import Any, Callable, NoReturn, Optional +from typing import Any, Callable, NoReturn import torch @@ -184,13 +184,9 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor): return 1 - sim -def should_use_deepgemm_for_fp8_linear( - output_dtype: torch.dtype, - weight: torch.Tensor, - supports_deep_gemm: Optional[bool] = None): - if supports_deep_gemm is None: - supports_deep_gemm = is_deep_gemm_supported() - return (supports_deep_gemm and output_dtype == torch.bfloat16 +def should_use_deepgemm_for_fp8_linear(output_dtype: torch.dtype, + weight: torch.Tensor): + return (is_deep_gemm_supported() and output_dtype == torch.bfloat16 and weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0)