|  | 
| 54 | 54 | from ..modules.decoder_layer import DecoderLayer | 
| 55 | 55 | from ..modules.embedding import Embedding | 
| 56 | 56 | from ..modules.fused_moe import (CutlassFusedMoE, DeepSeekV3MoeRoutingMethod, | 
| 57 |  | -                                 create_moe) | 
|  | 57 | +                                 WideEPMoE, create_moe) | 
| 58 | 58 | from ..modules.gated_mlp import GatedMLP | 
| 59 | 59 | from ..modules.linear import Linear, TensorParallelMode, WeightsLoadingConfig | 
| 60 | 60 | from ..modules.multi_stream_utils import maybe_execute_in_parallel | 
| @@ -513,7 +513,7 @@ def compute_routed_output(self, hidden_states, hidden_states_fp4, | 
| 513 | 513 |                                           self.mapping, | 
| 514 | 514 |                                           dim=0, | 
| 515 | 515 |                                           sizes=all_rank_num_tokens) | 
| 516 |  | -            elif not isinstance(self.experts, CutlassFusedMoE) or ( | 
|  | 516 | +            elif not isinstance(self.experts, (CutlassFusedMoE, WideEPMoE)) or ( | 
| 517 | 517 |                     not self.experts.has_fp8_qdq and self.experts.has_nvfp4): | 
| 518 | 518 |                 # Use padding when not using the cutlass path or when x_sf in self.experts is not None | 
| 519 | 519 |                 use_dp_padding = True | 
| @@ -780,116 +780,70 @@ def _run_MoE(hidden_states, hidden_states_fp4, do_finalize): | 
| 780 | 780 |                 do_finalize=do_finalize, | 
| 781 | 781 |             ) | 
| 782 | 782 | 
 | 
