Skip to content

Commit 1b31d57

Browse files
Merge branch 'main' into fix/gemma3-gguf-quantization
2 parents c2bc592 + a269173 commit 1b31d57

File tree

5 files changed

+85
-56
lines changed

5 files changed

+85
-56
lines changed

csrc/cache_kernels.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
#include <algorithm>
1818
#include <cassert>
19-
#include <cfloat> // FLT_MIN
19+
#include <cfloat>
2020

2121
#ifdef USE_ROCM
2222
#include <hip/hip_bf16.h>
@@ -479,6 +479,7 @@ __global__ void concat_and_cache_ds_mla_kernel(
479479

480480
// Compute the scale for the tile
481481
float tile_scale = max_abs / 448.f;
482+
tile_scale = fmaxf(tile_scale, FLT_MIN);
482483

483484
// The first lane of each half-warp writes the scale to kv_cache
484485
if ((lane_idx == 0) || (lane_idx == 16)) {

vllm/executor/ray_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from vllm.platforms import current_platform
1717
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
1818
from vllm.utils import get_ip
19+
from vllm.v1.outputs import AsyncModelRunnerOutput
1920
from vllm.v1.worker.worker_base import WorkerWrapperBase
2021

2122
if TYPE_CHECKING:
@@ -142,6 +143,11 @@ def execute_model_ray(
142143
# but may still be finished requests.
143144
assert not output or not output.req_ids
144145
output = scheduler_output, None
146+
# Ensure outputs crossing Ray compiled DAG are serializable.
147+
# AsyncModelRunnerOutput holds CUDA events and cannot be
148+
# pickled.
149+
if isinstance(output, AsyncModelRunnerOutput):
150+
output = output.get_output()
145151
return output
146152

147153
def override_env_vars(self, vars: Dict[str, str]):

vllm/model_executor/layers/quantization/kv_cache.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
8686
logger.warning_once(
8787
"Checkpoint does not provide a q scaling factor. "
8888
"Setting it to k_scale. This only matters for "
89-
"the flash-attn backend.")
89+
"FP8 Attention backends (flash-attn or flashinfer).")
9090
layer._q_scale.copy_(k_scale)
9191
layer._q_scale_float = k_scale
9292

@@ -98,9 +98,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
9898
if (k_scale == 1.0 and v_scale == 1.0
9999
and "e5m2" not in layer.kv_cache_dtype):
100100
logger.warning_once(
101-
"Using KV cache scaling factor 1.0 for fp8_e4m3. This "
102-
"may cause accuracy issues. Please make sure k/v_scale "
103-
"scaling factors are available in the fp8 checkpoint.")
101+
"Using KV cache scaling factor 1.0 for fp8_e4m3. "
102+
"If this is unintended, verify that k/v_scale "
103+
"scaling factors are properly set in the checkpoint.")
104104

105105
if layer.q_scale > 0.0:
106106
q_scale = layer.q_scale

vllm/model_executor/layers/quantization/modelopt.py

Lines changed: 66 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1064,7 +1064,7 @@ def __init__(
10641064
self.allow_flashinfer = _nvfp4.allow_flashinfer
10651065
self.use_marlin = _nvfp4.use_marlin
10661066
self.flashinfer_moe_backend = None
1067-
1067+
self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
10681068
if self.allow_flashinfer:
10691069
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
10701070
logger.info_once(
@@ -1197,19 +1197,23 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
11971197
weight_loader=weight_loader)
11981198
layer.register_parameter("w2_input_scale", w2_input_scale)
11991199

1200-
def prepare_static_weight_layouts_for_trtllm_moe(
1200+
def prepare_static_weights_for_trtllm_fp4_moe(
12011201
self,
1202-
gemm1_weights: torch.Tensor,
1203-
gemm2_weights: torch.Tensor,
1204-
gemm1_scales_linear_fp4_bytes: torch.Tensor,
1205-
gemm2_scales_linear_fp4_bytes: torch.Tensor,
1206-
hidden_size: int,
1207-
intermediate_size: int,
1208-
num_experts: int,
1209-
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
1202+
# args_dequant,
1203+
# args,
1204+
gemm1_weights,
1205+
gemm2_weights,
1206+
gemm1_scales_linear_fp4_bytes,
1207+
gemm2_scales_linear_fp4_bytes,
1208+
hidden_size,
1209+
intermediate_size,
1210+
num_experts,
1211+
):
1212+
from flashinfer import nvfp4_block_scale_interleave
1213+
from flashinfer.fused_moe.core import (
1214+
_maybe_get_cached_w2_permute_indices,
1215+
_maybe_get_cached_w3_w1_permute_indices)
12101216
"""Prepare quantized weights for kernel (done offline with weights)."""
1211-
from flashinfer import (reorder_rows_for_gated_act_gemm,
1212-
shuffle_matrix_a, shuffle_matrix_sf_a)
12131217
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
12141218

12151219
# Convert quantized weights to proper formats
@@ -1227,48 +1231,54 @@ def prepare_static_weight_layouts_for_trtllm_moe(
12271231
intermediate_size //
12281232
16) # fp8 scaling factors
12291233

1230-
# Reorder rows of W1 and scales for fused gated activation
1231-
gemm1_weights_fp4_interleaved = []
1232-
gemm1_scales_fp4_interleaved = []
1233-
for i in range(num_experts):
1234-
gemm1_weights_fp4_interleaved.append(
1235-
reorder_rows_for_gated_act_gemm(gemm1_weights_fp4[i].clone()))
1236-
gemm1_scales_fp4_interleaved.append(
1237-
reorder_rows_for_gated_act_gemm(
1238-
gemm1_scales_linear_fp4[i].clone()))
1239-
1240-
# Stack weights and scales for all experts
1241-
gemm1_weights_fp4_interleaved = torch.stack(
1242-
gemm1_weights_fp4_interleaved).reshape(num_experts,
1243-
2 * intermediate_size,
1244-
hidden_size // 2)
1245-
gemm1_scales_fp4_interleaved = torch.stack(
1246-
gemm1_scales_fp4_interleaved).reshape(num_experts,
1247-
2 * intermediate_size,
1248-
hidden_size // 16)
1249-
1250-
# Shuffle weights and scaling factors for transposed mma output
12511234
gemm1_weights_fp4_shuffled = []
12521235
gemm1_scales_fp4_shuffled = []
12531236
gemm2_weights_fp4_shuffled = []
12541237
gemm2_scales_fp4_shuffled = []
12551238
for i in range(num_experts):
1256-
gemm1_weights_fp4_shuffled.append(
1257-
shuffle_matrix_a(
1258-
gemm1_weights_fp4_interleaved[i].view(torch.uint8),
1259-
epilogue_tile_m))
1239+
# Calculate the permute indices for the following:
1240+
# 1. Reorder rows of W1 and scales for fused gated activation
1241+
# 2. Shuffle weights and scaling factors for transposed mma output
1242+
# for both w3_w1 and w2 weights and scale factors
1243+
permute_indices = _maybe_get_cached_w3_w1_permute_indices(
1244+
self._cache_permute_indices,
1245+
gemm1_weights_fp4[i].view(torch.uint8),
1246+
epilogue_tile_m,
1247+
)
1248+
gemm1_weights_fp4_shuffled.append(gemm1_weights_fp4[i].view(
1249+
torch.uint8)[permute_indices.to(
1250+
gemm1_weights_fp4.device)].contiguous())
1251+
1252+
permute_sf_indices = _maybe_get_cached_w3_w1_permute_indices(
1253+
self._cache_permute_indices,
1254+
gemm1_scales_linear_fp4[i].view(torch.uint8),
1255+
epilogue_tile_m,
1256+
num_elts_per_sf=16,
1257+
)
12601258
gemm1_scales_fp4_shuffled.append(
1261-
shuffle_matrix_sf_a(
1262-
gemm1_scales_fp4_interleaved[i].view(torch.uint8),
1263-
epilogue_tile_m))
1264-
1265-
gemm2_weights_fp4_shuffled.append(
1266-
shuffle_matrix_a(gemm2_weights_fp4[i].view(torch.uint8),
1267-
epilogue_tile_m))
1259+
nvfp4_block_scale_interleave(gemm1_scales_linear_fp4[i].view(
1260+
torch.uint8)[permute_sf_indices.to(
1261+
gemm1_scales_linear_fp4.device)].contiguous()))
1262+
1263+
permute_indices = _maybe_get_cached_w2_permute_indices(
1264+
self._cache_permute_indices,
1265+
gemm2_weights_fp4[i].view(torch.uint8),
1266+
epilogue_tile_m,
1267+
)
1268+
gemm2_weights_fp4_shuffled.append(gemm2_weights_fp4[i].view(
1269+
torch.uint8)[permute_indices.to(
1270+
gemm2_weights_fp4.device)].contiguous())
1271+
1272+
permute_sf_indices = _maybe_get_cached_w2_permute_indices(
1273+
self._cache_permute_indices,
1274+
gemm2_scales_linear_fp4[i].view(torch.uint8),
1275+
epilogue_tile_m,
1276+
num_elts_per_sf=16,
1277+
)
12681278
gemm2_scales_fp4_shuffled.append(
1269-
shuffle_matrix_sf_a(
1270-
gemm2_scales_linear_fp4[i].view(torch.uint8),
1271-
epilogue_tile_m))
1279+
nvfp4_block_scale_interleave(gemm2_scales_linear_fp4[i].view(
1280+
torch.uint8)[permute_sf_indices.to(
1281+
gemm2_scales_linear_fp4.device)].contiguous()))
12721282

12731283
# Stack weights for all experts
12741284
gemm1_weights_fp4_shuffled = torch.stack(gemm1_weights_fp4_shuffled)
@@ -1283,8 +1293,12 @@ def prepare_static_weight_layouts_for_trtllm_moe(
12831293
torch.stack(gemm2_scales_fp4_shuffled).view(
12841294
torch.float8_e4m3fn).reshape(num_experts, hidden_size,
12851295
intermediate_size // 16))
1286-
return (gemm1_weights_fp4_shuffled, gemm1_scales_fp4_shuffled,
1287-
gemm2_weights_fp4_shuffled, gemm2_scales_fp4_shuffled)
1296+
return (
1297+
gemm1_weights_fp4_shuffled,
1298+
gemm1_scales_fp4_shuffled,
1299+
gemm2_weights_fp4_shuffled,
1300+
gemm2_scales_fp4_shuffled,
1301+
)
12881302

12891303
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
12901304
# GEMM 1 processing
@@ -1334,9 +1348,10 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
13341348
if self.allow_flashinfer and \
13351349
self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
13361350
# Prepare static weights for TRT-LLM kernel
1351+
# alternate: prepare_static_weight_layouts_for_trtllm_moe
13371352
(gemm1_weights_fp4_shuffled, gemm1_scales_fp4_shuffled,
13381353
gemm2_weights_fp4_shuffled, gemm2_scales_fp4_shuffled
1339-
) = self.prepare_static_weight_layouts_for_trtllm_moe(
1354+
) = self.prepare_static_weights_for_trtllm_fp4_moe(
13401355
layer.w13_weight,
13411356
layer.w2_weight,
13421357
layer.w13_weight_scale,
@@ -1345,6 +1360,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
13451360
layer.w13_weight.size(-2) // 2, # intermediate_size
13461361
layer.w13_weight.size(0), # num_experts
13471362
)
1363+
logger.debug_once("Finished shuffling weights for TRT-LLM MOE")
13481364

13491365
layer.gemm1_weights_fp4_shuffled = Parameter(
13501366
gemm1_weights_fp4_shuffled, requires_grad=False)

vllm/model_executor/model_loader/weight_utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1061,12 +1061,18 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
10611061
return None
10621062
return remapped_name
10631063

1064+
if any("mla_attn" in key for key in params_dict):
1065+
attn_str = "mla_attn.mla_attn"
1066+
logger.debug_once(f"Found mla_attn with k_scale and v_scale in "
1067+
f"the checkpoint, using {attn_str} as attn_str")
1068+
else:
1069+
attn_str = "attn"
10641070
# Define scale name mapping patterns in order of precedence
10651071
scale_mapping_patterns = [
10661072
# ModelOpt format: .self_attn.{k,v}_proj.{k,v}_scale ->
10671073
# .self_attn.attn.{k,v}_scale
10681074
(r"\.self_attn\.([kv])_proj\.([kv])_scale$",
1069-
r".self_attn.attn.\2_scale"),
1075+
rf".self_attn.{attn_str}.\2_scale"),
10701076
# QKV proj format: .self_attn.qkv_proj.{k,v}_scale ->
10711077
# .self_attn.attn.{k,v}_scale
10721078
(r"\.self_attn\.qkv_proj\.([kv])_scale$", r".self_attn.attn.\1_scale"),

0 commit comments

Comments
 (0)