@@ -405,33 +405,141 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
405405 NUM_THREADS, true )
406406
407407template <typename scalar_t>
408- MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
409- int thread_m_blocks, int thread_n_blocks,
410- int thread_k_blocks, bool m_block_size_8,
411- bool has_act_order, bool has_zp,
412- int group_blocks, int num_threads,
413- bool is_zp_float) {
414- int num_bits = q_type.size_bits ();
415- auto kernel = MarlinDefault;
408+ bool gptq_marlin_m1_u4b8(
409+ MarlinFuncPtr& kernel, const vllm::ScalarType q_type, int thread_m_blocks,
410+ int thread_n_blocks, int thread_k_blocks, bool m_block_size_8,
411+ bool has_act_order, bool has_zp, int group_blocks, int num_threads,
412+ bool is_zp_float) {
413+ bool skipped = false ;
416414 if (false ) {
417415 }
418416 GPTQ_GET_IF_M1 (vllm::kU4B8 , 8 , 8 , 256 )
419417 GPTQ_GET_IF_M1 (vllm::kU4B8 , 8 , 4 , 128 )
418+ else {
419+ skipped = true ;
420+ }
421+ return skipped;
422+ }
420423
424+ template <typename scalar_t >
425+ bool gptq_marlin_m234_u4b8 (
426+ MarlinFuncPtr& kernel, const vllm::ScalarType q_type, int thread_m_blocks,
427+ int thread_n_blocks, int thread_k_blocks, bool m_block_size_8,
428+ bool has_act_order, bool has_zp, int group_blocks, int num_threads,
429+ bool is_zp_float) {
430+ bool skipped = false ;
431+ if (false ) {
432+ }
421433 GPTQ_GET_IF_M234 (vllm::kU4B8 , 16 , 4 , 256 )
422434 GPTQ_GET_IF_M234 (vllm::kU4B8 , 8 , 4 , 128 )
435+ else {
436+ skipped = true ;
437+ }
438+ return skipped;
439+ }
423440
441+ template <typename scalar_t >
442+ bool gptq_marlin_m1_u8b128 (
443+ MarlinFuncPtr& kernel, const vllm::ScalarType q_type, int thread_m_blocks,
444+ int thread_n_blocks, int thread_k_blocks, bool m_block_size_8,
445+ bool has_act_order, bool has_zp, int group_blocks, int num_threads,
446+ bool is_zp_float) {
447+ bool skipped = false ;
448+ if (false ) {
449+ }
424450 GPTQ_GET_IF_M1 (vllm::kU8B128 , 8 , 8 , 256 )
425451 GPTQ_GET_IF_M1 (vllm::kU8B128 , 8 , 4 , 128 )
452+ else {
453+ skipped = true ;
454+ }
455+ return skipped;
456+ }
426457
458+ template <typename scalar_t >
459+ bool gptq_marlin_m234_u8b128 (
460+ MarlinFuncPtr& kernel, const vllm::ScalarType q_type, int thread_m_blocks,
461+ int thread_n_blocks, int thread_k_blocks, bool m_block_size_8,
462+ bool has_act_order, bool has_zp, int group_blocks, int num_threads,
463+ bool is_zp_float) {
464+ bool skipped = false ;
465+ if (false ) {
466+ }
427467 GPTQ_GET_IF_M234 (vllm::kU8B128 , 16 , 4 , 256 )
428468 GPTQ_GET_IF_M234 (vllm::kU8B128 , 8 , 4 , 128 )
469+ else {
470+ skipped = true ;
471+ }
472+ return skipped;
473+ }
429474
475+ template <typename scalar_t >
476+ bool awq_marlin_m1_u4 (
477+ MarlinFuncPtr& kernel, const vllm::ScalarType q_type, int thread_m_blocks,
478+ int thread_n_blocks, int thread_k_blocks, bool m_block_size_8,
479+ bool has_act_order, bool has_zp, int group_blocks, int num_threads,
480+ bool is_zp_float) {
481+ bool skipped = false ;
482+ if (false ) {
483+ }
430484 AWQ_GET_IF_M1 (vllm::kU4 , 8 , 8 , 256 )
431485 AWQ_GET_IF_M1 (vllm::kU4 , 8 , 4 , 128 )
486+ else {
487+ skipped = true ;
488+ }
489+ return skipped;
490+ }
432491
492+ template <typename scalar_t >
493+ bool awq_marlin_m234_u4 (
494+ MarlinFuncPtr& kernel, const vllm::ScalarType q_type, int thread_m_blocks,
495+ int thread_n_blocks, int thread_k_blocks, bool m_block_size_8,
496+ bool has_act_order, bool has_zp, int group_blocks, int num_threads,
497+ bool is_zp_float) {
498+ bool skipped = false ;
499+ if (false ) {
500+ }
433501 AWQ_GET_IF_M234 (vllm::kU4 , 16 , 4 , 256 )
434502 AWQ_GET_IF_M234 (vllm::kU4 , 8 , 4 , 128 )
503+ else {
504+ skipped = true ;
505+ }
506+ return skipped;
507+ }
508+
509+ template <typename scalar_t >
510+ MarlinFuncPtr get_marlin_kernel (const vllm::ScalarType q_type,
511+ int thread_m_blocks, int thread_n_blocks,
512+ int thread_k_blocks, bool m_block_size_8,
513+ bool has_act_order, bool has_zp,
514+ int group_blocks, int num_threads,
515+ bool is_zp_float) {
516+ int num_bits = q_type.size_bits ();
517+ auto kernel = MarlinDefault;
518+
519+ bool skipped = gptq_marlin_m1_u4b8<scalar_t >(
520+ kernel, q_type, thread_m_blocks, thread_n_blocks,
521+ thread_k_blocks, m_block_size_8, has_act_order,
522+ has_zp, group_blocks, num_threads, is_zp_float) &&
523+ gptq_marlin_m234_u4b8<scalar_t >(
524+ kernel, q_type, thread_m_blocks, thread_n_blocks,
525+ thread_k_blocks, m_block_size_8, has_act_order,
526+ has_zp, group_blocks, num_threads, is_zp_float) &&
527+ gptq_marlin_m1_u8b128<scalar_t >(
528+ kernel, q_type, thread_m_blocks, thread_n_blocks,
529+ thread_k_blocks, m_block_size_8, has_act_order,
530+ has_zp, group_blocks, num_threads, is_zp_float) &&
531+ gptq_marlin_m234_u8b128<scalar_t >(
532+ kernel, q_type, thread_m_blocks, thread_n_blocks,
533+ thread_k_blocks, m_block_size_8, has_act_order,
534+ has_zp, group_blocks, num_threads, is_zp_float) &&
535+ awq_marlin_m1_u4<scalar_t >(
536+ kernel, q_type, thread_m_blocks, thread_n_blocks,
537+ thread_k_blocks, m_block_size_8, has_act_order,
538+ has_zp, group_blocks, num_threads, is_zp_float) &&
539+ awq_marlin_m234_u4<scalar_t >(
540+ kernel, q_type, thread_m_blocks, thread_n_blocks,
541+ thread_k_blocks, m_block_size_8, has_act_order,
542+ has_zp, group_blocks, num_threads, is_zp_float);
435543
436544 return kernel;
437545}
0 commit comments