Skip to content

Commit d9fec81

Browse files
committed
vLLM v0.8.5
Signed-off-by: Javier <[email protected]>
1 parent 997e9e0 commit d9fec81

File tree

4 files changed

+141
-8
lines changed

4 files changed

+141
-8
lines changed

cmake/external_projects/flashmla.cmake

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@ endif()
3030
FetchContent_MakeAvailable(flashmla)
3131
message(STATUS "FlashMLA is available at ${flashmla_SOURCE_DIR}")
3232

33+
if (WIN32)
34+
find_package(PythonInterp)
35+
find_package(Python)
36+
execute_process(COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/fix_cutlass_msvc.py ${flashmla_SOURCE_DIR})
37+
endif()
38+
3339
# The FlashMLA kernels only work on hopper and require CUDA 12.3 or later.
3440
# Only build FlashMLA kernels if we are building for something compatible with
3541
# sm90a

cmake/external_projects/vllm_flash_attn.cmake

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@ endif()
5050
FetchContent_MakeAvailable(vllm-flash-attn)
5151
message(STATUS "vllm-flash-attn is available at ${vllm-flash-attn_SOURCE_DIR}")
5252

53+
if (WIN32)
54+
find_package(PythonInterp)
55+
find_package(Python)
56+
execute_process(COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/fix_cutlass_msvc.py ${vllm-flash-attn_SOURCE_DIR})
57+
endif()
58+
5359
# Copy over the vllm-flash-attn python files (duplicated for fa2 and fa3, in
5460
# case only one is built, in the case both are built redundant work is done)
5561
install(

csrc/moe/marlin_moe_wna16/ops.cu

Lines changed: 116 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -405,33 +405,141 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
405405
NUM_THREADS, true)
406406

