Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
dacdf8c
fp4 marlin kernel
jinzhen-lin May 6, 2025
e2c0ad3
fix
jinzhen-lin May 6, 2025
0d5368b
fix
jinzhen-lin May 6, 2025
8d51e32
fix
jinzhen-lin May 6, 2025
c879e99
fix format
jinzhen-lin May 6, 2025
bb547a6
fix
jinzhen-lin May 6, 2025
4dddda5
fix
jinzhen-lin May 6, 2025
9aac76a
kFE2M1fn -> kFE2M1f
jinzhen-lin May 6, 2025
28d7f84
Merge remote-tracking branch 'origin/main' into fp4-marlin
jinzhen-lin May 6, 2025
8392d73
fix
jinzhen-lin May 6, 2025
5050d4b
fix
jinzhen-lin May 6, 2025
02576a9
fix comment
jinzhen-lin May 6, 2025
49978ad
fix
jinzhen-lin May 6, 2025
af12b22
fix
jinzhen-lin May 6, 2025
fa0d098
fix
jinzhen-lin May 6, 2025
e6265a6
update
jinzhen-lin May 9, 2025
fe3ea6e
fix for fp8
jinzhen-lin May 9, 2025
e6047e5
fix
jinzhen-lin May 9, 2025
dd53ce9
fix
jinzhen-lin May 9, 2025
810c95a
fix
jinzhen-lin May 9, 2025
0f07183
fix
jinzhen-lin May 9, 2025
7eb3f9b
fix
jinzhen-lin May 9, 2025
ed1db37
fix test
jinzhen-lin May 9, 2025
168fb3e
fix
jinzhen-lin May 9, 2025
25531eb
fix
jinzhen-lin May 9, 2025
ed95abb
fp4 moe marlin
jinzhen-lin May 9, 2025
d7b2ac7
fix
jinzhen-lin May 9, 2025
f09273b
fix
jinzhen-lin May 9, 2025
a82fcbf
add comment
jinzhen-lin May 9, 2025
7177a72
fix
jinzhen-lin May 9, 2025
4a6ac2a
fix
jinzhen-lin May 9, 2025
e6144ee
remove unused cuda kernel
jinzhen-lin May 9, 2025
45910c1
fix
jinzhen-lin May 9, 2025
027f6a3
Merge remote-tracking branch 'origin/main' into fp4-marlin
jinzhen-lin May 9, 2025
520149d
Merge remote-tracking branch 'origin/main' into fp4-marlin
jinzhen-lin May 9, 2025
7e0dbe8
fix test
jinzhen-lin May 10, 2025
9efad4e
Merge remote-tracking branch 'origin/main' into fp4-marlin
jinzhen-lin May 10, 2025
18df7ec
fp4 moe support
jinzhen-lin May 10, 2025
a21442d
fix moe support
jinzhen-lin May 10, 2025
660eb61
Merge remote-tracking branch 'origin/main' into fp4-marlin
jinzhen-lin May 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions csrc/core/scalar_type.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,8 @@ static inline constexpr auto kS8 = ScalarType::int_(8);
static inline constexpr auto kU8 = ScalarType::uint(8);
static inline constexpr auto kU8B128 = ScalarType::uint(8, 128);

static inline constexpr auto kFE2M1f =
ScalarType::float_(2, 1, true, ScalarType::NAN_NONE);
static inline constexpr auto kFE3M2f =
ScalarType::float_(3, 2, true, ScalarType::NAN_NONE);
static inline constexpr auto kFE4M3fn =
Expand All @@ -332,6 +334,7 @@ static inline constexpr auto kInt8 = kS8;
static inline constexpr auto kUint8 = kU8;
static inline constexpr auto kUint8b128 = kU8B128;

static inline constexpr auto kFloat4_e2m1f = kFE2M1f;
static inline constexpr auto kFloat6_e3m2f = kFE3M2f;
static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn;
static inline constexpr auto kFloat8_e5m2 = kFE5M2;
Expand Down
13 changes: 11 additions & 2 deletions csrc/moe/marlin_moe_wna16/generate_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,18 @@

