From 541fe067e94b592416181c14fddceba4f62a6017 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 26 Feb 2025 11:20:52 -0800 Subject: [PATCH 1/2] [BugFix] Make FP8 Linear compatible with torch.compile Signed-off-by: Woosuk Kwon --- .../model_executor/layers/quantization/fp8.py | 5 +---- .../layers/quantization/utils/fp8_utils.py | 20 +++++++++++++++++++ 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 76a7d4df8a36..a705f63be4ac 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -369,12 +369,9 @@ def apply(self, size_k=layer.input_size_per_partition, bias=bias) - # Note: lazy import to avoid triton import error. - from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - apply_w8a8_block_fp8_linear) if self.block_quant: assert self.quant_config.weight_block_size is not None - return apply_w8a8_block_fp8_linear( + return torch.ops.vllm.apply_w8a8_block_fp8_linear( input=x, weight=layer.weight, block_size=self.quant_config.weight_block_size, diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 61706f485f46..e51a86898e8a 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -17,6 +17,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( CUTLASS_BLOCK_FP8_SUPPORTED, CUTLASS_FP8_SUPPORTED, apply_fp8_linear) from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op logger = init_logger(__name__) @@ -81,6 +82,25 @@ def apply_w8a8_block_fp8_linear( return output.to(dtype=input.dtype).view(*output_shape) +def apply_w8a8_block_fp8_linear_fake( + input: torch.Tensor, + weight: torch.Tensor, + block_size: List[int], + weight_scale: torch.Tensor, + input_scale: Optional[torch.Tensor] = None, +) -> torch.Tensor: + output_shape = [*input.shape[:-1], weight.shape[0]] + return torch.empty_like(input, shape=output_shape) + + +direct_register_custom_op( + op_name="apply_w8a8_block_fp8_linear", + op_func=apply_w8a8_block_fp8_linear, + mutates_args=[], + fake_impl=apply_w8a8_block_fp8_linear_fake, +) + + # Unify the interface between `apply_w8a8_block_fp8_linear` and # `apply_fp8_linear` # NOTE(lucas): this is quite messy, we should think through this more formally From f113803c604184b449f84c3756325358e3591ea0 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 26 Feb 2025 13:38:11 -0800 Subject: [PATCH 2/2] Fix Signed-off-by: Woosuk Kwon --- vllm/model_executor/layers/quantization/utils/fp8_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index e51a86898e8a..7d91d2cf1c6e 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -90,7 +90,7 @@ def apply_w8a8_block_fp8_linear_fake( input_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: output_shape = [*input.shape[:-1], weight.shape[0]] - return torch.empty_like(input, shape=output_shape) + return torch.empty(output_shape, dtype=input.dtype, device=input.device) direct_register_custom_op(