Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions csrc/rocm/custom_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@
#define __HIP__MI300_MI250__
#endif

#if defined(__HIPCC__) && \
(defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's an ask to remove GFX940 and GFX941 from all code. we only want 942 here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll need to do it in a separate effort, in upstream too, it's all over the place

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there are any, we need to remove it in general as a separate effort yes. but we should not add more mentions in a new PR?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add a new PR today. Before that I prefer to have uniformity across the codebase to avoid confusion

#define __HIP__MI300__
#endif

#if defined(NDEBUG)
#undef NDEBUG
#include <assert.h>
Expand Down Expand Up @@ -357,7 +362,7 @@ void LLGemmZZ(void* in_a, void* in_b, void* out_c, const int M, const int K,
return rtn;
}*/

#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support
#if defined(__HIP__MI300__) // TODO: Add NAVI support
template <int THRDS, int YTILE, int WvPrGrp, int A_CHUNK, int UNRL, int M>
__global__ void __launch_bounds__(WvPrGrp* THRDS)
wvSpltKQ_hf_sml_(const int K, const int Kp, const int N, const DTYPE* B,
Expand Down Expand Up @@ -534,7 +539,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
n += CuCount * _WvPrGrp * YTILE;
}
}
#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support
#else // !defined(__HIP__MI300__) TODO: Add NAVI support
template <int THRDS, int YTILE, int WvPrGrp, int A_CHUNK, int UNRL, int M>
__global__ void wvSpltKQ_hf_sml_(const int K, const int Kp, const int N,
const DTYPE* B, const DTYPE* __restrict__ A,
Expand All @@ -544,9 +549,9 @@ __global__ void wvSpltKQ_hf_sml_(const int K, const int Kp, const int N,
const int CuCount) {
UNREACHABLE_CODE
}
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
#endif // defined(__HIP__MI300__) TODO: Add NAVI support

#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support
#if defined(__HIP__MI300__) // TODO: Add NAVI support
template <int THRDS, int YTILE, int WvPrGrp, int A_CHUNK, int UNRL, int M>
__global__ void __launch_bounds__(WvPrGrp* THRDS)
wvSpltKQ_hf_(const int K, const int Kp, const int N, const DTYPE* B,
Expand Down Expand Up @@ -722,7 +727,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
n += CuCount * _WvPrGrp * YTILE;
}
}
#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support
#else // !defined(__HIP__MI300__) TODO: Add NAVI support
template <int THRDS, int YTILE, int WvPrGrp, int A_CHUNK, int UNRL, int M>
__global__ void wvSpltKQ_hf_(const int K, const int Kp, const int N,
const DTYPE* B, const DTYPE* __restrict__ A,
Expand All @@ -731,7 +736,7 @@ __global__ void wvSpltKQ_hf_(const int K, const int Kp, const int N,
const int Otp, const int CuCount) {
UNREACHABLE_CODE
}
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
#endif // defined(__HIP__MI300__) TODO: Add NAVI support

#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support
// This version targets cases where A[] fits LDS capacity
Expand Down
5 changes: 3 additions & 2 deletions vllm/model_executor/layers/tuned_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from vllm import _custom_ops as ops
from vllm.envs import VLLM_USE_ROCM_SKINNY_GEMM
from vllm.platforms import current_platform
from vllm.utils import is_navi
from vllm.utils import is_mi250, is_navi

support_tuned_gemms = False
if current_platform.is_rocm():
Expand Down Expand Up @@ -102,7 +102,8 @@ def scaled_mm(
bias: Optional[torch.Tensor],
) -> torch.Tensor:
n = inp.shape[0]
if n != 1:
if (not VLLM_USE_ROCM_SKINNY_GEMM or n != 1
or not current_platform.is_rocm() or is_mi250() or is_navi()):
return torch._scaled_mm(inp,
weight,
out_dtype=out_dtype,
Expand Down