# int8 with zero point case (vllm::kU8) is also supported,
# we don't add it to reduce wheel size.
SCALAR_TYPES = ["vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn"]
SCALAR_TYPES = [
"vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn",
"vllm::kFE2M1f"
]
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)]

THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
# group_blocks:
# = 0 : act order case
# = -1 : channelwise quantization
# > 0 : group_size=16*group_blocks
GROUP_BLOCKS = [0, -1, 2, 4, 8]
GROUP_BLOCKS = [0, -1, 1, 2, 4, 8]
DTYPES = ["fp16", "bf16"]


Expand Down Expand Up @@ -72,6 +75,12 @@ def generate_new_kernels():
# for fp8
if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]:
continue
# nvfp4 only supports group_size == 16
if scalar_type == "vllm::kFE2M1f" and group_blocks not in [1, 2]:
continue
# other quantization methods don't support group_size = 16
if scalar_type != "vllm::kFE2M1f" and group_blocks == 1:
continue

k_blocks = thread_configs[0] // 16
n_blocks = thread_configs[1] // 16
Expand Down
23 changes: 12 additions & 11 deletions csrc/moe/marlin_moe_wna16/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,18 @@
#include "quantization/gptq_marlin/marlin_dtypes.cuh"
#include "core/scalar_type.hpp"

#define MARLIN_KERNEL_PARAMS \
const int4 *__restrict__ A, const int4 *__restrict__ B, \
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
const int4 *__restrict__ scales_ptr, const int4 *__restrict__ zp_ptr, \
const int *__restrict__ g_idx, \
const int32_t *__restrict__ sorted_token_ids_ptr, \
const int32_t *__restrict__ expert_ids_ptr, \
const int32_t *__restrict__ num_tokens_past_padded_ptr, \
const float *__restrict__ topk_weights_ptr, int top_k, \
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
int prob_n, int prob_k, int *locks, bool use_atomic_add, \
#define MARLIN_KERNEL_PARAMS \
const int4 *__restrict__ A, const int4 *__restrict__ B, \
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
const int4 *__restrict__ scales_ptr, \
const uint16_t *__restrict__ scale2_ptr, \
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
const int32_t *__restrict__ sorted_token_ids_ptr, \
const int32_t *__restrict__ expert_ids_ptr, \
const int32_t *__restrict__ num_tokens_past_padded_ptr, \
const float *__restrict__ topk_weights_ptr, int top_k, \
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
int prob_n, int prob_k, int *locks, bool use_atomic_add, \
bool use_fp32_reduce, int max_shared_mem

