1313 MoEPrepareAndFinalizeNoEP )
1414from vllm .model_executor .layers .fused_moe .topk_weight_and_reduce import (
1515 TopKWeightAndReduceDelegate )
16- from vllm .model_executor .layers .fused_moe .utils import (_fp8_perm ,
17- _fp8_quantize ,
16+ from vllm .model_executor .layers .fused_moe .utils import (_fp8_quantize ,
1817 _resize_cache )
1918from vllm .scalar_type import scalar_types
2019
@@ -34,6 +33,10 @@ def run_cutlass_moe_fp8(
3433 w2_scale : Optional [torch .Tensor ],
3534 a1q_scale : Optional [torch .Tensor ],
3635 a2_scale : Optional [torch .Tensor ],
36+ ab_strides1 : torch .Tensor ,
37+ ab_strides2 : torch .Tensor ,
38+ c_strides1 : torch .Tensor ,
39+ c_strides2 : torch .Tensor ,
3740 workspace13 : torch .Tensor ,
3841 workspace2 : torch .Tensor ,
3942 expert_num_tokens : Optional [torch .Tensor ],
@@ -152,27 +155,11 @@ def run_cutlass_moe_fp8(
152155 problem_sizes1 , problem_sizes2 , a_map ,
153156 c_map , global_num_experts , N , K )
154157
155- a1q = _fp8_perm (a1q , a_map )
156- a1q_scale = a1q_scale [a_map ] if per_act_token else a1q_scale
158+ a1q = ops .shuffle_rows (a1q , a_map )
159+ a1q_scale = (ops .shuffle_rows (a1q_scale , a_map )
160+ if per_act_token else a1q_scale )
157161 expert_offsets = expert_offsets [:- 1 ]
158162
159- ab_strides1 = torch .full ((w1 .size (0 ), ),
160- K ,
161- device = device ,
162- dtype = torch .int64 )
163- c_strides1 = torch .full ((w1 .size (0 ), ),
164- 2 * N ,
165- device = device ,
166- dtype = torch .int64 )
167- ab_strides2 = torch .full ((w1 .size (0 ), ),
168- N ,
169- device = device ,
170- dtype = torch .int64 )
171- c_strides2 = torch .full ((w1 .size (0 ), ),
172- K ,
173- device = device ,
174- dtype = torch .int64 )
175-
176163 if use_batched_format :
177164 c1 = _resize_cache (workspace13 , (local_E * padded_M , N * 2 ))
178165 c2 = _resize_cache (workspace2 , (local_E * padded_M , N ))
@@ -209,7 +196,8 @@ def run_cutlass_moe_fp8(
209196 else :
210197 # We can't do this inplace because output may point to the same tensor
211198 # as c3.
212- output .copy_ (c3 [c_map ].view (M * topk , K ), non_blocking = True )
199+ output .copy_ (ops .shuffle_rows (c3 , c_map ).view (M * topk , K ),
200+ non_blocking = True )
213201
214202
215203# TODO (bnell): split class batched vs. non-batched?
@@ -222,6 +210,10 @@ def __init__(
222210 out_dtype : Optional [torch .dtype ],
223211 per_act_token_quant : bool ,
224212 per_out_ch_quant : bool ,
213+ ab_strides1 : torch .Tensor ,
214+ ab_strides2 : torch .Tensor ,
215+ c_strides1 : torch .Tensor ,
216+ c_strides2 : torch .Tensor ,
225217 block_shape : Optional [list [int ]] = None ,
226218 num_dispatchers : Optional [int ] = None ,
227219 use_batched_format : bool = False ,
@@ -238,6 +230,10 @@ def __init__(
238230 self .max_experts_per_worker = max_experts_per_worker
239231 self .num_dispatchers = num_dispatchers
240232 self .out_dtype = out_dtype
233+ self .ab_strides1 = ab_strides1
234+ self .ab_strides2 = ab_strides2
235+ self .c_strides1 = c_strides1
236+ self .c_strides2 = c_strides2
241237 self .use_batched_format = use_batched_format
242238
243239 @property
@@ -316,7 +312,8 @@ def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
316312 run_cutlass_moe_fp8 (
317313 output , hidden_states , w1 , w2 , topk_ids , activation_callable ,
318314 global_num_experts , expert_map , w1_scale , w2_scale , a1q_scale ,
319- a2_scale , workspace13 , workspace2 , expert_num_tokens ,
315+ a2_scale , self .ab_strides1 , self .ab_strides2 , self .c_strides1 ,
316+ self .c_strides2 , workspace13 , workspace2 , expert_num_tokens ,
320317 self .out_dtype if self .out_dtype is not None else in_dtype ,
321318 self .per_act_token_quant , self .per_out_ch_quant ,
322319 self .use_batched_format )
@@ -330,6 +327,10 @@ def cutlass_moe_fp8(
330327 topk_ids : torch .Tensor ,
331328 w1_scale : torch .Tensor ,
332329 w2_scale : torch .Tensor ,
330+ ab_strides1 : torch .Tensor ,
331+ ab_strides2 : torch .Tensor ,
332+ c_strides1 : torch .Tensor ,
333+ c_strides2 : torch .Tensor ,
333334 per_act_token : Optional [bool ] = None ,
334335 activation : str = "silu" ,
335336 a1_scale : Optional [torch .Tensor ] = None ,
@@ -357,6 +358,17 @@ def cutlass_moe_fp8(
357358 Shape: [num_experts] or [num_experts, 2N]
358359 - w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
359360 Shape: [num_experts] or [num_experts, K]
361+ - ab_strides1 (torch.Tensor): The input/weight strides for the first gemm.
362+ Shape: [num_experts]
363+ - ab_strides2 (torch.Tensor): The input/weight strides for the second gemm.
364+ Shape: [num_experts]
365+ - c_strides1 (torch.Tensor): The output strides for the first gemm.
366+ Shape: [num_experts]
367+ - c_strides2 (torch.Tensor): The output strides for the second gemm.
368+ Shape: [num_experts]
369+ - per_act_token (Optional[bool]): Whether the scale is per-token or
370+ per-tensor.
371+ - activation (str): The activation function to use.
360372 - a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
361373 Shape: scalar or [M]
362374 - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
@@ -389,6 +401,10 @@ def cutlass_moe_fp8(
389401 out_dtype = a .dtype ,
390402 per_act_token_quant = per_act_token ,
391403 per_out_ch_quant = per_out_ch ,
404+ ab_strides1 = ab_strides1 ,
405+ ab_strides2 = ab_strides2 ,
406+ c_strides1 = c_strides1 ,
407+ c_strides2 = c_strides2 ,
392408 use_batched_format = False ,
393409 ),
394410 )
0 commit comments