22#include " ggml.h"
33#include " topk-moe.cuh"
44
5+ #include < initializer_list>
6+
57/*
68 This kernel does the following:
79 1. softmax over the logits per token [n_experts, n_tokens]
810 2. argmax reduce over the top-k (n_experts_used) logits
911 3. write weights + ids to global memory
12+ 4. optionally normalize the weights
1013
1114 It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models
1215*/
13- template <size_t n_experts>
16+ template <size_t n_experts, bool with_norm >
1417__launch_bounds__ (4 * WARP_SIZE, 1 ) __global__ void topk_moe_cuda(const float * logits,
1518 float * weights,
1619 int32_t * ids,
@@ -68,6 +71,11 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
6871 // we do the argmax reduce over n_expert_used, each time marking
6972 // the expert weight as -inf to exclude from the next iteration
7073
74+ float wt_sum = 0 .f ;
75+
76+ extern __shared__ float data_topk_shared[];
77+ float * wt_shared_ptr = data_topk_shared + row * n_expert_used;
78+
7179 for (int k = 0 ; k < n_expert_used; k++) {
7280 float max_val = wt[0 ];
7381 int max_expert = threadIdx .x ;
@@ -94,12 +102,33 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
94102 if ((max_expert & (WARP_SIZE - 1 )) == threadIdx .x ) {
95103 wt[max_expert / WARP_SIZE] = -INFINITY;
96104
97- weights[k] = max_val;
98- ids[k] = max_expert;
105+ wt_shared_ptr[k] = max_val;
106+ ids[k] = max_expert;
107+ if constexpr (with_norm) {
108+ wt_sum += max_val;
109+ }
110+ }
111+ }
112+
113+ if constexpr (with_norm) {
114+ wt_sum = warp_reduce_sum (wt_sum);
115+ const float inv_sum = 1 .0f / wt_sum;
116+
117+ if (threadIdx .x == 0 ) {
118+ for (int i = 0 ; i < n_expert_used; i++) {
119+ wt_shared_ptr[i] = wt_shared_ptr[i] * inv_sum;
120+ }
121+ }
122+ }
123+
124+ if (threadIdx .x == 0 ) {
125+ for (int i = 0 ; i < n_expert_used; i++) {
126+ weights[i] = wt_shared_ptr[i];
99127 }
100128 }
101129}
102130
131+ template <bool with_norm>
103132static void launch_topk_moe_cuda (ggml_backend_cuda_context & ctx,
104133 const float * logits,
105134 float * weights,
@@ -112,36 +141,48 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
112141 dim3 block_dims (WARP_SIZE, rows_per_block, 1 );
113142 cudaStream_t stream = ctx.stream ();
114143
144+ const int nbytes_shared = n_expert_used * rows_per_block * sizeof (float );
145+
115146 switch (n_expert) {
116147 case 1 :
117- topk_moe_cuda<1 ><<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
148+ topk_moe_cuda<1 , with_norm>
149+ <<<grid_dims, block_dims, nbytes_shared, stream>>> (logits, weights, ids, n_rows, n_expert_used);
118150 break ;
119151 case 2 :
120- topk_moe_cuda<2 ><<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
152+ topk_moe_cuda<2 , with_norm>
153+ <<<grid_dims, block_dims, nbytes_shared, stream>>> (logits, weights, ids, n_rows, n_expert_used);
121154 break ;
122155 case 4 :
123- topk_moe_cuda<4 ><<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
156+ topk_moe_cuda<4 , with_norm>
157+ <<<grid_dims, block_dims, nbytes_shared, stream>>> (logits, weights, ids, n_rows, n_expert_used);
124158 break ;
125159 case 8 :
126- topk_moe_cuda<8 ><<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
160+ topk_moe_cuda<8 , with_norm>
161+ <<<grid_dims, block_dims, nbytes_shared, stream>>> (logits, weights, ids, n_rows, n_expert_used);
127162 break ;
128163 case 16 :
129- topk_moe_cuda<16 ><<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
164+ topk_moe_cuda<16 , with_norm>
165+ <<<grid_dims, block_dims, nbytes_shared, stream>>> (logits, weights, ids, n_rows, n_expert_used);
130166 break ;
131167 case 32 :
132- topk_moe_cuda<32 ><<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
168+ topk_moe_cuda<32 , with_norm>
169+ <<<grid_dims, block_dims, nbytes_shared, stream>>> (logits, weights, ids, n_rows, n_expert_used);
133170 break ;
134171 case 64 :
135- topk_moe_cuda<64 ><<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
172+ topk_moe_cuda<64 , with_norm>
173+ <<<grid_dims, block_dims, nbytes_shared, stream>>> (logits, weights, ids, n_rows, n_expert_used);
136174 break ;
137175 case 128 :
138- topk_moe_cuda<128 ><<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
176+ topk_moe_cuda<128 , with_norm>
177+ <<<grid_dims, block_dims, nbytes_shared, stream>>> (logits, weights, ids, n_rows, n_expert_used);
139178 break ;
140179 case 256 :
141- topk_moe_cuda<256 ><<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
180+ topk_moe_cuda<256 , with_norm>
181+ <<<grid_dims, block_dims, nbytes_shared, stream>>> (logits, weights, ids, n_rows, n_expert_used);
142182 break ;
143183 case 512 :
144- topk_moe_cuda<512 ><<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
184+ topk_moe_cuda<512 , with_norm>
185+ <<<grid_dims, block_dims, nbytes_shared, stream>>> (logits, weights, ids, n_rows, n_expert_used);
145186 break ;
146187 default :
147188 GGML_ASSERT (false && " fatal error" );
@@ -152,7 +193,8 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
152193void ggml_cuda_op_topk_moe (ggml_backend_cuda_context & ctx,
153194 const ggml_tensor * logits,
154195 ggml_tensor * weights,
155- ggml_tensor * ids) {
196+ ggml_tensor * ids,
197+ const bool with_norm) {
156198 GGML_ASSERT (logits->type == GGML_TYPE_F32);
157199 GGML_ASSERT (weights->type == GGML_TYPE_F32);
158200 GGML_ASSERT (ids->type == GGML_TYPE_I32);
@@ -170,7 +212,11 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
170212
171213 const int n_expert_used = weights->ne [1 ];
172214
173- launch_topk_moe_cuda (ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
215+ if (with_norm) {
216+ launch_topk_moe_cuda<true >(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
217+ } else {
218+ launch_topk_moe_cuda<false >(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
219+ }
174220}
175221
176222bool ggml_cuda_should_use_topk_moe (const ggml_tensor * softmax, const ggml_tensor * weights) {
@@ -201,3 +247,17 @@ bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tenso
201247
202248 return true ;
203249}
250+
251+ std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops (bool norm) {
252+ static std::initializer_list<enum ggml_op> norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
253+ GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
254+ GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE };
255+
256+ static std::initializer_list<enum ggml_op> no_norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
257+ GGML_OP_VIEW, GGML_OP_GET_ROWS };
258+
259+ if (norm) {
260+ return norm_ops;
261+ }
262+ return no_norm_ops;
263+ }
0 commit comments