407407
template <typename scalar_t>
408-
MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
409-
int thread_m_blocks, int thread_n_blocks,
410-
int thread_k_blocks, bool m_block_size_8,
411-
bool has_act_order, bool has_zp,
412-
int group_blocks, int num_threads,
413-
bool is_zp_float) {
414-
int num_bits = q_type.size_bits();
415-
auto kernel = MarlinDefault;
408+
bool gptq_marlin_m1_u4b8(
409+
MarlinFuncPtr& kernel, const vllm::ScalarType q_type, int thread_m_blocks,
410+
int thread_n_blocks, int thread_k_blocks, bool m_block_size_8,
411+
bool has_act_order, bool has_zp, int group_blocks, int num_threads,
412+
bool is_zp_float) {
413+
bool skipped = false;
416414
if (false) {
417415
}
418416
GPTQ_GET_IF_M1(vllm::kU4B8, 8, 8, 256)
419417
GPTQ_GET_IF_M1(vllm::kU4B8, 8, 4, 128)
418+
else {
419+
skipped = true;
420+
}
421+
return skipped;
422+
}
420423

424+
template <typename scalar_t>
425+
bool gptq_marlin_m234_u4b8(
426+
MarlinFuncPtr& kernel, const vllm::ScalarType q_type, int thread_m_blocks,
427+
int thread_n_blocks, int thread_k_blocks, bool m_block_size_8,
428+
bool has_act_order, bool has_zp, int group_blocks, int num_threads,
429+
bool is_zp_float) {
430+
bool skipped = false;
431+
if (false) {
432+
}
421433
GPTQ_GET_IF_M234(vllm::kU4B8, 16, 4, 256)
422434
GPTQ_GET_IF_M234(vllm::kU4B8, 8, 4, 128)
435+
else {
436+
skipped = true;
437+
}
438+
return skipped;
439+
}
423440

441+
template <typename scalar_t>
442+
bool gptq_marlin_m1_u8b128(
443+
MarlinFuncPtr& kernel, const vllm::ScalarType q_type, int thread_m_blocks,
444+
int thread_n_blocks, int thread_k_blocks, bool m_block_size_8,
445+
bool has_act_order, bool has_zp, int group_blocks, int num_threads,
446+
bool is_zp_float) {
447+
bool skipped = false;
448+
if (false) {
449+
}
424450
GPTQ_GET_IF_M1(vllm::kU8B128, 8, 8, 256)
425451
GPTQ_GET_IF_M1(vllm::kU8B128, 8, 4, 128)
452+
else {
453+
skipped = true;
454+
}
455+
return skipped;
456+
}
426457

458+
template <typename scalar_t>
459+
bool gptq_marlin_m234_u8b128(
460+
MarlinFuncPtr& kernel, const vllm::ScalarType q_type, int thread_m_blocks,
461+
int thread_n_blocks, int thread_k_blocks, bool m_block_size_8,
462+
bool has_act_order, bool has_zp, int group_blocks, int num_threads,
463+
bool is_zp_float) {
464+
bool skipped = false;
465+
if (false) {
466+
}
427467
GPTQ_GET_IF_M234(vllm::kU8B128, 16, 4, 256)
428468
GPTQ_GET_IF_M234(vllm::kU8B128, 8, 4, 128)
469+
else {
470+
skipped = true;
471+
}
472+
return skipped;
473+
}
429474

475+
template <typename scalar_t>
476+
bool awq_marlin_m1_u4(
477+
MarlinFuncPtr& kernel, const vllm::ScalarType q_type, int thread_m_blocks,
478+
int thread_n_blocks, int thread_k_blocks, bool m_block_size_8,
479+
bool has_act_order, bool has_zp, int group_blocks, int num_threads,
480+
bool is_zp_float) {
481+
bool skipped = false;
482+
if (false) {
483+
}
430484
AWQ_GET_IF_M1(vllm::kU4, 8, 8, 256)
431485
AWQ_GET_IF_M1(vllm::kU4, 8, 4, 128)
486+
else {
487+
skipped = true;
488+
}
489+
return skipped;
490+
}
432491

492+
template <typename scalar_t>
493+
bool awq_marlin_m234_u4(
494+
MarlinFuncPtr& kernel, const vllm::ScalarType q_type, int thread_m_blocks,
495+
int thread_n_blocks, int thread_k_blocks, bool m_block_size_8,
496+
bool has_act_order, bool has_zp, int group_blocks, int num_threads,
497+
bool is_zp_float) {
498+
bool skipped = false;
499+
if (false) {
500+
}
433501
AWQ_GET_IF_M234(vllm::kU4, 16, 4, 256)
434502
AWQ_GET_IF_M234(vllm::kU4, 8, 4, 128)
503+
else {
504+
skipped = true;
505+
}
506+
return skipped;
507+
}
508+
509+
template <typename scalar_t>
510+
MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
511+
int thread_m_blocks, int thread_n_blocks,
512+
int thread_k_blocks, bool m_block_size_8,
513+
bool has_act_order, bool has_zp,
514+
int group_blocks, int num_threads,
515+
bool is_zp_float) {
516+
int num_bits = q_type.size_bits();
517+
auto kernel = MarlinDefault;
518+
519+
bool skipped = gptq_marlin_m1_u4b8<scalar_t>(
520+
kernel, q_type, thread_m_blocks, thread_n_blocks,
521+
thread_k_blocks, m_block_size_8, has_act_order,
522+
has_zp, group_blocks, num_threads, is_zp_float) &&
523+
gptq_marlin_m234_u4b8<scalar_t>(
524+
kernel, q_type, thread_m_blocks, thread_n_blocks,
525+
thread_k_blocks, m_block_size_8, has_act_order,
526+
has_zp, group_blocks, num_threads, is_zp_float) &&
527+
gptq_marlin_m1_u8b128<scalar_t>(
528+
kernel, q_type, thread_m_blocks, thread_n_blocks,
529+
thread_k_blocks, m_block_size_8, has_act_order,
530+
has_zp, group_blocks, num_threads, is_zp_float) &&
531+
gptq_marlin_m234_u8b128<scalar_t>(
532+
kernel, q_type, thread_m_blocks, thread_n_blocks,
533+
thread_k_blocks, m_block_size_8, has_act_order,
534+
has_zp, group_blocks, num_threads, is_zp_float) &&
535+
awq_marlin_m1_u4<scalar_t>(
536+
kernel, q_type, thread_m_blocks, thread_n_blocks,
537+
thread_k_blocks, m_block_size_8, has_act_order,
538+
has_zp, group_blocks, num_threads, is_zp_float) &&
539+
awq_marlin_m234_u4<scalar_t>(
540+
kernel, q_type, thread_m_blocks, thread_n_blocks,
541+
thread_k_blocks, m_block_size_8, has_act_order,
542+
has_zp, group_blocks, num_threads, is_zp_float);
435543

436544
return kernel;
437545
}

fix_cutlass_msvc.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import os
2+
import sys
3+
4+
platform_h_file = os.path.join(sys.argv[1], "csrc", "cutlass", "include", "cutlass", "platform", "platform.h")
5+
6+
if os.path.exists(platform_h_file):
7+
with open(platform_h_file, mode="r", encoding="utf-8") as file:
8+
header_content = "".join(file.readlines())
9+
10+
if "\n#if (201703L <=__cplusplus)\n" in header_content:
11+
header_content = header_content.replace("#if (201703L <=__cplusplus)", "#if defined(_MSC_VER) || (201703L <=__cplusplus)")
12+
with open(platform_h_file, mode="w", encoding="utf-8") as file:
13+
file.write(header_content)

0 commit comments

Comments
 (0)