@@ -301,9 +301,11 @@ __global__ void Marlin(
301301 int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
302302 const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
303303 // (k/groupsize)xn
304- const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
305- // (k/groupsize)x(n/pack_factor)
306- const int * __restrict__ g_idx, // int32 group indices of shape k
304+ const uint16_t * __restrict__ scale2_ptr, // fp16 global scale (for nvfp4
305+ // only)
306+ const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
307+ // (k/groupsize)x(n/pack_factor)
308+ const int * __restrict__ g_idx, // int32 group indices of shape k
307309 const int32_t * __restrict__ sorted_token_ids_ptr, // moe sorted_ids
308310 const int32_t * __restrict__ expert_ids_ptr, // moe expert ids
309311 const int32_t * __restrict__ num_tokens_past_padded_ptr, // moe num tokens
@@ -341,14 +343,25 @@ __global__ void Marlin(
341343 extern __shared__ int4 sh[];
342344 static constexpr auto w_type = vllm::ScalarType::from_id (w_type_id);
343345 constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8 ;
346+ constexpr bool is_int_type = w_type == vllm::kU4 || w_type == vllm::kU8 ||
347+ w_type == vllm::kU4B8 || w_type == vllm::kU8B128 ;
348+ // see comments of dequant.h for more details
349+ constexpr bool dequant_skip_flop =
350+ !is_int_type ||
351+ has_zp && !is_zp_float && !std::is_same<scalar_t , nv_bfloat16>::value ||
352+ has_zp && !is_zp_float && !(w_type == vllm::kU8 );
353+
354+ scalar_t2 global_scale;
355+
344356 constexpr bool has_act_order = group_blocks == 0 ;
345357
346358 constexpr int pack_factor = 32 / w_type.size_bits ();
347359 static_assert (thread_m_blocks == 1 || !m_block_size_8);
348360 constexpr int moe_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks);
349361 const int group_size =
350362 (!has_act_order && group_blocks == -1 ) ? prob_k : prob_k / num_groups;
351- const int scales_expert_stride = prob_n * prob_k / group_size / 8 ;
363+ const int scales_expert_stride =
364+ prob_n * prob_k / group_size / (w_type == vllm::kFE2M1f ? 16 : 8 );
352365 const int zp_expert_stride =
353366 is_zp_float ? prob_n * prob_k / group_size / 8
354367 : prob_n * prob_k / group_size / (pack_factor * 4 );
@@ -460,9 +473,16 @@ __global__ void Marlin(
460473 if (mul_topk_weights) {
461474 #pragma unroll
462475 for (int i = 0 ; i < 4 ; i++) {
463- sh_block_topk_weights[tid4 * 4 + i] =
464- Dtype::num2num2 (Dtype::float2num (
465- topk_weights_ptr[sh_block_sorted_ids[tid4 * 4 + i]]));
476+ if constexpr (w_type == vllm::kFE2M1f ) {
477+ sh_block_topk_weights[tid4 * 4 + i] = __hmul2 (
478+ global_scale,
479+ Dtype::num2num2 (Dtype::float2num (
480+ topk_weights_ptr[sh_block_sorted_ids[tid4 * 4 + i]])));
481+ } else {
482+ sh_block_topk_weights[tid4 * 4 + i] =
483+ Dtype::num2num2 (Dtype::float2num (
484+ topk_weights_ptr[sh_block_sorted_ids[tid4 * 4 + i]]));
485+ }
466486 }
467487 }
468488 }
@@ -493,6 +513,11 @@ __global__ void Marlin(
493513 expert_id = expert_ids_ptr[block_id];
494514 }
495515
516+ if constexpr (w_type == vllm::kFE2M1f ) {
517+ uint16_t val = scale2_ptr[expert_id];
518+ global_scale = Dtype::num2num2 (*reinterpret_cast <scalar_t *>(&val));
519+ }
520+
496521 B_expert_off = expert_id * prob_n * prob_k / (pack_factor * 4 );
497522 scales_ptr += (expert_id - old_expert_id) * scales_expert_stride;
498523 if constexpr (has_zp) {
@@ -606,7 +631,7 @@ __global__ void Marlin(
606631 constexpr int s_sh_stride = 16 * thread_n_blocks / 8 ;
607632 constexpr int s_tb_groups =
608633 !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks
609- ? thread_k_blocks / group_blocks
634+ ? thread_k_blocks / group_blocks / (w_type == vllm:: kFE2M1f ? 2 : 1 )
610635 : 1 ;
611636 constexpr int s_sh_stage = s_tb_groups * s_sh_stride;
612637 int s_gl_rd_delta = s_gl_stride;
@@ -664,7 +689,8 @@ __global__ void Marlin(
664689 if constexpr (group_blocks == -1 ) {
665690 s_gl_rd = s_sh_stride * slice_col + threadIdx.x ;
666691 } else {
667- s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
692+ s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) /
693+ (w_type == vllm::kFE2M1f ? 2 : 1 ) +
668694 s_sh_stride * slice_col + threadIdx.x ;
669695 }
670696 }
@@ -688,10 +714,20 @@ __global__ void Marlin(
688714 // we scale a `half2` tile in column-major layout in the former and in
689715 // row-major in the latter case.
690716 int s_sh_rd;
691- if constexpr (group_blocks != -1 )
717+ if constexpr (group_blocks != -1 && w_type == vllm::kFE2M1f ) {
718+ auto warp_id = threadIdx.x / 32 ;
719+ int n_warps = thread_n_blocks / 4 ;
720+ int warp_row = warp_id / n_warps;
721+
692722 s_sh_rd = 8 * ((threadIdx.x / 32 ) % (thread_n_blocks / 4 )) +
693723 (threadIdx.x % 32 ) / 4 ;
694- else if constexpr (group_blocks == -1 && (m_block_size_8 || has_zp))
724+ s_sh_rd = s_sh_rd * 2 + warp_row % 2 ;
725+
726+ } else if constexpr (group_blocks != -1 )
727+ s_sh_rd = 8 * ((threadIdx.x / 32 ) % (thread_n_blocks / 4 )) +
728+ (threadIdx.x % 32 ) / 4 ;
729+ else if constexpr (group_blocks == -1 &&
730+ (m_block_size_8 || (has_zp && !dequant_skip_flop)))
695731 s_sh_rd = 8 * ((threadIdx.x / 32 ) % (thread_n_blocks / 4 )) +
696732 (threadIdx.x % 32 ) / 8 ;
697733 else
@@ -801,7 +837,7 @@ __global__ void Marlin(
801837 sh_first_group_id = first_group_id;
802838 sh_num_groups = last_group_id - first_group_id + 1 ;
803839
804- if (sh_num_groups < act_s_max_num_groups) {
840+ if (sh_num_groups > act_s_max_num_groups) {
805841 sh_num_groups = act_s_max_num_groups;
806842 }
807843
@@ -1021,12 +1057,19 @@ __global__ void Marlin(
10211057 cur_k += k_iter_size * (k % b_sh_wr_iters);
10221058
10231059 int k_blocks = cur_k / 16 ;
1024- int cur_group_id = k_blocks / group_blocks;
1060+ int cur_group_id =
1061+ k_blocks / (group_blocks * (w_type == vllm::kFE2M1f ? 2 : 1 ));
10251062
10261063 int4* sh_s_stage = sh_s + s_sh_stage * pipe;
10271064
1028- reinterpret_cast <int4*>(&frag_s[k % 2 ])[0 ] =
1029- sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
1065+ if constexpr (w_type_id != vllm::kFE2M1f .id ()) {
1066+ reinterpret_cast <int4*>(&frag_s[k % 2 ])[0 ] =
1067+ sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
1068+ } else {
1069+ reinterpret_cast <int2*>(&frag_s[k % 2 ])[0 ] =
1070+ reinterpret_cast <int2*>(
1071+ sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)];
1072+ }
10301073 }
10311074 }
10321075
@@ -1199,22 +1242,7 @@ __global__ void Marlin(
11991242 };
12001243
12011244 auto dequant_data = [&](int q, scalar_t2* frag_b_ptr) {
1202- if constexpr (has_zp && is_zp_float || !has_zp) {
1203- dequant<scalar_t2, w_type_id>(q, frag_b_ptr);
1204- } else {
1205- static_assert (has_zp && !is_zp_float);
1206- static_assert (w_type_id == vllm::kU4 .id () || w_type_id == vllm::kU8 .id ());
1207- // If (has_zp && !is_zp_float),
1208- // we use not-zp version `dequant` function
1209- // to improve numerical accuracy.
1210- // Since both weight and zero point are dequanted using this logic,
1211- // the final dequanted weight would be correct.
1212- if constexpr (w_type_id == vllm::kU4 .id ()) {
1213- dequant<scalar_t2, vllm::kU4B8 .id ()>(q, frag_b_ptr);
1214- } else if constexpr (w_type_id == vllm::kU8 .id ()) {
1215- dequant<scalar_t2, vllm::kU8B128 .id ()>(q, frag_b_ptr);
1216- }
1217- }
1245+ dequant<scalar_t2, w_type_id, dequant_skip_flop>(q, frag_b_ptr);
12181246 };
12191247
12201248 // Execute the actual tensor core matmul of a sub-tile.
@@ -1244,13 +1272,23 @@ __global__ void Marlin(
12441272 dequant_data (zp_quant_1, reinterpret_cast <scalar_t2*>(&frag_zp) + 2 );
12451273 }
12461274 }
1247- if constexpr (has_zp && is_zp_float) {
1275+ if constexpr (!dequant_skip_flop && has_zp && is_zp_float) {
12481276 if (is_new_zp) {
12491277 reinterpret_cast <int4*>(&frag_zp)[0 ] =
12501278 reinterpret_cast <int4*>(&frag_zpf[k2])[0 ];
12511279 }
12521280 }
12531281
1282+ if constexpr (w_type == vllm::kFE2M1f ) {
1283+ int s_quant_0 = reinterpret_cast <int *>(frag_s[k2])[0 ];
1284+ int s_quant_1 = reinterpret_cast <int *>(frag_s[k2])[1 ];
1285+
1286+ dequant_fp8_scales<scalar_t2>(s_quant_0,
1287+ reinterpret_cast <scalar_t2*>(&frag_s[k2]));
1288+ dequant_fp8_scales<scalar_t2>(
1289+ s_quant_1, reinterpret_cast <scalar_t2*>(&frag_s[k2]) + 2 );
1290+ }
1291+
12541292 // We have the m dimension as the inner loop in order to encourage overlapping
12551293 // dequantization and matmul operations.
12561294 #pragma unroll
@@ -1259,7 +1297,10 @@ __global__ void Marlin(
12591297 FragB frag_b1;
12601298 int b_quant_0, b_quant_1;
12611299
1262- if constexpr (w_type.size_bits () == 4 ) {
1300+ if constexpr (w_type_id == vllm::kFE2M1f .id ()) {
1301+ b_quant_1 = frag_b_quant[k2][0 ][j];
1302+ b_quant_0 = b_quant_1 << 8 ;
1303+ } else if constexpr (w_type.size_bits () == 4 ) {
12631304 b_quant_0 = frag_b_quant[k2][0 ][j];
12641305 b_quant_1 = b_quant_0 >> 8 ;
12651306 } else {
@@ -1272,22 +1313,28 @@ __global__ void Marlin(
12721313 dequant_data (b_quant_0, reinterpret_cast <scalar_t2*>(&frag_b0));
12731314 dequant_data (b_quant_1, reinterpret_cast <scalar_t2*>(&frag_b1));
12741315
1316+ if constexpr (dequant_skip_flop && has_zp && !is_zp_float) {
1317+ sub_zp<scalar_t >(frag_b0, frag_zp[j], 0 );
1318+ sub_zp<scalar_t >(frag_b1, frag_zp[j], 1 );
1319+ }
1320+
12751321 // Apply scale to frag_b0
12761322 if constexpr (has_act_order) {
12771323 static_assert (group_blocks != -1 );
12781324 scale4<scalar_t >(frag_b0, act_frag_s[k2][0 ][j], act_frag_s[k2][1 ][j],
12791325 act_frag_s[k2][2 ][j], act_frag_s[k2][3 ][j], 0 );
12801326 scale4<scalar_t >(frag_b1, act_frag_s[k2][0 ][j], act_frag_s[k2][1 ][j],
12811327 act_frag_s[k2][2 ][j], act_frag_s[k2][3 ][j], 1 );
1282- } else if constexpr (has_zp && !is_zp_float && group_blocks == -1 ) {
1328+ } else if constexpr (!dequant_skip_flop && has_zp && !is_zp_float &&
1329+ group_blocks == -1 ) {
12831330 int idx = (threadIdx.x / 4 ) % 2 ;
12841331 scalar_t2 s2 = Dtype::nums2num2 (
12851332 reinterpret_cast <scalar_t *>(&frag_s[j / 2 ][j % 2 * 2 + 0 ])[idx],
12861333 reinterpret_cast <scalar_t *>(&frag_s[j / 2 ][j % 2 * 2 + 1 ])[idx]);
12871334 if (is_new_zp) frag_zp[j] = __hmul2 (frag_zp[j], s2);
12881335 scale_and_sub<scalar_t >(frag_b0, s2.x , frag_zp[j].x );
12891336 scale_and_sub<scalar_t >(frag_b1, s2.y , frag_zp[j].y );
1290- } else if constexpr (has_zp && group_blocks != -1 ) {
1337+ } else if constexpr (!dequant_skip_flop && has_zp && group_blocks != -1 ) {
12911338 if (is_new_zp)
12921339 frag_zp[j] = __hmul2 (frag_zp[j],
12931340 *reinterpret_cast <scalar_t2*>(&frag_s[k2][j]));
@@ -1554,10 +1601,17 @@ __global__ void Marlin(
15541601 // For per-column quantization we finally apply the scale here (only for
15551602 // 4-bit)
15561603 if constexpr (!has_act_order && group_blocks == -1 &&
1557- w_type.size_bits () == 4 && !has_zp) {
1604+ w_type.size_bits () == 4 &&
1605+ (has_zp && dequant_skip_flop || !has_zp)) {
15581606 res = __hmul2 (res, s[0 ]);
15591607 }
15601608
1609+ if constexpr (w_type == vllm::kFE2M1f ) {
1610+ if (!mul_topk_weights) {
1611+ res = __hmul2 (res, global_scale);
1612+ }
1613+ }
1614+
15611615 if constexpr (m_block_size_8) {
15621616 ((scalar_t *)sh_red)[idx] = res.x ;
15631617 ((scalar_t *)sh_red)[idx + 8 * c_sh_stride] = res.y ;
@@ -1648,7 +1702,9 @@ __global__ void Marlin(
16481702 if constexpr (has_zp && !is_zp_float && group_blocks == -1 ) {
16491703 if (i == 0 ) {
16501704 fetch_col_zp_to_shared ();
1651- fetch_col_scale_to_shared ();
1705+ if constexpr (!dequant_skip_flop) {
1706+ fetch_col_scale_to_shared ();
1707+ }
16521708 }
16531709 }
16541710 fetch_to_shared (i, i, i < slice_iters, i);
@@ -1737,7 +1793,8 @@ __global__ void Marlin(
17371793 bool last = slice_idx == slice_count - 1 ;
17381794 // For per-column scales, we only fetch them here in the final step before
17391795 // write-out
1740- if constexpr (!has_act_order && group_blocks == -1 && !has_zp) {
1796+ if constexpr (!has_act_order && group_blocks == -1 &&
1797+ (has_zp && dequant_skip_flop || !has_zp)) {
17411798 if (w_type.size_bits () == 8 || (last || use_atomic_add)) {
17421799 if (s_sh_wr_pred) {
17431800 cp_async4 (&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
@@ -1747,7 +1804,8 @@ __global__ void Marlin(
17471804 }
17481805
17491806 thread_block_reduce ();
1750- if constexpr (!has_act_order && group_blocks == -1 && !has_zp) {
1807+ if constexpr (!has_act_order && group_blocks == -1 &&
1808+ (has_zp && dequant_skip_flop || !has_zp)) {
17511809 if (w_type.size_bits () == 8 || (last || use_atomic_add)) {
17521810 cp_async_wait<0 >();
17531811 __syncthreads ();
@@ -1771,7 +1829,8 @@ __global__ void Marlin(
17711829 // that converts the fp32 results to fp16 (so that we avoid possible
17721830 // overflow in fp16)
17731831 if constexpr (!has_act_order && group_blocks == -1 &&
1774- w_type.size_bits () == 8 && !has_zp) {
1832+ w_type.size_bits () == 8 &&
1833+ (has_zp && dequant_skip_flop || !has_zp)) {
17751834 if (threadIdx.x / 32 < thread_n_blocks / 4 ) {
17761835 #pragma unroll
17771836 for (int i = 0 ; i < thread_m_blocks; i++) {
0 commit comments