Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
fed3d61
add quantisation config
robirv938 Aug 14, 2023
0f936d0
pass down quantisation setings
robirv938 Aug 14, 2023
520394a
american englihs
robirv938 Aug 14, 2023
640cedf
llama add the code for quantization
robirv938 Aug 14, 2023
0437ffa
update
robirv938 Aug 14, 2023
861d3d7
merge in the AWQ code with note saying its source
robirv938 Aug 14, 2023
d659e95
update
robirv938 Aug 14, 2023
c0e4862
update
ri938 Aug 14, 2023
fed311e
update
ri938 Aug 14, 2023
0937043
fix loading of layers
ri938 Aug 14, 2023
7109bd3
update
ri938 Aug 14, 2023
5bd5ed6
update
ri938 Aug 14, 2023
c3cc5ed
quantization config is part of the model config
ri938 Aug 15, 2023
02bdfed
function
ri938 Aug 15, 2023
2f97151
update
ri938 Aug 15, 2023
c39ec2a
Merge pull request #2 from ri938/add_awq_improvements
ri938 Aug 15, 2023
e5434ef
working prototype
ri938 Aug 15, 2023
ff4d693
merge linear layers
ri938 Aug 15, 2023
033e8c1
update
ri938 Aug 15, 2023
a3ac858
Merge pull request #3 from ri938/merge_linear_layers
ri938 Aug 15, 2023
fbaf889
fix pylint errors
ri938 Aug 16, 2023
db4db0c
improve the quant weight loaded code
ri938 Aug 16, 2023
73db30f
Merge pull request #5 from ri938/more_improvements_awq
ri938 Aug 16, 2023
eed1888
update
ri938 Aug 24, 2023
aaea899
Merge pull request #7 from ri938/remove_fies
ri938 Aug 24, 2023
67b614b
update
ri938 Aug 24, 2023
878a370
update
ri938 Aug 24, 2023
2617c55
Merge pull request #8 from ri938/organise
ri938 Aug 24, 2023
5fcc1c4
update
ri938 Aug 24, 2023
010b5bc
Merge pull request #9 from ri938/organise
ri938 Aug 24, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -173,3 +173,7 @@ cython_debug/

# Sphinx documentation
_build/

# vim swap files
*.swo
*.swp
79 changes: 79 additions & 0 deletions awq_ext/awq_kernels/dequantize.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h

@article{lin2023awq,
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
journal={arXiv},
year={2023}
}
*/

#pragma once


__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
{
uint4 result;

uint32_t* h = reinterpret_cast<uint32_t*>(&result);
uint32_t const i4s = reinterpret_cast<uint32_t const&>(source);

// First, we extract the i4s and construct an intermediate fp16 number.
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
static constexpr uint32_t TOP_MASK = 0x00f000f0;
static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;

// Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing
// format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions.
// In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and
// elt_67 to fp16 without having to shift them to the bottom bits before hand.

// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue
// immediately before required.
const uint32_t top_i4s = i4s >> 8;
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[0])
: "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[1])
: "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[2])
: "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[3])
: "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));

// I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the
// half2 ctor. In this case, I chose performance reliability over code readability.

// This is the half2 {1032, 1032} represented as an integer.
// static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
// Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7]
static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400;
// This is the half2 {1 / 16, 1 / 16} represented as an integer.
static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
// This is the half2 {-72, -72} represented as an integer.
// static constexpr uint32_t NEG_72 = 0xd480d480;
// Haotian: Let's use {-64, -64}.
static constexpr uint32_t NEG_64 = 0xd400d400;

// Finally, we construct the output numbers.
// Convert elt_01
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_23
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
// Convert elt_45
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_67
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));

return result;
}

4 changes: 4 additions & 0 deletions awq_ext/awq_kernels/gemm_cuda.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#include <torch/extension.h>

torch::Tensor gemm_forward_cuda(torch::Tensor _in_feats, torch::Tensor _kernel,
torch::Tensor _scaling_factors, torch::Tensor _zeros, int split_k_iters);
Loading