diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 9ea883d4a03c..c3fff1b4f207 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -29,6 +29,7 @@ logger = init_logger(__name__) +ExpertPlacementStrategy = Literal["linear", "round_robin"] DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"] diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index cab610decf90..a829ea7543da 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -12,11 +12,15 @@ from vllm.config import ParallelConfig from vllm.distributed import get_dp_group, get_tensor_model_parallel_rank from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.utils import cdiv from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe +from triton_kernels.matmul_ogs import PrecisionConfig + logger = init_logger(__name__) @@ -61,43 +65,276 @@ def get_config_quant_dtype( return None +def _quant_flags_to_group_shape( + quant_dtype: Union[torch.dtype, str, None], + per_act_token_quant: bool, + per_out_ch_quant: bool, + block_shape: Optional[list[int]], +) -> tuple[Optional[GroupShape], Optional[GroupShape]]: + """ + Convert MoE quantization flags into more generic GroupShapes. + """ + a_shape: Optional[GroupShape] + w_shape: Optional[GroupShape] + if block_shape is not None: + assert not per_act_token_quant + assert not per_out_ch_quant + # TODO(bnell): this is not quite right for activations since first + # dim should be 1. + a_shape = GroupShape(row=block_shape[0], col=block_shape[1]) + w_shape = GroupShape(row=block_shape[0], col=block_shape[1]) + else: + w_shape = None + a_shape = None if quant_dtype is None else GroupShape.PER_TENSOR + + if per_act_token_quant: + a_shape = GroupShape.PER_TOKEN + + if per_out_ch_quant: + w_shape = GroupShape.PER_TOKEN + + return a_shape, w_shape + + @dataclass -class FusedMoEQuantConfig: - # The post quantization activation type. +class FusedMoEQuantDesc: + """ + A quantization descriptor for fused MoE ops. This class can describe + either activations or weights. + """ + + # The quantized type of this parameters. None means unquantized or + # already quantized. # TODO (bnell): use scalar_type instead of Union. - quant_dtype: Union[torch.dtype, str, None] = None - per_act_token_quant: bool = False - per_out_ch_quant: bool = False - block_shape: Optional[list[int]] = None + dtype: Union[torch.dtype, str, None] = None - # TODO: add col major flag? - # add detailed quant info for input, intermediates, weights, etc? + # A field that describes the quantization group shape, from quant_utils.py. + # * (-1, -1) for per-tensor quantization + # * (1, -1) for per-row quantization + # * (-1, 1) for per-column quantization + # * (128, 128) for 128x128 deepseek style block quantization + # * (1, 128) for deepseek style activation quantization + # (i.e. per-token-per-group) + shape: Optional[GroupShape] = None + + # Quantization scales. + # TODO(bnell): maybe put PrecisionConfigs in subclass of QuantDesc? + scale: Union[torch.Tensor, "PrecisionConfig", None] = None + + # Quantization alphas or gscales, used for nvfp4 types. + # TODO(bnell): put some of these in subclasses + alpha_or_gscale: Optional[torch.Tensor] = None + + # Zero points for int4/int8 types + zp: Optional[torch.Tensor] = None + + # Biases for GPT triton MoE + bias: Optional[torch.Tensor] = None + + +# TODO(bnell): have subclasses for specific moe methods? +# e.g. for specific arguments bias, precision, etc. +@dataclass +class FusedMoEQuantConfig: + """ + The FusedMoEQuantConfig contains all the quantization parameters for + a single FusedMoEMethodBase operation. It consists of four + FusedMoEQuantDescs, one for each activation and set of weights. + + Each FusedMoEMethodBase must implement a get_fused_moe_quant_config + method to construct a FusedMoEQuantConfig for use with that class. + + FusedMoEQuant configs are only used for modular kernels, fused_experts + (from fused_moe.py), cutlass_moe_fp[48], rocm_aiter_fused_experts and + triton_kernel_moe_forward. Other MoE methods can ignore the + FusedMoEQuantConfig (for now) and hardcode it to None. + + There are currently some restrictions on what can be expressed: + - Most MoE ops only support similar quantization strategies for + each parameter, e.g. both weights must have the same GroupShape + and both activations must share the same GroupShape. One exception to + this is the cutlass moe which allows per channel quantization on the + outputs. Note: this restrictions are not always rigorously checked. + - Not all fused MoE functions support all the parameters, e.g. zero points, + global scales, alphas and biases are not universally supported. + - Fully general GroupShapes are not allowed. Activations only support + per token, per tensor or K-blocked. + - Weights are not required to have a GroupShape since they have already + been quantized. + + Other notes: + - PrecisionConfigs are specific to GPT OSS Triton. + - As a follow up it would probably make sense to subclass FusedMoEQuantDesc + or FusedMoEQuantConfig for particular FusedMoEMethodBase subclasses + so that only the required quantization parameters are used/stored. + """ + + # TODO(bnell) make sure a1_scales/a2_scales don't interfere with chunking + _a1: FusedMoEQuantDesc + _a2: FusedMoEQuantDesc + _w1: FusedMoEQuantDesc + _w2: FusedMoEQuantDesc def __post_init__(self): assert (not self.per_act_token_quant or self.block_shape is None), "illegal quantization" + # + # Convenience accessors for various properties. + # + + @property + def quant_dtype(self) -> Union[torch.dtype, str, None]: + return self._a1.dtype + @property def is_quantized(self) -> bool: return self.quant_dtype is not None @property def is_per_act_token(self) -> bool: - return self.per_act_token_quant + return self._a1.shape == GroupShape.PER_TOKEN + + @property + def per_act_token_quant(self) -> bool: + return self._a1.shape == GroupShape.PER_TOKEN + + @property + def per_out_ch_quant(self) -> bool: + return self._w1.shape == GroupShape.PER_TOKEN + + @property + def is_per_tensor(self) -> bool: + return self._a1.shape == GroupShape.PER_TENSOR + + @property + def block_shape(self) -> Optional[list[int]]: + if (self._a1.shape is not None + and self._a1.shape != GroupShape.PER_TENSOR + and self._a1.shape != GroupShape.PER_TOKEN): + return [self._a1.shape.row, self._a1.shape.col] + else: + return None @property def is_block_quantized(self) -> bool: return self.block_shape is not None @property - def is_per_tensor(self) -> bool: - return not self.per_act_token_quant and self.block_shape is None + def a1_scale(self) -> Optional[torch.Tensor]: + assert self._a1.scale is None or isinstance(self._a1.scale, + torch.Tensor) + return self._a1.scale + + @property + def a1_gscale(self) -> Optional[torch.Tensor]: + return self._a1.alpha_or_gscale + + @property + def a2_scale(self) -> Optional[torch.Tensor]: + assert self._a2.scale is None or isinstance(self._a2.scale, + torch.Tensor) + return self._a2.scale + + @property + def a2_gscale(self) -> Optional[torch.Tensor]: + return self._a2.alpha_or_gscale + + @property + def w1_scale(self) -> Optional[torch.Tensor]: + assert self._w1.scale is None or isinstance(self._w1.scale, + torch.Tensor) + return self._w1.scale + + @property + def w1_zp(self) -> Optional[torch.Tensor]: + return self._w1.zp + + @property + def w1_bias(self) -> Optional[torch.Tensor]: + return self._w1.bias + + @property + def w1_precision(self) -> Optional["PrecisionConfig"]: + assert self._w1.scale is None or isinstance(self._w1.scale, + PrecisionConfig) + return self._w1.scale + + @property + def g1_alphas(self) -> Optional[torch.Tensor]: + return self._w1.alpha_or_gscale + + @property + def w2_scale(self) -> Optional[torch.Tensor]: + assert self._w2.scale is None or isinstance(self._w2.scale, + torch.Tensor) + return self._w2.scale + + @property + def w2_zp(self) -> Optional[torch.Tensor]: + return self._w2.zp + + @property + def w2_bias(self) -> Optional[torch.Tensor]: + return self._w2.bias + + @property + def w2_precision(self) -> Optional["PrecisionConfig"]: + assert self._w2.scale is None or isinstance(self._w2.scale, + PrecisionConfig) + return self._w2.scale + + @property + def g2_alphas(self) -> Optional[torch.Tensor]: + return self._w2.alpha_or_gscale + + @property + def use_fp8_w8a8(self) -> bool: + return self.quant_dtype == torch.float8_e4m3fn + + @property + def use_int8_w8a8(self) -> bool: + return self.quant_dtype == torch.int8 + + @property + def use_int8_w8a16(self) -> bool: + return (self._a1.dtype is None and self._w1.dtype == torch.int8) + + @property + def use_int4_w4a16(self) -> bool: + return (self._a1.dtype is None and self._w1.dtype == "int4") + + @property + def use_mxfp4_w4a4(self) -> bool: + return self.quant_dtype == "mxfp4" + + @property + def use_nvfp4_w4a4(self) -> bool: + return self.quant_dtype == "nvfp4" + + def config_name(self, dtype: torch.dtype) -> Optional[str]: + """ + Return a string used to construct the filename that contains the + tuning info for a particular quantization scheme. See + try_get_optimal_moe_config in fused_moe.py. + """ + return _get_config_dtype_str( + use_fp8_w8a8=self.use_fp8_w8a8, + use_int8_w8a16=self.use_int8_w8a16, + use_int4_w4a16=self.use_int4_w4a16, + use_mxfp4_w4a4=self.use_mxfp4_w4a4, + dtype=dtype, + ) def scale_shape( self, max_tokens: int, hidden_dim: int, ) -> Optional[tuple[int, int]]: + """ + Construct the proper activation scale shape for this + config. + """ if self.is_quantized: if self.is_block_quantized: assert self.block_shape is not None @@ -117,6 +354,10 @@ def batched_scale_shape( max_tokens: int, hidden_dim: int, ) -> Optional[tuple[int, int, int]]: + """ + Construct the proper activation batched scale shape for this + config, e.g. (num experts, *scale_shape). + """ if self.is_quantized: scale_shape = self.scale_shape(max_tokens, hidden_dim) assert scale_shape is not None @@ -126,38 +367,154 @@ def batched_scale_shape( @staticmethod def make( - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - use_mxfp4_w4a4: bool = False, + quant_dtype: Union[torch.dtype, str, None] = None, per_act_token_quant: bool = False, per_out_ch_quant: bool = False, block_shape: Optional[list[int]] = None, + w1_scale: Union[torch.Tensor, "PrecisionConfig", None] = None, + w2_scale: Union[torch.Tensor, "PrecisionConfig", None] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + g1_alphas: Optional[torch.Tensor] = None, + g2_alphas: Optional[torch.Tensor] = None, + a1_gscale: Optional[torch.Tensor] = None, + a2_gscale: Optional[torch.Tensor] = None, + w1_bias: Optional[torch.Tensor] = None, + w2_bias: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, ) -> "FusedMoEQuantConfig": - assert sum([ - int(flag) for flag in [ - use_fp8_w8a8, - use_int8_w8a8, - use_int8_w8a16, - use_int4_w4a16, - use_mxfp4_w4a4, - ] - ]) <= 1, "Quantization flags are mutually exclusive." - - quant_dtype = get_config_quant_dtype( - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - use_mxfp4_w4a4=use_mxfp4_w4a4, - ) - return FusedMoEQuantConfig( - quant_dtype, - per_act_token_quant, - per_out_ch_quant, - block_shape, + """ + General builder function for a FusedMoEQuantConfig. + - quant_dtype: Optional quantization type. None if activations are + unquantized or quantized prior to calling. Note: "nvfp4" and + "mxfp4" are the only valid string values for quant_dtype. + - per_act_token_quant: Activations have per token quantization. + - per_out_ch_quant: Outputs have per channel quantization. (only + for cutlass). + - block_shape: Optional block size for block-wise quantization. + Incompatible with per_act_token and per_out_ch quant. + - w1_scale: Optional scale to be used for w1. + - w2_scale: Optional scale to be used for w2. + - a1_scale: Optional scale to be used for a1. + - a2_scale: Optional scale to be used for a2. + - g1_alphas: Optional global quantization scales for w1 (for nvfp4). + - g2_alphas: Optional global quantization scales for w2 (for nvfp4). + - a1_gscale: Optional global quantization scales for a1 (for nvfp4). + - a2_gscale: Optional global quantization scales for a2 (for nvfp4). + - w1_bias: Optional biases for w1 (GPT OSS Triton). + - w2_bias: Optional biases for w1 (GPT OSS Triton). + - w1_zp: Optional w1 zero points for int4/int8 quantization. + - w2_zp: Optional w2 zero points for int4/int8 quantization. + """ + assert (not isinstance(quant_dtype, str) or quant_dtype == "nvfp4" + or quant_dtype == "mxfp4") + a_shape, w_shape = _quant_flags_to_group_shape(quant_dtype, + per_act_token_quant, + per_out_ch_quant, + block_shape) + quant_config = FusedMoEQuantConfig( + _a1=FusedMoEQuantDesc(quant_dtype, a_shape, a1_scale, a1_gscale), + _a2=FusedMoEQuantDesc(quant_dtype, a_shape, a2_scale, a2_gscale), + _w1=FusedMoEQuantDesc(quant_dtype, w_shape, w1_scale, g1_alphas, + w1_zp, w1_bias), + _w2=FusedMoEQuantDesc(quant_dtype, w_shape, w2_scale, g2_alphas, + w2_zp, w2_bias), ) + assert quant_config.per_act_token_quant == per_act_token_quant + assert quant_config.per_out_ch_quant == per_out_ch_quant + assert quant_config.block_shape == block_shape + return quant_config + + +def fp8_w8a8_moe_quant_config( + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + per_act_token_quant: bool = False, + per_out_ch_quant: bool = False, + block_shape: Optional[list[int]] = None, +) -> FusedMoEQuantConfig: + """ + Construct a quant config for fp8 activations and fp8 weights. + """ + return FusedMoEQuantConfig.make(torch.float8_e4m3fn, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + per_act_token_quant=per_act_token_quant, + per_out_ch_quant=per_out_ch_quant, + block_shape=block_shape) + + +def int8_w8a8_moe_quant_config( + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + per_act_token_quant: bool = False, +) -> FusedMoEQuantConfig: + """ + Construct a quant config for int8 activations and int8 weights. + """ + return FusedMoEQuantConfig.make( + torch.int8, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + per_act_token_quant=per_act_token_quant, + per_out_ch_quant=False, + block_shape=None, + ) + + +def mxfp4_w4a16_moe_quant_config( + w1_scale: Union[torch.Tensor, "PrecisionConfig"], + w2_scale: Union[torch.Tensor, "PrecisionConfig"], + w1_bias: Optional[torch.Tensor] = None, + w2_bias: Optional[torch.Tensor] = None) -> FusedMoEQuantConfig: + """ + Construct a quant config for unquantized activations and mxfp4 weights. + """ + return FusedMoEQuantConfig( + _a1=FusedMoEQuantDesc(), + _a2=FusedMoEQuantDesc(), + _w1=FusedMoEQuantDesc("mxfp4", None, w1_scale, None, None, w1_bias), + _w2=FusedMoEQuantDesc("mxfp4", None, w2_scale, None, None, w2_bias), + ) + + +def mxfp4_w4a4_moe_quant_config( + w1_scale: Union[torch.Tensor, "PrecisionConfig"], + w2_scale: Union[torch.Tensor, "PrecisionConfig"], + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + w1_bias: Optional[torch.Tensor] = None, + w2_bias: Optional[torch.Tensor] = None, + block_shape: Optional[list[int]] = None, +) -> FusedMoEQuantConfig: + """ + Construct a quant config for mxfp4 activations and mxfp4 weights. + """ + return FusedMoEQuantConfig.make( + "mxfp4", + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + w1_bias=w1_bias, + w2_bias=w2_bias, + per_act_token_quant=False, + per_out_ch_quant=False, + block_shape=block_shape, + ) + + +# A FusedMoEQuantConfig constant for an unquantized MoE op. +FUSED_MOE_UNQUANTIZED_CONFIG: FusedMoEQuantConfig = FusedMoEQuantConfig.make() @dataclass @@ -321,8 +678,6 @@ class FusedMoEConfig: # The activation type. in_dtype: torch.dtype - quant_config: Optional[FusedMoEQuantConfig] = None - max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE has_bias: bool = False @@ -334,34 +689,6 @@ def __post_init__(self): assert self.max_num_tokens > 0 - @property - def quant_dtype(self) -> Union[torch.dtype, str, None]: - if self.quant_config is not None: - return self.quant_config.quant_dtype - else: - return None - - @property - def block_shape(self) -> Optional[list[int]]: - if self.quant_config is not None: - return self.quant_config.block_shape - else: - return None - - @property - def per_act_token_quant(self) -> bool: - if self.quant_config is not None: - return self.quant_config.per_act_token_quant - else: - return False - - @property - def per_out_ch_quant(self) -> bool: - if self.quant_config is not None: - return self.quant_config.per_out_ch_quant - else: - return False - @property def tp_size(self): return self.moe_parallel_config.tp_size @@ -404,93 +731,9 @@ def use_deepep_ll_kernels(self): @property def use_flashinfer_cutlass_kernels(self): - return self.moe_parallel_config.use_flashinfer_cutlass_kernels - - @staticmethod - def make( - num_experts: int, - experts_per_token: int, - hidden_dim: int, - num_local_experts: int, - moe_parallel_config: FusedMoEParallelConfig, - in_dtype: torch.dtype, - max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE, - quant_config: Optional[Union[FusedMoEQuantConfig, - QuantizationConfig]] = None, - has_bias: bool = False, - ) -> "FusedMoEConfig": - - _quant_config: Optional[FusedMoEQuantConfig] = None - - if quant_config is not None and isinstance(quant_config, - QuantizationConfig): - if hasattr(quant_config, 'weight_block_size'): - block_shape = quant_config.weight_block_size - else: - block_shape = None - per_act_token_quant = False - per_out_ch_quant = False - quant_dtype: Union[torch.dtype, str, None] = None - - input_quant = get_quant_config_input_quant(quant_config) - weight_quant = get_quant_config_weight_quant(quant_config) - - if input_quant is not None: - per_act_token_quant = (input_quant.strategy - == QuantizationStrategy.TOKEN - if input_quant is not None else False) - - if input_quant.num_bits == 8: - if input_quant.type == QuantizationType.FLOAT: - quant_dtype = torch.float8_e4m3fn - elif input_quant.type == QuantizationType.INT: - quant_dtype = torch.int8 - - from vllm.model_executor.layers.quantization.fp8 import Fp8Config - if quant_dtype is None and isinstance(quant_config, Fp8Config): - quant_dtype = torch.float8_e4m3fn - - from vllm.model_executor.layers.quantization.mxfp4 import ( - Mxfp4Config) - if (quant_dtype is None and isinstance(quant_config, Mxfp4Config) - and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8): - quant_dtype = "mxfp8" - - from vllm.model_executor.layers.quantization.modelopt import ( - ModelOptNvFp4Config) - if quant_dtype is None and isinstance(quant_config, - ModelOptNvFp4Config): - quant_dtype = "nvfp4" - - if weight_quant is not None: - per_out_ch_quant = ( - weight_quant.strategy == QuantizationStrategy.CHANNEL) - - if quant_dtype is not None: - _quant_config = FusedMoEQuantConfig( - quant_dtype=quant_dtype, - per_act_token_quant=per_act_token_quant, - per_out_ch_quant=per_out_ch_quant, - block_shape=block_shape, - ) - else: - _quant_config = FusedMoEQuantConfig() - if moe_parallel_config.dp_size > 1: - logger.warning_once("MoE DP setup unable to determine " - "quantization scheme or unsupported " - "quantization type. This model will " - "not run with DP enabled.") - else: - _quant_config = quant_config - - return FusedMoEConfig( - num_experts=num_experts, - experts_per_token=experts_per_token, - hidden_dim=hidden_dim, - num_local_experts=num_local_experts, - moe_parallel_config=moe_parallel_config, - in_dtype=in_dtype, - quant_config=_quant_config, - max_num_tokens=max_num_tokens, - has_bias=has_bias, - ) + """ + Whether to use FlashInfer cutlass kernels for NVFP4 MoE. + """ + return (envs.VLLM_USE_FLASHINFER_MOE_FP4 + and has_flashinfer_cutlass_fused_moe() + and envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput") \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py index 312befe2c1d7..71bb0e5b8168 100644 --- a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py +++ b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py @@ -6,9 +6,13 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.config import ( + FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEQuantConfig) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceDelegate) from vllm.utils import has_triton_kernels +from aiter.ops.triton.moe_op_gemm_a8w4 import moe_gemm_a8w4, downcast_to_static_fp8 + logger = init_logger(__name__) @@ -35,20 +39,10 @@ def triton_kernel_moe_forward( topk: int, renormalize: bool, activation: str = "silu", + quant_config: Optional[FusedMoEQuantConfig] = None, apply_router_weight_on_input: bool = False, - use_fp8_w8a8: bool = False, - per_channel_quant: bool = False, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w1_bias: Optional[torch.Tensor] = None, - w2_bias: Optional[torch.Tensor] = None, - w1_precision: Optional["PrecisionConfig"] = None, - w2_precision: Optional["PrecisionConfig"] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[list[int]] = None, ) -> torch.Tensor: routing_data, gather_idx, scatter_idx = routing(gating_output, @@ -64,20 +58,10 @@ def triton_kernel_moe_forward( gather_idx, scatter_idx, activation=activation, + quant_config=quant_config, apply_router_weight_on_input=apply_router_weight_on_input, - use_fp8_w8a8=use_fp8_w8a8, - per_channel_quant=per_channel_quant, global_num_experts=global_num_experts, - expert_map=expert_map, - w1_scale=w1_scale, - w2_scale=w2_scale, - w1_bias=w1_bias, - w2_bias=w2_bias, - w1_precision=w1_precision, - w2_precision=w2_precision, - a1_scale=a1_scale, - a2_scale=a2_scale, - block_shape=block_shape) + expert_map=expert_map) # This is a triton implementation of the fused_experts function @@ -90,28 +74,23 @@ def triton_kernel_fused_experts( gather_indx, # GatherIndx scatter_indx, # ScatterIndx activation: str = "silu", + quant_config: Optional[FusedMoEQuantConfig] = None, swiglu_alpha: float = 1.702, swiglu_limit: float = 7.0, apply_router_weight_on_input: bool = False, - use_fp8_w8a8: bool = False, - per_channel_quant: bool = False, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w1_bias: Optional[torch.Tensor] = None, - w2_bias: Optional[torch.Tensor] = None, - w1_precision: Optional["PrecisionConfig"] = None, - w2_precision: Optional["PrecisionConfig"] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[list[int]] = None, + a1q_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if quant_config is None: + quant_config = FUSED_MOE_UNQUANTIZED_CONFIG # type check, uint8 means mxfp4 assert hidden_states.dtype == torch.bfloat16 - assert w1_bias is None or w1_bias.dtype == torch.float32 - assert w2_bias is None or w2_bias.dtype == torch.float32 + assert (quant_config.w1_bias is None + or quant_config.w1_bias.dtype == torch.float32) + assert (quant_config.w2_bias is None + or quant_config.w2_bias.dtype == torch.float32) # Shape check, only check non-mxfp4 assert hidden_states.shape[-1] == w1.shape[-2] @@ -127,26 +106,35 @@ def triton_kernel_fused_experts( (swiglu_alpha, swiglu_limit), 2) gammas = routing_data.gate_scal if routing_data else None - intermediate_cache1 = matmul_ogs( - hidden_states, - w1, - w1_bias, - routing_data, - gather_indx=gather_indx, - precision_config=w1_precision, - gammas=gammas if apply_router_weight_on_input else None, - fused_activation=act) + hidden_states = downcast_to_static_fp8(hidden_states, quant_config.w1_precision.flex_ctx.lhs_data.scale) + + intermediate_cache1 = moe_gemm_a8w4(hidden_states, + w1.storage.data, + None, + quant_config.w1_precision.weight_scale.storage.data, + quant_config.w1_precision.flex_ctx.lhs_data.scale, + quant_config.w2_precision.flex_ctx.lhs_data.scale, + quant_config.w1_bias, routing_data, + gather_indx=gather_indx, + gammas=gammas if apply_router_weight_on_input else None, + swizzle_mx_scale="CDNA4_SCALE", + out_dtype=torch.float8_e4m3fn, + apply_swiglu=True, + alpha=swiglu_alpha, + limit=swiglu_limit) + + intermediate_cache3 = moe_gemm_a8w4(intermediate_cache1, + w2.storage.data, + None, + quant_config.w2_precision.weight_scale.storage.data, + quant_config.w2_precision.flex_ctx.lhs_data.scale, + None, + quant_config.w2_bias, + routing_data, + scatter_indx=scatter_indx, + gammas=None if apply_router_weight_on_input else gammas, + swizzle_mx_scale="CDNA4_SCALE") - intermediate_cache3 = matmul_ogs( - intermediate_cache1, - w2, - w2_bias, - routing_data, - scatter_indx=scatter_indx, - precision_config=w2_precision, - gammas=None if apply_router_weight_on_input else gammas, - y=output_tensor, - ) return intermediate_cache3 @@ -154,21 +142,13 @@ class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, - quant_config, max_num_tokens: int, num_dispatchers: int, - w1_precision: "PrecisionConfig", - w2_precision: "PrecisionConfig", - w1_bias: Optional[torch.Tensor], - w2_bias: Optional[torch.Tensor], + quant_config: FusedMoEQuantConfig, ): super().__init__(quant_config) self.max_num_tokens = max_num_tokens self.num_dispatchers = num_dispatchers - self.w1_precision = w1_precision - self.w2_precision = w2_precision - self.w1_bias = w1_bias - self.w2_bias = w2_bias @property def activation_formats( @@ -212,12 +192,7 @@ def apply( activation: str, global_num_experts: int, expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], @@ -228,20 +203,12 @@ def apply( hidden_states, w1, w2, - None, - None, - None, + routing_data=None, + gather_indx=None, + scatter_indx=None, activation=activation, + quant_config=self.quant_config, apply_router_weight_on_input=False, - use_fp8_w8a8=False, - per_channel_quant=False, global_num_experts=global_num_experts, expert_map=expert_map, - w1_scale=w1_scale, - w2_scale=w2_scale, - w1_bias=self.w1_bias, - w2_bias=self.w2_bias, - w1_precision=self.w1_precision, - w2_precision=self.w2_precision, - a1_scale=a1q_scale, - a2_scale=a2_scale) + a1q_scale=a1q_scale) \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 9a19888f2ee1..26ff4f1888c5 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -4,7 +4,7 @@ from abc import abstractmethod from collections.abc import Iterable from enum import Enum -from typing import Callable, Literal, Optional, overload +from typing import Callable, Literal, Optional, Union, overload import torch import torch.nn.functional as F @@ -12,6 +12,7 @@ import vllm.envs as envs from vllm.config import get_current_vllm_config +from vllm.config.parallel import ExpertPlacementStrategy from vllm.distributed import (get_dp_group, get_ep_group, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) @@ -21,7 +22,8 @@ from vllm.model_executor.custom_op import CustomOp # yapf: disable from vllm.model_executor.layers.fused_moe.config import ( - FusedMoEConfig, FusedMoEParallelConfig) + FusedMoEConfig, FusedMoEParallelConfig, + FusedMoEQuantConfig) # yapf: enable from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEActivationFormat, FusedMoEModularKernel, @@ -76,11 +78,11 @@ class FusedMoeWeightScaleSupported(Enum): class FusedMoEMethodBase(QuantizeMethodBase): - # TODO(bnell): also pass quant_config? def __init__(self, moe: FusedMoEConfig): super().__init__() self.moe = moe - self.fused_experts: Optional[Callable] = None + self.moe_quant_config: Optional[FusedMoEQuantConfig] = None + self.fused_experts: Optional[FusedMoEModularKernel] = None self.topk_indices_dtype = None @abstractmethod @@ -101,23 +103,28 @@ def uses_weight_scale_2_pattern(self) -> bool: @staticmethod def _maybe_make_prepare_finalize( - moe: FusedMoEConfig, ) -> Optional[FusedMoEPrepareAndFinalize]: + moe: FusedMoEConfig, + quant_config: Optional[FusedMoEQuantConfig], + ) -> Optional[FusedMoEPrepareAndFinalize]: all2all_manager = get_ep_group().device_communicator.all2all_manager assert all2all_manager is not None prepare_finalize: Optional[FusedMoEPrepareAndFinalize] = None + # TODO: could allow this now assert not moe.use_flashinfer_cutlass_kernels, \ "Must be created in modelopt.py" if moe.use_pplx_kernels: + assert quant_config is not None + hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes( moe.max_num_tokens, moe.hidden_dim, moe.in_dtype, - moe.quant_dtype, - per_act_token_quant=moe.per_act_token_quant, - block_shape=moe.block_shape, + quant_config.quant_dtype, + per_act_token_quant=quant_config.per_act_token_quant, + block_shape=quant_config.block_shape, ) all_to_all_args = dict( @@ -163,6 +170,7 @@ def _maybe_make_prepare_finalize( ) elif moe.use_deepep_ll_kernels: + assert quant_config is not None all_to_all_args = dict( max_num_tokens_per_dp_rank=moe.max_num_tokens, token_hidden_size=moe.hidden_dim, @@ -172,13 +180,11 @@ def _maybe_make_prepare_finalize( all2all_manager.world_size) handle = all2all_manager.get_handle(all_to_all_args) - # Note : We may want to use FP8 dispatch even otherwise just to - # reduce datamovement - use_fp8_dispatch = (moe.quant_config is not None - and moe.quant_config.quant_dtype - == current_platform.fp8_dtype() - and moe.quant_config.block_shape - == DEEPEP_QUANT_BLOCK_SHAPE) + # Note: We may want to use FP8 dispatch just to reduce + # data movement. + use_fp8_dispatch = ( + quant_config.quant_dtype == current_platform.fp8_dtype() + and quant_config.block_shape == DEEPEP_QUANT_BLOCK_SHAPE) prepare_finalize = DeepEPLLPrepareAndFinalize( handle, @@ -190,11 +196,10 @@ def _maybe_make_prepare_finalize( return prepare_finalize def maybe_make_prepare_finalize( - self, - moe: FusedMoEConfig, - ) -> Optional[FusedMoEPrepareAndFinalize]: - if moe.moe_parallel_config.use_all2all_kernels: - return FusedMoEMethodBase._maybe_make_prepare_finalize(moe) + self) -> Optional[FusedMoEPrepareAndFinalize]: + if self.moe.moe_parallel_config.use_all2all_kernels: + return FusedMoEMethodBase._maybe_make_prepare_finalize( + self.moe, self.moe_quant_config) else: return None @@ -202,7 +207,13 @@ def maybe_make_prepare_finalize( # prepare_communication_buffer_for_model. def init_prepare_finalize(self, layer: torch.nn.Module): assert self.moe is not None - prepare_finalize = self.maybe_make_prepare_finalize(self.moe) + + # We must get the quant config here so that the layer is + # completely initialized, i.e. all weights loaded and post + # processed. + self.moe_quant_config = self.get_fused_moe_quant_config(layer) + + prepare_finalize = self.maybe_make_prepare_finalize() if prepare_finalize is not None: logger.debug("%s for %s(%s)", prepare_finalize.__class__.__name__, @@ -211,16 +222,16 @@ def init_prepare_finalize(self, layer: torch.nn.Module): assert self.fused_experts is None, \ f"Attempt to override experts for {id(self)}!" self.topk_indices_dtype = prepare_finalize.topk_indices_dtype() - experts = self.select_gemm_impl(prepare_finalize, self.moe, layer) + experts = self.select_gemm_impl(prepare_finalize, layer) self.fused_experts = FusedMoEModularKernel( prepare_finalize, experts, + layer.shared_experts, ) def select_gemm_impl( self, prepare_finalize: FusedMoEPrepareAndFinalize, - moe: FusedMoEConfig, layer: torch.nn.Module, ) -> FusedMoEPermuteExpertsUnpermute: # based on the all2all implementation, select the appropriate @@ -229,6 +240,11 @@ def select_gemm_impl( f"{self.__class__.__name__} must select appropriate gemm " "implementation based on the prepare_finalize") + @abstractmethod + def get_fused_moe_quant_config( + self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + raise NotImplementedError + @abstractmethod def apply( self, @@ -244,6 +260,7 @@ def apply( expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -251,7 +268,7 @@ def apply( expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: raise NotImplementedError @@ -261,7 +278,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): def __init__(self, moe: FusedMoEConfig): super().__init__(moe) - self.has_bias = self.moe.has_bias self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() if self.rocm_aiter_moe_enabled: from .rocm_aiter_fused_moe import rocm_aiter_fused_experts @@ -269,23 +285,30 @@ def __init__(self, moe: FusedMoEConfig): else: self.rocm_aiter_fused_experts = None # type: ignore + def maybe_make_prepare_finalize( + self) -> Optional[FusedMoEPrepareAndFinalize]: + if self.rocm_aiter_moe_enabled: + return None + else: + return super().maybe_make_prepare_finalize() + def select_gemm_impl( self, prepare_finalize: FusedMoEPrepareAndFinalize, - # TODO(bnell): Remove. Every layer should have an moe config object. - moe: FusedMoEConfig, layer: torch.nn.Module, ) -> FusedMoEPermuteExpertsUnpermute: + assert self.moe_quant_config is not None if (prepare_finalize.activation_format == FusedMoEActivationFormat.BatchedExperts): logger.debug("BatchedTritonExperts %s", self.moe) return BatchedTritonExperts( max_num_tokens=self.moe.max_num_tokens, num_dispatchers=prepare_finalize.num_dispatchers(), + quant_config=self.moe_quant_config, ) else: logger.debug("TritonExperts %s", self.moe) - return TritonExperts() + return TritonExperts(self.moe_quant_config) def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, @@ -299,7 +322,7 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, requires_grad=False) layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) - if self.has_bias: + if self.moe.has_bias: w13_bias = torch.nn.Parameter(torch.zeros( num_experts, 2 * intermediate_size_per_partition, @@ -316,7 +339,7 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, requires_grad=False) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) - if self.has_bias: + if self.moe.has_bias: w2_bias = torch.nn.Parameter(torch.zeros(num_experts, hidden_size, dtype=params_dtype), @@ -400,6 +423,7 @@ def apply( expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -407,7 +431,7 @@ def apply( expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: if enable_eplb: assert expert_load_view is not None assert logical_to_physical_map is not None @@ -427,6 +451,7 @@ def apply( expert_map=expert_map, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, @@ -436,6 +461,16 @@ def apply( logical_replica_count=logical_replica_count, ) + def get_fused_moe_quant_config( + self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + if self.moe.has_bias: + return biased_moe_quant_config( + layer.w13_bias, + layer.w2_bias, + ) + else: + return FUSED_MOE_UNQUANTIZED_CONFIG + def forward_cuda( self, layer: torch.nn.Module, @@ -450,6 +485,7 @@ def forward_cuda( expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -457,7 +493,7 @@ def forward_cuda( expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, @@ -469,6 +505,7 @@ def forward_cuda( num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype, enable_eplb=enable_eplb, @@ -478,6 +515,7 @@ def forward_cuda( logical_replica_count=logical_replica_count) if self.rocm_aiter_moe_enabled: + assert self.fused_experts is None return self.rocm_aiter_fused_experts( hidden_states=x, w1=layer.w13_weight, @@ -488,7 +526,7 @@ def forward_cuda( activation=activation, apply_router_weight_on_input=apply_router_weight_on_input) elif self.fused_experts is not None: - if self.has_bias: + if self.moe.has_bias: raise ValueError( "FusedMoEModularKernel does not support bias.") return self.fused_experts( @@ -509,12 +547,11 @@ def forward_cuda( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, - w1_bias=layer.w13_bias if self.has_bias else None, - w2_bias=layer.w2_bias if self.has_bias else None, topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, activation=activation, + quant_config=self.moe_quant_config, apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, expert_map=expert_map, @@ -534,6 +571,7 @@ def forward_cpu( expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -541,7 +579,7 @@ def forward_cpu( expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ): + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: if enable_eplb is not False or expert_load_view is not None or \ logical_to_physical_map is not None or \ logical_replica_count is not None: @@ -560,6 +598,7 @@ def forward_cpu( expert_map, custom_routing_function, scoring_func, + routed_scaling_factor, e_score_correction_bias, apply_router_weight_on_input, activation, @@ -579,6 +618,7 @@ def forward_xpu( expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -586,7 +626,7 @@ def forward_xpu( expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ): + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: if enable_eplb is not False or expert_load_view is not None or \ logical_to_physical_map is not None or \ logical_replica_count is not None: @@ -617,6 +657,7 @@ def forward_tpu( expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -624,7 +665,7 @@ def forward_tpu( expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert not use_grouped_topk assert num_expert_group is None assert topk_group is None @@ -637,6 +678,9 @@ def forward_tpu( raise NotImplementedError( "Expert score correction bias is not supported for TPU.") assert activation == "silu", f"{activation} is not supported for TPU." + assert routed_scaling_factor == 1.0, \ + f"routed_scaling_factor {routed_scaling_factor} is not supported " \ + f"for TPU." if enable_eplb is not False or expert_load_view is not None or \ logical_to_physical_map is not None or \ logical_replica_count is not None: @@ -655,10 +699,75 @@ def forward_tpu( forward_native = forward_tpu elif current_platform.is_cpu(): forward_native = forward_cpu + elif current_platform.is_xpu(): + forward_native = forward_xpu else: forward_native = forward_cuda +def determine_expert_map( + ep_size: int, + ep_rank: int, + global_num_experts: int, + expert_placement_strategy: ExpertPlacementStrategy = "linear", +) -> tuple[int, Optional[torch.Tensor]]: + """ + Calculates how many experts should be assigned to each rank for EP and + creates a mapping from global to local expert index. Experts are + distributed evenly across ranks. Any remaining are assigned to the + last rank. + + Args: + ep_size: The size of the expert parallel group + ep_rank: The rank of the current process in the expert parallel + group + global_num_experts: The total number of experts in the model. + expert_placement_strategy: The expert placement strategy. + + Returns: + tuple[int, Optional[torch.Tensor]]: A tuple containing: + - local_num_experts (int): The number of experts assigned + to the current rank. + - expert_map (Optional[torch.Tensor]): A tensor of shape + (global_num_experts,) mapping from global to local index. + Contains -1 for experts not assigned to the current rank. + Returns None if ep_size is 1. + """ + assert ep_size > 0 + if ep_size == 1: + return (global_num_experts, None) + + # Distribute experts as evenly as possible to each rank. + base_experts = global_num_experts // ep_size + remainder = global_num_experts % ep_size + if ep_rank < remainder: + local_num_experts = base_experts + 1 + else: + local_num_experts = base_experts + + # Create a tensor of size num_experts filled with -1 + expert_map = torch.full((global_num_experts, ), -1, dtype=torch.int32) + # Create an expert map for the local experts + if expert_placement_strategy == "linear": + start_idx = ep_rank * base_experts + min(ep_rank, remainder) + expert_map[start_idx:start_idx + local_num_experts] = torch.arange( + 0, local_num_experts, dtype=torch.int32) + elif expert_placement_strategy == "round_robin": + local_log_experts = torch.arange(ep_rank, + global_num_experts, + ep_size, + dtype=torch.int32) + + expert_map[local_log_experts] = torch.arange(0, + local_num_experts, + dtype=torch.int32) + else: + raise ValueError("Unsupported expert placement strategy " + f"'{expert_placement_strategy}', expected one of " + f"{get_args(ExpertPlacementStrategy)}") + return (local_num_experts, expert_map) + + def determine_expert_map( ep_size: int, ep_rank: int, global_num_experts: int) -> tuple[int, Optional[torch.Tensor]]: @@ -740,7 +849,7 @@ class FusedMoE(CustomOp): intermediate_size: Intermediate size of the experts params_dtype: Data type for the parameters. reduce_results: Whether to all all_reduce on the output of the layer - renomalize: Whether to renormalize the logits in the fused_moe kernel + renormalize: Whether to renormalize the logits in the fused_moe kernel quant_config: Quantization configure. enable_eplb: Whether to enable expert parallelism load balancer. """ @@ -764,12 +873,14 @@ def __init__( prefix: str = "", custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", enable_eplb: bool = False, num_redundant_experts: int = 0, has_bias: bool = False, + is_sequence_parallel=False, ): super().__init__() if params_dtype is None: @@ -781,6 +892,10 @@ def __init__( dp_size_ = (dp_size if dp_size is not None else get_dp_group().world_size) + self.is_sequence_parallel = is_sequence_parallel + if self.is_sequence_parallel: + self.sp_size = tp_size_ + vllm_config = get_current_vllm_config() self.moe_parallel_config: FusedMoEParallelConfig = ( FusedMoEParallelConfig.make( @@ -790,13 +905,19 @@ def __init__( self.global_num_experts = num_experts + num_redundant_experts - # we padding globally so EP buffer allocation works - if quant_config and quant_config.get_name() == "mxfp4": - from vllm.model_executor.layers.quantization.mxfp4 import ( # noqa: E501 - should_use_flashinfer_mxfp4) - if current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER_FUSED_MOE_A16W4: - hidden_pad = round_up(hidden_size, 256) - hidden_size - if current_platform.is_rocm() or should_use_flashinfer_mxfp4(): + # we are padding globally so EP buffer allocation works + self._is_mxfp4 = self.is_mxfp4_quant(quant_config=quant_config) + if self._is_mxfp4: + from vllm.model_executor.layers.quantization.mxfp4 import ( + Mxfp4Backend, get_mxfp4_backend) + current_mxfp4_backend = get_mxfp4_backend() + if (current_mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16 + or current_mxfp4_backend + == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS): + hidden_size = round_up(hidden_size, 128) + elif (current_platform.is_rocm() or current_mxfp4_backend + == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM or + current_mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16): hidden_size = round_up(hidden_size, 256) # For smuggling this layer into the fused moe custom op @@ -820,15 +941,36 @@ def __init__( else: assert num_redundant_experts == 0, \ "Redundant experts are only supported with EPLB." + + expert_placement_strategy = ( + vllm_config.parallel_config.expert_placement_strategy) + if expert_placement_strategy == "round_robin": + # TODO(Bruce): will support round robin expert placement with + # EPLB enabled in the future. + round_robin_supported = ((num_expert_group is not None + and num_expert_group > 1) + and num_redundant_experts == 0 + and not self.enable_eplb) + + if not round_robin_supported: + logger.warning( + "Round-robin expert placement is only supported for " + "models with multiple expert groups and no redundant " + "experts. Falling back to linear expert placement.") + expert_placement_strategy = "linear" + self.local_num_experts, self.expert_map = determine_expert_map( ep_size=self.ep_size, ep_rank=self.ep_rank, - global_num_experts=self.global_num_experts) + global_num_experts=self.global_num_experts, + expert_placement_strategy=expert_placement_strategy, + ) logger.info_once( - "[EP Rank %s/%s] Expert parallelism is enabled. Local/global" + "[EP Rank %s/%s] Expert parallelism is enabled. Expert " + "placement strategy: %s. Local/global" " number of experts: %s/%s. Experts local to global index map:" - " %s.", self.ep_rank, self.ep_size, self.local_num_experts, - self.global_num_experts, + " %s.", self.ep_rank, self.ep_size, expert_placement_strategy, + self.local_num_experts, self.global_num_experts, get_compressed_expert_map(self.expert_map)) else: self.local_num_experts, self.expert_map = (self.global_num_experts, @@ -848,6 +990,7 @@ def __init__( self.topk_group = topk_group self.custom_routing_function = custom_routing_function self.scoring_func = scoring_func + self.routed_scaling_factor = routed_scaling_factor self.e_score_correction_bias = e_score_correction_bias self.apply_router_weight_on_input = apply_router_weight_on_input self.activation = activation @@ -863,16 +1006,18 @@ def __init__( # since model_config is not set in the pytest test. model_dtype = params_dtype - moe = FusedMoEConfig.make(num_experts=self.global_num_experts, - experts_per_token=top_k, - hidden_dim=hidden_size, - num_local_experts=self.local_num_experts, - moe_parallel_config=self.moe_parallel_config, - in_dtype=model_dtype, - max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE, - quant_config=quant_config, - has_bias=has_bias) + moe = FusedMoEConfig( + num_experts=self.global_num_experts, + experts_per_token=top_k, + hidden_dim=hidden_size, + num_local_experts=self.local_num_experts, + moe_parallel_config=self.moe_parallel_config, + in_dtype=model_dtype, + max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE, + has_bias=has_bias, + ) self.moe_config = moe + self.moe_quant_config: Optional[FusedMoEQuantConfig] = None self.quant_config = quant_config # Note: get_quant_method will look at the layer's local_num_experts @@ -914,30 +1059,44 @@ def __init__( "CompressedTensorsWNA16MarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")): moe_quant_params["intermediate_size_full"] = intermediate_size - - # need pad hidden_size for ROCM mxfp4 - if (self.quant_method.__class__.__name__ == "MXFP4MoEMethod" - and current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER_FUSED_MOE_A16W4): - moe_quant_params["hidden_pad"] = hidden_pad self.quant_method.create_weights(layer=self, **moe_quant_params) # Chunked all2all staging tensor self.batched_hidden_states: Optional[torch.Tensor] = None self.batched_router_logits: Optional[torch.Tensor] = None + + # TODO(bnell): flashinfer uses non-batched format. + # Does it really need a batched buffer? if (self.moe_parallel_config.use_pplx_kernels or self.moe_parallel_config.use_deepep_ll_kernels - or self.moe_parallel_config.use_flashinfer_cutlass_kernels): - self.batched_hidden_states = torch.zeros( - (moe.max_num_tokens, self.hidden_size), - dtype=moe.in_dtype, - device=torch.cuda.current_device()) - - # Note here we use `num_experts` which is logical expert count - self.batched_router_logits = torch.zeros( - (moe.max_num_tokens, num_experts), - dtype=moe.in_dtype, - device=torch.cuda.current_device()) + or self.moe_config.use_flashinfer_cutlass_kernels): + if vllm_config.parallel_config.enable_dbo: + self.batched_hidden_states = torch.zeros( + (2, moe.max_num_tokens, self.hidden_size), + dtype=moe.in_dtype, + device=torch.cuda.current_device()) + + # Note here we use `num_experts` which is logical expert count + self.batched_router_logits = torch.zeros( + (2, moe.max_num_tokens, num_experts), + dtype=moe.in_dtype, + device=torch.cuda.current_device()) + else: + self.batched_hidden_states = torch.zeros( + (moe.max_num_tokens, self.hidden_size), + dtype=moe.in_dtype, + device=torch.cuda.current_device()) + + # Note here we use `num_experts` which is logical expert count + self.batched_router_logits = torch.zeros( + (moe.max_num_tokens, num_experts), + dtype=moe.in_dtype, + device=torch.cuda.current_device()) + + @property + def shared_experts(self) -> Optional[torch.nn.Module]: + return None @property def tp_size(self): @@ -981,7 +1140,9 @@ def use_deepep_ll_kernels(self): @property def use_flashinfer_cutlass_kernels(self): - return self.moe_parallel_config.use_flashinfer_cutlass_kernels + return (self.moe_quant_config is not None + and self.moe_quant_config.quant_dtype == "nvfp4" + and self.moe_config.use_flashinfer_cutlass_kernels) def update_expert_map(self): # ep_size and ep_rank should already be updated @@ -1153,16 +1314,47 @@ def weight_loader(self, expert_id: int, return_success: bool = False) -> Optional[bool]: - if self.quant_config and self.quant_config.get_name() == "mxfp4": - # (FIXME) for gpt-oss all experts are combined - if "bias" in weight_name: - dim1 = loaded_weight.shape[1] - param.data[:, :dim1].copy_(loaded_weight) - else: - dim1 = loaded_weight.shape[1] - dim2 = loaded_weight.shape[2] - param.data[:, :dim1, :dim2].copy_(loaded_weight) - return True if return_success else None + if self._is_mxfp4: + if self.quant_config.get_name() == "mxfp4": + # (FIXME) for gpt-oss all experts are combined + if "bias" in weight_name: + dim1 = loaded_weight.shape[1] + param.data[:, :dim1].copy_(loaded_weight) + else: + dim1 = loaded_weight.shape[1] + dim2 = loaded_weight.shape[2] + param.data[:, :dim1, :dim2].copy_(loaded_weight) + return True if return_success else None + elif self.quant_config.get_name() == "quark": + # When self._is_mxfp4 is true, model_dtype must be gpt_oss + expert_data = param.data[expert_id] + if "input_scale" in weight_name: + assert loaded_weight.numel() == 1 + expert_data.data.copy_(loaded_weight) + return True if return_success else None + + shard_dim = 0 if shard_id in ( + "w1", "w3") or "bias" in weight_name else 1 + if shard_id == "w2": + shard_size = loaded_weight.shape[shard_dim] // self.tp_size + loaded_weight = loaded_weight.narrow( + shard_dim, shard_size * self.tp_rank, shard_size) + if "bias" in weight_name: + dim1 = loaded_weight.shape[0] + expert_data.data[:dim1].copy_(loaded_weight) + else: + dim1 = loaded_weight.shape[0] + dim2 = loaded_weight.shape[1] + expert_data.data[:dim1, :dim2].copy_(loaded_weight) + elif shard_id is None: + if "bias" in weight_name: + dim1 = loaded_weight.shape[0] + expert_data.data[:dim1].copy_(loaded_weight) + else: + dim1 = loaded_weight.shape[0] + dim2 = loaded_weight.shape[1] + expert_data.data[:dim1, :dim2].copy_(loaded_weight) + return True if return_success else None expert_id = self._map_global_expert_id_to_local_expert_id(expert_id) if expert_id == -1: @@ -1390,7 +1582,8 @@ def get_expert_weights(self) -> Iterable[torch.Tensor]: return [ weight.view(self.local_num_experts, -1) for name, weight in weights - if name not in NON_EXPERT_WEIGHTS + if name not in NON_EXPERT_WEIGHTS and weight.shape != torch.Size( + []) and not name.startswith("_shared_experts.") ] def set_eplb_state( @@ -1410,6 +1603,11 @@ def set_eplb_state( self.logical_to_physical_map = logical_to_physical_map[moe_layer_idx] self.logical_replica_count = logical_replica_count[moe_layer_idx] + def ensure_moe_quant_config(self): + if self.quant_method.moe_quant_config is None: + self.quant_method.moe_quant_config = ( + self.quant_method.get_fused_moe_quant_config(self)) + @staticmethod def select_experts( hidden_states: torch.Tensor, @@ -1421,6 +1619,7 @@ def select_experts( num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, indices_type: Optional[torch.dtype] = None, enable_eplb: bool = False, @@ -1465,6 +1664,7 @@ def select_experts( num_expert_group=num_expert_group, topk_group=topk_group, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias) if indices_type is not None: topk_ids = topk_ids.to(dtype=indices_type) @@ -1571,25 +1771,52 @@ def maybe_all_reduce_tensor_model_parallel( else: return tensor_model_parallel_all_reduce(final_hidden_states) - def forward(self, hidden_states: torch.Tensor, - router_logits: torch.Tensor): + def forward_native( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: og_hidden_states = hidden_states.shape[-1] if self.hidden_size != og_hidden_states: hidden_states = F.pad(hidden_states, (0, self.hidden_size - og_hidden_states), mode='constant', value=0.0) - # TODO: Once the OOM issue for the TPU backend is resolved, we will - # switch to using the moe_forward custom op. - if current_platform.is_tpu(): - return self.forward_impl(hidden_states, router_logits) + + if self.shared_experts is None: + if current_platform.is_tpu(): + # TODO: Once the OOM issue for the TPU backend is resolved, we + # will switch to using the moe_forward custom op. + fused_output = self.forward_impl(hidden_states, router_logits) + assert not isinstance(fused_output, tuple) + else: + fused_output = torch.ops.vllm.moe_forward( + hidden_states, router_logits, self.layer_name) + return fused_output[..., :og_hidden_states] else: - return torch.ops.vllm.moe_forward( - hidden_states, router_logits, - self.layer_name)[..., :og_hidden_states] + if current_platform.is_tpu(): + # TODO: Once the OOM issue for the TPU backend is resolved, we + # will switch to using the moe_forward custom op. + shared_output, fused_output = self.forward_impl( + hidden_states, router_logits) + else: + shared_output, fused_output = torch.ops.vllm.moe_forward_shared( + hidden_states, router_logits, self.layer_name) + return (shared_output[..., :og_hidden_states], + fused_output[..., :og_hidden_states]) + + def forward_cuda( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + return self.forward_native(hidden_states, router_logits) - def forward_impl_chunked(self, full_hidden_states: torch.Tensor, - full_router_logits: torch.Tensor): + def forward_impl_chunked( + self, + full_hidden_states: torch.Tensor, + full_router_logits: torch.Tensor, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.batched_hidden_states is not None assert self.batched_router_logits is not None assert self.batched_hidden_states.dtype == full_hidden_states.dtype @@ -1600,21 +1827,41 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor, assert ( self.batched_router_logits.size(-1) == full_router_logits.size(-1)) - full_final_hidden_states = torch.empty_like(full_hidden_states) + self.ensure_moe_quant_config() + + full_fused_final_hidden_states = torch.empty_like(full_hidden_states) + if self.shared_experts is not None: + full_shared_final_hidden_states = torch.empty_like( + full_hidden_states) def process_chunk(chunk_start, chunk_end, skip_result_store=False): chunk_size = chunk_end - chunk_start hidden_states = full_hidden_states[chunk_start:chunk_end, :] router_logits = full_router_logits[chunk_start:chunk_end, :] - assert (self.batched_hidden_states.size(0) # type: ignore + assert self.batched_hidden_states is not None + assert self.batched_router_logits is not None + # This is only true when DBO has been enabled in the config. + # Both tensors will have an outer dimension for the ubatch id + if self.batched_hidden_states.dim() == 3: + assert self.batched_router_logits.dim() == 3 + batch_buffer_idx = dbo_current_ubatch_id() + batched_hidden_states = self.batched_hidden_states[ + batch_buffer_idx, :] + batched_router_logits = self.batched_router_logits[ + batch_buffer_idx, :] + else: + batched_hidden_states = self.batched_hidden_states + batched_router_logits = self.batched_router_logits + + assert (batched_hidden_states.size(0) # type: ignore >= chunk_size) - assert (self.batched_router_logits.size(0) # type: ignore + assert (batched_router_logits.size(0) # type: ignore >= chunk_size) - staged_hidden_states = self.batched_hidden_states[: - chunk_size, :] # type: ignore - staged_router_logits = self.batched_router_logits[: - chunk_size, :] # type: ignore + staged_hidden_states = batched_hidden_states[: + chunk_size, :] # type: ignore + staged_router_logits = batched_router_logits[: + chunk_size, :] # type: ignore staged_hidden_states.copy_(hidden_states, non_blocking=True) staged_router_logits.copy_(router_logits, non_blocking=True) @@ -1632,6 +1879,7 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): num_expert_group=self.num_expert_group, custom_routing_function=self.custom_routing_function, scoring_func=self.scoring_func, + routed_scaling_factor=self.routed_scaling_factor, e_score_correction_bias=self.e_score_correction_bias, activation=self.activation, enable_eplb=self.enable_eplb, @@ -1640,20 +1888,40 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): logical_replica_count=self.logical_replica_count, ) + assert self.shared_experts is None or isinstance( + final_hidden_states, tuple) + if not skip_result_store: - full_final_hidden_states[chunk_start:chunk_end, :].copy_( - final_hidden_states, non_blocking=True) + if self.shared_experts is None: + full_fused_final_hidden_states[ + chunk_start:chunk_end, :].copy_(final_hidden_states, + non_blocking=True) + else: + full_shared_final_hidden_states[ + chunk_start:chunk_end, :].copy_(final_hidden_states[0], + non_blocking=True) + full_fused_final_hidden_states[ + chunk_start:chunk_end, :].copy_(final_hidden_states[1], + non_blocking=True) ctx = get_forward_context() # flashinfer_cutlass_kernels can handle: optional DP + TP/EP - max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu + max_tokens_across_dispatchers = ctx.dp_metadata.max_tokens_across_dp_cpu moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens + + # If the input to the MoE is sequence parallel then divide by sp_size + # to find the maximum number of tokens for any individual dispatcher. + if self.is_sequence_parallel: + max_tokens_across_dispatchers = cdiv(max_tokens_across_dispatchers, + self.sp_size) + num_tokens = full_hidden_states.size(0) for chunk_idx, chunk_start_ in enumerate( - range(0, max_tokens_across_dp, moe_dp_chunk_size_per_rank)): + range(0, max_tokens_across_dispatchers, + moe_dp_chunk_size_per_rank)): chunk_start = chunk_start_ chunk_end = min(chunk_start + moe_dp_chunk_size_per_rank, - max_tokens_across_dp) + max_tokens_across_dispatchers) # clamp start and end chunk_start = min(chunk_start, num_tokens - 1) chunk_end = min(chunk_end, num_tokens) @@ -1663,25 +1931,45 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): chunk_end, skip_result_store=chunk_start_ >= num_tokens) - return full_final_hidden_states + if self.shared_experts is None: + return full_fused_final_hidden_states + else: + return (full_shared_final_hidden_states, + full_fused_final_hidden_states) - def forward_impl(self, hidden_states: torch.Tensor, - router_logits: torch.Tensor): + def forward_impl( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.quant_method is not None + + self.ensure_moe_quant_config() + # Route to the chunked forward path using the FlashInfer Cutlass kernel # only when data parallelism (DP) is enabled. - use_flashinfer_cutlass_kernels = ( - self.dp_size > 1 - and self.moe_parallel_config.use_flashinfer_cutlass_kernels) + _use_flashinfer_cutlass_kernels = (self.dp_size > 1 and + self.use_flashinfer_cutlass_kernels) + if (self.moe_parallel_config.use_pplx_kernels or self.moe_parallel_config.use_deepep_ll_kernels - or use_flashinfer_cutlass_kernels): + or _use_flashinfer_cutlass_kernels): return self.forward_impl_chunked(hidden_states, router_logits) do_naive_dispatch_combine: bool = ( self.dp_size > 1 and not self.moe_parallel_config.use_deepep_ht_kernels - and not self.moe_parallel_config.use_flashinfer_cutlass_kernels) + and not self.moe_config.use_flashinfer_cutlass_kernels) + + # If there are shared experts but we are not using a modular kernel, the + # shared experts must be called here + if (not isinstance(self.quant_method.fused_experts, + FusedMoEModularKernel) + and self.shared_experts is not None): + shared_output = self.shared_experts(hidden_states) + else: + shared_output = None + if do_naive_dispatch_combine: hidden_states, router_logits = get_ep_group().dispatch( hidden_states, router_logits) @@ -1700,6 +1988,7 @@ def forward_impl(self, hidden_states: torch.Tensor, num_expert_group=self.num_expert_group, custom_routing_function=self.custom_routing_function, scoring_func=self.scoring_func, + routed_scaling_factor=self.routed_scaling_factor, e_score_correction_bias=self.e_score_correction_bias, activation=self.activation, apply_router_weight_on_input=self.apply_router_weight_on_input, @@ -1709,14 +1998,32 @@ def forward_impl(self, hidden_states: torch.Tensor, logical_replica_count=self.logical_replica_count, ) - if do_naive_dispatch_combine: - final_hidden_states = get_ep_group().combine(final_hidden_states) - if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): - # Default set to False. (May have to add shared expert outputs. - final_hidden_states = self.maybe_all_reduce_tensor_model_parallel( - final_hidden_states) + if shared_output is not None: + assert not isinstance(final_hidden_states, tuple) + assert self.shared_experts is not None + final_hidden_states = ( + shared_output, + final_hidden_states, + ) - return final_hidden_states + def reduce_output(states: torch.Tensor, + do_combine: bool = True) -> torch.Tensor: + if do_naive_dispatch_combine and do_combine: + states = get_ep_group().combine(states) + + if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): + states = self.maybe_all_reduce_tensor_model_parallel(states) + + return states + + if self.shared_experts is None: + assert not isinstance(final_hidden_states, tuple) + return reduce_output(final_hidden_states) + else: + return ( + reduce_output(final_hidden_states[0], do_combine=False), + reduce_output(final_hidden_states[1]), + ) @classmethod def make_expert_params_mapping( @@ -1770,6 +2077,22 @@ def extra_repr(self) -> str: return s + def is_mxfp4_quant(self, + quant_config: Optional[QuantizationConfig] = None + ) -> bool: + name = quant_config.get_name() if quant_config else None + if name == "mxfp4": + return True + elif name == "quark": + from vllm.config import get_current_vllm_config + vllm_config = get_current_vllm_config() + model_type = getattr(vllm_config.model_config.hf_config, + "model_type", None) + # Padding for triton kernel only is enabled when it is gpt_oss + return quant_config.is_global_mxfp4 and model_type == "gpt_oss" + else: + return False + def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor, layer_name: str) -> torch.Tensor: diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 9d09c46245aa..290a39610848 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from enum import Enum from typing import Callable, Optional import torch @@ -11,6 +12,9 @@ from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig, FusedMoEMethodBase) from vllm.model_executor.layers.fused_moe import modular_kernel as mk +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEQuantConfig, mxfp4_w4a4_moe_quant_config, + mxfp4_w4a16_moe_quant_config) from vllm.model_executor.layers.fused_moe.trtllm_moe import TrtLlmGenExperts from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) @@ -33,6 +37,75 @@ logger = init_logger(__name__) +# enum for mxfp4 backend +class Mxfp4Backend(Enum): + NONE = 0 + + # FlashInfer Backend + SM100_FI_MXFP4_MXFP8_TRTLLM = 1 + SM100_FI_MXFP4_MXFP8_CUTLASS = 2 + SM100_FI_MXFP4_BF16 = 3 + SM90_FI_MXFP4_BF16 = 4 + + # Marlin Backend + MARLIN = 5 + + # Triton Backend + TRITON = 6 + + +def get_mxfp4_backend(): + # Backend Selection + if current_platform.is_cuda(): + if (current_platform.is_device_capability(90) and has_flashinfer() + and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16): + logger.info_once("Using FlashInfer MXFP4 BF16 backend for SM90") + return Mxfp4Backend.SM90_FI_MXFP4_BF16 + elif (current_platform.is_device_capability(100) and has_flashinfer() + and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS): + logger.info_once( + "Using FlashInfer MXFP4 MXFP8 CUTLASS backend for SM100") + return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS + elif (current_platform.is_device_capability(100) and has_flashinfer() + and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8): + logger.info_once( + "Using FlashInfer MXFP4 MXFP8 TRTLLM backend for SM100, " + "for high concurrency throughput workloads consider setting " + "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS=1 for better " + "performance") + return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM + elif current_platform.is_device_capability(100) and has_flashinfer(): + logger.info_once( + "Using FlashInfer MXFP4 BF16 backend for SM100, " + "For faster performance on SM100, consider setting " + "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1, though this may impact " + "accuracy.") + return Mxfp4Backend.SM100_FI_MXFP4_BF16 + elif ((current_platform.is_device_capability(100) + or current_platform.is_device_capability(90)) + and not has_flashinfer()): + logger.warning_once( + "MXFP4 MoE is enabled on Hopper/Blackwell but FlashInfer " + "is not available. This may result in degraded performance. " + "Please `pip install vllm[flashinfer]` for best results.") + + # If FlashInfer is not available, try either Marlin or Triton + if current_platform.get_device_capability( + )[0] < 9 or not has_triton_kernels() or not is_torch_equal_or_newer( + "2.8.0"): + logger.info_once("Using Marlin backend") + return Mxfp4Backend.MARLIN + else: + logger.info_once("Using Triton backend") + return Mxfp4Backend.TRITON + elif current_platform.is_rocm() and has_triton_kernels(): + logger.info_once("Using Triton backend") + return Mxfp4Backend.TRITON + + return Mxfp4Backend.NONE + + + def _should_use_flashinfer_mxfp4_bf16(): """Determine if FlashInfer MXFP4 BF16 should be used.""" # If explicitly set, respect the setting @@ -468,6 +541,31 @@ def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int): return tile_tokens_dim + def get_fused_moe_quant_config( + self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + + if self.mxfp4_backend == Mxfp4Backend.MARLIN: + return None + + if self.mxfp4_backend == Mxfp4Backend.TRITON: + w1_scale = self.w13_precision_config + w2_scale = self.w2_precision_config + return mxfp4_w4a16_moe_quant_config( + w1_bias=layer.w13_bias, + w2_bias=layer.w2_bias, + w1_scale=w1_scale, + w2_scale=w2_scale, + ) + else: + w1_scale = layer.w13_weight_scale + w2_scale = layer.w2_weight_scale + return mxfp4_w4a4_moe_quant_config( + w1_bias=layer.w13_bias, + w2_bias=layer.w2_bias, + w1_scale=w1_scale, + w2_scale=w2_scale, + ) + def select_gemm_impl( self, prepare_finalize: mk.FusedMoEPrepareAndFinalize, diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py index 56224cbbcdf0..49892f9604ae 100644 --- a/vllm/model_executor/layers/quantization/quark/quark.py +++ b/vllm/model_executor/layers/quantization/quark/quark.py @@ -41,6 +41,17 @@ def __init__(self, self.kv_cache_group = kv_cache_group self.kv_cache_config = kv_cache_config self.pack_method = pack_method + self._is_global_mxfp4() + + def _is_global_mxfp4(self): + # Check if it is MXFP4 to determine if pre-padding should be applied. + # This must be created during the initialization of moe. + global_quant_config = cast( + dict[str, Any], self.quant_config.get("global_quant_config")) + weight_quant = global_quant_config.get("weight") + input_quant = global_quant_config.get("input_tensors") + self.is_global_mxfp4 = self._is_mx_fp4(weight_quant=weight_quant, + input_quant=input_quant) def get_linear_method(self) -> "QuarkLinearMethod": return QuarkLinearMethod(self) @@ -225,42 +236,39 @@ def _is_static_tensor_w8a8(self, weight_quant: Optional[dict[str, Any]], def _is_mx_fp4(self, weight_quant: Optional[dict[str, Any]], input_quant: Optional[dict[str, Any]]) -> bool: - # Confirm weights and input quantized. - if weight_quant is None or input_quant is None: + # Confirm weights quantized. + if weight_quant is None: logger.debug("Quark model is not in MX-FP4 format: " - "weight_quant or input_quant not set") + "weight_quant not set") return False # Input and weight dtype needs to be fp4. - if weight_quant.get("dtype") != "fp4" or input_quant.get( - "dtype") != "fp4": - logger.debug("Quark model is not in MX-FP4 format: dtype not fp4") + if weight_quant.get("dtype") != "fp4": + logger.debug("Quark model is not in MX-FP4 format: weight dtype not fp4") return False # Input and weight qscheme needs to be per group. - if weight_quant.get("qscheme") != "per_group" or input_quant.get( - "qscheme") != "per_group": + if weight_quant.get("qscheme") != "per_group": logger.debug("Quark model is not in MX-FP4 format: not per_group") return False # Input and weight group size needs to be 32. - if weight_quant.get("group_size") != 32 or input_quant.get( - "group_size") != 32: + if weight_quant.get("group_size") != 32: logger.debug( "Quark model is not in MX-FP4 format: not group_size=32") return False - # Activations need to use dynamic quantization. - if input_quant.get("is_dynamic") is False: + # Activations and weight scales need to be in e8m0 format. + if weight_quant.get("scale_format") != "e8m0": logger.debug( - "Quark model is not in MX-FP4 format: not activation dynamic") + "Quark model is not in MX-FP4 format: not scale_format e8m0") return False - # Activations and weight scales need to be in e8m0 format. - if weight_quant.get("scale_format") != "e8m0" or input_quant.get( - "scale_format") != "e8m0": + # Input dtype needs to be one of {'fp4', 'fp6_e2m3', 'fp8_e4m3'}. + if input_quant.get("dtype") not in ("fp4", "fp6_e2m3", "fp8_e4m3"): logger.debug( - "Quark model is not in MX-FP4 format: not scale_format e8m0") + "Quark model is not in MX-FP4 format: expected input dtype " + "to be one of {'fp4', 'fp6_e2m3', 'fp8_e4m3'}") return False return True diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index e4771056cc4e..335fb3eb518b 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -1,21 +1,26 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Union import torch +from torch.nn.parameter import Parameter from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEQuantConfig, fp8_w8a8_moe_quant_config, + mxfp4_w4a4_moe_quant_config) from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( - OCP_MX_BLOCK_SIZE) + OCP_MX_BLOCK_SIZE, _swizzle_mxfp4) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform +from vllm.utils import round_up logger = init_logger(__name__) @@ -49,8 +54,16 @@ def get_moe_method( return QuarkW8A8Fp8MoEMethod(weight_config, input_config, module.moe_config) elif quant_config._is_mx_fp4(weight_config, input_config): - return QuarkW4A4MXFp4MoEMethod(weight_config, input_config, - module.moe_config) + from vllm.config import get_current_vllm_config + vllm_config = get_current_vllm_config() + model_type = getattr(vllm_config.model_config.hf_config, + "model_type", None) + if model_type == "gpt_oss": + return QuarkW4MXFp4MoEMethod_OSS(weight_config, input_config, + module.moe_config) + else: + return QuarkW4MXFp4MoEMethod(weight_config, input_config, + module.moe_config) else: raise RuntimeError("Unsupported FusedMoe scheme") @@ -265,7 +278,7 @@ def apply( activation=activation) -class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod): +class QuarkW4MXFp4MoEMethod(QuarkMoEMethod): def __init__( self, @@ -290,7 +303,7 @@ def __init__( if self.static_input_scales: raise NotImplementedError( - "QuarkW4A4MXFp4MoEMethod with static input scales is currently " + "QuarkW4MXFp4MoEMethod with static input scales is currently " "not implemented. Please open an issue.") if not current_platform.supports_mx(): @@ -386,7 +399,7 @@ def apply( if enable_eplb: raise NotImplementedError( - "EPLB not supported for `QuarkW4A4MXFp4MoEMethod` yet.") + "EPLB not supported for `QuarkW4MXFp4MoEMethod` yet.") from vllm.model_executor.layers.fused_moe import fused_experts @@ -422,3 +435,294 @@ def apply( activation=activation, ) return out + + +class QuarkW4MXFp4MoEMethod_OSS(QuarkMoEMethod): + + def __init__( + self, + weight_config: dict[str, Any], + input_config: dict[str, Any], + moe: FusedMoEConfig, + ): + super().__init__(moe) + self.weight_quant = weight_config + self.input_quant = input_config + + weight_qscheme = self.weight_quant.get("qscheme") + input_qscheme = self.input_quant.get("qscheme") + if not (weight_qscheme == "per_group"): + raise ValueError( + "For MX(FP4) Fused MoE layers, only per-group scales " + "for weights and activations are supported. Found " + f"{weight_qscheme}, {input_qscheme}") # noqa E501 + + self.static_input_scales = not self.input_quant.get("is_dynamic") + if not current_platform.supports_mx(): + self.emulate = True + logger.warning_once( + "The current platform does not support native MXFP4 " + "computation. Simulated weight dequantization and activation " + "QDQ (quantize and dequantize) will be used, with the linear " + "layers computed in high precision.") + else: + self.emulate = True + logger.warning_once( + "The current platform supports native MXFP4 " + "computation, but kernels are not yet integrated in vLLM. " + "Simulated weight dequantization and activation " + "QDQ (quantize and dequantize) will be used, with the linear " + "layers computed in high precision.") + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + self.num_experts = num_experts + # Add the quantization method used (per tensor/grouped/channel) + # to ensure the weight scales are loaded in properly + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}) + mxfp4_block = 32 + weight_dtype = torch.uint8 + weight_scale_dtype = torch.uint8 + per_tensor_fp8_act_scale_dtype = torch.bfloat16 + self.intermediate_size_per_partition = intermediate_size_per_partition + intermediate_size_per_partition_after_pad = \ + intermediate_size_per_partition + + if current_platform.is_rocm(): + intermediate_size_per_partition_after_pad = round_up( + intermediate_size_per_partition, 256) # 2880 -> 2944 + else: + intermediate_size_per_partition_after_pad = round_up( + intermediate_size_per_partition, 64) + # Fused gate_up_proj (column parallel) + w13_weight = torch.nn.Parameter( + torch.zeros( + num_experts, + 2 * intermediate_size_per_partition_after_pad, + hidden_size // 2, + dtype=weight_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w13_weight_scale = torch.nn.Parameter( + torch.zeros( + num_experts, + 2 * intermediate_size_per_partition_after_pad, + hidden_size // mxfp4_block, + dtype=weight_scale_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + + w13_bias = torch.nn.Parameter( + torch.zeros( + num_experts, + 2 * intermediate_size_per_partition_after_pad, + dtype=torch.bfloat16, + ), + requires_grad=False, + ) + layer.register_parameter("w13_bias", w13_bias) + set_weight_attrs(w13_bias, extra_weight_attrs) + + # down_proj (row parallel) + w2_weight = torch.nn.Parameter( + torch.zeros( + num_experts, + hidden_size, + intermediate_size_per_partition_after_pad // 2, + dtype=weight_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + w2_weight_scale = torch.nn.Parameter( + torch.zeros( + num_experts, + hidden_size, + intermediate_size_per_partition_after_pad // mxfp4_block, + dtype=weight_scale_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + w2_bias = torch.nn.Parameter( + torch.zeros( + num_experts, + hidden_size, + dtype=torch.bfloat16, + ), + requires_grad=False, + ) + layer.register_parameter("w2_bias", w2_bias) + set_weight_attrs(w2_bias, extra_weight_attrs) + if self.static_input_scales: + w13_input_scale = torch.nn.Parameter( + torch.zeros( + num_experts, + dtype=per_tensor_fp8_act_scale_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_input_scale", w13_input_scale) + set_weight_attrs(w13_input_scale, extra_weight_attrs) + w2_input_scale = torch.nn.Parameter( + torch.zeros( + num_experts, + dtype=per_tensor_fp8_act_scale_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_input_scale", w2_input_scale) + set_weight_attrs(w2_input_scale, extra_weight_attrs) + + def process_weights_after_loading(self, layer): + + from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig + + w13_bias = layer.w13_bias.to(torch.float32) + w2_bias = layer.w2_bias.to(torch.float32) + + layer.w13_bias = Parameter(w13_bias, requires_grad=False) + layer.w2_bias = Parameter(w2_bias, requires_grad=False) + + # FIXME warp need to be adjusted based on batch size + # only apply to batched mode + if self.moe.use_ep: + num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8 + else: + num_warps = 8 + + w13_weight, w13_flex, w13_scale = _swizzle_mxfp4( + layer.w13_weight, layer.w13_weight_scale, num_warps) + w2_weight, w2_flex, w2_scale = _swizzle_mxfp4(layer.w2_weight, + layer.w2_weight_scale, + num_warps) + + self.w13_weight_triton_tensor = w13_weight + self.w2_weight_triton_tensor = w2_weight + + # need to delete the original weights to save memory on single GPU + del layer.w13_weight + del layer.w2_weight + layer.w13_weight = None + layer.w2_weight = None + torch.cuda.empty_cache() + + if self.static_input_scales: # wmxfp4 a fp8 pertensor + if (layer.w13_input_scale is None or layer.w2_input_scale is None): + raise ValueError( + "QuantConfig has static quantization, but found " + "activation scales are None.") + if (not all_close_1d(layer.w13_input_scale) + or not all_close_1d(layer.w2_input_scale)): + logger.warning_once( + "Found input_scales that are not equal for " + "fp8 MoE layer. Using the maximum across experts " + "for each layer.") + # layer.w13_input_scale = torch.nn.Parameter( + # layer.w13_input_scale.max(), requires_grad=False) + # layer.w2_input_scale = torch.nn.Parameter( + # layer.w2_input_scale.max(), requires_grad=False) + + layer.w13_input_scale = torch.nn.Parameter( + layer.w13_input_scale.max().to(torch.float32), + requires_grad=False) + layer.w2_input_scale = torch.nn.Parameter( + layer.w2_input_scale.max().to(torch.float32), + requires_grad=False) + + from triton_kernels.numerics import InFlexData + lhs_data13 = InFlexData(scale=layer.w13_input_scale) + lhs_data2 = InFlexData(scale=layer.w2_input_scale) + + self.w13_precision_config = PrecisionConfig( + weight_scale=w13_scale, + flex_ctx=FlexCtx(rhs_data=w13_flex, lhs_data=lhs_data13)) + self.w2_precision_config = PrecisionConfig(weight_scale=w2_scale, + flex_ctx=FlexCtx( + rhs_data=w2_flex, + lhs_data=lhs_data2)) + else: + self.w13_precision_config = PrecisionConfig( + weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex)) + self.w2_precision_config = PrecisionConfig( + weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex)) + + def get_fused_moe_quant_config( + self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + w1_scale = self.w13_precision_config + w2_scale = self.w2_precision_config + + if self.static_input_scales: + # TODO: how to set scale? + return mxfp4_w4a4_moe_quant_config( + w1_bias=layer.w13_bias, + w2_bias=layer.w2_bias, + w1_scale=w1_scale, + w2_scale=w2_scale, + ) + + else: + return mxfp4_w4a4_moe_quant_config( + w1_bias=layer.w13_bias, + w2_bias=layer.w2_bias, + w1_scale=w1_scale, + w2_scale=w2_scale, + ) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + assert self.fused_experts is None + + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for `QuarkW4MXFp4MoEMethod_OSS` yet.") + + from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( # noqa: E501 + triton_kernel_moe_forward) + + return triton_kernel_moe_forward( + hidden_states=x, + w1=self.w13_weight_triton_tensor, + w2=self.w2_weight_triton_tensor, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + global_num_experts=global_num_experts, + expert_map=expert_map, + quant_config=self.moe_quant_config, + apply_router_weight_on_input=apply_router_weight_on_input, + ) diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index a652d4f47b62..c83bd3f2283e 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os -from collections.abc import Iterable +import typing +from collections.abc import Callable, Iterable from typing import Optional import torch @@ -31,6 +32,7 @@ from vllm.utils import cdiv from vllm.platforms import current_platform from .utils import (AutoWeightsLoader, WeightsMapper, extract_layer_index, + is_pp_missing_parameter, maybe_prefix) import os @@ -587,6 +589,165 @@ def _load_weights_other( loaded_params.add(name) return loaded_params + def _load_weights_quark( + self, + ep_rank_start: int, + ep_rank_end: int, + heads_per_rank: int, + head_start: int, + weights: Iterable[tuple[str, torch.Tensor]], + stacked_params_mapping: list[tuple[str, ...]], + ) -> set[str]: + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + use_ep = self.parallel_config.enable_expert_parallel + if use_ep: + pass # TODO: error + + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + + intermediate_size = self.config.intermediate_size + per_rank_intermediate_size = cdiv(intermediate_size, tp_size) + # Calculate common slicing bounds for current rank + tp_rank_start = tp_rank * per_rank_intermediate_size + tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, + intermediate_size) + expert_params_mapping = self.get_expert_mapping() + for name, loaded_weight in weights: + if "sinks" in name: + # Handle attention sinks (distributed across ranks) + param = params_dict[name] + narrow_weight = loaded_weight.narrow(0, head_start, + heads_per_rank) + param.data.copy_(narrow_weight) + loaded_params.add(name) + continue + + # mapping to convert individual experts input_scale into fused_moe. + if "input_scale" in name: # w2 w13 input_scale + parts = name.split(".") + expert_id = int(parts[-2]) + name = ".".join(parts[:-2] + parts[-1:]) + param = params_dict[name] + + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, + loaded_weight, + weight_name=name, + shard_id=None, + expert_id=expert_id) + loaded_params.add(name) + continue + + # mapping to convert weight and bias of individual + # experts gate_up_proj into fused_moe. + if ".w13" in name: + parts = name.split(".") + expert_id = int(parts[-2]) + name = ".".join(parts[:-2] + parts[-1:]) + if use_ep: + narrow_weight = loaded_weight[ep_rank_start:ep_rank_end, + ...] + else: + narrow_weight = loaded_weight[2 * tp_rank_start:2 * + tp_rank_end, ...] + param = params_dict[name] + + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, + narrow_weight, + weight_name=name, + shard_id=None, + expert_id=expert_id) + loaded_params.add(name) + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if (("mlp.experts." in name) and name not in params_dict): + continue + name = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + is_expert_weight = False + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + + # Anyway, this is an expert weight and should not be + # attempted to load as other weights later + is_expert_weight = True + + # Do not modify `name` since the loop may continue here + # Instead, create a new variable + name_mapped = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name_mapped, self): + continue + + param = params_dict[name_mapped] + # We should ask the weight loader to return success or not + # here since otherwise we may skip experts with other + # available replicas. + weight_loader = typing.cast(Callable[..., bool], + param.weight_loader) + success = weight_loader(param, + loaded_weight, + name_mapped, + shard_id=shard_id, + expert_id=expert_id, + return_success=True) + if success: + name = name_mapped + break + else: + if is_expert_weight: + # We've checked that this is an expert weight + # However it's not mapped locally to this rank + # So we simply skip it + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + return loaded_params + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + return FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_local_experts) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ @@ -616,6 +777,10 @@ def load_weights(self, weights: Iterable[tuple[str, return self._load_weights_mxfp4(ep_rank_end, ep_rank_start, heads_per_rank, head_start, weights, stacked_params_mapping) + elif quant_method == "quark": + return self._load_weights_quark(ep_rank_end, ep_rank_start, + heads_per_rank, head_start, + weights, stacked_params_mapping) else: return self._load_weights_other(ep_rank_end, ep_rank_start, heads_per_rank, head_start, @@ -648,6 +813,13 @@ class GptOssForCausalLM(nn.Module): # MoE Bias ".gate_up_proj_bias": ".w13_bias", ".down_proj_bias": ".w2_bias", + + # For quark format + ".gate_up_proj.weight": ".w13_weight", + ".gate_up_proj.weight_scale": ".w13_weight_scale", + ".gate_up_proj.bias": ".w13_bias", + ".gate_up_proj.input_scale": ".w13_input_scale", + ".down_proj.input_scale": ".w2_input_scale" }, ) @@ -692,4 +864,7 @@ def load_weights(self, weights: Iterable[tuple[str, skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) - return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) \ No newline at end of file + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + return self.model.get_expert_mapping() \ No newline at end of file