@@ -18,6 +18,50 @@ def is_rocm_aiter_block_scaled_moe_enabled() -> bool:
1818 envs .VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE
1919
2020
21+ def is_rocm_aiter_channel_scaled_moe_enabled () -> bool :
22+ return is_rocm_aiter_moe_enabled () and \
23+ envs .VLLM_ROCM_USE_AITER_FP8_CHANNEL_SCALED_MOE
24+
25+
26+ def asm_moe_tkw1_impl (sorted_ids : torch .Tensor ,
27+ sorted_weights : torch .Tensor ,
28+ sorted_expert_ids : torch .Tensor ,
29+ num_valid_ids : torch .Tensor ,
30+ moe_buf : torch .Tensor ,
31+ hidden_states : torch .Tensor ,
32+ w1 : torch .Tensor ,
33+ w2 : torch .Tensor ,
34+ topk_weight : torch .Tensor ,
35+ topk_ids : torch .Tensor ,
36+ fc1_scale : Optional [torch .Tensor ] = None ,
37+ fc2_scale : Optional [torch .Tensor ] = None ,
38+ fc1_smooth_scale : Optional [torch .Tensor ] = None ,
39+ fc2_smooth_scale : Optional [torch .Tensor ] = None ,
40+ activation_str : str = "silu" ) -> None :
41+ import aiter as rocm_aiter
42+
43+ if activation_str == "silu" :
44+ activation = rocm_aiter .ActivationType .Silu
45+ elif activation_str == "gelu" :
46+ activation = rocm_aiter .ActivationType .Gelu
47+ else :
48+ activation = rocm_aiter .ActivationType .Silu
49+
50+ E , model_dim , _ = w2 .shape
51+ M , topk = topk_ids .shape
52+ device = topk_ids .device
53+
54+ a8_type = (w1 .dtype if w1 .dtype != torch .int32 and w1 .dtype != torch .uint32
55+ else torch .float8_e4m3fnuz )
56+ a8 = torch .empty ((M , model_dim ), dtype = a8_type , device = device )
57+ a8_scale = torch .empty (M , dtype = torch .float , device = device )
58+ rocm_aiter .dynamic_per_token_scaled_fp8_quant (a8 , hidden_states , a8_scale )
59+ fmoe_func = rocm_aiter .fmoe_g1u1_tkw1
60+ fmoe_func (moe_buf , a8 , w1 , w2 , sorted_ids , sorted_weights ,
61+ sorted_expert_ids , num_valid_ids , topk , a8_scale , fc1_scale ,
62+ fc2_scale , fc2_smooth_scale , activation )
63+
64+
2165def rocm_aiter_fused_experts (
2266 * ,
2367 hidden_states : torch .Tensor ,
@@ -26,10 +70,12 @@ def rocm_aiter_fused_experts(
2670 topk_weights : torch .Tensor ,
2771 topk_ids : torch .Tensor ,
2872 use_fp8_w8a8 : bool = False ,
73+ apply_router_weight_on_input : bool = False ,
2974 w1_scale : Optional [torch .Tensor ] = None ,
3075 w2_scale : Optional [torch .Tensor ] = None ,
3176 block_shape : Optional [List [int ]] = None ,
3277 expert_mask : Optional [torch .Tensor ] = None ,
78+ activation : str = "silu" ,
3379 ** kwagrs # Ignore additional keyword arguments
3480) -> torch .Tensor :
3581
@@ -38,8 +84,22 @@ def rocm_aiter_fused_experts(
3884
3985 from vllm .model_executor .layers .quantization .utils .fp8_utils import (
4086 per_token_group_quant_fp8 )
87+
88+ if apply_router_weight_on_input :
89+ _ , topk = topk_weights .shape
90+ assert (
91+ topk == 1
92+ ), "Only support topk=1 when `apply_router_weight_on_input` is True"
93+
94+ hidden_states = hidden_states * topk_weights .to (hidden_states .dtype )
95+ topk_ids = topk_ids .to (torch .int32 )
96+ topk_weights = torch .ones_like (topk_weights , dtype = torch .float32 )
97+
98+ if is_rocm_aiter_block_scaled_moe_enabled () and use_fp8_w8a8 :
99+ assert not apply_router_weight_on_input , (
100+ "apply_router_weight_on_input is not supported for block scaled moe"
101+ )
41102
42- if envs .VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE and use_fp8_w8a8 :
43103 assert w1_scale is not None
44104 assert w2_scale is not None
45105
@@ -88,8 +148,53 @@ def rocm_aiter_fused_experts(
88148 None ,
89149 )
90150 return out_asm
151+
152+ elif is_rocm_aiter_channel_scaled_moe_enabled () and use_fp8_w8a8 :
153+ topk_weights = topk_weights .to (torch .float32 )
154+ topk_ids = topk_ids .to (torch .int32 )
155+
156+ E , model_dim , _ = w2 .shape
157+ dtype = hidden_states .dtype
158+
159+ if expert_mask is not None :
160+ E = expert_mask .numel ()
161+
162+ (
163+ sorted_token_ids ,
164+ sorted_weight_buf ,
165+ sorted_expert_ids ,
166+ num_valid_ids ,
167+ out_asm ,
168+ ) = rocm_aiter_asm_fmoe .moe_sorting_ck (topk_ids ,
169+ topk_weights ,
170+ E ,
171+ model_dim ,
172+ dtype ,
173+ expert_mask = expert_mask )
174+
175+ asm_moe_tkw1_impl (
176+ sorted_ids = sorted_token_ids ,
177+ sorted_weights = sorted_weight_buf ,
178+ sorted_expert_ids = sorted_expert_ids ,
179+ num_valid_ids = num_valid_ids ,
180+ moe_buf = out_asm ,
181+ hidden_states = hidden_states ,
182+ w1 = w1 ,
183+ w2 = w2 ,
184+ topk_weight = topk_weights ,
185+ topk_ids = topk_ids ,
186+ fc1_scale = w1_scale ,
187+ fc2_scale = w2_scale ,
188+ fc1_smooth_scale = None ,
189+ fc2_smooth_scale = None ,
190+ activation_str = activation )
191+
192+ return out_asm
91193
92194 elif use_fp8_w8a8 :
195+ assert not apply_router_weight_on_input , (
196+ "apply_router_weight_on_input is not supported for fp8_w8a8" )
197+
93198 return rocm_aiter_asm_fmoe .asm_moe (hidden_states = hidden_states ,
94199 w1 = w1 ,
95200 w2 = w2 ,
0 commit comments