3232 PerTensorScaleParameter )
3333from vllm .model_executor .utils import set_weight_attrs
3434from vllm .platforms import current_platform
35- from vllm .utils import is_navi
35+ from vllm .utils import aiter_2stage_moe_enabled , aiter_moe_enabled , is_navi
3636
37- if envs . VLLM_USE_AITER_MOE :
37+ if aiter_moe_enabled () :
3838 from aiter .fused_moe_bf16_asm import asm_moe
39- if envs . VLLM_USE_AITER_2STAGE_MOE :
39+ if aiter_2stage_moe_enabled () :
4040 from aiter .fused_moe_bf16_asm import ck_moe_2stages
4141 from aiter .ops .shuffle import shuffle_weight
4242
@@ -621,7 +621,7 @@ def process_weights_after_loading(self, layer: Module) -> None:
621621 requires_grad = False )
622622 layer .w2_weight = torch .nn .Parameter (w2_weight ,
623623 requires_grad = False )
624- if envs . VLLM_USE_AITER_MOE :
624+ if aiter_moe_enabled () :
625625 w13_scales = layer .w13_weight_scale .data .unsqueeze (
626626 - 1 ).unsqueeze (- 1 ).expand (
627627 (- 1 , layer .w13_weight .shape [1 ], - 1 ))
@@ -632,13 +632,13 @@ def process_weights_after_loading(self, layer: Module) -> None:
632632 layer .w13_weight_scale = torch .nn .Parameter (
633633 w13_scales .contiguous (), requires_grad = False )
634634
635- if envs . VLLM_USE_AITER_2STAGE_MOE :
636- layer .w13_weight = torch .nn .Parameter (
637- shuffle_weight ( layer .w13_weight , layout = (32 , 32 )),
638- requires_grad = False )
639- layer .w2_weight = torch .nn .Parameter (
640- shuffle_weight ( layer .w2_weight , layout = (32 , 32 )),
641- requires_grad = False )
635+ if aiter_2stage_moe_enabled () :
636+ layer .w13_weight = torch .nn .Parameter (shuffle_weight (
637+ layer .w13_weight , layout = (32 , 32 )),
638+ requires_grad = False )
639+ layer .w2_weight = torch .nn .Parameter (shuffle_weight (
640+ layer .w2_weight , layout = (32 , 32 )),
641+ requires_grad = False )
642642 else :
643643 layer .w13_weight = torch .nn .Parameter (shuffle_weight (
644644 layer .w13_weight ),
@@ -715,32 +715,31 @@ def process_weights_after_loading(self, layer: Module) -> None:
715715 dq_weight , max_w13_scales [expert_id ])
716716 start += shard_size
717717
718- if envs . VLLM_USE_AITER_MOE :
719- if envs . VLLM_USE_AITER_2STAGE_MOE :
718+ if aiter_moe_enabled () :
719+ if aiter_2stage_moe_enabled () :
720720 max_w13_scales = max_w13_scales .unsqueeze (- 1 )
721721 w2_scales = layer .w2_weight_scale .data .unsqueeze (- 1 )
722+ layer .w13_weight = torch .nn .Parameter (shuffle_weight (
723+ layer .w13_weight , layout = (32 , 32 )),
724+ requires_grad = False )
725+ layer .w2_weight = torch .nn .Parameter (shuffle_weight (
726+ layer .w2_weight , layout = (32 , 32 )),
727+ requires_grad = False )
722728 else :
723729 max_w13_scales = max_w13_scales .unsqueeze (- 1 ).unsqueeze (
724730 - 1 ).expand ((- 1 , layer .w13_weight .shape [1 ], - 1 ))
725- w2_scales = layer .w2_weight_scale .data .unsqueeze (- 1 ).unsqueeze (
726- - 1 ).expand ((- 1 , layer .w2_weight .shape [1 ], - 1 ))
727-
728- layer .w2_weight_scale = torch .nn .Parameter (
729- w2_scales .contiguous (), requires_grad = False )
730- if envs .VLLM_USE_AITER_2STAGE_MOE :
731- layer .w13_weight = torch .nn .Parameter (
732- shuffle_weight (layer .w13_weight , layout = (32 , 32 )),
733- requires_grad = False )
734- layer .w2_weight = torch .nn .Parameter (
735- shuffle_weight (layer .w2_weight , layout = (32 , 32 )),
736- requires_grad = False )
737- else :
731+ w2_scales = layer .w2_weight_scale .data .unsqueeze (
732+ - 1 ).unsqueeze (- 1 ).expand (
733+ (- 1 , layer .w2_weight .shape [1 ], - 1 ))
738734 layer .w13_weight = torch .nn .Parameter (shuffle_weight (
739735 layer .w13_weight ),
740736 requires_grad = False )
741737 layer .w2_weight = torch .nn .Parameter (shuffle_weight (
742738 layer .w2_weight ),
743739 requires_grad = False )
740+
741+ layer .w2_weight_scale = torch .nn .Parameter (
742+ w2_scales .contiguous (), requires_grad = False )
744743 layer .w13_weight_scale = torch .nn .Parameter (
745744 max_w13_scales .contiguous (), requires_grad = False )
746745 return
@@ -776,15 +775,15 @@ def apply(
776775 e_score_correction_bias = e_score_correction_bias ,
777776 )
778777
779- if envs . VLLM_USE_AITER_MOE :
780- if envs . VLLM_USE_AITER_2STAGE_MOE :
778+ if aiter_moe_enabled () :
779+ if aiter_2stage_moe_enabled () :
781780 return ck_moe_2stages (a1 = x ,
782- w1 = layer .w13_weight ,
783- w2 = layer .w2_weight ,
784- topk_weight = topk_weights ,
785- topk_ids = topk_ids ,
786- fc1_scale = layer .w13_weight_scale ,
787- fc2_scale = layer .w2_weight_scale )
781+ w1 = layer .w13_weight ,
782+ w2 = layer .w2_weight ,
783+ topk_weight = topk_weights ,
784+ topk_ids = topk_ids ,
785+ fc1_scale = layer .w13_weight_scale ,
786+ fc2_scale = layer .w2_weight_scale )
788787
789788 return asm_moe (
790789 hidden_states = x ,
0 commit comments