@@ -666,181 +666,3 @@ def fused_moe(
666666 w2_scale = w2_scale ,
667667 a1_scale = a1_scale ,
668668 a2_scale = a2_scale )
669-
670-
671- def single_marlin_moe (
672- hidden_states : torch .Tensor ,
673- w : torch .Tensor ,
674- scales : torch .Tensor ,
675- gating_output : torch .Tensor ,
676- g_idx : torch .Tensor ,
677- rand_perm : torch .Tensor ,
678- topk : int ,
679- renormalize : bool ,
680- override_config : Optional [Dict [str , Any ]] = None ,
681- use_fp8 : bool = False ,
682- ) -> torch .Tensor :
683- """
684- This function computes a Marlin MoE MMM using weights w
685- and top-k gating mechanism. It is meant for testing and debugging.
686-
687- Parameters:
688- - hidden_states (torch.Tensor): The input tensor to the MoE layer.
689- - w (torch.Tensor): The first set of expert weights.
690- - gating_output (torch.Tensor): The output of the gating operation
691- (before softmax).
692- - topk (int): The number of top-k experts to select.
693- - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
694- - inplace (bool): If True, perform the operation in-place.
695- Defaults to False.
696- - override_config (Optional[Dict[str, Any]]): Optional override
697- for the kernel configuration.
698- - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
699- products for w and w2. Defaults to False.
700-
701- Returns:
702- - torch.Tensor: The output tensor after applying the MoE layer.
703- """
704- # Check constraints.
705- assert hidden_states .shape [0 ] == gating_output .shape [0 ], (
706- "Number of tokens mismatch" )
707- assert hidden_states .shape [1 ] == w .shape [1 ] * 16 , "Hidden size mismatch"
708- assert gating_output .shape [1 ] == w .shape [0 ], "Number of experts mismatch"
709- assert hidden_states .is_contiguous (), "Hidden_states must be contiguous"
710- assert w .is_contiguous (), "Expert weights must be contiguous"
711- assert hidden_states .dtype in [
712- torch .float32 , torch .float16 , torch .bfloat16
713- ]
714- M , K = hidden_states .shape
715- E = w .shape [0 ]
716- N = w .shape [2 ] // 2
717-
718- topk_weights , topk_ids = fused_topk (hidden_states , gating_output , topk ,
719- renormalize )
720-
721- # This might not be an optimal config for a single MMM
722- get_config_func = functools .partial (try_get_optimal_moe_config ,
723- w .shape ,
724- w .shape ,
725- topk_ids .shape [1 ],
726- "float8" if use_fp8 else None ,
727- override_config = override_config ,
728- is_marlin = True )
729- config = get_config_func (M )
730-
731- block_size_m = config ['BLOCK_SIZE_M' ]
732-
733- sorted_token_ids , _ , _ = moe_align_block_size (topk_ids , block_size_m , E )
734-
735- max_workspace_size = (N // 64 ) * 16
736- workspace = torch .zeros (max_workspace_size ,
737- dtype = torch .int ,
738- device = "cuda" ,
739- requires_grad = False )
740-
741- intermediate_cache = torch .ops ._moe_C .marlin_gemm_moe (
742- hidden_states , w , sorted_token_ids , topk_weights , topk_ids , scales ,
743- g_idx , rand_perm , workspace , M , N , K , True , E , topk , block_size_m ,
744- True , False )
745-
746- return torch .sum (intermediate_cache .view (* intermediate_cache .shape ), dim = 1 )
747-
748-
749- def fused_marlin_moe (hidden_states : torch .Tensor ,
750- w1 : torch .Tensor ,
751- w2 : torch .Tensor ,
752- gating_output : torch .Tensor ,
753- g_idx1 : torch .Tensor ,
754- g_idx2 : torch .Tensor ,
755- rand_perm1 : torch .Tensor ,
756- rand_perm2 : torch .Tensor ,
757- topk : int ,
758- renormalize : bool ,
759- override_config : Optional [Dict [str , Any ]] = None ,
760- use_fp8 : bool = False ,
761- w1_scale : Optional [torch .Tensor ] = None ,
762- w2_scale : Optional [torch .Tensor ] = None ) -> torch .Tensor :
763- """
764- This function computes a Mixture of Experts (MoE) layer using two sets of
765- weights, w1 and w2, and top-k gating mechanism.
766-
767- Parameters:
768- - hidden_states (torch.Tensor): The input tensor to the MoE layer.
769- - w1 (torch.Tensor): The first set of expert weights.
770- - w2 (torch.Tensor): The second set of expert weights.
771- - gating_output (torch.Tensor): The output of the gating operation
772- (before softmax).
773- - topk (int): The number of top-k experts to select.
774- - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
775- - inplace (bool): If True, perform the operation in-place.
776- Defaults to False.
777- - override_config (Optional[Dict[str, Any]]): Optional override
778- for the kernel configuration.
779- - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
780- products for w1 and w2. Defaults to False.
781- - w1_scale (Optional[torch.Tensor]): Optional scale to be used for
782- w1.
783- - w2_scale (Optional[torch.Tensor]): Optional scale to be used for
784- w2.
785-
786- Returns:
787- - torch.Tensor: The output tensor after applying the MoE layer.
788- """
789- # Check constraints.
790- assert hidden_states .shape [0 ] == gating_output .shape [0 ], (
791- "Number of tokens mismatch" )
792- assert hidden_states .shape [
793- 1 ] == w1 .shape [1 ] * 16 , "Hidden size mismatch w1"
794- assert hidden_states .shape [
795- 1 ] == w2 .shape [2 ] // 2 , "Hidden size mismatch w2"
796- assert gating_output .shape [1 ] == w1 .shape [0 ], "Number of experts mismatch"
797- assert hidden_states .is_contiguous (), "Hidden_states must be contiguous"
798- assert w1 .is_contiguous (), "Expert weights1 must be contiguous"
799- assert w2 .is_contiguous (), "Expert weights2 must be contiguous"
800- assert hidden_states .dtype in [
801- torch .float32 , torch .float16 , torch .bfloat16
802- ]
803- M , K = hidden_states .shape
804- E = w1 .shape [0 ]
805- N = w2 .shape [1 ] * 16
806-
807- topk_weights , topk_ids = fused_topk (hidden_states , gating_output , topk ,
808- renormalize )
809-
810- get_config_func = functools .partial (try_get_optimal_moe_config ,
811- w1 .shape ,
812- w2 .shape ,
813- topk_ids .shape [1 ],
814- "float8" if use_fp8 else None ,
815- override_config = override_config ,
816- is_marlin = True )
817- config = get_config_func (M )
818-
819- block_size_m = config ['BLOCK_SIZE_M' ]
820-
821- sorted_token_ids , _ , _ = moe_align_block_size (topk_ids , block_size_m , E )
822-
823- max_workspace_size = ((M + 255 ) // 256 ) * (max (2 * N , K ) // 64 ) * 16
824- workspace = torch .zeros (max_workspace_size ,
825- dtype = torch .int ,
826- device = "cuda" ,
827- requires_grad = False )
828-
829- intermediate_cache2 = torch .empty ((M * topk_ids .shape [1 ], N ),
830- device = hidden_states .device ,
831- dtype = hidden_states .dtype )
832-
833- intermediate_cache1 = torch .ops ._moe_C .marlin_gemm_moe (
834- hidden_states , w1 , sorted_token_ids , topk_weights , topk_ids , w1_scale ,
835- g_idx1 , rand_perm1 , workspace , M , 2 * N , K , True , E , topk ,
836- block_size_m , True , False )
837-
838- ops .silu_and_mul (intermediate_cache2 , intermediate_cache1 .view (- 1 , 2 * N ))
839-
840- intermediate_cache3 = torch .ops ._moe_C .marlin_gemm_moe (
841- intermediate_cache2 , w2 , sorted_token_ids , topk_weights , topk_ids ,
842- w2_scale , g_idx2 , rand_perm2 , workspace , M , K , N , True , E , topk ,
843- block_size_m , False , True )
844-
845- return torch .sum (intermediate_cache3 .view (* intermediate_cache3 .shape ),
846- dim = 1 )
0 commit comments