11# SPDX-License-Identifier: Apache-2.0 
22""" CUTLASS based Fused MoE kernels.""" 
3- import  os 
43from  typing  import  Optional 
54
65import  torch 
@@ -271,8 +270,6 @@ def cutlass_moe_fp8(
271270
272271FLOAT4_E2M1_MAX  =  scalar_types .float4_e2m1f .max ()
273272FLOAT8_E4M3_MAX  =  torch .finfo (torch .float8_e4m3fn ).max 
274- MAX_TOKENS_PER_EXPERT  =  int (
275-     os .environ .get ('VLLM_MODELOPT_MAX_TOKENS_PER_EXPERT' , '65536' ))
276273
277274
278275def  cutlass_moe_fp4 (a : torch .Tensor , a1_gscale : torch .Tensor ,
@@ -330,10 +327,7 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
330327    assert  a .dtype  in  [torch .half , torch .bfloat16 ], "Invalid input dtype" 
331328    assert  (topk_weights .shape [0 ] ==  m  and  topk_ids .shape [0 ]
332329            ==  m ), ("topk must be provided for each row of a" )
333-     assert  (m  <=  MAX_TOKENS_PER_EXPERT ), (
334-         f"m must be less than MAX_TOKENS_PER_EXPERT({ MAX_TOKENS_PER_EXPERT }  )" 
335-         f" for cutlass_moe_fp4, observed m = { m }  . Use" 
336-         f" VLLM_MODELOPT_MAX_TOKENS_PER_EXPERT to set this value." )
330+ 
337331    out_dtype  =  a .dtype 
338332    num_topk  =  topk_ids .shape [1 ]
339333
@@ -362,8 +356,7 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
362356        expert_offsets ,
363357        blockscale_offsets ,
364358        num_topk ,
365-         expert_map = a_map ,
366-         MAX_TOKENS_PER_EXPERT = MAX_TOKENS_PER_EXPERT )
359+         expert_map = a_map )
367360
368361    c1  =  ops .cutlass_fp4_moe_mm (rep_a_fp4 , w1_fp4 , rep_a_blockscale ,
369362                                w1_blockscale , w1_alphas , problem_sizes1 ,
@@ -378,12 +371,7 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
378371    torch .ops ._C .silu_and_mul (intermediate , c1 )
379372
380373    int_fp4 , int_blockscale  =  ops .scaled_fp4_experts_quant (
381-         intermediate ,
382-         a2_gscale ,
383-         expert_offsets ,
384-         blockscale_offsets ,
385-         num_topk ,
386-         MAX_TOKENS_PER_EXPERT = MAX_TOKENS_PER_EXPERT )
374+         intermediate , a2_gscale , expert_offsets , blockscale_offsets , num_topk )
387375
388376    c2  =  ops .cutlass_fp4_moe_mm (int_fp4 , w2_fp4 , int_blockscale , w2_blockscale ,
389377                                w2_alphas , problem_sizes2 , expert_offsets [:- 1 ],
0 commit comments