diff --git a/csrc/core/scalar_type.hpp b/csrc/core/scalar_type.hpp index c2ae554c9f8..d0f85e23609 100644 --- a/csrc/core/scalar_type.hpp +++ b/csrc/core/scalar_type.hpp @@ -315,6 +315,8 @@ static inline constexpr auto kS8 = ScalarType::int_(8); static inline constexpr auto kU8 = ScalarType::uint(8); static inline constexpr auto kU8B128 = ScalarType::uint(8, 128); +static inline constexpr auto kFE2M1f = + ScalarType::float_(2, 1, true, ScalarType::NAN_NONE); static inline constexpr auto kFE3M2f = ScalarType::float_(3, 2, true, ScalarType::NAN_NONE); static inline constexpr auto kFE4M3fn = @@ -332,6 +334,7 @@ static inline constexpr auto kInt8 = kS8; static inline constexpr auto kUint8 = kU8; static inline constexpr auto kUint8b128 = kU8B128; +static inline constexpr auto kFloat4_e2m1f = kFE2M1f; static inline constexpr auto kFloat6_e3m2f = kFE3M2f; static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn; static inline constexpr auto kFloat8_e5m2 = kFE5M2; diff --git a/csrc/moe/marlin_moe_wna16/generate_kernels.py b/csrc/moe/marlin_moe_wna16/generate_kernels.py index 902bcd9dfd2..15f008d4f61 100644 --- a/csrc/moe/marlin_moe_wna16/generate_kernels.py +++ b/csrc/moe/marlin_moe_wna16/generate_kernels.py @@ -31,7 +31,10 @@ # int8 with zero point case (vllm::kU8) is also supported, # we don't add it to reduce wheel size. -SCALAR_TYPES = ["vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn"] +SCALAR_TYPES = [ + "vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn", + "vllm::kFE2M1f" +] THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)] THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4] @@ -39,7 +42,7 @@ # = 0 : act order case # = -1 : channelwise quantization # > 0 : group_size=16*group_blocks -GROUP_BLOCKS = [0, -1, 2, 4, 8] +GROUP_BLOCKS = [0, -1, 1, 2, 4, 8] DTYPES = ["fp16", "bf16"] @@ -72,6 +75,12 @@ def generate_new_kernels(): # for fp8 if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]: continue + # nvfp4 only supports group_size == 16 + if scalar_type == "vllm::kFE2M1f" and group_blocks not in [1, 2]: + continue + # other quantization methods don't support group_size = 16 + if scalar_type != "vllm::kFE2M1f" and group_blocks == 1: + continue k_blocks = thread_configs[0] // 16 n_blocks = thread_configs[1] // 16 diff --git a/csrc/moe/marlin_moe_wna16/kernel.h b/csrc/moe/marlin_moe_wna16/kernel.h index c40c33d01f3..537282aba8c 100644 --- a/csrc/moe/marlin_moe_wna16/kernel.h +++ b/csrc/moe/marlin_moe_wna16/kernel.h @@ -7,17 +7,18 @@ #include "quantization/gptq_marlin/marlin_dtypes.cuh" #include "core/scalar_type.hpp" -#define MARLIN_KERNEL_PARAMS \ - const int4 *__restrict__ A, const int4 *__restrict__ B, \ - int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ - const int4 *__restrict__ scales_ptr, const int4 *__restrict__ zp_ptr, \ - const int *__restrict__ g_idx, \ - const int32_t *__restrict__ sorted_token_ids_ptr, \ - const int32_t *__restrict__ expert_ids_ptr, \ - const int32_t *__restrict__ num_tokens_past_padded_ptr, \ - const float *__restrict__ topk_weights_ptr, int top_k, \ - bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \ - int prob_n, int prob_k, int *locks, bool use_atomic_add, \ +#define MARLIN_KERNEL_PARAMS \ + const int4 *__restrict__ A, const int4 *__restrict__ B, \ + int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ + const int4 *__restrict__ scales_ptr, \ + const uint16_t *__restrict__ scale2_ptr, \ + const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \ + const int32_t *__restrict__ sorted_token_ids_ptr, \ + const int32_t *__restrict__ expert_ids_ptr, \ + const int32_t *__restrict__ num_tokens_past_padded_ptr, \ + const float *__restrict__ topk_weights_ptr, int top_k, \ + bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \ + int prob_n, int prob_k, int *locks, bool use_atomic_add, \ bool use_fp32_reduce, int max_shared_mem namespace MARLIN_NAMESPACE_NAME { diff --git a/csrc/moe/marlin_moe_wna16/marlin_template.h b/csrc/moe/marlin_moe_wna16/marlin_template.h index c9e199bcea1..dedbe1b792f 100644 --- a/csrc/moe/marlin_moe_wna16/marlin_template.h +++ b/csrc/moe/marlin_moe_wna16/marlin_template.h @@ -301,9 +301,11 @@ __global__ void Marlin( int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce) const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape // (k/groupsize)xn - const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape - // (k/groupsize)x(n/pack_factor) - const int* __restrict__ g_idx, // int32 group indices of shape k + const uint16_t* __restrict__ scale2_ptr, // fp16 global scale (for nvfp4 + // only) + const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape + // (k/groupsize)x(n/pack_factor) + const int* __restrict__ g_idx, // int32 group indices of shape k const int32_t* __restrict__ sorted_token_ids_ptr, // moe sorted_ids const int32_t* __restrict__ expert_ids_ptr, // moe expert ids const int32_t* __restrict__ num_tokens_past_padded_ptr, // moe num tokens @@ -341,6 +343,16 @@ __global__ void Marlin( extern __shared__ int4 sh[]; static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id); constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8; + constexpr bool is_int_type = w_type == vllm::kU4 || w_type == vllm::kU8 || + w_type == vllm::kU4B8 || w_type == vllm::kU8B128; + // see comments of dequant.h for more details + constexpr bool dequant_skip_flop = + !is_int_type || + has_zp && !is_zp_float && !std::is_same::value || + has_zp && !is_zp_float && !(w_type == vllm::kU8); + + scalar_t2 global_scale; + constexpr bool has_act_order = group_blocks == 0; constexpr int pack_factor = 32 / w_type.size_bits(); @@ -348,7 +360,8 @@ __global__ void Marlin( constexpr int moe_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks); const int group_size = (!has_act_order && group_blocks == -1) ? prob_k : prob_k / num_groups; - const int scales_expert_stride = prob_n * prob_k / group_size / 8; + const int scales_expert_stride = + prob_n * prob_k / group_size / (w_type == vllm::kFE2M1f ? 16 : 8); const int zp_expert_stride = is_zp_float ? prob_n * prob_k / group_size / 8 : prob_n * prob_k / group_size / (pack_factor * 4); @@ -460,9 +473,16 @@ __global__ void Marlin( if (mul_topk_weights) { #pragma unroll for (int i = 0; i < 4; i++) { - sh_block_topk_weights[tid4 * 4 + i] = - Dtype::num2num2(Dtype::float2num( - topk_weights_ptr[sh_block_sorted_ids[tid4 * 4 + i]])); + if constexpr (w_type == vllm::kFE2M1f) { + sh_block_topk_weights[tid4 * 4 + i] = __hmul2( + global_scale, + Dtype::num2num2(Dtype::float2num( + topk_weights_ptr[sh_block_sorted_ids[tid4 * 4 + i]]))); + } else { + sh_block_topk_weights[tid4 * 4 + i] = + Dtype::num2num2(Dtype::float2num( + topk_weights_ptr[sh_block_sorted_ids[tid4 * 4 + i]])); + } } } } @@ -493,6 +513,11 @@ __global__ void Marlin( expert_id = expert_ids_ptr[block_id]; } + if constexpr (w_type == vllm::kFE2M1f) { + uint16_t val = scale2_ptr[expert_id]; + global_scale = Dtype::num2num2(*reinterpret_cast(&val)); + } + B_expert_off = expert_id * prob_n * prob_k / (pack_factor * 4); scales_ptr += (expert_id - old_expert_id) * scales_expert_stride; if constexpr (has_zp) { @@ -606,7 +631,7 @@ __global__ void Marlin( constexpr int s_sh_stride = 16 * thread_n_blocks / 8; constexpr int s_tb_groups = !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks - ? thread_k_blocks / group_blocks + ? thread_k_blocks / group_blocks / (w_type == vllm::kFE2M1f ? 2 : 1) : 1; constexpr int s_sh_stage = s_tb_groups * s_sh_stride; int s_gl_rd_delta = s_gl_stride; @@ -664,7 +689,8 @@ __global__ void Marlin( if constexpr (group_blocks == -1) { s_gl_rd = s_sh_stride * slice_col + threadIdx.x; } else { - s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) / + (w_type == vllm::kFE2M1f ? 2 : 1) + s_sh_stride * slice_col + threadIdx.x; } } @@ -688,10 +714,20 @@ __global__ void Marlin( // we scale a `half2` tile in column-major layout in the former and in // row-major in the latter case. int s_sh_rd; - if constexpr (group_blocks != -1) + if constexpr (group_blocks != -1 && w_type == vllm::kFE2M1f) { + auto warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + int warp_row = warp_id / n_warps; + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; - else if constexpr (group_blocks == -1 && (m_block_size_8 || has_zp)) + s_sh_rd = s_sh_rd * 2 + warp_row % 2; + + } else if constexpr (group_blocks != -1) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; + else if constexpr (group_blocks == -1 && + (m_block_size_8 || (has_zp && !dequant_skip_flop))) s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 8; else @@ -801,7 +837,7 @@ __global__ void Marlin( sh_first_group_id = first_group_id; sh_num_groups = last_group_id - first_group_id + 1; - if (sh_num_groups < act_s_max_num_groups) { + if (sh_num_groups > act_s_max_num_groups) { sh_num_groups = act_s_max_num_groups; } @@ -1021,12 +1057,19 @@ __global__ void Marlin( cur_k += k_iter_size * (k % b_sh_wr_iters); int k_blocks = cur_k / 16; - int cur_group_id = k_blocks / group_blocks; + int cur_group_id = + k_blocks / (group_blocks * (w_type == vllm::kFE2M1f ? 2 : 1)); int4* sh_s_stage = sh_s + s_sh_stage * pipe; - reinterpret_cast(&frag_s[k % 2])[0] = - sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; + if constexpr (w_type_id != vllm::kFE2M1f.id()) { + reinterpret_cast(&frag_s[k % 2])[0] = + sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; + } else { + reinterpret_cast(&frag_s[k % 2])[0] = + reinterpret_cast( + sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)]; + } } } @@ -1199,22 +1242,7 @@ __global__ void Marlin( }; auto dequant_data = [&](int q, scalar_t2* frag_b_ptr) { - if constexpr (has_zp && is_zp_float || !has_zp) { - dequant(q, frag_b_ptr); - } else { - static_assert(has_zp && !is_zp_float); - static_assert(w_type_id == vllm::kU4.id() || w_type_id == vllm::kU8.id()); - // If (has_zp && !is_zp_float), - // we use not-zp version `dequant` function - // to improve numerical accuracy. - // Since both weight and zero point are dequanted using this logic, - // the final dequanted weight would be correct. - if constexpr (w_type_id == vllm::kU4.id()) { - dequant(q, frag_b_ptr); - } else if constexpr (w_type_id == vllm::kU8.id()) { - dequant(q, frag_b_ptr); - } - } + dequant(q, frag_b_ptr); }; // Execute the actual tensor core matmul of a sub-tile. @@ -1244,13 +1272,23 @@ __global__ void Marlin( dequant_data(zp_quant_1, reinterpret_cast(&frag_zp) + 2); } } - if constexpr (has_zp && is_zp_float) { + if constexpr (!dequant_skip_flop && has_zp && is_zp_float) { if (is_new_zp) { reinterpret_cast(&frag_zp)[0] = reinterpret_cast(&frag_zpf[k2])[0]; } } + if constexpr (w_type == vllm::kFE2M1f) { + int s_quant_0 = reinterpret_cast(frag_s[k2])[0]; + int s_quant_1 = reinterpret_cast(frag_s[k2])[1]; + + dequant_fp8_scales(s_quant_0, + reinterpret_cast(&frag_s[k2])); + dequant_fp8_scales( + s_quant_1, reinterpret_cast(&frag_s[k2]) + 2); + } + // We have the m dimension as the inner loop in order to encourage overlapping // dequantization and matmul operations. #pragma unroll @@ -1259,7 +1297,10 @@ __global__ void Marlin( FragB frag_b1; int b_quant_0, b_quant_1; - if constexpr (w_type.size_bits() == 4) { + if constexpr (w_type_id == vllm::kFE2M1f.id()) { + b_quant_1 = frag_b_quant[k2][0][j]; + b_quant_0 = b_quant_1 << 8; + } else if constexpr (w_type.size_bits() == 4) { b_quant_0 = frag_b_quant[k2][0][j]; b_quant_1 = b_quant_0 >> 8; } else { @@ -1272,6 +1313,11 @@ __global__ void Marlin( dequant_data(b_quant_0, reinterpret_cast(&frag_b0)); dequant_data(b_quant_1, reinterpret_cast(&frag_b1)); + if constexpr (dequant_skip_flop && has_zp && !is_zp_float) { + sub_zp(frag_b0, frag_zp[j], 0); + sub_zp(frag_b1, frag_zp[j], 1); + } + // Apply scale to frag_b0 if constexpr (has_act_order) { static_assert(group_blocks != -1); @@ -1279,7 +1325,8 @@ __global__ void Marlin( act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0); scale4(frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j], act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1); - } else if constexpr (has_zp && !is_zp_float && group_blocks == -1) { + } else if constexpr (!dequant_skip_flop && has_zp && !is_zp_float && + group_blocks == -1) { int idx = (threadIdx.x / 4) % 2; scalar_t2 s2 = Dtype::nums2num2( reinterpret_cast(&frag_s[j / 2][j % 2 * 2 + 0])[idx], @@ -1287,7 +1334,7 @@ __global__ void Marlin( if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], s2); scale_and_sub(frag_b0, s2.x, frag_zp[j].x); scale_and_sub(frag_b1, s2.y, frag_zp[j].y); - } else if constexpr (has_zp && group_blocks != -1) { + } else if constexpr (!dequant_skip_flop && has_zp && group_blocks != -1) { if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], *reinterpret_cast(&frag_s[k2][j])); @@ -1554,10 +1601,17 @@ __global__ void Marlin( // For per-column quantization we finally apply the scale here (only for // 4-bit) if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 4 && !has_zp) { + w_type.size_bits() == 4 && + (has_zp && dequant_skip_flop || !has_zp)) { res = __hmul2(res, s[0]); } + if constexpr (w_type == vllm::kFE2M1f) { + if (!mul_topk_weights) { + res = __hmul2(res, global_scale); + } + } + if constexpr (m_block_size_8) { ((scalar_t*)sh_red)[idx] = res.x; ((scalar_t*)sh_red)[idx + 8 * c_sh_stride] = res.y; @@ -1648,7 +1702,9 @@ __global__ void Marlin( if constexpr (has_zp && !is_zp_float && group_blocks == -1) { if (i == 0) { fetch_col_zp_to_shared(); - fetch_col_scale_to_shared(); + if constexpr (!dequant_skip_flop) { + fetch_col_scale_to_shared(); + } } } fetch_to_shared(i, i, i < slice_iters, i); @@ -1737,7 +1793,8 @@ __global__ void Marlin( bool last = slice_idx == slice_count - 1; // For per-column scales, we only fetch them here in the final step before // write-out - if constexpr (!has_act_order && group_blocks == -1 && !has_zp) { + if constexpr (!has_act_order && group_blocks == -1 && + (has_zp && dequant_skip_flop || !has_zp)) { if (w_type.size_bits() == 8 || (last || use_atomic_add)) { if (s_sh_wr_pred) { cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); @@ -1747,7 +1804,8 @@ __global__ void Marlin( } thread_block_reduce(); - if constexpr (!has_act_order && group_blocks == -1 && !has_zp) { + if constexpr (!has_act_order && group_blocks == -1 && + (has_zp && dequant_skip_flop || !has_zp)) { if (w_type.size_bits() == 8 || (last || use_atomic_add)) { cp_async_wait<0>(); __syncthreads(); @@ -1771,7 +1829,8 @@ __global__ void Marlin( // that converts the fp32 results to fp16 (so that we avoid possible // overflow in fp16) if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 8 && !has_zp) { + w_type.size_bits() == 8 && + (has_zp && dequant_skip_flop || !has_zp)) { if (threadIdx.x / 32 < thread_n_blocks / 4) { #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { diff --git a/csrc/moe/marlin_moe_wna16/ops.cu b/csrc/moe/marlin_moe_wna16/ops.cu index 00b4e934cc3..2cff04f699b 100644 --- a/csrc/moe/marlin_moe_wna16/ops.cu +++ b/csrc/moe/marlin_moe_wna16/ops.cu @@ -291,6 +291,7 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8, // BIGGROUP: cases for big group size (group_blocks in [-1, 8]) // FZP: cases for float-zero-point (is_zp_float = true) // ACT: cases for act order case (group_blocks == 0) + // FP4: cases for nvfp4(e2m1) (group_blocks == 1) #define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \ @@ -338,6 +339,21 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8, _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + #define FP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) + + #define FP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) + + #define FP4_GET_IF(W_TYPE) \ + FP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ + FP4_GET_IF_M1(W_TYPE, 8, 4, 128) \ + FP4_GET_IF_M234(W_TYPE, 16, 4, 256) \ + FP4_GET_IF_M234(W_TYPE, 8, 4, 128) + #define BIGGROUP_GET_IF(W_TYPE) \ BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \ BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \ @@ -394,6 +410,8 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type, BIGGROUP_GET_IF(vllm::kFE4M3fn) + FP4_GET_IF(vllm::kFE2M1f) + ACT_GET_IF(vllm::kU4B8) ACT_GET_IF(vllm::kU8B128) @@ -465,7 +483,7 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m, template void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, - void* zp, void* g_idx, void* perm, void* a_tmp, + void* s2, void* zp, void* g_idx, void* perm, void* a_tmp, void* sorted_token_ids, void* expert_ids, void* num_tokens_past_padded, void* topk_weights, int moe_block_size, int top_k, bool mul_topk_weights, bool is_ep, @@ -479,14 +497,16 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, bool m_block_size_8 = moe_block_size == 8; if (has_zp) { - TORCH_CHECK(q_type == vllm::kU4, - "q_type must be u4 when has_zp = True. Got = ", q_type.str()); + TORCH_CHECK( + q_type == vllm::kU4 || q_type == vllm::kU8, + "q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str()); } else { - TORCH_CHECK(q_type == vllm::kU4B8 || q_type == vllm::kU8B128 || - q_type == vllm::kFE4M3fn, - "q_type must be uint4b8, uint8b128 or fp8e4m3 when has_zp = " - "False. Got = ", - q_type.str()); + TORCH_CHECK( + q_type == vllm::kU4B8 || q_type == vllm::kU8B128 || + q_type == vllm::kFE4M3fn || q_type == vllm::kFE2M1f, + "q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when " + "has_zp = False. Got = ", + q_type.str()); } TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, @@ -519,6 +539,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, int4* C_ptr = (int4*)C; int4* C_tmp_ptr = (int4*)C_tmp; const int4* s_ptr = (const int4*)s; + const uint16_t* s2_ptr = (const uint16_t*)s2; const int4* zp_ptr = (const int4*)zp; const int* g_idx_ptr = (const int*)g_idx; const int* perm_ptr = (const int*)perm; @@ -627,7 +648,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, // avoid ">>>" being formatted to "> > >" // clang-format off kernel<<>>( - A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, + A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, s2_ptr, zp_ptr, g_idx_ptr, sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr, topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m, prob_n, prob_k, locks, use_atomic_add, use_fp32_reduce, max_shared_mem); @@ -639,6 +660,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, torch::Tensor moe_wna16_marlin_gemm( torch::Tensor& a, std::optional const& c_or_none, torch::Tensor& b_q_weight, torch::Tensor& b_scales, + std::optional const& global_scale_or_none, std::optional const& b_zeros_or_none, std::optional const& g_idx_or_none, std::optional const& perm_or_none, torch::Tensor& workspace, @@ -790,6 +812,17 @@ torch::Tensor moe_wna16_marlin_gemm( } } + torch::Tensor global_scale; + if (global_scale_or_none.has_value()) { + global_scale = global_scale_or_none.value(); + TORCH_CHECK(b_q_type == vllm::kFE2M1f, + "global_scale can only be used for float4_e2m1f."); + } else { + global_scale = torch::empty({0}, options); + TORCH_CHECK(!(b_q_type == vllm::kFE2M1f), + "the global_scale parameter must be passed for float4_e2m1f."); + } + torch::Tensor b_zeros; if (b_zeros_or_none.has_value()) { b_zeros = b_zeros_or_none.value(); @@ -802,13 +835,14 @@ torch::Tensor moe_wna16_marlin_gemm( if (has_zp) { TORCH_CHECK( - b_q_type == vllm::kU4, - "b_q_type must be u4 when has_zp = True. Got = ", b_q_type.str()); + b_q_type == vllm::kU4 || b_q_type == vllm::kU8, + "b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str()); } else { TORCH_CHECK(b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128 || - b_q_type == vllm::kFE4M3fn, - "b_q_type must be uint4b8, uint8b128 or fp8e4m3 when has_zp = " - "False. Got = ", + b_q_type == vllm::kFE4M3fn || b_q_type == vllm::kFE2M1f, + "b_q_type must be uint4b8, uint8b128, float8_e4m3fn or " + "float4_e2m1f when " + "has_zp = False. Got = ", b_q_type.str()); } @@ -854,9 +888,16 @@ torch::Tensor moe_wna16_marlin_gemm( int dev = a.get_device(); if (a.scalar_type() == at::ScalarType::Half) { + void* scales_ptr; + if (b_q_type == vllm::kFE2M1f) { + scales_ptr = b_scales.data_ptr(); + } else { + scales_ptr = b_scales.data_ptr(); + } + MARLIN_NAMESPACE_NAME::marlin_mm( a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), - c_tmp.data_ptr(), b_scales.data_ptr(), + c_tmp.data_ptr(), scales_ptr, global_scale.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), sorted_token_ids.data_ptr(), expert_ids.data_ptr(), num_tokens_past_padded.data_ptr(), @@ -866,11 +907,18 @@ torch::Tensor moe_wna16_marlin_gemm( at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, use_atomic_add, use_fp32_reduce, is_zp_float); } else if (a.scalar_type() == at::ScalarType::BFloat16) { + void* scales_ptr; + if (b_q_type == vllm::kFE2M1f) { + scales_ptr = b_scales.data_ptr(); + } else { + scales_ptr = b_scales.data_ptr(); + } + MARLIN_NAMESPACE_NAME::marlin_mm( a.data_ptr(), b_q_weight.data_ptr(), - c.data_ptr(), c_tmp.data_ptr(), - b_scales.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), - perm.data_ptr(), a_tmp.data_ptr(), + c.data_ptr(), c_tmp.data_ptr(), scales_ptr, + global_scale.data_ptr(), b_zeros.data_ptr(), + g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), sorted_token_ids.data_ptr(), expert_ids.data_ptr(), num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(), moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k, diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 2a8b9bb39ca..810026d034c 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -44,7 +44,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { m.def( "moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none," - "Tensor! b_q_weight, Tensor! b_scales, Tensor? b_zeros_or_none," + "Tensor! b_q_weight, Tensor! b_scales, Tensor? global_scale, Tensor? " + "b_zeros_or_none," "Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace," "Tensor sorted_token_ids," "Tensor! expert_ids, Tensor! num_tokens_past_padded," diff --git a/csrc/quantization/gptq_marlin/dequant.h b/csrc/quantization/gptq_marlin/dequant.h index 3c0d77ac345..ae0d6c0f200 100644 --- a/csrc/quantization/gptq_marlin/dequant.h +++ b/csrc/quantization/gptq_marlin/dequant.h @@ -1,3 +1,67 @@ +/* +Fast Dequantization (Converting INT4/INT8/FP4/FP8 to FP16/BF16) + +The process of fast dequantization can be summarized as a combination +of bitwise operations and floating-point computations: + +weight =>(bit_op / bitwise operations)=> +f16_value =>(flop / floating-point computation)=> +dequantized_weight + +Since the dequantized weights typically require subtracting the zero point and +applying a scale factor, the floating-point computation step can be fused with +the zero-point subtraction and scaling operations. + +The following are the parts that need to be modified for the fused operation +of zero-point subtraction and scaling. + +## INT4 => FP16/BF16 or INT8 => FP16 + +The floating-point computation is `__hsub2` + +If has zero points: + + flop(bit_op(weight)) - flop(bit_op(zp)) + = sub(bit_op(weight), bias) - sub(bit_op(zp), bias) + = bit_op(weight) - bit_op(zp) + +so we don't need additional modification. + +If has float zero points: + + flop(bit_op(weight)) - fzp + = sub(bit_op(weight), bias) - fzp + = bit_op(weight) - (fzp + bias) + +where the `fzp + bias` can be computed at weight loading. But this +may have accuracy issue, so we should not use this in most cases. + +If has not zero points: + + scale(flop(bit_op(weight))) + = scale(sub(bit_op(weight), bias)) + = scale(bit_op(weight)) - scale(bias) + = fma(bit_op(weight), scale_factor, scale(bias)) + +where the `scale(bias)` can be cached. But this may have accuracy issue, +so we should not use this in most cases. + + +## INT8 => BF16 + +INT8 => BF16 is a special case, it use byte_perm instead of flop. +We cannot fused byte_perm with scaling. + + +## FP4/FP8 => FP16/BF16 + + scale(flop(bit_op(weight))) + = scale(mul(bit_op(weight), multiplier)) + = mul(bit_op(weight), scale_factor * multiplier) + +where `scale_factor * multiplier` can be computed at weight loading. + +*/ #include "marlin_dtypes.cuh" @@ -27,7 +91,8 @@ __device__ inline uint32_t prmt(uint32_t a) { return res; } -template +template __device__ inline void dequant(int q, scalar_t2* frag_b); // @@ -40,7 +105,22 @@ __device__ inline void dequant(int q, scalar_t2* frag_b); // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385 // template <> -__device__ inline void dequant(int q, half2* frag_b) { +__device__ inline void dequant(int q, + half2* frag_b) { + const int MASK = 0x000f000f; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + q >>= 4; + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + + frag_b[0] = *reinterpret_cast(&lo); + frag_b[1] = *reinterpret_cast(&hi); +} + +template <> +__device__ inline void dequant(int q, + half2* frag_b) { const int LO = 0x000f000f; const int HI = 0x00f000f0; const int EX = 0x64006400; @@ -62,7 +142,14 @@ __device__ inline void dequant(int q, half2* frag_b) { } template <> -__device__ inline void dequant(int q, half2* frag_b) { +__device__ inline void dequant(int q, + half2* frag_b) { + dequant(q, frag_b); +} + +template <> +__device__ inline void dequant(int q, + half2* frag_b) { const int LO = 0x000f000f; const int HI = 0x00f000f0; const int EX = 0x64006400; @@ -84,7 +171,7 @@ __device__ inline void dequant(int q, half2* frag_b) { } template <> -__device__ inline void dequant( +__device__ inline void dequant( int q, nv_bfloat162* frag_b) { static constexpr uint32_t MASK = 0x000f000f; static constexpr uint32_t EX = 0x43004300; @@ -96,39 +183,36 @@ __device__ inline void dequant( int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); // clang-format on - static constexpr uint32_t MUL = 0x3F803F80; - static constexpr uint32_t ADD = 0xC308C308; + frag_b[0] = *reinterpret_cast(&lo); + frag_b[1] = *reinterpret_cast(&hi); +} + +template <> +__device__ inline void dequant( + int q, nv_bfloat162* frag_b) { + dequant(q, frag_b); - frag_b[0] = __hfma2(*reinterpret_cast(&lo), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); + static constexpr uint32_t SUB = 0x43084308; + + frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast(&SUB)); + frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast(&SUB)); } template <> -__device__ inline void dequant( +__device__ inline void dequant( int q, nv_bfloat162* frag_b) { - static constexpr uint32_t MASK = 0x000f000f; - static constexpr uint32_t EX = 0x43004300; + dequant(q, frag_b); +} - // Guarantee that the `(a & b) | c` operations are LOP3s. - // clang-format off - int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); - q >>= 4; - int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); - // clang-format on +template <> +__device__ inline void dequant( + int q, nv_bfloat162* frag_b) { + dequant(q, frag_b); - static constexpr uint32_t MUL = 0x3F803F80; - static constexpr uint32_t ADD = 0xC300C300; + static constexpr uint32_t SUB = 0x43004300; - frag_b[0] = __hfma2(*reinterpret_cast(&lo), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); + frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast(&SUB)); + frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast(&SUB)); } // @@ -140,8 +224,8 @@ __device__ inline void dequant( // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175 // template <> -__device__ inline void dequant(int q, - half2* frag_b) { +__device__ inline void dequant(int q, + half2* frag_b) { static constexpr uint32_t mask_for_elt_01 = 0x5250; static constexpr uint32_t mask_for_elt_23 = 0x5351; static constexpr uint32_t start_byte_for_fp16 = 0x64646464; @@ -149,33 +233,42 @@ __device__ inline void dequant(int q, uint32_t lo = prmt(q); uint32_t hi = prmt(q); - static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + frag_b[0] = *reinterpret_cast(&lo); + frag_b[1] = *reinterpret_cast(&hi); +} - frag_b[0] = __hsub2(*reinterpret_cast(&lo), +template <> +__device__ inline void dequant( + int q, half2* frag_b) { + dequant(q, frag_b); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - frag_b[1] = __hsub2(*reinterpret_cast(&hi), + frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); } template <> -__device__ inline void dequant(int q, half2* frag_b) { - static constexpr uint32_t mask_for_elt_01 = 0x5250; - static constexpr uint32_t mask_for_elt_23 = 0x5351; - static constexpr uint32_t start_byte_for_fp16 = 0x64646464; +__device__ inline void dequant(int q, + half2* frag_b) { + dequant(q, frag_b); +} - uint32_t lo = prmt(q); - uint32_t hi = prmt(q); +template <> +__device__ inline void dequant(int q, + half2* frag_b) { + dequant(q, frag_b); static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400; - - frag_b[0] = __hsub2(*reinterpret_cast(&lo), + frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - frag_b[1] = __hsub2(*reinterpret_cast(&hi), + frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); } template <> -__device__ inline void dequant( +__device__ inline void dequant( int q, nv_bfloat162* frag_b) { float fp32_intermediates[4]; uint32_t* fp32_intermediates_casted = @@ -200,7 +293,7 @@ __device__ inline void dequant( } template <> -__device__ inline void dequant( +__device__ inline void dequant( int q, nv_bfloat162* frag_b) { float fp32_intermediates[4]; uint32_t* fp32_intermediates_casted = @@ -225,22 +318,30 @@ __device__ inline void dequant( } template <> -__device__ inline void dequant(int q, - half2* frag_b) { +__device__ inline void dequant( + int q, half2* frag_b) { // Constants for FP8 (E4M3) and FP16 formats - constexpr int FP8_EXPONENT = 4, FP8_MANTISSA = 3, FP16_EXPONENT = 5; + constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5; constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT; - - // Calculate MASK for extracting mantissa and exponent - constexpr int MASK1 = 0x80000000; - constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA); - constexpr int MASK3 = MASK2 & 0x7fffffff; - constexpr int MASK = MASK3 | (MASK3 >> 16); - // Final MASK value: 0x7F007F00 + constexpr int MASK = 0x7F007F00; // Extract and shift FP8 values to FP16 format int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); - int Out2 = ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT); + q <<= 8; + int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + +template <> +__device__ inline void dequant( + int q, half2* frag_b) { + dequant(q, frag_b); + + // Constants for FP8 (E4M3) and FP16 formats + constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5; // Construct and apply exponent bias constexpr int BIAS_OFFSET = @@ -248,28 +349,36 @@ __device__ inline void dequant(int q, const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET)); // Convert to half2 and apply bias - // Note: reverse indexing is intentional because weights are permuted - frag_b[1] = __hmul2(*reinterpret_cast(&Out1), bias_reg); - frag_b[0] = __hmul2(*reinterpret_cast(&Out2), bias_reg); + frag_b[1] = __hmul2(frag_b[1], bias_reg); + frag_b[0] = __hmul2(frag_b[0], bias_reg); } template <> -__device__ inline void dequant( +__device__ inline void dequant( int q, nv_bfloat162* frag_b) { // Constants for FP8 (E4M3) and BF16 formats - constexpr int FP8_EXPONENT = 4, FP8_MANTISSA = 3, BF16_EXPONENT = 8; + constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8; constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT; - // Calculate MASK for extracting mantissa and exponent - constexpr int MASK1 = 0x80000000; - constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA); - constexpr int MASK3 = MASK2 & 0x7fffffff; - constexpr int MASK = MASK3 | (MASK3 >> 16); - // Final MASK value: 0x7F007F00 + constexpr int MASK = 0x7F007F00; // Extract and shift FP8 values to BF16 format int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); - int Out2 = ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT); + q <<= 8; + int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + +template <> +__device__ inline void dequant( + int q, nv_bfloat162* frag_b) { + dequant(q, frag_b); + + // Constants for FP8 (E4M3) and BF16 formats + constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8; // Construct and apply exponent bias constexpr int BIAS_OFFSET = @@ -281,9 +390,116 @@ __device__ inline void dequant( __float2bfloat162_rn(*reinterpret_cast(&BIAS)); // Convert to bfloat162 and apply bias + frag_b[1] = __hmul2(frag_b[1], bias_reg); + frag_b[0] = __hmul2(frag_b[0], bias_reg); +} + +template <> +__device__ inline void dequant(int q, + half2* frag_b) { + // Constants for FP4 (E2M1) and FP16 formats + constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5; + constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP4_EXPONENT; + constexpr int MASK = 0x70007000; + + // Extract and shift FP4 values to FP16 format + int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 4; + int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + +template <> +__device__ inline void dequant( + int q, half2* frag_b) { + dequant(q, frag_b); + + // Constants for FP4 (E2M1) and FP16 formats + constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5; + + // Construct and apply exponent bias + constexpr int BIAS_OFFSET = + (1 << (FP16_EXPONENT - 1)) - (1 << (FP4_EXPONENT - 1)); + const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET)); + + // Convert to half2 and apply bias + frag_b[1] = __hmul2(frag_b[1], bias_reg); + frag_b[0] = __hmul2(frag_b[0], bias_reg); +} + +template <> +__device__ inline void dequant( + int q, nv_bfloat162* frag_b) { + // Constants for FP4 (E2M1) and FP16 formats + constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8; + constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP4_EXPONENT; + constexpr int MASK = 0x70007000; + + // Extract and shift FP4 values to FP16 format + int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 4; + int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + +template <> +__device__ inline void dequant( + int q, nv_bfloat162* frag_b) { + dequant(q, frag_b); + + // Constants for FP4 (E2M1) and BF16 formats + constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8; + + // Construct and apply exponent bias + constexpr int BIAS_OFFSET = + (1 << (BF16_EXPONENT - 1)) - (1 << (FP4_EXPONENT - 1)); + // Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent + // position + constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23; + const nv_bfloat162 bias_reg = + __float2bfloat162_rn(*reinterpret_cast(&BIAS)); + + // Convert to half2 and apply bias + frag_b[1] = __hmul2(frag_b[1], bias_reg); + frag_b[0] = __hmul2(frag_b[0], bias_reg); +} + +template +__device__ inline void dequant_fp8_scales(int q, scalar_t2* frag_b); + +template <> +__device__ inline void dequant_fp8_scales(int q, half2* frag_b) { + int Out1 = (q & 0xFF00FF00) >> 1; + ; + q <<= 8; + int Out2 = (q & 0xFF00FF00) >> 1; + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +}; + +template <> +__device__ inline void dequant_fp8_scales(int q, + nv_bfloat162* frag_b) { + constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8; + constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT; + constexpr int MASK = 0x7F007F00; + + // Extract and shift FP8 values to BF16 format + int Out1 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 8; + int Out2 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT); + // Note: reverse indexing is intentional because weights are permuted - frag_b[1] = __hmul2(*reinterpret_cast(&Out1), bias_reg); - frag_b[0] = __hmul2(*reinterpret_cast(&Out2), bias_reg); + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); } #endif diff --git a/csrc/quantization/gptq_marlin/generate_kernels.py b/csrc/quantization/gptq_marlin/generate_kernels.py index 8b4b951f3d8..4ac7121ab4e 100644 --- a/csrc/quantization/gptq_marlin/generate_kernels.py +++ b/csrc/quantization/gptq_marlin/generate_kernels.py @@ -31,7 +31,10 @@ # int8 with zero point case (vllm::kU8) is also supported, # we don't add it to reduce wheel size. -SCALAR_TYPES = ["vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn"] +SCALAR_TYPES = [ + "vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn", + "vllm::kFE2M1f" +] THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128), (128, 64, 128)] @@ -40,7 +43,7 @@ # = 0 : act order case # = -1 : channelwise quantization # > 0 : group_size=16*group_blocks -GROUP_BLOCKS = [0, -1, 2, 4, 8] +GROUP_BLOCKS = [0, 1, -1, 2, 4, 8] DTYPES = ["fp16", "bf16"] @@ -73,6 +76,12 @@ def generate_new_kernels(): # for fp8 if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]: continue + # nvfp4 only supports group_size == 16 + if scalar_type == "vllm::kFE2M1f" and group_blocks != 1: + continue + # other quantization methods don't support group_size = 16 + if scalar_type != "vllm::kFE2M1f" and group_blocks == 1: + continue k_blocks = thread_configs[0] // 16 n_blocks = thread_configs[1] // 16 diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index 02527a48166..4a242f2050d 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -258,6 +258,7 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks, // BIGGROUP: cases for big group size (group_blocks in [-1, 8]) // FZP: cases for float-zero-point (is_zp_float = true) // ACT: cases for act order case (group_blocks == 0) + // FP4: cases for nvfp4(e2m1) (group_blocks == 1) #define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \ @@ -314,6 +315,23 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks, BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128) \ BIGGROUP_GET_IF_M234(W_TYPE, 4, 8, 128) + #define FP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) + + #define FP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) + + #define FP4_GET_IF(W_TYPE) \ + FP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ + FP4_GET_IF_M1(W_TYPE, 8, 4, 128) \ + FP4_GET_IF_M1(W_TYPE, 4, 8, 128) \ + FP4_GET_IF_M234(W_TYPE, 16, 4, 256) \ + FP4_GET_IF_M234(W_TYPE, 8, 4, 128) \ + FP4_GET_IF_M234(W_TYPE, 4, 8, 128) + // We currently have 4-bit models only with group_blocks == 4 #define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \ @@ -366,6 +384,8 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type, COMMON_GET_IF(vllm::kU4B8) COMMON_GET_IF(vllm::kU8B128) + FP4_GET_IF(vllm::kFE2M1f) + BIGGROUP_GET_IF(vllm::kFE4M3fn) ACT_GET_IF(vllm::kU4B8) @@ -434,8 +454,8 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m, template void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, - void* zp, void* g_idx, void* perm, void* a_tmp, int prob_m, - int prob_n, int prob_k, int lda, void* workspace, + void* s2, void* zp, void* g_idx, void* perm, void* a_tmp, + int prob_m, int prob_n, int prob_k, int lda, void* workspace, vllm::ScalarType const& q_type, bool has_act_order, bool is_k_full, bool has_zp, int num_groups, int group_size, int dev, cudaStream_t stream, int thread_k_init, @@ -446,11 +466,12 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, q_type == vllm::kU4 || q_type == vllm::kU8, "q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str()); } else { - TORCH_CHECK(q_type == vllm::kU4B8 || q_type == vllm::kU8B128 || - q_type == vllm::kFE4M3fn, - "q_type must be uint4b8, uint8b128 or float8_e4m3fn when " - "has_zp = False. Got = ", - q_type.str()); + TORCH_CHECK( + q_type == vllm::kU4B8 || q_type == vllm::kU8B128 || + q_type == vllm::kFE4M3fn || q_type == vllm::kFE2M1f, + "q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when " + "has_zp = False. Got = ", + q_type.str()); } TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, @@ -483,6 +504,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, int4* C_ptr = (int4*)C; int4* C_tmp_ptr = (int4*)C_tmp; const int4* s_ptr = (const int4*)s; + const uint16_t* s2_ptr = (const uint16_t*)s2; const int4* zp_ptr = (const int4*)zp; const int* g_idx_ptr = (const int*)g_idx; const int* perm_ptr = (const int*)perm; @@ -601,7 +623,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, // avoid ">>>" being formatted to "> > >" // clang-format off kernel<<>>( - A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, num_groups, + A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, s2_ptr, zp_ptr, g_idx_ptr, num_groups, prob_m_split, prob_n, prob_k, lda, locks, part_use_atomic_add, use_fp32_reduce, max_shared_mem_new); // clang-format on @@ -617,6 +639,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, torch::Tensor gptq_marlin_gemm( torch::Tensor& a, std::optional c_or_none, torch::Tensor& b_q_weight, torch::Tensor& b_scales, + std::optional const& global_scale_or_none, std::optional const& b_zeros_or_none, std::optional const& g_idx_or_none, std::optional const& perm_or_none, torch::Tensor& workspace, @@ -759,6 +782,17 @@ torch::Tensor gptq_marlin_gemm( } } + torch::Tensor global_scale; + if (global_scale_or_none.has_value()) { + global_scale = global_scale_or_none.value(); + TORCH_CHECK(b_q_type == vllm::kFE2M1f, + "global_scale can only be used for float4_e2m1f."); + } else { + global_scale = torch::empty({0}, options); + TORCH_CHECK(!(b_q_type == vllm::kFE2M1f), + "the global_scale parameter must be passed for float4_e2m1f."); + } + torch::Tensor b_zeros; if (b_zeros_or_none.has_value()) { b_zeros = b_zeros_or_none.value(); @@ -774,8 +808,9 @@ torch::Tensor gptq_marlin_gemm( "b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str()); } else { TORCH_CHECK(b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128 || - b_q_type == vllm::kFE4M3fn, - "b_q_type must be uint4b8, uint8b128 or float8_e4m3fn when " + b_q_type == vllm::kFE4M3fn || b_q_type == vllm::kFE2M1f, + "b_q_type must be uint4b8, uint8b128, float8_e4m3fn or " + "float4_e2m1f when " "has_zp = False. Got = ", b_q_type.str()); } @@ -820,22 +855,36 @@ torch::Tensor gptq_marlin_gemm( int dev = a.get_device(); if (a.scalar_type() == at::ScalarType::Half) { + void* scales_ptr; + if (b_q_type == vllm::kFE2M1f) { + scales_ptr = b_scales.data_ptr(); + } else { + scales_ptr = b_scales.data_ptr(); + } + marlin::marlin_mm( a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), - c_tmp.data_ptr(), b_scales.data_ptr(), + c_tmp.data_ptr(), scales_ptr, global_scale.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, size_k, a.stride(0), workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, use_atomic_add, use_fp32_reduce, is_zp_float); } else if (a.scalar_type() == at::ScalarType::BFloat16) { + void* scales_ptr; + if (b_q_type == vllm::kFE2M1f) { + scales_ptr = b_scales.data_ptr(); + } else { + scales_ptr = b_scales.data_ptr(); + } + marlin::marlin_mm( a.data_ptr(), b_q_weight.data_ptr(), - c.data_ptr(), c_tmp.data_ptr(), - b_scales.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), - perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, size_k, - a.stride(0), workspace.data_ptr(), b_q_type, has_act_order, is_k_full, - has_zp, num_groups, group_size, dev, + c.data_ptr(), c_tmp.data_ptr(), scales_ptr, + global_scale.data_ptr(), b_zeros.data_ptr(), + g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), + size_m, size_n, size_k, a.stride(0), workspace.data_ptr(), b_q_type, + has_act_order, is_k_full, has_zp, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, use_atomic_add, use_fp32_reduce, is_zp_float); } else { diff --git a/csrc/quantization/gptq_marlin/kernel.h b/csrc/quantization/gptq_marlin/kernel.h index eb2700c95e8..f92056589d2 100644 --- a/csrc/quantization/gptq_marlin/kernel.h +++ b/csrc/quantization/gptq_marlin/kernel.h @@ -7,13 +7,14 @@ #include "marlin_dtypes.cuh" #include "core/scalar_type.hpp" -#define MARLIN_KERNEL_PARAMS \ - const int4 *__restrict__ A, const int4 *__restrict__ B, \ - int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ - const int4 *__restrict__ scales_ptr, const int4 *__restrict__ zp_ptr, \ - const int *__restrict__ g_idx, int num_groups, int prob_m, int prob_n, \ - int prob_k, int lda, int *locks, bool use_atomic_add, \ - bool use_fp32_reduce, int max_shared_mem +#define MARLIN_KERNEL_PARAMS \ + const int4 *__restrict__ A, const int4 *__restrict__ B, \ + int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ + const int4 *__restrict__ scales_ptr, \ + const uint16_t *__restrict__ scale2_ptr, \ + const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \ + int num_groups, int prob_m, int prob_n, int prob_k, int lda, int *locks, \ + bool use_atomic_add, bool use_fp32_reduce, int max_shared_mem namespace MARLIN_NAMESPACE_NAME { template ::value || + has_zp && !is_zp_float && !(w_type == vllm::kU8); + + scalar_t2 global_scale; + + if constexpr (w_type == vllm::kFE2M1f) { + uint16_t val = scale2_ptr[0]; + global_scale = Dtype::num2num2(*reinterpret_cast(&val)); + } + constexpr bool has_act_order = group_blocks == 0; constexpr int m_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks); @@ -481,7 +498,7 @@ __global__ void Marlin( constexpr int s_sh_stride = 16 * thread_n_blocks / 8; constexpr int s_tb_groups = !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks - ? thread_k_blocks / group_blocks + ? thread_k_blocks / group_blocks / (w_type == vllm::kFE2M1f ? 2 : 1) : 1; constexpr int s_sh_stage = s_tb_groups * s_sh_stride; int s_gl_rd_delta = s_gl_stride; @@ -540,7 +557,8 @@ __global__ void Marlin( if constexpr (group_blocks == -1) { s_gl_rd = s_sh_stride * slice_col + threadIdx.x; } else { - s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) / + (w_type == vllm::kFE2M1f ? 2 : 1) + s_sh_stride * slice_col + threadIdx.x; } } @@ -564,10 +582,20 @@ __global__ void Marlin( // we scale a `half2` tile in column-major layout in the former and in // row-major in the latter case. int s_sh_rd; - if constexpr (group_blocks != -1) + if constexpr (group_blocks != -1 && w_type == vllm::kFE2M1f) { + auto warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + int warp_row = warp_id / n_warps; + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; - else if constexpr (group_blocks == -1 && (m_block_size_8 || has_zp)) + s_sh_rd = s_sh_rd * 2 + warp_row % 2; + + } else if constexpr (group_blocks != -1) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; + else if constexpr (group_blocks == -1 && + (m_block_size_8 || (has_zp && !dequant_skip_flop))) s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 8; else @@ -681,7 +709,7 @@ __global__ void Marlin( sh_first_group_id = first_group_id; sh_num_groups = last_group_id - first_group_id + 1; - if (sh_num_groups < act_s_max_num_groups) { + if (sh_num_groups > act_s_max_num_groups) { sh_num_groups = act_s_max_num_groups; } @@ -887,12 +915,19 @@ __global__ void Marlin( cur_k += k_iter_size * (k % b_sh_wr_iters); int k_blocks = cur_k / 16; - int cur_group_id = k_blocks / group_blocks; + int cur_group_id = + k_blocks / (group_blocks * (w_type == vllm::kFE2M1f ? 2 : 1)); int4* sh_s_stage = sh_s + s_sh_stage * pipe; - reinterpret_cast(&frag_s[k % 2])[0] = - sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; + if constexpr (w_type_id != vllm::kFE2M1f.id()) { + reinterpret_cast(&frag_s[k % 2])[0] = + sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; + } else { + reinterpret_cast(&frag_s[k % 2])[0] = + reinterpret_cast( + sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)]; + } } } @@ -1065,22 +1100,7 @@ __global__ void Marlin( }; auto dequant_data = [&](int q, scalar_t2* frag_b_ptr) { - if constexpr (has_zp && is_zp_float || !has_zp) { - dequant(q, frag_b_ptr); - } else { - static_assert(has_zp && !is_zp_float); - static_assert(w_type_id == vllm::kU4.id() || w_type_id == vllm::kU8.id()); - // If (has_zp && !is_zp_float), - // we use not-zp version `dequant` function - // to improve numerical accuracy. - // Since both weight and zero point are dequanted using this logic, - // the final dequanted weight would be correct. - if constexpr (w_type_id == vllm::kU4.id()) { - dequant(q, frag_b_ptr); - } else if constexpr (w_type_id == vllm::kU8.id()) { - dequant(q, frag_b_ptr); - } - } + dequant(q, frag_b_ptr); }; // Execute the actual tensor core matmul of a sub-tile. @@ -1110,13 +1130,23 @@ __global__ void Marlin( dequant_data(zp_quant_1, reinterpret_cast(&frag_zp) + 2); } } - if constexpr (has_zp && is_zp_float) { + if constexpr (!dequant_skip_flop && has_zp && is_zp_float) { if (is_new_zp) { reinterpret_cast(&frag_zp)[0] = reinterpret_cast(&frag_zpf[k2])[0]; } } + if constexpr (w_type == vllm::kFE2M1f) { + int s_quant_0 = reinterpret_cast(frag_s[k2])[0]; + int s_quant_1 = reinterpret_cast(frag_s[k2])[1]; + + dequant_fp8_scales(s_quant_0, + reinterpret_cast(&frag_s[k2])); + dequant_fp8_scales( + s_quant_1, reinterpret_cast(&frag_s[k2]) + 2); + } + // We have the m dimension as the inner loop in order to encourage overlapping // dequantization and matmul operations. #pragma unroll @@ -1125,7 +1155,10 @@ __global__ void Marlin( FragB frag_b1; int b_quant_0, b_quant_1; - if constexpr (w_type.size_bits() == 4) { + if constexpr (w_type_id == vllm::kFE2M1f.id()) { + b_quant_1 = frag_b_quant[k2][0][j]; + b_quant_0 = b_quant_1 << 8; + } else if constexpr (w_type.size_bits() == 4) { b_quant_0 = frag_b_quant[k2][0][j]; b_quant_1 = b_quant_0 >> 8; } else { @@ -1138,6 +1171,11 @@ __global__ void Marlin( dequant_data(b_quant_0, reinterpret_cast(&frag_b0)); dequant_data(b_quant_1, reinterpret_cast(&frag_b1)); + if constexpr (dequant_skip_flop && has_zp && !is_zp_float) { + sub_zp(frag_b0, frag_zp[j], 0); + sub_zp(frag_b1, frag_zp[j], 1); + } + // Apply scale to frag_b0 if constexpr (has_act_order) { static_assert(group_blocks != -1); @@ -1145,7 +1183,8 @@ __global__ void Marlin( act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0); scale4(frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j], act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1); - } else if constexpr (has_zp && !is_zp_float && group_blocks == -1) { + } else if constexpr (!dequant_skip_flop && has_zp && !is_zp_float && + group_blocks == -1) { int idx = (threadIdx.x / 4) % 2; scalar_t2 s2 = Dtype::nums2num2( reinterpret_cast(&frag_s[j / 2][j % 2 * 2 + 0])[idx], @@ -1153,7 +1192,7 @@ __global__ void Marlin( if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], s2); scale_and_sub(frag_b0, s2.x, frag_zp[j].x); scale_and_sub(frag_b1, s2.y, frag_zp[j].y); - } else if constexpr (has_zp && group_blocks != -1) { + } else if constexpr (!dequant_skip_flop && has_zp && group_blocks != -1) { if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], *reinterpret_cast(&frag_s[k2][j])); @@ -1408,10 +1447,15 @@ __global__ void Marlin( // For per-column quantization we finally apply the scale here (only for // 4-bit) if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 4 && !has_zp) { + w_type.size_bits() == 4 && + (has_zp && dequant_skip_flop || !has_zp)) { res = __hmul2(res, s[0]); } + if constexpr (w_type == vllm::kFE2M1f) { + res = __hmul2(res, global_scale); + } + if constexpr (m_block_size_8) { ((scalar_t*)sh_red)[idx] = res.x; ((scalar_t*)sh_red)[idx + 8 * c_sh_stride] = res.y; @@ -1488,7 +1532,9 @@ __global__ void Marlin( if constexpr (has_zp && !is_zp_float && group_blocks == -1) { if (i == 0) { fetch_col_zp_to_shared(); - fetch_col_scale_to_shared(); + if constexpr (!dequant_skip_flop) { + fetch_col_scale_to_shared(); + } } } fetch_to_shared(i, i, i < slice_iters); @@ -1563,7 +1609,8 @@ __global__ void Marlin( bool last = slice_idx == slice_count - 1; // For per-column scales, we only fetch them here in the final step before // write-out - if constexpr (!has_act_order && group_blocks == -1 && !has_zp) { + if constexpr (!has_act_order && group_blocks == -1 && + (has_zp && dequant_skip_flop || !has_zp)) { if (w_type.size_bits() == 8 || (last || use_atomic_add)) { if (s_sh_wr_pred) { cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); @@ -1573,7 +1620,8 @@ __global__ void Marlin( } thread_block_reduce(); - if constexpr (!has_act_order && group_blocks == -1 && !has_zp) { + if constexpr (!has_act_order && group_blocks == -1 && + (has_zp && dequant_skip_flop || !has_zp)) { if (w_type.size_bits() == 8 || (last || use_atomic_add)) { cp_async_wait<0>(); __syncthreads(); @@ -1597,7 +1645,8 @@ __global__ void Marlin( // that converts the fp32 results to fp16 (so that we avoid possible // overflow in fp16) if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 8 && !has_zp) { + w_type.size_bits() == 8 && + (has_zp && dequant_skip_flop || !has_zp)) { if (threadIdx.x / 32 < thread_n_blocks / 4) { #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 1dbd11f5f2a..2430641ea6e 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -292,8 +292,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // gptq_marlin Optimized Quantized GEMM for GPTQ. ops.def( "gptq_marlin_gemm(Tensor a, Tensor? c_or_none, Tensor b_q_weight, " - "Tensor b_scales, Tensor? b_zeros_or_none, Tensor? g_idx_or_none, " - "Tensor? perm_or_none, Tensor workspace, int b_q_type, " + "Tensor b_scales, Tensor? global_scale, Tensor? b_zeros_or_none, Tensor? " + "g_idx_or_none, Tensor? perm_or_none, Tensor workspace, int b_q_type, " "SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, " "bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) -> Tensor", {stride_tag}); diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index d6831006038..c1d0940f26c 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -16,6 +16,8 @@ from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( fused_moe as iterative_moe) +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + rand_marlin_weight_fp4_like) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( marlin_quant_fp8_torch) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( @@ -286,21 +288,64 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool, atol=mixtral_moe_tol[dtype]) +def marlin_moe_generate_valid_test_cases(): + import itertools + m_list = [1, 123, 666] + n_list = [128, 1024] + k_list = [256, 2048] + e_list = [4, 12] + topk_list = [2, 3] + ep_size_list = [1, 4] + dtype_list = [torch.half, torch.bfloat16] + group_size_list = [-1, 16, 32, 128] + act_order_list = [True, False] + quant_type_list = [ + scalar_types.float4_e2m1f, + scalar_types.float8_e4m3fn, + scalar_types.uint4, + scalar_types.uint4b8, + scalar_types.uint8b128, + ] + is_k_full_list = [True, False] + + all_combinations = itertools.product(m_list, n_list, k_list, e_list, + topk_list, ep_size_list, dtype_list, + group_size_list, act_order_list, + quant_type_list, is_k_full_list) + + def is_invalid(m, n, k, e, topk, ep_size, dtype, group_size, act_order, + quant_type, is_k_full): + + if quant_type == scalar_types.float8_e4m3fn and \ + group_size not in [-1, 128]: + return False + if quant_type == scalar_types.float4_e2m1f and group_size != 16: + return False + if quant_type != scalar_types.float4_e2m1f and group_size == 16: + return False + + # Filter act_order + if act_order: + if group_size in (-1, k, n): + return False + if quant_type not in [scalar_types.uint4b8]: + return False + elif not is_k_full: + return False + + return True + + cases = [] + for case in all_combinations: + if is_invalid(*case): + cases.append(case) + return cases + + @pytest.mark.flaky(reruns=2) -@pytest.mark.parametrize("m", [1, 123, 666]) -@pytest.mark.parametrize("n", [128, 1024]) -@pytest.mark.parametrize("k", [256, 2048]) -@pytest.mark.parametrize("e", [4, 12]) -@pytest.mark.parametrize("topk", [2, 3]) -@pytest.mark.parametrize("ep_size", [1, 4]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("group_size", [-1, 32, 128]) -@pytest.mark.parametrize("act_order", [True, False]) -@pytest.mark.parametrize("quant_type", [ - scalar_types.uint4, scalar_types.uint8b128, scalar_types.uint4b8, - scalar_types.float8_e4m3fn -]) -@pytest.mark.parametrize("is_k_full", [True, False]) +@pytest.mark.parametrize(("m, n, k, e, topk, ep_size, dtype, group_size," + "act_order, quant_type, is_k_full"), + marlin_moe_generate_valid_test_cases()) @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") def test_fused_marlin_moe( m: int, @@ -338,6 +383,11 @@ def test_fused_marlin_moe( if not is_k_full: return + if quant_type == scalar_types.float4_e2m1f and group_size != 16: + return + if quant_type != scalar_types.float4_e2m1f and group_size == 16: + return + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 20 @@ -355,12 +405,27 @@ def test_fused_marlin_moe( w_ref1_l = [] qweight1_l = [] scales1_l = [] + global_scale1_l = [] zeros1_l = [] g_idx1_l = [] sort_indices1_l = [] for i in range(w1.shape[0]): - if has_zp: + if quant_type == scalar_types.float4_e2m1f: + w_ref1, qweight1, scales1, global_scale1 = \ + rand_marlin_weight_fp4_like(w1[i], group_size) + + w_ref1_l.append(w_ref1.T) + qweight1_l.append(qweight1) + scales1_l.append(scales1) + global_scale1_l.append(global_scale1) + elif quant_type == scalar_types.float8_e4m3fn: + w_ref1, qweight1, scales1 = marlin_quant_fp8_torch( + w1[i], group_size) + w_ref1_l.append(w_ref1.T) + qweight1_l.append(qweight1) + scales1_l.append(scales1) + elif has_zp: w_ref1, qweight1, scales1, zeros1 = awq_marlin_quantize( w1[i].transpose(1, 0), quant_type, group_size) @@ -368,7 +433,7 @@ def test_fused_marlin_moe( qweight1_l.append(qweight1) scales1_l.append(scales1) zeros1_l.append(zeros1) - elif quant_type != scalar_types.float8_e4m3fn: + else: test_perm = torch.randperm(k) w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = \ marlin_quantize(w1[i].transpose(1, 0), quant_type, @@ -379,16 +444,11 @@ def test_fused_marlin_moe( scales1_l.append(scales1) g_idx1_l.append(g_idx1) sort_indices1_l.append(sort_indices1) - else: - w_ref1, qweight1, scales1 = marlin_quant_fp8_torch( - w1[i], group_size) - w_ref1_l.append(w_ref1.T) - qweight1_l.append(qweight1) - scales1_l.append(scales1) w_ref1 = stack_and_dev(w_ref1_l) qweight1 = stack_and_dev(qweight1_l).contiguous() scales1 = stack_and_dev(scales1_l) + global_scale1 = stack_and_dev(global_scale1_l) if global_scale1_l else None g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None zeros1 = stack_and_dev(zeros1_l) if zeros1_l else None sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None @@ -396,12 +456,27 @@ def test_fused_marlin_moe( w_ref2_l = [] qweight2_l = [] scales2_l = [] + global_scale2_l = [] zeros2_l = [] g_idx2_l = [] sort_indices2_l = [] for i in range(w2.shape[0]): - if has_zp: + if quant_type == scalar_types.float4_e2m1f: + w_ref2, qweight2, scales2, global_scale2 = \ + rand_marlin_weight_fp4_like(w2[i], group_size) + + w_ref2_l.append(w_ref2.T) + qweight2_l.append(qweight2) + scales2_l.append(scales2) + global_scale2_l.append(global_scale2) + elif quant_type == scalar_types.float8_e4m3fn: + w_ref2, qweight2, scales2 = marlin_quant_fp8_torch( + w2[i], group_size) + w_ref2_l.append(w_ref2.T) + qweight2_l.append(qweight2) + scales2_l.append(scales2) + elif has_zp: w_ref2, qweight2, scales2, zeros2 = awq_marlin_quantize( w2[i].transpose(1, 0), quant_type, group_size) @@ -409,7 +484,7 @@ def test_fused_marlin_moe( qweight2_l.append(qweight2) scales2_l.append(scales2) zeros2_l.append(zeros2) - elif quant_type != scalar_types.float8_e4m3fn: + else: test_perm = torch.randperm(n) w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = \ marlin_quantize(w2[i].transpose(1, 0), quant_type, @@ -420,24 +495,18 @@ def test_fused_marlin_moe( scales2_l.append(scales2) g_idx2_l.append(g_idx2) sort_indices2_l.append(sort_indices2) - else: - w_ref2, qweight2, scales2 = marlin_quant_fp8_torch( - w2[i], group_size) - w_ref2_l.append(w_ref2.T) - qweight2_l.append(qweight2) - scales2_l.append(scales2) w_ref2 = stack_and_dev(w_ref2_l) qweight2 = stack_and_dev(qweight2_l).contiguous() scales2 = stack_and_dev(scales2_l) + global_scale2 = stack_and_dev(global_scale2_l) if global_scale2_l else None g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None zeros2 = stack_and_dev(zeros2_l) if zeros2_l else None sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None score = torch.randn((m, e), device="cuda", dtype=dtype) - topk_weights, topk_ids, token_expert_indices = fused_topk( - a, score, topk, False) + topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, e_map) @@ -452,6 +521,8 @@ def test_fused_marlin_moe( topk_ids, global_num_experts=e, expert_map=e_map, + global_scale1=global_scale1, + global_scale2=global_scale2, g_idx1=g_idx1, g_idx2=g_idx2, sort_indices1=sort_indices1, diff --git a/tests/kernels/quantization/test_marlin_gemm.py b/tests/kernels/quantization/test_marlin_gemm.py index c125e0b5ec7..52507b375c2 100644 --- a/tests/kernels/quantization/test_marlin_gemm.py +++ b/tests/kernels/quantization/test_marlin_gemm.py @@ -20,6 +20,8 @@ MARLIN_SUPPORTED_GROUP_SIZES, marlin_make_empty_g_idx, marlin_make_workspace_new, marlin_permute_scales, query_marlin_supported_quant_types) +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + FP4_MARLIN_SUPPORTED_GROUP_SIZES, rand_marlin_weight_fp4_like) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( marlin_quant_fp8_torch) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( @@ -190,9 +192,10 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, reason="Marlin is not supported on this GPU type.") @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) -@pytest.mark.parametrize("quant_type", - query_marlin_supported_quant_types(False)) -@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES) +@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types()) +@pytest.mark.parametrize( + "group_size", + set(MARLIN_SUPPORTED_GROUP_SIZES + FP4_MARLIN_SUPPORTED_GROUP_SIZES)) @pytest.mark.parametrize("mnk_factors", MNK_FACTORS) @pytest.mark.parametrize("act_order", ACT_ORDER_OPTS) @pytest.mark.parametrize("is_k_full", K_FULL_OPTS) @@ -210,6 +213,7 @@ def test_gptq_marlin_gemm( use_fp32_reduce, ): m_factor, n_factor, k_factor = mnk_factors + has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8] size_m = m_factor size_k = k_chunk * k_factor @@ -220,6 +224,8 @@ def test_gptq_marlin_gemm( return if group_size == size_k: return + if has_zp: + return if size_k % group_size != 0: return @@ -227,7 +233,15 @@ def test_gptq_marlin_gemm( a_input = rand_data((size_m, size_k)) b_weight = rand_data((size_k, size_n)) - if quant_type == scalar_types.float8_e4m3fn: + if quant_type == scalar_types.float4_e2m1f: + if group_size != 16 or act_order: + return + w_ref, marlin_q_w, marlin_s, marlin_s2 = rand_marlin_weight_fp4_like( + b_weight.T, group_size) + g_idx = None + sort_indices = None + marlin_zp = None + elif quant_type == scalar_types.float8_e4m3fn: if group_size not in [-1, 128]: return if act_order: @@ -236,26 +250,39 @@ def test_gptq_marlin_gemm( b_weight.T, group_size) g_idx = None sort_indices = None + marlin_zp = None + marlin_s2 = None + elif has_zp: + if group_size == 16: + return + w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize( + b_weight, quant_type, group_size) + g_idx = None + sort_indices = None + marlin_s2 = None else: + if group_size == 16: + return w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize( b_weight, quant_type, group_size, act_order) - - marlin_zp = marlin_make_empty_g_idx(marlin_s.device) + marlin_zp = None + marlin_s2 = None workspace = marlin_make_workspace_new(w_ref.device) - opcheck( - torch.ops._C.gptq_marlin_gemm, - (a_input, None, marlin_q_w, marlin_s, marlin_zp, g_idx, sort_indices, - workspace, quant_type.id, a_input.shape[0], b_weight.shape[1], - a_input.shape[1], is_k_full, use_atomic_add, use_fp32_reduce, False), - test_utils=DEFAULT_OPCHECK_TEST_UTILS) + opcheck(torch.ops._C.gptq_marlin_gemm, + (a_input, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, g_idx, + sort_indices, workspace, quant_type.id, a_input.shape[0], + b_weight.shape[1], a_input.shape[1], is_k_full, use_atomic_add, + use_fp32_reduce, False), + test_utils=DEFAULT_OPCHECK_TEST_UTILS) output = ops.gptq_marlin_gemm( a_input, None, marlin_q_w, marlin_s, + marlin_s2, marlin_zp, g_idx, sort_indices, @@ -339,67 +366,6 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size, assert max_diff < 0.04 -@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"), - reason="Marlin is not supported on this GPU type.") -@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) -@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) -@pytest.mark.parametrize("quant_type", - query_marlin_supported_quant_types(True)) -@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES) -@pytest.mark.parametrize("mnk_factors", MNK_FACTORS) -@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS) -def test_awq_marlin_gemm( - k_chunk, - n_chunk, - quant_type, - group_size, - mnk_factors, - use_fp32_reduce, -): - m_factor, n_factor, k_factor = mnk_factors - - size_m = m_factor - size_k = k_chunk * k_factor - size_n = n_chunk * n_factor - - a_input = rand_data((size_m, size_k)) - b_weight = rand_data((size_k, size_n)) - - w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize( - b_weight, quant_type, group_size) - - g_idx = torch.empty(0, dtype=torch.int, device=marlin_q_w.device) - sort_indices = torch.empty(0, dtype=torch.int, device=marlin_q_w.device) - is_k_full = True - - workspace = marlin_make_workspace_new(a_input.device) - - output = ops.gptq_marlin_gemm( - a_input, - None, - marlin_q_w, - marlin_s, - marlin_zp, - g_idx, - sort_indices, - workspace, - quant_type, - a_input.shape[0], - b_weight.shape[1], - a_input.shape[1], - is_k_full=is_k_full, - use_fp32_reduce=use_fp32_reduce, - is_zp_float=False, - ) - output_ref = torch.matmul(a_input, w_ref) - - torch.cuda.synchronize() - - max_diff = compute_max_diff(output, output_ref) - - assert max_diff < 0.04 - - @pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"), reason="Marlin is not supported on this GPU type.") @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) @@ -452,6 +418,7 @@ def test_hqq_marlin_gemm( None, marlin_w_q, marlin_s, + None, marlin_zp, g_idx, g_idx_sort_indices, @@ -564,6 +531,7 @@ def test_marlin_gemm_subset_input(): None, marlin_q_w, marlin_s, + None, marlin_zp, g_idx, sort_indices, diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 80f54974521..9d920b64404 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -333,6 +333,7 @@ def _gptq_marlin_gemm_fake(a: torch.Tensor, c: Optional[torch.Tensor], b_q_weight: torch.Tensor, b_scales: torch.Tensor, + global_scale: Optional[torch.Tensor], b_zeros: Optional[torch.Tensor], g_idx: Optional[torch.Tensor], perm: Optional[torch.Tensor], @@ -866,6 +867,7 @@ def gptq_marlin_gemm(a: torch.Tensor, c: Optional[torch.Tensor], b_q_weight: torch.Tensor, b_scales: torch.Tensor, + global_scale: Optional[torch.Tensor], b_zeros: Optional[torch.Tensor], g_idx: Optional[torch.Tensor], perm: Optional[torch.Tensor], @@ -878,9 +880,10 @@ def gptq_marlin_gemm(a: torch.Tensor, use_atomic_add: bool = False, use_fp32_reduce: bool = False, is_zp_float: bool = False) -> torch.Tensor: - return torch.ops._C.gptq_marlin_gemm(a, c, b_q_weight, b_scales, b_zeros, - g_idx, perm, workspace, b_q_type.id, - size_m, size_n, size_k, is_k_full, + return torch.ops._C.gptq_marlin_gemm(a, c, b_q_weight, b_scales, + global_scale, b_zeros, g_idx, perm, + workspace, b_q_type.id, size_m, + size_n, size_k, is_k_full, use_atomic_add, use_fp32_reduce, is_zp_float) @@ -1381,6 +1384,7 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor, def moe_wna16_marlin_gemm(input: torch.Tensor, output: Optional[torch.Tensor], b_qweight: torch.Tensor, b_scales: torch.Tensor, + global_scale: Optional[torch.Tensor], b_qzeros: Optional[torch.Tensor], g_idx: Optional[torch.Tensor], perm: Optional[torch.Tensor], @@ -1395,11 +1399,11 @@ def moe_wna16_marlin_gemm(input: torch.Tensor, output: Optional[torch.Tensor], use_fp32_reduce: bool, is_zp_float: bool) -> torch.Tensor: return torch.ops._moe_C.moe_wna16_marlin_gemm( - input, output, b_qweight, b_scales, b_qzeros, g_idx, perm, workspace, - sorted_token_ids, expert_ids, num_tokens_past_padded, topk_weights, - moe_block_size, top_k, mul_topk_weights, is_ep, b_q_type.id, size_m, - size_n, size_k, is_k_full, use_atomic_add, use_fp32_reduce, - is_zp_float) + input, output, b_qweight, b_scales, global_scale, b_qzeros, g_idx, + perm, workspace, sorted_token_ids, expert_ids, num_tokens_past_padded, + topk_weights, moe_block_size, top_k, mul_topk_weights, is_ep, + b_q_type.id, size_m, size_n, size_k, is_k_full, use_atomic_add, + use_fp32_reduce, is_zp_float) if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"): diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index b96d34ec2db..4c84dd53833 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -25,6 +25,8 @@ def fused_marlin_moe(hidden_states: torch.Tensor, quant_type_id: int, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, + global_scale1: Optional[torch.Tensor] = None, + global_scale2: Optional[torch.Tensor] = None, g_idx1: Optional[torch.Tensor] = None, g_idx2: Optional[torch.Tensor] = None, sort_indices1: Optional[torch.Tensor] = None, @@ -64,11 +66,13 @@ def fused_marlin_moe(hidden_states: torch.Tensor, quant_type = ScalarType.from_id(quant_type_id) assert quant_type in [ scalar_types.uint4, scalar_types.uint8b128, scalar_types.uint4b8, - scalar_types.float8_e4m3fn + scalar_types.float8_e4m3fn, scalar_types.float4_e2m1f ] - int4_scalar_types = [scalar_types.uint4, scalar_types.uint4b8] - num_bits = 4 if quant_type in int4_scalar_types else 8 + bit4_scalar_types = [ + scalar_types.uint4, scalar_types.uint4b8, scalar_types.float4_e2m1f + ] + num_bits = 4 if quant_type in bit4_scalar_types else 8 # Check constraints. assert hidden_states.shape[0] == gating_output.shape[ @@ -133,6 +137,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor, intermediate_cache1, w1, w1_scale, + global_scale1, w1_zeros, g_idx1, sort_indices1, @@ -165,6 +170,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor, intermediate_cache3, w2, w2_scale, + global_scale2, w2_zeros, g_idx2, sort_indices2, @@ -202,6 +208,8 @@ def fused_marlin_moe_fake(hidden_states: torch.Tensor, topk_ids: torch.Tensor, quant_type_id: int, global_num_experts: int = -1, + global_scale1: Optional[torch.Tensor] = None, + global_scale2: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None, g_idx1: Optional[torch.Tensor] = None, g_idx2: Optional[torch.Tensor] = None, diff --git a/vllm/model_executor/layers/quantization/hqq_marlin.py b/vllm/model_executor/layers/quantization/hqq_marlin.py index 7bd398137e0..e7511f330ea 100644 --- a/vllm/model_executor/layers/quantization/hqq_marlin.py +++ b/vllm/model_executor/layers/quantization/hqq_marlin.py @@ -304,8 +304,10 @@ def apply( marlin_out = ops.gptq_marlin_gemm( x, + None, layer.marlin_qweight, scales, + None, zeros, layer.g_idx, layer.g_idx_sort_indices, @@ -315,7 +317,7 @@ def apply( self.output_size_per_partition, self.input_size_per_partition, True, # is_k_full - True, # has_zp + False, # use atomic add True, # use 32-bit reduce True, # use float zp ) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index e9b16b8a0ac..bd9daa7c608 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -17,6 +17,9 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + apply_fp4_marlin_linear, is_fp4_marlin_supported, + prepare_fp4_layer_for_marlin, prepare_moe_fp4_layer_for_marlin) from vllm.model_executor.layers.quantization.utils.quant_utils import ( is_layer_skipped) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( @@ -24,6 +27,7 @@ from vllm.model_executor.parameter import (ModelWeightParameter, PerTensorScaleParameter) from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types logger = init_logger(__name__) @@ -196,7 +200,7 @@ def get_supported_act_dtypes(cls) -> List[torch.dtype]: @classmethod def get_min_capability(cls) -> int: - return 100 + return 80 @classmethod def get_config_filenames(cls) -> List[str]: @@ -278,9 +282,15 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): def __init__(self, quant_config: ModelOptNvFp4Config): self.quant_config = quant_config self.cutlass_nvfp4_supported = cutlass_fp4_supported() + self.use_marlin = False + if not self.cutlass_nvfp4_supported: - raise ValueError("Current platform does not support NVFP4" - " quantization. Please use Blackwell and above.") + if is_fp4_marlin_supported(): + self.use_marlin = True + else: + raise ValueError("Current platform does not support NVFP4" + " quantization. Please use Blackwell and" + " above.") def create_weights( self, @@ -392,12 +402,29 @@ def process_weights_after_loading(self, layer: Module) -> None: layer.weight_scale_swizzled = Parameter(swizzled_weight_scale, requires_grad=False) + if self.use_marlin: + prepare_fp4_layer_for_marlin(layer) + del layer.alpha + del layer.input_scale + del layer.weight_scale_swizzled + def apply( self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if self.use_marlin: + return apply_fp4_marlin_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + weight_scale_2=layer.weight_scale_2, + workspace=layer.workspace, + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + bias=bias) + output_dtype = x.dtype # for input only the contracting dimension has a constraint. @@ -434,6 +461,16 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): def __init__(self, quant_config: ModelOptNvFp4Config): self.quant_config = quant_config + self.cutlass_nvfp4_supported = cutlass_fp4_supported() + self.use_marlin = False + + if not self.cutlass_nvfp4_supported: + if is_fp4_marlin_supported(): + self.use_marlin = True + else: + raise ValueError("Current platform does not support NVFP4" + " quantization. Please use Blackwell and" + " above.") def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, @@ -442,6 +479,8 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, raise ValueError("NVFP4 quantization was selected, " " dynamic quantization is not supported.") + layer.num_experts = num_experts + layer.params_dtype = params_dtype layer.quant_config = self.quant_config weight_dtype = torch.uint8 weight_scale_dtype = torch.float8_e4m3fn @@ -594,7 +633,15 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w2_blockscale_swizzled = Parameter(w2_blockscale_swizzled, requires_grad=False) - return + + if self.use_marlin: + prepare_moe_fp4_layer_for_marlin(layer) + del layer.g1_alphas + del layer.g2_alphas + del layer.w13_input_scale_quant + del layer.w2_input_scale_quant + del layer.w13_blockscale_swizzled + del layer.w2_blockscale_swizzled def apply( self, @@ -614,6 +661,35 @@ def apply( apply_router_weight_on_input: bool = False, activation: str = "silu", ): + if self.use_marlin: + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + ) + + return torch.ops.vllm.fused_marlin_moe( + x, + layer.w13_weight, + layer.w2_weight, + layer.w13_weight_scale, + layer.w2_weight_scale, + router_logits, + topk_weights, + topk_ids, + global_scale1=layer.w13_weight_scale_2, + global_scale2=layer.w2_weight_scale_2, + quant_type_id=scalar_types.float4_e2m1f.id, + global_num_experts=global_num_experts, + expert_map=expert_map) + assert activation == "silu", "Only SiLU activation is supported." assert not apply_router_weight_on_input, ( "Router weight on input is not " diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index a2b1b7cb0e1..89268ef7a38 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -33,7 +33,7 @@ # without runtime zero-point. We support common cases, i.e. AWQ and GPTQ. # TODO: we may want to move this into the C++ so its closer to the actual impl def query_marlin_supported_quant_types( - has_zp: bool, + has_zp: Optional[bool] = None, include_fp_type: bool = True, device_capability: Optional[int] = None, ): @@ -45,6 +45,16 @@ def query_marlin_supported_quant_types( if device_capability < 80: return [] + # - has_zp is True: return quant_types that has zero points + # - has_zp is False: return quant_types that has not zero points + # - has_zp is None: both + if has_zp is None: + types0 = query_marlin_supported_quant_types(False, include_fp_type, + device_capability) + types1 = query_marlin_supported_quant_types(True, include_fp_type, + device_capability) + return types0 + types1 + if has_zp: # AWQ style, unsigned + runtime zero-point return [scalar_types.uint4] @@ -52,7 +62,7 @@ def query_marlin_supported_quant_types( # GPTQ style, unsigned + symmetric bias res = [scalar_types.uint4b8, scalar_types.uint8b128] if include_fp_type: - res += [scalar_types.float8_e4m3fn] + res += [scalar_types.float8_e4m3fn, scalar_types.float4_e2m1f] return res @@ -394,6 +404,7 @@ def apply_gptq_marlin_linear( None, weight, weight_scale, + None, weight_zp, g_idx, g_idx_sort_indices, @@ -439,6 +450,7 @@ def apply_awq_marlin_linear( None, weight, weight_scale, + None, weight_zp, g_idx, g_idx_sort_indices, diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py new file mode 100644 index 00000000000..15177af58ae --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py @@ -0,0 +1,277 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional + +import torch + +import vllm._custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_scales, + should_use_atomic_add_reduce) +from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types + +FP4_MARLIN_SUPPORTED_GROUP_SIZES = [16] + +logger = init_logger(__name__) + + +def is_fp4_marlin_supported(): + return current_platform.has_device_capability(80) + + +def fp4_marlin_process_scales(marlin_scales): + assert (marlin_scales >= 0).all() + + # convert to half first, we would convert to fp8 later + marlin_scales = marlin_scales.to(torch.half) + + # 8 is the number of scale number using by one thread + marlin_scales = marlin_scales.view(marlin_scales.size(0) // 2, 2, -1, 8) + marlin_scales = marlin_scales.permute(0, 2, 1, 3).reshape( + marlin_scales.size(0) * 2, -1) + + # fit the layout of fp8 dequantization + marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view( + marlin_scales.size(0), -1) + + # We assume that weight_scale (FP8-S1E4M3) is always greater + # than or equal to 0. So we can convert + # (weight_scale * (2 ** 7) to a special FP8-S0E5M3 format. + # After multiplying by 2 ** 7, the top bit of FP8-S0E5M3 would always be 1 + # when weight_scale > 0. This allows us to have an exponent bias + # closer to zero after dequantization. + + marlin_scales = (marlin_scales * (2**7)).view(torch.int16) << 1 + marlin_scales = marlin_scales.view(torch.float8_e4m3fn) + marlin_scales = marlin_scales[:, 1::2].contiguous() + + return marlin_scales + + +def fp4_marlin_process_global_scale(global_scale): + assert global_scale.dtype in [torch.half, torch.bfloat16] + fp4_exponent = 2 + if global_scale.dtype == torch.half: + target_exponent = 5 + elif global_scale.dtype == torch.bfloat16: + target_exponent = 8 + # exponent_bias_fp16 = 2 ** 4 - 2 ** 1 = 14 + # exponent_bias_bf16 = 2 ** 7 - 2 ** 1 = 126 + exponent_bias = 2**(target_exponent - 1) - 2**(fp4_exponent - 1) + return global_scale * (2.0**(exponent_bias - 7)) + + +def apply_fp4_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_scale_2: torch.Tensor, + workspace: torch.Tensor, + size_n: int, + size_k: int, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: + # For GPUs that lack FP4 hardware support, we can leverage the + # Marlin kernel for fast weight-only FP4 quantization + + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (size_n, ) + + use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0), + n=size_n, + k=size_k, + device=input.device, + dtype=input.dtype) + + output = ops.gptq_marlin_gemm(a=reshaped_x, + c=None, + b_q_weight=weight, + b_scales=weight_scale, + global_scale=weight_scale_2, + b_zeros=None, + g_idx=None, + perm=None, + workspace=workspace, + b_q_type=scalar_types.float4_e2m1f, + size_m=reshaped_x.size(0), + size_n=size_n, + size_k=size_k, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) + + +def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: + logger.warning_once( + "Your GPU does not have native support for FP4 computation but " + "FP4 quantization is being used. Weight-only FP4 compression will " + "be used leveraging the Marlin kernel. This may degrade " + "performance for compute-heavy workloads.") + + part_size_n = layer.output_size_per_partition + part_size_k = layer.input_size_per_partition + param_dtype = layer.params_dtype + + assert layer.weight.shape == (part_size_n, part_size_k // 2) + + device = layer.weight.device + + # WORKSPACE + layer.workspace = marlin_make_workspace_new(device) + + # WEIGHT + # Repack weights to marlin format + perm = torch.empty(0, dtype=torch.int, device=device) + qweight = layer.weight.view(torch.int32).T.contiguous() + + marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight, + perm=perm, + size_k=part_size_k, + size_n=part_size_n, + num_bits=4) + layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) + + # WEIGHT SCALES + # Permute scales + weight_scale = layer.weight_scale.T.to(param_dtype) + weight_scale = marlin_permute_scales(s=weight_scale, + size_k=part_size_k, + size_n=part_size_n, + group_size=16) + weight_scale = fp4_marlin_process_scales(weight_scale) + layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) + + weight_scale_2 = layer.weight_scale_2.to(param_dtype) + weight_scale_2 = fp4_marlin_process_global_scale(weight_scale_2) + layer.weight_scale_2 = torch.nn.Parameter(weight_scale_2, + requires_grad=False) + + return + + +def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: + logger.warning_once( + "Your GPU does not have native support for FP4 computation but " + "FP4 quantization is being used. Weight-only FP4 compression will " + "be used leveraging the Marlin kernel. This may degrade " + "performance for compute-heavy workloads.") + + e = layer.num_experts + k = layer.hidden_size + n = layer.intermediate_size_per_partition + + # WORKSPACE + device = layer.w13_weight.device + param_dtype = layer.params_dtype + layer.workspace = marlin_make_workspace_new(device, 4) + perm = torch.empty(0, dtype=torch.int, device=device) + + # WEIGHT + # Repack weights to marlin format + for name in ["w13_weight", "w2_weight"]: + weight = getattr(layer, name) + tensor_list = [] + if "w13" in name: + size_n, size_k = n * 2, k + else: + size_n, size_k = k, n + + assert weight.shape == (e, size_n, size_k // 2) + + for i in range(e): + qweight = weight[i].view(torch.int32).T.contiguous() + + marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight, + perm=perm, + size_k=size_k, + size_n=size_n, + num_bits=4) + tensor_list.append(marlin_qweight) + + weight = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) + weight = torch.nn.Parameter(weight, requires_grad=False) + + setattr(layer, name, weight) + + # WEIGHT SCALES + # Permute scales + for name in ["w13", "w2"]: + scales = getattr(layer, name + "_weight_scale").to(param_dtype) + global_scale = getattr(layer, name + "_weight_scale_2").to(param_dtype) + + tensor_list = [] + if "w13" in name: + size_n, size_k = n * 2, k + else: + size_n, size_k = k, n + + for i in range(e): + marlin_scales = marlin_permute_scales(s=scales[i].T, + size_k=size_k, + size_n=size_n, + group_size=16) + marlin_scales = fp4_marlin_process_scales(marlin_scales) + tensor_list.append(marlin_scales) + + scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) + scales = torch.nn.Parameter(scales, requires_grad=False) + setattr(layer, name + "_weight_scale", scales) + + global_scale = fp4_marlin_process_global_scale(global_scale) + global_scale = torch.nn.Parameter(global_scale, requires_grad=False) + setattr(layer, name + "_weight_scale_2", global_scale) + + +def rand_marlin_weight_fp4_like(weight, group_size): + assert group_size > 0 + size_n, size_k = weight.shape + device = weight.device + + scales = weight.view(size_n, -1, group_size).abs().max(-1)[0] / 6 + global_scale = scales.max() / 448 + scales = (scales / global_scale).to(torch.float8_e4m3fn) + + fp4_weight = torch.randint(0, + 256, (size_n, size_k // 2), + dtype=torch.uint8, + device=weight.device) + fp4_weight_part_1 = ((fp4_weight & 0b10000000) | + ((fp4_weight & 0b01110000) >> 2)) + fp4_weight_part_1 = fp4_weight_part_1.view(torch.float8_e4m3fn) + fp4_weight_part_1 = fp4_weight_part_1.to(weight.dtype) * (2**6) + + fp4_weight2 = fp4_weight << 4 + fp4_weight_part_2 = ((fp4_weight2 & 0b10000000) | + ((fp4_weight2 & 0b01110000) >> 2)) + fp4_weight_part_2 = fp4_weight_part_2.view(torch.float8_e4m3fn) + fp4_weight_part_2 = fp4_weight_part_2.to(weight.dtype) * (2**6) + + weight_ref = torch.cat( + [fp4_weight_part_2.unsqueeze(2), + fp4_weight_part_1.unsqueeze(2)], 2).view(size_n, size_k) + weight_ref = weight_ref * global_scale.to(weight.dtype) * \ + scales.repeat_interleave(group_size, 1).to(weight.dtype) + + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=fp4_weight.view(torch.int32).T.contiguous(), + perm=torch.empty(0, dtype=torch.int, device=device), + size_k=size_k, + size_n=size_n, + num_bits=4, + ) + + marlin_scales = marlin_permute_scales(s=scales.T.to(weight.dtype), + size_k=size_k, + size_n=size_n, + group_size=group_size) + marlin_scales = fp4_marlin_process_scales(marlin_scales) + + global_scale = fp4_marlin_process_global_scale(global_scale) + + return weight_ref.T, marlin_qweight, marlin_scales, global_scale diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py index 1e0078e246b..3080d2a0da8 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py @@ -19,6 +19,20 @@ def is_fp8_marlin_supported(): return current_platform.has_device_capability(80) +def fp8_fused_exponent_bias_into_scales(scales): + fp8_exponent = 4 + if scales.dtype == torch.half: + target_exponent = 5 + elif scales.dtype == torch.bfloat16: + target_exponent = 8 + # exponent_bias_fp16 = 2 ** 4 - 2 ** 3 = 8 + # exponent_bias_bf16 = 2 ** 7 - 2 ** 3 = 120 + exponent_bias = 2**(target_exponent - 1) - 2**(fp8_exponent - 1) + s = torch.ones_like(scales) * 2 + s = s**exponent_bias + return scales * s + + def apply_fp8_marlin_linear( input: torch.Tensor, weight: torch.Tensor, @@ -44,6 +58,7 @@ def apply_fp8_marlin_linear( c=None, b_q_weight=weight, b_scales=weight_scale, + global_scale=None, b_zeros=None, g_idx=None, perm=None, @@ -132,8 +147,10 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module, # block-wise quantization -> group-wise quantization # (size_k // block_size[1], ceil(size_n / block_size[0])) # =>(repeat)=> (size_k // block_size[1], size_n) + if not size_k_first: + scales = scales.T.contiguous() block_n = layer.weight_block_size[0] - scales = scales.T.repeat_interleave(block_n, 1) + scales = scales.repeat_interleave(block_n, 1) # size_n may not divisible by block_size[0] scales = scales[:, :part_size_n] @@ -141,6 +158,7 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module, size_k=part_size_k, size_n=part_size_n, group_size=group_size) + marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales) layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False) @@ -239,8 +257,10 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module, # block-wise quantization -> group-wise quantization # (e, size_k // block_size[1], ceil(size_n / block_size[0])) # =>(repeat)=> (e, size_k // block_size[1], size_n) + if not size_k_first: + scales = scales.permute(0, 2, 1) block_n = layer.weight_block_size[0] - scales = scales.permute(0, 2, 1).repeat_interleave(block_n, 2) + scales = scales.repeat_interleave(block_n, 2) # size_n may not divisible by block_size[0] scales = scales[..., :size_n].contiguous() @@ -302,4 +322,6 @@ def marlin_quant_fp8_torch(weight, group_size): size_n=size_n, group_size=group_size) + marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales) + return weight_ref.T, marlin_qweight, marlin_scales