Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@
from typing import Callable, List, Optional

import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter

from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501
dequantize_to_dtype)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
apply_fp4_marlin_linear, prepare_fp4_layer_for_marlin)
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)
Expand All @@ -31,6 +30,10 @@ def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition

# Weight
weight = ModelWeightParameter(data=torch.empty(
Expand Down Expand Up @@ -60,48 +63,30 @@ def create_weights(self, layer: torch.nn.Module,

layer.register_parameter("weight_scale", weight_scale)

def swizzle_blockscale(self, scale: torch.tensor):
assert (scale.dtype == torch.float8_e4m3fn)
# Pad and blockwise interleave weight_scale
scale_ndim = scale.ndim
if scale.ndim == 2:
scale = scale.unsqueeze(0)
assert scale.ndim == 3
B, M, K = scale.shape
round_up_multiple = lambda x, m: (x + m - 1) // m * m
M_padded = round_up_multiple(M, 128)
K_padded = round_up_multiple(K, 4)
padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype)
padded_scale[:B, :M, :K] = scale
batches, rows, cols = padded_scale.shape
assert rows % 128 == 0
assert cols % 4 == 0
padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32,
cols // 4, 4)
swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
swizzled_scale = swizzled_scale.contiguous().cuda()
return (swizzled_scale.reshape(M, K)
if scale_ndim == 2 else swizzled_scale.reshape(B, M, K))

def process_weights_after_loading(self, layer) -> None:
layer.weight_global_scale = Parameter(
layer.weight_global_scale.max().to(torch.float32),
# Process parameters for marlin repacking

# Rename weight_packed to weight that marlin expects
layer.weight = Parameter(layer.weight_packed.data, requires_grad=False)
del layer.weight_packed
# Rename weight_global_scale to weight_scale_2 that marlin expects
# Note: ct stores the inverse of what is expected by the marlin kernel
layer.weight_scale_2 = Parameter(
1 / layer.weight_global_scale.max().to(torch.float32),
requires_grad=False)
# Note: a post weight loading step but not required for the emulation
swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale)
layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
requires_grad=False)
del layer.weight_global_scale

prepare_fp4_layer_for_marlin(layer)

def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:

w_fp4 = layer.weight_packed.data
w_global_scale = layer.weight_global_scale
w_blockscale = layer.weight_scale_swizzled.data
w_dq = dequantize_to_dtype(w_fp4, w_blockscale, w_global_scale,
x.dtype, x.device, self.group_size)
out = F.linear(x, w_dq)
del w_dq, w_fp4, w_global_scale, w_blockscale
return out
return apply_fp4_marlin_linear(input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
weight_scale_2=layer.weight_scale_2,
workspace=layer.workspace,
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
bias=bias)