namespace MARLIN_NAMESPACE_NAME {
Expand Down
139 changes: 99 additions & 40 deletions csrc/moe/marlin_moe_wna16/marlin_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -301,9 +301,11 @@ __global__ void Marlin(
int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
// (k/groupsize)xn
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
// (k/groupsize)x(n/pack_factor)
const int* __restrict__ g_idx, // int32 group indices of shape k
const uint16_t* __restrict__ scale2_ptr, // fp16 global scale (for nvfp4
// only)
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
// (k/groupsize)x(n/pack_factor)
const int* __restrict__ g_idx, // int32 group indices of shape k
const int32_t* __restrict__ sorted_token_ids_ptr, // moe sorted_ids
const int32_t* __restrict__ expert_ids_ptr, // moe expert ids
const int32_t* __restrict__ num_tokens_past_padded_ptr, // moe num tokens
Expand Down Expand Up @@ -341,14 +343,25 @@ __global__ void Marlin(
extern __shared__ int4 sh[];
static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id);
constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8;
constexpr bool is_int_type = w_type == vllm::kU4 || w_type == vllm::kU8 ||
w_type == vllm::kU4B8 || w_type == vllm::kU8B128;
// see comments of dequant.h for more details
constexpr bool dequant_skip_flop =
!is_int_type ||
has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value ||
has_zp && !is_zp_float && !(w_type == vllm::kU8);

scalar_t2 global_scale;

constexpr bool has_act_order = group_blocks == 0;

constexpr int pack_factor = 32 / w_type.size_bits();
static_assert(thread_m_blocks == 1 || !m_block_size_8);
constexpr int moe_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks);
const int group_size =
(!has_act_order && group_blocks == -1) ? prob_k : prob_k / num_groups;
const int scales_expert_stride = prob_n * prob_k / group_size / 8;
const int scales_expert_stride =
prob_n * prob_k / group_size / (w_type == vllm::kFE2M1f ? 16 : 8);
const int zp_expert_stride =
is_zp_float ? prob_n * prob_k / group_size / 8
: prob_n * prob_k / group_size / (pack_factor * 4);
Expand Down Expand Up @@ -460,9 +473,16 @@ __global__ void Marlin(
if (mul_topk_weights) {
#pragma unroll
for (int i = 0; i < 4; i++) {
sh_block_topk_weights[tid4 * 4 + i] =
Dtype::num2num2(Dtype::float2num(
topk_weights_ptr[sh_block_sorted_ids[tid4 * 4 + i]]));
if constexpr (w_type == vllm::kFE2M1f) {
sh_block_topk_weights[tid4 * 4 + i] = __hmul2(
global_scale,
Dtype::num2num2(Dtype::float2num(
topk_weights_ptr[sh_block_sorted_ids[tid4 * 4 + i]])));
} else {
sh_block_topk_weights[tid4 * 4 + i] =
Dtype::num2num2(Dtype::float2num(
topk_weights_ptr[sh_block_sorted_ids[tid4 * 4 + i]]));
}
}
}
}
Expand Down Expand Up @@ -493,6 +513,11 @@ __global__ void Marlin(
expert_id = expert_ids_ptr[block_id];
}

if constexpr (w_type == vllm::kFE2M1f) {
uint16_t val = scale2_ptr[expert_id];
global_scale = Dtype::num2num2(*reinterpret_cast<scalar_t*>(&val));
}

B_expert_off = expert_id * prob_n * prob_k / (pack_factor * 4);
scales_ptr += (expert_id - old_expert_id) * scales_expert_stride;
if constexpr (has_zp) {
Expand Down Expand Up @@ -606,7 +631,7 @@ __global__ void Marlin(
constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
constexpr int s_tb_groups =
!has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks
? thread_k_blocks / group_blocks
? thread_k_blocks / group_blocks / (w_type == vllm::kFE2M1f ? 2 : 1)
: 1;
constexpr int s_sh_stage = s_tb_groups * s_sh_stride;
int s_gl_rd_delta = s_gl_stride;
Expand Down Expand Up @@ -664,7 +689,8 @@ __global__ void Marlin(
if constexpr (group_blocks == -1) {
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
} else {
s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) /
(w_type == vllm::kFE2M1f ? 2 : 1) +
s_sh_stride * slice_col + threadIdx.x;
}
}
Expand All @@ -688,10 +714,20 @@ __global__ void Marlin(
// we scale a `half2` tile in column-major layout in the former and in
// row-major in the latter case.
int s_sh_rd;
if constexpr (group_blocks != -1)
if constexpr (group_blocks != -1 && w_type == vllm::kFE2M1f) {
auto warp_id = threadIdx.x / 32;
int n_warps = thread_n_blocks / 4;
int warp_row = warp_id / n_warps;

s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) / 4;
else if constexpr (group_blocks == -1 && (m_block_size_8 || has_zp))
s_sh_rd = s_sh_rd * 2 + warp_row % 2;

} else if constexpr (group_blocks != -1)
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) / 4;
else if constexpr (group_blocks == -1 &&
(m_block_size_8 || (has_zp && !dequant_skip_flop)))
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) / 8;
else
Expand Down Expand Up @@ -801,7 +837,7 @@ __global__ void Marlin(
sh_first_group_id = first_group_id;
sh_num_groups = last_group_id - first_group_id + 1;

if (sh_num_groups < act_s_max_num_groups) {
if (sh_num_groups > act_s_max_num_groups) {
sh_num_groups = act_s_max_num_groups;
}

Expand Down Expand Up @@ -1021,12 +1057,19 @@ __global__ void Marlin(
cur_k += k_iter_size * (k % b_sh_wr_iters);

int k_blocks = cur_k / 16;
int cur_group_id = k_blocks / group_blocks;
int cur_group_id =
k_blocks / (group_blocks * (w_type == vllm::kFE2M1f ? 2 : 1));

int4* sh_s_stage = sh_s + s_sh_stage * pipe;

reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
if constexpr (w_type_id != vllm::kFE2M1f.id()) {
reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
} else {
reinterpret_cast<int2*>(&frag_s[k % 2])[0] =
reinterpret_cast<int2*>(
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)];
}
}
}

