diff --git a/.gitignore b/.gitignore index da5a337c4683..b531b7918c30 100644 --- a/.gitignore +++ b/.gitignore @@ -173,3 +173,7 @@ cython_debug/ # Sphinx documentation _build/ + +# vim swap files +*.swo +*.swp diff --git a/awq_ext/awq_kernels/dequantize.cuh b/awq_ext/awq_kernels/dequantize.cuh new file mode 100644 index 000000000000..5d333b35c148 --- /dev/null +++ b/awq_ext/awq_kernels/dequantize.cuh @@ -0,0 +1,79 @@ +/* +Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h + +@article{lin2023awq, + title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration}, + author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song}, + journal={arXiv}, + year={2023} +} +*/ + +#pragma once + + +__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) +{ + uint4 result; + + uint32_t* h = reinterpret_cast(&result); + uint32_t const i4s = reinterpret_cast(source); + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t BOTTOM_MASK = 0x000f000f; + static constexpr uint32_t TOP_MASK = 0x00f000f0; + static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400; + + // Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing + // format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions. + // In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and + // elt_67 to fp16 without having to shift them to the bottom bits before hand. + + // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue + // immediately before required. + const uint32_t top_i4s = i4s >> 8; + // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[0]) + : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[1]) + : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[2]) + : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[3]) + : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + + // I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the + // half2 ctor. In this case, I chose performance reliability over code readability. + + // This is the half2 {1032, 1032} represented as an integer. + // static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408; + // Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7] + static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400; + // This is the half2 {1 / 16, 1 / 16} represented as an integer. + static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00; + // This is the half2 {-72, -72} represented as an integer. + // static constexpr uint32_t NEG_72 = 0xd480d480; + // Haotian: Let's use {-64, -64}. + static constexpr uint32_t NEG_64 = 0xd400d400; + + // Finally, we construct the output numbers. + // Convert elt_01 + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_23 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); + // Convert elt_45 + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_67 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); + + return result; +} + diff --git a/awq_ext/awq_kernels/gemm_cuda.h b/awq_ext/awq_kernels/gemm_cuda.h new file mode 100644 index 000000000000..11d5820cf99d --- /dev/null +++ b/awq_ext/awq_kernels/gemm_cuda.h @@ -0,0 +1,4 @@ +#include + +torch::Tensor gemm_forward_cuda(torch::Tensor _in_feats, torch::Tensor _kernel, + torch::Tensor _scaling_factors, torch::Tensor _zeros, int split_k_iters); diff --git a/awq_ext/awq_kernels/gemm_cuda_gen.cu b/awq_ext/awq_kernels/gemm_cuda_gen.cu new file mode 100644 index 000000000000..1632d8be3eb2 --- /dev/null +++ b/awq_ext/awq_kernels/gemm_cuda_gen.cu @@ -0,0 +1,478 @@ +/* + +@article{lin2023awq, + title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration}, + author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song}, + journal={arXiv}, + year={2023} +} + + */ + + +#include +#include "gemm_cuda.h" +#include "dequantize.cuh" +#include +#include + + +// Pack two half values. +static inline __device__ __host__ unsigned +__pack_half2(const half x, const half y) { + unsigned v0 = *((unsigned short *)&x); + unsigned v1 = *((unsigned short *)&y); + return (v1 << 16) | v0; +} + +__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C) +{ + static constexpr uint32_t ZERO = 0x0; + float C_warp[32]; + __shared__ half A_shared[16 * (32 + 8)]; + __shared__ half B_shared[32 * (128 + 8)]; + + __shared__ half scaling_factors_shared[128]; + __shared__ half zeros_shared[128]; + + int j_factors1 = ((OC + 128 - 1) / 128); + int blockIdx_x = 0; + int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1); + int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1); + + half A_shared_warp[8]; + half B_shared_warp[32]; + for (int j_0_4_init = 0; j_0_4_init < 4; ++j_0_4_init) { + for (int i = 0; i < 8; ++i) { + C_warp[(j_0_4_init * 8) + i] = 0.0; + } + } + + static constexpr int row_stride_warp = 32 * 8 / 32; + static constexpr int row_stride = 2 * 32 * 8 / 128; + bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 128; + // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 + bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id + // bool wb_C_flag = (threadIdx.x / 4) < M; + + half* A_ptr = A + + (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC + + (((int)threadIdx.x) % (32 / 8)) * 8; + + int* B_ptr = B + + ((int)threadIdx.y) * (OC / 8) * 2 + + (((int)threadIdx.x) / (128 / 8)) * (OC / 8) + + (((int)blockIdx_y) % j_factors1) * (128 / 8) + + (((int)threadIdx.x) % (128 / 8)) * 1; +// Why * 1 in the above line? + + half* A_shared_ptr = A_shared + + ((int)threadIdx.y) * row_stride_warp * (32 + 8) + + (((int)threadIdx.x) / (32 / 8)) * (32 + 8) + + (((int)threadIdx.x) % (32 / 8) ) * 8; + + half* B_shared_ptr = B_shared + + ((int)threadIdx.y) * (row_stride / 2) * (128 + 8) + + (((int)threadIdx.x) / (128 / 8)) * (128 + 8) + + (((int)threadIdx.x) % (128 / 8)) * 8; + + int* zeros_ptr = zeros + + (((int)blockIdx_y) % j_factors1) * (128 / 8) + + ((int)threadIdx.x) % (128 / 8); + + half* scaling_factors_ptr = scaling_factors + + (((int)blockIdx_y) % j_factors1) * (128) + + (((int)threadIdx.x) % (128 / 8)) * 8; + + half* C_ptr = C + + blockIdx_z * M * OC // blockIdz.x -> split_k dim + + (((int)blockIdx_y) % j_factors1) * 128 + + ((int)threadIdx.y) * 64 + + (((int)threadIdx.x) % 4) * 2; + + // preload s.f. and zeros + int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters; + if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1; + for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) { + int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z; + __syncthreads(); + // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 + if (ld_A_flag) + { + *(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32)); + } + else + { + *(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0); + } + + // for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) { + uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8)); + uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded); + uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC)); + /* + if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){ + printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w); + } + */ + // uint4 B_loaded_scale = make_uint4(0, 0, 0, 0); + int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8); + + for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 8; ++ax0_ax1_fused_0) { + + // B: 32 x 136 (128+8) float16 + // each warp: 32 x 4 + // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4 + // *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8))); + // row stride in shared memory: (NWARPS * 32 * 8 / cta_N) + uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8)); + uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); + //uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8); + + // uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8); + // - zero and * scale + // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale. + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); + /* + if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){ + printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w); + } + */ + + // write back + *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (128 + 8)) = B_loaded_fp16; + } + __syncthreads(); + + for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) { + { + unsigned int addr; + __asm__ __volatile__( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + : "=r"(addr) + : "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8)))) + ); + + + __asm__ __volatile__( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16" + "{%0, %1, %2, %3}, [%4];\n" + : "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3]) + : "r"(addr) + ); + } + + for (int ax1_0 = 0; ax1_0 < 4; ++ax1_0) { + { + unsigned int addr; + __asm__ __volatile__( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + : "=r"(addr) + : "l"((void *)((&(B_shared[(((k_0_1 * 2176) + (((int)threadIdx.y) * 64)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 136) + ((((int)threadIdx.x) >> 4) * 8)))) + ); + __asm__ __volatile__( + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" + "{%0, %1, %2, %3}, [%4];\n" + : "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3]) + : "r"(addr) + ); + } + } + for (int j_0_4 = 0; j_0_4 < 4; ++j_0_4) { + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" + : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) + : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); + } + + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" + : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) + : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); + } + } + } + } + +// TODO: Shang: Hoist loop invariance. + for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) { + for (int local_id = 0; local_id < 8; ++local_id) { + int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8; + if (row_offset < M) + { + *(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]); + } + } + } +} + + +__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C) +{ + static constexpr uint32_t ZERO = 0x0; + float C_warp[32]; + __shared__ half A_shared[16 * (32 + 8)]; + __shared__ half B_shared[32 * (64 + 8)]; + + __shared__ half scaling_factors_shared[64]; + __shared__ half zeros_shared[64]; + + int j_factors1 = ((OC + 64 - 1) / 64); + + int blockIdx_x = 0; + int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1); + int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1); + + half A_shared_warp[8]; + half B_shared_warp[16]; + for (int j_0_4_init = 0; j_0_4_init < 2; ++j_0_4_init) { + for (int i = 0; i < 8; ++i) { + C_warp[(j_0_4_init * 8) + i] = 0.0; + } + } + + static constexpr int row_stride_warp = 32 * 8 / 32; + static constexpr int row_stride = 2 * 32 * 8 / 64; + bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 64; + // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 + bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id + // bool wb_C_flag = (threadIdx.x / 4) < M; + + half* A_ptr = A + + (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC + + (((int)threadIdx.x) % (32 / 8)) * 8; + + int* B_ptr = B + + ((int)threadIdx.y) * (OC / 8) * 4 + + (((int)threadIdx.x) / (64 / 8)) * (OC / 8) + + (((int)blockIdx_y) % j_factors1) * (64 / 8) + + (((int)threadIdx.x) % (64 / 8)) * 1; +// Why * 1 in the above line? + + half* A_shared_ptr = A_shared + + ((int)threadIdx.y) * row_stride_warp * (32 + 8) + + (((int)threadIdx.x) / (32 / 8)) * (32 + 8) + + (((int)threadIdx.x) % (32 / 8) ) * 8; + + half* B_shared_ptr = B_shared + + ((int)threadIdx.y) * (row_stride / 2) * (64 + 8) + + (((int)threadIdx.x) / (64 / 8)) * (64 + 8) + + (((int)threadIdx.x) % (64 / 8)) * 8; + + int* zeros_ptr = zeros + + (((int)blockIdx_y) % j_factors1) * (64 / 8) + + ((int)threadIdx.x) % (64 / 8); + + half* scaling_factors_ptr = scaling_factors + + (((int)blockIdx_y) % j_factors1) * (64) + + (((int)threadIdx.x) % (64 / 8)) * 8; + + half* C_ptr = C + + blockIdx_z * M * OC // blockIdz.x -> split_k dim + + (((int)blockIdx_y) % j_factors1) * 64 + + ((int)threadIdx.y) * 32 + + (((int)threadIdx.x) % 4) * 2; + + // preload s.f. and zeros + int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters; + if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1; + for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) { + int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z; + __syncthreads(); + // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 + if (ld_A_flag) + { + *(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32)); + } + else + { + *(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0); + } + + // for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) { + uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8)); + uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded); + uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC)); + /* + if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){ + printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w); + } + */ + // uint4 B_loaded_scale = make_uint4(0, 0, 0, 0); + int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8); + + for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 4; ++ax0_ax1_fused_0) { + + // B: 32 x 136 (128+8) float16 + // each warp: 32 x 4 + // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4 + // *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8))); + // row stride in shared memory: (NWARPS * 32 * 8 / cta_N) + uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8)); + uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); + //uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8); + + // uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8); + // - zero and * scale + // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale. + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); + /* + if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){ + printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w); + } + */ + + // write back + *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (64 + 8)) = B_loaded_fp16; + } + __syncthreads(); + + for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) + { + { + unsigned int addr; + __asm__ __volatile__( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + : "=r"(addr) + : "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8)))) + ); + __asm__ __volatile__( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16" + "{%0, %1, %2, %3}, [%4];\n" + : "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3]) + : "r"(addr) + ); + } + + + for (int ax1_0 = 0; ax1_0 < 2; ++ax1_0) + { + { + unsigned int addr; + __asm__ __volatile__( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + : "=r"(addr) + : "l"((void *)((&(B_shared[(((k_0_1 * 1152) + (((int)threadIdx.y) * 32)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 72) + ((((int)threadIdx.x) >> 4) * 8)))) + ); + __asm__ __volatile__( + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" + "{%0, %1, %2, %3}, [%4];\n" + : "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3]) + : "r"(addr) + ); + } + } + + for (int j_0_4 = 0; j_0_4 < 2; ++j_0_4) + { + + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" + : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) + : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); + } + + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" + : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) + : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); + } + } + } + } + +// TODO: Shang: Hoist loop invariance. + for (int ax1_0_1 = 0; ax1_0_1 < 2; ++ax1_0_1) { + for (int local_id = 0; local_id < 8; ++local_id) { + int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8; + if (row_offset < M) + { + *(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]); + } + } + } +} + +// in_feats: M, IC [float16] +// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b] +// scaling_factors: IC // G, OC [float16] +// zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b] +// assume that batch_size < 16 for now + +torch::Tensor gemm_forward_cuda( + torch::Tensor _in_feats, + torch::Tensor _kernel, + torch::Tensor _scaling_factors, + torch::Tensor _zeros, + int split_k_iters) +{ + int num_in_feats = _in_feats.size(0); + int num_in_channels = _in_feats.size(1); + const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats)); + + auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device()); + at::Tensor _out_feats = torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options); + int num_out_feats = _out_feats.size(-2); + int num_out_channels = _out_feats.size(-1); + + auto in_feats = reinterpret_cast(_in_feats.data_ptr()); + auto kernel = reinterpret_cast(_kernel.data_ptr()); + auto out_feats = reinterpret_cast(_out_feats.data_ptr()); + auto scaling_factors = reinterpret_cast(_scaling_factors.data_ptr()); + auto zeros = reinterpret_cast(_zeros.data_ptr()); + int group_size = num_in_channels / _scaling_factors.size(0); + + if (num_out_channels % 64 != 0) + throw std::invalid_argument("OC is not multiple of cta_N = 64"); + if (num_out_channels % 8 != 0) + throw std::invalid_argument("OC is not multiple of pack_num = 8"); + if (group_size % 32 != 0) + throw std::invalid_argument("Group size should be a multiple of 32"); + if (num_out_channels % group_size != 0) + throw std::invalid_argument("OC is not multiple of Group size"); + + if (num_out_channels % 128 == 0) + { + int j_factors1 = num_out_channels / 128 / 1; + dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); + // threadIdx.x: 32 + // threadIdx.y: i_factors[2] * j_factors[2] + dim3 threads_per_block(32, 2); + gemm_forward_4bit_cuda_m16n128k32<<>>( + group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats); + } + else if (num_out_channels % 64 == 0) + { + int j_factors1 = num_out_channels / 64 / 1; + dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); + + // threadIdx.x: 32 + // threadIdx.y: i_factors[2] * j_factors[2] + dim3 threads_per_block(32, 2); + gemm_forward_4bit_cuda_m16n64k32<<>>( + group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats); + } + return _out_feats.sum(0); +} diff --git a/awq_ext/awq_kernels/pybind.cpp b/awq_ext/awq_kernels/pybind.cpp new file mode 100644 index 000000000000..540d4f8cdc5e --- /dev/null +++ b/awq_ext/awq_kernels/pybind.cpp @@ -0,0 +1,10 @@ +// adapted from llm-awq: https://github.com/mit-han-lab/llm-awq + +#include +#include +#include "gemm_cuda.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("gemm_forward_cuda", &gemm_forward_cuda, "Quantized GEMM kernel."); +} diff --git a/awq_ext/setup.py b/awq_ext/setup.py new file mode 100644 index 000000000000..57cf54b611b4 --- /dev/null +++ b/awq_ext/setup.py @@ -0,0 +1,26 @@ +# adapted from llm-awq: https://github.com/mit-han-lab/llm-awq + +from setuptools import find_packages, setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtension + +extra_compile_args = { + "cxx": ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17"], + "nvcc": ["-O3", "-std=c++17"], +} + +setup( + name="awq_inference_engine", + packages=find_packages(), + ext_modules=[ + CUDAExtension( + name="awq_inference_engine", + sources=[ + "awq_kernels/pybind.cpp", + "awq_kernels/gemm_cuda_gen.cu", + ], + extra_compile_args=extra_compile_args, + ), + ], + cmdclass={"build_ext": BuildExtension}, + install_requires=["torch"], +) diff --git a/vllm/config.py b/vllm/config.py index 2e8d58411181..d9e857185f8a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -12,6 +12,34 @@ _GB = 1 << 30 +class QuantizationConfig: + """Quantization settings + + Args: + method: The quantization method to apply + bits: How many bits the linear layers are quantized to + group_size: What size the weights were quantized in groups of + """ + + def __init__( + self, + method: str, + bits: Optional[int] = 4, + group_size: Optional[int] = 128 + ) -> None: + self.method = method + self.bits = bits + self.group_size = group_size + self._verify() + + def _verify(self) -> None: + allowed_methods = ["awq"] + if self.method not in allowed_methods: + raise ValueError( + f"Unknown quantization method ({self.method})" + f" must be from choice of {allowed_methods}") + + class ModelConfig: """Configuration for the model. @@ -31,6 +59,7 @@ class ModelConfig: will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models. seed: Random seed for reproducibility. + quantization_config: Optional quantization settings """ def __init__( @@ -44,6 +73,7 @@ def __init__( use_dummy_weights: bool, dtype: str, seed: int, + quantization_config: Optional[QuantizationConfig] = None ) -> None: self.model = model self.tokenizer = tokenizer @@ -53,6 +83,7 @@ def __init__( self.use_np_weights = use_np_weights self.use_dummy_weights = use_dummy_weights self.seed = seed + self.quantization_config = quantization_config self.hf_config = get_config(model, trust_remote_code) self.dtype = _get_and_verify_dtype(self.hf_config, dtype) @@ -86,6 +117,10 @@ def verify_with_parallel_config( "must be divisible by pipeline parallel size " f"({pipeline_parallel_size}).") + if self.quantization_config and tensor_parallel_size > 1: + raise NotImplementedError( + "Quantization does not currently support tensor parallelism") + def get_hidden_size(self) -> int: return self.hf_config.hidden_size @@ -140,6 +175,13 @@ def get_num_layers(self, parallel_config: "ParallelConfig") -> int: total_num_hidden_layers = self.hf_config.num_hidden_layers return total_num_hidden_layers // parallel_config.pipeline_parallel_size + def get_quantization_method(self): + if self.quantization_config is None: + method = None + else: + method = self.quantization_config.method + return method + class CacheConfig: """Configuration for the KV cache. diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 99fe593b4cb0..3d7d20ef99fe 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -4,7 +4,7 @@ from typing import Optional, Tuple from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, - SchedulerConfig) + SchedulerConfig, QuantizationConfig) @dataclass @@ -28,6 +28,7 @@ class EngineArgs: max_num_batched_tokens: int = 2560 max_num_seqs: int = 256 disable_log_stats: bool = False + quantization: Optional[int] = None def __post_init__(self): if self.tokenizer is None: @@ -130,6 +131,13 @@ def add_cli_args( parser.add_argument('--disable-log-stats', action='store_true', help='disable logging statistics') + + # quantization settings + parser.add_argument('--quantization', + type=str, + default=None, + choices=['awq', None], + help='Method used to quantize the weights') return parser @classmethod @@ -144,11 +152,16 @@ def create_engine_configs( self, ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]: # Initialize the configs. + if self.quantization is not None: + quantization_config = QuantizationConfig(self.quantization) + else: + quantization_config = None + model_config = ModelConfig(self.model, self.tokenizer, self.tokenizer_mode, self.trust_remote_code, self.download_dir, self.use_np_weights, self.use_dummy_weights, self.dtype, - self.seed) + self.seed, quantization_config) cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.swap_space) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 908d01d959fd..32f4af24ba9e 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -78,6 +78,7 @@ def __init__( f"download_dir={model_config.download_dir!r}, " f"use_np_weights={model_config.use_np_weights}, " f"tensor_parallel_size={parallel_config.tensor_parallel_size}, " + f"quantization_method={model_config.get_quantization_method()}, " f"seed={model_config.seed})") # TODO(woosuk): Print more configs in debug mode. diff --git a/vllm/model_executor/layers/quant.py b/vllm/model_executor/layers/quant.py new file mode 100644 index 000000000000..f1c14eb27de3 --- /dev/null +++ b/vllm/model_executor/layers/quant.py @@ -0,0 +1,111 @@ +# adapted from llm-awq: https://github.com/mit-han-lab/llm-awq + +import torch +import torch.nn as nn + + +try: + import awq_inference_engine + KERNELS_INSTALLED = True +except ImportError as ex: + KERNELS_INSTALLED = False + + +class ScaledActivation(nn.Module): + def __init__(self, module, scales): + super().__init__() + self.act = module + self.scales = nn.Parameter(scales.data) + + def forward(self, x): + return self.act(x) / self.scales.view(1, 1, -1).to(x.device) + + +class AWQLinear(nn.Module): + def __init__( + self, + w_bit, + group_size, + in_features, + out_features, + bias, + dev + ): + super().__init__() + + if not KERNELS_INSTALLED: + raise ImportError( + "Unable to import awq_ext: run setup.py" + " to install AWQ CUDA kernels") + + if w_bit not in [4]: + raise NotImplementedError("Only 4-bit are supported for now.") + + self.in_features = in_features + self.out_features = out_features + self.w_bit = w_bit + self.group_size = group_size if group_size != -1 else in_features + + # quick sanity check (make sure aligment) + assert self.in_features % self.group_size == 0 + assert out_features % (32 // self.w_bit) == 0 + + qweight_buffer = torch.empty( + (in_features, out_features // (32 // self.w_bit)), + dtype=torch.int32, + device=dev + ) + self.register_buffer("qweight", qweight_buffer) + + qzeros_buffer = torch.empty( + ( + in_features // self.group_size, + out_features // (32 // self.w_bit) + ), + dtype=torch.int32, + device=dev + ) + self.register_buffer("qzeros", qzeros_buffer) + + scales_buffer = torch.empty( + (in_features // self.group_size, out_features), + dtype=torch.float16, + device=dev + ) + self.register_buffer("scales", scales_buffer) + + if bias: + bias_buffer = torch.empty( + (out_features), + dtype=torch.float16, + device=dev + ) + self.register_buffer("bias", bias_buffer) + else: + self.bias = None + + @torch.no_grad() + def forward(self, x): + out_shape = x.shape[:-1] + (self.out_features, ) + + out = awq_inference_engine.gemm_forward_cuda( + x.reshape(-1, x.shape[-1]), + self.qweight, + self.scales, + self.qzeros, + 8 + ) + + out = out + self.bias if self.bias is not None else out + return out.reshape(out_shape) + + def extra_repr(self) -> str: + str_repr = "in_features={}, out_features={}, " \ + "bias={}, w_bit={}, group_size={}" + return str_repr.format( + self.in_features, + self.out_features, + self.bias is not None, + self.w_bit, + self.group_size + ) diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index 85d917e6d3b5..aa897acf3bec 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -39,13 +39,25 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: f"Supported architectures: {list(_MODEL_REGISTRY.keys())}") +def _supports_quantization(model_class): + return model_class is LlamaForCausalLM + + def get_model(model_config: ModelConfig) -> nn.Module: model_class = _get_model_architecture(model_config.hf_config) torch.set_default_dtype(model_config.dtype) # Create a model instance. # The weights will be initialized as empty tensors. - model = model_class(model_config.hf_config) + + if _supports_quantization(model_class): + model = model_class( + model_config.hf_config, + model_config.quantization_config + ) + else: + model = model_class(model_config.hf_config) + if model_config.use_dummy_weights: model = model.cuda() # NOTE(woosuk): For accurate performance evaluation, we assign @@ -56,4 +68,5 @@ def get_model(model_config: ModelConfig) -> nn.Module: model.load_weights(model_config.model, model_config.download_dir, model_config.use_np_weights) model = model.cuda() + return model.eval() diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 93ab499e64a2..4ec7a6b1a2e4 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -31,11 +31,13 @@ from torch import nn from transformers import LlamaConfig +from vllm.config import QuantizationConfig from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.attention import PagedAttentionWithRoPE +from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers import quant from vllm.model_executor.weight_utils import (hf_model_weights_iterator, load_tensor_parallel_weights) from vllm.model_executor.parallel_utils.parallel_state import ( @@ -138,21 +140,143 @@ def forward( return output -class LlamaDecoderLayer(nn.Module): +def get_quantized_layer(in_features, out_features, quant_config): + layer = quant.AWQLinear( + w_bit=quant_config.bits, + group_size=quant_config.group_size, + in_features=in_features, + out_features=out_features, + bias=None, + dev=torch.cuda.current_device() + ) + return layer - def __init__(self, config: LlamaConfig): + +class QuantLlamaAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + quant_config: QuantizationConfig + ): super().__init__() - self.hidden_size = config.hidden_size - self.self_attn = LlamaAttention( - hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - num_kv_heads=config.num_key_value_heads, + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + assert tp_size == 1, "quantization does not support TP" + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + assert self.total_num_kv_heads % tp_size == 0 + self.num_kv_heads = self.total_num_kv_heads // tp_size + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + + self.qkv_proj = get_quantized_layer( + hidden_size, + self.q_size + 2 * self.kv_size, + quant_config ) - self.mlp = LlamaMLP( - hidden_size=self.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, + + self.o_proj = get_quantized_layer( + self.total_num_heads * self.head_dim, + hidden_size, + quant_config ) + + self.attn = PagedAttentionWithRoPE(self.num_heads, + self.head_dim, + self.scaling, + rotary_dim=self.head_dim, + num_kv_heads=self.num_kv_heads) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + cache_event: Optional[torch.cuda.Event], + ) -> torch.Tensor: + qkv = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + k_cache, v_cache = kv_cache + attn_output = self.attn(positions, q, k, v, k_cache, v_cache, + input_metadata, cache_event) + return self.o_proj(attn_output) + + +class QuantLlamaMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: QuantizationConfig + ): + super().__init__() + + self.gate_up_proj = get_quantized_layer( + hidden_size, + 2 * intermediate_size, quant_config + ) + + self.down_proj = get_quantized_layer( + intermediate_size, + hidden_size, + quant_config + ) + + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x = self.down_proj(x) + return x + + +class LlamaDecoderLayer(nn.Module): + + def __init__(self, config: LlamaConfig, quant_config: QuantizationConfig): + super().__init__() + self.hidden_size = config.hidden_size + + use_quantized_layers = quant_config and quant_config.method is not None + + if use_quantized_layers: + self.self_attn = QuantLlamaAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + quant_config=quant_config + ) + self.mlp = QuantLlamaMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config + ) + else: + self.self_attn = LlamaAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + ) + self.mlp = LlamaMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, @@ -188,18 +312,22 @@ def forward( class LlamaModel(nn.Module): - def __init__(self, config: LlamaConfig): + def __init__(self, config: LlamaConfig, quant_config: QuantizationConfig): super().__init__() self.config = config + self.quant_config = quant_config self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size vocab_size = ((config.vocab_size + 63) // 64) * 64 self.embed_tokens = VocabParallelEmbedding( vocab_size, config.hidden_size, perform_initialization=False) + self.layers = nn.ModuleList([ - LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers) + LlamaDecoderLayer(config, quant_config) + for _ in range(config.num_hidden_layers) ]) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( @@ -230,10 +358,11 @@ def forward( class LlamaForCausalLM(nn.Module): - def __init__(self, config): + def __init__(self, config, quant_config): super().__init__() self.config = config - self.model = LlamaModel(config) + self.quant_config = quant_config + self.model = LlamaModel(config, quant_config) vocab_size = ((config.vocab_size + 63) // 64) * 64 self.lm_head = ColumnParallelLinear(config.hidden_size, vocab_size, @@ -272,6 +401,7 @@ def load_weights(self, kv_proj_shard_size = (self.config.hidden_size // self.config.num_attention_heads * self.config.num_key_value_heads // tp_size) + attention_weight_specs = [ # (weight_name, shard_size, offset) ("q_proj", q_proj_shard_size, 0), @@ -296,16 +426,27 @@ def load_weights(self, extra_rows = extra_rows.to(loaded_weight) loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0) + is_quantized = self.quant_config is not None + is_attention_weight = False for weight_name, shard_size, offset in attention_weight_specs: if weight_name not in name: continue param = state_dict[name.replace(weight_name, "qkv_proj")] - loaded_weight = loaded_weight[ - shard_size * tensor_model_parallel_rank:shard_size * - (tensor_model_parallel_rank + 1)] - param_slice = param.data[offset:offset + shard_size] + if not is_quantized: + loaded_weight = loaded_weight[ + shard_size * tensor_model_parallel_rank:shard_size * + (tensor_model_parallel_rank + 1)] + param_slice = param.data[offset:offset + shard_size] + else: + # TODO: this is specific to AWQ + if "qweight" in name or "qzeros" in name: + adjustment = 32 / self.quant_config.bits + shard_size = int(shard_size // adjustment) + offset = int(offset // adjustment) + param_slice = param.data[:, offset:offset + shard_size] + assert param_slice.shape == loaded_weight.shape param_slice.copy_(loaded_weight) @@ -319,12 +460,20 @@ def load_weights(self, if weight_name not in name: continue param = state_dict[name.replace(weight_name, "gate_up_proj")] - shard_size = param.shape[0] // 2 - loaded_weight = loaded_weight[ - shard_size * tensor_model_parallel_rank:shard_size * - (tensor_model_parallel_rank + 1)] - param_slice = param.data[shard_size * stride_id:shard_size * - (stride_id + 1)] + + if not is_quantized: + shard_size = param.shape[0] // 2 + loaded_weight = loaded_weight[ + shard_size * tensor_model_parallel_rank:shard_size * + (tensor_model_parallel_rank + 1)] + param_slice = param.data[shard_size * stride_id:shard_size * + (stride_id + 1)] + else: + shard_size = param.shape[1] // 2 + start = shard_size * stride_id + end = shard_size * (stride_id + 1) + param_slice = param.data[:, start:end] + assert param_slice.shape == loaded_weight.shape param_slice.copy_(loaded_weight) is_gate_up_weight = True