| 783 |  | -        cutlass_min_latency_mode = self._enable_min_latency_mode( | 
| 784 |  | -            hidden_states.shape[0]) | 
| 785 |  | - | 
| 786 |  | -        if cutlass_min_latency_mode: | 
| 787 |  | -            assert self.fusion_config.PRE_MOE_FUSION and self.fusion_config.POST_MOE_FUSION | 
| 788 |  | -            assert self.model_config.moe_backend == 'CUTLASS' | 
| 789 |  | - | 
| 790 |  | -            hidden_states, hidden_states_act, hidden_states_sf, residual = self.allreduce( | 
|  | 783 | +        if self.fusion_config.PRE_MOE_FUSION: | 
|  | 784 | +            # moe_backend can be either CUTLASS or TRTLLM here | 
|  | 785 | +            # TODO: unify the two min-latency MoE backends by enabling quant fusion | 
|  | 786 | +            hidden_states, residual = self.allreduce( | 
| 791 | 787 |                 hidden_states, | 
| 792 | 788 |                 all_reduce_params=AllReduceParams( | 
| 793 |  | -                    fusion_op=AllReduceFusionOp. | 
| 794 |  | -                    RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4, | 
|  | 789 | +                    fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM, | 
| 795 | 790 |                     residual=residual, | 
| 796 | 791 |                     norm_weight=self.post_attention_layernorm.weight, | 
| 797 |  | -                    scale=self.mlp.experts.fc31_input_scale, | 
| 798 | 792 |                     eps=self.post_attention_layernorm.variance_epsilon, | 
|  | 793 | +                    trigger_completion_at_end=False, | 
| 799 | 794 |                 )) | 
| 800 |  | -            hidden_states_fp4 = Fp4QuantizedTensor(hidden_states_act, | 
| 801 |  | -                                                   hidden_states_sf) | 
| 802 |  | - | 
| 803 |  | -            hidden_states = _run_MoE(hidden_states, | 
| 804 |  | -                                     hidden_states_fp4, | 
| 805 |  | -                                     do_finalize=False) | 
|  | 795 | +        else: | 
|  | 796 | +            # No fusion | 
|  | 797 | +            hidden_states, residual = self.post_attention_layernorm( | 
|  | 798 | +                hidden_states, residual) | 
| 806 | 799 | 
 | 
| 807 |  | -            shared_output = hidden_states[0] | 
| 808 |  | -            hidden_states_activated_experts = hidden_states[1] | 
| 809 |  | -            num_activated_experts_per_node = hidden_states[2] | 
| 810 |  | -            experts_to_token_score = hidden_states[3] | 
|  | 800 | +        # Note: this fusion pattern is only supported for TRTLLM-nvfp4 backend now | 
|  | 801 | +        do_finalize = not (hidden_states.shape[0] | 
|  | 802 | +                           <= self.moe_allreduce.max_token | 
|  | 803 | +                           and self.fusion_config.POST_MOE_FUSION | 
|  | 804 | +                           and self.model_config.moe_backend == 'TRTLLM' | 
|  | 805 | +                           and self.mlp.experts.has_nvfp4) | 
| 811 | 806 | 
 | 
| 812 |  | -            moe_all_reduce_params = MoEAllReduceParams( | 
| 813 |  | -                residual=residual, | 
| 814 |  | -                norm_weight=self.next_layer_layernorm.weight, | 
| 815 |  | -                device_num_experts=num_activated_experts_per_node, | 
| 816 |  | -                expert_scale_factor=experts_to_token_score, | 
| 817 |  | -                shared_expert_output=shared_output, | 
| 818 |  | -                eps=self.next_layer_layernorm.variance_epsilon, | 
| 819 |  | -                is_cutlass_min_latency=True, | 
| 820 |  | -            ) | 
|  | 807 | +        hidden_states = _run_MoE(hidden_states, | 
|  | 808 | +                                 hidden_states_fp4=None, | 
|  | 809 | +                                 do_finalize=do_finalize) | 
| 821 | 810 | 
 | 
| 822 |  | -            # MoE_finalize is fused into allreduce | 
| 823 |  | -            hidden_states, residual = self.moe_allreduce( | 
| 824 |  | -                hidden_states_activated_experts, | 
| 825 |  | -                all_reduce_params=moe_all_reduce_params, | 
| 826 |  | -            ) | 
| 827 |  | -        else: | 
| 828 |  | -            if self.fusion_config.PRE_MOE_FUSION: | 
| 829 |  | -                # moe_backend can be either CUTLASS or TRTLLM here | 
| 830 |  | -                # TODO: unify the two min-latency MoE backends by enabling quant fusion | 
|  | 811 | +        if self.fusion_config.POST_MOE_FUSION: | 
|  | 812 | +            if do_finalize: | 
| 831 | 813 |                 hidden_states, residual = self.allreduce( | 
| 832 | 814 |                     hidden_states, | 
| 833 | 815 |                     all_reduce_params=AllReduceParams( | 
| 834 | 816 |                         fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM, | 
| 835 | 817 |                         residual=residual, | 
| 836 |  | -                        norm_weight=self.post_attention_layernorm.weight, | 
| 837 |  | -                        eps=self.post_attention_layernorm.variance_epsilon, | 
|  | 818 | +                        norm_weight=self.next_layer_layernorm.weight, | 
|  | 819 | +                        eps=self.next_layer_layernorm.variance_epsilon, | 
| 838 | 820 |                         trigger_completion_at_end=False, | 
| 839 | 821 |                     )) | 
| 840 | 822 |             else: | 
| 841 |  | -                # No fusion | 
| 842 |  | -                hidden_states, residual = self.post_attention_layernorm( | 
|  | 823 | +                assert len( | 
|  | 824 | +                    hidden_states) == 4, "hidden_states must have 4 elements" | 
|  | 825 | + | 
|  | 826 | +                shared_output = hidden_states[0] | 
|  | 827 | +                fc2_output = hidden_states[1] | 
|  | 828 | +                expert_scale_factor = hidden_states[2] | 
|  | 829 | +                expanded_idx_to_permuted_idx = hidden_states[3] | 
|  | 830 | + | 
|  | 831 | +                moe_all_reduce_params = MoEAllReduceParams( | 
|  | 832 | +                    expanded_idx_to_permuted_idx=expanded_idx_to_permuted_idx, | 
|  | 833 | +                    expert_scale_factor=expert_scale_factor, | 
|  | 834 | +                    shared_expert_output=shared_output, | 
|  | 835 | +                    residual=residual, | 
|  | 836 | +                    norm_weight=self.next_layer_layernorm.weight, | 
|  | 837 | +                    eps=self.next_layer_layernorm.variance_epsilon, | 
|  | 838 | +                    is_cutlass_min_latency=False, | 
|  | 839 | +                ) | 
|  | 840 | +                hidden_states, residual = self.moe_allreduce( | 
|  | 841 | +                    fc2_output, all_reduce_params=moe_all_reduce_params) | 
|  | 842 | +        else: | 
|  | 843 | +            if self.next_layer_layernorm is not None: | 
|  | 844 | +                hidden_states, residual = self.next_layer_layernorm( | 
| 843 | 845 |                     hidden_states, residual) | 
| 844 | 846 | 
 | 
| 845 |  | -            # Note: this fusion pattern is only supported for TRTLLM-nvfp4 backend now | 
| 846 |  | -            do_finalize = not (hidden_states.shape[0] | 
| 847 |  | -                               <= self.moe_allreduce.max_token | 
| 848 |  | -                               and self.fusion_config.POST_MOE_FUSION | 
| 849 |  | -                               and self.model_config.moe_backend == 'TRTLLM' | 
| 850 |  | -                               and self.mlp.experts.has_nvfp4) | 
| 851 |  | - | 
| 852 |  | -            hidden_states = _run_MoE(hidden_states, | 
| 853 |  | -                                     hidden_states_fp4=None, | 
| 854 |  | -                                     do_finalize=do_finalize) | 
| 855 |  | - | 
| 856 |  | -            if self.fusion_config.POST_MOE_FUSION: | 
| 857 |  | -                if do_finalize: | 
| 858 |  | -                    hidden_states, residual = self.allreduce( | 
| 859 |  | -                        hidden_states, | 
| 860 |  | -                        all_reduce_params=AllReduceParams( | 
| 861 |  | -                            fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM, | 
| 862 |  | -                            residual=residual, | 
| 863 |  | -                            norm_weight=self.next_layer_layernorm.weight, | 
| 864 |  | -                            eps=self.next_layer_layernorm.variance_epsilon, | 
| 865 |  | -                            trigger_completion_at_end=False, | 
| 866 |  | -                        )) | 
| 867 |  | -                else: | 
| 868 |  | -                    assert len(hidden_states | 
| 869 |  | -                               ) == 4, "hidden_states must have 4 elements" | 
| 870 |  | - | 
| 871 |  | -                    shared_output = hidden_states[0] | 
| 872 |  | -                    fc2_output = hidden_states[1] | 
| 873 |  | -                    expert_scale_factor = hidden_states[2] | 
| 874 |  | -                    expanded_idx_to_permuted_idx = hidden_states[3] | 
| 875 |  | - | 
| 876 |  | -                    moe_all_reduce_params = MoEAllReduceParams( | 
| 877 |  | -                        expanded_idx_to_permuted_idx= | 
| 878 |  | -                        expanded_idx_to_permuted_idx, | 
| 879 |  | -                        expert_scale_factor=expert_scale_factor, | 
| 880 |  | -                        shared_expert_output=shared_output, | 
| 881 |  | -                        residual=residual, | 
| 882 |  | -                        norm_weight=self.next_layer_layernorm.weight, | 
| 883 |  | -                        eps=self.next_layer_layernorm.variance_epsilon, | 
| 884 |  | -                        is_cutlass_min_latency=False, | 
| 885 |  | -                    ) | 
| 886 |  | -                    hidden_states, residual = self.moe_allreduce( | 
| 887 |  | -                        fc2_output, all_reduce_params=moe_all_reduce_params) | 
| 888 |  | -            else: | 
| 889 |  | -                if self.next_layer_layernorm is not None: | 
| 890 |  | -                    hidden_states, residual = self.next_layer_layernorm( | 
| 891 |  | -                        hidden_states, residual) | 
| 892 |  | - | 
| 893 | 847 |         return hidden_states, residual | 
| 894 | 848 | 
 | 
| 895 | 849 |     def forward_mlp( | 
|  | 
0 commit comments