Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions tests/compile/test_fusion.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import pytest
import torch
from compressed_tensors.quantization import FP8_DTYPE

import vllm.envs as envs
from vllm.compilation.fusion import (FusionPass, find_auto_fn,
Expand All @@ -9,7 +8,8 @@
from vllm.config import CompilationConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear)
apply_fp8_linear, cutlass_fp8_supported)
from vllm.utils import FP8_DTYPE

from .backend import TestBackend

Expand All @@ -24,16 +24,25 @@ def __init__(self, hidden_size: int, eps: float, *args, **kwargs):
torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
for _ in range(2)
]
self.cutlass_fp8_supported = cutlass_fp8_supported()

def forward(self, x):
resid = torch.relu(x)
y = self.norm[0](x)

x2 = apply_fp8_linear(y, self.w[0], self.scale[0], self.scale[1])
x2 = apply_fp8_linear(y,
self.w[0],
self.scale[0],
self.scale[1],
cutlass_fp8_supported=self.cutlass_fp8_supported)
# make sure resid is used for replacement to work
y2, resid = self.norm[1](x2, resid)

x3 = apply_fp8_linear(y2, self.w[1], self.scale[2], self.scale[3])
x3 = apply_fp8_linear(y2,
self.w[1],
self.scale[2],
self.scale[3],
cutlass_fp8_supported=self.cutlass_fp8_supported)
y3, resid = self.norm[2](x3, resid) # use resid here
return y3

Expand All @@ -42,8 +51,8 @@ def forward(self, x):
@pytest.mark.parametrize("hidden_size", [64, 3392, 4096])
@pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049])
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda",
reason="Only test on CUDA")
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"],
reason="Only test on CUDA and Rocm")
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps):
torch.set_default_device("cuda")
torch.set_default_dtype(torch.float16)
Expand Down
3 changes: 1 addition & 2 deletions tests/kernels/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@
import torch

from vllm.platforms import current_platform
from vllm.utils import FP8_DTYPE

# Using the default value (240.0) from pytorch will cause accuracy
# issue on dynamic quantization models. Here use 224.0 for rocm.
ROCM_FP8_MAX = 224.0
FP8_DTYPE = torch.float8_e4m3fnuz if current_platform.is_rocm() \
else torch.float8_e4m3fn


def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
Expand Down
4 changes: 2 additions & 2 deletions tests/kernels/test_fp8_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
import torch

import vllm._custom_ops as ops
from tests.kernels.quant_utils import (FP8_DTYPE,
ref_dynamic_per_tensor_fp8_quant,
from tests.kernels.quant_utils import (ref_dynamic_per_tensor_fp8_quant,
ref_dynamic_per_token_quant)
from tests.kernels.utils import opcheck
from vllm.platforms import current_platform
from vllm.utils import FP8_DTYPE

DTYPES = [torch.half, torch.bfloat16, torch.float]
HIDDEN_SIZES = [1, 2, 3, 4, 16, 67, 768, 2048, 5120, 5137, 8192,
Expand Down
6 changes: 2 additions & 4 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType
from vllm.utils import FP8_DTYPE

logger = init_logger(__name__)

Expand Down Expand Up @@ -703,12 +704,9 @@ def scaled_fp8_quant(
# This code assumes batch_dim and num_tokens are flattened
assert (input.ndim == 2)
shape: Union[Tuple[int, int], torch.Size] = input.shape
# For rocm, the output fp8 dtype is torch.float_e3m3fnuz
out_dtype: torch.dtype = torch.float8_e4m3fnuz \
if current_platform.is_rocm() else torch.float8_e4m3fn
if num_token_padding:
shape = (max(num_token_padding, input.shape[0]), shape[1])
output = torch.empty(shape, device=input.device, dtype=out_dtype)
output = torch.empty(shape, device=input.device, dtype=FP8_DTYPE)

if scale is None:
if use_per_token_if_dynamic:
Expand Down
4 changes: 2 additions & 2 deletions vllm/compilation/fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from vllm.config import CompilationConfig
from vllm.logger import init_logger
from vllm.utils import FP8_DTYPE

from .vllm_inductor_pass import VllmInductorPass, is_func

Expand Down Expand Up @@ -82,8 +83,7 @@ def empty_bf16(*args, **kwargs):


def empty_fp8(*args, **kwargs):
fp8 = torch.float8_e4m3fn
return torch.empty(*args, **kwargs, dtype=fp8, device="cuda")
return torch.empty(*args, **kwargs, dtype=FP8_DTYPE, device="cuda")


def empty_fp32(*args, **kwargs):
Expand Down
5 changes: 2 additions & 3 deletions vllm/model_executor/layers/quantization/utils/w8a8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def apply_fp8_linear(
qinput, x_scale = ops.scaled_fp8_quant(
input_2d,
input_scale,
num_token_padding=17,
num_token_padding=17 if current_platform.is_cuda() else None,
use_per_token_if_dynamic=use_per_token_if_dynamic)

per_tensor_weights = (weight_scale.numel() == 1)
Expand All @@ -144,8 +144,7 @@ def apply_fp8_linear(
if type(output) is tuple and len(output) == 2:
output = output[0]

return torch.narrow(output, 0, 0,
input_2d.shape[0]).view(*output_shape)
return output[0:input_2d.shape[0], ...].view(*output_shape)

else:
# Fallback for channelwise case, where we use unfused DQ
Expand Down
4 changes: 4 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,10 @@
torch.int64: np.int64,
}

# If rocm, use float8_e4m3fnuz for float8.
FP8_DTYPE = torch.float8_e4m3fnuz if current_platform.is_rocm(
) else torch.float8_e4m3fn

P = ParamSpec('P')
K = TypeVar("K")
T = TypeVar("T")
Expand Down
Loading