3232 PerTensorScaleParameter )
3333from vllm .model_executor .utils import set_weight_attrs
3434from vllm .platforms import current_platform
35- from vllm .utils import aiter_2stage_moe_enabled , aiter_moe_enabled , is_navi
35+ from vllm .utils import is_navi
3636
37- if aiter_moe_enabled () :
37+ if envs . VLLM_USE_AITER_MOE :
3838 from aiter .fused_moe_bf16_asm import asm_moe
39- if aiter_2stage_moe_enabled () :
39+ if envs . VLLM_USE_AITER_2STAGE_MOE :
4040 from aiter .fused_moe_bf16_asm import ck_moe_2stages
4141 from aiter .ops .shuffle import shuffle_weight
4242
@@ -621,7 +621,7 @@ def process_weights_after_loading(self, layer: Module) -> None:
621621 requires_grad = False )
622622 layer .w2_weight = torch .nn .Parameter (w2_weight ,
623623 requires_grad = False )
624- if aiter_moe_enabled () :
624+ if envs . VLLM_USE_AITER_MOE :
625625 w13_scales = layer .w13_weight_scale .data .unsqueeze (
626626 - 1 ).unsqueeze (- 1 ).expand (
627627 (- 1 , layer .w13_weight .shape [1 ], - 1 ))
@@ -632,13 +632,13 @@ def process_weights_after_loading(self, layer: Module) -> None:
632632 layer .w13_weight_scale = torch .nn .Parameter (
633633 w13_scales .contiguous (), requires_grad = False )
634634
635- if aiter_2stage_moe_enabled () :
636- layer .w13_weight = torch .nn .Parameter (shuffle_weight (
637- layer .w13_weight , layout = (32 , 32 )),
638- requires_grad = False )
639- layer .w2_weight = torch .nn .Parameter (shuffle_weight (
640- layer .w2_weight , layout = (32 , 32 )),
641- requires_grad = False )
635+ if envs . VLLM_USE_AITER_2STAGE_MOE :
636+ layer .w13_weight = torch .nn .Parameter (
637+ shuffle_weight ( layer .w13_weight , layout = (32 , 32 )),
638+ requires_grad = False )
639+ layer .w2_weight = torch .nn .Parameter (
640+ shuffle_weight ( layer .w2_weight , layout = (32 , 32 )),
641+ requires_grad = False )
642642 else :
643643 layer .w13_weight = torch .nn .Parameter (shuffle_weight (
644644 layer .w13_weight ),
@@ -715,31 +715,32 @@ def process_weights_after_loading(self, layer: Module) -> None:
715715 dq_weight , max_w13_scales [expert_id ])
716716 start += shard_size
717717
718- if aiter_moe_enabled () :
719- if aiter_2stage_moe_enabled () :
718+ if envs . VLLM_USE_AITER_MOE :
719+ if envs . VLLM_USE_AITER_2STAGE_MOE :
720720 max_w13_scales = max_w13_scales .unsqueeze (- 1 )
721721 w2_scales = layer .w2_weight_scale .data .unsqueeze (- 1 )
722- layer .w13_weight = torch .nn .Parameter (shuffle_weight (
723- layer .w13_weight , layout = (32 , 32 )),
724- requires_grad = False )
725- layer .w2_weight = torch .nn .Parameter (shuffle_weight (
726- layer .w2_weight , layout = (32 , 32 )),
727- requires_grad = False )
728722 else :
729723 max_w13_scales = max_w13_scales .unsqueeze (- 1 ).unsqueeze (
730724 - 1 ).expand ((- 1 , layer .w13_weight .shape [1 ], - 1 ))
731- w2_scales = layer .w2_weight_scale .data .unsqueeze (
732- - 1 ).unsqueeze (- 1 ).expand (
733- (- 1 , layer .w2_weight .shape [1 ], - 1 ))
725+ w2_scales = layer .w2_weight_scale .data .unsqueeze (- 1 ).unsqueeze (
726+ - 1 ).expand ((- 1 , layer .w2_weight .shape [1 ], - 1 ))
727+
728+ layer .w2_weight_scale = torch .nn .Parameter (
729+ w2_scales .contiguous (), requires_grad = False )
730+ if envs .VLLM_USE_AITER_2STAGE_MOE :
731+ layer .w13_weight = torch .nn .Parameter (
732+ shuffle_weight (layer .w13_weight , layout = (32 , 32 )),
733+ requires_grad = False )
734+ layer .w2_weight = torch .nn .Parameter (
735+ shuffle_weight (layer .w2_weight , layout = (32 , 32 )),
736+ requires_grad = False )
737+ else :
734738 layer .w13_weight = torch .nn .Parameter (shuffle_weight (
735739 layer .w13_weight ),
736740 requires_grad = False )
737741 layer .w2_weight = torch .nn .Parameter (shuffle_weight (
738742 layer .w2_weight ),
739743 requires_grad = False )
740-
741- layer .w2_weight_scale = torch .nn .Parameter (
742- w2_scales .contiguous (), requires_grad = False )
743744 layer .w13_weight_scale = torch .nn .Parameter (
744745 max_w13_scales .contiguous (), requires_grad = False )
745746 return
@@ -775,15 +776,15 @@ def apply(
775776 e_score_correction_bias = e_score_correction_bias ,
776777 )
777778
778- if aiter_moe_enabled () :
779- if aiter_2stage_moe_enabled () :
779+ if envs . VLLM_USE_AITER_MOE :
780+ if envs . VLLM_USE_AITER_2STAGE_MOE :
780781 return ck_moe_2stages (a1 = x ,
781- w1 = layer .w13_weight ,
782- w2 = layer .w2_weight ,
783- topk_weight = topk_weights ,
784- topk_ids = topk_ids ,
785- fc1_scale = layer .w13_weight_scale ,
786- fc2_scale = layer .w2_weight_scale )
782+ w1 = layer .w13_weight ,
783+ w2 = layer .w2_weight ,
784+ topk_weight = topk_weights ,
785+ topk_ids = topk_ids ,
786+ fc1_scale = layer .w13_weight_scale ,
787+ fc2_scale = layer .w2_weight_scale )
787788
788789 return asm_moe (
789790 hidden_states = x ,
0 commit comments