diff --git a/vllm/envs.py b/vllm/envs.py index 69186ec3c695..4fc2c4fa889d 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -84,6 +84,7 @@ VLLM_ROCM_USE_AITER: bool = False VLLM_ROCM_USE_AITER_LINEAR: bool = True VLLM_ROCM_USE_AITER_MOE: bool = True + VLLM_ROCM_USE_AITER_2STAGE_MOE: bool = True VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE: bool = False VLLM_ROCM_USE_AITER_RMSNORM: bool = True VLLM_ROCM_USE_AITER_FA: bool = True @@ -591,6 +592,11 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: lambda: (os.getenv("VLLM_ROCM_USE_AITER_MOE", "True").lower() in ("true", "1")), + # use aiter ck fused moe op if ater ops are enabled + "VLLM_ROCM_USE_AITER_2STAGE_MOE": + lambda: (os.getenv("VLLM_ROCM_USE_AITER_2STAGE_MOE", "True").lower() in + ("true", "1")), + # Whether to use aiter block scaled moe kernel. # By default this is disabled. "VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE": diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 5cbbe49bbba4..15be14ac6a44 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -119,11 +119,15 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: requires_grad=False) # Lazy import to avoid importing triton. from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - is_rocm_aiter_moe_enabled, shuffle_weights) + is_rocm_aiter_2stage_moe_enabled, is_rocm_aiter_moe_enabled, + shuffle_weights) if is_rocm_aiter_moe_enabled(): + layout = (32, 32) if is_rocm_aiter_2stage_moe_enabled() else (16, + 16) # reshaping weights is required for aiter moe kernel. - shuffled_w13, shuffled_w2 = shuffle_weights( - layer.w13_weight.data, layer.w2_weight.data) + shuffled_w13, shuffled_w2 = shuffle_weights(layer.w13_weight.data, + layer.w2_weight.data, + layout=layout) layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False) diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index ac158a7eee53..a3c4ffaa818d 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -13,6 +13,12 @@ def is_rocm_aiter_moe_enabled() -> bool: and envs.VLLM_ROCM_USE_AITER \ +def is_rocm_aiter_2stage_moe_enabled() -> bool: + return current_platform.is_rocm() \ + and envs.VLLM_ROCM_USE_AITER_2STAGE_MOE \ + and envs.VLLM_ROCM_USE_AITER + + def is_rocm_aiter_block_scaled_moe_enabled() -> bool: return is_rocm_aiter_moe_enabled() and \ envs.VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE @@ -30,6 +36,8 @@ def rocm_aiter_fused_experts( w2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None, expert_mask: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, **kwagrs # Ignore additional keyword arguments ) -> torch.Tensor: @@ -90,6 +98,17 @@ def rocm_aiter_fused_experts( return out_asm elif use_fp8_w8a8: + if is_rocm_aiter_2stage_moe_enabled(): + from aiter.fused_moe_bf16_asm import ck_moe_2stages + return ck_moe_2stages(a1=hidden_states, + w1=w1, + w2=w2, + topk_weight=topk_weights, + topk_ids=topk_ids, + fc1_scale=w1_scale, + fc2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale) return rocm_aiter_asm_fmoe.asm_moe(hidden_states=hidden_states, w1=w1, w2=w2, @@ -101,6 +120,14 @@ def rocm_aiter_fused_experts( fc2_smooth_scale=None, a16=False) + if is_rocm_aiter_2stage_moe_enabled(): + from aiter.fused_moe_bf16_asm import ck_moe_2stages + return ck_moe_2stages(a1=hidden_states, + w1=w1, + w2=w2, + topk_weight=topk_weights, + topk_ids=topk_ids) + return rocm_aiter.ck_moe(hidden_states=hidden_states, w1=w1, w2=w2, @@ -120,7 +147,8 @@ def rocm_aiter_topk_softmax(topk_weights: torch.Tensor, return topk_weights, topk_indices -def shuffle_weights(*tensors: torch.Tensor) -> tuple[torch.Tensor, ...]: +def shuffle_weights(*tensors: torch.Tensor, + layout: tuple[int, int]) -> tuple[torch.Tensor, ...]: """ Applies shuffle_weight function from AITER to each input tensor and returns them. @@ -133,7 +161,7 @@ def shuffle_weights(*tensors: torch.Tensor) -> tuple[torch.Tensor, ...]: """ from aiter.ops.shuffle import shuffle_weight - return tuple(shuffle_weight(tensor) for tensor in tensors) + return tuple(shuffle_weight(tensor, layout=layout) for tensor in tensors) def expand_weights(*tensors: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 24648582f476..b10cc322b51f 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -583,8 +583,9 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int, def process_weights_after_loading(self, layer: Module) -> None: # Lazy import to avoid importing triton too early. from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - expand_weights, is_rocm_aiter_block_scaled_moe_enabled, - is_rocm_aiter_moe_enabled, shuffle_weights) + expand_weights, is_rocm_aiter_2stage_moe_enabled, + is_rocm_aiter_block_scaled_moe_enabled, is_rocm_aiter_moe_enabled, + shuffle_weights) # TODO (rob): refactor block quant into separate class. if self.block_quant: @@ -614,7 +615,9 @@ def process_weights_after_loading(self, layer: Module) -> None: if is_rocm_aiter_block_scaled_moe_enabled(): # reshaping weights is required for aiter moe kernel. shuffled_w13, shuffled_w2 = shuffle_weights( - layer.w13_weight.data, layer.w2_weight.data) + layer.w13_weight.data, + layer.w2_weight.data, + layout=(16, 16)) layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False) @@ -672,9 +675,12 @@ def process_weights_after_loading(self, layer: Module) -> None: w13_scales.contiguous(), requires_grad=False) layer.w2_weight_scale = torch.nn.Parameter( w2_scales.contiguous(), requires_grad=False) - - shuffled_w13, shuffled_w2 = shuffle_weights( - layer.w13_weight, layer.w2_weight) + layout = (32, + 32) if is_rocm_aiter_2stage_moe_enabled() else (16, + 16) + shuffled_w13, shuffled_w2 = shuffle_weights(layer.w13_weight, + layer.w2_weight, + layout=layout) layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False) @@ -759,8 +765,12 @@ def process_weights_after_loading(self, layer: Module) -> None: layer.w2_weight_scale = torch.nn.Parameter( w2_scales.contiguous(), requires_grad=False) - shuffled_w13, shuffled_w2 = shuffle_weights( - layer.w13_weight, layer.w2_weight) + layout = (32, + 32) if is_rocm_aiter_2stage_moe_enabled() else (16, + 16) + shuffled_w13, shuffled_w2 = shuffle_weights(layer.w13_weight, + layer.w2_weight, + layout=layout) layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)