diff --git a/csrc/rocm/custom_kernels.cu b/csrc/rocm/custom_kernels.cu index d130461b27e2..6143729b237f 100644 --- a/csrc/rocm/custom_kernels.cu +++ b/csrc/rocm/custom_kernels.cu @@ -10,6 +10,11 @@ #define __HIP__MI300_MI250__ #endif +#if defined(__HIPCC__) && \ + (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + #define __HIP__MI300__ +#endif + #if defined(NDEBUG) #undef NDEBUG #include @@ -357,7 +362,7 @@ void LLGemmZZ(void* in_a, void* in_b, void* out_c, const int M, const int K, return rtn; }*/ -#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support +#if defined(__HIP__MI300__) // TODO: Add NAVI support template __global__ void __launch_bounds__(WvPrGrp* THRDS) wvSpltKQ_hf_sml_(const int K, const int Kp, const int N, const DTYPE* B, @@ -534,7 +539,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) n += CuCount * _WvPrGrp * YTILE; } } -#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support +#else // !defined(__HIP__MI300__) TODO: Add NAVI support template __global__ void wvSpltKQ_hf_sml_(const int K, const int Kp, const int N, const DTYPE* B, const DTYPE* __restrict__ A, @@ -544,9 +549,9 @@ __global__ void wvSpltKQ_hf_sml_(const int K, const int Kp, const int N, const int CuCount) { UNREACHABLE_CODE } -#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support +#endif // defined(__HIP__MI300__) TODO: Add NAVI support -#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support +#if defined(__HIP__MI300__) // TODO: Add NAVI support template __global__ void __launch_bounds__(WvPrGrp* THRDS) wvSpltKQ_hf_(const int K, const int Kp, const int N, const DTYPE* B, @@ -722,7 +727,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) n += CuCount * _WvPrGrp * YTILE; } } -#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support +#else // !defined(__HIP__MI300__) TODO: Add NAVI support template __global__ void wvSpltKQ_hf_(const int K, const int Kp, const int N, const DTYPE* B, const DTYPE* __restrict__ A, @@ -731,7 +736,7 @@ __global__ void wvSpltKQ_hf_(const int K, const int Kp, const int N, const int Otp, const int CuCount) { UNREACHABLE_CODE } -#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support +#endif // defined(__HIP__MI300__) TODO: Add NAVI support #if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support // This version targets cases where A[] fits LDS capacity diff --git a/vllm/model_executor/layers/tuned_gemm.py b/vllm/model_executor/layers/tuned_gemm.py index cf3caebf3201..ce3ab80985bd 100644 --- a/vllm/model_executor/layers/tuned_gemm.py +++ b/vllm/model_executor/layers/tuned_gemm.py @@ -10,7 +10,7 @@ from vllm import _custom_ops as ops from vllm.envs import VLLM_USE_ROCM_SKINNY_GEMM from vllm.platforms import current_platform -from vllm.utils import is_navi +from vllm.utils import is_mi250, is_navi support_tuned_gemms = False if current_platform.is_rocm(): @@ -102,7 +102,8 @@ def scaled_mm( bias: Optional[torch.Tensor], ) -> torch.Tensor: n = inp.shape[0] - if n != 1: + if (not VLLM_USE_ROCM_SKINNY_GEMM or n != 1 + or not current_platform.is_rocm() or is_mi250() or is_navi()): return torch._scaled_mm(inp, weight, out_dtype=out_dtype,