Expand Down Expand Up @@ -1199,22 +1242,7 @@ __global__ void Marlin(
};

auto dequant_data = [&](int q, scalar_t2* frag_b_ptr) {
if constexpr (has_zp && is_zp_float || !has_zp) {
dequant<scalar_t2, w_type_id>(q, frag_b_ptr);
} else {
static_assert(has_zp && !is_zp_float);
static_assert(w_type_id == vllm::kU4.id() || w_type_id == vllm::kU8.id());
// If (has_zp && !is_zp_float),
// we use not-zp version `dequant` function
// to improve numerical accuracy.
// Since both weight and zero point are dequanted using this logic,
// the final dequanted weight would be correct.
if constexpr (w_type_id == vllm::kU4.id()) {
dequant<scalar_t2, vllm::kU4B8.id()>(q, frag_b_ptr);
} else if constexpr (w_type_id == vllm::kU8.id()) {
dequant<scalar_t2, vllm::kU8B128.id()>(q, frag_b_ptr);
}
}
dequant<scalar_t2, w_type_id, dequant_skip_flop>(q, frag_b_ptr);
};

// Execute the actual tensor core matmul of a sub-tile.
Expand Down Expand Up @@ -1244,13 +1272,23 @@ __global__ void Marlin(
dequant_data(zp_quant_1, reinterpret_cast<scalar_t2*>(&frag_zp) + 2);
}
}
if constexpr (has_zp && is_zp_float) {
if constexpr (!dequant_skip_flop && has_zp && is_zp_float) {
if (is_new_zp) {
reinterpret_cast<int4*>(&frag_zp)[0] =
reinterpret_cast<int4*>(&frag_zpf[k2])[0];
}
}

if constexpr (w_type == vllm::kFE2M1f) {
int s_quant_0 = reinterpret_cast<int*>(frag_s[k2])[0];
int s_quant_1 = reinterpret_cast<int*>(frag_s[k2])[1];

dequant_fp8_scales<scalar_t2>(s_quant_0,
reinterpret_cast<scalar_t2*>(&frag_s[k2]));
dequant_fp8_scales<scalar_t2>(
s_quant_1, reinterpret_cast<scalar_t2*>(&frag_s[k2]) + 2);
}

// We have the m dimension as the inner loop in order to encourage overlapping
// dequantization and matmul operations.
#pragma unroll
Expand All @@ -1259,7 +1297,10 @@ __global__ void Marlin(
FragB frag_b1;
int b_quant_0, b_quant_1;

if constexpr (w_type.size_bits() == 4) {
if constexpr (w_type_id == vllm::kFE2M1f.id()) {
b_quant_1 = frag_b_quant[k2][0][j];
b_quant_0 = b_quant_1 << 8;
} else if constexpr (w_type.size_bits() == 4) {
b_quant_0 = frag_b_quant[k2][0][j];
b_quant_1 = b_quant_0 >> 8;
} else {
Expand All @@ -1272,22 +1313,28 @@ __global__ void Marlin(
dequant_data(b_quant_0, reinterpret_cast<scalar_t2*>(&frag_b0));
dequant_data(b_quant_1, reinterpret_cast<scalar_t2*>(&frag_b1));

if constexpr (dequant_skip_flop && has_zp && !is_zp_float) {
sub_zp<scalar_t>(frag_b0, frag_zp[j], 0);
sub_zp<scalar_t>(frag_b1, frag_zp[j], 1);
}

// Apply scale to frag_b0
if constexpr (has_act_order) {
static_assert(group_blocks != -1);
scale4<scalar_t>(frag_b0, act_frag_s[k2][0][j], act_frag_s[k2][1][j],
act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0);
scale4<scalar_t>(frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j],
act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1);
} else if constexpr (has_zp && !is_zp_float && group_blocks == -1) {
} else if constexpr (!dequant_skip_flop && has_zp && !is_zp_float &&
group_blocks == -1) {
int idx = (threadIdx.x / 4) % 2;
scalar_t2 s2 = Dtype::nums2num2(
reinterpret_cast<scalar_t*>(&frag_s[j / 2][j % 2 * 2 + 0])[idx],
reinterpret_cast<scalar_t*>(&frag_s[j / 2][j % 2 * 2 + 1])[idx]);
if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], s2);
scale_and_sub<scalar_t>(frag_b0, s2.x, frag_zp[j].x);
scale_and_sub<scalar_t>(frag_b1, s2.y, frag_zp[j].y);
} else if constexpr (has_zp && group_blocks != -1) {
} else if constexpr (!dequant_skip_flop && has_zp && group_blocks != -1) {
if (is_new_zp)
frag_zp[j] = __hmul2(frag_zp[j],
*reinterpret_cast<scalar_t2*>(&frag_s[k2][j]));
Expand Down Expand Up @@ -1554,10 +1601,17 @@ __global__ void Marlin(
// For per-column quantization we finally apply the scale here (only for
// 4-bit)
if constexpr (!has_act_order && group_blocks == -1 &&
w_type.size_bits() == 4 && !has_zp) {
w_type.size_bits() == 4 &&
(has_zp && dequant_skip_flop || !has_zp)) {
res = __hmul2(res, s[0]);
}

if constexpr (w_type == vllm::kFE2M1f) {
if (!mul_topk_weights) {
res = __hmul2(res, global_scale);
}
}

if constexpr (m_block_size_8) {
((scalar_t*)sh_red)[idx] = res.x;
((scalar_t*)sh_red)[idx + 8 * c_sh_stride] = res.y;
Expand Down Expand Up @@ -1648,7 +1702,9 @@ __global__ void Marlin(
if constexpr (has_zp && !is_zp_float && group_blocks == -1) {
if (i == 0) {
fetch_col_zp_to_shared();
fetch_col_scale_to_shared();
if constexpr (!dequant_skip_flop) {
fetch_col_scale_to_shared();
}
}
}
fetch_to_shared(i, i, i < slice_iters, i);
Expand Down Expand Up @@ -1737,7 +1793,8 @@ __global__ void Marlin(
bool last = slice_idx == slice_count - 1;
// For per-column scales, we only fetch them here in the final step before
// write-out
if constexpr (!has_act_order && group_blocks == -1 && !has_zp) {
if constexpr (!has_act_order && group_blocks == -1 &&
(has_zp && dequant_skip_flop || !has_zp)) {
if (w_type.size_bits() == 8 || (last || use_atomic_add)) {
if (s_sh_wr_pred) {
cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
Expand All @@ -1747,7 +1804,8 @@ __global__ void Marlin(
}

thread_block_reduce();
if constexpr (!has_act_order && group_blocks == -1 && !has_zp) {
if constexpr (!has_act_order && group_blocks == -1 &&
(has_zp && dequant_skip_flop || !has_zp)) {
if (w_type.size_bits() == 8 || (last || use_atomic_add)) {
cp_async_wait<0>();
__syncthreads();
Expand All @@ -1771,7 +1829,8 @@ __global__ void Marlin(
// that converts the fp32 results to fp16 (so that we avoid possible
// overflow in fp16)
if constexpr (!has_act_order && group_blocks == -1 &&
w_type.size_bits() == 8 && !has_zp) {
w_type.size_bits() == 8 &&
(has_zp && dequant_skip_flop || !has_zp)) {
if (threadIdx.x / 32 < thread_n_blocks / 4) {
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
Expand Down
Loading