Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
VLLM_ROCM_USE_AITER: bool = False
VLLM_ROCM_USE_AITER_LINEAR: bool = True
VLLM_ROCM_USE_AITER_MOE: bool = True
VLLM_ROCM_USE_AITER_2STAGE_MOE: bool = True
VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE: bool = False
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
VLLM_ROCM_USE_AITER_FA: bool = True
Expand Down Expand Up @@ -591,6 +592,11 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
lambda: (os.getenv("VLLM_ROCM_USE_AITER_MOE", "True").lower() in
("true", "1")),

# use aiter ck fused moe op if ater ops are enabled
"VLLM_ROCM_USE_AITER_2STAGE_MOE":
lambda: (os.getenv("VLLM_ROCM_USE_AITER_2STAGE_MOE", "True").lower() in
("true", "1")),

# Whether to use aiter block scaled moe kernel.
# By default this is disabled.
"VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE":
Expand Down
10 changes: 7 additions & 3 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,15 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
requires_grad=False)
# Lazy import to avoid importing triton.
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled, shuffle_weights)
is_rocm_aiter_2stage_moe_enabled, is_rocm_aiter_moe_enabled,
shuffle_weights)
if is_rocm_aiter_moe_enabled():
layout = (32, 32) if is_rocm_aiter_2stage_moe_enabled() else (16,
16)
# reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data)
shuffled_w13, shuffled_w2 = shuffle_weights(layer.w13_weight.data,
layer.w2_weight.data,
layout=layout)

layer.w13_weight = torch.nn.Parameter(shuffled_w13,
requires_grad=False)
Expand Down
32 changes: 30 additions & 2 deletions vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ def is_rocm_aiter_moe_enabled() -> bool:
and envs.VLLM_ROCM_USE_AITER \


def is_rocm_aiter_2stage_moe_enabled() -> bool:
return current_platform.is_rocm() \
and envs.VLLM_ROCM_USE_AITER_2STAGE_MOE \
and envs.VLLM_ROCM_USE_AITER


def is_rocm_aiter_block_scaled_moe_enabled() -> bool:
return is_rocm_aiter_moe_enabled() and \
envs.VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE
Expand All @@ -30,6 +36,8 @@ def rocm_aiter_fused_experts(
w2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
expert_mask: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
**kwagrs # Ignore additional keyword arguments
) -> torch.Tensor:

Expand Down Expand Up @@ -90,6 +98,17 @@ def rocm_aiter_fused_experts(
return out_asm

elif use_fp8_w8a8:
if is_rocm_aiter_2stage_moe_enabled():
from aiter.fused_moe_bf16_asm import ck_moe_2stages
return ck_moe_2stages(a1=hidden_states,
w1=w1,
w2=w2,
topk_weight=topk_weights,
topk_ids=topk_ids,
fc1_scale=w1_scale,
fc2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale)
return rocm_aiter_asm_fmoe.asm_moe(hidden_states=hidden_states,
w1=w1,
w2=w2,
Expand All @@ -101,6 +120,14 @@ def rocm_aiter_fused_experts(
fc2_smooth_scale=None,
a16=False)

if is_rocm_aiter_2stage_moe_enabled():
from aiter.fused_moe_bf16_asm import ck_moe_2stages
return ck_moe_2stages(a1=hidden_states,
w1=w1,
w2=w2,
topk_weight=topk_weights,
topk_ids=topk_ids)

return rocm_aiter.ck_moe(hidden_states=hidden_states,
w1=w1,
w2=w2,
Expand All @@ -120,7 +147,8 @@ def rocm_aiter_topk_softmax(topk_weights: torch.Tensor,
return topk_weights, topk_indices


def shuffle_weights(*tensors: torch.Tensor) -> tuple[torch.Tensor, ...]:
def shuffle_weights(*tensors: torch.Tensor,
layout: tuple[int, int]) -> tuple[torch.Tensor, ...]:
"""
Applies shuffle_weight function from AITER to each
input tensor and returns them.
Expand All @@ -133,7 +161,7 @@ def shuffle_weights(*tensors: torch.Tensor) -> tuple[torch.Tensor, ...]:
"""
from aiter.ops.shuffle import shuffle_weight

return tuple(shuffle_weight(tensor) for tensor in tensors)
return tuple(shuffle_weight(tensor, layout=layout) for tensor in tensors)


def expand_weights(*tensors: torch.Tensor,
Expand Down
26 changes: 18 additions & 8 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,8 +583,9 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
def process_weights_after_loading(self, layer: Module) -> None:
# Lazy import to avoid importing triton too early.
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
expand_weights, is_rocm_aiter_block_scaled_moe_enabled,
is_rocm_aiter_moe_enabled, shuffle_weights)
expand_weights, is_rocm_aiter_2stage_moe_enabled,
is_rocm_aiter_block_scaled_moe_enabled, is_rocm_aiter_moe_enabled,
shuffle_weights)

# TODO (rob): refactor block quant into separate class.
if self.block_quant:
Expand Down Expand Up @@ -614,7 +615,9 @@ def process_weights_after_loading(self, layer: Module) -> None:
if is_rocm_aiter_block_scaled_moe_enabled():
# reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data)
layer.w13_weight.data,
layer.w2_weight.data,
layout=(16, 16))

layer.w13_weight = torch.nn.Parameter(shuffled_w13,
requires_grad=False)
Expand Down Expand Up @@ -672,9 +675,12 @@ def process_weights_after_loading(self, layer: Module) -> None:
w13_scales.contiguous(), requires_grad=False)
layer.w2_weight_scale = torch.nn.Parameter(
w2_scales.contiguous(), requires_grad=False)

shuffled_w13, shuffled_w2 = shuffle_weights(
layer.w13_weight, layer.w2_weight)
layout = (32,
32) if is_rocm_aiter_2stage_moe_enabled() else (16,
16)
shuffled_w13, shuffled_w2 = shuffle_weights(layer.w13_weight,
layer.w2_weight,
layout=layout)

layer.w13_weight = torch.nn.Parameter(shuffled_w13,
requires_grad=False)
Expand Down Expand Up @@ -759,8 +765,12 @@ def process_weights_after_loading(self, layer: Module) -> None:
layer.w2_weight_scale = torch.nn.Parameter(
w2_scales.contiguous(), requires_grad=False)

shuffled_w13, shuffled_w2 = shuffle_weights(
layer.w13_weight, layer.w2_weight)
layout = (32,
32) if is_rocm_aiter_2stage_moe_enabled() else (16,
16)
shuffled_w13, shuffled_w2 = shuffle_weights(layer.w13_weight,
layer.w2_weight,
layout=layout)

layer.w13_weight = torch.nn.Parameter(shuffled_w13,
requires_grad=False)
Expand Down