diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index 1e170e80d2f7..e3c18ce5a50b 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -25,6 +25,8 @@ #include +#include "core/scalar_type.hpp" + template inline std::string str(T x) { return std::to_string(x); @@ -131,11 +133,26 @@ __device__ inline int lop3(int a, int b, int c) { return res; } -// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 -// values. We mostly follow the strategy in the link below, with some small -// changes: -// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h -__device__ inline FragB dequant(int q) { +// Constructs destination register by taking bytes from 2 sources (based on +// mask) +template +__device__ inline uint32_t prmt(uint32_t a) { + uint32_t res; + asm volatile("prmt.b32 %0, %1, %2, %3;\n" + : "=r"(res) + : "r"(a), "n"(start_byte), "n"(mask)); + return res; +} + +template +__device__ inline FragB dequant(int q); + +// Efficiently dequantize 4bit values packed in an int32 value into a full +// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below, +// with some small changes: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 +template <> +__device__ inline FragB dequant(int q) { const int LO = 0x000f000f; const int HI = 0x00f000f0; const int EX = 0x64006400; @@ -156,6 +173,28 @@ __device__ inline FragB dequant(int q) { return frag_b; } +// Fast Int8ToFp16: Efficiently dequantize 8bit int values to fp16 +// Reference: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 +template <> +__device__ inline FragB dequant(int q) { + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + uint32_t lo = prmt(q); + uint32_t hi = prmt(q); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + + FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(*reinterpret_cast(&hi), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + return frag_b; +} + // Multiply dequantized values by the corresponding quantization scale; used // only for grouped quantization. __device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { @@ -296,7 +335,8 @@ __global__ void compute_expert_offsets(int const* __restrict__ topk_ids, __syncthreads(); } -template ( - &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]); + + #pragma unroll + for (int i = 0; i < b_thread_vecs; i++) { + frag_b_quant[k % 2][i] = *reinterpret_cast( + &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); + } }; bool is_same_group[stages]; @@ -840,10 +893,19 @@ __device__ inline void MarlinMoESingle( // dequantization and matmul operations. #pragma unroll for (int j = 0; j < 4; j++) { - int b_quant = frag_b_quant[k % 2][j]; - int b_quant_shift = b_quant >> 8; + int b_quant_0, b_quant_1; + if constexpr (w_type.size_bits() == 4) { + b_quant_0 = frag_b_quant[k % 2][0][j]; + b_quant_1 = b_quant_0 >> 8; + } else { + static_assert(w_type.size_bits() == 8); + int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); + b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; + b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; + } - FragB frag_b0 = dequant(b_quant); + FragB frag_b0 = dequant(b_quant_0); + FragB frag_b1 = dequant(b_quant_1); // Apply scale to frag_b0 if constexpr (has_act_order) { @@ -855,8 +917,6 @@ __device__ inline void MarlinMoESingle( } } - FragB frag_b1 = dequant(b_quant_shift); - // Apply scale to frag_b1 if constexpr (has_act_order) { scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], @@ -881,13 +941,13 @@ __device__ inline void MarlinMoESingle( // multiple warps that accumulate their partial sums of the same output // location; which we have to reduce over in the end. We do in shared memory. auto thread_block_reduce = [&]() { - constexpr int red_off = threads / b_sh_stride / 2; + constexpr int red_off = threads / b_sh_stride_threads / 2; if (red_off >= 1) { - int red_idx = threadIdx.x / b_sh_stride; - constexpr int red_sh_stride = b_sh_stride * 4 * 2; - constexpr int red_sh_delta = b_sh_stride; - int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + - (threadIdx.x % b_sh_stride); + int red_idx = threadIdx.x / b_sh_stride_threads; + constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; + constexpr int red_sh_delta = b_sh_stride_threads; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads); // Parallel logarithmic shared memory reduction. We make sure to avoid any // unnecessary read or write iterations, e.g., for two warps we write only @@ -1035,8 +1095,10 @@ __device__ inline void MarlinMoESingle( auto write = [&](int idx, float c0, float c1, FragS& s) { half2 res = __halves2half2(__float2half(c0), __float2half(c1)); - // For per-column quantization we finally apply the scale here - if constexpr (!has_act_order && group_blocks == -1) { + // For per-column quantization we finally apply the scale here (only for + // 4-bit) + if constexpr (!has_act_order && group_blocks == -1 && + w_type.size_bits() == 4) { res = __hmul2(res, s[0]); } @@ -1169,25 +1231,67 @@ __device__ inline void MarlinMoESingle( // For per-column scales, we only fetch them here in the final step before // write-out if constexpr (!has_act_order && group_blocks == -1) { - if (last) { + if constexpr (w_type.size_bits() == 8) { if (s_sh_wr_pred) { cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); } cp_async_fence(); + } else { + if (last) { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + } } } thread_block_reduce(); if constexpr (!has_act_order && group_blocks == -1) { - if (last) { + if constexpr (w_type.size_bits() == 8) { cp_async_wait<0>(); __syncthreads(); if (threadIdx.x / 32 < thread_n_blocks / 4) { reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; } + + } else { + if (last) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + } + } + } + + // For 8-bit channelwise, we apply the scale before the global reduction + // that converts the fp32 results to fp16 (so that we avoid possible + // overflow in fp16) + if constexpr (!has_act_order && group_blocks == -1 && + w_type.size_bits() == 8) { + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + scale_float(reinterpret_cast(&frag_c[i][j][0][0]), + frag_s[j / 2][2 * (j % 2) + 0]); + scale_float(reinterpret_cast(&frag_c[i][j][0][2]), + frag_s[j / 2][2 * (j % 2) + 0]); + + scale_float(reinterpret_cast(&frag_c[i][j][1][0]), + frag_s[j / 2][2 * (j % 2) + 1]); + scale_float(reinterpret_cast(&frag_c[i][j][1][2]), + frag_s[j / 2][2 * (j % 2) + 1]); + } + } } } + if (slice_count > 1) { // only globally reduce if there is more than one // block in a slice barrier_acquire(&locks[slice_col], slice_idx); @@ -1227,7 +1331,8 @@ __device__ inline void MarlinMoESingle( } } -template ( + MarlinMoESingle( A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, current_m_block); } else if (max_block == 2) { - MarlinMoESingle( + MarlinMoESingle( A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, current_m_block); } else if (max_block == 3) { - MarlinMoESingle( + MarlinMoESingle( A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, current_m_block); } else { - MarlinMoESingle( + MarlinMoESingle( A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, @@ -1342,7 +1447,8 @@ __global__ void compute_expert_offsets(int const* __restrict__ topk_ids, return; } -template , \ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - MarlinMoE \ + MarlinMoE \ <<>>( \ A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ @@ -1494,42 +1601,43 @@ thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) { return thread_config_t{-1, -1, -1}; } -#define CALL_IF_MOE(N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - \ - __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) +#define CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + \ + __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + \ + __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + \ + __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + \ + __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, const void* sorted_ids, const void* topk_weights, const void* topk_ids, const void* s, const void* g_idx, const void* perm, void* a_tmp, void* expert_offsets, int prob_m, int prob_n, int prob_k, void* workspace, - bool has_act_order, bool is_k_full, int num_groups, - int group_size, int num_experts, int topk, - int moe_block_size, int dev, cudaStream_t stream, - int thread_k, int thread_n, int sms, int max_par, - bool replicate_input, bool apply_weights) { + vllm::ScalarType const& q_type, bool has_act_order, + bool is_k_full, int num_groups, int group_size, + int num_experts, int topk, int moe_block_size, int dev, + cudaStream_t stream, int thread_k, int thread_n, + int sms, int max_par, bool replicate_input, + bool apply_weights) { TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, ", ", prob_n, ", ", prob_k, "]"); @@ -1611,10 +1719,13 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, has_act_order = false; } + int pack_factor = 32 / q_type.size_bits(); + for (int expert_idx = 0; expert_idx < num_experts; ++expert_idx) { const int4* A_ptr = (const int4*)A; int4* a_tmp_ptr = (int4*)a_tmp; - const int4* B_ptr = (const int4*)B + (prob_n * prob_k / 32) * expert_idx; + const int4* B_ptr = + (const int4*)B + (prob_n * prob_k / (pack_factor * 4)) * expert_idx; int4* C_ptr = (int4*)C; const float* topk_weights_ptr = (const float*)topk_weights; const int* sorted_ids_ptr = (const int*)sorted_ids; @@ -1645,10 +1756,14 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, if (false) { } - CALL_IF_MOE(16, 4, 256) - CALL_IF_MOE(8, 8, 256) - CALL_IF_MOE(8, 4, 128) - CALL_IF_MOE(4, 8, 128) + CALL_IF_MOE(vllm::kU4B8, 16, 4, 256) + CALL_IF_MOE(vllm::kU4B8, 8, 8, 256) + CALL_IF_MOE(vllm::kU4B8, 8, 4, 128) + CALL_IF_MOE(vllm::kU4B8, 4, 8, 128) + CALL_IF_MOE(vllm::kU8B128, 16, 4, 256) + CALL_IF_MOE(vllm::kU8B128, 8, 8, 256) + CALL_IF_MOE(vllm::kU8B128, 8, 4, 128) + CALL_IF_MOE(vllm::kU8B128, 4, 8, 128) else { TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + str(prob_n) + ", " + str(prob_k) + "]" + @@ -1670,9 +1785,15 @@ torch::Tensor marlin_gemm_moe( const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights, const torch::Tensor& topk_ids, const torch::Tensor& b_scales, const torch::Tensor& g_idx, const torch::Tensor& perm, - torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k, - bool is_k_full, int64_t num_experts, int64_t topk, int64_t moe_block_size, + torch::Tensor& workspace, vllm::ScalarTypeTorchPtr const& b_q_type, + int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full, + int64_t num_experts, int64_t topk, int64_t moe_block_size, bool replicate_input, bool apply_weights) { + TORCH_CHECK(*b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128, + "b_q_type must be uint4b8 or uint8b128. Got = ", b_q_type->str()); + + int pack_factor = 32 / b_q_type->size_bits(); + int max_par = 4; int dev = a.get_device(); @@ -1733,8 +1854,8 @@ torch::Tensor marlin_gemm_moe( topk_weights.data_ptr(), topk_ids.data_ptr(), b_scales.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), expert_offsets.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(), - has_act_order, is_k_full, num_groups, group_size, num_experts, topk, - moe_block_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, + *b_q_type, has_act_order, is_k_full, num_groups, group_size, num_experts, + topk, moe_block_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, max_par, replicate_input, apply_weights); return c; -} \ No newline at end of file +} diff --git a/csrc/moe/marlin_moe_ops.h b/csrc/moe/marlin_moe_ops.h index 01ba8ff69850..adee8399a4d6 100644 --- a/csrc/moe/marlin_moe_ops.h +++ b/csrc/moe/marlin_moe_ops.h @@ -2,11 +2,14 @@ #include +#include "core/scalar_type.hpp" + torch::Tensor marlin_gemm_moe( const torch::Tensor& a, const torch::Tensor& b_q_weights, const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights, const torch::Tensor& topk_ids, const torch::Tensor& b_scales, const torch::Tensor& g_idx, const torch::Tensor& perm, - torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k, - bool is_k_full, int64_t num_experts, int64_t topk, int64_t moe_block_size, - bool replicate_input, bool apply_weights); \ No newline at end of file + torch::Tensor& workspace, vllm::ScalarTypeTorchPtr const& b_q_type, + int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full, + int64_t num_experts, int64_t topk, int64_t moe_block_size, + bool replicate_input, bool apply_weights); diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index d4d43e2c601b..cd65a8ee92b9 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -13,10 +13,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { m.def( "marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, " "Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! " - "g_idx, Tensor! perm, Tensor! workspace, int size_m, int size_n, int " - "size_k, bool is_k_full, int num_experts, int topk, int moe_block_size, " - "bool replicate_input, bool apply_weights) -> Tensor"); - + "g_idx, Tensor! perm, Tensor! workspace, " + "__torch__.torch.classes._core_C.ScalarType b_q_type, int size_m, " + "int size_n, int size_k, bool is_k_full, int num_experts, int topk, " + "int moe_block_size, bool replicate_input, bool apply_weights)" + " -> Tensor"); m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe); #endif } diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index f526c381b333..f7642bf02b05 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -2,6 +2,8 @@ Run `pytest tests/kernels/test_moe.py`. """ +from typing import List + import pytest import torch from transformers import MixtralConfig @@ -9,7 +11,12 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.fused_moe.fused_moe_marlin import ( + fused_moe_marlin, single_moe_marlin) +from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( + marlin_quantize) from vllm.model_executor.models.mixtral import MixtralMoE +from vllm.scalar_type import scalar_types def torch_moe(a, w1, w2, score, topk): @@ -29,6 +36,20 @@ def torch_moe(a, w1, w2, score, topk): topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) +def torch_moe_single(a, w, score, topk): + B, D = a.shape + a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + out = torch.zeros(B * topk, w.shape[1], dtype=a.dtype, device=a.device) + score = torch.softmax(score, dim=-1, dtype=torch.float32) + _, topk_ids = torch.topk(score, topk) + topk_ids = topk_ids.view(-1) + for i in range(w.shape[0]): + mask = topk_ids == i + if mask.sum(): + out[mask] = a[mask] @ w[i].transpose(0, 1) + return (out.view(B, -1, w.shape[1])).sum(dim=1) + + @pytest.mark.parametrize("m", [1024 * 128, 512, 222, 33, 1]) @pytest.mark.parametrize("n", [2048, 256, 1024]) @pytest.mark.parametrize("k", [128, 511, 1024]) @@ -43,11 +64,11 @@ def test_fused_moe( topk: int, dtype: torch.dtype, ): - a = torch.randn((m, k), device='cuda', dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10 - w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10 + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 - score = torch.randn((m, e), device='cuda', dtype=dtype) + score = torch.randn((m, e), device="cuda", dtype=dtype) triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False) torch_output = torch_moe(a, w1, w2, score, topk) torch.testing.assert_close(triton_output, torch_output, atol=1e-2, rtol=0) @@ -99,3 +120,199 @@ def test_mixtral_moe(dtype: torch.dtype): vllm_states, rtol=mixtral_moe_tol[dtype], atol=mixtral_moe_tol[dtype]) + + +def stack_and_dev(tensors: List[torch.Tensor]): + dev = tensors[0].device + return torch.stack(tensors, dim=0).to(dev) + + +def compute_max_diff(output, output_ref): + return torch.mean(torch.abs(output - output_ref)) / torch.mean( + torch.abs(output_ref)) + + +@pytest.mark.parametrize("m", [64, 512, 222, 33, 1]) +@pytest.mark.parametrize("n", [128, 2048, 256, 1024]) +@pytest.mark.parametrize("k", [128, 1024, 512]) +@pytest.mark.parametrize("e", [4, 8, 64]) +@pytest.mark.parametrize("topk", [2, 6]) +@pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) +@pytest.mark.parametrize("act_order", [True, False]) +@pytest.mark.parametrize("num_bits", [4, 8]) +def test_fused_marlin_moe( + m: int, + n: int, + k: int, + e: int, + topk: int, + group_size: int, + act_order: bool, + num_bits: int, +): + torch.manual_seed(7) + + if topk > e: + return + + # Filter act_order + if act_order: + if group_size == -1: + return + if group_size in (k, n): + return + + quant_type = (scalar_types.uint4b8 + if num_bits == 4 else scalar_types.uint8b128) + dtype = torch.float16 + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + for i in range(w2.shape[0]): + w2[0] = torch.eye(k, n, device="cuda", dtype=dtype) + + w_ref1_l = [] + qweight1_l = [] + scales1_l = [] + g_idx1_l = [] + sort_indices1_l = [] + + for i in range(w1.shape[0]): + test_perm = torch.randperm(k) + w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize( + w1[i].transpose(1, 0), quant_type, group_size, act_order, + test_perm) + w_ref1_l.append(w_ref1) + qweight1_l.append(qweight1) + scales1_l.append(scales1) + g_idx1_l.append(g_idx1) + sort_indices1_l.append(sort_indices1) + + w_ref1 = stack_and_dev(w_ref1_l) + qweight1 = stack_and_dev(qweight1_l).contiguous() + scales1 = stack_and_dev(scales1_l) + g_idx1 = stack_and_dev(g_idx1_l) + sort_indices1 = stack_and_dev(sort_indices1_l) + + w_ref2_l = [] + qweight2_l = [] + scales2_l = [] + g_idx2_l = [] + sort_indices2_l = [] + + for i in range(w2.shape[0]): + test_perm = torch.randperm(n) + w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize( + w2[i].transpose(1, 0), quant_type, group_size, act_order, + test_perm) + w_ref2_l.append(w_ref2) + qweight2_l.append(qweight2) + scales2_l.append(scales2) + g_idx2_l.append(g_idx2) + sort_indices2_l.append(sort_indices2) + + w_ref2 = stack_and_dev(w_ref2_l) + qweight2 = stack_and_dev(qweight2_l).contiguous() + scales2 = stack_and_dev(scales2_l) + g_idx2 = stack_and_dev(g_idx2_l) + sort_indices2 = stack_and_dev(sort_indices2_l) + + score = torch.randn((m, e), device="cuda", dtype=dtype) + triton_output = fused_moe( + a, + w_ref1.transpose(1, 2).contiguous(), + w_ref2.transpose(1, 2).contiguous(), + score, + topk, + renormalize=False, + ) + marlin_output = fused_moe_marlin( + a, + qweight1, + qweight2, + score, + g_idx1, + g_idx2, + sort_indices1, + sort_indices2, + topk, + renormalize=False, + w1_scale=scales1, + w2_scale=scales2, + num_bits=num_bits, + ) + + assert compute_max_diff(marlin_output, triton_output) < 4e-2 + + +@pytest.mark.skip("This test is here for the sake of debugging, " + "don't run it in automated tests.") +@pytest.mark.parametrize("m", [64, 512, 222, 33, 1]) +@pytest.mark.parametrize("n", [128, 2048, 256, 1024]) +@pytest.mark.parametrize("k", [128, 1024, 512]) +@pytest.mark.parametrize("e", [4, 8, 64]) +@pytest.mark.parametrize("topk", [2, 6]) +@pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) +@pytest.mark.parametrize("act_order", [True, False]) +@pytest.mark.parametrize("num_bits", [4, 8]) +def test_marlin_moe_mmm( + m: int, + n: int, + k: int, + e: int, + topk: int, + group_size: int, + act_order: bool, + num_bits: int, +): + if topk > e: + return + + # Filter act_order + if act_order: + if group_size == -1: + return + if group_size == k: + return + + quant_type = (scalar_types.uint4b8 + if num_bits == 4 else scalar_types.uint8b128) + dtype = torch.float16 + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10 + + w_ref_l = [] + qweights_l = [] + scales_l = [] + g_idx_l = [] + sort_indices_l = [] + + for i in range(w.shape[0]): + test_perm = torch.randperm(k) + w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize( + w[i].transpose(1, 0), quant_type, group_size, act_order, test_perm) + w_ref_l.append(w_ref) + qweights_l.append(qweight) + scales_l.append(scales) + g_idx_l.append(g_idx) + sort_indices_l.append(sort_indices) + + w_ref = stack_and_dev(w_ref_l) + qweight = stack_and_dev(qweights_l).contiguous() + scales = stack_and_dev(scales_l) + g_idx = stack_and_dev(g_idx_l) + sort_indices = stack_and_dev(sort_indices_l) + + score = torch.randn((m, e), device="cuda", dtype=dtype) + marlin_output = single_moe_marlin(a, + qweight, + scales, + score, + g_idx, + sort_indices, + topk, + renormalize=False, + num_bits=num_bits) + torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk) + + assert compute_max_diff(marlin_output, torch_output) < 1e-2 diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index e5e7bb696397..35a3a32470f0 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -314,7 +314,7 @@ def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, num_bits: int) -> torch.Tensor: num_experts = b_q_weight.shape[0] assert size_k % 16 == 0 - output = torch.empty((num_experts, size_k // 16, size_n * 2), + output = torch.empty((num_experts, size_k // 16, size_n * (num_bits // 2)), device=b_q_weight.device, dtype=b_q_weight.dtype) for e in range(num_experts): diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index fd6f41b90042..3e9440607248 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -1,17 +1,24 @@ +from vllm.model_executor.layers.fused_moe.fused_moe_marlin import ( + fused_moe_marlin, single_moe_marlin) from vllm.model_executor.layers.fused_moe.layer import ( - FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) + FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported, GPTQFusedMoE) from vllm.triton_utils import HAS_TRITON -__all__ = ["FusedMoE", "FusedMoEMethodBase", "FusedMoeWeightScaleSupported"] +__all__ = [ + "FusedMoE", + "FusedMoEMethodBase", + "FusedMoeWeightScaleSupported", + "GPTQFusedMoE", + "fused_moe_marlin", + "single_moe_marlin", +] if HAS_TRITON: - from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_experts, fused_marlin_moe, fused_moe, fused_topk, - get_config_file_name, grouped_topk) + fused_experts, fused_moe, fused_topk, get_config_file_name, + grouped_topk) __all__ += [ - "fused_marlin_moe", "fused_moe", "fused_topk", "fused_experts", diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index d2b152320e11..613d67e64bff 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -323,15 +323,22 @@ def get_moe_configs(E: int, N: int, return None -def get_default_config(M: int, E: int, N: int, K: int, topk: int, - dtype: Optional[str], - is_marlin: bool) -> Dict[str, int]: +def get_default_config( + M: int, + E: int, + N: int, + K: int, + topk: int, + dtype: Optional[str], + is_marlin: bool, +) -> Dict[str, int]: config = { 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8 } + # A heuristic: fused marlin works faster with this config for small M if M <= E or (is_marlin and M <= 32): config = { 'BLOCK_SIZE_M': 16, @@ -342,14 +349,15 @@ def get_default_config(M: int, E: int, N: int, K: int, topk: int, return config -def try_get_optimal_moe_config(w1_shape: Tuple[int, ...], - w2_shape: Tuple[int, ...], - top_k: int, - dtype: Optional[str], - M: int, - override_config: Optional[Dict[str, - Any]] = None, - is_marlin: bool = False): +def try_get_optimal_moe_config( + w1_shape: Tuple[int, ...], + w2_shape: Tuple[int, ...], + top_k: int, + dtype: Optional[str], + M: int, + override_config: Optional[Dict[str, Any]] = None, + is_marlin: bool = False, +): if override_config: config = override_config else: @@ -391,6 +399,7 @@ def fused_topk( topk, dtype=torch.int32, device=hidden_states.device) + ops.topk_softmax( topk_weights, topk_ids, @@ -437,108 +446,6 @@ def grouped_topk(hidden_states: torch.Tensor, return topk_weights, topk_ids -def fused_marlin_moe(hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - gating_output: torch.Tensor, - g_idx1: torch.Tensor, - g_idx2: torch.Tensor, - rand_perm1: torch.Tensor, - rand_perm2: torch.Tensor, - topk: int, - renormalize: bool, - override_config: Optional[Dict[str, Any]] = None, - use_fp8: bool = False, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None) -> torch.Tensor: - """ - This function computes a Mixture of Experts (MoE) layer using two sets of - weights, w1 and w2, and top-k gating mechanism. - Parameters: - - hidden_states (torch.Tensor): The input tensor to the MoE layer. - - w1 (torch.Tensor): The first set of expert weights. - - w2 (torch.Tensor): The second set of expert weights. - - gating_output (torch.Tensor): The output of the gating operation - (before softmax). - - topk (int): The number of top-k experts to select. - - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - - inplace (bool): If True, perform the operation in-place. - Defaults to False. - - override_config (Optional[Dict[str, Any]]): Optional override - for the kernel configuration. - - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner - products for w1 and w2. Defaults to False. - - w1_scale (Optional[torch.Tensor]): Optional scale to be used for - w1. - - w2_scale (Optional[torch.Tensor]): Optional scale to be used for - w2. - Returns: - - torch.Tensor: The output tensor after applying the MoE layer. - """ - # Check constraints. - assert hidden_states.shape[0] == gating_output.shape[0], ( - "Number of tokens mismatch") - assert hidden_states.shape[ - 1] == w1.shape[1] * 16, "Hidden size mismatch w1" - assert hidden_states.shape[ - 1] == w2.shape[2] // 2, "Hidden size mismatch w2" - assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" - assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" - assert w1.is_contiguous(), "Expert weights1 must be contiguous" - assert w2.is_contiguous(), "Expert weights2 must be contiguous" - assert hidden_states.dtype in [ - torch.float32, torch.float16, torch.bfloat16 - ] - - #TODO fp8 is not implemented yet - assert not use_fp8 - - M, K = hidden_states.shape - E = w1.shape[0] - N = w2.shape[1] * 16 - - topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, - renormalize) - - get_config_func = functools.partial(try_get_optimal_moe_config, - w1.shape, - w2.shape, - topk_ids.shape[1], - "float8" if use_fp8 else None, - override_config=override_config, - is_marlin=True) - config = get_config_func(M) - - block_size_m = config['BLOCK_SIZE_M'] - - sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E) - - max_workspace_size = ((M + 255) // 256) * (max(2 * N, K) // 64) * 16 - workspace = torch.zeros(max_workspace_size, - dtype=torch.int, - device="cuda", - requires_grad=False) - - intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N), - device=hidden_states.device, - dtype=hidden_states.dtype) - - intermediate_cache1 = torch.ops._moe_C.marlin_gemm_moe( - hidden_states, w1, sorted_token_ids, topk_weights, topk_ids, w1_scale, - g_idx1, rand_perm1, workspace, M, 2 * N, K, True, E, topk, - block_size_m, True, False) - - ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N)) - - intermediate_cache3 = torch.ops._moe_C.marlin_gemm_moe( - intermediate_cache2, w2, sorted_token_ids, topk_weights, topk_ids, - w2_scale, g_idx2, rand_perm2, workspace, M, K, N, True, E, topk, - block_size_m, False, True) - - return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), - dim=1) - - def get_config_dtype_str(dtype: torch.dtype, use_int8_w8a16: Optional[bool] = False, use_fp8_w8a8: Optional[bool] = False): diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py b/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py new file mode 100644 index 000000000000..40f9f66f1706 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py @@ -0,0 +1,245 @@ +"""Fused MoE utilities for GPTQ.""" +import functools +from typing import Any, Dict, Optional + +import torch + +from vllm import _custom_ops as ops +from vllm.scalar_type import scalar_types + +from .fused_moe import (fused_topk, moe_align_block_size, + try_get_optimal_moe_config) + + +def single_moe_marlin( + hidden_states: torch.Tensor, + w: torch.Tensor, + scales: torch.Tensor, + gating_output: torch.Tensor, + g_idx: torch.Tensor, + rand_perm: torch.Tensor, + topk: int, + renormalize: bool, + override_config: Optional[Dict[str, Any]] = None, + use_fp8: bool = False, + num_bits: int = 8, +) -> torch.Tensor: + """ + This function computes a Marlin MoE MMM using weights w + and top-k gating mechanism. It is meant for testing and debugging. + + Parameters: + - hidden_states (torch.Tensor): The input tensor to the MoE layer. + - w (torch.Tensor): The first set of expert weights. + - gating_output (torch.Tensor): The output of the gating operation + (before softmax). + - topk (int): The number of top-k experts to select. + - renormalize (bool): If True, renormalize the top-k weights to sum to 1. + - inplace (bool): If True, perform the operation in-place. + Defaults to False. + - override_config (Optional[Dict[str, Any]]): Optional override + for the kernel configuration. + - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner + product for w. Defaults to False. + + Returns: + - torch.Tensor: The output tensor after applying the MoE layer. + """ + # Check constraints. + assert hidden_states.shape[0] == gating_output.shape[0], ( + "Number of tokens mismatch") + assert hidden_states.shape[1] == w.shape[1] * 16, "Hidden size mismatch" + assert gating_output.shape[1] == w.shape[0], "Number of experts mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w.is_contiguous(), "Expert weights must be contiguous" + assert hidden_states.dtype in [ + torch.float32, torch.float16, torch.bfloat16 + ] + assert num_bits in [4, 8] + # TODO support this + assert not use_fp8 + + M, K = hidden_states.shape + E = w.shape[0] + N = w.shape[2] // (num_bits // 2) + + topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, + renormalize) + + # This might not be an optimal config for a single MMM + get_config_func = functools.partial(try_get_optimal_moe_config, + w.shape, + w.shape, + topk_ids.shape[1], + "float8" if use_fp8 else None, + override_config=override_config, + is_marlin=True) + config = get_config_func(M) + + block_size_m = config['BLOCK_SIZE_M'] + + sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E) + + max_workspace_size = (N // 64) * 16 + workspace = torch.zeros(max_workspace_size, + dtype=torch.int, + device="cuda", + requires_grad=False) + + scalar_type = (scalar_types.uint4b8 + if num_bits == 4 else scalar_types.uint8b128) + + intermediate_cache = torch.ops._moe_C.marlin_gemm_moe( + hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales, + g_idx, rand_perm, workspace, scalar_type, M, N, K, True, E, topk, + block_size_m, True, False) + + return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) + + +def fused_moe_marlin( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + gating_output: torch.Tensor, + g_idx1: torch.Tensor, + g_idx2: torch.Tensor, + rand_perm1: torch.Tensor, + rand_perm2: torch.Tensor, + topk: int, + renormalize: bool, + override_config: Optional[Dict[str, Any]] = None, + use_fp8: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + num_bits: int = 8, +) -> torch.Tensor: + """ + This function computes a Mixture of Experts (MoE) layer using two sets of + weights, w1 and w2, and top-k gating mechanism. + + Parameters: + - hidden_states (torch.Tensor): The input tensor to the MoE layer. + - w1 (torch.Tensor): The first set of expert weights. + - w2 (torch.Tensor): The second set of expert weights. + - gating_output (torch.Tensor): The output of the gating operation + (before softmax). + - topk (int): The number of top-k experts to select. + - renormalize (bool): If True, renormalize the top-k weights to sum to 1. + - inplace (bool): If True, perform the operation in-place. + Defaults to False. + - override_config (Optional[Dict[str, Any]]): Optional override + for the kernel configuration. + - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner + products for w1 and w2. Defaults to False. + - w1_scale (Optional[torch.Tensor]): Optional scale to be used for + w1. + - w2_scale (Optional[torch.Tensor]): Optional scale to be used for + w2. + + Returns: + - torch.Tensor: The output tensor after applying the MoE layer. + """ + # Check constraints. + assert hidden_states.shape[0] == gating_output.shape[ + 0], "Number of tokens mismatch" + assert hidden_states.shape[ + 1] == w1.shape[1] * 16, "Hidden size mismatch w1" + assert hidden_states.shape[1] == w2.shape[2] // ( + num_bits // 2), "Hidden size mismatch w2" + assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w1.is_contiguous(), "Expert weights1 must be contiguous" + assert w2.is_contiguous(), "Expert weights2 must be contiguous" + assert hidden_states.dtype in [ + torch.float32, torch.float16, torch.bfloat16 + ] + assert num_bits in [4, 8] + # TODO support this + assert not use_fp8 + + M, K = hidden_states.shape + E = w1.shape[0] + N = w2.shape[1] * 16 + + topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, + renormalize) + + get_config_func = functools.partial( + try_get_optimal_moe_config, + w1.shape, + w2.shape, + topk_ids.shape[1], + "float8" if use_fp8 else None, + override_config=override_config, + is_marlin=True, + ) + config = get_config_func(M) + + block_size_m = config["BLOCK_SIZE_M"] + + sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E) + + max_workspace_size = ((M + 255) // 256) * (max(2 * N, K) // 64) * 16 + workspace = torch.zeros(max_workspace_size, + dtype=torch.int, + device="cuda", + requires_grad=False) + + scalar_type = (scalar_types.uint4b8 + if num_bits == 4 else scalar_types.uint8b128) + + intermediate_cache2 = torch.empty( + (M * topk_ids.shape[1], N), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + intermediate_cache1 = torch.ops._moe_C.marlin_gemm_moe( + hidden_states, + w1, + sorted_token_ids, + topk_weights, + topk_ids, + w1_scale, + g_idx1, + rand_perm1, + workspace, + scalar_type, + M, + 2 * N, + K, + True, + E, + topk, + block_size_m, + True, + False, + ) + + ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N)) + + intermediate_cache3 = torch.ops._moe_C.marlin_gemm_moe( + intermediate_cache2, + w2, + sorted_token_ids, + topk_weights, + topk_ids, + w2_scale, + g_idx2, + rand_perm2, + workspace, + scalar_type, + M, + K, + N, + True, + E, + topk, + block_size_m, + False, + True, + ) + + return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), + dim=1) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 61ebef5e11f4..9882410af1f2 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -25,15 +25,29 @@ class FusedMoeWeightScaleSupported(Enum): class FusedMoEMethodBase(QuantizeMethodBase): @abstractmethod - def create_weights(self, layer: torch.nn.Module, num_experts: int, - hidden_size: int, intermediate_size: int, - params_dtype: torch.dtype, **extra_weight_attrs): + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): raise NotImplementedError @abstractmethod - def apply(self, layer: torch.nn.Module, x: torch.Tensor, - router_logits: torch.Tensor, top_k: int, renormalize: bool, - use_grouped_topk: bool) -> torch.Tensor: + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + ) -> torch.Tensor: raise NotImplementedError @@ -43,22 +57,25 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size: int, params_dtype: torch.dtype, **extra_weight_attrs): - # Fused gate_up_proj (column parallel) - w13_weight = torch.nn.Parameter(torch.empty(num_experts, - 2 * intermediate_size, - hidden_size, - dtype=params_dtype), - requires_grad=False) + w13_weight = torch.nn.Parameter( + torch.empty(num_experts, + 2 * intermediate_size, + hidden_size, + dtype=params_dtype), + requires_grad=False, + ) layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) # down_proj (row parallel) - w2_weight = torch.nn.Parameter(torch.empty(num_experts, - hidden_size, - intermediate_size, - dtype=params_dtype), - requires_grad=False) + w2_weight = torch.nn.Parameter( + torch.empty(num_experts, + hidden_size, + intermediate_size, + dtype=params_dtype), + requires_grad=False, + ) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) @@ -67,10 +84,10 @@ def apply(self, x: torch.Tensor, router_logits: torch.Tensor, top_k: int, - renormalize: bool, - use_grouped_topk: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None) -> torch.Tensor: + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None) -> torch.Tensor: return self.forward(x=x, layer=layer, @@ -182,6 +199,7 @@ def __init__( get_tensor_model_parallel_world_size()) self.top_k = top_k self.num_experts = num_experts + self.intermediate_size = intermediate_size self.intermediate_size_per_partition = intermediate_size // self.tp_size self.reduce_results = reduce_results self.renormalize = renormalize @@ -192,8 +210,8 @@ def __init__( self.topk_group = topk_group if quant_config is None: - self.quant_method: Optional[QuantizeMethodBase] = ( - UnquantizedFusedMoEMethod()) + self.quant_method: Optional[ + QuantizeMethodBase] = UnquantizedFusedMoEMethod() else: self.quant_method = quant_config.get_quant_method(self, prefix) assert self.quant_method is not None @@ -204,7 +222,8 @@ def __init__( hidden_size=hidden_size, intermediate_size=self.intermediate_size_per_partition, params_dtype=params_dtype, - weight_loader=self.weight_loader) + weight_loader=self.weight_loader, + ) def _load_per_tensor_weight_scale(self, shard_id: str, param: torch.nn.Parameter, @@ -476,4 +495,160 @@ def _load_fp8_scale(self, param: torch.nn.Parameter, param_data[expert_id][idx] = loaded_weight # If we are in the row parallel case (down_proj) else: - param_data[expert_id] = loaded_weight \ No newline at end of file + param_data[expert_id] = loaded_weight + + +class GPTQFusedMoE(torch.nn.Module): + """GPTQFusedMoE layer for GPTQ MoE models. + + This layer contains both MergedColumnParallel weights (gate_up_proj / + w13) and RowParallelLinear weights (down_proj/ w2). + + Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We + copy that naming convention here and handle any remapping in the + load_weights function in each model implementation. + + Args: + num_experts: Number of experts in the model + top_k: Number of experts selected for each token + hidden_size: Input hidden state size of the transformer + intermediate_size: Intermediate size of the experts + params_dtype: Data type for the parameters. + reduce_results: Whether to all all_reduce on the output of the layer + renomalize: Whether to renormalize the logits in the fused_moe kernel + quant_config: Quantization configure. + """ + + def __init__( + self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + reduce_results: bool = False, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + prefix: str = "", + ): + super().__init__() + + if params_dtype is None: + params_dtype = torch.get_default_dtype() + + self.tp_size = (tp_size if tp_size is not None else + get_tensor_model_parallel_world_size()) + self.top_k = top_k + self.num_experts = num_experts + self.intermediate_size = intermediate_size + self.intermediate_size_per_partition = intermediate_size // self.tp_size + self.reduce_results = reduce_results + self.renormalize = renormalize + assert (not use_grouped_topk and num_expert_group is None + and topk_group is None) + + if quant_config is None: + self.quant_method: Optional[ + QuantizeMethodBase] = UnquantizedFusedMoEMethod() + else: + self.quant_method = quant_config.get_quant_method(self, prefix) + assert self.quant_method is not None + + self.quant_method.create_weights( + layer=self, + num_experts=num_experts, + hidden_size=hidden_size, + intermediate_size=self.intermediate_size_per_partition, + params_dtype=params_dtype, + weight_loader=self.weight_loader, + ) + + def weight_loader(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor, weight_name: str, + shard_id: str, expert_id: int) -> None: + + if ("_qweight" in weight_name or "_scales" in weight_name + or "_qzeros" in weight_name): + if "w13" in weight_name: + shard_size = loaded_weight.size()[-1] + if shard_id == "w1": + param.data[expert_id, :, :shard_size] = loaded_weight + elif shard_id == "w2" or shard_id == "w3": + param.data[expert_id, :, shard_size:] = loaded_weight + else: + raise ValueError(f"Invalid shard_id: {shard_id}: " + "must be w1, w2, or w3.") + elif "w2" in weight_name: + param.data[expert_id][:] = loaded_weight + else: + raise ValueError(f"Invalid weight name: {weight_name}: " + "must contain 'w13' or 'w2'.") + elif "_g_idx" in weight_name: + if "w13" not in weight_name and "w2" not in weight_name: + raise ValueError(f"Invalid weight name: {weight_name}: " + "must contain 'w13' or 'w2'.") + param.data[expert_id] = loaded_weight + else: + raise ValueError(f"Invalid weight name: {weight_name}.") + + @staticmethod + def select_experts(hidden_states: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + use_grouped_topk: bool, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None): + assert (not use_grouped_topk and topk_group is None + and num_expert_group is None) + from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk + + topk_weights, topk_ids = fused_topk(hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize) + + return topk_weights, topk_ids + + def forward(self, hidden_states: torch.Tensor, + router_logits: torch.Tensor): + assert self.quant_method is not None + + # Matrix multiply. + final_hidden_states = self.quant_method.apply( + layer=self, + x=hidden_states, + router_logits=router_logits, + top_k=self.top_k, + renormalize=self.renormalize, + use_grouped_topk=False, + topk_group=False, + num_expert_group=False) + + if self.reduce_results and self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + + return final_hidden_states + + @classmethod + def make_expert_params_mapping( + cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str, + ckpt_up_proj_name: str, + num_experts: int) -> List[Tuple[str, str, int, str]]: + + return [ + # (param_name, weight_name, expert_id, shard_id) + ("experts.w13_" if weight_name + in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_", + f"experts.{expert_id}.{weight_name}.", expert_id, shard_id) + for expert_id in range(num_experts) for shard_id, weight_name in [ + ("w1", ckpt_gate_proj_name), + ("w2", ckpt_down_proj_name), + ("w3", ckpt_up_proj_name), + ] + ] diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 0e0ab9ce9169..ba4f719a3f97 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -266,18 +266,21 @@ def apply(self, num_expert_group: Optional[int] = None, topk_group: Optional[int] = None) -> torch.Tensor: - from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_marlin_moe) - - return fused_marlin_moe(x, - layer.w13_weight_packed, - layer.w2_weight_packed, - router_logits, - layer.w13_g_idx, - layer.w2_g_idx, - layer.w13_g_idx_sort_indices, - layer.w2_g_idx_sort_indices, - top_k, - renormalize=renormalize, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale) + from vllm.model_executor.layers.fused_moe.fused_moe_marlin import ( + fused_moe_marlin) + + return fused_moe_marlin( + x, + layer.w13_weight_packed, + layer.w2_weight_packed, + router_logits, + layer.w13_g_idx, + layer.w2_g_idx, + layer.w13_g_idx_sort_indices, + layer.w2_g_idx_sort_indices, + top_k, + renormalize=renormalize, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + num_bits=self.num_bits, + ) diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py index dabf17df78fe..6a2ed7704d13 100644 --- a/vllm/model_executor/layers/quantization/experts_int8.py +++ b/vllm/model_executor/layers/quantization/experts_int8.py @@ -88,13 +88,13 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, 2 * intermediate_size, dtype=torch.float32), requires_grad=False) - layer.register_parameter("w13_scale", w13_scale) + layer.register_parameter("w13_weight_scale", w13_scale) w2_scale = torch.nn.Parameter(torch.zeros(num_experts, hidden_size, dtype=torch.float32), requires_grad=False) - layer.register_parameter("w2_scale", w2_scale) + layer.register_parameter("w2_weight_scale", w2_scale) def apply(self, layer: torch.nn.Module, @@ -157,7 +157,7 @@ def quantize_and_call_weight_loader(param: torch.nn.Parameter, layer.w2_scale.data[expert_id, :].copy_(scales[:, 0]) else: raise ValueError( - f"Shard id must be in [0,1,2] but got {shard_id}") + f"Shard id must be in ['w1','w2','w3'] but got {shard_id}") weight_loader(param, loaded_weight, weight_name, shard_id, expert_id) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 1817dbcb023a..14f62868b7bd 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -473,10 +473,10 @@ def apply(self, x: torch.Tensor, router_logits: torch.Tensor, top_k: int, - renormalize: bool, - use_grouped_topk: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None) -> torch.Tensor: + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None) -> torch.Tensor: from vllm.model_executor.layers.fused_moe import fused_experts diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 94eb3f301541..ceaca8981ea0 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -1,18 +1,26 @@ -from typing import Any, Dict, List, Optional +import enum +from enum import Enum +from typing import Any, Dict, List, Optional, Union import torch from torch.nn import Parameter from vllm import _custom_ops as ops from vllm.logger import init_logger -from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.fused_moe.fused_moe_marlin import ( + fused_moe_marlin) +from vllm.model_executor.layers.fused_moe.layer import (FusedMoEMethodBase, + GPTQFusedMoE) +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( apply_gptq_marlin_linear, check_marlin_supported, marlin_is_k_full, - marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales, - marlin_repeat_scales_on_all_ranks, marlin_sort_g_idx, replace_tensor, - verify_marlin_supported, verify_marlin_supports_shape) + marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales, + marlin_permute_scales, marlin_repeat_scales_on_all_ranks, + marlin_sort_g_idx, replace_tensor, verify_marlin_supported, + verify_marlin_supports_shape) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.parameter import (ChannelQuantScaleParameter, GroupQuantScaleParameter, @@ -24,6 +32,11 @@ logger = init_logger(__name__) +class GPTQMarlinState(Enum): + REPACK = enum.auto() + READY = enum.auto() + + class GPTQMarlinConfig(QuantizationConfig): """Config class for GPTQ Marlin""" @@ -33,8 +46,14 @@ class GPTQMarlinConfig(QuantizationConfig): (8, True): scalar_types.uint8b128, } - def __init__(self, weight_bits: int, group_size: int, desc_act: bool, - is_sym: bool, lm_head_quantized: bool) -> None: + def __init__( + self, + weight_bits: int, + group_size: int, + desc_act: bool, + is_sym: bool, + lm_head_quantized: bool, + ) -> None: if desc_act and group_size == -1: # In this case, act_order == True is the same as act_order == False # (since we have only one group per output channel) @@ -109,11 +128,14 @@ def override_quantization_method(cls, hf_quant_cfg, " faster inference") return None - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["GPTQMarlinLinearMethod"]: - if (isinstance(layer, LinearBase) or - (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)): + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional[Union["GPTQMarlinLinearMethod", "GPTQMarlinMoEMethod"]]: + if isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead) + and self.lm_head_quantized): return GPTQMarlinLinearMethod(self) + elif isinstance(layer, GPTQFusedMoE): + return GPTQMarlinMoEMethod(self) return None def get_scaled_act_names(self) -> List[str]: @@ -179,7 +201,8 @@ def create_weights( output_size_per_partition=output_size_per_partition, input_size_per_partition=input_size_per_partition, input_size=input_size, - group_size=group_size) + group_size=group_size, + ) # Determine sharding if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act, @@ -299,7 +322,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: perm=layer.g_idx_sort_indices, size_k=layer.input_size_per_partition, size_n=layer.output_size_per_partition, - num_bits=self.quant_config.quant_type.size_bits) + num_bits=self.quant_config.quant_type.size_bits, + ) replace_tensor(layer, "qweight", marlin_qweight) # Permute scales from autogptq format to marlin format. @@ -308,7 +332,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: size_k=(layer.input_size if self.quant_config.desc_act else layer.input_size_per_partition), size_n=layer.output_size_per_partition, - group_size=self.quant_config.group_size) + group_size=self.quant_config.group_size, + ) replace_tensor(layer, "scales", marlin_scales) def apply( @@ -329,4 +354,249 @@ def apply( output_size_per_partition=layer.output_size_per_partition, input_size_per_partition=layer.input_size_per_partition, is_k_full=layer.is_k_full, - bias=bias) + bias=bias, + ) + + +class GPTQMarlinMoEMethod(FusedMoEMethodBase): + """MoE Marlin method with quantization.""" + + def __init__(self, quant_config: GPTQMarlinConfig) -> None: + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + # Currently assuming is_k_full is always True + # (input size per partition is the same as full input size) + # Supports only sym for now (no zp) + if self.quant_config.group_size != -1: + scales_size13 = hidden_size // self.quant_config.group_size + scales_size2 = intermediate_size // self.quant_config.group_size + else: + scales_size13 = 1 + scales_size2 = 1 + # Fused gate_up_proj (column parallel) + w13_qweight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size // self.quant_config.pack_factor, + 2 * intermediate_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_qweight", w13_qweight) + set_weight_attrs(w13_qweight, extra_weight_attrs) + # down_proj (row parallel) + w2_qweight = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size // self.quant_config.pack_factor, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_qweight", w2_qweight) + set_weight_attrs(w2_qweight, extra_weight_attrs) + # up_proj scales + w13_scales = torch.nn.Parameter( + torch.empty(num_experts, + scales_size13, + 2 * intermediate_size, + dtype=params_dtype), + requires_grad=False, + ) + layer.register_parameter("w13_scales", w13_scales) + set_weight_attrs(w13_scales, extra_weight_attrs) + # down_proj scales + w2_scales = torch.nn.Parameter( + torch.empty(num_experts, + scales_size2, + hidden_size, + dtype=params_dtype), + requires_grad=False, + ) + layer.register_parameter("w2_scales", w2_scales) + set_weight_attrs(w2_scales, extra_weight_attrs) + # up_proj scales + w13_qzeros = torch.nn.Parameter( + torch.empty(num_experts, + scales_size13, + 2 * intermediate_size // self.quant_config.pack_factor, + dtype=params_dtype), + requires_grad=False, + ) + layer.register_parameter("w13_qzeros", w13_qzeros) + set_weight_attrs(w13_qzeros, extra_weight_attrs) + # down_proj scales + w2_qzeros = torch.nn.Parameter( + torch.empty(num_experts, + scales_size2, + hidden_size // self.quant_config.pack_factor, + dtype=params_dtype), + requires_grad=False, + ) + layer.register_parameter("w2_qzeros", w2_qzeros) + set_weight_attrs(w2_qzeros, extra_weight_attrs) + w13_g_idx = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_g_idx", w13_g_idx) + set_weight_attrs(w13_g_idx, extra_weight_attrs) + w2_g_idx = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_g_idx", w2_g_idx) + set_weight_attrs(w2_g_idx, extra_weight_attrs) + w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_g_idx_sort_indices", + w13_g_idx_sort_indices) + set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs) + w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_g_idx_sort_indices", + w2_g_idx_sort_indices) + set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs) + layer.marlin_state = GPTQMarlinState.REPACK + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + + # Process act_order + if self.quant_config.desc_act: + # Get sorting based on g_idx + num_experts = layer.w13_g_idx.shape[0] + w13_g_idx_sort_indices = torch.empty_like(layer.w13_g_idx) + w2_g_idx_sort_indices = torch.empty_like(layer.w2_g_idx) + w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx) + w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx) + for e in range(num_experts): + w13_g_idx_sort_indices[e] = torch.argsort( + layer.w13_g_idx[e]).to(torch.int32) + w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_g_idx[e]).to( + torch.int32) + w13_sorted_g_idx[e] = layer.w13_g_idx[e][ + w13_g_idx_sort_indices[e]] + w2_sorted_g_idx[e] = layer.w2_g_idx[e][ + w2_g_idx_sort_indices[e]] + replace_tensor(layer, "w13_g_idx", w13_sorted_g_idx) + replace_tensor(layer, "w2_g_idx", w2_sorted_g_idx) + replace_tensor(layer, "w13_g_idx_sort_indices", + w13_g_idx_sort_indices) + replace_tensor(layer, "w2_g_idx_sort_indices", + w2_g_idx_sort_indices) + else: + # Reset g_idx related tensors + num_experts = layer.w13_g_idx.shape[0] + device = layer.w13_g_idx.device + layer.w13_g_idx = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + layer.w2_g_idx = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + layer.w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + layer.w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + # Repack weights + marlin_w13_qweight = ops.gptq_marlin_moe_repack( + layer.w13_qweight, + layer.w13_g_idx_sort_indices, + layer.w13_qweight.shape[1] * self.quant_config.pack_factor, + layer.w13_qweight.shape[2], + self.quant_config.quant_type.size_bits, + ) + replace_tensor(layer, "w13_qweight", marlin_w13_qweight) + marlin_w2_qweight = ops.gptq_marlin_moe_repack( + layer.w2_qweight, + layer.w2_g_idx_sort_indices, + layer.w2_qweight.shape[1] * self.quant_config.pack_factor, + layer.w2_qweight.shape[2], + self.quant_config.quant_type.size_bits, + ) + replace_tensor(layer, "w2_qweight", marlin_w2_qweight) + # Repack scales + marlin_w13_scales = marlin_moe_permute_scales( + s=layer.w13_scales, + size_k=(layer.intermediate_size if self.quant_config.desc_act else + layer.intermediate_size_per_partition), + size_n=layer.w13_scales.shape[2], + group_size=self.quant_config.group_size, + ) + replace_tensor(layer, "w13_scales", marlin_w13_scales) + marlin_w2_scales = marlin_moe_permute_scales( + s=layer.w2_scales, + size_k=layer.w2_scales.shape[1] * self.quant_config.pack_factor, + size_n=layer.w2_scales.shape[2], + group_size=self.quant_config.group_size, + ) + replace_tensor(layer, "w2_scales", marlin_w2_scales) + layer.marlin_state = GPTQMarlinState.READY + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + ) -> torch.Tensor: + return fused_moe_marlin( + x, + layer.w13_qweight, + layer.w2_qweight, + router_logits, + layer.w13_g_idx, + layer.w2_g_idx, + layer.w13_g_idx_sort_indices, + layer.w2_g_idx_sort_indices, + top_k, + renormalize=renormalize, + w1_scale=layer.w13_scales, + w2_scale=layer.w2_scales, + num_bits=self.quant_config.quant_type.size_bits, + ) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 0ec68ac5b0f2..699d5f184414 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -176,6 +176,23 @@ def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int, return s +def marlin_moe_permute_scales( + s: torch.Tensor, + size_k: int, + size_n: int, + group_size: int, +): + num_experts = s.shape[0] + output = torch.empty( + (num_experts, s.shape[1], s.shape[2]), + device=s.device, + dtype=s.dtype, + ) + for e in range(num_experts): + output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size) + return output + + def marlin_zero_points(zp: torch.Tensor, size_k: int, size_n: int, num_bits: int) -> torch.Tensor: # Permute zero-points in a similar way to scales, but do not use the diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py index 7d08ac6f8746..4a06c5d63d52 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py @@ -1,6 +1,6 @@ """Utility functions used for tests and benchmarks""" -from typing import List +from typing import List, Optional import numpy as np import torch @@ -92,8 +92,11 @@ def get_weight_perm(num_bits: int): return perm -def marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int, - act_order: bool): +def marlin_quantize(w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + act_order: bool, + test_perm: Optional[torch.Tensor] = None): size_k, size_n = w.shape num_bits = quant_type.size_bits @@ -104,7 +107,7 @@ def marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int, # Quantize (and apply act_order if provided) w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights( - w, quant_type, group_size, act_order) + w, quant_type, group_size, act_order, test_perm) # For act_order, sort the "weights" and "g_idx" so that group ids are # increasing diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index 33f24ff5d54d..bdfda31de852 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -1,5 +1,5 @@ """This file is used for /tests and /benchmarks""" -from typing import List +from typing import List, Optional import numpy import torch @@ -53,7 +53,10 @@ def get_pack_factor(num_bits): return 32 // num_bits -def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int): +def permute_rows(q_w: torch.Tensor, + w_ref: torch.Tensor, + group_size: int, + test_perm: Optional[torch.Tensor] = None): assert q_w.shape == w_ref.shape orig_device = q_w.device @@ -64,7 +67,7 @@ def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int): g_idx[i] = i // group_size # Simulate act_order by doing a random permutation on K - rand_perm = torch.randperm(k_size) + rand_perm = test_perm if test_perm is not None else torch.randperm(k_size) g_idx = g_idx[rand_perm].contiguous() q_w = q_w[rand_perm, :].contiguous() @@ -164,8 +167,11 @@ def reshape_w(w): ) -def gptq_quantize_weights(w: torch.Tensor, quant_type: ScalarType, - group_size: int, act_order: bool): +def gptq_quantize_weights(w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + act_order: bool, + test_perm: Optional[torch.Tensor] = None): size_k, _ = w.shape assert w.is_floating_point(), "w must be float" @@ -186,7 +192,8 @@ def gptq_quantize_weights(w: torch.Tensor, quant_type: ScalarType, ), "For act_order, groupsize = {} must be less than size_k = {}".format( group_size, size_k) - w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size) + w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size, + test_perm) return w_ref, w_q, w_s, g_idx, rand_perm diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index 8bdd52b34317..b5b91c02e0ac 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -21,6 +21,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Mixtral model.""" +import logging from typing import Iterable, List, Optional, Tuple import numpy as np @@ -34,6 +35,7 @@ from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) +from vllm.model_executor.layers.fused_moe import GPTQFusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (QKVParallelLinear, ReplicatedLinear, @@ -49,6 +51,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput +logger = logging.getLogger(__name__) + class MixtralMLP(nn.Module): @@ -94,10 +98,13 @@ class MixtralMoE(nn.Module): def __init__( self, config: MixtralConfig, + use_fused_moe: bool, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config + self.use_fused_moe = use_fused_moe + self.quant_config = quant_config self.rank = get_tensor_model_parallel_rank() self.tp_size = get_tensor_model_parallel_world_size() self.num_total_experts = config.num_local_experts @@ -113,14 +120,28 @@ def __init__( raise ValueError( f"Rank {self.rank} has no experts assigned to it.") - self.experts = nn.ModuleList([ - MixtralMLP(self.num_total_experts, - config.hidden_size, - config.intermediate_size, - quant_config=quant_config) - if idx in self.expert_indicies else None - for idx in range(self.num_total_experts) - ]) + if self.use_fused_moe: + params_dtype = torch.float16 + self.experts = GPTQFusedMoE( + num_experts=self.num_total_experts, + top_k=self.top_k, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + params_dtype=params_dtype, + reduce_results=True, + renormalize=True, + quant_config=quant_config, + tp_size=self.tp_size) + else: + self.experts = nn.ModuleList([ + MixtralMLP(self.num_total_experts, + config.hidden_size, + config.intermediate_size, + quant_config=quant_config) + if idx in self.expert_indicies else None + for idx in range(self.num_total_experts) + ]) + self.gate = ReplicatedLinear(config.hidden_size, self.num_total_experts, bias=False, @@ -129,31 +150,36 @@ def __init__( def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) - # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) - routing_weights, selected_experts = torch.topk(routing_weights, - self.top_k, - dim=-1) - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - - final_hidden_states = None - for expert_idx in self.expert_indicies: - expert_layer = self.experts[expert_idx] - expert_mask = (selected_experts == expert_idx) - expert_weights = (routing_weights * expert_mask).sum(dim=-1, - keepdim=True) - - current_hidden_states = expert_layer(hidden_states).mul_( - expert_weights) - if final_hidden_states is None: - final_hidden_states = current_hidden_states - else: - final_hidden_states.add_(current_hidden_states) - - return tensor_model_parallel_all_reduce(final_hidden_states).view( - num_tokens, hidden_dim) + if self.use_fused_moe: + ret = self.experts(hidden_states.half(), router_logits) + return ret.bfloat16() + else: + routing_weights = F.softmax(router_logits, + dim=1, + dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, + self.top_k, + dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + + final_hidden_states = None + for expert_idx in self.expert_indicies: + expert_layer = self.experts[expert_idx] + expert_mask = (selected_experts == expert_idx) + expert_weights = (routing_weights * expert_mask).sum( + dim=-1, keepdim=True) + + current_hidden_states = expert_layer(hidden_states).mul_( + expert_weights) + if final_hidden_states is None: + final_hidden_states = current_hidden_states + else: + final_hidden_states.add_(current_hidden_states) + + return tensor_model_parallel_all_reduce(final_hidden_states).view( + num_tokens, hidden_dim) class MixtralAttention(nn.Module): @@ -238,6 +264,7 @@ class MixtralDecoderLayer(nn.Module): def __init__( self, config: MixtralConfig, + use_fused_moe: bool, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: @@ -254,6 +281,7 @@ def __init__( cache_config=cache_config, quant_config=quant_config) self.block_sparse_moe = MixtralMoE(config=config, + use_fused_moe=use_fused_moe, quant_config=quant_config) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -294,6 +322,7 @@ class MixtralModel(nn.Module): def __init__( self, config: MixtralConfig, + use_fused_moe: bool, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: @@ -307,6 +336,7 @@ def __init__( ) self.layers = nn.ModuleList([ MixtralDecoderLayer(config, + use_fused_moe, cache_config, quant_config=quant_config) for _ in range(config.num_hidden_layers) @@ -341,14 +371,15 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() + + self.use_fused_moe = (config.torch_dtype != torch.float8_e4m3fn) self.config = config self.quant_config = quant_config - self.model = MixtralModel(config, cache_config, quant_config) + self.model = MixtralModel(config, self.use_fused_moe, cache_config, + quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config) - if self.config.tie_word_embeddings: - self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -389,6 +420,14 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("qkv_proj", "v_proj", "v"), ] + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = GPTQFusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="w1", + ckpt_down_proj_name="w2", + ckpt_up_proj_name="w3", + num_experts=self.config.num_local_experts) + params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: @@ -408,11 +447,31 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - # Skip experts that are not assigned to this worker. - if ("block_sparse_moe.experts." in name - and name not in params_dict): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + + if self.use_fused_moe: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, name, shard_id, + expert_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + else: + if ("block_sparse_moe.experts." in name + and name not in params_dict): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight)