From 2c6fdc002bdfab7d2f0fd8f17e4c0461616fedfd Mon Sep 17 00:00:00 2001 From: charlifu Date: Mon, 24 Mar 2025 23:36:36 +0000 Subject: [PATCH 01/17] add apply_linear_rocm Signed-off-by: charlifu --- vllm/model_executor/layers/linear.py | 5 +++++ vllm/model_executor/layers/utils.py | 8 +++++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 1ae574072b8f..f5b9155c6153 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -17,6 +17,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.utils import apply_gemm_rocm # yapf: disable from vllm.model_executor.parameter import (BasevLLMParameter, BlockQuantScaleParameter, @@ -26,6 +27,7 @@ RowvLLMParameter) # yapf: enable from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform logger = init_logger(__name__) @@ -188,6 +190,9 @@ def apply(self, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: + if current_platform.is_rocm(): + return apply_gemm_rocm(x, layer.weight, bias) + return F.linear(x, layer.weight, bias) diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index a9ef973917e1..e0e0949317f8 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 """Utility methods for model layers.""" -from typing import Tuple +from typing import Optional, Tuple import torch @@ -56,3 +56,9 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts logits -= presence_penalties.unsqueeze(dim=1) * output_mask return logits + + +def apply_gemm_rocm(x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None): + pass From f6784a65c5fdac5081f08281ce222e9bacac49e3 Mon Sep 17 00:00:00 2001 From: charlifu Date: Wed, 26 Mar 2025 15:41:51 +0000 Subject: [PATCH 02/17] add skinny gemm for fp16 Signed-off-by: charlifu --- CMakeLists.txt | 1 + csrc/rocm/ops.h | 6 + csrc/rocm/skinny_gemms.cu | 1150 +++++++++++++++++++++++++++ csrc/rocm/torch_bindings.cpp | 12 + vllm/_custom_ops.py | 11 + vllm/model_executor/layers/utils.py | 27 +- vllm/platforms/interface.py | 7 + vllm/platforms/rocm.py | 8 +- 8 files changed, 1220 insertions(+), 2 deletions(-) create mode 100644 csrc/rocm/skinny_gemms.cu diff --git a/CMakeLists.txt b/CMakeLists.txt index 65d1ddbeee0b..6e110c55453a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -619,6 +619,7 @@ if(VLLM_GPU_LANG STREQUAL "HIP") # set(VLLM_ROCM_EXT_SRC "csrc/rocm/torch_bindings.cpp" + "csrc/rocm/skinny_gemms.cu" "csrc/rocm/attention.cu") define_gpu_extension_target( diff --git a/csrc/rocm/ops.h b/csrc/rocm/ops.h index ba161951772a..cf71c4f3370b 100644 --- a/csrc/rocm/ops.h +++ b/csrc/rocm/ops.h @@ -2,6 +2,12 @@ #include +void LLMM1(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, + const int64_t rows_per_block); + +void wvSpltK(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, + const int64_t N_in, const int64_t CuCount); + void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, diff --git a/csrc/rocm/skinny_gemms.cu b/csrc/rocm/skinny_gemms.cu new file mode 100644 index 000000000000..a4d002921ec3 --- /dev/null +++ b/csrc/rocm/skinny_gemms.cu @@ -0,0 +1,1150 @@ +#include +#include +#include +#include +#include +#include +#include +#include "cuda_compat.h" + +#if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx942__)) + #define __HIP__MI300_MI250__ +#endif + +#if defined(__HIPCC__) && defined(__gfx942__) + #define __HIP__MI300__ +#endif + +#if defined(NDEBUG) + #undef NDEBUG + #include + #define UNREACHABLE_CODE assert(false); + #define NDEBUG +#else + #define UNREACHABLE_CODE assert(false); +#endif + +template +__device__ __forceinline__ T loadnt(T* addr) { + return __builtin_nontemporal_load(addr); +} + +__device__ __forceinline__ float4 load_ntmprl(const float4* addr) { + auto addr_alias = reinterpret_cast(addr); + auto dat0 = loadnt(addr_alias); + auto dat1 = loadnt(addr_alias + 1); + auto dat2 = loadnt(addr_alias + 2); + auto dat3 = loadnt(addr_alias + 3); + return make_float4(dat0, dat1, dat2, dat3); +} + +// TBlock fetches entire rows of A, and entire col of B (K dimension); assume +// N=1 for time being grid is M/A_NUM_ROWS blocks +template +__global__ void LLGemm1_kernel(float4* af4, __half2* bf4, __half2* c, + const int K) { + __shared__ float red_smem[NUM_A_ROWS_PER_BLOCK][WARP_SIZE]; + const int row_addr = blockIdx.x * NUM_A_ROWS_PER_BLOCK * K / 8; + const int threadid = threadIdx.x; + const int warp = threadIdx.x / WARP_SIZE; + const int lane = threadIdx.x % WARP_SIZE; + const int num_warps = blockDim.x / WARP_SIZE; + const int qwarpid = threadid / 16; + const int qthreadid = threadid % 16; + float4 rowA_elem4[NUM_A_ROWS_PER_BLOCK]; + __half2 colB_elem4x, colB_elem4y, colB_elem4z, colB_elem4w; + float4 sum4; //[NUM_A_ROWS_PER_BLOCK]; + float acc[NUM_A_ROWS_PER_BLOCK] = {0.0}; + __half2 acch2; + __half2 oval; + + // As we later use warp shuffle operations, we may have more threads in the + // block than the actual available data, hence the if guard here. + if (threadid * 8 < K) { +#pragma unroll + for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) { + // rowA_elem4[i] holds 8 * half numbers seen as a single float4. + rowA_elem4[i] = load_ntmprl(&af4[row_addr + threadid + K / 8 * i]); + } + } + + colB_elem4x = bf4[threadid * 4 + 0]; + colB_elem4y = bf4[threadid * 4 + 1]; + colB_elem4z = bf4[threadid * 4 + 2]; + colB_elem4w = bf4[threadid * 4 + 3]; + + __half2 Af2; + __half2 Bf2; + float2 S; + + auto Ah2ptr = reinterpret_cast<__half2*>(&rowA_elem4); + __half2* ah2lptr; + +#pragma unroll + for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) { + // Multiply-add on 8 half. + ah2lptr = Ah2ptr + i * 4; + Af2 = *(ah2lptr); + acch2 = __hmul2(Af2, colB_elem4x); + Af2 = *(ah2lptr + 1); + acch2 = __hfma2(Af2, colB_elem4y, acch2); + Af2 = *(ah2lptr + 2); + acch2 = __hfma2(Af2, colB_elem4z, acch2); + Af2 = *(ah2lptr + 3); + acch2 = __hfma2(Af2, colB_elem4w, acch2); + S = __half22float2(acch2); + + // See comment above concerning the if guard. + if (threadid * 8 < K) { + acc[i] = S.x + S.y; // accumulation on float + } + } + +// all reduce across warp. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { +#pragma unroll + for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) { + acc[i] += __shfl_xor(acc[i], mask); + } + } + + // Warp leaders store the data to shared memory. + if (lane < NUM_A_ROWS_PER_BLOCK) { + red_smem[lane][warp] = acc[lane]; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + if (qwarpid < NUM_A_ROWS_PER_BLOCK) { + acc[qwarpid] = qthreadid < num_warps ? red_smem[qwarpid][qthreadid] : 0.f; +#pragma unroll + for (int mask = 16 / 2; mask >= 1; mask /= 2) { + acc[qwarpid] += __shfl_xor(acc[qwarpid], mask); + } + float oval2 = __shfl_xor(acc[qwarpid], 16); + + if (threadid % WARP_SIZE == 0 or threadid % WARP_SIZE == 32) { + oval = __float22half2_rn(make_float2(acc[qwarpid], oval2)); + c[blockIdx.x * NUM_A_ROWS_PER_BLOCK / 2 + qwarpid / 2] = oval; + } + } +} + +// define the kernel calling code: +void LLGemm1(void* in_a, void* in_b, void* out_c, const int M, const int K, + cudaStream_t stream, const int rows_per_block = 4) { + float4* af4 = reinterpret_cast(in_a); + auto* bf4 = reinterpret_cast<__half2*>(in_b); + auto* c = reinterpret_cast<__half2*>(out_c); + + // NUM_TREADS need to be a multiple of WARP_SIZE, as we are using warp shuffle + // operations. + const int NUM_THREADS = + K * 2 / 16 % WARP_SIZE == 0 + ? K * 2 / 16 + : K * 2 / 16 + (WARP_SIZE - K * 2 / 16 % WARP_SIZE); + + int NUM_BLOCKS = M / rows_per_block; + + if (rows_per_block == 2) { + LLGemm1_kernel<2><<>>(af4, bf4, c, K); + } else if (rows_per_block == 4) { + LLGemm1_kernel<4><<>>(af4, bf4, c, K); + } else if (rows_per_block == 8) { + LLGemm1_kernel<8><<>>(af4, bf4, c, K); + } else if (rows_per_block == 16) { + LLGemm1_kernel<16><<>>(af4, bf4, c, K); + } else { + NUM_BLOCKS = M / 4; + LLGemm1_kernel<4><<>>(af4, bf4, c, K); + } + + cudaError_t err = cudaGetLastError(); + if (cudaSuccess != err) + throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); +} + +void LLMM1(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, + const int64_t rows_per_block) { + auto M = in_a.size(0); + auto K = in_a.size(1); + + // call the kernel function... + LLGemm1(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), M, K, + at::cuda::getCurrentCUDAStream(), rows_per_block); +} + +#define DTYPE half + +#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support +// This version targets cases where A[] fits LDS capacity +template +__global__ void __launch_bounds__(WvPrGrp* THRDS) + wvSpltK_hf_sml_(const int K, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, const int _WvPrGrp, + const int CuCount) { + using half8 = + __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; + union bigType { + DTYPE h[A_CHUNK]; + float f[A_CHUNK / 2]; + float2 f2[A_CHUNK / 4]; + double d[A_CHUNK / 4]; + half8 h8; + }; + + //---------------------------------------------------- + // Reserving 64 KB of LDS to have 1 WG / CU + // Goal is to bring the activation matrix A to the LDS + // and use it across the lifetime of the work group + // TODO: When activation matrix is larger than 64 KB + // then this is not goint to work! + //---------------------------------------------------- + __shared__ half s[1024 * 32]; + + //---------------------------------------------------- + // Fetch the activation matrix to LDS + // Loop iteration: + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements + // - Each WG will fetch 512 * 16 => 8K elements + // - Then the WG will move to another 8 K elements + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + for (uint32_t k = 0; k < min(K * M, 32 * 1024); + k += THRDS * WvPrGrp * A_CHUNK) { + uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); + + // Transpose of A implementation + // uint32_t k_ot = (k_in / M) + (k_in % M) * K; // transopse for + // bank-conflict-free readback + + if (k_in >= min(K * M, 32 * 1024)) break; + + //((bigType*)(&s[k_in]))->b128 = ((bigType*)(&A[k_in]))->b128; + *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); + //((bigType*)(&s[k_ot]))->b128 = ((bigType*)(&A[k_in]))->b128; + } + __syncthreads(); + + // int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp); + if (threadIdx.y >= _WvPrGrp) return; + + uint32_t n = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE; + + float sum[M][YTILE]; + + //---------------------------------------------------- + // Each wave works on a single column of weight matrix. + // There are 16 waves per WG, and hence, each WG is + // working on 16 columns of weight matrix. Moreover, + // we tile in column direction by YTILE, so when YTILE=1 + // the above math is right, however, when YTILE=2 then + // each wave will be working on 2 columns and WG will + // be working on 32 columns. + // + // Top level loop that makes WGs persistent! + // - WGs iterates across columns of weight matrix + // - Each wave within WG works on a given column(s) + // - After completing first set of columns, WGs start + // working on the next set of available columns + //---------------------------------------------------- + while (n < N) { + //---------------------------------------------------- + // 'sum' accumulates the matrix A x B computation + // split across 64 lanes. + // + // YTILE represents how many column of weight matrix + // are being worked on by each wave. + //---------------------------------------------------- + for (int i = 0; i < YTILE; i++) + for (int m = 0; m < M; m++) sum[m][i] = 0; + + bigType bigA[M][UNRL]; + bigType bigB0[UNRL]; + bigType bigB1[UNRL]; + bigType bigB2[UNRL]; + bigType bigB3[UNRL]; + bigType bigB4[UNRL]; + bigType bigB5[UNRL]; + bigType bigB6[UNRL]; + bigType bigB7[UNRL]; + //---------------------------------------------------- + // Fetch weight matrix B in interleaved K-split! + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements (1024B) + // - YTILE represents the number of column being serviced + // by wave + // - Loop for fetching weight matrix (B) are unrolled + // + // Fetch activation matrix A from LDS + // - Loop for fetching activation matrix (A) are unrolled + // + // Finally, do the matrix multiplication in an unrolled + // fashion. This provides lot of food for compiler + // scheduling. + // + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + // for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { + for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { + // Fetch the weight matrix from memory! + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + const half* B_ = &B[(n + 0) * K + k_]; + bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * K]))); + //---------------------------------------------------- + // The following code with YTILE > 1 has to be deleted + //---------------------------------------------------- + if (YTILE >= 2) bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); + if (YTILE >= 3) bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); + if (YTILE >= 4) bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); + if (YTILE >= 5) bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); + if (YTILE >= 6) bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); + if (YTILE >= 7) bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); + if (YTILE >= 8) bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); + } + + // Fetch activation matrix from either just LDS or from both LDS / memory + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // Fetch A activation matrix in interleaved fashion from LDS or memory + + for (int m = 0; m < M; m++) { + // if (k_ + K * m < 32 * 1024) + bigA[m][k2] = *((const bigType*)(&(s[k_ + K * m]))); + // else + // bigA[m][k2] = *((const bigType*)(&(A[k_ + K * m]))); + } + } + + // Do the matrix multiplication in interleaved manner + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + // Do the matrix multiplication of activation and weight matrix + // - Remember the accumulation is happening for K-split of 64! + #pragma unroll + for (uint32_t m = 0; m < M; m++) { + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) { + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][0]) + : "0"(sum[m][0]), "v"(bigA[m][k2].f[b]), "v"(bigB0[k2].f[b])); + + //---------------------------------------------------- + // The following code with YTILE > 1 + //---------------------------------------------------- + if (YTILE >= 2) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][1]) + : "0"(sum[m][1]), "v"(bigA[m][k2].f[b]), "v"(bigB1[k2].f[b])); + if (YTILE >= 3) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][2]) + : "0"(sum[m][2]), "v"(bigA[m][k2].f[b]), "v"(bigB2[k2].f[b])); + if (YTILE >= 4) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][3]) + : "0"(sum[m][3]), "v"(bigA[m][k2].f[b]), "v"(bigB3[k2].f[b])); + if (YTILE >= 5) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][4]) + : "0"(sum[m][4]), "v"(bigA[m][k2].f[b]), "v"(bigB4[k2].f[b])); + if (YTILE >= 6) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][5]) + : "0"(sum[m][5]), "v"(bigA[m][k2].f[b]), "v"(bigB5[k2].f[b])); + if (YTILE >= 7) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][6]) + : "0"(sum[m][6]), "v"(bigA[m][k2].f[b]), "v"(bigB6[k2].f[b])); + if (YTILE >= 8) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][7]) + : "0"(sum[m][7]), "v"(bigA[m][k2].f[b]), "v"(bigB7[k2].f[b])); + } + } + } + } + + //---------------------------------------------------- + // Final reduction step using shuffle + //---------------------------------------------------- + for (int m = 0; m < M; m++) { + for (int y = 0; y < YTILE; y++) { + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + } + } + if (threadIdx.x == 63) { + for (int m = 0; m < M; m++) { + for (int i = 0; i < YTILE; i++) { + // if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); + C[n + i + m * N] = __float2half(sum[m][i]); + } + } + } + + n += CuCount * _WvPrGrp * YTILE; + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + // if (n < N && (n + YTILE) >= N) { + // uint32_t startColumn = N - YTILE; + // for (uint32_t i = 0; i < (n - startColumn); i++) { + // commitColumn[i] = 0; + // } + // n = startColumn; + //} + } +} +#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support +template +__global__ void wvSpltK_hf_sml_(const int K, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, + const int _WvPrGrp, const int CuCount) { + UNREACHABLE_CODE +} +#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support + +#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support +// This version targets cases where A[] marginally exceeds LDS capacity +template +__global__ void __launch_bounds__(WvPrGrp* THRDS) + wvSpltK_hf_(const int K, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, const int _WvPrGrp, + const int CuCount) { + using half8 = + __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; + union bigType { + DTYPE h[A_CHUNK]; + float f[A_CHUNK / 2]; + float2 f2[A_CHUNK / 4]; + double d[A_CHUNK / 4]; + half8 h8; + }; + + //---------------------------------------------------- + // Reserving 64 KB of LDS to have 1 WG / CU + // Goal is to bring the activation matrix A to the LDS + // and use it across the lifetime of the work group + // TODO: When activation matrix is larger than 64 KB + // then this is not goint to work! + //---------------------------------------------------- + __shared__ half s[1024 * 32]; + + //---------------------------------------------------- + // Computation of columns that need to be committed to memory! + //---------------------------------------------------- + uint32_t commitColumn[YTILE]; + for (uint32_t i = 0; i < YTILE; i++) { + commitColumn[i] = 1; + } + + //---------------------------------------------------- + // Indexing function into the column of weight matrix B + // Algorithm does 64 lane k-splitting / wave and uses + // WG ID and Thread ID to find the index. + //---------------------------------------------------- + // int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp); + uint32_t n = (blockIdx.x * _WvPrGrp + threadIdx.y) * YTILE; + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + if (n < N && (n + YTILE) >= N) { + uint32_t startColumn = N - YTILE; + for (uint32_t i = 0; i < (n - startColumn); i++) { + commitColumn[i] = 0; + } + n = startColumn; + } + + //---------------------------------------------------- + // Fetch the activation matrix to LDS + // Loop iteration: + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements + // - Each WG will fetch 512 * 16 => 8K elements + // - Then the WG will move to another 8 K elements + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + for (uint32_t k = 0; k < min(K * M, 32 * 1024); + k += THRDS * WvPrGrp * A_CHUNK) { + uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); + + // Transpose of A implementation + // uint32_t k_ot = (k_in / M) + (k_in % M) * K; // transopse for + // bank-conflict-free readback + + if (k_in >= min(K * M, 32 * 1024)) break; + + //((bigType*)(&s[k_in]))->b128 = ((bigType*)(&A[k_in]))->b128; + *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); + //((bigType*)(&s[k_ot]))->b128 = ((bigType*)(&A[k_in]))->b128; + } + + __syncthreads(); + + if (threadIdx.y >= _WvPrGrp) return; + + float sum[M][YTILE]; + + //---------------------------------------------------- + // Each wave works on a single column of weight matrix. + // There are 16 waves per WG, and hence, each WG is + // working on 16 columns of weight matrix. Moreover, + // we tile in column direction by YTILE, so when YTILE=1 + // the above math is right, however, when YTILE=2 then + // each wave will be working on 2 columns and WG will + // be working on 32 columns. + // + // Top level loop that makes WGs persistent! + // - WGs iterates across columns of weight matrix + // - Each wave within WG works on a given column(s) + // - After completing first set of columns, WGs start + // working on the next set of available columns + //---------------------------------------------------- + while (n < N) { + //---------------------------------------------------- + // 'sum' accumulates the matrix A x B computation + // split across 64 lanes. + // + // YTILE represents how many column of weight matrix + // are being worked on by each wave. + //---------------------------------------------------- + for (int i = 0; i < YTILE; i++) + for (int m = 0; m < M; m++) sum[m][i] = 0; + + bigType bigA[M][UNRL]; + bigType bigB0[UNRL]; + bigType bigB1[UNRL]; + bigType bigB2[UNRL]; + bigType bigB3[UNRL]; + bigType bigB4[UNRL]; + bigType bigB5[UNRL]; + bigType bigB6[UNRL]; + bigType bigB7[UNRL]; + bigType bigB8[UNRL]; + //---------------------------------------------------- + // Fetch weight matrix B in interleaved K-split! + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements (1024B) + // - YTILE represents the number of column being serviced + // by wave + // - Loop for fetching weight matrix (B) are unrolled + // + // Fetch activation matrix A from LDS + // - Loop for fetching activation matrix (A) are unrolled + // + // Finally, do the matrix multiplication in an unrolled + // fashion. This provides lot of food for compiler + // scheduling. + // + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { + // Fetch the weight matrix from memory! + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + const half* B_ = &B[(n + 0) * K + k_]; + bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * K]))); + //---------------------------------------------------- + // The following code with YTILE > 1 has to be deleted + //---------------------------------------------------- + if (YTILE >= 2) bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); + if (YTILE >= 3) bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); + if (YTILE >= 4) bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); + if (YTILE >= 5) bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); + if (YTILE >= 6) bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); + if (YTILE >= 7) bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); + if (YTILE >= 8) bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); + } + + // Fetch activation matrix from either just LDS or from both LDS / memory + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // Fetch A activation matrix in interleaved fashion from LDS or memory + + for (int m = 0; m < M; m++) { + if (k_ + K * m < 32 * 1024) + bigA[m][k2] = *((const bigType*)(&(s[k_ + K * m]))); + else + bigA[m][k2] = *((const bigType*)(&(A[k_ + K * m]))); + } + } + + // Do the matrix multiplication in interleaved manner + #pragma unroll + for (uint32_t m = 0; m < M; m++) { + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + // Do the matrix multiplication of activation and weight matrix + // - Remember the accumulation is happening for K-split of 64! + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) { + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][0]) + : "0"(sum[m][0]), "v"(bigA[m][k2].f[b]), "v"(bigB0[k2].f[b])); + + //---------------------------------------------------- + // The following code with YTILE > 1 + //---------------------------------------------------- + if (YTILE >= 2) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][1]) + : "0"(sum[m][1]), "v"(bigA[m][k2].f[b]), "v"(bigB1[k2].f[b])); + if (YTILE >= 3) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][2]) + : "0"(sum[m][2]), "v"(bigA[m][k2].f[b]), "v"(bigB2[k2].f[b])); + if (YTILE >= 4) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][3]) + : "0"(sum[m][3]), "v"(bigA[m][k2].f[b]), "v"(bigB3[k2].f[b])); + if (YTILE >= 5) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][4]) + : "0"(sum[m][4]), "v"(bigA[m][k2].f[b]), "v"(bigB4[k2].f[b])); + if (YTILE >= 6) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][5]) + : "0"(sum[m][5]), "v"(bigA[m][k2].f[b]), "v"(bigB5[k2].f[b])); + if (YTILE >= 7) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][6]) + : "0"(sum[m][6]), "v"(bigA[m][k2].f[b]), "v"(bigB6[k2].f[b])); + if (YTILE >= 8) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][7]) + : "0"(sum[m][7]), "v"(bigA[m][k2].f[b]), "v"(bigB7[k2].f[b])); + } + } + } + } + + //---------------------------------------------------- + // Final reduction step using shuffle + //---------------------------------------------------- + for (int m = 0; m < M; m++) { + for (int y = 0; y < YTILE; y++) { + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + } + } + + if (threadIdx.x == 63) { + for (int m = 0; m < M; m++) { + for (int i = 0; i < YTILE; i++) { + if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); + } + } + } + + n += CuCount * _WvPrGrp * YTILE; + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + if (n < N && (n + YTILE) >= N) { + uint32_t startColumn = N - YTILE; + for (uint32_t i = 0; i < (n - startColumn); i++) { + commitColumn[i] = 0; + } + n = startColumn; + } + } +} + +#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support +template +__global__ void wvSpltK_hf_(const int K, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, + const int _WvPrGrp, const int CuCount) { + UNREACHABLE_CODE +} +#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support + +#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support +// This version targets big A[] cases, where it is much larger than LDS capacity +template +__global__ void __launch_bounds__(WvPrGrp* THRDS) + wvSpltK_hf_big_(const int K, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, const int _WvPrGrp, + const int CuCount) { + using half8 = + __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; + + union bigType { + DTYPE h[A_CHUNK]; + float f[A_CHUNK / 2]; + float2 f2[A_CHUNK / 4]; + double d[A_CHUNK / 4]; + half8 h8; + }; + + //---------------------------------------------------- + // Reserving 64 KB of LDS to have 1 WG / CU + // Goal is to bring the activation matrix A to the LDS + // and use it across the lifetime of the work group + // TODO: When activation matrix is larger than 64 KB + // then this is not goint to work! + //---------------------------------------------------- + __shared__ half s[1024 * 32]; + + //---------------------------------------------------- + // Computation of columns that need to be committed to memory! + //---------------------------------------------------- + uint32_t commitColumn[YTILE]; + for (uint32_t i = 0; i < YTILE; i++) { + commitColumn[i] = 1; + } + + // int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp); + if (threadIdx.y >= _WvPrGrp) return; + + //---------------------------------------------------- + // Indexing function into the column of weight matrix B + // Algorithm does 64 lane k-splitting / wave and uses + // WG ID and Thread ID to find the index. + //---------------------------------------------------- + uint32_t n = (blockIdx.x * _WvPrGrp + threadIdx.y) * YTILE; + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + if (n < N && (n + YTILE) >= N) { + uint32_t startColumn = N - YTILE; + for (uint32_t i = 0; i < (n - startColumn); i++) { + commitColumn[i] = 0; + } + n = startColumn; + } + + //---------------------------------------------------- + // Fetch the activation matrix to LDS + // Loop iteration: + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements + // - Each WG will fetch 512 * 16 => 8K elements + // - Then the WG will move to another 8 K elements + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + #define PCML + #ifndef PCML + for (uint32_t k = 0; k < min(K * M, 32 * 1024); + k += THRDS * WvPrGrp * A_CHUNK) { + uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); + + // Transpose of A implementation + // uint32_t k_ot = (k_in / M) + (k_in % M) * K; // transopse for + // bank-conflict-free readback + + if (k_in >= min(K * M, 32 * 1024)) break; + + //((bigType*)(&s[k_in]))->b128 = ((bigType*)(&A[k_in]))->b128; + *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); + //((bigType*)(&s[k_ot]))->b128 = ((bigType*)(&A[k_in]))->b128; + } + __syncthreads(); + #endif + + #define TUC (THRDS * UNRL * A_CHUNK) + uint32_t kBase = 0; + // find biggest k size that fits in LDS + uint32_t kFit = (32 * 1024) / M; + // kFit = (kFit%TWC==0) ? kFit : (kFit-kFit%TWC+TWC); //round up to multiple + // of TUC + kFit = (kFit % TUC == 0) + ? kFit + : (kFit - kFit % TUC); // round up to multiple of TUC + // if (kFit == 0) kFit = TUC; + kFit = min(kFit, K); + + float sum[M][YTILE]; + + //---------------------------------------------------- + // Each wave works on a single column of weight matrix. + // There are 16 waves per WG, and hence, each WG is + // working on 16 columns of weight matrix. Moreover, + // we tile in column direction by YTILE, so when YTILE=1 + // the above math is right, however, when YTILE=2 then + // each wave will be working on 2 columns and WG will + // be working on 32 columns. + // + // Top level loop that makes WGs persistent! + // - WGs iterates across columns of weight matrix + // - Each wave within WG works on a given column(s) + // - After completing first set of columns, WGs start + // working on the next set of available columns + //---------------------------------------------------- + #ifdef PCML + int YW = (YTILE * _WvPrGrp); + uint32_t Nrndp = (N % YW == 0) ? N : (N - N % YW + YW); + while (n < Nrndp) { + #else + while (n < N) { + #endif + //---------------------------------------------------- + // 'sum' accumulates the matrix A x B computation + // split across 64 lanes. + // + // YTILE represents how many column of weight matrix + // are being worked on by each wave. + //---------------------------------------------------- + for (int i = 0; i < YTILE; i++) + for (int m = 0; m < M; m++) sum[m][i] = 0; + + bigType bigA[M][UNRL]; + bigType bigB0[UNRL]; + bigType bigB1[UNRL]; + bigType bigB2[UNRL]; + bigType bigB3[UNRL]; + bigType bigB4[UNRL]; + bigType bigB5[UNRL]; + bigType bigB6[UNRL]; + bigType bigB7[UNRL]; + bigType bigB8[UNRL]; + bigType bigB9[UNRL]; + bigType bigB10[UNRL]; + //---------------------------------------------------- + // Fetch weight matrix B in interleaved K-split! + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements (1024B) + // - YTILE represents the number of column being serviced + // by wave + // - Loop for fetching weight matrix (B) are unrolled + // + // Fetch activation matrix A from LDS + // - Loop for fetching activation matrix (A) are unrolled + // + // Finally, do the matrix multiplication in an unrolled + // fashion. This provides lot of food for compiler + // scheduling. + // + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { + #ifdef PCML + if ((k1 == 0) || (k1 == kBase + kFit)) { // load next chunk of A[] to LDS + if (k1 != 0) kBase += kFit; + __syncthreads(); + for (uint32_t k = 0; k < kFit; k += THRDS * _WvPrGrp * A_CHUNK) { + uint32_t kOff = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); + if (kBase + kOff >= K) break; + if (kOff >= kFit) break; + for (uint32_t m = 0; m < M; m++) { + uint32_t k_in = kBase + m * K + kOff; + uint32_t k_ot = m * kFit + kOff; + *((bigType*)(&s[k_ot])) = *((bigType*)(&A[k_in])); + } + } + __syncthreads(); + } + if (n >= N) continue; + #endif + + // Fetch the weight matrix from memory! + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + const half* B_ = &B[(n + 0) * K + k_]; + bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * K]))); + //---------------------------------------------------- + // The following code with YTILE > 1 has to be deleted + //---------------------------------------------------- + if (YTILE >= 2) bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); + if (YTILE >= 3) bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); + if (YTILE >= 4) bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); + if (YTILE >= 5) bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); + if (YTILE >= 6) bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); + if (YTILE >= 7) bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); + if (YTILE >= 8) bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); + } + + // Fetch activation matrix from either just LDS or from both LDS / memory + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // Fetch A activation matrix in interleaved fashion from LDS or memory + + for (int m = 0; m < M; m++) { + #ifdef PCML + bigA[m][k2] = *((const bigType*)(&(s[k_ - kBase + kFit * m]))); + #else + if (k_ + K * m < 32 * 1024) + bigA[m][k2] = *((const bigType*)(&(s[k_ + K * m]))); + else + bigA[m][k2] = *((const bigType*)(&(A[k_ + K * m]))); + #endif + } + } + + // Do the matrix multiplication in interleaved manner + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + #pragma unroll + for (uint32_t m = 0; m < M; m++) { + // Do the matrix multiplication of activation and weight matrix + // - Remember the accumulation is happening for K-split of 64! + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) { + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][0]) + : "0"(sum[m][0]), "v"(bigA[m][k2].f[b]), "v"(bigB0[k2].f[b])); + + //---------------------------------------------------- + // The following code with YTILE > 1 + //---------------------------------------------------- + if (YTILE >= 2) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][1]) + : "0"(sum[m][1]), "v"(bigA[m][k2].f[b]), "v"(bigB1[k2].f[b])); + if (YTILE >= 3) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][2]) + : "0"(sum[m][2]), "v"(bigA[m][k2].f[b]), "v"(bigB2[k2].f[b])); + if (YTILE >= 4) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][3]) + : "0"(sum[m][3]), "v"(bigA[m][k2].f[b]), "v"(bigB3[k2].f[b])); + if (YTILE >= 5) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][4]) + : "0"(sum[m][4]), "v"(bigA[m][k2].f[b]), "v"(bigB4[k2].f[b])); + if (YTILE >= 6) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][5]) + : "0"(sum[m][5]), "v"(bigA[m][k2].f[b]), "v"(bigB5[k2].f[b])); + if (YTILE >= 7) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][6]) + : "0"(sum[m][6]), "v"(bigA[m][k2].f[b]), "v"(bigB6[k2].f[b])); + if (YTILE >= 8) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][7]) + : "0"(sum[m][7]), "v"(bigA[m][k2].f[b]), "v"(bigB7[k2].f[b])); + } + } + } + } + + #ifdef PCML + if (n >= N) { + n += CuCount * _WvPrGrp * YTILE; + kBase = 0; + continue; + } + #endif + + //---------------------------------------------------- + // Final reduction step using shuffle + //---------------------------------------------------- + for (int m = 0; m < M; m++) { + for (int y = 0; y < YTILE; y++) { + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + } + } + + if (threadIdx.x == 63) { + for (int m = 0; m < M; m++) { + for (int i = 0; i < YTILE; i++) { + if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); + } + } + } + + n += CuCount * _WvPrGrp * YTILE; + kBase = 0; + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + if (n < N && (n + YTILE) >= N) { + uint32_t startColumn = N - YTILE; + for (uint32_t i = 0; i < (n - startColumn); i++) { + commitColumn[i] = 0; + } + n = startColumn; + } + } +} +#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support +template +__global__ void wvSpltK_hf_big_(const int K, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, + const int _WvPrGrp, const int CuCount) { + UNREACHABLE_CODE +} +#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support + +int mindiv(int N, int div1, int div2) { + int nPrRnd = div1 * div2; + int rnds0 = N / nPrRnd; + nPrRnd -= div1 * 3; + int rnds3 = N / nPrRnd; + nPrRnd -= div1; + int rnds4 = N / nPrRnd; + nPrRnd -= div1; + int rnds5 = N / nPrRnd; + nPrRnd -= div1; + int rnds6 = N / nPrRnd; + nPrRnd -= div1; + int rnds7 = N / nPrRnd; + nPrRnd -= div1; + int rnds8 = N / nPrRnd; + nPrRnd -= div1; + int rnds9 = N / nPrRnd; + nPrRnd -= div1; + int rtn = div2; + if (rnds0 == rnds3) rtn = div2 - 3; + if (rnds0 == rnds4) rtn = div2 - 4; + if (rnds0 == rnds5) rtn = div2 - 5; + if (rnds0 == rnds6) rtn = div2 - 6; + if (rnds0 == rnds7) rtn = div2 - 7; + if (rnds0 == rnds8) rtn = div2 - 8; + if (rnds0 == rnds9) rtn = div2 - 9; + return rtn; +} + +void wvSpltK_(void* in_a, void* in_b, void* out_c, const int M_in, + const int K_in, const int N_in, cudaStream_t stream, + const int CuCount = 0) { + dim3 grid(CuCount); + half* af4 = reinterpret_cast(in_a); + const half* bf4 = reinterpret_cast(in_b); + auto* c = reinterpret_cast(out_c); + +#define WVSPLTK(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \ + _N) \ + { \ + dim3 block(64, _WvPrGrp); \ + if ((K_in * N_in <= 32 * 1024) && (M_in % _YTILEs == 0)) { \ + int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \ + wvSpltK_hf_sml_<64, _YTILEs, _WvPrGrp, 8, _UNRLs, _N> \ + <<>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \ + CuCount); \ + } else if (K_in * N_in <= 32 * 1024 * 1.2) { \ + int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \ + wvSpltK_hf_<64, _YTILEm, _WvPrGrp, 8, _UNRLm, _N> \ + <<>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \ + CuCount); \ + } else { \ + int __wvPrGrp = mindiv(M_in, CuCount * _YTILEb, _WvPrGrp); \ + wvSpltK_hf_big_<64, _YTILEb, _WvPrGrp, 8, _UNRLb, _N> \ + <<>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \ + CuCount); \ + } \ + } + + switch (N_in) { + case 1: + WVSPLTK(16, 2, 2, 2, 2, 2, 2, 1) // MI308 + break; + case 2: + WVSPLTK(16, 2, 2, 2, 2, 2, 2, 2) // MI308 + break; + case 3: + WVSPLTK(16, 4, 7, 7, 1, 1, 1, 3) // MI308 + break; + case 4: + WVSPLTK(16, 4, 7, 7, 1, 1, 1, 4) // MI308 + break; + default: + throw std::runtime_error("Unsupported N value: " + std::to_string(M_in) + + "," + std::to_string(K_in) + "," + + std::to_string(N_in)); + } + + cudaError_t err = cudaGetLastError(); + if (cudaSuccess != err) { + throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); + } +} + +void wvSpltK(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, + const int64_t N_in, const int64_t CuCount) { + auto M = in_a.size(0); + auto K = in_a.size(1); + int N = N_in; + wvSpltK_(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), M, K, N, + at::cuda::getCurrentCUDAStream(), CuCount); +} \ No newline at end of file diff --git a/csrc/rocm/torch_bindings.cpp b/csrc/rocm/torch_bindings.cpp index a5d2e2f97a3e..cd0b17c5ac6f 100644 --- a/csrc/rocm/torch_bindings.cpp +++ b/csrc/rocm/torch_bindings.cpp @@ -14,6 +14,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) { // vLLM custom ops for rocm + // Custom gemm op for matrix-vector multiplication + rocm_ops.def( + "LLMM1(Tensor in_a, Tensor in_b, Tensor! out_c, int rows_per_block) -> " + "()"); + rocm_ops.impl("LLMM1", torch::kCUDA, &LLMM1); + + // Custom gemm op for skinny matrix-matrix multiplication + rocm_ops.def( + "wvSpltK(Tensor in_a, Tensor in_b, Tensor! out_c, int N_in," + " int CuCount) -> ()"); + rocm_ops.impl("wvSpltK", torch::kCUDA, &wvSpltK); + // Custom attention op // Compute the attention between an input query and the cached // keys/values using PagedAttention. diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index dc07bad4680f..c39c8740cf3a 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1127,6 +1127,17 @@ def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, ssm_states, pad_slot_id) +# ROCm skinny gemms +def LLMM1(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor, + rows_per_block: int) -> None: + torch.ops._rocm_C.LLMM1(a, b, out, rows_per_block) + + +def wvSpltK(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor, N: int, + cu_count: int) -> None: + torch.ops._rocm_C.wvSpltK(a, b, out, N, cu_count) + + # moe def moe_sum(input: torch.Tensor, output: torch.Tensor): torch.ops._moe_C.moe_sum(input, output) diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index e0e0949317f8..3444e3d3b93f 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -4,6 +4,9 @@ import torch +from vllm import _custom_ops as ops +from vllm.platforms import current_platform + def get_token_bin_counts_and_mask( tokens: torch.Tensor, @@ -61,4 +64,26 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, def apply_gemm_rocm(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None): - pass + x_view = x.view(-1, x.size(-1)) + m = weight.shape[0] + k = weight.shape[1] + n = x_view.shape[0] + cu_count = current_platform.get_cu_count() + + if bias is not None or x.dtype != torch.float16 or k % 8 != 0: + return torch.nn.functional.linear(x, weight, bias) + if m > 8 and n <= 4: + out = torch.empty(x_view.shape[0], + weight.shape[0], + dtype=x.dtype, + device=x.device) + ops.wvSpltK(weight, x_view, out, n, cu_count) + return out.view(*x.shape[:-1], weight.shape[0]) + elif m % 4 == 0 and n == 1 and k <= 8192: + out = torch.empty(x_view.shape[0], + weight.shape[0], + dtype=x.dtype, + device=x.device) + ops.LLMM1(weight, x_view, out, 4) + return out.view(*x.shape[:-1], weight.shape[0]) + return torch.nn.functional.linear(x, weight, bias) diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 9981deee39b7..5ffaed4eea6c 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -371,6 +371,13 @@ def use_all_gather(cls) -> bool: or parallel_config.distributed_executor_backend == "external_launcher") + @classmethod + def get_cu_count(cls, device_id: int = 0) -> int: + """ + Returns the total number of compute units (CU) on single GPU. + """ + raise NotImplementedError + class UnspecifiedPlatform(Platform): _enum = PlatformEnum.UNSPECIFIED diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index ee708f5961df..5cb06a028a89 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import os -from functools import lru_cache, wraps +from functools import cache, lru_cache, wraps from typing import TYPE_CHECKING, Dict, List, Optional import torch @@ -249,3 +249,9 @@ def fp8_dtype(cls) -> torch.dtype: return torch.float8_e4m3fnuz else: return torch.float8_e4m3fn + + @classmethod + @cache + def get_cu_count(cls, device_id: int = 0) -> int: + return torch.cuda.get_device_properties( + device_id).multi_processor_count From 0993ea0bb546f33eddce5b5e2b1ba7e9616e6bc2 Mon Sep 17 00:00:00 2001 From: charlifu Date: Wed, 26 Mar 2025 16:26:21 +0000 Subject: [PATCH 03/17] use wvSplitK Signed-off-by: charlifu --- csrc/rocm/ops.h | 4 +- csrc/rocm/skinny_gemms.cu | 98 ++++++++++++++--------------- csrc/rocm/torch_bindings.cpp | 4 +- vllm/_custom_ops.py | 6 +- vllm/model_executor/layers/utils.py | 2 +- 5 files changed, 57 insertions(+), 57 deletions(-) diff --git a/csrc/rocm/ops.h b/csrc/rocm/ops.h index cf71c4f3370b..b248fd70e9b0 100644 --- a/csrc/rocm/ops.h +++ b/csrc/rocm/ops.h @@ -5,8 +5,8 @@ void LLMM1(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, const int64_t rows_per_block); -void wvSpltK(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, - const int64_t N_in, const int64_t CuCount); +void wvSplitK(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, + const int64_t N_in, const int64_t CuCount); void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& tmp_out, diff --git a/csrc/rocm/skinny_gemms.cu b/csrc/rocm/skinny_gemms.cu index a4d002921ec3..206c24f1d227 100644 --- a/csrc/rocm/skinny_gemms.cu +++ b/csrc/rocm/skinny_gemms.cu @@ -182,9 +182,9 @@ void LLMM1(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, // This version targets cases where A[] fits LDS capacity template __global__ void __launch_bounds__(WvPrGrp* THRDS) - wvSpltK_hf_sml_(const int K, const int N, const DTYPE* B, - const DTYPE* __restrict__ A, DTYPE* C, const int _WvPrGrp, - const int CuCount) { + wvSplitK_hf_sml_(const int K, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, const int _WvPrGrp, + const int CuCount) { using half8 = __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; union bigType { @@ -429,9 +429,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) } #else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support template -__global__ void wvSpltK_hf_sml_(const int K, const int N, const DTYPE* B, - const DTYPE* __restrict__ A, DTYPE* C, - const int _WvPrGrp, const int CuCount) { +__global__ void wvSplitK_hf_sml_(const int K, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, + const int _WvPrGrp, const int CuCount) { UNREACHABLE_CODE } #endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support @@ -440,9 +440,9 @@ __global__ void wvSpltK_hf_sml_(const int K, const int N, const DTYPE* B, // This version targets cases where A[] marginally exceeds LDS capacity template __global__ void __launch_bounds__(WvPrGrp* THRDS) - wvSpltK_hf_(const int K, const int N, const DTYPE* B, - const DTYPE* __restrict__ A, DTYPE* C, const int _WvPrGrp, - const int CuCount) { + wvSplitK_hf_(const int K, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, const int _WvPrGrp, + const int CuCount) { using half8 = __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; union bigType { @@ -712,9 +712,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) #else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support template -__global__ void wvSpltK_hf_(const int K, const int N, const DTYPE* B, - const DTYPE* __restrict__ A, DTYPE* C, - const int _WvPrGrp, const int CuCount) { +__global__ void wvSplitK_hf_(const int K, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, + const int _WvPrGrp, const int CuCount) { UNREACHABLE_CODE } #endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support @@ -723,9 +723,9 @@ __global__ void wvSpltK_hf_(const int K, const int N, const DTYPE* B, // This version targets big A[] cases, where it is much larger than LDS capacity template __global__ void __launch_bounds__(WvPrGrp* THRDS) - wvSpltK_hf_big_(const int K, const int N, const DTYPE* B, - const DTYPE* __restrict__ A, DTYPE* C, const int _WvPrGrp, - const int CuCount) { + wvSplitK_hf_big_(const int K, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, const int _WvPrGrp, + const int CuCount) { using half8 = __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; @@ -1049,9 +1049,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) } #else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support template -__global__ void wvSpltK_hf_big_(const int K, const int N, const DTYPE* B, - const DTYPE* __restrict__ A, DTYPE* C, - const int _WvPrGrp, const int CuCount) { +__global__ void wvSplitK_hf_big_(const int K, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, + const int _WvPrGrp, const int CuCount) { UNREACHABLE_CODE } #endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support @@ -1085,48 +1085,48 @@ int mindiv(int N, int div1, int div2) { return rtn; } -void wvSpltK_(void* in_a, void* in_b, void* out_c, const int M_in, - const int K_in, const int N_in, cudaStream_t stream, - const int CuCount = 0) { +void wvSplitK_(void* in_a, void* in_b, void* out_c, const int M_in, + const int K_in, const int N_in, cudaStream_t stream, + const int CuCount = 0) { dim3 grid(CuCount); half* af4 = reinterpret_cast(in_a); const half* bf4 = reinterpret_cast(in_b); auto* c = reinterpret_cast(out_c); -#define WVSPLTK(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \ - _N) \ - { \ - dim3 block(64, _WvPrGrp); \ - if ((K_in * N_in <= 32 * 1024) && (M_in % _YTILEs == 0)) { \ - int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \ - wvSpltK_hf_sml_<64, _YTILEs, _WvPrGrp, 8, _UNRLs, _N> \ - <<>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \ - CuCount); \ - } else if (K_in * N_in <= 32 * 1024 * 1.2) { \ - int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \ - wvSpltK_hf_<64, _YTILEm, _WvPrGrp, 8, _UNRLm, _N> \ - <<>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \ - CuCount); \ - } else { \ - int __wvPrGrp = mindiv(M_in, CuCount * _YTILEb, _WvPrGrp); \ - wvSpltK_hf_big_<64, _YTILEb, _WvPrGrp, 8, _UNRLb, _N> \ - <<>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \ - CuCount); \ - } \ +#define WVSPLITK(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \ + _N) \ + { \ + dim3 block(64, _WvPrGrp); \ + if ((K_in * N_in <= 32 * 1024) && (M_in % _YTILEs == 0)) { \ + int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \ + wvSplitK_hf_sml_<64, _YTILEs, _WvPrGrp, 8, _UNRLs, _N> \ + <<>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \ + CuCount); \ + } else if (K_in * N_in <= 32 * 1024 * 1.2) { \ + int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \ + wvSplitK_hf_<64, _YTILEm, _WvPrGrp, 8, _UNRLm, _N> \ + <<>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \ + CuCount); \ + } else { \ + int __wvPrGrp = mindiv(M_in, CuCount * _YTILEb, _WvPrGrp); \ + wvSplitK_hf_big_<64, _YTILEb, _WvPrGrp, 8, _UNRLb, _N> \ + <<>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \ + CuCount); \ + } \ } switch (N_in) { case 1: - WVSPLTK(16, 2, 2, 2, 2, 2, 2, 1) // MI308 + WVSPLITK(16, 2, 2, 2, 2, 2, 2, 1) // MI308 break; case 2: - WVSPLTK(16, 2, 2, 2, 2, 2, 2, 2) // MI308 + WVSPLITK(16, 2, 2, 2, 2, 2, 2, 2) // MI308 break; case 3: - WVSPLTK(16, 4, 7, 7, 1, 1, 1, 3) // MI308 + WVSPLITK(16, 4, 7, 7, 1, 1, 1, 3) // MI308 break; case 4: - WVSPLTK(16, 4, 7, 7, 1, 1, 1, 4) // MI308 + WVSPLITK(16, 4, 7, 7, 1, 1, 1, 4) // MI308 break; default: throw std::runtime_error("Unsupported N value: " + std::to_string(M_in) + @@ -1140,11 +1140,11 @@ void wvSpltK_(void* in_a, void* in_b, void* out_c, const int M_in, } } -void wvSpltK(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, - const int64_t N_in, const int64_t CuCount) { +void wvSplitK(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, + const int64_t N_in, const int64_t CuCount) { auto M = in_a.size(0); auto K = in_a.size(1); int N = N_in; - wvSpltK_(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), M, K, N, - at::cuda::getCurrentCUDAStream(), CuCount); + wvSplitK_(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), M, K, N, + at::cuda::getCurrentCUDAStream(), CuCount); } \ No newline at end of file diff --git a/csrc/rocm/torch_bindings.cpp b/csrc/rocm/torch_bindings.cpp index cd0b17c5ac6f..0565c96801ce 100644 --- a/csrc/rocm/torch_bindings.cpp +++ b/csrc/rocm/torch_bindings.cpp @@ -22,9 +22,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) { // Custom gemm op for skinny matrix-matrix multiplication rocm_ops.def( - "wvSpltK(Tensor in_a, Tensor in_b, Tensor! out_c, int N_in," + "wvSplitK(Tensor in_a, Tensor in_b, Tensor! out_c, int N_in," " int CuCount) -> ()"); - rocm_ops.impl("wvSpltK", torch::kCUDA, &wvSpltK); + rocm_ops.impl("wvSplitK", torch::kCUDA, &wvSplitK); // Custom attention op // Compute the attention between an input query and the cached diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index c39c8740cf3a..4fb3c8aa3375 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1133,9 +1133,9 @@ def LLMM1(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor, torch.ops._rocm_C.LLMM1(a, b, out, rows_per_block) -def wvSpltK(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor, N: int, - cu_count: int) -> None: - torch.ops._rocm_C.wvSpltK(a, b, out, N, cu_count) +def wvSplitK(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor, N: int, + cu_count: int) -> None: + torch.ops._rocm_C.wvSplitK(a, b, out, N, cu_count) # moe diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index 3444e3d3b93f..eb5fe0031ae3 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -77,7 +77,7 @@ def apply_gemm_rocm(x: torch.Tensor, weight.shape[0], dtype=x.dtype, device=x.device) - ops.wvSpltK(weight, x_view, out, n, cu_count) + ops.wvSplitK(weight, x_view, out, n, cu_count) return out.view(*x.shape[:-1], weight.shape[0]) elif m % 4 == 0 and n == 1 and k <= 8192: out = torch.empty(x_view.shape[0], From 6dfdd5f2b0fecdb1ca14cdad84967cdcb4bfe061 Mon Sep 17 00:00:00 2001 From: charlifu Date: Wed, 26 Mar 2025 20:47:30 +0000 Subject: [PATCH 04/17] add env for skinny gemm Signed-off-by: charlifu --- vllm/envs.py | 6 ++++++ vllm/model_executor/layers/utils.py | 7 ++++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/vllm/envs.py b/vllm/envs.py index f0fd20c70e3b..5b0e910b961b 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -74,6 +74,7 @@ VLLM_USE_V1: bool = True VLLM_ROCM_USE_AITER: bool = False VLLM_ROCM_USE_AITER_RMSNORM: bool = True + VLLM_ROCM_USE_SKINNY_GEMM: bool = True VLLM_ROCM_FP8_PADDING: bool = True VLLM_ROCM_MOE_PADDING: bool = True VLLM_ENABLE_V1_MULTIPROCESSING: bool = True @@ -517,6 +518,11 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: lambda: (os.getenv("VLLM_ROCM_USE_AITER_RMSNORM", "True").lower() in ("true", "1")), + # use rocm skinny gemms + "VLLM_ROCM_USE_SKINNY_GEMM": + lambda: (os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in + ("true", "1")), + # Pad the fp8 weights to 256 bytes for ROCm "VLLM_ROCM_FP8_PADDING": lambda: bool(int(os.getenv("VLLM_ROCM_FP8_PADDING", "1"))), diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index eb5fe0031ae3..c9fc01cf42bc 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -5,6 +5,7 @@ import torch from vllm import _custom_ops as ops +from vllm import envs from vllm.platforms import current_platform @@ -70,7 +71,11 @@ def apply_gemm_rocm(x: torch.Tensor, n = x_view.shape[0] cu_count = current_platform.get_cu_count() - if bias is not None or x.dtype != torch.float16 or k % 8 != 0: + use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM is True and \ + bias is None and \ + x.dtype is torch.float16 and k % 8 == 0) + + if use_skinny is not True: return torch.nn.functional.linear(x, weight, bias) if m > 8 and n <= 4: out = torch.empty(x_view.shape[0], From 9aa20594d3ec45aa57da30dda301a583df524881 Mon Sep 17 00:00:00 2001 From: charlifu Date: Thu, 27 Mar 2025 16:33:05 +0000 Subject: [PATCH 05/17] add bf16 support for llmm1 Signed-off-by: charlifu --- csrc/rocm/skinny_gemms.cu | 135 ++++++++++++++++++++++++++------------ 1 file changed, 93 insertions(+), 42 deletions(-) diff --git a/csrc/rocm/skinny_gemms.cu b/csrc/rocm/skinny_gemms.cu index 206c24f1d227..be9199854df3 100644 --- a/csrc/rocm/skinny_gemms.cu +++ b/csrc/rocm/skinny_gemms.cu @@ -1,10 +1,14 @@ #include #include +#include + #include #include #include + #include #include + #include "cuda_compat.h" #if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx942__)) @@ -24,6 +28,45 @@ #define UNREACHABLE_CODE assert(false); #endif +template +struct scalar2 {}; + +template +C10_DEVICE C10_ALWAYS_INLINE float2 __s22float2(T v); + +template +C10_DEVICE C10_ALWAYS_INLINE T __float22s2_rn(float2 v); + +template <> +struct scalar2 { + using type = __half2; +}; + +template <> +C10_DEVICE C10_ALWAYS_INLINE float2 __s22float2(__half2 v) { + return __half22float2(v); +} + +template <> +C10_DEVICE C10_ALWAYS_INLINE __half2 __float22s2_rn(float2 v) { + return __float22half2_rn(v); +} + +template <> +struct scalar2 { + using type = __hip_bfloat162; +}; + +template <> +C10_DEVICE C10_ALWAYS_INLINE float2 __s22float2(__hip_bfloat162 v) { + return __bfloat1622float2(v); +} + +template <> +C10_DEVICE C10_ALWAYS_INLINE __hip_bfloat162 __float22s2_rn(float2 v) { + return __float22bfloat162_rn(v); +} + template __device__ __forceinline__ T loadnt(T* addr) { return __builtin_nontemporal_load(addr); @@ -40,9 +83,13 @@ __device__ __forceinline__ float4 load_ntmprl(const float4* addr) { // TBlock fetches entire rows of A, and entire col of B (K dimension); assume // N=1 for time being grid is M/A_NUM_ROWS blocks -template -__global__ void LLGemm1_kernel(float4* af4, __half2* bf4, __half2* c, - const int K) { +template +__global__ void LLGemm1_kernel(const scalar_t* in_a, const scalar_t* in_b, + scalar_t* out_c, const int K) { + using scalar2_t = typename scalar2::type; + auto af4 = reinterpret_cast(in_a); + auto bf4 = reinterpret_cast(in_b); + auto c = reinterpret_cast(out_c); __shared__ float red_smem[NUM_A_ROWS_PER_BLOCK][WARP_SIZE]; const int row_addr = blockIdx.x * NUM_A_ROWS_PER_BLOCK * K / 8; const int threadid = threadIdx.x; @@ -52,11 +99,11 @@ __global__ void LLGemm1_kernel(float4* af4, __half2* bf4, __half2* c, const int qwarpid = threadid / 16; const int qthreadid = threadid % 16; float4 rowA_elem4[NUM_A_ROWS_PER_BLOCK]; - __half2 colB_elem4x, colB_elem4y, colB_elem4z, colB_elem4w; + scalar2_t colB_elem4x, colB_elem4y, colB_elem4z, colB_elem4w; float4 sum4; //[NUM_A_ROWS_PER_BLOCK]; float acc[NUM_A_ROWS_PER_BLOCK] = {0.0}; - __half2 acch2; - __half2 oval; + scalar2_t acch2; + scalar2_t oval; // As we later use warp shuffle operations, we may have more threads in the // block than the actual available data, hence the if guard here. @@ -73,12 +120,12 @@ __global__ void LLGemm1_kernel(float4* af4, __half2* bf4, __half2* c, colB_elem4z = bf4[threadid * 4 + 2]; colB_elem4w = bf4[threadid * 4 + 3]; - __half2 Af2; - __half2 Bf2; + scalar2_t Af2; + scalar2_t Bf2; float2 S; - auto Ah2ptr = reinterpret_cast<__half2*>(&rowA_elem4); - __half2* ah2lptr; + auto Ah2ptr = reinterpret_cast(&rowA_elem4); + scalar2_t* ah2lptr; #pragma unroll for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) { @@ -92,7 +139,7 @@ __global__ void LLGemm1_kernel(float4* af4, __half2* bf4, __half2* c, acch2 = __hfma2(Af2, colB_elem4z, acch2); Af2 = *(ah2lptr + 3); acch2 = __hfma2(Af2, colB_elem4w, acch2); - S = __half22float2(acch2); + S = __s22float2(acch2); // See comment above concerning the if guard. if (threadid * 8 < K) { @@ -126,18 +173,22 @@ __global__ void LLGemm1_kernel(float4* af4, __half2* bf4, __half2* c, float oval2 = __shfl_xor(acc[qwarpid], 16); if (threadid % WARP_SIZE == 0 or threadid % WARP_SIZE == 32) { - oval = __float22half2_rn(make_float2(acc[qwarpid], oval2)); + oval = __float22s2_rn(make_float2(acc[qwarpid], oval2)); c[blockIdx.x * NUM_A_ROWS_PER_BLOCK / 2 + qwarpid / 2] = oval; } } } -// define the kernel calling code: -void LLGemm1(void* in_a, void* in_b, void* out_c, const int M, const int K, - cudaStream_t stream, const int rows_per_block = 4) { - float4* af4 = reinterpret_cast(in_a); - auto* bf4 = reinterpret_cast<__half2*>(in_b); - auto* c = reinterpret_cast<__half2*>(out_c); +void LLMM1(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, + const int64_t rows_per_block) { + auto M = in_a.size(0); + auto K = in_a.size(1); + auto N = in_b.size(0); + + TORCH_CHECK(N == 1, "Row number of activation tensor must be 1."); + TORCH_CHECK(in_a.dtype() == in_b.dtype()); + TORCH_CHECK(in_b.dtype() == torch::kFloat16 || + in_b.dtype() == torch::kBFloat16); // NUM_TREADS need to be a multiple of WARP_SIZE, as we are using warp shuffle // operations. @@ -148,32 +199,32 @@ void LLGemm1(void* in_a, void* in_b, void* out_c, const int M, const int K, int NUM_BLOCKS = M / rows_per_block; - if (rows_per_block == 2) { - LLGemm1_kernel<2><<>>(af4, bf4, c, K); - } else if (rows_per_block == 4) { - LLGemm1_kernel<4><<>>(af4, bf4, c, K); - } else if (rows_per_block == 8) { - LLGemm1_kernel<8><<>>(af4, bf4, c, K); - } else if (rows_per_block == 16) { - LLGemm1_kernel<16><<>>(af4, bf4, c, K); - } else { - NUM_BLOCKS = M / 4; - LLGemm1_kernel<4><<>>(af4, bf4, c, K); - } - - cudaError_t err = cudaGetLastError(); - if (cudaSuccess != err) - throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); -} - -void LLMM1(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, - const int64_t rows_per_block) { - auto M = in_a.size(0); - auto K = in_a.size(1); + const at::cuda::OptionalCUDAGuard device_guard(device_of(in_b)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // call the kernel function... - LLGemm1(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), M, K, - at::cuda::getCurrentCUDAStream(), rows_per_block); + AT_DISPATCH_REDUCED_FLOATING_TYPES(in_b.scalar_type(), "LLGemm1", [&] { + auto a_ptr = in_a.data_ptr(); + auto b_ptr = in_b.data_ptr(); + auto c_ptr = out_c.data_ptr(); + if (rows_per_block == 2) { + LLGemm1_kernel + <<>>(a_ptr, b_ptr, c_ptr, K); + } else if (rows_per_block == 4) { + LLGemm1_kernel + <<>>(a_ptr, b_ptr, c_ptr, K); + } else if (rows_per_block == 8) { + LLGemm1_kernel + <<>>(a_ptr, b_ptr, c_ptr, K); + } else if (rows_per_block == 16) { + LLGemm1_kernel + <<>>(a_ptr, b_ptr, c_ptr, K); + } else { + NUM_BLOCKS = M / 4; + LLGemm1_kernel + <<>>(a_ptr, b_ptr, c_ptr, K); + } + }); } #define DTYPE half From 16fb48c25d64fd14210fa212e99621348cff19de Mon Sep 17 00:00:00 2001 From: charlifu Date: Fri, 28 Mar 2025 16:04:06 +0000 Subject: [PATCH 06/17] update skinny gemms Signed-off-by: charlifu --- csrc/rocm/skinny_gemms.cu | 14 ++++++++------ vllm/model_executor/layers/utils.py | 2 +- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/csrc/rocm/skinny_gemms.cu b/csrc/rocm/skinny_gemms.cu index be9199854df3..d020de8c130c 100644 --- a/csrc/rocm/skinny_gemms.cu +++ b/csrc/rocm/skinny_gemms.cu @@ -32,38 +32,40 @@ template struct scalar2 {}; template -C10_DEVICE C10_ALWAYS_INLINE float2 __s22float2(T v); +__device__ __forceinline__ float2 __s22float2(T v); template -C10_DEVICE C10_ALWAYS_INLINE T __float22s2_rn(float2 v); +__device__ __forceinline__ T __float22s2_rn(float2 v); +// Vector (size 2) definition and cvt functions for fp16 template <> struct scalar2 { using type = __half2; }; template <> -C10_DEVICE C10_ALWAYS_INLINE float2 __s22float2(__half2 v) { +__device__ __forceinline__ float2 __s22float2(__half2 v) { return __half22float2(v); } template <> -C10_DEVICE C10_ALWAYS_INLINE __half2 __float22s2_rn(float2 v) { +__device__ __forceinline__ __half2 __float22s2_rn(float2 v) { return __float22half2_rn(v); } +// Vector (size 2) definition and cvt functions for bf16 template <> struct scalar2 { using type = __hip_bfloat162; }; template <> -C10_DEVICE C10_ALWAYS_INLINE float2 __s22float2(__hip_bfloat162 v) { +__device__ __forceinline__ float2 __s22float2(__hip_bfloat162 v) { return __bfloat1622float2(v); } template <> -C10_DEVICE C10_ALWAYS_INLINE __hip_bfloat162 __float22s2_rn(float2 v) { +__device__ __forceinline__ __hip_bfloat162 __float22s2_rn(float2 v) { return __float22bfloat162_rn(v); } diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index c9fc01cf42bc..77a9544c7b3c 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -77,7 +77,7 @@ def apply_gemm_rocm(x: torch.Tensor, if use_skinny is not True: return torch.nn.functional.linear(x, weight, bias) - if m > 8 and n <= 4: + if m > 8 and n <= 2: out = torch.empty(x_view.shape[0], weight.shape[0], dtype=x.dtype, From e06862eed03f9c429506307fd7e557d9fc3fd3a6 Mon Sep 17 00:00:00 2001 From: charlifu Date: Sat, 29 Mar 2025 21:56:53 +0000 Subject: [PATCH 07/17] add bf16 wvsplitK Signed-off-by: charlifu --- csrc/rocm/skinny_gemms.cu | 443 +++++++++++++--------------- vllm/model_executor/layers/utils.py | 4 +- 2 files changed, 214 insertions(+), 233 deletions(-) diff --git a/csrc/rocm/skinny_gemms.cu b/csrc/rocm/skinny_gemms.cu index d020de8c130c..f22b1097d6ba 100644 --- a/csrc/rocm/skinny_gemms.cu +++ b/csrc/rocm/skinny_gemms.cu @@ -28,16 +28,28 @@ #define UNREACHABLE_CODE assert(false); #endif +template +struct scalar {}; + template struct scalar2 {}; template __device__ __forceinline__ float2 __s22float2(T v); +template +__device__ __forceinline__ T __float2s(float v); + template __device__ __forceinline__ T __float22s2_rn(float2 v); // Vector (size 2) definition and cvt functions for fp16 + +template <> +struct scalar { + using type = half; +}; + template <> struct scalar2 { using type = __half2; @@ -48,12 +60,22 @@ __device__ __forceinline__ float2 __s22float2(__half2 v) { return __half22float2(v); } +template <> +__device__ __forceinline__ half __float2s(float v) { + return __float2half(v); +} + template <> __device__ __forceinline__ __half2 __float22s2_rn(float2 v) { return __float22half2_rn(v); } // Vector (size 2) definition and cvt functions for bf16 +template <> +struct scalar { + using type = __hip_bfloat16; +}; + template <> struct scalar2 { using type = __hip_bfloat162; @@ -64,6 +86,11 @@ __device__ __forceinline__ float2 __s22float2(__hip_bfloat162 v) { return __bfloat1622float2(v); } +template <> +__device__ __forceinline__ __hip_bfloat16 __float2s(float v) { + return __float2bfloat16(v); +} + template <> __device__ __forceinline__ __hip_bfloat162 __float22s2_rn(float2 v) { return __float22bfloat162_rn(v); @@ -131,7 +158,7 @@ __global__ void LLGemm1_kernel(const scalar_t* in_a, const scalar_t* in_b, #pragma unroll for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) { - // Multiply-add on 8 half. + // Multiply-add on 8 scalar_t. ah2lptr = Ah2ptr + i * 4; Af2 = *(ah2lptr); acch2 = __hmul2(Af2, colB_elem4x); @@ -229,23 +256,31 @@ void LLMM1(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, }); } -#define DTYPE half +#define DOT2C(V0, V2, V3) \ + if (std::is_same_v) { \ + asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(V0) : "0"(V0), "v"(V2), "v"(V3)); \ + } else if (std::is_same_v) { \ + float2 s = __bfloat1622float2(*((__hip_bfloat162*)(&(V2)))) * \ + __bfloat1622float2(*((__hip_bfloat162*)(&(V3)))); \ + V0 += (s.x + s.y); \ + } #if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support // This version targets cases where A[] fits LDS capacity -template +template __global__ void __launch_bounds__(WvPrGrp* THRDS) - wvSplitK_hf_sml_(const int K, const int N, const DTYPE* B, - const DTYPE* __restrict__ A, DTYPE* C, const int _WvPrGrp, - const int CuCount) { - using half8 = + wvSplitK_hf_sml_(const int K, const int N, const scalar_t* B, + const scalar_t* __restrict__ A, scalar_t* C, + const int _WvPrGrp, const int CuCount) { + using scalar8 = __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; union bigType { - DTYPE h[A_CHUNK]; + scalar_t h[A_CHUNK]; float f[A_CHUNK / 2]; float2 f2[A_CHUNK / 4]; double d[A_CHUNK / 4]; - half8 h8; + scalar8 h8; }; //---------------------------------------------------- @@ -255,7 +290,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // TODO: When activation matrix is larger than 64 KB // then this is not goint to work! //---------------------------------------------------- - __shared__ half s[1024 * 32]; + __shared__ scalar_t s[1024 * 32]; //---------------------------------------------------- // Fetch the activation matrix to LDS @@ -270,19 +305,12 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) k += THRDS * WvPrGrp * A_CHUNK) { uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); - // Transpose of A implementation - // uint32_t k_ot = (k_in / M) + (k_in % M) * K; // transopse for - // bank-conflict-free readback - if (k_in >= min(K * M, 32 * 1024)) break; - //((bigType*)(&s[k_in]))->b128 = ((bigType*)(&A[k_in]))->b128; *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); - //((bigType*)(&s[k_ot]))->b128 = ((bigType*)(&A[k_in]))->b128; } __syncthreads(); - // int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp); if (threadIdx.y >= _WvPrGrp) return; uint32_t n = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE; @@ -350,18 +378,18 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) uint32_t k_ = k + threadIdx.x * A_CHUNK; if (k_ >= K) break; - const half* B_ = &B[(n + 0) * K + k_]; - bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * K]))); + const scalar_t* B_ = &B[(n + 0) * K + k_]; + bigB0[k2].h8 = (loadnt((scalar8*)(&B_[0 * K]))); //---------------------------------------------------- // The following code with YTILE > 1 has to be deleted //---------------------------------------------------- - if (YTILE >= 2) bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); - if (YTILE >= 3) bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); - if (YTILE >= 4) bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); - if (YTILE >= 5) bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); - if (YTILE >= 6) bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); - if (YTILE >= 7) bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); - if (YTILE >= 8) bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); + if (YTILE >= 2) bigB1[k2].h8 = (loadnt((scalar8*)(&B_[1 * K]))); + if (YTILE >= 3) bigB2[k2].h8 = (loadnt((scalar8*)(&B_[2 * K]))); + if (YTILE >= 4) bigB3[k2].h8 = (loadnt((scalar8*)(&B_[3 * K]))); + if (YTILE >= 5) bigB4[k2].h8 = (loadnt((scalar8*)(&B_[4 * K]))); + if (YTILE >= 6) bigB5[k2].h8 = (loadnt((scalar8*)(&B_[5 * K]))); + if (YTILE >= 7) bigB6[k2].h8 = (loadnt((scalar8*)(&B_[6 * K]))); + if (YTILE >= 8) bigB7[k2].h8 = (loadnt((scalar8*)(&B_[7 * K]))); } // Fetch activation matrix from either just LDS or from both LDS / memory @@ -374,10 +402,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // Fetch A activation matrix in interleaved fashion from LDS or memory for (int m = 0; m < M; m++) { - // if (k_ + K * m < 32 * 1024) bigA[m][k2] = *((const bigType*)(&(s[k_ + K * m]))); - // else - // bigA[m][k2] = *((const bigType*)(&(A[k_ + K * m]))); } } @@ -393,41 +418,32 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) for (uint32_t m = 0; m < M; m++) { #pragma unroll for (uint32_t b = 0; b < A_CHUNK / 2; b++) { - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][0]) - : "0"(sum[m][0]), "v"(bigA[m][k2].f[b]), "v"(bigB0[k2].f[b])); + DOT2C(sum[m][0], bigA[m][k2].f[b], bigB0[k2].f[b]) //---------------------------------------------------- // The following code with YTILE > 1 //---------------------------------------------------- - if (YTILE >= 2) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][1]) - : "0"(sum[m][1]), "v"(bigA[m][k2].f[b]), "v"(bigB1[k2].f[b])); - if (YTILE >= 3) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][2]) - : "0"(sum[m][2]), "v"(bigA[m][k2].f[b]), "v"(bigB2[k2].f[b])); - if (YTILE >= 4) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][3]) - : "0"(sum[m][3]), "v"(bigA[m][k2].f[b]), "v"(bigB3[k2].f[b])); - if (YTILE >= 5) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][4]) - : "0"(sum[m][4]), "v"(bigA[m][k2].f[b]), "v"(bigB4[k2].f[b])); - if (YTILE >= 6) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][5]) - : "0"(sum[m][5]), "v"(bigA[m][k2].f[b]), "v"(bigB5[k2].f[b])); - if (YTILE >= 7) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][6]) - : "0"(sum[m][6]), "v"(bigA[m][k2].f[b]), "v"(bigB6[k2].f[b])); - if (YTILE >= 8) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][7]) - : "0"(sum[m][7]), "v"(bigA[m][k2].f[b]), "v"(bigB7[k2].f[b])); + if (YTILE >= 2) { + DOT2C(sum[m][1], bigA[m][k2].f[b], bigB1[k2].f[b]); + } + if (YTILE >= 3) { + DOT2C(sum[m][2], bigA[m][k2].f[b], bigB2[k2].f[b]); + } + if (YTILE >= 4) { + DOT2C(sum[m][3], bigA[m][k2].f[b], bigB3[k2].f[b]); + } + if (YTILE >= 5) { + DOT2C(sum[m][4], bigA[m][k2].f[b], bigB4[k2].f[b]); + } + if (YTILE >= 6) { + DOT2C(sum[m][5], bigA[m][k2].f[b], bigB5[k2].f[b]); + } + if (YTILE >= 7) { + DOT2C(sum[m][6], bigA[m][k2].f[b], bigB6[k2].f[b]); + } + if (YTILE >= 8) { + DOT2C(sum[m][7], bigA[m][k2].f[b], bigB7[k2].f[b]); + } } } } @@ -462,28 +478,19 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) for (int m = 0; m < M; m++) { for (int i = 0; i < YTILE; i++) { // if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); - C[n + i + m * N] = __float2half(sum[m][i]); + C[n + i + m * N] = __float2s(sum[m][i]); } } } n += CuCount * _WvPrGrp * YTILE; - - // Check whether there will be fragmenation! - // This will happen only for the last wave! - // if (n < N && (n + YTILE) >= N) { - // uint32_t startColumn = N - YTILE; - // for (uint32_t i = 0; i < (n - startColumn); i++) { - // commitColumn[i] = 0; - // } - // n = startColumn; - //} } } #else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support -template -__global__ void wvSplitK_hf_sml_(const int K, const int N, const DTYPE* B, - const DTYPE* __restrict__ A, DTYPE* C, +template +__global__ void wvSplitK_hf_sml_(const int K, const int N, const scalar_t* B, + const scalar_t* __restrict__ A, scalar_t* C, const int _WvPrGrp, const int CuCount) { UNREACHABLE_CODE } @@ -491,19 +498,20 @@ __global__ void wvSplitK_hf_sml_(const int K, const int N, const DTYPE* B, #if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support // This version targets cases where A[] marginally exceeds LDS capacity -template +template __global__ void __launch_bounds__(WvPrGrp* THRDS) - wvSplitK_hf_(const int K, const int N, const DTYPE* B, - const DTYPE* __restrict__ A, DTYPE* C, const int _WvPrGrp, - const int CuCount) { - using half8 = + wvSplitK_hf_(const int K, const int N, const scalar_t* B, + const scalar_t* __restrict__ A, scalar_t* C, + const int _WvPrGrp, const int CuCount) { + using scalar8 = __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; union bigType { - DTYPE h[A_CHUNK]; + scalar_t h[A_CHUNK]; float f[A_CHUNK / 2]; float2 f2[A_CHUNK / 4]; double d[A_CHUNK / 4]; - half8 h8; + scalar8 h8; }; //---------------------------------------------------- @@ -513,7 +521,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // TODO: When activation matrix is larger than 64 KB // then this is not goint to work! //---------------------------------------------------- - __shared__ half s[1024 * 32]; + __shared__ scalar_t s[1024 * 32]; //---------------------------------------------------- // Computation of columns that need to be committed to memory! @@ -554,15 +562,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) k += THRDS * WvPrGrp * A_CHUNK) { uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); - // Transpose of A implementation - // uint32_t k_ot = (k_in / M) + (k_in % M) * K; // transopse for - // bank-conflict-free readback - if (k_in >= min(K * M, 32 * 1024)) break; - //((bigType*)(&s[k_in]))->b128 = ((bigType*)(&A[k_in]))->b128; *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); - //((bigType*)(&s[k_ot]))->b128 = ((bigType*)(&A[k_in]))->b128; } __syncthreads(); @@ -632,18 +634,18 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) uint32_t k_ = k + threadIdx.x * A_CHUNK; if (k_ >= K) break; - const half* B_ = &B[(n + 0) * K + k_]; - bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * K]))); + const scalar_t* B_ = &B[(n + 0) * K + k_]; + bigB0[k2].h8 = (loadnt((scalar8*)(&B_[0 * K]))); //---------------------------------------------------- // The following code with YTILE > 1 has to be deleted //---------------------------------------------------- - if (YTILE >= 2) bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); - if (YTILE >= 3) bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); - if (YTILE >= 4) bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); - if (YTILE >= 5) bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); - if (YTILE >= 6) bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); - if (YTILE >= 7) bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); - if (YTILE >= 8) bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); + if (YTILE >= 2) bigB1[k2].h8 = (loadnt((scalar8*)(&B_[1 * K]))); + if (YTILE >= 3) bigB2[k2].h8 = (loadnt((scalar8*)(&B_[2 * K]))); + if (YTILE >= 4) bigB3[k2].h8 = (loadnt((scalar8*)(&B_[3 * K]))); + if (YTILE >= 5) bigB4[k2].h8 = (loadnt((scalar8*)(&B_[4 * K]))); + if (YTILE >= 6) bigB5[k2].h8 = (loadnt((scalar8*)(&B_[5 * K]))); + if (YTILE >= 7) bigB6[k2].h8 = (loadnt((scalar8*)(&B_[6 * K]))); + if (YTILE >= 8) bigB7[k2].h8 = (loadnt((scalar8*)(&B_[7 * K]))); } // Fetch activation matrix from either just LDS or from both LDS / memory @@ -675,41 +677,32 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // - Remember the accumulation is happening for K-split of 64! #pragma unroll for (uint32_t b = 0; b < A_CHUNK / 2; b++) { - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][0]) - : "0"(sum[m][0]), "v"(bigA[m][k2].f[b]), "v"(bigB0[k2].f[b])); + DOT2C(sum[m][0], bigA[m][k2].f[b], bigB0[k2].f[b]); //---------------------------------------------------- // The following code with YTILE > 1 //---------------------------------------------------- - if (YTILE >= 2) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][1]) - : "0"(sum[m][1]), "v"(bigA[m][k2].f[b]), "v"(bigB1[k2].f[b])); - if (YTILE >= 3) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][2]) - : "0"(sum[m][2]), "v"(bigA[m][k2].f[b]), "v"(bigB2[k2].f[b])); - if (YTILE >= 4) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][3]) - : "0"(sum[m][3]), "v"(bigA[m][k2].f[b]), "v"(bigB3[k2].f[b])); - if (YTILE >= 5) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][4]) - : "0"(sum[m][4]), "v"(bigA[m][k2].f[b]), "v"(bigB4[k2].f[b])); - if (YTILE >= 6) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][5]) - : "0"(sum[m][5]), "v"(bigA[m][k2].f[b]), "v"(bigB5[k2].f[b])); - if (YTILE >= 7) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][6]) - : "0"(sum[m][6]), "v"(bigA[m][k2].f[b]), "v"(bigB6[k2].f[b])); - if (YTILE >= 8) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][7]) - : "0"(sum[m][7]), "v"(bigA[m][k2].f[b]), "v"(bigB7[k2].f[b])); + if (YTILE >= 2) { + DOT2C(sum[m][1], bigA[m][k2].f[b], bigB1[k2].f[b]); + } + if (YTILE >= 3) { + DOT2C(sum[m][2], bigA[m][k2].f[b], bigB2[k2].f[b]); + } + if (YTILE >= 4) { + DOT2C(sum[m][3], bigA[m][k2].f[b], bigB3[k2].f[b]); + } + if (YTILE >= 5) { + DOT2C(sum[m][4], bigA[m][k2].f[b], bigB4[k2].f[b]); + } + if (YTILE >= 6) { + DOT2C(sum[m][5], bigA[m][k2].f[b], bigB5[k2].f[b]); + } + if (YTILE >= 7) { + DOT2C(sum[m][6], bigA[m][k2].f[b], bigB6[k2].f[b]); + } + if (YTILE >= 8) { + DOT2C(sum[m][7], bigA[m][k2].f[b], bigB7[k2].f[b]); + } } } } @@ -744,7 +737,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (threadIdx.x == 63) { for (int m = 0; m < M; m++) { for (int i = 0; i < YTILE; i++) { - if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); + if (commitColumn[i]) + C[n + i + m * N] = __float2s(sum[m][i]); } } } @@ -764,9 +758,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) } #else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support -template -__global__ void wvSplitK_hf_(const int K, const int N, const DTYPE* B, - const DTYPE* __restrict__ A, DTYPE* C, +template +__global__ void wvSplitK_hf_(const int K, const int N, const scalar_t* B, + const scalar_t* __restrict__ A, scalar_t* C, const int _WvPrGrp, const int CuCount) { UNREACHABLE_CODE } @@ -774,20 +769,21 @@ __global__ void wvSplitK_hf_(const int K, const int N, const DTYPE* B, #if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support // This version targets big A[] cases, where it is much larger than LDS capacity -template +template __global__ void __launch_bounds__(WvPrGrp* THRDS) - wvSplitK_hf_big_(const int K, const int N, const DTYPE* B, - const DTYPE* __restrict__ A, DTYPE* C, const int _WvPrGrp, - const int CuCount) { - using half8 = + wvSplitK_hf_big_(const int K, const int N, const scalar_t* B, + const scalar_t* __restrict__ A, scalar_t* C, + const int _WvPrGrp, const int CuCount) { + using scalar8 = __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; union bigType { - DTYPE h[A_CHUNK]; + scalar_t h[A_CHUNK]; float f[A_CHUNK / 2]; float2 f2[A_CHUNK / 4]; double d[A_CHUNK / 4]; - half8 h8; + scalar8 h8; }; //---------------------------------------------------- @@ -797,7 +793,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // TODO: When activation matrix is larger than 64 KB // then this is not goint to work! //---------------------------------------------------- - __shared__ half s[1024 * 32]; + __shared__ scalar_t s[1024 * 32]; //---------------------------------------------------- // Computation of columns that need to be committed to memory! @@ -842,15 +838,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) k += THRDS * WvPrGrp * A_CHUNK) { uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); - // Transpose of A implementation - // uint32_t k_ot = (k_in / M) + (k_in % M) * K; // transopse for - // bank-conflict-free readback - if (k_in >= min(K * M, 32 * 1024)) break; - //((bigType*)(&s[k_in]))->b128 = ((bigType*)(&A[k_in]))->b128; *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); - //((bigType*)(&s[k_ot]))->b128 = ((bigType*)(&A[k_in]))->b128; } __syncthreads(); #endif @@ -957,18 +947,18 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) uint32_t k_ = k + threadIdx.x * A_CHUNK; if (k_ >= K) break; - const half* B_ = &B[(n + 0) * K + k_]; - bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * K]))); + const scalar_t* B_ = &B[(n + 0) * K + k_]; + bigB0[k2].h8 = (loadnt((scalar8*)(&B_[0 * K]))); //---------------------------------------------------- // The following code with YTILE > 1 has to be deleted //---------------------------------------------------- - if (YTILE >= 2) bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); - if (YTILE >= 3) bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); - if (YTILE >= 4) bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); - if (YTILE >= 5) bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); - if (YTILE >= 6) bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); - if (YTILE >= 7) bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); - if (YTILE >= 8) bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); + if (YTILE >= 2) bigB1[k2].h8 = (loadnt((scalar8*)(&B_[1 * K]))); + if (YTILE >= 3) bigB2[k2].h8 = (loadnt((scalar8*)(&B_[2 * K]))); + if (YTILE >= 4) bigB3[k2].h8 = (loadnt((scalar8*)(&B_[3 * K]))); + if (YTILE >= 5) bigB4[k2].h8 = (loadnt((scalar8*)(&B_[4 * K]))); + if (YTILE >= 6) bigB5[k2].h8 = (loadnt((scalar8*)(&B_[5 * K]))); + if (YTILE >= 7) bigB6[k2].h8 = (loadnt((scalar8*)(&B_[6 * K]))); + if (YTILE >= 8) bigB7[k2].h8 = (loadnt((scalar8*)(&B_[7 * K]))); } // Fetch activation matrix from either just LDS or from both LDS / memory @@ -1004,41 +994,32 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // - Remember the accumulation is happening for K-split of 64! #pragma unroll for (uint32_t b = 0; b < A_CHUNK / 2; b++) { - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][0]) - : "0"(sum[m][0]), "v"(bigA[m][k2].f[b]), "v"(bigB0[k2].f[b])); + DOT2C(sum[m][0], bigA[m][k2].f[b], bigB0[k2].f[b]); //---------------------------------------------------- // The following code with YTILE > 1 //---------------------------------------------------- - if (YTILE >= 2) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][1]) - : "0"(sum[m][1]), "v"(bigA[m][k2].f[b]), "v"(bigB1[k2].f[b])); - if (YTILE >= 3) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][2]) - : "0"(sum[m][2]), "v"(bigA[m][k2].f[b]), "v"(bigB2[k2].f[b])); - if (YTILE >= 4) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][3]) - : "0"(sum[m][3]), "v"(bigA[m][k2].f[b]), "v"(bigB3[k2].f[b])); - if (YTILE >= 5) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][4]) - : "0"(sum[m][4]), "v"(bigA[m][k2].f[b]), "v"(bigB4[k2].f[b])); - if (YTILE >= 6) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][5]) - : "0"(sum[m][5]), "v"(bigA[m][k2].f[b]), "v"(bigB5[k2].f[b])); - if (YTILE >= 7) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][6]) - : "0"(sum[m][6]), "v"(bigA[m][k2].f[b]), "v"(bigB6[k2].f[b])); - if (YTILE >= 8) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][7]) - : "0"(sum[m][7]), "v"(bigA[m][k2].f[b]), "v"(bigB7[k2].f[b])); + if (YTILE >= 2) { + DOT2C(sum[m][1], bigA[m][k2].f[b], bigB1[k2].f[b]); + } + if (YTILE >= 3) { + DOT2C(sum[m][2], bigA[m][k2].f[b], bigB2[k2].f[b]); + } + if (YTILE >= 4) { + DOT2C(sum[m][3], bigA[m][k2].f[b], bigB3[k2].f[b]); + } + if (YTILE >= 5) { + DOT2C(sum[m][4], bigA[m][k2].f[b], bigB4[k2].f[b]); + } + if (YTILE >= 6) { + DOT2C(sum[m][5], bigA[m][k2].f[b], bigB5[k2].f[b]); + } + if (YTILE >= 7) { + DOT2C(sum[m][6], bigA[m][k2].f[b], bigB6[k2].f[b]); + } + if (YTILE >= 8) { + DOT2C(sum[m][7], bigA[m][k2].f[b], bigB7[k2].f[b]); + } } } } @@ -1081,7 +1062,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (threadIdx.x == 63) { for (int m = 0; m < M; m++) { for (int i = 0; i < YTILE; i++) { - if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); + if (commitColumn[i]) + C[n + i + m * N] = __float2s(sum[m][i]); } } } @@ -1101,9 +1083,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) } } #else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support -template -__global__ void wvSplitK_hf_big_(const int K, const int N, const DTYPE* B, - const DTYPE* __restrict__ A, DTYPE* C, +template +__global__ void wvSplitK_hf_big_(const int K, const int N, const scalar_t* B, + const scalar_t* __restrict__ A, scalar_t* C, const int _WvPrGrp, const int CuCount) { UNREACHABLE_CODE } @@ -1138,13 +1121,19 @@ int mindiv(int N, int div1, int div2) { return rtn; } -void wvSplitK_(void* in_a, void* in_b, void* out_c, const int M_in, - const int K_in, const int N_in, cudaStream_t stream, - const int CuCount = 0) { +void wvSplitK(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, + const int64_t N_in, const int64_t CuCount) { + auto M_in = in_a.size(0); + auto K_in = in_a.size(1); + + TORCH_CHECK(in_a.dtype() == in_b.dtype()); + TORCH_CHECK(in_a.dtype() == torch::kFloat16 || + in_a.dtype() == torch::kBFloat16); + dim3 grid(CuCount); - half* af4 = reinterpret_cast(in_a); - const half* bf4 = reinterpret_cast(in_b); - auto* c = reinterpret_cast(out_c); + + const at::cuda::OptionalCUDAGuard device_guard(device_of(in_a)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); #define WVSPLITK(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \ _N) \ @@ -1152,52 +1141,44 @@ void wvSplitK_(void* in_a, void* in_b, void* out_c, const int M_in, dim3 block(64, _WvPrGrp); \ if ((K_in * N_in <= 32 * 1024) && (M_in % _YTILEs == 0)) { \ int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \ - wvSplitK_hf_sml_<64, _YTILEs, _WvPrGrp, 8, _UNRLs, _N> \ + wvSplitK_hf_sml_ \ <<>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \ CuCount); \ } else if (K_in * N_in <= 32 * 1024 * 1.2) { \ int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \ - wvSplitK_hf_<64, _YTILEm, _WvPrGrp, 8, _UNRLm, _N> \ + wvSplitK_hf_ \ <<>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \ CuCount); \ } else { \ int __wvPrGrp = mindiv(M_in, CuCount * _YTILEb, _WvPrGrp); \ - wvSplitK_hf_big_<64, _YTILEb, _WvPrGrp, 8, _UNRLb, _N> \ + wvSplitK_hf_big_ \ <<>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \ CuCount); \ } \ } - switch (N_in) { - case 1: - WVSPLITK(16, 2, 2, 2, 2, 2, 2, 1) // MI308 - break; - case 2: - WVSPLITK(16, 2, 2, 2, 2, 2, 2, 2) // MI308 - break; - case 3: - WVSPLITK(16, 4, 7, 7, 1, 1, 1, 3) // MI308 - break; - case 4: - WVSPLITK(16, 4, 7, 7, 1, 1, 1, 4) // MI308 - break; - default: - throw std::runtime_error("Unsupported N value: " + std::to_string(M_in) + - "," + std::to_string(K_in) + "," + - std::to_string(N_in)); - } - - cudaError_t err = cudaGetLastError(); - if (cudaSuccess != err) { - throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); - } -} - -void wvSplitK(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, - const int64_t N_in, const int64_t CuCount) { - auto M = in_a.size(0); - auto K = in_a.size(1); - int N = N_in; - wvSplitK_(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), M, K, N, - at::cuda::getCurrentCUDAStream(), CuCount); + AT_DISPATCH_REDUCED_FLOATING_TYPES(in_b.scalar_type(), "wvSplitK", [&] { + using fptype = typename scalar::type; + fptype* af4 = reinterpret_cast(in_a.data_ptr()); + const fptype* bf4 = reinterpret_cast(in_b.data_ptr()); + fptype* c = reinterpret_cast(out_c.data_ptr()); + switch (N_in) { + case 1: + WVSPLITK(16, 2, 2, 2, 2, 2, 2, 1) // MI308 + break; + case 2: + WVSPLITK(16, 2, 2, 2, 2, 2, 2, 2) // MI308 + break; + case 3: + WVSPLITK(16, 4, 7, 7, 1, 1, 1, 3) // MI308 + break; + case 4: + WVSPLITK(16, 4, 7, 7, 1, 1, 1, 4) // MI308 + break; + default: + throw std::runtime_error( + "Unsupported N value: " + std::to_string(M_in) + "," + + std::to_string(K_in) + "," + std::to_string(N_in)); + } + }); } \ No newline at end of file diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index 77a9544c7b3c..44ef84aca9a9 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -72,8 +72,8 @@ def apply_gemm_rocm(x: torch.Tensor, cu_count = current_platform.get_cu_count() use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM is True and \ - bias is None and \ - x.dtype is torch.float16 and k % 8 == 0) + x.dtype in [torch.float16, torch.bfloat16] \ + and k % 8 == 0 and bias is None ) if use_skinny is not True: return torch.nn.functional.linear(x, weight, bias) From 5c60d0b5aa85e39e0d146a78d5248317a31bb79b Mon Sep 17 00:00:00 2001 From: charlifu Date: Mon, 31 Mar 2025 16:23:54 +0000 Subject: [PATCH 08/17] clean up Signed-off-by: charlifu --- csrc/rocm/skinny_gemms.cu | 138 +++++++++++---------------- vllm/model_executor/layers/linear.py | 9 +- vllm/model_executor/layers/utils.py | 14 ++- 3 files changed, 67 insertions(+), 94 deletions(-) diff --git a/csrc/rocm/skinny_gemms.cu b/csrc/rocm/skinny_gemms.cu index f22b1097d6ba..6c091de2ab5d 100644 --- a/csrc/rocm/skinny_gemms.cu +++ b/csrc/rocm/skinny_gemms.cu @@ -344,14 +344,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) for (int m = 0; m < M; m++) sum[m][i] = 0; bigType bigA[M][UNRL]; - bigType bigB0[UNRL]; - bigType bigB1[UNRL]; - bigType bigB2[UNRL]; - bigType bigB3[UNRL]; - bigType bigB4[UNRL]; - bigType bigB5[UNRL]; - bigType bigB6[UNRL]; - bigType bigB7[UNRL]; + bigType bigB[YTILE][UNRL]; //---------------------------------------------------- // Fetch weight matrix B in interleaved K-split! // - Each thread (lane) is fetching 8 elements (A_Chunk) @@ -379,17 +372,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (k_ >= K) break; const scalar_t* B_ = &B[(n + 0) * K + k_]; - bigB0[k2].h8 = (loadnt((scalar8*)(&B_[0 * K]))); + bigB[0][k2].h8 = (loadnt((scalar8*)(&B_[0 * K]))); //---------------------------------------------------- // The following code with YTILE > 1 has to be deleted //---------------------------------------------------- - if (YTILE >= 2) bigB1[k2].h8 = (loadnt((scalar8*)(&B_[1 * K]))); - if (YTILE >= 3) bigB2[k2].h8 = (loadnt((scalar8*)(&B_[2 * K]))); - if (YTILE >= 4) bigB3[k2].h8 = (loadnt((scalar8*)(&B_[3 * K]))); - if (YTILE >= 5) bigB4[k2].h8 = (loadnt((scalar8*)(&B_[4 * K]))); - if (YTILE >= 6) bigB5[k2].h8 = (loadnt((scalar8*)(&B_[5 * K]))); - if (YTILE >= 7) bigB6[k2].h8 = (loadnt((scalar8*)(&B_[6 * K]))); - if (YTILE >= 8) bigB7[k2].h8 = (loadnt((scalar8*)(&B_[7 * K]))); + if (YTILE >= 2) bigB[1][k2].h8 = (loadnt((scalar8*)(&B_[1 * K]))); + if (YTILE >= 3) bigB[2][k2].h8 = (loadnt((scalar8*)(&B_[2 * K]))); + if (YTILE >= 4) bigB[3][k2].h8 = (loadnt((scalar8*)(&B_[3 * K]))); + if (YTILE >= 5) bigB[4][k2].h8 = (loadnt((scalar8*)(&B_[4 * K]))); + if (YTILE >= 6) bigB[5][k2].h8 = (loadnt((scalar8*)(&B_[5 * K]))); + if (YTILE >= 7) bigB[6][k2].h8 = (loadnt((scalar8*)(&B_[6 * K]))); + if (YTILE >= 8) bigB[7][k2].h8 = (loadnt((scalar8*)(&B_[7 * K]))); } // Fetch activation matrix from either just LDS or from both LDS / memory @@ -418,31 +411,30 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) for (uint32_t m = 0; m < M; m++) { #pragma unroll for (uint32_t b = 0; b < A_CHUNK / 2; b++) { - DOT2C(sum[m][0], bigA[m][k2].f[b], bigB0[k2].f[b]) - + DOT2C(sum[m][0], bigA[m][k2].f[b], bigB[0][k2].f[b]) //---------------------------------------------------- // The following code with YTILE > 1 //---------------------------------------------------- if (YTILE >= 2) { - DOT2C(sum[m][1], bigA[m][k2].f[b], bigB1[k2].f[b]); + DOT2C(sum[m][1], bigA[m][k2].f[b], bigB[1][k2].f[b]); } if (YTILE >= 3) { - DOT2C(sum[m][2], bigA[m][k2].f[b], bigB2[k2].f[b]); + DOT2C(sum[m][2], bigA[m][k2].f[b], bigB[2][k2].f[b]); } if (YTILE >= 4) { - DOT2C(sum[m][3], bigA[m][k2].f[b], bigB3[k2].f[b]); + DOT2C(sum[m][3], bigA[m][k2].f[b], bigB[3][k2].f[b]); } if (YTILE >= 5) { - DOT2C(sum[m][4], bigA[m][k2].f[b], bigB4[k2].f[b]); + DOT2C(sum[m][4], bigA[m][k2].f[b], bigB[4][k2].f[b]); } if (YTILE >= 6) { - DOT2C(sum[m][5], bigA[m][k2].f[b], bigB5[k2].f[b]); + DOT2C(sum[m][5], bigA[m][k2].f[b], bigB[5][k2].f[b]); } if (YTILE >= 7) { - DOT2C(sum[m][6], bigA[m][k2].f[b], bigB6[k2].f[b]); + DOT2C(sum[m][6], bigA[m][k2].f[b], bigB[6][k2].f[b]); } if (YTILE >= 8) { - DOT2C(sum[m][7], bigA[m][k2].f[b], bigB7[k2].f[b]); + DOT2C(sum[m][7], bigA[m][k2].f[b], bigB[7][k2].f[b]); } } } @@ -600,15 +592,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) for (int m = 0; m < M; m++) sum[m][i] = 0; bigType bigA[M][UNRL]; - bigType bigB0[UNRL]; - bigType bigB1[UNRL]; - bigType bigB2[UNRL]; - bigType bigB3[UNRL]; - bigType bigB4[UNRL]; - bigType bigB5[UNRL]; - bigType bigB6[UNRL]; - bigType bigB7[UNRL]; - bigType bigB8[UNRL]; + bigType bigB[YTILE][UNRL]; //---------------------------------------------------- // Fetch weight matrix B in interleaved K-split! // - Each thread (lane) is fetching 8 elements (A_Chunk) @@ -635,17 +619,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (k_ >= K) break; const scalar_t* B_ = &B[(n + 0) * K + k_]; - bigB0[k2].h8 = (loadnt((scalar8*)(&B_[0 * K]))); + bigB[0][k2].h8 = (loadnt((scalar8*)(&B_[0 * K]))); //---------------------------------------------------- // The following code with YTILE > 1 has to be deleted //---------------------------------------------------- - if (YTILE >= 2) bigB1[k2].h8 = (loadnt((scalar8*)(&B_[1 * K]))); - if (YTILE >= 3) bigB2[k2].h8 = (loadnt((scalar8*)(&B_[2 * K]))); - if (YTILE >= 4) bigB3[k2].h8 = (loadnt((scalar8*)(&B_[3 * K]))); - if (YTILE >= 5) bigB4[k2].h8 = (loadnt((scalar8*)(&B_[4 * K]))); - if (YTILE >= 6) bigB5[k2].h8 = (loadnt((scalar8*)(&B_[5 * K]))); - if (YTILE >= 7) bigB6[k2].h8 = (loadnt((scalar8*)(&B_[6 * K]))); - if (YTILE >= 8) bigB7[k2].h8 = (loadnt((scalar8*)(&B_[7 * K]))); + if (YTILE >= 2) bigB[1][k2].h8 = (loadnt((scalar8*)(&B_[1 * K]))); + if (YTILE >= 3) bigB[2][k2].h8 = (loadnt((scalar8*)(&B_[2 * K]))); + if (YTILE >= 4) bigB[3][k2].h8 = (loadnt((scalar8*)(&B_[3 * K]))); + if (YTILE >= 5) bigB[4][k2].h8 = (loadnt((scalar8*)(&B_[4 * K]))); + if (YTILE >= 6) bigB[5][k2].h8 = (loadnt((scalar8*)(&B_[5 * K]))); + if (YTILE >= 7) bigB[6][k2].h8 = (loadnt((scalar8*)(&B_[6 * K]))); + if (YTILE >= 8) bigB[7][k2].h8 = (loadnt((scalar8*)(&B_[7 * K]))); } // Fetch activation matrix from either just LDS or from both LDS / memory @@ -677,31 +661,30 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // - Remember the accumulation is happening for K-split of 64! #pragma unroll for (uint32_t b = 0; b < A_CHUNK / 2; b++) { - DOT2C(sum[m][0], bigA[m][k2].f[b], bigB0[k2].f[b]); - + DOT2C(sum[m][0], bigA[m][k2].f[b], bigB[0][k2].f[b]); //---------------------------------------------------- // The following code with YTILE > 1 //---------------------------------------------------- if (YTILE >= 2) { - DOT2C(sum[m][1], bigA[m][k2].f[b], bigB1[k2].f[b]); + DOT2C(sum[m][1], bigA[m][k2].f[b], bigB[1][k2].f[b]); } if (YTILE >= 3) { - DOT2C(sum[m][2], bigA[m][k2].f[b], bigB2[k2].f[b]); + DOT2C(sum[m][2], bigA[m][k2].f[b], bigB[2][k2].f[b]); } if (YTILE >= 4) { - DOT2C(sum[m][3], bigA[m][k2].f[b], bigB3[k2].f[b]); + DOT2C(sum[m][3], bigA[m][k2].f[b], bigB[3][k2].f[b]); } if (YTILE >= 5) { - DOT2C(sum[m][4], bigA[m][k2].f[b], bigB4[k2].f[b]); + DOT2C(sum[m][4], bigA[m][k2].f[b], bigB[4][k2].f[b]); } if (YTILE >= 6) { - DOT2C(sum[m][5], bigA[m][k2].f[b], bigB5[k2].f[b]); + DOT2C(sum[m][5], bigA[m][k2].f[b], bigB[5][k2].f[b]); } if (YTILE >= 7) { - DOT2C(sum[m][6], bigA[m][k2].f[b], bigB6[k2].f[b]); + DOT2C(sum[m][6], bigA[m][k2].f[b], bigB[6][k2].f[b]); } if (YTILE >= 8) { - DOT2C(sum[m][7], bigA[m][k2].f[b], bigB7[k2].f[b]); + DOT2C(sum[m][7], bigA[m][k2].f[b], bigB[7][k2].f[b]); } } } @@ -892,17 +875,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) for (int m = 0; m < M; m++) sum[m][i] = 0; bigType bigA[M][UNRL]; - bigType bigB0[UNRL]; - bigType bigB1[UNRL]; - bigType bigB2[UNRL]; - bigType bigB3[UNRL]; - bigType bigB4[UNRL]; - bigType bigB5[UNRL]; - bigType bigB6[UNRL]; - bigType bigB7[UNRL]; - bigType bigB8[UNRL]; - bigType bigB9[UNRL]; - bigType bigB10[UNRL]; + bigType bigB[YTILE][UNRL]; //---------------------------------------------------- // Fetch weight matrix B in interleaved K-split! // - Each thread (lane) is fetching 8 elements (A_Chunk) @@ -948,17 +921,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (k_ >= K) break; const scalar_t* B_ = &B[(n + 0) * K + k_]; - bigB0[k2].h8 = (loadnt((scalar8*)(&B_[0 * K]))); + bigB[0][k2].h8 = (loadnt((scalar8*)(&B_[0 * K]))); //---------------------------------------------------- // The following code with YTILE > 1 has to be deleted //---------------------------------------------------- - if (YTILE >= 2) bigB1[k2].h8 = (loadnt((scalar8*)(&B_[1 * K]))); - if (YTILE >= 3) bigB2[k2].h8 = (loadnt((scalar8*)(&B_[2 * K]))); - if (YTILE >= 4) bigB3[k2].h8 = (loadnt((scalar8*)(&B_[3 * K]))); - if (YTILE >= 5) bigB4[k2].h8 = (loadnt((scalar8*)(&B_[4 * K]))); - if (YTILE >= 6) bigB5[k2].h8 = (loadnt((scalar8*)(&B_[5 * K]))); - if (YTILE >= 7) bigB6[k2].h8 = (loadnt((scalar8*)(&B_[6 * K]))); - if (YTILE >= 8) bigB7[k2].h8 = (loadnt((scalar8*)(&B_[7 * K]))); + if (YTILE >= 2) bigB[1][k2].h8 = (loadnt((scalar8*)(&B_[1 * K]))); + if (YTILE >= 3) bigB[2][k2].h8 = (loadnt((scalar8*)(&B_[2 * K]))); + if (YTILE >= 4) bigB[3][k2].h8 = (loadnt((scalar8*)(&B_[3 * K]))); + if (YTILE >= 5) bigB[4][k2].h8 = (loadnt((scalar8*)(&B_[4 * K]))); + if (YTILE >= 6) bigB[5][k2].h8 = (loadnt((scalar8*)(&B_[5 * K]))); + if (YTILE >= 7) bigB[6][k2].h8 = (loadnt((scalar8*)(&B_[6 * K]))); + if (YTILE >= 8) bigB[7][k2].h8 = (loadnt((scalar8*)(&B_[7 * K]))); } // Fetch activation matrix from either just LDS or from both LDS / memory @@ -994,31 +967,30 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // - Remember the accumulation is happening for K-split of 64! #pragma unroll for (uint32_t b = 0; b < A_CHUNK / 2; b++) { - DOT2C(sum[m][0], bigA[m][k2].f[b], bigB0[k2].f[b]); - + DOT2C(sum[m][0], bigA[m][k2].f[b], bigB[0][k2].f[b]); //---------------------------------------------------- // The following code with YTILE > 1 //---------------------------------------------------- if (YTILE >= 2) { - DOT2C(sum[m][1], bigA[m][k2].f[b], bigB1[k2].f[b]); + DOT2C(sum[m][1], bigA[m][k2].f[b], bigB[1][k2].f[b]); } if (YTILE >= 3) { - DOT2C(sum[m][2], bigA[m][k2].f[b], bigB2[k2].f[b]); + DOT2C(sum[m][2], bigA[m][k2].f[b], bigB[2][k2].f[b]); } if (YTILE >= 4) { - DOT2C(sum[m][3], bigA[m][k2].f[b], bigB3[k2].f[b]); + DOT2C(sum[m][3], bigA[m][k2].f[b], bigB[3][k2].f[b]); } if (YTILE >= 5) { - DOT2C(sum[m][4], bigA[m][k2].f[b], bigB4[k2].f[b]); + DOT2C(sum[m][4], bigA[m][k2].f[b], bigB[4][k2].f[b]); } if (YTILE >= 6) { - DOT2C(sum[m][5], bigA[m][k2].f[b], bigB5[k2].f[b]); + DOT2C(sum[m][5], bigA[m][k2].f[b], bigB[5][k2].f[b]); } if (YTILE >= 7) { - DOT2C(sum[m][6], bigA[m][k2].f[b], bigB6[k2].f[b]); + DOT2C(sum[m][6], bigA[m][k2].f[b], bigB[6][k2].f[b]); } if (YTILE >= 8) { - DOT2C(sum[m][7], bigA[m][k2].f[b], bigB7[k2].f[b]); + DOT2C(sum[m][7], bigA[m][k2].f[b], bigB[7][k2].f[b]); } } } @@ -1164,16 +1136,16 @@ void wvSplitK(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, fptype* c = reinterpret_cast(out_c.data_ptr()); switch (N_in) { case 1: - WVSPLITK(16, 2, 2, 2, 2, 2, 2, 1) // MI308 + WVSPLITK(16, 2, 2, 2, 2, 2, 2, 1) break; case 2: - WVSPLITK(16, 2, 2, 2, 2, 2, 2, 2) // MI308 + WVSPLITK(16, 2, 2, 2, 2, 2, 2, 2) break; case 3: - WVSPLITK(16, 4, 7, 7, 1, 1, 1, 3) // MI308 + WVSPLITK(16, 4, 7, 7, 1, 1, 1, 3) break; case 4: - WVSPLITK(16, 4, 7, 7, 1, 1, 1, 4) // MI308 + WVSPLITK(16, 4, 7, 7, 1, 1, 1, 4) break; default: throw std::runtime_error( diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index f5b9155c6153..6f8e6ec0b213 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -6,7 +6,6 @@ import torch import torch.nn as nn -import torch.nn.functional as F from torch.nn.parameter import Parameter, UninitializedParameter from vllm.distributed import (divide, get_tensor_model_parallel_rank, @@ -17,7 +16,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) -from vllm.model_executor.layers.utils import apply_gemm_rocm +from vllm.model_executor.layers.utils import dispatch_unquantized_gemm # yapf: disable from vllm.model_executor.parameter import (BasevLLMParameter, BlockQuantScaleParameter, @@ -27,7 +26,6 @@ RowvLLMParameter) # yapf: enable from vllm.model_executor.utils import set_weight_attrs -from vllm.platforms import current_platform logger = init_logger(__name__) @@ -190,10 +188,7 @@ def apply(self, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - if current_platform.is_rocm(): - return apply_gemm_rocm(x, layer.weight, bias) - - return F.linear(x, layer.weight, bias) + return dispatch_unquantized_gemm()(x, layer.weight, bias) class LinearBase(torch.nn.Module): diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index 44ef84aca9a9..0e8a5e184a8f 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 """Utility methods for model layers.""" -from typing import Optional, Tuple +from typing import Callable, Optional, Tuple import torch @@ -62,9 +62,9 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, return logits -def apply_gemm_rocm(x: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor] = None): +def rocm_unquantized_gemm(x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None): x_view = x.view(-1, x.size(-1)) m = weight.shape[0] k = weight.shape[1] @@ -92,3 +92,9 @@ def apply_gemm_rocm(x: torch.Tensor, ops.LLMM1(weight, x_view, out, 4) return out.view(*x.shape[:-1], weight.shape[0]) return torch.nn.functional.linear(x, weight, bias) + + +def dispatch_unquantized_gemm() -> Callable[..., torch.Tensor]: + if current_platform.is_rocm(): + return rocm_unquantized_gemm + return torch.nn.functional.linear From c017ce13a0099465e09cbdcb097ac01d795e9270 Mon Sep 17 00:00:00 2001 From: charlifu Date: Mon, 31 Mar 2025 21:25:46 +0000 Subject: [PATCH 09/17] add n == 3 case Signed-off-by: charlifu --- vllm/model_executor/layers/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index 0e8a5e184a8f..fc8a1549abbb 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -73,11 +73,11 @@ def rocm_unquantized_gemm(x: torch.Tensor, use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM is True and \ x.dtype in [torch.float16, torch.bfloat16] \ - and k % 8 == 0 and bias is None ) + and k % 8 == 0 and bias is None) if use_skinny is not True: return torch.nn.functional.linear(x, weight, bias) - if m > 8 and n <= 2: + if m > 8 and n < 4: out = torch.empty(x_view.shape[0], weight.shape[0], dtype=x.dtype, From 76f81726fb5801afa5eeca71f4e46f3d1dc24203 Mon Sep 17 00:00:00 2001 From: charlifu Date: Tue, 1 Apr 2025 14:50:55 +0000 Subject: [PATCH 10/17] disable fp8 gemm padding for rocm Signed-off-by: charlifu --- vllm/model_executor/layers/quantization/utils/w8a8_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index b8e6384d7359..de1001ed2d33 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -156,7 +156,8 @@ def __init__(self, if pad_output is None: config = get_current_vllm_config().compilation_config pad_output = config.level < CompilationLevel.PIECEWISE - self.output_padding = 17 if pad_output else None + self.output_padding = 17 if ( + pad_output and not current_platform.is_rocm()) else None def apply( self, From 91205a4fe5bae6e5b12b4984acdb7f99344c2039 Mon Sep 17 00:00:00 2001 From: charlifu Date: Tue, 8 Apr 2025 02:19:25 +0000 Subject: [PATCH 11/17] add wvsplitK fp8 and unit tests Signed-off-by: charlifu --- csrc/rocm/ops.h | 11 +- csrc/rocm/skinny_gemms.cu | 880 +++++++++++++----- csrc/rocm/torch_bindings.cpp | 14 +- tests/kernels/test_rocm_skinny_gemms.py | 81 ++ vllm/_custom_ops.py | 21 +- .../layers/quantization/utils/w8a8_utils.py | 255 +++-- vllm/model_executor/layers/utils.py | 17 +- vllm/platforms/interface.py | 7 + vllm/platforms/rocm.py | 7 + 9 files changed, 957 insertions(+), 336 deletions(-) create mode 100644 tests/kernels/test_rocm_skinny_gemms.py diff --git a/csrc/rocm/ops.h b/csrc/rocm/ops.h index b248fd70e9b0..c435946bfa14 100644 --- a/csrc/rocm/ops.h +++ b/csrc/rocm/ops.h @@ -2,11 +2,14 @@ #include -void LLMM1(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, - const int64_t rows_per_block); +torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b, + const int64_t rows_per_block); -void wvSplitK(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, - const int64_t N_in, const int64_t CuCount); +torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b, + const int64_t CuCount); + +void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, + at::Tensor& scale_a, at::Tensor& scale_b, const int64_t CuCount); void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& tmp_out, diff --git a/csrc/rocm/skinny_gemms.cu b/csrc/rocm/skinny_gemms.cu index 6c091de2ab5d..29dbbe8e35e8 100644 --- a/csrc/rocm/skinny_gemms.cu +++ b/csrc/rocm/skinny_gemms.cu @@ -10,6 +10,8 @@ #include #include "cuda_compat.h" +#include "dispatch_utils.h" +#include "quantization/fp8/common.cuh" #if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx942__)) #define __HIP__MI300_MI250__ @@ -43,8 +45,7 @@ __device__ __forceinline__ T __float2s(float v); template __device__ __forceinline__ T __float22s2_rn(float2 v); -// Vector (size 2) definition and cvt functions for fp16 - +// Definitions and cvt functions for fp16 template <> struct scalar { using type = half; @@ -56,13 +57,13 @@ struct scalar2 { }; template <> -__device__ __forceinline__ float2 __s22float2(__half2 v) { - return __half22float2(v); +__device__ __forceinline__ half __float2s(float v) { + return __float2half(v); } template <> -__device__ __forceinline__ half __float2s(float v) { - return __float2half(v); +__device__ __forceinline__ float2 __s22float2(__half2 v) { + return __half22float2(v); } template <> @@ -70,7 +71,7 @@ __device__ __forceinline__ __half2 __float22s2_rn(float2 v) { return __float22half2_rn(v); } -// Vector (size 2) definition and cvt functions for bf16 +// Definitions and cvt functions for bf16 template <> struct scalar { using type = __hip_bfloat16; @@ -82,13 +83,13 @@ struct scalar2 { }; template <> -__device__ __forceinline__ float2 __s22float2(__hip_bfloat162 v) { - return __bfloat1622float2(v); +__device__ __forceinline__ __hip_bfloat16 __float2s(float v) { + return __float2bfloat16(v); } template <> -__device__ __forceinline__ __hip_bfloat16 __float2s(float v) { - return __float2bfloat16(v); +__device__ __forceinline__ float2 __s22float2(__hip_bfloat162 v) { + return __bfloat1622float2(v); } template <> @@ -125,12 +126,11 @@ __global__ void LLGemm1_kernel(const scalar_t* in_a, const scalar_t* in_b, const int warp = threadIdx.x / WARP_SIZE; const int lane = threadIdx.x % WARP_SIZE; const int num_warps = blockDim.x / WARP_SIZE; - const int qwarpid = threadid / 16; - const int qthreadid = threadid % 16; + const int qwarpid = threadid / num_warps; + const int qthreadid = threadid % num_warps; float4 rowA_elem4[NUM_A_ROWS_PER_BLOCK]; scalar2_t colB_elem4x, colB_elem4y, colB_elem4z, colB_elem4w; - float4 sum4; //[NUM_A_ROWS_PER_BLOCK]; - float acc[NUM_A_ROWS_PER_BLOCK] = {0.0}; + float acc[NUM_A_ROWS_PER_BLOCK]; scalar2_t acch2; scalar2_t oval; @@ -171,9 +171,7 @@ __global__ void LLGemm1_kernel(const scalar_t* in_a, const scalar_t* in_b, S = __s22float2(acch2); // See comment above concerning the if guard. - if (threadid * 8 < K) { - acc[i] = S.x + S.y; // accumulation on float - } + acc[i] = (threadid * 8 < K ? S.x + S.y : 0.f); } // all reduce across warp. @@ -195,21 +193,20 @@ __global__ void LLGemm1_kernel(const scalar_t* in_a, const scalar_t* in_b, if (qwarpid < NUM_A_ROWS_PER_BLOCK) { acc[qwarpid] = qthreadid < num_warps ? red_smem[qwarpid][qthreadid] : 0.f; -#pragma unroll - for (int mask = 16 / 2; mask >= 1; mask /= 2) { + for (int mask = num_warps / 2; mask >= 1; mask /= 2) { acc[qwarpid] += __shfl_xor(acc[qwarpid], mask); } - float oval2 = __shfl_xor(acc[qwarpid], 16); + float oval2 = __shfl_xor(acc[qwarpid], num_warps); - if (threadid % WARP_SIZE == 0 or threadid % WARP_SIZE == 32) { + if (lane % (num_warps * 2) == 0) { oval = __float22s2_rn(make_float2(acc[qwarpid], oval2)); c[blockIdx.x * NUM_A_ROWS_PER_BLOCK / 2 + qwarpid / 2] = oval; } } } -void LLMM1(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, - const int64_t rows_per_block) { +torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b, + const int64_t rows_per_block) { auto M = in_a.size(0); auto K = in_a.size(1); auto N = in_b.size(0); @@ -219,6 +216,9 @@ void LLMM1(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, TORCH_CHECK(in_b.dtype() == torch::kFloat16 || in_b.dtype() == torch::kBFloat16); + auto out_c = torch::empty( + {N, M}, torch::TensorOptions().dtype(in_b.dtype()).device(in_b.device())); + // NUM_TREADS need to be a multiple of WARP_SIZE, as we are using warp shuffle // operations. const int NUM_THREADS = @@ -254,12 +254,14 @@ void LLMM1(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, <<>>(a_ptr, b_ptr, c_ptr, K); } }); + + return out_c; } #define DOT2C(V0, V2, V3) \ - if (std::is_same_v) { \ + if constexpr (std::is_same_v) { \ asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(V0) : "0"(V0), "v"(V2), "v"(V3)); \ - } else if (std::is_same_v) { \ + } else if constexpr (std::is_same_v) { \ float2 s = __bfloat1622float2(*((__hip_bfloat162*)(&(V2)))) * \ __bfloat1622float2(*((__hip_bfloat162*)(&(V3)))); \ V0 += (s.x + s.y); \ @@ -268,9 +270,9 @@ void LLMM1(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, #if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support // This version targets cases where A[] fits LDS capacity template + int UNRL, int N> __global__ void __launch_bounds__(WvPrGrp* THRDS) - wvSplitK_hf_sml_(const int K, const int N, const scalar_t* B, + wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B, const scalar_t* __restrict__ A, scalar_t* C, const int _WvPrGrp, const int CuCount) { using scalar8 = @@ -301,11 +303,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // - Then the WG will move to another 8 K elements // TODO: Logic below will only work when K is multiple of 8 //---------------------------------------------------- - for (uint32_t k = 0; k < min(K * M, 32 * 1024); + for (uint32_t k = 0; k < min(K * N, 32 * 1024); k += THRDS * WvPrGrp * A_CHUNK) { uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); - if (k_in >= min(K * M, 32 * 1024)) break; + if (k_in >= min(K * N, 32 * 1024)) break; *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); } @@ -313,9 +315,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (threadIdx.y >= _WvPrGrp) return; - uint32_t n = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE; + uint32_t m = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE; - float sum[M][YTILE]; + float sum[N][YTILE]; //---------------------------------------------------- // Each wave works on a single column of weight matrix. @@ -332,7 +334,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // - After completing first set of columns, WGs start // working on the next set of available columns //---------------------------------------------------- - while (n < N) { + while (m < M) { //---------------------------------------------------- // 'sum' accumulates the matrix A x B computation // split across 64 lanes. @@ -341,9 +343,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // are being worked on by each wave. //---------------------------------------------------- for (int i = 0; i < YTILE; i++) - for (int m = 0; m < M; m++) sum[m][i] = 0; + for (int n = 0; n < N; n++) sum[n][i] = 0; - bigType bigA[M][UNRL]; + bigType bigA[N][UNRL]; bigType bigB[YTILE][UNRL]; //---------------------------------------------------- // Fetch weight matrix B in interleaved K-split! @@ -371,18 +373,25 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) uint32_t k_ = k + threadIdx.x * A_CHUNK; if (k_ >= K) break; - const scalar_t* B_ = &B[(n + 0) * K + k_]; + const scalar_t* B_ = &B[(m + 0) * K + k_]; bigB[0][k2].h8 = (loadnt((scalar8*)(&B_[0 * K]))); //---------------------------------------------------- // The following code with YTILE > 1 has to be deleted //---------------------------------------------------- - if (YTILE >= 2) bigB[1][k2].h8 = (loadnt((scalar8*)(&B_[1 * K]))); - if (YTILE >= 3) bigB[2][k2].h8 = (loadnt((scalar8*)(&B_[2 * K]))); - if (YTILE >= 4) bigB[3][k2].h8 = (loadnt((scalar8*)(&B_[3 * K]))); - if (YTILE >= 5) bigB[4][k2].h8 = (loadnt((scalar8*)(&B_[4 * K]))); - if (YTILE >= 6) bigB[5][k2].h8 = (loadnt((scalar8*)(&B_[5 * K]))); - if (YTILE >= 7) bigB[6][k2].h8 = (loadnt((scalar8*)(&B_[6 * K]))); - if (YTILE >= 8) bigB[7][k2].h8 = (loadnt((scalar8*)(&B_[7 * K]))); + if constexpr (YTILE >= 2) + bigB[1][k2].h8 = (loadnt((scalar8*)(&B_[1 * K]))); + if constexpr (YTILE >= 3) + bigB[2][k2].h8 = (loadnt((scalar8*)(&B_[2 * K]))); + if constexpr (YTILE >= 4) + bigB[3][k2].h8 = (loadnt((scalar8*)(&B_[3 * K]))); + if constexpr (YTILE >= 5) + bigB[4][k2].h8 = (loadnt((scalar8*)(&B_[4 * K]))); + if constexpr (YTILE >= 6) + bigB[5][k2].h8 = (loadnt((scalar8*)(&B_[5 * K]))); + if constexpr (YTILE >= 7) + bigB[6][k2].h8 = (loadnt((scalar8*)(&B_[6 * K]))); + if constexpr (YTILE >= 8) + bigB[7][k2].h8 = (loadnt((scalar8*)(&B_[7 * K]))); } // Fetch activation matrix from either just LDS or from both LDS / memory @@ -394,8 +403,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // Fetch A activation matrix in interleaved fashion from LDS or memory - for (int m = 0; m < M; m++) { - bigA[m][k2] = *((const bigType*)(&(s[k_ + K * m]))); + for (int n = 0; n < N; n++) { + bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n]))); } } @@ -408,33 +417,33 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // Do the matrix multiplication of activation and weight matrix // - Remember the accumulation is happening for K-split of 64! #pragma unroll - for (uint32_t m = 0; m < M; m++) { + for (uint32_t n = 0; n < N; n++) { #pragma unroll for (uint32_t b = 0; b < A_CHUNK / 2; b++) { - DOT2C(sum[m][0], bigA[m][k2].f[b], bigB[0][k2].f[b]) + DOT2C(sum[n][0], bigA[n][k2].f[b], bigB[0][k2].f[b]) //---------------------------------------------------- // The following code with YTILE > 1 //---------------------------------------------------- - if (YTILE >= 2) { - DOT2C(sum[m][1], bigA[m][k2].f[b], bigB[1][k2].f[b]); + if constexpr (YTILE >= 2) { + DOT2C(sum[n][1], bigA[n][k2].f[b], bigB[1][k2].f[b]); } - if (YTILE >= 3) { - DOT2C(sum[m][2], bigA[m][k2].f[b], bigB[2][k2].f[b]); + if constexpr (YTILE >= 3) { + DOT2C(sum[n][2], bigA[n][k2].f[b], bigB[2][k2].f[b]); } - if (YTILE >= 4) { - DOT2C(sum[m][3], bigA[m][k2].f[b], bigB[3][k2].f[b]); + if constexpr (YTILE >= 4) { + DOT2C(sum[n][3], bigA[n][k2].f[b], bigB[3][k2].f[b]); } - if (YTILE >= 5) { - DOT2C(sum[m][4], bigA[m][k2].f[b], bigB[4][k2].f[b]); + if constexpr (YTILE >= 5) { + DOT2C(sum[n][4], bigA[n][k2].f[b], bigB[4][k2].f[b]); } - if (YTILE >= 6) { - DOT2C(sum[m][5], bigA[m][k2].f[b], bigB[5][k2].f[b]); + if constexpr (YTILE >= 6) { + DOT2C(sum[n][5], bigA[n][k2].f[b], bigB[5][k2].f[b]); } - if (YTILE >= 7) { - DOT2C(sum[m][6], bigA[m][k2].f[b], bigB[6][k2].f[b]); + if constexpr (YTILE >= 7) { + DOT2C(sum[n][6], bigA[n][k2].f[b], bigB[6][k2].f[b]); } - if (YTILE >= 8) { - DOT2C(sum[m][7], bigA[m][k2].f[b], bigB[7][k2].f[b]); + if constexpr (YTILE >= 8) { + DOT2C(sum[n][7], bigA[n][k2].f[b], bigB[7][k2].f[b]); } } } @@ -444,44 +453,44 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) //---------------------------------------------------- // Final reduction step using shuffle //---------------------------------------------------- - for (int m = 0; m < M; m++) { + for (int n = 0; n < N; n++) { for (int y = 0; y < YTILE; y++) { asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); } } if (threadIdx.x == 63) { - for (int m = 0; m < M; m++) { + for (int n = 0; n < N; n++) { for (int i = 0; i < YTILE; i++) { - // if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); - C[n + i + m * N] = __float2s(sum[m][i]); + // if (commitColumn[i]) C[m + i + n * M] = __float2half(sum[n][i]); + C[m + i + n * M] = __float2s(sum[n][i]); } } } - n += CuCount * _WvPrGrp * YTILE; + m += CuCount * _WvPrGrp * YTILE; } } #else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support template -__global__ void wvSplitK_hf_sml_(const int K, const int N, const scalar_t* B, + int UNRL, int N> +__global__ void wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B, const scalar_t* __restrict__ A, scalar_t* C, const int _WvPrGrp, const int CuCount) { UNREACHABLE_CODE @@ -491,9 +500,9 @@ __global__ void wvSplitK_hf_sml_(const int K, const int N, const scalar_t* B, #if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support // This version targets cases where A[] marginally exceeds LDS capacity template + int UNRL, int N> __global__ void __launch_bounds__(WvPrGrp* THRDS) - wvSplitK_hf_(const int K, const int N, const scalar_t* B, + wvSplitK_hf_(const int K, const int M, const scalar_t* B, const scalar_t* __restrict__ A, scalar_t* C, const int _WvPrGrp, const int CuCount) { using scalar8 = @@ -529,16 +538,16 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // WG ID and Thread ID to find the index. //---------------------------------------------------- // int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp); - uint32_t n = (blockIdx.x * _WvPrGrp + threadIdx.y) * YTILE; + uint32_t m = (blockIdx.x * _WvPrGrp + threadIdx.y) * YTILE; // Check whether there will be fragmenation! // This will happen only for the last wave! - if (n < N && (n + YTILE) >= N) { - uint32_t startColumn = N - YTILE; - for (uint32_t i = 0; i < (n - startColumn); i++) { + if (m < M && (m + YTILE) >= M) { + uint32_t startColumn = M - YTILE; + for (uint32_t i = 0; i < (m - startColumn); i++) { commitColumn[i] = 0; } - n = startColumn; + m = startColumn; } //---------------------------------------------------- @@ -550,11 +559,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // - Then the WG will move to another 8 K elements // TODO: Logic below will only work when K is multiple of 8 //---------------------------------------------------- - for (uint32_t k = 0; k < min(K * M, 32 * 1024); + for (uint32_t k = 0; k < min(K * N, 32 * 1024); k += THRDS * WvPrGrp * A_CHUNK) { uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); - if (k_in >= min(K * M, 32 * 1024)) break; + if (k_in >= min(K * N, 32 * 1024)) break; *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); } @@ -563,7 +572,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (threadIdx.y >= _WvPrGrp) return; - float sum[M][YTILE]; + float sum[N][YTILE]; //---------------------------------------------------- // Each wave works on a single column of weight matrix. @@ -580,7 +589,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // - After completing first set of columns, WGs start // working on the next set of available columns //---------------------------------------------------- - while (n < N) { + while (m < M) { //---------------------------------------------------- // 'sum' accumulates the matrix A x B computation // split across 64 lanes. @@ -589,9 +598,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // are being worked on by each wave. //---------------------------------------------------- for (int i = 0; i < YTILE; i++) - for (int m = 0; m < M; m++) sum[m][i] = 0; + for (int n = 0; n < N; n++) sum[n][i] = 0; - bigType bigA[M][UNRL]; + bigType bigA[N][UNRL]; bigType bigB[YTILE][UNRL]; //---------------------------------------------------- // Fetch weight matrix B in interleaved K-split! @@ -618,18 +627,25 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) uint32_t k_ = k + threadIdx.x * A_CHUNK; if (k_ >= K) break; - const scalar_t* B_ = &B[(n + 0) * K + k_]; + const scalar_t* B_ = &B[(m + 0) * K + k_]; bigB[0][k2].h8 = (loadnt((scalar8*)(&B_[0 * K]))); //---------------------------------------------------- // The following code with YTILE > 1 has to be deleted //---------------------------------------------------- - if (YTILE >= 2) bigB[1][k2].h8 = (loadnt((scalar8*)(&B_[1 * K]))); - if (YTILE >= 3) bigB[2][k2].h8 = (loadnt((scalar8*)(&B_[2 * K]))); - if (YTILE >= 4) bigB[3][k2].h8 = (loadnt((scalar8*)(&B_[3 * K]))); - if (YTILE >= 5) bigB[4][k2].h8 = (loadnt((scalar8*)(&B_[4 * K]))); - if (YTILE >= 6) bigB[5][k2].h8 = (loadnt((scalar8*)(&B_[5 * K]))); - if (YTILE >= 7) bigB[6][k2].h8 = (loadnt((scalar8*)(&B_[6 * K]))); - if (YTILE >= 8) bigB[7][k2].h8 = (loadnt((scalar8*)(&B_[7 * K]))); + if constexpr (YTILE >= 2) + bigB[1][k2].h8 = (loadnt((scalar8*)(&B_[1 * K]))); + if constexpr (YTILE >= 3) + bigB[2][k2].h8 = (loadnt((scalar8*)(&B_[2 * K]))); + if constexpr (YTILE >= 4) + bigB[3][k2].h8 = (loadnt((scalar8*)(&B_[3 * K]))); + if constexpr (YTILE >= 5) + bigB[4][k2].h8 = (loadnt((scalar8*)(&B_[4 * K]))); + if constexpr (YTILE >= 6) + bigB[5][k2].h8 = (loadnt((scalar8*)(&B_[5 * K]))); + if constexpr (YTILE >= 7) + bigB[6][k2].h8 = (loadnt((scalar8*)(&B_[6 * K]))); + if constexpr (YTILE >= 8) + bigB[7][k2].h8 = (loadnt((scalar8*)(&B_[7 * K]))); } // Fetch activation matrix from either just LDS or from both LDS / memory @@ -641,17 +657,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // Fetch A activation matrix in interleaved fashion from LDS or memory - for (int m = 0; m < M; m++) { - if (k_ + K * m < 32 * 1024) - bigA[m][k2] = *((const bigType*)(&(s[k_ + K * m]))); + for (int n = 0; n < N; n++) { + if (k_ + K * n < 32 * 1024) + bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n]))); else - bigA[m][k2] = *((const bigType*)(&(A[k_ + K * m]))); + bigA[n][k2] = *((const bigType*)(&(A[k_ + K * n]))); } } // Do the matrix multiplication in interleaved manner #pragma unroll - for (uint32_t m = 0; m < M; m++) { + for (uint32_t n = 0; n < N; n++) { #pragma unroll for (uint32_t k2 = 0; k2 < UNRL; k2++) { uint32_t k = k1 + k2 * THRDS * A_CHUNK; @@ -661,30 +677,30 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // - Remember the accumulation is happening for K-split of 64! #pragma unroll for (uint32_t b = 0; b < A_CHUNK / 2; b++) { - DOT2C(sum[m][0], bigA[m][k2].f[b], bigB[0][k2].f[b]); + DOT2C(sum[n][0], bigA[n][k2].f[b], bigB[0][k2].f[b]); //---------------------------------------------------- // The following code with YTILE > 1 //---------------------------------------------------- - if (YTILE >= 2) { - DOT2C(sum[m][1], bigA[m][k2].f[b], bigB[1][k2].f[b]); + if constexpr (YTILE >= 2) { + DOT2C(sum[n][1], bigA[n][k2].f[b], bigB[1][k2].f[b]); } - if (YTILE >= 3) { - DOT2C(sum[m][2], bigA[m][k2].f[b], bigB[2][k2].f[b]); + if constexpr (YTILE >= 3) { + DOT2C(sum[n][2], bigA[n][k2].f[b], bigB[2][k2].f[b]); } - if (YTILE >= 4) { - DOT2C(sum[m][3], bigA[m][k2].f[b], bigB[3][k2].f[b]); + if constexpr (YTILE >= 4) { + DOT2C(sum[n][3], bigA[n][k2].f[b], bigB[3][k2].f[b]); } - if (YTILE >= 5) { - DOT2C(sum[m][4], bigA[m][k2].f[b], bigB[4][k2].f[b]); + if constexpr (YTILE >= 5) { + DOT2C(sum[n][4], bigA[n][k2].f[b], bigB[4][k2].f[b]); } - if (YTILE >= 6) { - DOT2C(sum[m][5], bigA[m][k2].f[b], bigB[5][k2].f[b]); + if constexpr (YTILE >= 6) { + DOT2C(sum[n][5], bigA[n][k2].f[b], bigB[5][k2].f[b]); } - if (YTILE >= 7) { - DOT2C(sum[m][6], bigA[m][k2].f[b], bigB[6][k2].f[b]); + if constexpr (YTILE >= 7) { + DOT2C(sum[n][6], bigA[n][k2].f[b], bigB[6][k2].f[b]); } - if (YTILE >= 8) { - DOT2C(sum[m][7], bigA[m][k2].f[b], bigB[7][k2].f[b]); + if constexpr (YTILE >= 8) { + DOT2C(sum[n][7], bigA[n][k2].f[b], bigB[7][k2].f[b]); } } } @@ -694,56 +710,56 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) //---------------------------------------------------- // Final reduction step using shuffle //---------------------------------------------------- - for (int m = 0; m < M; m++) { + for (int n = 0; n < N; n++) { for (int y = 0; y < YTILE; y++) { asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); } } if (threadIdx.x == 63) { - for (int m = 0; m < M; m++) { + for (int n = 0; n < N; n++) { for (int i = 0; i < YTILE; i++) { if (commitColumn[i]) - C[n + i + m * N] = __float2s(sum[m][i]); + C[m + i + n * M] = __float2s(sum[n][i]); } } } - n += CuCount * _WvPrGrp * YTILE; + m += CuCount * _WvPrGrp * YTILE; // Check whether there will be fragmenation! // This will happen only for the last wave! - if (n < N && (n + YTILE) >= N) { - uint32_t startColumn = N - YTILE; - for (uint32_t i = 0; i < (n - startColumn); i++) { + if (m < M && (m + YTILE) >= M) { + uint32_t startColumn = M - YTILE; + for (uint32_t i = 0; i < (m - startColumn); i++) { commitColumn[i] = 0; } - n = startColumn; + m = startColumn; } } } #else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support template -__global__ void wvSplitK_hf_(const int K, const int N, const scalar_t* B, + int UNRL, int N> +__global__ void wvSplitK_hf_(const int K, const int M, const scalar_t* B, const scalar_t* __restrict__ A, scalar_t* C, const int _WvPrGrp, const int CuCount) { UNREACHABLE_CODE @@ -753,9 +769,9 @@ __global__ void wvSplitK_hf_(const int K, const int N, const scalar_t* B, #if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support // This version targets big A[] cases, where it is much larger than LDS capacity template + int UNRL, int N> __global__ void __launch_bounds__(WvPrGrp* THRDS) - wvSplitK_hf_big_(const int K, const int N, const scalar_t* B, + wvSplitK_hf_big_(const int K, const int M, const scalar_t* B, const scalar_t* __restrict__ A, scalar_t* C, const int _WvPrGrp, const int CuCount) { using scalar8 = @@ -794,16 +810,16 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // Algorithm does 64 lane k-splitting / wave and uses // WG ID and Thread ID to find the index. //---------------------------------------------------- - uint32_t n = (blockIdx.x * _WvPrGrp + threadIdx.y) * YTILE; + uint32_t m = (blockIdx.x * _WvPrGrp + threadIdx.y) * YTILE; // Check whether there will be fragmenation! // This will happen only for the last wave! - if (n < N && (n + YTILE) >= N) { - uint32_t startColumn = N - YTILE; - for (uint32_t i = 0; i < (n - startColumn); i++) { + if (m < M && (m + YTILE) >= M) { + uint32_t startColumn = M - YTILE; + for (uint32_t i = 0; i < (m - startColumn); i++) { commitColumn[i] = 0; } - n = startColumn; + m = startColumn; } //---------------------------------------------------- @@ -817,11 +833,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) //---------------------------------------------------- #define PCML #ifndef PCML - for (uint32_t k = 0; k < min(K * M, 32 * 1024); + for (uint32_t k = 0; k < min(K * N, 32 * 1024); k += THRDS * WvPrGrp * A_CHUNK) { uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); - if (k_in >= min(K * M, 32 * 1024)) break; + if (k_in >= min(K * N, 32 * 1024)) break; *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); } @@ -831,7 +847,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) #define TUC (THRDS * UNRL * A_CHUNK) uint32_t kBase = 0; // find biggest k size that fits in LDS - uint32_t kFit = (32 * 1024) / M; + uint32_t kFit = (32 * 1024) / N; // kFit = (kFit%TWC==0) ? kFit : (kFit-kFit%TWC+TWC); //round up to multiple // of TUC kFit = (kFit % TUC == 0) @@ -840,7 +856,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // if (kFit == 0) kFit = TUC; kFit = min(kFit, K); - float sum[M][YTILE]; + float sum[N][YTILE]; //---------------------------------------------------- // Each wave works on a single column of weight matrix. @@ -859,10 +875,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) //---------------------------------------------------- #ifdef PCML int YW = (YTILE * _WvPrGrp); - uint32_t Nrndp = (N % YW == 0) ? N : (N - N % YW + YW); - while (n < Nrndp) { + uint32_t Mrndp = (M % YW == 0) ? M : (M - M % YW + YW); + while (m < Mrndp) { #else - while (n < N) { + while (m < M) { #endif //---------------------------------------------------- // 'sum' accumulates the matrix A x B computation @@ -872,9 +888,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // are being worked on by each wave. //---------------------------------------------------- for (int i = 0; i < YTILE; i++) - for (int m = 0; m < M; m++) sum[m][i] = 0; + for (int n = 0; n < N; n++) sum[n][i] = 0; - bigType bigA[M][UNRL]; + bigType bigA[N][UNRL]; bigType bigB[YTILE][UNRL]; //---------------------------------------------------- // Fetch weight matrix B in interleaved K-split! @@ -902,15 +918,15 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) uint32_t kOff = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); if (kBase + kOff >= K) break; if (kOff >= kFit) break; - for (uint32_t m = 0; m < M; m++) { - uint32_t k_in = kBase + m * K + kOff; - uint32_t k_ot = m * kFit + kOff; + for (uint32_t n = 0; n < N; n++) { + uint32_t k_in = kBase + n * K + kOff; + uint32_t k_ot = n * kFit + kOff; *((bigType*)(&s[k_ot])) = *((bigType*)(&A[k_in])); } } __syncthreads(); } - if (n >= N) continue; + if (m >= M) continue; #endif // Fetch the weight matrix from memory! @@ -920,18 +936,25 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) uint32_t k_ = k + threadIdx.x * A_CHUNK; if (k_ >= K) break; - const scalar_t* B_ = &B[(n + 0) * K + k_]; + const scalar_t* B_ = &B[(m + 0) * K + k_]; bigB[0][k2].h8 = (loadnt((scalar8*)(&B_[0 * K]))); //---------------------------------------------------- // The following code with YTILE > 1 has to be deleted //---------------------------------------------------- - if (YTILE >= 2) bigB[1][k2].h8 = (loadnt((scalar8*)(&B_[1 * K]))); - if (YTILE >= 3) bigB[2][k2].h8 = (loadnt((scalar8*)(&B_[2 * K]))); - if (YTILE >= 4) bigB[3][k2].h8 = (loadnt((scalar8*)(&B_[3 * K]))); - if (YTILE >= 5) bigB[4][k2].h8 = (loadnt((scalar8*)(&B_[4 * K]))); - if (YTILE >= 6) bigB[5][k2].h8 = (loadnt((scalar8*)(&B_[5 * K]))); - if (YTILE >= 7) bigB[6][k2].h8 = (loadnt((scalar8*)(&B_[6 * K]))); - if (YTILE >= 8) bigB[7][k2].h8 = (loadnt((scalar8*)(&B_[7 * K]))); + if constexpr (YTILE >= 2) + bigB[1][k2].h8 = (loadnt((scalar8*)(&B_[1 * K]))); + if constexpr (YTILE >= 3) + bigB[2][k2].h8 = (loadnt((scalar8*)(&B_[2 * K]))); + if constexpr (YTILE >= 4) + bigB[3][k2].h8 = (loadnt((scalar8*)(&B_[3 * K]))); + if constexpr (YTILE >= 5) + bigB[4][k2].h8 = (loadnt((scalar8*)(&B_[4 * K]))); + if constexpr (YTILE >= 6) + bigB[5][k2].h8 = (loadnt((scalar8*)(&B_[5 * K]))); + if constexpr (YTILE >= 7) + bigB[6][k2].h8 = (loadnt((scalar8*)(&B_[6 * K]))); + if constexpr (YTILE >= 8) + bigB[7][k2].h8 = (loadnt((scalar8*)(&B_[7 * K]))); } // Fetch activation matrix from either just LDS or from both LDS / memory @@ -943,14 +966,14 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // Fetch A activation matrix in interleaved fashion from LDS or memory - for (int m = 0; m < M; m++) { + for (int n = 0; n < N; n++) { #ifdef PCML - bigA[m][k2] = *((const bigType*)(&(s[k_ - kBase + kFit * m]))); + bigA[n][k2] = *((const bigType*)(&(s[k_ - kBase + kFit * n]))); #else - if (k_ + K * m < 32 * 1024) - bigA[m][k2] = *((const bigType*)(&(s[k_ + K * m]))); + if (k_ + K * n < 32 * 1024) + bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n]))); else - bigA[m][k2] = *((const bigType*)(&(A[k_ + K * m]))); + bigA[n][k2] = *((const bigType*)(&(A[k_ + K * n]))); #endif } } @@ -962,35 +985,35 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) uint32_t k_ = k + threadIdx.x * A_CHUNK; if (k_ >= K) break; #pragma unroll - for (uint32_t m = 0; m < M; m++) { + for (uint32_t n = 0; n < N; n++) { // Do the matrix multiplication of activation and weight matrix // - Remember the accumulation is happening for K-split of 64! #pragma unroll for (uint32_t b = 0; b < A_CHUNK / 2; b++) { - DOT2C(sum[m][0], bigA[m][k2].f[b], bigB[0][k2].f[b]); + DOT2C(sum[n][0], bigA[n][k2].f[b], bigB[0][k2].f[b]); //---------------------------------------------------- // The following code with YTILE > 1 //---------------------------------------------------- - if (YTILE >= 2) { - DOT2C(sum[m][1], bigA[m][k2].f[b], bigB[1][k2].f[b]); + if constexpr (YTILE >= 2) { + DOT2C(sum[n][1], bigA[n][k2].f[b], bigB[1][k2].f[b]); } - if (YTILE >= 3) { - DOT2C(sum[m][2], bigA[m][k2].f[b], bigB[2][k2].f[b]); + if constexpr (YTILE >= 3) { + DOT2C(sum[n][2], bigA[n][k2].f[b], bigB[2][k2].f[b]); } - if (YTILE >= 4) { - DOT2C(sum[m][3], bigA[m][k2].f[b], bigB[3][k2].f[b]); + if constexpr (YTILE >= 4) { + DOT2C(sum[n][3], bigA[n][k2].f[b], bigB[3][k2].f[b]); } - if (YTILE >= 5) { - DOT2C(sum[m][4], bigA[m][k2].f[b], bigB[4][k2].f[b]); + if constexpr (YTILE >= 5) { + DOT2C(sum[n][4], bigA[n][k2].f[b], bigB[4][k2].f[b]); } - if (YTILE >= 6) { - DOT2C(sum[m][5], bigA[m][k2].f[b], bigB[5][k2].f[b]); + if constexpr (YTILE >= 6) { + DOT2C(sum[n][5], bigA[n][k2].f[b], bigB[5][k2].f[b]); } - if (YTILE >= 7) { - DOT2C(sum[m][6], bigA[m][k2].f[b], bigB[6][k2].f[b]); + if constexpr (YTILE >= 7) { + DOT2C(sum[n][6], bigA[n][k2].f[b], bigB[6][k2].f[b]); } - if (YTILE >= 8) { - DOT2C(sum[m][7], bigA[m][k2].f[b], bigB[7][k2].f[b]); + if constexpr (YTILE >= 8) { + DOT2C(sum[n][7], bigA[n][k2].f[b], bigB[7][k2].f[b]); } } } @@ -998,8 +1021,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) } #ifdef PCML - if (n >= N) { - n += CuCount * _WvPrGrp * YTILE; + if (m >= M) { + m += CuCount * _WvPrGrp * YTILE; kBase = 0; continue; } @@ -1008,56 +1031,56 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) //---------------------------------------------------- // Final reduction step using shuffle //---------------------------------------------------- - for (int m = 0; m < M; m++) { + for (int n = 0; n < N; n++) { for (int y = 0; y < YTILE; y++) { asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); } } if (threadIdx.x == 63) { - for (int m = 0; m < M; m++) { + for (int n = 0; n < N; n++) { for (int i = 0; i < YTILE; i++) { if (commitColumn[i]) - C[n + i + m * N] = __float2s(sum[m][i]); + C[m + i + n * M] = __float2s(sum[n][i]); } } } - n += CuCount * _WvPrGrp * YTILE; + m += CuCount * _WvPrGrp * YTILE; kBase = 0; // Check whether there will be fragmenation! // This will happen only for the last wave! - if (n < N && (n + YTILE) >= N) { - uint32_t startColumn = N - YTILE; - for (uint32_t i = 0; i < (n - startColumn); i++) { + if (m < M && (m + YTILE) >= M) { + uint32_t startColumn = M - YTILE; + for (uint32_t i = 0; i < (m - startColumn); i++) { commitColumn[i] = 0; } - n = startColumn; + m = startColumn; } } } #else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support template -__global__ void wvSplitK_hf_big_(const int K, const int N, const scalar_t* B, + int UNRL, int N> +__global__ void wvSplitK_hf_big_(const int K, const int M, const scalar_t* B, const scalar_t* __restrict__ A, scalar_t* C, const int _WvPrGrp, const int CuCount) { UNREACHABLE_CODE @@ -1093,15 +1116,21 @@ int mindiv(int N, int div1, int div2) { return rtn; } -void wvSplitK(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, - const int64_t N_in, const int64_t CuCount) { +torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b, + const int64_t CuCount) { auto M_in = in_a.size(0); auto K_in = in_a.size(1); + auto N_in = in_b.size(0); TORCH_CHECK(in_a.dtype() == in_b.dtype()); + TORCH_CHECK(K_in % 8 == 0, "k % 8 == 0"); TORCH_CHECK(in_a.dtype() == torch::kFloat16 || in_a.dtype() == torch::kBFloat16); + auto out_c = torch::empty( + {N_in, M_in}, + torch::TensorOptions().dtype(in_b.dtype()).device(in_b.device())); + dim3 grid(CuCount); const at::cuda::OptionalCUDAGuard device_guard(device_of(in_a)); @@ -1153,4 +1182,419 @@ void wvSplitK(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, std::to_string(K_in) + "," + std::to_string(N_in)); } }); + return out_c; +} + +#if defined(__HIP__MI300__) // TODO: Add NAVI support +template +__global__ void __launch_bounds__(WvPrGrp* THRDS) + wvSplitKQ_hf_sml_(const int K, const int Kp, const int M, const fp8_t* B, + const fp8_t* __restrict__ A, scalar_t* C, + const float* __restrict__ s_A, + const float* __restrict__ s_B, const int _WvPrGrp, + const int CuCount) { + using scalar8 = + __attribute__((__vector_size__((A_CHUNK / 4) * sizeof(float)))) float; + using intx2 = __attribute__((__vector_size__(2 * sizeof(int)))) int; + using intx4 = __attribute__((__vector_size__(4 * sizeof(int)))) int; + union bigType { + char f8[A_CHUNK]; + char2 c2[A_CHUNK / 2]; + scalar_t h[A_CHUNK / 2]; + float f[A_CHUNK / 4]; + int i[A_CHUNK / 4]; + long l[A_CHUNK / 8]; + intx4 l2[A_CHUNK / 16]; + scalar8 h8; + }; + + __shared__ fp8_t s[1024 * 64]; + + for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK; + k < min(K * N, 64 * 1024); k += THRDS * WvPrGrp * A_CHUNK) { + *((bigType*)(&s[k])) = *((bigType*)(&A[k])); + } + __syncthreads(); + + if (threadIdx.y >= _WvPrGrp) return; + + uint32_t m = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE; + + using floatx16 = __attribute__((__vector_size__(16 * sizeof(float)))) float; + floatx16 sum[N][YTILE]; + float sA = *s_A; + float sB = *s_B; + + while (m < M) { + for (int i = 0; i < YTILE; i++) + for (int n = 0; n < N; n++) sum[n][i] = {0.f}; + + bigType bigA[N][UNRL]; + bigType bigB[YTILE][UNRL]; + + for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + #pragma unroll + for (uint32_t n = 0; n < N; ++n) bigA[n][k2].h8 = {0.f}; + #pragma unroll + for (uint32_t y = 0; y < YTILE; ++y) bigB[y][k2].h8 = {0.f}; + } + + // Fetch the weight matrix from memory! + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + const fp8_t* B_ = &B[(m + 0) * Kp + k_]; + #pragma unroll + for (uint32_t y = 0; y < YTILE; ++y) { + bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[y * Kp]))); + } + } + + // Fetch activation matrix from either just LDS or from both LDS / memory + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + for (int n = 0; n < N; n++) { + bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n]))); + } + } + + // Do the matrix multiplication in interleaved manner + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + if (k >= K) break; + + for (uint32_t n = 0; n < N; n++) { + for (int i = 0; i < A_CHUNK; i += 8) { + for (int y = 0; y < YTILE; ++y) { + sum[n][y] = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( + bigA[n][k2].l[i / 8], bigB[y][k2].l[i / 8], sum[n][y], 0, 0, + 0); + } + } + } + } + } + + // Final reduction + for (int n = 0; n < N; n++) { + for (int y = 0; y < YTILE; y++) { + float accm0 = sum[n][y][0]; + float accm16 = sum[n][y][8]; + asm("v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[n][y][1]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[n][y][9]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[n][y][2]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[n][y][10]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[n][y][3]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[n][y][11]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[n][y][4]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[n][y][12]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[n][y][5]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[n][y][13]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[n][y][6]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[n][y][14]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[n][y][7]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[n][y][15]), "v"(accm16)); + accm0 += __shfl(accm0, 36); + accm16 += __shfl(accm16, 52); + sum[n][y][0] = accm0 + __shfl(accm16, 16); + } + } + + if (threadIdx.x == 0) { + for (int n = 0; n < N; n++) { + for (int y = 0; y < YTILE; y++) { + C[m + y + n * M] = __float2s(sum[n][y][0] * sA * sB); + } + } + } + + m += CuCount * _WvPrGrp * YTILE; + } +} +#else // !defined(__HIP__MI300__) TODO: Add NAVI support +template +__global__ void wvSplitKQ_hf_sml_(const int K, const int Kp, const int M, + const fp8_t* B, const fp8_t* __restrict__ A, + scalar_t* C, const float* __restrict__ s_A, + const float* __restrict__ s_B, + const int _WvPrGrp, const int CuCount) { + UNREACHABLE_CODE +} +#endif // defined(__HIP__MI300__) TODO: Add NAVI support + +#if defined(__HIP__MI300__) // TODO: Add NAVI support +template +__global__ void __launch_bounds__(WvPrGrp* THRDS) + wvSplitKQ_hf_(const int K, const int Kp, const int M, const fp8_t* B, + const fp8_t* __restrict__ A, scalar_t* C, + const float* __restrict__ s_A, const float* __restrict__ s_B, + const int _WvPrGrp, const int CuCount) { + using scalar8 = + __attribute__((__vector_size__((A_CHUNK / 4) * sizeof(float)))) float; + using intx2 = __attribute__((__vector_size__(2 * sizeof(int)))) int; + using intx4 = __attribute__((__vector_size__(4 * sizeof(int)))) int; + union bigType { + char f8[A_CHUNK]; + char2 c2[A_CHUNK / 2]; + scalar_t h[A_CHUNK / 2]; + float f[A_CHUNK / 4]; + int i[A_CHUNK / 4]; + long l[A_CHUNK / 8]; + intx4 l2[A_CHUNK / 16]; + scalar8 h8; + }; + + __shared__ fp8_t s[1024 * 64]; + + for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK; + k < min(K * N, 64 * 1024); k += THRDS * WvPrGrp * A_CHUNK) { + *((bigType*)(&s[k])) = *((bigType*)(&A[k])); + } + __syncthreads(); + + if (threadIdx.y >= _WvPrGrp) return; + + uint32_t m = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE; + + using floatx16 = __attribute__((__vector_size__(16 * sizeof(float)))) float; + floatx16 sum[N][YTILE]; + float sA = *s_A; + float sB = *s_B; + + while (m < M) { + for (int i = 0; i < YTILE; i++) + for (int n = 0; n < N; n++) sum[n][i] = {0}; + + bigType bigA[N][UNRL]; + bigType bigB[YTILE][UNRL]; + + for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { + // Fetch the weight matrix from memory! + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + const fp8_t* B_ = &B[(m + 0) * Kp + k_]; + for (int y = 0; y < YTILE; ++y) { + if (y + m >= M) break; // To avoid mem access fault. + bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[y * Kp]))); + } + } + + // Fetch activation matrix from either just LDS or from both LDS / memory + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + for (int n = 0; n < N; n++) { + if (k_ + K * n < 64 * 1024) + bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n]))); + else + bigA[n][k2] = *((const bigType*)(&(A[k_ + K * n]))); + } + } + + // Do the matrix multiplication in interleaved manner + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + for (uint32_t n = 0; n < N; n++) { + for (int i = 0; i < A_CHUNK; i += 8) { + for (int y = 0; y < YTILE; ++y) { + sum[n][y] = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( + bigA[n][k2].l[i / 8], bigB[y][k2].l[i / 8], sum[n][y], 0, 0, + 0); + } + } + } + } + } + + // Final reduction + for (int n = 0; n < N; n++) { + for (int y = 0; y < YTILE; y++) { + float accm0 = sum[n][y][0]; + float accm16 = sum[n][y][8]; + asm("v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[n][y][1]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[n][y][9]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[n][y][2]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[n][y][10]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[n][y][3]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[n][y][11]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[n][y][4]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[n][y][12]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[n][y][5]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[n][y][13]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[n][y][6]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[n][y][14]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[n][y][7]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[n][y][15]), "v"(accm16)); + accm0 += __shfl(accm0, 36); + accm16 += __shfl(accm16, 52); + sum[n][y][0] = accm0 + __shfl(accm16, 16); + } + } + + if (threadIdx.x == 0) { + for (int n = 0; n < N; n++) { + for (int y = 0; y < YTILE; y++) { + if (y + m >= M) break; // To avoid mem access fault. + C[m + y + n * M] = __float2s(sum[n][y][0] * sA * sB); + } + } + } + + m += CuCount * _WvPrGrp * YTILE; + } +} +#else // !defined(__HIP__MI300__) TODO: Add NAVI support +template +__global__ void wvSplitKQ_hf_(const int K, const int Kp, const int M, + const fp8_t* B, const fp8_t* __restrict__ A, + scalar_t* C, const float* __restrict__ s_A, + const float* __restrict__ s_B, const int _WvPrGrp, + const int CuCount) { + UNREACHABLE_CODE +} +#endif // defined(__HIP__MI300__) TODO: Add NAVI support + +void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, + at::Tensor& scale_a, at::Tensor& scale_b, + const int64_t CuCount) { + static c10::ScalarType kFp8Type = is_fp8_ocp() + ? c10::ScalarType::Float8_e4m3fn + : c10::ScalarType::Float8_e4m3fnuz; + auto M_in = in_a.size(0); + auto K_in = in_a.size(1); + auto N_in = in_b.size(0); + auto Kp_in = in_a.stride(0); + TORCH_CHECK(K_in % 16 == 0, "k % 16 == 0"); + TORCH_CHECK(in_a.dtype() == in_b.dtype() && in_a.dtype() == kFp8Type); + TORCH_CHECK(out_c.dtype() == torch::kFloat16 || + out_c.dtype() == torch::kBFloat16); + + dim3 grid(CuCount); + const at::cuda::OptionalCUDAGuard device_guard(device_of(in_a)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + +#define WVSPLITKQ(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \ + _N) \ + { \ + dim3 block(64, _WvPrGrp); \ + if ((K_in * N_in <= 64 * 1024) && (M_in % _YTILEs == 0)) { \ + int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \ + wvSplitKQ_hf_sml_ \ + <<>>(K_in, Kp_in, M_in, a_ptr, b_ptr, c_ptr, \ + s_a, s_b, __wvPrGrp, CuCount); \ + } else { \ + int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \ + wvSplitKQ_hf_ \ + <<>>(K_in, Kp_in, M_in, a_ptr, b_ptr, c_ptr, \ + s_a, s_b, __wvPrGrp, CuCount); \ + } \ + } + + AT_DISPATCH_REDUCED_FLOATING_TYPES(out_c.scalar_type(), "wvSplitKQ", [&] { + using fptype = typename scalar::type; + auto c_ptr = reinterpret_cast(out_c.data_ptr()); + auto s_a = scale_a.data_ptr(); + auto s_b = scale_b.data_ptr(); + VLLM_DISPATCH_FP8_TYPES(in_a.scalar_type(), "wvSplitKQ", [&] { + auto a_ptr = in_a.data_ptr(); + auto b_ptr = in_b.data_ptr(); + switch (N_in) { + case 1: + WVSPLITKQ(16, 2, 2, 2, 2, 2, 2, 1) + break; + case 2: + WVSPLITKQ(16, 2, 2, 2, 2, 2, 2, 2) + break; + case 3: + WVSPLITKQ(16, 4, 7, 7, 1, 1, 1, 3) + break; + case 4: + WVSPLITKQ(16, 4, 7, 7, 1, 1, 1, 4) + break; + default: + throw std::runtime_error( + "Unsupported N value: " + std::to_string(M_in) + "," + + std::to_string(K_in) + "," + std::to_string(N_in)); + } + }); + }); } \ No newline at end of file diff --git a/csrc/rocm/torch_bindings.cpp b/csrc/rocm/torch_bindings.cpp index 0565c96801ce..5cbfda7c52f6 100644 --- a/csrc/rocm/torch_bindings.cpp +++ b/csrc/rocm/torch_bindings.cpp @@ -16,16 +16,22 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) { // Custom gemm op for matrix-vector multiplication rocm_ops.def( - "LLMM1(Tensor in_a, Tensor in_b, Tensor! out_c, int rows_per_block) -> " - "()"); + "LLMM1(Tensor in_a, Tensor in_b, int rows_per_block) -> " + "Tensor"); rocm_ops.impl("LLMM1", torch::kCUDA, &LLMM1); // Custom gemm op for skinny matrix-matrix multiplication rocm_ops.def( - "wvSplitK(Tensor in_a, Tensor in_b, Tensor! out_c, int N_in," - " int CuCount) -> ()"); + "wvSplitK(Tensor in_a, Tensor in_b, int CuCount) -> " + "Tensor"); rocm_ops.impl("wvSplitK", torch::kCUDA, &wvSplitK); + // wvSplitK for fp8 + rocm_ops.def( + "wvSplitKQ(Tensor in_a, Tensor in_b, Tensor! out_c, Tensor scale_a, " + " Tensor scale_b, int CuCount) -> ()"); + rocm_ops.impl("wvSplitKQ", torch::kCUDA, &wvSplitKQ); + // Custom attention op // Compute the attention between an input query and the cached // keys/values using PagedAttention. diff --git a/tests/kernels/test_rocm_skinny_gemms.py b/tests/kernels/test_rocm_skinny_gemms.py new file mode 100644 index 000000000000..8ce2e553d231 --- /dev/null +++ b/tests/kernels/test_rocm_skinny_gemms.py @@ -0,0 +1,81 @@ +# SPDX-License-Identifier: Apache-2.0 +import pytest +import torch + +import vllm._custom_ops as ops +from tests.kernels.quant_utils import ref_dynamic_per_tensor_fp8_quant +from vllm.platforms import current_platform + +DTYPES = [torch.bfloat16, torch.float16] +M = [16, 32, 64, 128, 256, 512, 1024, 4096, 8192] +K = [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192] # k % 8 == 0 +N = [1, 2, 3, 4] +SEEDS = [0] + + +@pytest.mark.parametrize("n", [1]) # only test for batch size 1 +@pytest.mark.parametrize("k", K) +@pytest.mark.parametrize("m", M) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("rows_per_block", [2, 4, 8, 16]) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.skipif(not current_platform.is_rocm(), + reason="only test for rocm") +@torch.inference_mode() +def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed): + torch.manual_seed(seed) + A = torch.rand(n, k, dtype=dtype, device="cuda") + B = torch.rand(m, k, dtype=dtype, device="cuda") + + ref_out = torch.matmul(A, B.t()) + out = ops.LLMM1(B, A, rows_per_block) + + assert torch.allclose(out, ref_out, rtol=0.01) + + +@pytest.mark.parametrize("n", N) # only test for batch size <= 4 +@pytest.mark.parametrize("k", K + [9216, 10240, 16384]) +@pytest.mark.parametrize("m", [8] + M) # m >= 8 +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.skipif(not current_platform.is_rocm(), + reason="only test for rocm") +def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed): + torch.manual_seed(seed) + cu_count = current_platform.get_cu_count() + + A = torch.rand(n, k, dtype=dtype, device="cuda") + B = torch.rand(m, k, dtype=dtype, device="cuda") + + ref_out = torch.matmul(A, B.t()) + out = ops.wvSplitK(B, A, cu_count) + + assert torch.allclose(out, ref_out, rtol=0.01) + + +@pytest.mark.parametrize("n", N) # only test for batch size <= 4 +@pytest.mark.parametrize("k", + K[1:] + [8192 * 2, 8192 * 3, 8192 * 4]) # k % 16 == 0 +@pytest.mark.parametrize("m", M) # m >= 16 +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.skipif(not current_platform.is_rocm(), + reason="only test for rocm") +def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed): + torch.manual_seed(seed) + + A = torch.rand(n, k, device="cuda") + B = torch.rand(m, k, device="cuda") + + A, scale_a = ref_dynamic_per_tensor_fp8_quant(A) + B, scale_b = ref_dynamic_per_tensor_fp8_quant(B) + + ref_out = torch._scaled_mm(A, + B.t(), + out_dtype=dtype, + scale_a=scale_a, + scale_b=scale_b) + out = ops.wvSplitKQ(B, A, dtype, scale_a, scale_b, + current_platform.get_cu_count()) + + assert torch.allclose(out, ref_out, rtol=0.01) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index af14c7a27c07..6eb6b7069886 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1181,14 +1181,23 @@ def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, # ROCm skinny gemms -def LLMM1(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor, - rows_per_block: int) -> None: - torch.ops._rocm_C.LLMM1(a, b, out, rows_per_block) +def LLMM1(a: torch.Tensor, b: torch.Tensor, + rows_per_block: int) -> torch.Tensor: + return torch.ops._rocm_C.LLMM1(a, b, rows_per_block) -def wvSplitK(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor, N: int, - cu_count: int) -> None: - torch.ops._rocm_C.wvSplitK(a, b, out, N, cu_count) +def wvSplitK(a: torch.Tensor, b: torch.Tensor, cu_count: int) -> torch.Tensor: + return torch.ops._rocm_C.wvSplitK(a, b, cu_count) + + +def wvSplitKQ(a: torch.Tensor, b: torch.Tensor, out_dtype: torch.dtype, + scale_a: torch.Tensor, scale_b: torch.Tensor, + cu_count: int) -> torch.Tensor: + out = torch.empty((b.shape[0], a.shape[0]), + dtype=out_dtype, + device=b.device) + torch.ops._rocm_C.wvSplitKQ(a, b, out, scale_a, scale_b, cu_count) + return out # moe diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index de1001ed2d33..03bfb9d874b4 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch @@ -131,6 +131,154 @@ def maybe_create_device_identity(): TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32) +def cutlass_w8a8_scaled_mm(*, qinput: torch.Tensor, weight: torch.Tensor, + out_dtype: torch.dtype, scale_a: torch.Tensor, + scale_b: torch.Tensor, bias: torch.Tensor, + output_shape: List, **kwargs) -> torch.Tensor: + + # Fused GEMM_DQ + output = ops.cutlass_scaled_mm(qinput, + weight, + out_dtype=out_dtype, + scale_a=scale_a, + scale_b=scale_b, + bias=bias) + return output.view(*output_shape) + + +def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor, + weight: torch.Tensor, + out_dtype: torch.dtype, + scale_a: torch.Tensor, + scale_b: torch.Tensor, bias: torch.Tensor, + input_2d: torch.Tensor, + output_shape: List) -> torch.Tensor: + if current_platform.is_rocm_skinny_gemm_enabled() and qinput.shape[0] == 1: + output = ops.wvSplitKQ(weight.t(), qinput, scale_a, scale_b, + current_platform.get_cu_count()) + else: + output = torch._scaled_mm(qinput, + weight, + out_dtype=out_dtype, + scale_a=scale_a, + scale_b=scale_b, + bias=bias) + + return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape) + + +def torch_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor, + weight: torch.Tensor, + out_dtype: torch.dtype, + scale_a: torch.Tensor, + scale_b: torch.Tensor, bias: torch.Tensor, + input_2d: torch.Tensor, + output_shape: List) -> torch.Tensor: + output = torch._scaled_mm(qinput, + weight, + out_dtype=out_dtype, + scale_a=scale_a, + scale_b=scale_b, + bias=bias) + # A fix for discrepancy in scaled_mm which returns tuple + # for torch < 2.5 and a single value in torch >= 2.5 + if type(output) is tuple and len(output) == 2: + output = output[0] + + return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape) + + +def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor, + weight: torch.Tensor, + out_dtype: torch.dtype, + scale_a: torch.Tensor, + scale_b: torch.Tensor, bias: torch.Tensor, + input_2d: torch.Tensor, + output_shape: List) -> torch.Tensor: + # For now validated on ROCm platform + # fp8 rowwise scaling in torch._scaled_mm is introduced in + # https://github.com/pytorch/pytorch/pull/144432 using + # hipBLASLt and ROCm 6.3, which only exists in torch 2.7 and above. + # For CUDA platform please validate if the + # torch._scaled_mm support rowwise scaled GEMM + # Fused GEMM_DQ Rowwise GEMM + output = torch._scaled_mm(qinput, + weight, + out_dtype=out_dtype, + scale_a=scale_a, + scale_b=scale_b.t(), + bias=bias) + + output = torch.narrow(output, 0, 0, input_2d.shape[0]) + output = output.view(*output_shape) + return output + + +def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor, + weight: torch.Tensor, + out_dtype: torch.dtype, + scale_a: torch.Tensor, + scale_b: torch.Tensor, bias: torch.Tensor, + input_2d: torch.Tensor, + output_shape: List, + **kwargs) -> torch.Tensor: + # Use unfused DQ due to limitations with scaled_mm + + # Symmetric quantized GEMM by definition computes the following: + # C = (s_x * X) (s_w * W) + bias + # This is equivalent to dequantizing the weights and activations + # before applying a GEMM. + # + # In order to compute quantized operands, a quantized kernel + # will rewrite the above like so: + # C = s_w * s_x * (X * W) + bias + # + # For the scaled_mm fallback case, we break this down, since it + # does not support s_w being a vector. + + # GEMM + # This computes C = (X * W). + # Output in fp32 to allow subsequent ops to happen in-place + output = torch._scaled_mm(qinput, + weight, + scale_a=TORCH_DEVICE_IDENTITY, + scale_b=TORCH_DEVICE_IDENTITY, + out_dtype=torch.float32) + # A fix for discrepancy in scaled_mm which returns tuple + # for torch < 2.5 and a single value in torch >= 2.5 + if type(output) is tuple and len(output) == 2: + output = output[0] + # Unpad (undo num_token_padding) + output = torch.narrow(output, 0, 0, input_2d.shape[0]) + x_scale = torch.narrow(scale_a, 0, 0, input_2d.shape[0]) + + # DQ + # C = sw * sx * (X * W) + bias + output = output * x_scale * scale_b.t() + if bias is not None: + output = output + bias + return output.to(out_dtype).view(*output_shape) + + +def dispatch_w8a8_scaled_mm( + cutlass_fp8_supported: bool, per_tensor_weights: bool, + per_tensor_activations: bool, use_per_token_if_dynamic: Optional[bool] +) -> Callable[..., torch.Tensor]: + + if cutlass_fp8_supported: + return cutlass_w8a8_scaled_mm + if per_tensor_weights and per_tensor_activations: + if current_platform.is_rocm(): + return rocm_per_tensor_w8a8_scaled_mm + return torch_per_tensor_w8a8_scaled_mm + # torch.scaled_mm supports per tensor weights + activations only + # so fallback to naive if per channel or per token + if (use_per_token_if_dynamic and not per_tensor_weights + and not per_tensor_activations and USE_ROWWISE_TORCH_SCALED_MM): + return torch_per_token_w8a8_scaled_mm + return torch_channelwise_w8a8_scaled_mm + + # TODO(luka): follow similar pattern for marlin and block-fp8-linear # https://github.com/vllm-project/vllm/issues/14397 class Fp8LinearOp: @@ -196,18 +344,6 @@ def apply( input_scale, scale_ub=input_scale_ub, use_per_token_if_dynamic=use_per_token_if_dynamic) - - # Fused GEMM_DQ - output = ops.cutlass_scaled_mm(qinput, - weight, - out_dtype=out_dtype, - scale_a=x_scale, - scale_b=weight_scale, - bias=bias) - return output.view(*output_shape) - - # torch.scaled_mm supports per tensor weights + activations only - # so fallback to naive if per channel or per token else: if input.dtype != current_platform.fp8_dtype(): # Maybe apply padding to output, see comment in __init__ @@ -219,84 +355,21 @@ def apply( else: qinput, x_scale = input_2d, input_scale - per_tensor_weights = (weight_scale.numel() == 1) - per_tensor_activations = (x_scale.numel() == 1) - - if per_tensor_weights and per_tensor_activations: - # Fused GEMM_DQ - output = torch._scaled_mm(qinput, - weight, - out_dtype=out_dtype, - scale_a=x_scale, - scale_b=weight_scale, - bias=bias) - # A fix for discrepancy in scaled_mm which returns tuple - # for torch < 2.5 and a single value in torch >= 2.5 - if type(output) is tuple and len(output) == 2: - output = output[0] - - return torch.narrow(output, 0, 0, - input_2d.shape[0]).view(*output_shape) - - elif (use_per_token_if_dynamic and not per_tensor_weights - and not per_tensor_activations - and USE_ROWWISE_TORCH_SCALED_MM): - # For now validated on ROCm platform - # fp8 rowwise scaling in torch._scaled_mm is introduced in - # https://github.com/pytorch/pytorch/pull/144432 using hipBLASLt - # and ROCm 6.3, which only exists in torch 2.7 and above. - # For CUDA platform please validate if the - # torch._scaled_mm support rowwise scaled GEMM - # Fused GEMM_DQ Rowwise GEMM - output = torch._scaled_mm(qinput, - weight, - out_dtype=out_dtype, - scale_a=x_scale, - scale_b=weight_scale.t(), - bias=bias) - - output = torch.narrow(output, 0, 0, input_2d.shape[0]) - output = output.view(*output_shape) - return output - - else: - # Fallback for channelwise case, where we use unfused DQ - # due to limitations with scaled_mm - - # Symmetric quantized GEMM by definition computes the following: - # C = (s_x * X) (s_w * W) + bias - # This is equivalent to dequantizing the weights and activations - # before applying a GEMM. - # - # In order to compute quantized operands, a quantized kernel - # will rewrite the above like so: - # C = s_w * s_x * (X * W) + bias - # - # For the scaled_mm fallback case, we break this down, since it - # does not support s_w being a vector. - - # GEMM - # This computes C = (X * W). - # Output in fp32 to allow subsequent ops to happen in-place - output = torch._scaled_mm(qinput, - weight, - scale_a=TORCH_DEVICE_IDENTITY, - scale_b=TORCH_DEVICE_IDENTITY, - out_dtype=torch.float32) - # A fix for discrepancy in scaled_mm which returns tuple - # for torch < 2.5 and a single value in torch >= 2.5 - if type(output) is tuple and len(output) == 2: - output = output[0] - # Unpad (undo num_token_padding) - output = torch.narrow(output, 0, 0, input_2d.shape[0]) - x_scale = torch.narrow(x_scale, 0, 0, input_2d.shape[0]) - - # DQ - # C = sw * sx * (X * W) + bias - output = output * x_scale * weight_scale.t() - if bias is not None: - output = output + bias - return output.to(dtype=input.dtype).view(*output_shape) + per_tensor_weights = (weight_scale.numel() == 1) + per_tensor_activations = (x_scale.numel() == 1) + + w8a8_scaled_mm_func = dispatch_w8a8_scaled_mm( + self.cutlass_fp8_supported, per_tensor_weights, + per_tensor_activations, use_per_token_if_dynamic) + + return w8a8_scaled_mm_func(qinput=qinput, + weight=weight, + out_dtype=input.dtype, + scale_a=x_scale, + scale_b=weight_scale, + bias=bias, + input_2d=input_2d, + output_shape=output_shape) def normalize_e4m3fn_to_e4m3fnuz( diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index fc8a1549abbb..1611254512c2 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -5,7 +5,6 @@ import torch from vllm import _custom_ops as ops -from vllm import envs from vllm.platforms import current_platform @@ -71,25 +70,17 @@ def rocm_unquantized_gemm(x: torch.Tensor, n = x_view.shape[0] cu_count = current_platform.get_cu_count() - use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM is True and \ + use_skinny = (current_platform.is_rocm_skinny_gemm_enabled() and \ x.dtype in [torch.float16, torch.bfloat16] \ and k % 8 == 0 and bias is None) if use_skinny is not True: return torch.nn.functional.linear(x, weight, bias) - if m > 8 and n < 4: - out = torch.empty(x_view.shape[0], - weight.shape[0], - dtype=x.dtype, - device=x.device) - ops.wvSplitK(weight, x_view, out, n, cu_count) + if m > 8 and n <= 4: + out = ops.wvSplitK(weight, x_view, cu_count) return out.view(*x.shape[:-1], weight.shape[0]) elif m % 4 == 0 and n == 1 and k <= 8192: - out = torch.empty(x_view.shape[0], - weight.shape[0], - dtype=x.dtype, - device=x.device) - ops.LLMM1(weight, x_view, out, 4) + out = ops.LLMM1(weight, x_view, out, 4) return out.view(*x.shape[:-1], weight.shape[0]) return torch.nn.functional.linear(x, weight, bias) diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 45aab3bf9297..01a525aa344b 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -379,6 +379,13 @@ def supports_v1(cls, model_config: ModelConfig) -> bool: """ return False + @classmethod + def is_rocm_skinny_gemm_enabled(cls) -> bool: + """ + Return if skinny gemms enabled on rocm + """ + raise NotImplementedError + @classmethod def get_cu_count(cls, device_id: int = 0) -> int: """ diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 9f4f98fceb6b..89a8dc84b4c8 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -288,3 +288,10 @@ def supports_v1(cls, model_config: ModelConfig) -> bool: def get_cu_count(cls, device_id: int = 0) -> int: return torch.cuda.get_device_properties( device_id).multi_processor_count + + @classmethod + def is_rocm_skinny_gemm_enabled(cls) -> bool: + """ + Return if skinny gemms enabled on rocm + """ + return envs.VLLM_ROCM_USE_SKINNY_GEMM From 63efd7f4eae84738a2bb11058088fc54bc3658b8 Mon Sep 17 00:00:00 2001 From: charlifu Date: Tue, 8 Apr 2025 15:24:07 +0000 Subject: [PATCH 12/17] fix fp8 skinny gemm call Signed-off-by: charlifu --- tests/kernels/test_rocm_skinny_gemms.py | 5 ++--- vllm/model_executor/layers/quantization/utils/w8a8_utils.py | 5 +++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/kernels/test_rocm_skinny_gemms.py b/tests/kernels/test_rocm_skinny_gemms.py index 8ce2e553d231..622079c39445 100644 --- a/tests/kernels/test_rocm_skinny_gemms.py +++ b/tests/kernels/test_rocm_skinny_gemms.py @@ -54,9 +54,8 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed): @pytest.mark.parametrize("n", N) # only test for batch size <= 4 -@pytest.mark.parametrize("k", - K[1:] + [8192 * 2, 8192 * 3, 8192 * 4]) # k % 16 == 0 -@pytest.mark.parametrize("m", M) # m >= 16 +@pytest.mark.parametrize("k", K[1:] + [14336, 24576, 32768]) # k % 16 == 0 +@pytest.mark.parametrize("m", M + [28672]) # m >= 16 @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.skipif(not current_platform.is_rocm(), diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 03bfb9d874b4..1710bcb50189 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -153,8 +153,9 @@ def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor, input_2d: torch.Tensor, output_shape: List) -> torch.Tensor: - if current_platform.is_rocm_skinny_gemm_enabled() and qinput.shape[0] == 1: - output = ops.wvSplitKQ(weight.t(), qinput, scale_a, scale_b, + if current_platform.is_rocm_skinny_gemm_enabled( + ) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0: + output = ops.wvSplitKQ(weight.t(), qinput, out_dtype, scale_a, scale_b, current_platform.get_cu_count()) else: output = torch._scaled_mm(qinput, From 5a095062459533c4dda2a2136693f309d5379924 Mon Sep 17 00:00:00 2001 From: charlifu Date: Tue, 8 Apr 2025 19:22:14 +0000 Subject: [PATCH 13/17] fix engine test Signed-off-by: charlifu --- vllm/model_executor/layers/utils.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index 1611254512c2..359d2d97ad38 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -64,19 +64,20 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, def rocm_unquantized_gemm(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None): - x_view = x.view(-1, x.size(-1)) - m = weight.shape[0] k = weight.shape[1] - n = x_view.shape[0] - cu_count = current_platform.get_cu_count() - use_skinny = (current_platform.is_rocm_skinny_gemm_enabled() and \ x.dtype in [torch.float16, torch.bfloat16] \ and k % 8 == 0 and bias is None) if use_skinny is not True: return torch.nn.functional.linear(x, weight, bias) - if m > 8 and n <= 4: + + x_view = x.view(-1, x.size(-1)) + n = x_view.shape[0] + m = weight.shape[0] + cu_count = current_platform.get_cu_count() + + if m > 8 and n < 4: out = ops.wvSplitK(weight, x_view, cu_count) return out.view(*x.shape[:-1], weight.shape[0]) elif m % 4 == 0 and n == 1 and k <= 8192: From 660fefb86ee80dce123872f2802b825cf05508cc Mon Sep 17 00:00:00 2001 From: charlifu Date: Wed, 9 Apr 2025 14:59:05 +0000 Subject: [PATCH 14/17] remove env check out of platform class Signed-off-by: charlifu --- .../layers/quantization/utils/w8a8_utils.py | 5 +++-- vllm/model_executor/layers/utils.py | 3 ++- vllm/platforms/interface.py | 7 ------- vllm/platforms/rocm.py | 9 +-------- 4 files changed, 6 insertions(+), 18 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 1710bcb50189..d68350c4b304 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -5,6 +5,7 @@ import torch from vllm import _custom_ops as ops +from vllm import envs from vllm.config import CompilationLevel, get_current_vllm_config from vllm.platforms import current_platform @@ -153,8 +154,8 @@ def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor, input_2d: torch.Tensor, output_shape: List) -> torch.Tensor: - if current_platform.is_rocm_skinny_gemm_enabled( - ) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0: + if envs.VLLM_ROCM_USE_SKINNY_GEMM and qinput.shape[ + 0] == 1 and qinput.shape[1] % 16 == 0: output = ops.wvSplitKQ(weight.t(), qinput, out_dtype, scale_a, scale_b, current_platform.get_cu_count()) else: diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index 359d2d97ad38..6f71be6fba9d 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -5,6 +5,7 @@ import torch from vllm import _custom_ops as ops +from vllm import envs from vllm.platforms import current_platform @@ -65,7 +66,7 @@ def rocm_unquantized_gemm(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None): k = weight.shape[1] - use_skinny = (current_platform.is_rocm_skinny_gemm_enabled() and \ + use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and \ x.dtype in [torch.float16, torch.bfloat16] \ and k % 8 == 0 and bias is None) diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index d79fdb35116e..195269886f6d 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -393,13 +393,6 @@ def use_custom_allreduce(cls) -> bool: """ return False - @classmethod - def is_rocm_skinny_gemm_enabled(cls) -> bool: - """ - Return if skinny gemms enabled on rocm - """ - raise NotImplementedError - @classmethod def get_cu_count(cls, device_id: int = 0) -> int: """ diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 5d61d01e0219..85a1bc4ac477 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -318,11 +318,4 @@ def use_custom_allreduce(cls) -> bool: @cache def get_cu_count(cls, device_id: int = 0) -> int: return torch.cuda.get_device_properties( - device_id).multi_processor_count - - @classmethod - def is_rocm_skinny_gemm_enabled(cls) -> bool: - """ - Return if skinny gemms enabled on rocm - """ - return envs.VLLM_ROCM_USE_SKINNY_GEMM + device_id).multi_processor_count \ No newline at end of file From 9674634ebadd9aab1b553ca76dea40b245c21eff Mon Sep 17 00:00:00 2001 From: charlifu Date: Wed, 9 Apr 2025 16:55:22 +0000 Subject: [PATCH 15/17] add torch version check for row_wise scaled_mm Signed-off-by: charlifu --- vllm/model_executor/layers/quantization/utils/w8a8_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index d68350c4b304..7ed23b121c50 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -18,6 +18,7 @@ # The condition is determined once as the operations # are time consuming. USE_ROWWISE_TORCH_SCALED_MM = (current_platform.is_rocm() + and torch.__version__[0:3] >= "2.7" and current_platform.has_device_capability(94)) From c8c248b67d37ef395d0107b5a8852c8039ef5bdc Mon Sep 17 00:00:00 2001 From: charlifu Date: Mon, 14 Apr 2025 21:56:29 +0000 Subject: [PATCH 16/17] remove cache decorator to fix V1 error Signed-off-by: charlifu --- vllm/platforms/rocm.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index a6daef6dbb59..77fcd2d7c0f9 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -311,7 +311,6 @@ def use_custom_allreduce(cls) -> bool: return any(gfx in gcn_arch for gfx in supported_archs) @classmethod - @cache def get_cu_count(cls, device_id: int = 0) -> int: return torch.cuda.get_device_properties( device_id).multi_processor_count \ No newline at end of file From 6535863976fe0e8d50ce24e1ac74f2230039b947 Mon Sep 17 00:00:00 2001 From: Charlie Fu Date: Mon, 21 Apr 2025 10:31:10 -0500 Subject: [PATCH 17/17] Update vllm/model_executor/layers/quantization/utils/w8a8_utils.py Co-authored-by: Tyler Michael Smith Signed-off-by: charlifu --- .../layers/quantization/utils/w8a8_utils.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 7ed23b121c50..d279ffe45d6d 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -198,12 +198,16 @@ def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor, input_2d: torch.Tensor, output_shape: List) -> torch.Tensor: - # For now validated on ROCm platform - # fp8 rowwise scaling in torch._scaled_mm is introduced in - # https://github.com/pytorch/pytorch/pull/144432 using - # hipBLASLt and ROCm 6.3, which only exists in torch 2.7 and above. - # For CUDA platform please validate if the - # torch._scaled_mm support rowwise scaled GEMM + # Note: Callers of this function should check USE_ROWWISE_TORCH_SCALED_MM + # when using it. + # For now it has only been validated on ROCm platform. + # fp8 rowwise scaling in torch._scaled_mm is introduced in + # https://github.com/pytorch/pytorch/pull/144432 using + # hipBLASLt and ROCm 6.3, which only exists in torch 2.7 and above. + # + # For CUDA platform please validate if the torch._scaled_mm supports + # rowwise scaled GEMM before using it + # Fused GEMM_DQ Rowwise GEMM output = torch._scaled_mm(qinput, weight,