1313 MoEPrepareAndFinalizeNoEP )
1414from vllm .model_executor .layers .fused_moe .topk_weight_and_reduce import (
1515 TopKWeightAndReduceDelegate )
16- from vllm .model_executor .layers .fused_moe .utils import (_fp8_quantize ,
16+ from vllm .model_executor .layers .fused_moe .utils import (_fp8_perm ,
17+ _fp8_quantize ,
1718 _resize_cache ,
1819 extract_required_args )
1920from vllm .scalar_type import scalar_types
@@ -34,10 +35,6 @@ def run_cutlass_moe_fp8(
3435 w2_scale : Optional [torch .Tensor ],
3536 a1q_scale : Optional [torch .Tensor ],
3637 a2_scale : Optional [torch .Tensor ],
37- ab_strides1 : torch .Tensor ,
38- ab_strides2 : torch .Tensor ,
39- c_strides1 : torch .Tensor ,
40- c_strides2 : torch .Tensor ,
4138 workspace13 : torch .Tensor ,
4239 workspace2 : torch .Tensor ,
4340 expert_num_tokens : Optional [torch .Tensor ],
@@ -156,11 +153,27 @@ def run_cutlass_moe_fp8(
156153 problem_sizes1 , problem_sizes2 , a_map ,
157154 c_map , global_num_experts , N , K )
158155
159- a1q = ops .shuffle_rows (a1q , a_map )
160- a1q_scale = (ops .shuffle_rows (a1q_scale , a_map )
161- if per_act_token else a1q_scale )
156+ a1q = _fp8_perm (a1q , a_map )
157+ a1q_scale = a1q_scale [a_map ] if per_act_token else a1q_scale
162158 expert_offsets = expert_offsets [:- 1 ]
163159
160+ ab_strides1 = torch .full ((w1 .size (0 ), ),
161+ K ,
162+ device = device ,
163+ dtype = torch .int64 )
164+ c_strides1 = torch .full ((w1 .size (0 ), ),
165+ 2 * N ,
166+ device = device ,
167+ dtype = torch .int64 )
168+ ab_strides2 = torch .full ((w1 .size (0 ), ),
169+ N ,
170+ device = device ,
171+ dtype = torch .int64 )
172+ c_strides2 = torch .full ((w1 .size (0 ), ),
173+ K ,
174+ device = device ,
175+ dtype = torch .int64 )
176+
164177 if use_batched_format :
165178 c1 = _resize_cache (workspace13 , (local_E * padded_M , N * 2 ))
166179 c2 = _resize_cache (workspace2 , (local_E * padded_M , N ))
@@ -197,8 +210,7 @@ def run_cutlass_moe_fp8(
197210 else :
198211 # We can't do this inplace because output may point to the same tensor
199212 # as c3.
200- output .copy_ (ops .shuffle_rows (c3 , c_map ).view (M * topk , K ),
201- non_blocking = True )
213+ output .copy_ (c3 [c_map ].view (M * topk , K ), non_blocking = True )
202214
203215
204216# TODO (bnell): split class batched vs. non-batched?
@@ -211,10 +223,6 @@ def __init__(
211223 out_dtype : Optional [torch .dtype ],
212224 per_act_token_quant : bool ,
213225 per_out_ch_quant : bool ,
214- ab_strides1 : torch .Tensor ,
215- ab_strides2 : torch .Tensor ,
216- c_strides1 : torch .Tensor ,
217- c_strides2 : torch .Tensor ,
218226 block_shape : Optional [list [int ]] = None ,
219227 num_dispatchers : Optional [int ] = None ,
220228 use_batched_format : bool = False ,
@@ -231,10 +239,6 @@ def __init__(
231239 self .max_experts_per_worker = max_experts_per_worker
232240 self .num_dispatchers = num_dispatchers
233241 self .out_dtype = out_dtype
234- self .ab_strides1 = ab_strides1
235- self .ab_strides2 = ab_strides2
236- self .c_strides1 = c_strides1
237- self .c_strides2 = c_strides2
238242 self .use_batched_format = use_batched_format
239243
240244 @property
@@ -314,8 +318,7 @@ def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
314318 run_cutlass_moe_fp8 (
315319 output , hidden_states , w1 , w2 , topk_ids , activation_callable ,
316320 global_num_experts , expert_map , w1_scale , w2_scale , a1q_scale ,
317- a2_scale , self .ab_strides1 , self .ab_strides2 , self .c_strides1 ,
318- self .c_strides2 , workspace13 , workspace2 , expert_num_tokens ,
321+ a2_scale , workspace13 , workspace2 , expert_num_tokens ,
319322 self .out_dtype if self .out_dtype is not None else in_dtype ,
320323 self .per_act_token_quant , self .per_out_ch_quant ,
321324 self .use_batched_format )
@@ -329,10 +332,6 @@ def cutlass_moe_fp8(
329332 topk_ids : torch .Tensor ,
330333 w1_scale : torch .Tensor ,
331334 w2_scale : torch .Tensor ,
332- ab_strides1 : torch .Tensor ,
333- ab_strides2 : torch .Tensor ,
334- c_strides1 : torch .Tensor ,
335- c_strides2 : torch .Tensor ,
336335 per_act_token : Optional [bool ] = None ,
337336 activation : str = "silu" ,
338337 a1_scale : Optional [torch .Tensor ] = None ,
@@ -360,17 +359,6 @@ def cutlass_moe_fp8(
360359 Shape: [num_experts] or [num_experts, 2N]
361360 - w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
362361 Shape: [num_experts] or [num_experts, K]
363- - ab_strides1 (torch.Tensor): The input/weight strides for the first gemm.
364- Shape: [num_experts]
365- - ab_strides2 (torch.Tensor): The input/weight strides for the second gemm.
366- Shape: [num_experts]
367- - c_strides1 (torch.Tensor): The output strides for the first gemm.
368- Shape: [num_experts]
369- - c_strides2 (torch.Tensor): The output strides for the second gemm.
370- Shape: [num_experts]
371- - per_act_token (Optional[bool]): Whether the scale is per-token or
372- per-tensor.
373- - activation (str): The activation function to use.
374362 - a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
375363 Shape: scalar or [M]
376364 - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
@@ -403,10 +391,6 @@ def cutlass_moe_fp8(
403391 out_dtype = a .dtype ,
404392 per_act_token_quant = per_act_token ,
405393 per_out_ch_quant = per_out_ch ,
406- ab_strides1 = ab_strides1 ,
407- ab_strides2 = ab_strides2 ,
408- c_strides1 = c_strides1 ,
409- c_strides2 = c_strides2 ,
410394 use_batched_format = False ,
411395 ),
412396 )
0 commit comments