diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 9d09c46245aa..b2e34c286773 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -63,7 +63,7 @@ def should_use_flashinfer_mxfp4(): if current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_FUSED_MOE_A16W4: import aiter - from aiter.fused_moe import fused_topk, moe_sorting + from aiter.fused_moe import fused_moe, fused_topk, moe_sorting from aiter.ops.shuffle import shuffle_mxfp4_weight, shuffle_mxfp4_scale class Mxfp4Config(QuantizationConfig): @@ -690,51 +690,29 @@ def apply( token_num = x.shape[0] BLOCKM = 16 if token_num < 2048 else 32 topk_weights, topk_ids = fused_topk(x, router_logits, top_k, True) - sorted_ids, sorted_weights, sorted_expert_ids, num_valid_ids, moe_out = moe_sorting( - topk_ids, - topk_weights, - self.num_experts, - x.shape[1], - torch.bfloat16, - BLOCKM - ) - _, n1, k1 = self.w13_weight_aiter_tensor.shape - _, k2, n2 = self.w2_weight_aiter_tensor.shape - D = n2 if k2 == k1 else n2*2 - cktile_moe_out1 = torch.empty((token_num, top_k, D), dtype=torch.bfloat16, device=x.device) - aiter.moe_cktile2stages_gemm1( + return fused_moe( x, self.w13_weight_aiter_tensor, - cktile_moe_out1, - sorted_ids, - sorted_expert_ids, - num_valid_ids, - top_k, - self.intermediate_pad // 64 * 64 * 2, - self.hidden_pad // 128 * 128, # k_pad_zeros - None, # sorted_weights - None, - self.w13_scale_aiter_tensor, - self.w13_bias_aiter_tensor, - BLOCKM, # block_size - ) - aiter.moe_cktile2stages_gemm2( - cktile_moe_out1, self.w2_weight_aiter_tensor, - moe_out, - sorted_ids, - sorted_expert_ids, - num_valid_ids, - top_k, - self.hidden_pad // 64 * 64, # n_pad_zeros - self.intermediate_pad // 128 * 128, - sorted_weights, # sorted_weights - None, - self.w2_scale_aiter_tensor, - layer.w2_bias, - BLOCKM, # block_size + topk_weights, + topk_ids, + expert_mask=None, + activation=aiter.ActivationType.Swiglu, + quant_type=aiter.QuantType.per_1x32, + doweight_stage1=False, + w1_scale=self.w13_scale_aiter_tensor, + w2_scale=self.w2_scale_aiter_tensor, + a1_scale=None, + a2_scale=None, + block_size_M=BLOCKM, + num_local_tokens=None, + moe_sorting_dispatch_policy=0, + dtype=None, + hidden_pad=self.hidden_pad, + intermediate_pad=self.intermediate_pad, + bias1=self.w13_bias_aiter_tensor, + bias2=layer.w2_bias, ) - return moe_out from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( # noqa: E501 triton_kernel_moe_forward)