Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,12 +369,9 @@ def apply(self,
size_k=layer.input_size_per_partition,
bias=bias)

# Note: lazy import to avoid triton import error.
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
apply_w8a8_block_fp8_linear)
if self.block_quant:
assert self.quant_config.weight_block_size is not None
return apply_w8a8_block_fp8_linear(
return torch.ops.vllm.apply_w8a8_block_fp8_linear(
input=x,
weight=layer.weight,
block_size=self.quant_config.weight_block_size,
Expand Down
20 changes: 20 additions & 0 deletions vllm/model_executor/layers/quantization/utils/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_BLOCK_FP8_SUPPORTED, CUTLASS_FP8_SUPPORTED, apply_fp8_linear)
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op

logger = init_logger(__name__)

Expand Down Expand Up @@ -81,6 +82,25 @@ def apply_w8a8_block_fp8_linear(
return output.to(dtype=input.dtype).view(*output_shape)


def apply_w8a8_block_fp8_linear_fake(
input: torch.Tensor,
weight: torch.Tensor,
block_size: List[int],
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
output_shape = [*input.shape[:-1], weight.shape[0]]
return torch.empty(output_shape, dtype=input.dtype, device=input.device)


direct_register_custom_op(
op_name="apply_w8a8_block_fp8_linear",
op_func=apply_w8a8_block_fp8_linear,
mutates_args=[],
fake_impl=apply_w8a8_block_fp8_linear_fake,
)


# Unify the interface between `apply_w8a8_block_fp8_linear` and
# `apply_fp8_linear`
# NOTE(lucas): this is quite messy, we should think through this more formally
Expand Down