diff --git a/benchmarks/benchmark_evaluation.py b/benchmarks/benchmark_evaluation.py new file mode 100644 index 000000000000..00958cbb14c7 --- /dev/null +++ b/benchmarks/benchmark_evaluation.py @@ -0,0 +1,192 @@ +import argparse +# import asyncio +# import json +import os +# import random +# import time +from typing import List, Tuple, Dict + +# import aiohttp +import numpy as np +import pandas as pd +# from transformers import PreTrainedTokenizerBase +# from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm import LLM, SamplingParams, RequestOutput +from mmlu_template import MMLUTemplate + +TEMPLATE_REGITRY = { + "mmlu": MMLUTemplate, +} + + +def sample_requests( + # dataset_path: str, + # num_requests: int, + # tokenizer: PreTrainedTokenizerBase, + dev_data_path: str, + test_data_path: str, + subjects: List[str], + dataset_template: str = "mmlu", + is_analyse: bool = False, +) -> Tuple[List[str], List[str], List[int]]: + # Load the dataset. + nums_questions = [] + dataset = [] + labels = [] + template_class = TEMPLATE_REGITRY[dataset_template] + for subject in subjects: + test_dataset = pd.read_csv(os.path.join(test_data_path, subject + "_test.csv"), header=None) + nums_questions.append(len(test_dataset)) + template = template_class(subject, os.path.join(dev_data_path, subject + "_dev.csv"), is_analyse) + for idx in range(len(test_dataset)): + prompt = template.getTemplate(test_dataset, idx) + dataset.append(prompt) + labels.append(test_dataset.iloc[idx, -1]) + return dataset, labels, nums_questions + + +def run_vllm( + requests: List[str], + output_len: int, + model: str, + tokenizer: str, + kv_cache_dtype: str = "int8", + kv_quant_params_path: str = None, + tensor_parallel_size: int = 1, + seed: int = 0, + n: int = 1, + use_beam_search: bool = False, + trust_remote_code: bool = False, + quantmethod: str = None, +) -> List[RequestOutput]: + llm = LLM( + model=model, + tokenizer=tokenizer, + tensor_parallel_size=tensor_parallel_size, + seed=seed, + trust_remote_code=trust_remote_code, + kv_cache_dtype=kv_cache_dtype, + kv_quant_params_path=kv_quant_params_path, + quantization = quantmethod + ) + for prompt in requests: + sampling_params = SamplingParams( + n=n, + temperature=0.0 if use_beam_search else 1.0, + top_p=1.0, + use_beam_search=use_beam_search, + ignore_eos=True, + max_tokens=output_len, + ) + # FIXME(woosuk): Do not use internal method. + llm._add_request( + prompt=prompt, + prompt_token_ids=None, + sampling_params=sampling_params, + ) + + # FIXME(woosuk): Do use internal method. + return llm._run_engine(use_tqdm=True) + + +def evalute( + request_outputs: List[RequestOutput], + labels: List[str], + nums_questions: List[int], + subjects: List[str], + dataset_template: str = "mmlu", +) -> Dict[str, float]: + template_class = TEMPLATE_REGITRY[dataset_template] + pred = [template_class.findAnswer(r.outputs[0].text) for r in request_outputs] + ids = np.cumsum(nums_questions) + lhs = 0 + accs: List[float] = [] + for rhs in ids: + pred_paritition = np.array(pred[lhs: rhs]) + labels_partition = np.array(labels[lhs: rhs]) + acc = np.mean(pred_paritition == labels_partition) + accs.append(acc) + sub2acc = {sub: acc for sub, acc in zip(subjects, accs)} + return sub2acc + + +def main(args: argparse.Namespace): + subjects = [ + "abstract_algebra", + "anatomy", + "astronomy", + "business_ethics", + "clinical_knowledge", + "college_biology", + "college_chemistry", + "college_computer_science", + "college_mathematics", + ] + dataset, labels, nums_questions = sample_requests( + args.dev_data_path, + args.test_data_path, + subjects, + is_analyse=args.is_analyse + ) + request_outputs = run_vllm( + dataset, + args.output_len, + args.model, + args.tokenizer, + args.kv_cache_dtype, + args.kv_quant_params_path, + args.tensor_parallel_size, + args.seed, args.n, + args.use_beam_search, + args.trust_remote_code, + args.quantization + ) + sub2acc = evalute( + request_outputs, + labels, + nums_questions, + subjects, + ) + print(sub2acc) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="evaluation for quantization.") + + parser.add_argument("--model", type=str, default="facebook/opt-125m") + parser.add_argument("--tokenizer", type=str, default=None) + parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1) + parser.add_argument("--n", + type=int, + default=1, + help="Number of generated sequences per prompt.") + parser.add_argument("--use-beam-search", action="store_true") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument('--trust-remote-code', + action='store_true', + help='trust remote code from huggingface') + parser.add_argument("--dev-data-path", + type=str, + default=None, + help="path to few-shot dataset") + parser.add_argument("--test-data-path", + type=str, + default=None, + help="path to test dataset") + parser.add_argument("--is-analyse", + action="store_true") + parser.add_argument("--output-len", + type=int, + default=100, + help="nums of max token for evaluation outputs") + parser.add_argument("--kv-cache-dtype", + type=str, + default="float16") + parser.add_argument("--kv-quant-params-path", + type=str, + default=None) + parser.add_argument("--quantization", + type=str, + default=None) + args = parser.parse_args() + main(args) diff --git a/benchmarks/mmlu_template.py b/benchmarks/mmlu_template.py new file mode 100644 index 000000000000..81a7f8bc6128 --- /dev/null +++ b/benchmarks/mmlu_template.py @@ -0,0 +1,119 @@ +import pandas as pd +import json +from langchain.prompts import PromptTemplate + +template = PromptTemplate( + input_variables=["question", "A", "B", "C", "D", "Answer"], + template= + """ +USER: {question} +A. {A} +B. {B} +C. {C} +D. {D} ASSISTANT: Answer: {Answer} +""", +) + +template_with_analyse = PromptTemplate( + input_variables=["question", "A", "B", "C", "D"], + template= + """ +Q:{question} +(A) {A} (B) {B} (C) {C} (D) {D} +A: Let's think step by step. +""", +) + + +def gen_prompt(train_df, subject, k=1): + prompt = "SYSTEM: The following are multiple choice questions (with answers) about {}," \ + "Please select the correct answer from the options.".format(subject.replace('_', ' ')) + + for i in range(k): + prompt += template.format(question=train_df.iloc[i, 0], + A=train_df.iloc[i, 1], + B=train_df.iloc[i, 2], + C=train_df.iloc[i, 3], + D=train_df.iloc[i, 4], + Answer=train_df.iloc[i, 5] + )[1:-1] + return prompt + + +## add an abstract base class or common base class for generality +class MMLUTemplate(): + + def __init__(self, subject, file_path, is_analyse): + self.fiveShotTemplate = "" + self.file_path = file_path + self.subject = subject + self.choices = ["A", "B", "C", "D"] + self.is_analyse = is_analyse + self.few_shot_template = "" + if not is_analyse: + self.getFewShotBaseTemplates() + else: + self.getFewShotBaseTemplateAnalyse() + + def getFewShotBaseTemplates(self, k=5): + """few_shot模板不带分析""" + dev_df = pd.read_csv(self.file_path, header=None) + + self.few_shot_template = gen_prompt(dev_df, self.subject, k) + return self.few_shot_template + + def getFewShotBaseTemplateAnalyse(self): + """few_shot模板带分析,更改json文件就行""" + mmlu_prompt = json.load(open('templates/lib_prompt/mmlu-cot.json')) + self.few_shot_template = mmlu_prompt[self.subject] + return self.few_shot_template + + def getTemplate(self, test_df, i): + """获得模板""" + if self.is_analyse: + templ = template_with_analyse.format( + question=test_df.iloc[i, 0], + A=test_df.iloc[i, 1], + B=test_df.iloc[i, 2], + C=test_df.iloc[i, 3], + D=test_df.iloc[i, 4] + ) + + return self.few_shot_template + "\n" + templ + + else: + prompt_end = template.format( + question=test_df.iloc[i, 0], + A=test_df.iloc[i, 1], + B=test_df.iloc[i, 2], + C=test_df.iloc[i, 3], + D=test_df.iloc[i, 4], + Answer='')[1:-5] + return self.few_shot_template + prompt_end + @staticmethod + def findAnswer(res): + """解析函数""" + # print("模型输出为:", res) + d = "NO" + for d_ in res: + if 65 <= ord(d_) <= 68: + d = d_ + break + # print("答案解析为:", d) + return d + + @staticmethod + def findAnwerUsingRule(res): + # print("模型输出为:", res) + result = "NO" + pattern = 'the answer is (' + try: + pred = res.lower().split(pattern)[1][0] + + if 65 <= ord(pred.upper()) <= 68: + result = pred.upper() + except: + pass + + # print("答案解析为:",result) + return result diff --git a/csrc/activation.cpp b/csrc/activation.cpp index c100f89ac737..76dda452e822 100644 --- a/csrc/activation.cpp +++ b/csrc/activation.cpp @@ -1,28 +1,19 @@ #include -void silu_and_mul( - torch::Tensor& out, - torch::Tensor& input); +void silu_and_mul(torch::Tensor &out, torch::Tensor &input); -void gelu_new( - torch::Tensor& out, - torch::Tensor& input); +void gelu_new(torch::Tensor &out, torch::Tensor &input); -void gelu_fast( - torch::Tensor& out, - torch::Tensor& input); +void gelu_fast(torch::Tensor &out, torch::Tensor &input); + +void invoke_dequant_silu_and_mul_quant(torch::Tensor &out, torch::Tensor &input, + const float scale_gate, + const float scale_up, + const float scale_out); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def( - "silu_and_mul", - &silu_and_mul, - "Activation function used in SwiGLU."); - m.def( - "gelu_new", - &gelu_new, - "GELU implementation used in GPT-2."); - m.def( - "gelu_fast", - &gelu_fast, - "Approximate GELU implementation."); + m.def("silu_and_mul", &silu_and_mul, "Activation function used in SwiGLU."); + m.def("gelu_new", &gelu_new, "GELU implementation used in GPT-2."); + m.def("gelu_fast", &gelu_fast, "Approximate GELU implementation."); + m.def("invoke_dequant_silu_and_mul_quant", &invoke_dequant_silu_and_mul_quant, "Dequant input, apply silu act and quant output"); } diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index c6ae5db8f9c4..23969b3a56bb 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -1,21 +1,21 @@ -#include #include +#include #include "dispatch_utils.h" +#include "quant_utils.cuh" namespace vllm { -template -__device__ __forceinline__ T silu(const T& x) { +template __device__ __forceinline__ T silu(const T &x) { // x * sigmoid(x) - return (T) (((float) x) / (1.0f + expf((float) -x))); + return (T)(((float)x) / (1.0f + expf((float)-x))); } -template -__global__ void silu_and_mul_kernel( - scalar_t* __restrict__ out, // [num_tokens, d] - const scalar_t* __restrict__ input, // [num_tokens, 2, d] - const int d) { +template +__global__ void +silu_and_mul_kernel(scalar_t *__restrict__ out, // [num_tokens, d] + const scalar_t *__restrict__ input, // [num_tokens, 2, d] + const int d) { const int token_idx = blockIdx.x; for (int idx = threadIdx.x; idx < d; idx += blockDim.x) { const scalar_t x = __ldg(&input[token_idx * 2 * d + idx]); @@ -24,11 +24,24 @@ __global__ void silu_and_mul_kernel( } } +__global__ void dequant_silu_and_mul_quant_kernel( + int8_t *__restrict__ out, // [num_tokens, d] + const int32_t *__restrict__ input, // [num_tokens, 2, d] + const int d, const float scale_gate, const float scale_up, + const float scale_out) { + const int token_idx = blockIdx.x; + for (int idx = threadIdx.x; idx < d; idx += blockDim.x) { + const float x = (float)__ldg(&input[token_idx * 2 * d + idx]) * scale_gate; + const float y = + (float)__ldg(&input[token_idx * 2 * d + d + idx]) * scale_up; + out[token_idx * d + idx] = float_to_int8_rn(silu(x) * y / scale_out); + } +} + } // namespace vllm -void silu_and_mul( - torch::Tensor& out, // [num_tokens, d] - torch::Tensor& input) // [num_tokens, 2 * d] +void silu_and_mul(torch::Tensor &out, // [num_tokens, d] + torch::Tensor &input) // [num_tokens, 2 * d] { int num_tokens = input.size(0); int d = input.size(1) / 2; @@ -36,25 +49,35 @@ void silu_and_mul( dim3 grid(num_tokens); dim3 block(std::min(d, 1024)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES( - input.scalar_type(), - "silu_and_mul_kernel", - [&] { - vllm::silu_and_mul_kernel<<>>( - out.data_ptr(), - input.data_ptr(), - d); - }); + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "silu_and_mul_kernel", [&] { + vllm::silu_and_mul_kernel<<>>( + out.data_ptr(), input.data_ptr(), d); + }); +} + +void invoke_dequant_silu_and_mul_quant(torch::Tensor &out, torch::Tensor &input, + const float scale_gate, + const float scale_up, + const float scale_out) { + int num_tokens = input.size(0); + int d = input.size(1) / 2; + + dim3 grid(num_tokens); + dim3 block(std::min(d, 1024)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + vllm::dequant_silu_and_mul_quant_kernel<<>>( + out.data_ptr(), input.data_ptr(), d, scale_gate, + scale_up, scale_out); } namespace vllm { // Element-wise activation kernel template. -template -__global__ void activation_kernel( - scalar_t* __restrict__ out, // [num_tokens, d] - const scalar_t* __restrict__ input, // [num_tokens, d] - const int d) { +template +__global__ void +activation_kernel(scalar_t *__restrict__ out, // [num_tokens, d] + const scalar_t *__restrict__ input, // [num_tokens, d] + const int d) { const int token_idx = blockIdx.x; for (int idx = threadIdx.x; idx < d; idx += blockDim.x) { const scalar_t x = __ldg(&input[token_idx * d + idx]); @@ -65,50 +88,44 @@ __global__ void activation_kernel( } // namespace vllm // Launch element-wise activation kernel. -#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \ - int num_tokens = input.size(0); \ - int d = input.size(1); \ - dim3 grid(num_tokens); \ - dim3 block(std::min(d, 1024)); \ - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ - VLLM_DISPATCH_FLOATING_TYPES( \ - input.scalar_type(), \ - "activation_kernel", \ - [&] { \ - vllm::activation_kernel><<>>( \ - out.data_ptr(), \ - input.data_ptr(), \ - d); \ - }); +#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \ + int num_tokens = input.size(0); \ + int d = input.size(1); \ + dim3 grid(num_tokens); \ + dim3 block(std::min(d, 1024)); \ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "activation_kernel", [&] { \ + vllm::activation_kernel> \ + <<>>(out.data_ptr(), \ + input.data_ptr(), d); \ + }); namespace vllm { -template -__device__ __forceinline__ T gelu_new_kernel(const T& x) { - const float x3 = (float) (x * x * x); - const T t = (T) tanhf((T) (0.79788456f * (float) (x + (T) (0.044715f * x3)))); - return ((T) 0.5) * x * (((T) 1.0) + t); +template __device__ __forceinline__ T gelu_new_kernel(const T &x) { + const float x3 = (float)(x * x * x); + const T t = (T)tanhf((T)(0.79788456f * (float)(x + (T)(0.044715f * x3)))); + return ((T)0.5) * x * (((T)1.0) + t); } -template -__device__ __forceinline__ T gelu_fast_kernel(const T& x) { - const float f = (float) x; - const T t = (T) tanhf(((T) (f * 0.79788456f)) * (((T) 1.0) + (T) (0.044715f * f) * x)); - return ((T) 0.5) * x * (((T) 1.0) + t); +template +__device__ __forceinline__ T gelu_fast_kernel(const T &x) { + const float f = (float)x; + const T t = + (T)tanhf(((T)(f * 0.79788456f)) * (((T)1.0) + (T)(0.044715f * f) * x)); + return ((T)0.5) * x * (((T)1.0) + t); } } // namespace vllm -void gelu_new( - torch::Tensor& out, // [num_tokens, d] - torch::Tensor& input) // [num_tokens, d] +void gelu_new(torch::Tensor &out, // [num_tokens, d] + torch::Tensor &input) // [num_tokens, d] { LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel); } -void gelu_fast( - torch::Tensor& out, // [num_tokens, d] - torch::Tensor& input) // [num_tokens, d] +void gelu_fast(torch::Tensor &out, // [num_tokens, d] + torch::Tensor &input) // [num_tokens, d] { LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel); } diff --git a/csrc/attention.cpp b/csrc/attention.cpp index 6be8a6d25ae4..e1b8159feb79 100644 --- a/csrc/attention.cpp +++ b/csrc/attention.cpp @@ -14,9 +14,31 @@ void single_query_cached_kv_attention( int max_context_len, const c10::optional& alibi_slopes); +void single_query_cached_kv_quantized_attention( + torch::Tensor& out, // [num_seqs, num_heads, head_size] + torch::Tensor& query, // [num_seqs, num_heads, head_size] + torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] + torch::Tensor& head_mapping, // [num_heads] + float scale, + torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor& context_lens, // [num_seqs] + int block_size, + int max_context_len, + const c10::optional& alibi_slopes, + const float k_scale, + const float k_zp, + const float v_scale, + const float v_zp); + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def( "single_query_cached_kv_attention", &single_query_cached_kv_attention, "Compute the attention between an input query and the cached key/value tensors"); + m.def( + "single_query_cached_kv_quantized_attention", + &single_query_cached_kv_quantized_attention, + "Compute the attention between an input query and the cached & quantized key/value tensors" + ); } diff --git a/csrc/attention/attention_dtypes.h b/csrc/attention/attention_dtypes.h index 88b4eddec7fc..ce1a03375233 100644 --- a/csrc/attention/attention_dtypes.h +++ b/csrc/attention/attention_dtypes.h @@ -4,3 +4,4 @@ #include "dtype_float16.cuh" #include "dtype_float32.cuh" #include "dtype_bfloat16.cuh" +#include "dtype_int8.cuh" diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 3fc5860bf147..ddb2ad22b535 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -17,7 +17,7 @@ */ #include #include - +#include "../quant_utils.cuh" #include "attention_dtypes.h" #include "attention_utils.cuh" @@ -338,6 +338,282 @@ __global__ void single_query_cached_kv_attention_kernel( } } +template< + typename scalar_t, + typename cache_type, + int HEAD_SIZE, + int BLOCK_SIZE, + int NUM_THREADS> +__global__ void single_query_cached_kv_attention_quantized_kernel( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_type* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] + const cache_type* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] + const int* __restrict__ head_mapping, // [num_heads] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, + const int kv_block_stride, + const int kv_head_stride, + const float k_scale, + const float k_zp, + const float v_scale, + const float v_zp) { + constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); + constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS + assert(NUM_THREADS % THREAD_GROUP_SIZE == 0); + constexpr int NUM_TOKENS_PER_THREAD_GROUP = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE; + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + const int thread_idx = threadIdx.x; + const int warp_idx = thread_idx / WARP_SIZE; + const int lane = thread_idx % WARP_SIZE; + + const int head_idx = blockIdx.x; + const int num_heads = gridDim.x; + const int kv_head_idx = head_mapping[head_idx]; + const int seq_idx = blockIdx.y; + const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; + + // A vector type to store a part of a key or a query. + // The vector size is configured in such a way that the threads in a thread group + // fetch or compute 16 bytes at a time. + // For example, if the size of a thread group is 4 and the data type is half, + // then the vector size is 16 / (4 * sizeof(half)) == 2. + constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1); + using K_vec = typename Vec::Type; + using Q_vec = typename Vec::Type; + using Vec_quant = typename Vec::Type; + using Vec_dequant = typename FloatVec::Type; + + constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE; + constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE; + + const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE; + const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE; + + // Load the query to registers. + // Each thread in a thread group has a different part of the query. + // For example, if the the thread group size is 4, then the first thread in the group + // has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ... + // th vectors of the query, and so on. + // NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous. + const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; + __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD]; +#pragma unroll + for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) { + const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; + q_vecs[thread_group_offset][i] = *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); + } + __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs + + // Memory planning. + extern __shared__ char shared_mem[]; + // NOTE(woosuk): We use FP32 for the softmax logits for better accuracy. + float* logits = reinterpret_cast(shared_mem); + // Workspace for reduction. + __shared__ float red_smem[2 * NUM_WARPS]; + + // x == THREAD_GROUP_SIZE * VEC_SIZE + // Each thread group fetches x elements from the key at a time. + constexpr int x = 16 / sizeof(cache_type); + float qk_max = -FLT_MAX; + + const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; + const int context_len = context_lens[seq_idx]; + const int num_blocks = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE; + + // Iterate over the key blocks. + // Each warp fetches a block of keys for each iteration. + // Each thread group in a warp fetches a key from the block, and computes + // dot product with the query. + for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) { + const int physical_block_number = block_table[block_idx]; + + // Load a key to registers. + // Each thread in a thread group has a different part of the key. + // For example, if the the thread group size is 4, then the first thread in the group + // has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th + // vectors of the key, and so on. + for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { + const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + K_vec k_vecs[NUM_VECS_PER_THREAD]; + +#pragma unroll + for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { + const cache_type* k_ptr = k_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride + + physical_block_offset * x; + const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; + const int offset1 = (vec_idx * VEC_SIZE) / x; + const int offset2 = (vec_idx * VEC_SIZE) % x; + // dequant and conversion + Vec_quant k_vec_quant = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); + Vec_dequant k_vec_dequant = dequant(k_vec_quant, k_scale, k_zp); + k_vecs[j] = vec_conversion(k_vec_dequant); + // k_vecs[j] = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); + } + + // Compute dot product. + // This includes a reduction across the threads in the same thread group. + float qk = scale * Qk_dot::dot(q_vecs[thread_group_offset], k_vecs); + // Add the ALiBi bias if slopes are given. + qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len) : 0; + + if (thread_group_offset == 0) { + // Store the partial reductions to shared memory. + // NOTE(woosuk): It is required to zero out the masked logits. + const bool mask = token_idx >= context_len; + logits[token_idx] = mask ? 0.f : qk; + // Update the max value. + qk_max = mask ? qk_max : fmaxf(qk_max, qk); + } + } + } + + // Perform reduction across the threads in the same warp to get the + // max qk value for each "warp" (not across the thread block yet). + // The 0-th thread of each thread group already has its max qk value. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + } + if (lane == 0) { + red_smem[warp_idx] = qk_max; + } + __syncthreads(); + + // TODO(woosuk): Refactor this part. + // Get the max qk value for the sequence. + qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + } + // Broadcast the max qk value to all threads. + qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); + + // Get the sum of the exp values. + float exp_sum = 0.f; + for (int i = thread_idx; i < context_len; i += NUM_THREADS) { + float val = __expf(logits[i] - qk_max); + logits[i] = val; + exp_sum += val; + } + exp_sum = block_sum(&red_smem[NUM_WARPS], exp_sum); + + // Compute softmax. + const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); + for (int i = thread_idx; i < context_len; i += NUM_THREADS) { + logits[i] *= inv_sum; + } + __syncthreads(); + + // Each thread will fetch 16 bytes from the value cache at a time. + constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE); + using V_vec = typename Vec::Type; + using V_vec_quant = typename Vec::Type; + using V_vec_dequant = typename FloatVec::Type; + using L_vec = typename Vec::Type; + using Float_L_vec = typename FloatVec::Type; + + constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; + constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; + constexpr int NUM_ROWS_PER_THREAD = (HEAD_SIZE + NUM_ROWS_PER_ITER - 1) / NUM_ROWS_PER_ITER; + + // NOTE(woosuk): We use FP32 for the accumulator for better accuracy. + float accs[NUM_ROWS_PER_THREAD]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + accs[i] = 0.f; + } + + for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) { + const int physical_block_number = block_table[block_idx]; + const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + L_vec logits_vec; + from_float(logits_vec, *reinterpret_cast(logits + token_idx)); + + const cache_type* v_ptr = v_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE) { + const int offset = row_idx * BLOCK_SIZE + physical_block_offset; + // dequant and conversion + V_vec_quant v_vec_quant = *reinterpret_cast(v_ptr + offset); + V_vec_dequant v_vec_dequant = dequant(v_vec_quant, v_scale, v_zp); + V_vec v_vec = vec_conversion(v_vec_dequant); + // V_vec v_vec = *reinterpret_cast(v_ptr + offset); + accs[i] += dot(logits_vec, v_vec); + } + } + } + + // Perform reduction within each warp. +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + float acc = accs[i]; +#pragma unroll + for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { + acc += __shfl_xor_sync(uint32_t(-1), acc, mask); + } + accs[i] = acc; + } + + // NOTE(woosuk): A barrier is required because the shared memory space for logits + // is reused for the output. + __syncthreads(); + + // Perform reduction across warps. + float* out_smem = reinterpret_cast(shared_mem); +#pragma unroll + for (int i = NUM_WARPS; i > 1; i /= 2) { + int mid = i / 2; + // Upper warps write to shared memory. + if (warp_idx >= mid && warp_idx < i) { + float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + dst[row_idx] = accs[i]; + } + } + } + __syncthreads(); + + // Lower warps update the output. + if (warp_idx < mid) { + const float* src = &out_smem[warp_idx * HEAD_SIZE]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + accs[i] += src[row_idx]; + } + } + } + __syncthreads(); + } + + // Write the final output. + if (warp_idx == 0) { + scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + from_float(*(out_ptr + row_idx), accs[i]); + } + } + } +} } // namespace vllm #define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \ @@ -357,6 +633,28 @@ __global__ void single_query_cached_kv_attention_kernel( kv_block_stride, \ kv_head_stride); +// specifying cache type to int8 manually +#define LAUNCH_ATTENTION_QUANTIZED_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \ + vllm::single_query_cached_kv_attention_quantized_kernel \ + <<>>( \ + out_ptr, \ + query_ptr, \ + key_cache_ptr, \ + value_cache_ptr, \ + head_mapping_ptr, \ + scale, \ + block_tables_ptr, \ + context_lens_ptr, \ + max_num_blocks_per_seq, \ + alibi_slopes_ptr, \ + q_stride, \ + kv_block_stride, \ + kv_head_stride, \ + k_scale, \ + k_zp, \ + v_scale, \ + v_zp); + // TODO(woosuk): Tune NUM_THREADS. template< typename T, @@ -442,6 +740,94 @@ void single_query_cached_kv_attention_launcher( } } +template< + typename T, + int BLOCK_SIZE, + int NUM_THREADS = 128> +void single_query_cached_kv_attention_quantized_launcher( + torch::Tensor& out, + torch::Tensor& query, + torch::Tensor& key_cache, + torch::Tensor& value_cache, + torch::Tensor& head_mapping, + float scale, + torch::Tensor& block_tables, + torch::Tensor& context_lens, + int max_context_len, + const c10::optional& alibi_slopes, + const float k_scale, + const float k_zp, + const float v_scale, + const float v_zp) { + int num_seqs = query.size(0); + int num_heads = query.size(1); + int head_size = query.size(2); + int max_num_blocks_per_seq = block_tables.size(1); + int q_stride = query.stride(0); + int kv_block_stride = key_cache.stride(0); + int kv_head_stride = key_cache.stride(1); + + int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); + assert(head_size % thread_group_size == 0); + + // NOTE: alibi_slopes is optional. + const float* alibi_slopes_ptr = alibi_slopes ? + reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; + + T* out_ptr = reinterpret_cast(out.data_ptr()); + T* query_ptr = reinterpret_cast(query.data_ptr()); + int8_t* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); // TODO: support other types + int8_t* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); // TODO: support other types + int* head_mapping_ptr = reinterpret_cast(head_mapping.data_ptr()); + int* block_tables_ptr = block_tables.data_ptr(); + int* context_lens_ptr = context_lens.data_ptr(); + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + int padded_max_context_len = ((max_context_len + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE; + int logits_size = padded_max_context_len * sizeof(float); + int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); + int shared_mem_size = std::max(logits_size, outputs_size); + + dim3 grid(num_heads, num_seqs); + dim3 block(NUM_THREADS); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + switch (head_size) { + // NOTE(woosuk): To reduce the compilation time, we omitted head sizes + // 32, 160, 192. + // case 32: + // LAUNCH_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS); + // break; + case 64: + LAUNCH_ATTENTION_QUANTIZED_KERNEL(T, 64, BLOCK_SIZE, NUM_THREADS); + break; + case 80: + LAUNCH_ATTENTION_QUANTIZED_KERNEL(T, 80, BLOCK_SIZE, NUM_THREADS); + break; + case 96: + LAUNCH_ATTENTION_QUANTIZED_KERNEL(T, 96, BLOCK_SIZE, NUM_THREADS); + break; + case 112: + LAUNCH_ATTENTION_QUANTIZED_KERNEL(T, 112, BLOCK_SIZE, NUM_THREADS); + break; + case 128: + LAUNCH_ATTENTION_QUANTIZED_KERNEL(T, 128, BLOCK_SIZE, NUM_THREADS); + break; + // case 160: + // LAUNCH_ATTENTION_KERNEL(T, 160, BLOCK_SIZE, NUM_THREADS); + // break; + // case 192: + // LAUNCH_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS); + // break; + case 256: + LAUNCH_ATTENTION_QUANTIZED_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS); + break; + default: + TORCH_CHECK(false, "Unsupported head size: ", head_size); + break; + } +} + #define CALL_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ single_query_cached_kv_attention_launcher( \ out, \ @@ -455,6 +841,24 @@ void single_query_cached_kv_attention_launcher( max_context_len, \ alibi_slopes); +#define CALL_QUANTIZED_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ + single_query_cached_kv_attention_quantized_launcher( \ + out, \ + query, \ + key_cache, \ + value_cache, \ + head_mapping, \ + scale, \ + block_tables, \ + context_lens, \ + max_context_len, \ + alibi_slopes, \ + k_scale, \ + k_zp, \ + v_scale, \ + v_zp); + + // NOTE(woosuk): To reduce the compilation time, we omitted block sizes // 1, 2, 4, 64, 128, 256. #define CALL_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ @@ -491,6 +895,40 @@ void single_query_cached_kv_attention_launcher( break; \ } +#define CALL_QUANTIZED_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ + switch (block_size) { \ + /* case 1: */ \ + /* CALL_KERNEL_LAUNCHER(T, 1); */ \ + /* break; */ \ + /* case 2: */ \ + /* CALL_KERNEL_LAUNCHER(T, 2); */ \ + /* break; */ \ + /* case 4: */ \ + /* CALL_KERNEL_LAUNCHER(T, 4); */ \ + /* break; */ \ + case 8: \ + CALL_QUANTIZED_KERNEL_LAUNCHER(T, 8); \ + break; \ + case 16: \ + CALL_QUANTIZED_KERNEL_LAUNCHER(T, 16); \ + break; \ + /*case 32: \ + CALL_QUANTIZED_KERNEL_LAUNCHER(T, 32); \ + break;*/ \ + /* case 64: */ \ + /* CALL_KERNEL_LAUNCHER(T, 64); */ \ + /* break; */ \ + /* case 128: */ \ + /* CALL_KERNEL_LAUNCHER(T, 128); */ \ + /* break; */ \ + /* case 256: */ \ + /* CALL_KERNEL_LAUNCHER(T, 256); */ \ + /* break; */ \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ + } + void single_query_cached_kv_attention( torch::Tensor& out, // [num_seqs, num_heads, head_size] torch::Tensor& query, // [num_seqs, num_heads, head_size] @@ -514,6 +952,32 @@ void single_query_cached_kv_attention( } } +void single_query_cached_kv_quantized_attention( + torch::Tensor& out, // [num_seqs, num_heads, head_size] + torch::Tensor& query, // [num_seqs, num_heads, head_size] + torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] + torch::Tensor& head_mapping, // [num_heads] + float scale, + torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor& context_lens, // [num_seqs] + int block_size, + int max_context_len, + const c10::optional& alibi_slopes, + const float k_scale, + const float k_zp, + const float v_scale, + const float v_zp) { + if (query.dtype() == at::ScalarType::Float) { + CALL_QUANTIZED_KERNEL_LAUNCHER_BLOCK_SIZE(float); + } else if (query.dtype() == at::ScalarType::Half) { + CALL_QUANTIZED_KERNEL_LAUNCHER_BLOCK_SIZE(uint16_t); + } else if (query.dtype() == at::ScalarType::BFloat16) { + CALL_QUANTIZED_KERNEL_LAUNCHER_BLOCK_SIZE(__nv_bfloat16); + } else { + TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + } +} #undef WARP_SIZE #undef MAX #undef MIN diff --git a/csrc/attention/dtype_float32.cuh b/csrc/attention/dtype_float32.cuh index b200d2d226eb..51407f35e2d0 100644 --- a/csrc/attention/dtype_float32.cuh +++ b/csrc/attention/dtype_float32.cuh @@ -86,6 +86,14 @@ inline __device__ float4 add(float4 a, float4 b) { return c; } +// for compiling, the above function seems to be useless +inline __device__ Float4_ add(Float4_ a, Float4_ b) { + Float4_ c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + return c; +} + // Vector multiplication. template<> inline __device__ float mul(float a, float b) { diff --git a/csrc/attention/dtype_int8.cuh b/csrc/attention/dtype_int8.cuh new file mode 100644 index 000000000000..91e6ec40b038 --- /dev/null +++ b/csrc/attention/dtype_int8.cuh @@ -0,0 +1,49 @@ +#pragma once + +#include +#include "attention_generic.cuh" +#include "dtype_float32.cuh" + +namespace vllm { +// define int8 vector types for quantization of kv cache + +template<> +struct Vec { + using Type = int8_t; +}; + +template<> +struct Vec { + using Type = int16_t; +}; + +template<> +struct Vec { + using Type = int32_t; +}; + +template<> +struct Vec { + using Type = int64_t; +}; + +template<> +struct FloatVec { + using Type = float; +}; + +template<> +struct FloatVec { + using Type = float2; +}; + +template<> +struct FloatVec { + using Type = Float4_; +}; + +template<> +struct FloatVec { + using Type = Float8_; +}; +} diff --git a/csrc/cache.cpp b/csrc/cache.cpp index 9ae17bb2985c..5ada275ad472 100644 --- a/csrc/cache.cpp +++ b/csrc/cache.cpp @@ -27,6 +27,17 @@ void gather_cached_kv( torch::Tensor& value_cache, torch::Tensor& slot_mapping); +void reshape_and_cache_quantized( + torch::Tensor& key, // [num_tokens, num_heads, head_size] + torch::Tensor& value, // [num_tokens, num_heads, head_size] + torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] + torch::Tensor& slot_mapping, // [num_tokens] + const float k_scale, + const float k_zp, + const float v_scale, + const float v_zp); + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def( "swap_blocks", @@ -44,4 +55,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "gather_cached_kv", &gather_cached_kv, "Gather key and value from the cache into contiguous QKV tensors"); + m.def( + "reshape_and_cache_quantized", + &reshape_and_cache_quantized, + "Reshape and quantized key and value tensors and cache them"); } diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index ddad2b5a29b9..948193278d29 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -7,6 +7,7 @@ #include #include #include +#include "quant_utils.cuh" void swap_blocks( torch::Tensor& src, @@ -127,7 +128,7 @@ void copy_blocks( dim3 grid(num_layers, num_pairs); dim3 block(std::min(1024, numel_per_block)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES( + VLLM_DISPATCH_QUANT_TYPES( key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] { vllm::copy_blocks_kernel<<>>( key_cache_ptrs_tensor.data_ptr(), @@ -137,6 +138,7 @@ void copy_blocks( })); } + namespace vllm { template @@ -181,6 +183,54 @@ __global__ void reshape_and_cache_kernel( } } +template // cache_dtype can only be int8_t for now +__global__ void reshape_and_cache_quantized_kernel( + const attn_dtype* __restrict__ key, // [num_tokens, num_heads, head_size] + const attn_dtype* __restrict__ value, // [num_tokens, num_heads, head_size] + cache_dtype* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + cache_dtype* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size] + const int* __restrict__ slot_mapping, // [num_tokens] + const int key_stride, + const int value_stride, + const int num_heads, + const int head_size, + const int block_size, + const int x, + const float k_scale, + const float k_zp, + const float v_scale, + const float v_zp) { + const int token_idx = blockIdx.x; + const int slot_idx = slot_mapping[token_idx]; + const int block_idx = slot_idx / block_size; + const int block_offset = slot_idx % block_size; + + const int n = num_heads * head_size; + for (int i = threadIdx.x; i < n; i += blockDim.x) { + const int src_key_idx = token_idx * key_stride + i; + const int src_value_idx = token_idx * value_stride + i; + + const int head_idx = i / head_size; + const int head_offset = i % head_size; + const int x_idx = head_offset / x; + const int x_offset = head_offset % x; + + const int tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x + + head_idx * (head_size / x) * block_size * x + + x_idx * block_size * x + + block_offset * x + + x_offset; + const int tgt_value_idx = block_idx * num_heads * head_size * block_size + + head_idx * head_size * block_size + + head_offset * block_size + + block_offset; + // TODO (Lin Pengyun): use vector reading and quantization to improve IO ultilization + attn_dtype tgt_key = __ldg(&key[src_key_idx]); + key_cache[tgt_key_idx] = quant(tgt_key, k_scale, k_zp); + attn_dtype tgt_value = __ldg(&value[src_value_idx]); + value_cache[tgt_value_idx] = quant(tgt_value, v_scale, v_zp); + } +} } // namespace vllm void reshape_and_cache( @@ -221,6 +271,52 @@ void reshape_and_cache( }); } +void reshape_and_cache_quantized( + torch::Tensor& key, // [num_tokens, num_heads, head_size] + torch::Tensor& value, // [num_tokens, num_heads, head_size] + torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] + torch::Tensor& slot_mapping, // [num_tokens] + const float k_scale, + const float k_zp, + const float v_scale, + const float v_zp) +{ + int num_tokens = key.size(0); + int num_heads = key.size(1); + int head_size = key.size(2); + int block_size = key_cache.size(3); + int x = key_cache.size(4); + + int key_stride = key.stride(0); + int value_stride = value.stride(0); + + dim3 grid(num_tokens); + dim3 block(std::min(num_heads * head_size, 512)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES( + key.scalar_type(), + "reshape_and_cache_quantized_kernel", + [&] { + vllm::reshape_and_cache_quantized_kernel<<>>( + key.data_ptr(), + value.data_ptr(), + key_cache.data_ptr(), + value_cache.data_ptr(), + slot_mapping.data_ptr(), + key_stride, + value_stride, + num_heads, + head_size, + block_size, + x, + k_scale, + k_zp, + v_scale, + v_zp); + }); +} + namespace vllm { // Grid: (num_blocks, block_size). diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h index 7c0c49d392a9..921d453b703c 100644 --- a/csrc/dispatch_utils.h +++ b/csrc/dispatch_utils.h @@ -7,8 +7,17 @@ #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ + // AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) + +#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \ + VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH( \ TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) + +#define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, NAME, VLLM_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__)) diff --git a/csrc/fused.cpp b/csrc/fused.cpp new file mode 100644 index 000000000000..b3cd340047ca --- /dev/null +++ b/csrc/fused.cpp @@ -0,0 +1,26 @@ +#include + +void invoke_dequant_add_residual( + torch::Tensor &out, // [num_tokens, hidden_size] + torch::Tensor &input, // [num_tokens, hidden_size] + torch::Tensor &residual, // [num_tokens, hidden_size] + float scale); + +void invoke_dequant( + torch::Tensor &out, // [num_tokens, hidden_size] + torch::Tensor &input, // [num_tokens, hidden_size] + float scale); + +void invoke_quant( + torch::Tensor &out, // [num_tokens, hidden_size] + torch::Tensor &input, // [num_tokens, hidden_size] + float scale); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("invoke_dequant_add_residual", &invoke_dequant_add_residual, + "Add the dequanted result and residual."); + m.def("invoke_dequant", &invoke_dequant, + "Dequant."); + m.def("invoke_quant", &invoke_quant, + "Quant."); +} diff --git a/csrc/fused_kernels.cu b/csrc/fused_kernels.cu new file mode 100644 index 000000000000..83f357142aa9 --- /dev/null +++ b/csrc/fused_kernels.cu @@ -0,0 +1,100 @@ +#include +#include + +#include "dispatch_utils.h" +#include "quant_utils.cuh" +#include + +namespace vllm { +template +__global__ void dequant_add_residual_kernel(const int32_t *__restrict__ input, + const T *__restrict__ residual, + T *__restrict__ output, + const float scale, int m, int n) { + const int tid = threadIdx.x; + for (int i = tid; i < n; i += blockDim.x) { + output[blockIdx.x * n + i] = + (T)((((float)input[blockIdx.x * n + i]) * scale) + + (float)residual[blockIdx.x * n + i]); + } +} + +template +__global__ void dequant_kernel(const int32_t *__restrict__ input, + T *__restrict__ output, + const float scale, int m, int n, int input_stride, int out_stride) { + const int tid = threadIdx.x; + for (int i = tid; i < n; i += blockDim.x) { + output[blockIdx.x * out_stride + i] = + (T)(((float)input[blockIdx.x * input_stride + i]) * scale); + } +} + +template +__global__ void quant_kernel(const T *__restrict__ input, + int8_t *__restrict__ output, + const float scale, int m, int n) { + const int tid = threadIdx.x; + for (int i = tid; i < n; i += blockDim.x) { + output[blockIdx.x * n + i] = + float_to_int8_rn(((float)input[blockIdx.x * n + i]) / scale); + } +} +} // namespace vllm + +void invoke_dequant_add_residual( + torch::Tensor &out, // [num_tokens, hidden_size] + torch::Tensor &input, // [num_tokens, hidden_size] + torch::Tensor &residual, // [num_tokens, hidden_size] + float scale) { + int m = input.size(0); + int n = input.size(1); + dim3 grid(m); + dim3 block(min(n, 1024)); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES( + residual.scalar_type(), "dequant_add_residual_kernel", [&] { + vllm::dequant_add_residual_kernel<<>>( + input.data_ptr(), residual.data_ptr(), + out.data_ptr(), scale, m, n); + }); +} + +void invoke_dequant( + torch::Tensor &out, // [num_tokens, hidden_size] + torch::Tensor &input, // [num_tokens, hidden_size] + float scale) { + int m = input.size(0); + int n = input.size(1); + int input_stride = input.stride(0); + int out_stride = out.stride(0); + dim3 grid(m); + dim3 block(min(n, 1024)); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES( + out.scalar_type(), "dequant_kernel", [&] { + vllm::dequant_kernel<<>>( + input.data_ptr(), out.data_ptr(), scale, m, n, input_stride, out_stride); + }); +} + +void invoke_quant( + torch::Tensor &out, // [num_tokens, hidden_size] + torch::Tensor &input, // [num_tokens, hidden_size] + float scale) { + assert(input.is_contiguous()); + assert(out.is_contiguous()); + int m = input.size(0); + int n = input.size(1); + dim3 grid(m); + dim3 block(min(n, 1024)); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "quant_kernel", [&] { + vllm::quant_kernel<<>>( + input.data_ptr(), out.data_ptr(), scale, m, n); + }); +} \ No newline at end of file diff --git a/csrc/int8gemm/cublas/allocator.h b/csrc/int8gemm/cublas/allocator.h new file mode 100644 index 000000000000..7ac1345441f1 --- /dev/null +++ b/csrc/int8gemm/cublas/allocator.h @@ -0,0 +1,424 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * Memory Allocator + **/ + +#pragma once + +#include "cuda_utils.h" +#include +#include +#include + +#ifdef GOOGLE_CUDA +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#endif + +#ifdef TORCH_CUDA +#include "torch/extension.h" +#include +#endif + +#if defined(CUDART_VERSION) && CUDART_VERSION < 11020 +#define CUDA_MEMORY_POOL_DISABLED +#endif + +enum class AllocatorType { CUDA, TF, TH }; + +enum class ReallocType { + INCREASE, + REUSE, + DECREASE, +}; + +class IAllocator { +public: + virtual ~IAllocator(){}; + + virtual void *malloc(size_t size, const bool is_set_zero = true, + bool is_host = false) = 0; + virtual void free(void **ptr, bool is_host = false) const = 0; + virtual void setStream(cudaStream_t stream) = 0; + virtual cudaStream_t returnStream() = 0; + virtual void memSet(void *ptr, const int val, const size_t size) = 0; + + template + void *reMalloc(T *ptr, size_t size, const bool is_set_zero = true, + bool is_host = false) { + // FT_LOG_DEBUG(__PRETTY_FUNCTION__); + size = ((size + 31) / 32) * 32; // make the buffer align with 32 bytes + void *void_ptr = (void *)ptr; + void *ptr_address = getAddress(void_ptr); + if (isExist(ptr_address)) { + ReallocType realloc_type = isReMalloc(ptr_address, size); + if (realloc_type == ReallocType::INCREASE) { + // FT_LOG_DEBUG("ReMalloc the buffer %p since it is too small.", + // void_ptr); + free((void **)(&void_ptr), is_host); + return malloc(size, is_set_zero, is_host); + } +#if !defined(CUDA_MEMORY_POOL_DISABLED) + else if (realloc_type == ReallocType::DECREASE) { + // FT_LOG_DEBUG("ReMalloc the buffer %p to release unused memory to + // memory pools.", void_ptr); + free((void **)(&void_ptr), is_host); + return malloc(size, is_set_zero, is_host); + } +#endif + else { + // FT_LOG_DEBUG("Reuse original buffer %p with size %d and do nothing + // for reMalloc.", void_ptr, size); + if (is_set_zero) { + memSet(void_ptr, 0, size); + } + return void_ptr; + } + } else { + // FT_LOG_DEBUG("Cannot find buffer %p, mallocing new one.", void_ptr); + return malloc(size, is_set_zero, is_host); + } + } + +protected: + virtual bool isExist(void *address) const = 0; + virtual ReallocType isReMalloc(void *address, size_t size) const = 0; + + void *getAddress(void *ptr) const { return ptr; } +}; + +template class Allocator; + +template <> class Allocator : public IAllocator { +private: + const int device_id_; + cudaStream_t stream_ = 0; // initialize as default stream + std::unordered_map *pointer_mapping_; + + bool isExist(void *address) const { + return pointer_mapping_->count(address) > 0; + } + ReallocType isReMalloc(void *address, size_t size) const { + FT_CHECK(isExist(address)); + if (pointer_mapping_->at(address) < size) { + return ReallocType::INCREASE; + } else if (pointer_mapping_->at(address) == size) { + return ReallocType::REUSE; + } else { + return ReallocType::DECREASE; + } + } + +public: + Allocator(int device_id) : device_id_(device_id) { + // FT_LOG_DEBUG(__PRETTY_FUNCTION__); + pointer_mapping_ = new std::unordered_map(); +#if defined(CUDA_MEMORY_POOL_DISABLED) + // FT_LOG_WARNING( + // "Async cudaMalloc/Free is not supported before CUDA 11.2. Using Sync + // cudaMalloc/Free." "Note this may lead to hang with NCCL kernels + // launched in parallel; if so, try NCCL_LAUNCH_MODE=GROUP"); +#else + int device_count = 1; + check_cuda_error(cudaGetDeviceCount(&device_count)); + cudaMemPool_t mempool; + check_cuda_error(cudaDeviceGetDefaultMemPool(&mempool, device_id)); + cudaMemAccessDesc desc = {}; + int peer_access_available = 0; + for (int i = 0; i < device_count; i++) { + if (i == device_id) { + continue; + } + check_cuda_error( + cudaDeviceCanAccessPeer(&peer_access_available, device_id, i)); + if (!peer_access_available) { + // FT_LOG_WARNING("Device " + std::to_string(device_id) + " peer access + // Device " + std::to_string(i) + // + " is not available."); + continue; + } + desc.location.type = cudaMemLocationTypeDevice; + desc.location.id = i; + desc.flags = cudaMemAccessFlagsProtReadWrite; + check_cuda_error(cudaMemPoolSetAccess(mempool, &desc, 1)); + } + // set memory pool threshold to avoid shrinking the pool + uint64_t setVal = UINT64_MAX; + check_cuda_error(cudaMemPoolSetAttribute( + mempool, cudaMemPoolAttrReleaseThreshold, &setVal)); +#endif + } + + virtual ~Allocator() { + // FT_LOG_DEBUG(__PRETTY_FUNCTION__); + while (!pointer_mapping_->empty()) { + free((void **)(&pointer_mapping_->begin()->first)); + } + delete pointer_mapping_; + } + + void setStream(cudaStream_t stream) { stream_ = stream; } + + cudaStream_t returnStream() { return stream_; }; + + void *malloc(size_t size, const bool is_set_zero = true, + bool is_host = false) { + // FT_LOG_DEBUG(__PRETTY_FUNCTION__); + if (size == 0) { + return nullptr; + } + void *ptr = nullptr; + int o_device = 0; + + check_cuda_error(getSetDevice(device_id_, &o_device)); + if (is_host) { + check_cuda_error(cudaMallocHost(&ptr, (size_t)(ceil(size / 32.)) * 32)); + } else { +#if defined(CUDA_MEMORY_POOL_DISABLED) + check_cuda_error(cudaMalloc(&ptr, (size_t)(ceil(size / 32.)) * 32)); +#else + check_cuda_error( + cudaMallocAsync(&ptr, (size_t)(ceil(size / 32.)) * 32, stream_)); +#endif + } + if (is_set_zero) { + check_cuda_error( + cudaMemsetAsync(ptr, 0, (size_t)(ceil(size / 32.)) * 32, stream_)); + } + check_cuda_error(getSetDevice(o_device)); + // FT_LOG_DEBUG("malloc buffer %p with size %ld", ptr, size); + + pointer_mapping_->insert({getAddress(ptr), size}); + + return ptr; + } + + void free(void **ptr, bool is_host = false) const { + // FT_LOG_DEBUG(__PRETTY_FUNCTION__); + void *address = getAddress(*ptr); + if (*ptr != nullptr) { + int o_device = 0; + if (pointer_mapping_->count(address)) { + // FT_LOG_DEBUG("Free buffer %p", address); + check_cuda_error(getSetDevice(device_id_, &o_device)); + if (is_host) { + check_cuda_error(cudaFreeHost(*ptr)); + } else { +#if defined(CUDA_MEMORY_POOL_DISABLED) + check_cuda_error(cudaFree(*ptr)); +#else + check_cuda_error(cudaFreeAsync(*ptr, stream_)); + cudaStreamSynchronize(stream_); +#endif + } + check_cuda_error(getSetDevice(o_device)); + pointer_mapping_->erase(address); + } else { + // FT_LOG_WARNING("pointer_mapping_ does not have information of ptr at + // %p.", address); + } + } + *ptr = nullptr; + return; + } + + void memSet(void *ptr, const int val, const size_t size) { + check_cuda_error(cudaMemsetAsync(ptr, val, size, stream_)); + } +}; + +#ifdef GOOGLE_CUDA +using namespace tensorflow; +template <> class Allocator : public IAllocator { + OpKernelContext *context_; + std::unordered_map *pointer_mapping_; + cudaStream_t stream_; + + bool isExist(void *address) const { + return pointer_mapping_->count(address) > 0; + } + ReallocType isReMalloc(void *address, size_t size) const { + FT_CHECK(isExist(address)); + size_t current_buffer_size = 1; + for (int i = 0; i < pointer_mapping_->at(address).dims(); i++) { + current_buffer_size *= pointer_mapping_->at(address).dim_size(i); + } + // FT_LOG_DEBUG("current_buffer_size: %d, new buffer: %d", + // current_buffer_size, size); + if (current_buffer_size < size) { + return ReallocType::INCREASE; + } else if (current_buffer_size == size) { + return ReallocType::REUSE; + } else { + return ReallocType::DECREASE; + } + } + +public: + Allocator(OpKernelContext *context, cudaStream_t stream) + : context_(context), stream_(stream) { + pointer_mapping_ = new std::unordered_map(); + } + + void setStream(cudaStream_t stream) { stream_ = stream; } + + cudaStream_t returnStream() { return stream_; }; + + void *malloc(size_t size, const bool is_set_zero = true, + bool is_host = false) { + // FT_LOG_DEBUG(__PRETTY_FUNCTION__); + tensorflow::Tensor buf; + long long int buf_size = ((long long int)ceil(size / 32.) * 32); + tensorflow::Status status; + if (is_host) { + tensorflow::AllocatorAttributes pinned_allocator; + pinned_allocator.set_on_host(true); + pinned_allocator.set_gpu_compatible(true); + status = context_->allocate_temp(DT_UINT8, TensorShape{buf_size}, &buf, + pinned_allocator); + } else { + status = context_->allocate_temp(DT_UINT8, TensorShape{buf_size}, &buf); + } + + if (status != tensorflow::Status::OK()) { + throw std::runtime_error("TF error: context->allocate_temp failed"); + } + + auto flat = buf.flat(); + void *ptr = (void *)flat.data(); + if (is_set_zero) { + cudaMemsetAsync(ptr, 0, buf_size, stream_); + } + pointer_mapping_->insert({getAddress(ptr), buf}); + + return ptr; + } + + void free(void **ptr, bool is_host = false) const { + // FT_LOG_DEBUG(__PRETTY_FUNCTION__); + void *address = getAddress(*ptr); + pointer_mapping_->erase(address); + *ptr = nullptr; + return; + } + + virtual ~Allocator() { + while (!pointer_mapping_->empty()) { + void *ptr = pointer_mapping_->begin()->second.flat().data(); + free((void **)(&ptr)); + } + pointer_mapping_->clear(); + delete pointer_mapping_; + } + + void memSet(void *ptr, const int val, const size_t size) { + check_cuda_error(cudaMemsetAsync(ptr, val, size, stream_)); + } +}; +#endif + +#ifdef TORCH_CUDA +template <> class Allocator : public IAllocator { + std::unordered_map *pointer_mapping_; + + bool isExist(void *address) const { + return pointer_mapping_->count(address) > 0; + } + ReallocType isReMalloc(void *address, size_t size) const { + FT_CHECK(isExist(address)); + size_t current_buffer_size = 1; + for (int i = 0; i < pointer_mapping_->at(address).dim(); i++) { + current_buffer_size *= pointer_mapping_->at(address).size(i); + } + // FT_LOG_DEBUG( + // "current_buffer_size: %d, original buffer: %p, new buffer: %d", + // current_buffer_size, address, size); + if (current_buffer_size < size) { + return ReallocType::INCREASE; + } else if (current_buffer_size == size) { + return ReallocType::REUSE; + } else { + return ReallocType::DECREASE; + } + } + +public: + Allocator() { + pointer_mapping_ = new std::unordered_map(); + } + + void setStream(cudaStream_t stream) { + // nothing to do here; + } + + cudaStream_t returnStream() { + // nothing to do here; + return 0; + }; + + void *malloc(size_t size, const bool is_set_zero = true, + bool is_host = false) { + // FT_LOG_DEBUG(__PRETTY_FUNCTION__); + int64_t buf_size = static_cast(ceil(size / 32.)) * 32; + torch::Tensor buf; + if (is_host) { + buf = torch::empty( + {buf_size}, + torch::dtype(torch::kUInt8).device(torch::kCPU).pinned_memory(true)); + } else { + buf = torch::empty({buf_size}, + torch::dtype(torch::kUInt8).device(torch::kCUDA)); + } + void *ptr = buf.data_ptr(); + if (is_set_zero) { + cudaMemset(ptr, 0, buf_size); + } + // FT_LOG_DEBUG("malloc buffer %p with size %ld", ptr, buf_size); + pointer_mapping_->insert({getAddress(ptr), buf}); + return ptr; + } + + void free(void **ptr, bool is_host = false) const { + // FT_LOG_DEBUG(__PRETTY_FUNCTION__); + void *address = getAddress(*ptr); + pointer_mapping_->erase(address); + *ptr = nullptr; + return; + } + + virtual ~Allocator() { + // FT_LOG_DEBUG(__PRETTY_FUNCTION__); + while (!pointer_mapping_->empty()) { + void *ptr = pointer_mapping_->begin()->second.data_ptr(); + free((void **)(&ptr)); + } + pointer_mapping_->clear(); + delete pointer_mapping_; + } + + void memSet(void *ptr, const int val, const size_t size) { + check_cuda_error(cudaMemset(ptr, val, size)); + } +}; +#endif diff --git a/csrc/int8gemm/cublas/bindings.cpp b/csrc/int8gemm/cublas/bindings.cpp new file mode 100644 index 000000000000..fa35fe686dcc --- /dev/null +++ b/csrc/int8gemm/cublas/bindings.cpp @@ -0,0 +1,215 @@ +/* + gemm methods are adapted from ft +*/ +#include +#include +#include "cublasAlgoMap.h" +#include "cublasINT8MMWrapper.h" +#include "transform_layout.h" + +class I8CUGEMM { +private: + cublasINT8MMWrapper *int8_gemm_wrapper = nullptr; + // const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + +public: + I8CUGEMM(); + ~I8CUGEMM(); + + void linear_a8_w8_o32(torch::Tensor &input, torch::Tensor &weight, + torch::Tensor &output); + void linear_a8_w8_o32_(torch::Tensor &input, torch::Tensor &weight, + torch::Tensor &output); + void linear_a8_w8_o8(torch::Tensor &input, torch::Tensor &weight, + torch::Tensor &output, float alpha); + void linear_a8_w8_o8_(torch::Tensor &input, torch::Tensor &weight, + torch::Tensor &output, float alpha); + void linear_a8_w8_ofp32(torch::Tensor &input, torch::Tensor &weight, + torch::Tensor &output, float alpha); + void transform_row_to_col32(torch::Tensor &input, torch::Tensor &out); + void transform_col32_to_row(torch::Tensor &input, torch::Tensor &out); + void transform_row_to_ampere(torch::Tensor &input, torch::Tensor &out); + void transform_row_to_turing(torch::Tensor &input, torch::Tensor &out); + +}; +I8CUGEMM::I8CUGEMM() { + // cublasAlgoMap *cublas_algo_map = new cublasAlgoMap("igemm_config.in"); + cublasAlgoMap *cublas_algo_map = new cublasAlgoMap(); + std::mutex *cublas_wrapper_mutex = new std::mutex(); + bool use_ORDER_COL32_2R_4R4 = true; + + // const cudaStream_t stream; + // const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + cublasLtHandle_t cublaslt_handle; +// cudaStreamCreate(&stream); + cublasLtCreate(&cublaslt_handle); + + int8_gemm_wrapper = + new cublasINT8MMWrapper(cublaslt_handle, this->stream, cublas_algo_map, + cublas_wrapper_mutex, use_ORDER_COL32_2R_4R4); +} + +I8CUGEMM::~I8CUGEMM() {} + +void I8CUGEMM::linear_a8_w8_o32(torch::Tensor &input, // INT8 + torch::Tensor &weight, // INT8 + torch::Tensor &out // INT32 +) { + int m = input.size(0); + int n = weight.size(0); + int k = input.size(1); + + // Set data types + int8_t *input_ptr = input.data_ptr(); + int8_t *weight_ptr = weight.data_ptr(); + int32_t *output_ptr = out.data_ptr(); + + int8_gemm_wrapper->Gemm(output_ptr, 1, m, n, k, 0, 0, 0, input_ptr, + weight_ptr); +} + +void I8CUGEMM::linear_a8_w8_o32_(torch::Tensor &input, // INT8 + torch::Tensor &weight, // INT8 + torch::Tensor &out // INT32 +) { + int m = input.size(0); + int n = weight.size(0); + int k = input.size(1); + + // Set data types + int8_t *input_ptr = input.data_ptr(); + int8_t *weight_ptr = weight.data_ptr(); + int32_t *output_ptr = out.data_ptr(); + + int8_gemm_wrapper->Gemm_(output_ptr, 1, m, n, k, 0, 0, 0, input_ptr, + weight_ptr); +} + +void I8CUGEMM::linear_a8_w8_o8(torch::Tensor &input, // INT8 + torch::Tensor &weight, // INT8 + torch::Tensor &out, // INT8 + float alpha // FP32 +) { + int m = input.size(0); + int n = weight.size(0); + int k = input.size(1); + + // Set data types + int8_t *input_ptr = input.data_ptr(); + int8_t *weight_ptr = weight.data_ptr(); + int8_t *output_ptr = out.data_ptr(); + + int8_gemm_wrapper->Gemm(output_ptr, 1, m, n, k, 0, 0, 0, alpha, input_ptr, + weight_ptr); +} + +void I8CUGEMM::linear_a8_w8_o8_(torch::Tensor &input, // INT8 + torch::Tensor &weight, // INT8 + torch::Tensor &out, // INT8 + float alpha // FP32 +) { + int m = input.size(0); + int n = weight.size(0); + int k = input.size(1); + + // Set data types + int8_t *input_ptr = input.data_ptr(); + int8_t *weight_ptr = weight.data_ptr(); + int8_t *output_ptr = out.data_ptr(); + + int8_gemm_wrapper->Gemm_(output_ptr, 1, m, n, k, 0, 0, 0, alpha, input_ptr, + weight_ptr); +} + +void I8CUGEMM::linear_a8_w8_ofp32(torch::Tensor &input, // INT8 + torch::Tensor &weight, // INT8 + torch::Tensor &out, // INT8 + float alpha // FP32 +) { + int m = input.size(0); + int n = weight.size(0); + int k = input.size(1); + + // Set data types + int8_t *input_ptr = input.data_ptr(); + int8_t *weight_ptr = weight.data_ptr(); + float *output_ptr = out.data_ptr(); + + int8_gemm_wrapper->Gemm_f(output_ptr, 1, m, n, k, 0, 0, 0, alpha, input_ptr, + weight_ptr); +} + +void I8CUGEMM::transform_row_to_col32(torch::Tensor &input, torch::Tensor &out) { + int m = input.size(0); + int n = input.size(1); + int m_ = out.size(0); + int n_ = out.size(1); + + assert(m == m_); + assert(n == n_); + // const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + int8_t *input_ptr = input.data_ptr(); + int8_t *out_ptr = out.data_ptr(); + invokeRowMajorToCOL32(out_ptr, input_ptr, m, n, this->stream); + // invokeRowMajorToCOL32(out_ptr, input_ptr, m, n, stream); +} + +void I8CUGEMM::transform_col32_to_row(torch::Tensor &input, torch::Tensor &out) { + int m = input.size(0); + int n = input.size(1); + int m_ = out.size(0); + int n_ = out.size(1); + + assert(m == m_); + assert(n == n_); + // const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + int8_t *input_ptr = input.data_ptr(); + int8_t *out_ptr = out.data_ptr(); + invokeCOL32ToRowMajor(out_ptr, input_ptr, m, n, this->stream); + // invokeCOL32ToRowMajor(out_ptr, input_ptr, m, n, stream); +} + +void I8CUGEMM::transform_row_to_ampere(torch::Tensor &input, torch::Tensor &out) { + int m = input.size(0); + int n = input.size(1); + int m_ = out.size(0); + int n_ = out.size(1); + + assert(m == m_); + assert(n == n_); + // const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + int8_t *input_ptr = input.data_ptr(); + int8_t *out_ptr = out.data_ptr(); + invokeRowMajorToAmpere(out_ptr, input_ptr, m, n, this->stream); + // invokeCOL32ToRowMajor(out_ptr, input_ptr, m, n, stream); +} + +void I8CUGEMM::transform_row_to_turing(torch::Tensor &input, torch::Tensor &out) { + int m = input.size(0); + int n = input.size(1); + int m_ = out.size(0); + int n_ = out.size(1); + + assert(m == m_); + assert(n == n_); + // const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + int8_t *input_ptr = input.data_ptr(); + int8_t *out_ptr = out.data_ptr(); + invokeRowMajorToTuring(out_ptr, input_ptr, m, n, this->stream); + // invokeCOL32ToRowMajor(out_ptr, input_ptr, m, n, stream); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + pybind11::class_(m, "I8CUGEMM") + .def(pybind11::init<>()) + .def("linear_a8_w8_o32", &I8CUGEMM::linear_a8_w8_o32) + .def("linear_a8_w8_o8", &I8CUGEMM::linear_a8_w8_o8) + .def("linear_a8_w8_o8_", &I8CUGEMM::linear_a8_w8_o8_) + .def("linear_a8_w8_o32_", &I8CUGEMM::linear_a8_w8_o32_) + .def("linear_a8_w8_ofp32", &I8CUGEMM::linear_a8_w8_ofp32) + .def("transform_row_to_col32", &I8CUGEMM::transform_row_to_col32) + .def("transform_col32_to_row", &I8CUGEMM::transform_col32_to_row) + .def("transform_row_to_ampere", &I8CUGEMM::transform_row_to_ampere) + .def("transform_row_to_turing", &I8CUGEMM::transform_row_to_turing); +} diff --git a/csrc/int8gemm/cublas/cublasAlgoMap.cc b/csrc/int8gemm/cublas/cublasAlgoMap.cc new file mode 100644 index 000000000000..61e41438c6a8 --- /dev/null +++ b/csrc/int8gemm/cublas/cublasAlgoMap.cc @@ -0,0 +1,188 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cublasAlgoMap.h" + +cublasAlgoMap::cublasAlgoMap(const std::string filename, + const std::string sp_config_filename) + : config_filename_(filename), sp_config_filename_(sp_config_filename) { + loadGemmConfig(); + loadSpGemmConfig(); +} + +cublasAlgoMap::cublasAlgoMap(const cublasAlgoMap &algo_map) + : config_filename_(algo_map.config_filename_), + sp_config_filename_(algo_map.sp_config_filename_), + algo_map_(algo_map.algo_map_), sp_algo_map_(algo_map.sp_algo_map_) {} + +cublasAlgoMap::~cublasAlgoMap() { algo_map_.clear(); } + +void cublasAlgoMap::loadGemmConfig() { + FILE *fd; + fd = fopen(config_filename_.c_str(), "r"); + if (fd == NULL) { + std::cout << "[WARNING] " << config_filename_ + << " is not found; using default GEMM algo" << std::endl; + return; + } + + int batchCount2, m2, n2, k2, algoId, customOption, tile, splitK_val; + int batch_size, seq_len, head_num, size_per_head, dataType; + int swizzle, reductionScheme, workspaceSize, stages; + int inner_shapeId, cluster_shapeId, mma_shapeId, cga_shapeId, sche_mode; + float exec_time; + char tmp[1024]; + if (!fgets(tmp, 1024, fd)) { + printf("[ERROR] fgets fail at %s:%d \n", __FILE__, __LINE__); + exit(-1); + } + while (fscanf(fd, + "%d %d %d %d %d ### %d %d %d %d %d %d %d %d %d %d %d %d " +#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3) + "%d %d " +#elif (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH < 3) + "%d %d %d " +#endif + "%f\n", + &batch_size, &seq_len, &head_num, &size_per_head, &dataType, + &batchCount2, &n2, &m2, &k2, &algoId, &customOption, &tile, + &splitK_val, &swizzle, &reductionScheme, &workspaceSize, + &stages, +#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3) + &inner_shapeId, &cluster_shapeId, +#elif (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH < 3) + &mma_shapeId, &cga_shapeId, &sche_mode, +#endif + &exec_time) != EOF) { + if (dataType != FLOAT_DATATYPE && dataType != HALF_DATATYPE && + dataType != BFLOAT16_DATATYPE && dataType != INT8_DATATYPE && + dataType != FP8_DATATYPE) { + printf("[WARNING][readAlgoFromConfig] wrong dataType %d!\n", dataType); + continue; + } + cublasAlgoConfig_t markStr{batchCount2, m2, n2, k2, + static_cast(dataType)}; + // workspaceSize should be zero + if (algo_map_.find(markStr) == algo_map_.end()) { + algo_map_[markStr].algoId = algoId; + algo_map_[markStr].customOption = customOption; + algo_map_[markStr].tile = tile; + algo_map_[markStr].splitK_val = splitK_val; + algo_map_[markStr].swizzle = swizzle; + algo_map_[markStr].reductionScheme = reductionScheme; + algo_map_[markStr].workspaceSize = workspaceSize; + algo_map_[markStr].stages = stages; +#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3) + algo_map_[markStr].inner_shapeId = (uint16_t)inner_shapeId; + algo_map_[markStr].cluster_shapeId = (uint16_t)cluster_shapeId; +#elif (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH < 3) + algo_map_[markStr].mma_shapeId = (uint16_t)mma_shapeId; + algo_map_[markStr].cga_shapeId = (uint16_t)cga_shapeId; + algo_map_[markStr].sche_mode = (uint16_t)sche_mode; +#endif + algo_map_[markStr].exec_time = exec_time; + } + } + fclose(fd); +} + +bool cublasAlgoMap::isExist(const int batch_count, const int m, const int n, + const int k, const CublasDataType data_type) { + cublasAlgoConfig_t mark{batch_count, n, m, k, data_type}; + return algo_map_.find(mark) != algo_map_.end(); +} + +cublasLtMatmulAlgo_info cublasAlgoMap::getAlgo(const int batch_count, + const int m, const int n, + const int k, + const CublasDataType data_type) { + cublasAlgoConfig_t mark{batch_count, n, m, k, data_type}; + if (algo_map_.find(mark) != algo_map_.end()) { + return algo_map_[mark]; + } else { + cublasLtMatmulAlgo_info tmp_algo; + tmp_algo.algoId = static_cast(data_type == FLOAT_DATATYPE + ? CUBLAS_GEMM_DEFAULT + : CUBLAS_GEMM_DEFAULT_TENSOR_OP); + tmp_algo.customOption = -1; + tmp_algo.tile = -1; + tmp_algo.splitK_val = -1; + tmp_algo.swizzle = -1; + tmp_algo.reductionScheme = -1; + tmp_algo.workspaceSize = -1; + tmp_algo.stages = -1; + tmp_algo.exec_time = -1.0f; + return tmp_algo; + } +} + +void cublasAlgoMap::loadSpGemmConfig() { + if (sp_config_filename_.empty()) { + return; + } + FILE *fd = fopen(sp_config_filename_.c_str(), "r"); + if (fd == NULL) { + printf("[WARNING] %s is not found; using SPGEMM algo id 0\n", + sp_config_filename_.c_str()); + return; + } + sp_algo_map_.clear(); + int batch_size, seq_len, head_num, size_per_head, data_type; + int batchCount, m, n, k, algoId; + float exec_time; + char tmp[1024]; + if (!fgets(tmp, 1024, fd)) { + printf("[ERROR] fgets fail at %s:%d \n", __FILE__, __LINE__); + exit(-1); + } + while (fscanf(fd, "%d %d %d %d %d ### %d %d %d %d %d %f\n", &batch_size, + &seq_len, &head_num, &size_per_head, &data_type, &batchCount, + &m, &n, &k, &algoId, &exec_time) != EOF) { + char mark[256]; + sprintf(mark, "%d_%d_%d_%d", batchCount, m, n, k); + std::string markStr(mark); + sp_algo_map_[markStr] = algoId; + } + fclose(fd); +} + +int cublasAlgoMap::getSpAlgo(const int batch_count, const int m, const int n, + const int k) { + char mark[256]; + sprintf(mark, "%d_%d_%d_%d", batch_count, m, n, k); + if (sp_algo_map_.find(mark) != sp_algo_map_.end()) { + return sp_algo_map_[mark]; + } else { + // for remove padding, select algo 1 for simplicity + return 0; + } +} + +bool cublasAlgoMap::isUseSparse(const int batch_count, const int m, const int n, + const int k) { + // not available to use cusparselt. + if (m % 8 != 0 || n % 8 != 0 || k % 8 != 0) { + return false; + } + char mark[256]; + sprintf(mark, "%d_%d_%d_%d", batch_count, m, n, k); + if (sp_algo_map_.find(mark) != sp_algo_map_.end()) { + return sp_algo_map_[mark] != -1; + } else { + // no gemm test case, choose sparse according to sparse flag + return true; + } +} diff --git a/csrc/int8gemm/cublas/cublasAlgoMap.h b/csrc/int8gemm/cublas/cublasAlgoMap.h new file mode 100644 index 000000000000..beb9d3a23d90 --- /dev/null +++ b/csrc/int8gemm/cublas/cublasAlgoMap.h @@ -0,0 +1,108 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cuda_utils.h" +#include +#include +#include +#include +#include +#include +#include + +#pragma once + +#define GEMM_NUM 6 +#define GEMM_CONFIG "gemm_config.in" +#define IGEMM_CONFIG "igemm_config.in" +#define SPGEMM_CONFIG "spgemm_config.in" +#define SPIGEMM_CONFIG "spigemm_config.in" + +typedef struct { + int algoId, customOption, tile, splitK_val; + int swizzle, reductionScheme, workspaceSize; + // only used in cublasLt >= 11.0 + int stages; +#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3) + uint16_t inner_shapeId, cluster_shapeId; +#elif (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH < 3) + uint16_t mma_shapeId, cga_shapeId, sche_mode; +#endif + float exec_time; +} cublasLtMatmulAlgo_info; + +/* Structure to store information about different run trials */ +typedef struct { + cublasLtMatmulAlgo_t algo; + cublasStatus_t status; + float time; + size_t workspaceSize; // actual memory workspace needed + cublasMath_t mathMode; + cublasLtReductionScheme_t reductionScheme; + int customOption; + float wavesCount; +} customMatmulPerf_t; + +struct cublasAlgoConfig_t { + int batch_count; + int m; + int n; + int k; + CublasDataType data_type; + bool operator==(cublasAlgoConfig_t const &config) const { + return (batch_count == config.batch_count) && (m == config.m) && + (n == config.n) && (k == config.k) && + (data_type == config.data_type); + } +}; + +class cublasAlgoConfig_hasher { +public: + std::size_t operator()(cublasAlgoConfig_t const &config) const { + return config.batch_count * 98317ull ^ config.m * 49157ull ^ + config.n * 24593ull ^ config.k * 196613ull ^ + static_cast(config.data_type) * 6151ull; + } +}; + +class cublasAlgoMap { +private: + std::unordered_map + algo_map_; + std::string config_filename_; + std::string sp_config_filename_; + std::map sp_algo_map_; + +public: + cublasAlgoMap(){}; + explicit cublasAlgoMap(const std::string filename, + const std::string sp_config_filename = ""); + cublasAlgoMap(const cublasAlgoMap &map); + ~cublasAlgoMap(); + void loadGemmConfig(); + void loadSpGemmConfig(); + int getSpAlgo(const int batch_count, const int m, const int n, const int k); + bool isUseSparse(const int batch_count, const int m, const int n, + const int k); + + bool isExist(const int batch_count, const int m, const int n, const int k, + const CublasDataType data_type); + + cublasLtMatmulAlgo_info getAlgo(const int batch_count, const int m, + const int n, const int k, + const CublasDataType data_type); +}; diff --git a/csrc/int8gemm/cublas/cublasINT8MMWrapper.cc b/csrc/int8gemm/cublas/cublasINT8MMWrapper.cc new file mode 100644 index 000000000000..23f68c971414 --- /dev/null +++ b/csrc/int8gemm/cublas/cublasINT8MMWrapper.cc @@ -0,0 +1,840 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cublasINT8MMWrapper.h" + +#ifndef CUDART_VERSION +#error CUDART_VERSION Undefined! +#endif + +cublasINT8MMWrapper::cublasINT8MMWrapper(cublasLtHandle_t cublaslt_handle, + cudaStream_t stream, + cublasAlgoMap *cublas_algo_map, + std::mutex *mu, + bool use_ORDER_COL32_2R_4R4) + : cublasMMWrapper(nullptr, cublaslt_handle, stream, cublas_algo_map, mu, + nullptr), + use_ORDER_COL32_2R_4R4_(use_ORDER_COL32_2R_4R4) {} + +cublasINT8MMWrapper::cublasINT8MMWrapper(cublasHandle_t cublas_handle, + cublasLtHandle_t cublaslt_handle, + cudaStream_t stream, + cublasAlgoMap *cublas_algo_map, + std::mutex *mu, + bool use_ORDER_COL32_2R_4R4) + : cublasMMWrapper(cublas_handle, cublaslt_handle, stream, cublas_algo_map, + mu, nullptr), + use_ORDER_COL32_2R_4R4_(use_ORDER_COL32_2R_4R4) {} + +#ifdef SPARSITY_ENABLED +cublasINT8MMWrapper::cublasINT8MMWrapper(cublasLtHandle_t cublaslt_handle, + cusparseLtHandle_t cusparselt_handle, + cudaStream_t stream, + cublasAlgoMap *cublas_algo_map, + std::mutex *mu, + bool use_ORDER_COL32_2R_4R4) + : cublasMMWrapper(nullptr, cublaslt_handle, cusparselt_handle, stream, + cublas_algo_map, mu, nullptr), + use_ORDER_COL32_2R_4R4_(use_ORDER_COL32_2R_4R4) {} +#endif + +cublasINT8MMWrapper::~cublasINT8MMWrapper() { mu_ = nullptr; } + +cublasINT8MMWrapper::cublasINT8MMWrapper(const cublasINT8MMWrapper &wrapper) + : +#ifdef SPARSITY_ENABLED + cublasMMWrapper(nullptr, wrapper.cublaslt_handle_, + wrapper.cusparselt_handle_, wrapper.stream_, + wrapper.cublas_algo_map_, wrapper.mu_, + wrapper.allocator_), +#else + cublasMMWrapper(nullptr, wrapper.cublaslt_handle_, wrapper.stream_, + wrapper.cublas_algo_map_, wrapper.mu_, + wrapper.allocator_), +#endif + use_ORDER_COL32_2R_4R4_(wrapper.use_ORDER_COL32_2R_4R4_) { +} + +// for int8 cublasLtMM with algo +// ATransform should be m*n, CUBLASLT_ORDER_COL32 +// kernel should be n*k, CUBLASLT_ORDER_COL4_4R2_8C or +// CUBLASLT_ORDER_COL32_2R_4R4 res is m*n, CUBLASLT_ORDER_COL32 +void cublasINT8MMWrapper::Gemm(int *res, int batchCount, int m, int n, int k, + int64_t stridea, int64_t strideb, + int64_t stridec, const int8_t *ATransform, + const int8_t *kernel) { + mu_->lock(); + cublasOperation_t opTranspose = CUBLAS_OP_T; +#if (CUDART_VERSION >= 11000) + cublasComputeType_t computeType = CUBLAS_COMPUTE_32I; +#else + cudaDataType_t computeType = CUDA_R_32I; +#endif + cublasLtMatmulDesc_t matmulDesc; + cublasLtMatrixLayout_t AtransformDesc = NULL; + cublasLtMatrixLayout_t BtransformDesc = NULL; + cublasLtMatrixLayout_t CtransformDesc = NULL; + cublasLtOrder_t order_COL32 = CUBLASLT_ORDER_COL32; + + cublasLtOrder_t order_matrixB; +#if (CUDART_VERSION >= 11000) + if (use_ORDER_COL32_2R_4R4_) { + order_matrixB = CUBLASLT_ORDER_COL32_2R_4R4; + } else { + order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C; + } +#else + order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C; +#endif + + int ldaTransform = 32 * m; + int ldbTransform; + if (use_ORDER_COL32_2R_4R4_) { + ldbTransform = 32 * ((n + 32 - 1) / 32) * 32; + } else { + ldbTransform = 32 * ((n + 8 - 1) / 8) * 8; + } + int ldcTransform = 32 * m; + + // create matmulDesc +#if (CUDART_VERSION >= 11000) + cublasLtMatmulDescCreate(&matmulDesc, computeType, CUDA_R_32I); +#else + cublasLtMatmulDescCreate(&matmulDesc, computeType); +#endif + cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, + &opTranspose, sizeof(cublasOperation_t)); + cublasLtMatrixLayoutCreate(&AtransformDesc, CUDA_R_8I, m, k, ldaTransform); + cublasLtMatrixLayoutSetAttribute(AtransformDesc, CUBLASLT_MATRIX_LAYOUT_ORDER, + &order_COL32, sizeof(order_COL32)); + cublasLtMatrixLayoutCreate(&BtransformDesc, CUDA_R_8I, n, k, ldbTransform); + cublasLtMatrixLayoutSetAttribute(BtransformDesc, CUBLASLT_MATRIX_LAYOUT_ORDER, + &order_matrixB, sizeof(order_matrixB)); + cublasLtMatrixLayoutCreate(&CtransformDesc, CUDA_R_32I, m, n, ldcTransform); + cublasLtMatrixLayoutSetAttribute(CtransformDesc, CUBLASLT_MATRIX_LAYOUT_ORDER, + &order_COL32, sizeof(order_COL32)); + if (batchCount > 1) { + cublasLtMatrixLayoutSetAttribute(AtransformDesc, + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &batchCount, sizeof(batchCount)); + cublasLtMatrixLayoutSetAttribute( + AtransformDesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridea, + sizeof(stridea)); + cublasLtMatrixLayoutSetAttribute(BtransformDesc, + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &batchCount, sizeof(batchCount)); + cublasLtMatrixLayoutSetAttribute( + BtransformDesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideb, + sizeof(strideb)); + cublasLtMatrixLayoutSetAttribute(CtransformDesc, + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &batchCount, sizeof(batchCount)); + cublasLtMatrixLayoutSetAttribute( + CtransformDesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridec, + sizeof(stridec)); + } + + int alphaI = 1; + int betaI = 0; + + // get algo + cublasLtMatmulAlgo_t algo; + int findAlgo = 0; + if (cublas_algo_map_->isExist(batchCount, m, n, k, INT8_DATATYPE)) { + // printf("find algo %s\n", markStr.c_str()); + findAlgo = 1; + + cublasLtMatmulAlgo_info tmp_info = + cublas_algo_map_->getAlgo(batchCount, m, n, k, INT8_DATATYPE); + + cublasLtMatmulAlgoInit(cublaslt_handle_, computeType, CUDA_R_32I, CUDA_R_8I, + CUDA_R_8I, CUDA_R_32I, CUDA_R_32I, tmp_info.algoId, + &algo); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &(tmp_info.customOption), + sizeof(tmp_info.customOption)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, + &(tmp_info.tile), + sizeof(tmp_info.tile)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, + &(tmp_info.splitK_val), + sizeof(tmp_info.splitK_val)); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(tmp_info.swizzle), + sizeof(tmp_info.swizzle)); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, + &(tmp_info.reductionScheme), sizeof(int)); +#if (CUDART_VERSION >= 11000) + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, + &(tmp_info.stages), + sizeof(tmp_info.stages)); +#endif + } else { + findAlgo = 1; + int algoId; + if (use_ORDER_COL32_2R_4R4_) { + algoId = 7; + } else { + algoId = 6; + } + int swizzle = 0; + int customOption = 0; + int tile = 20; + int splitK_val = 0; + int reductionScheme = 0; + cublasLtMatmulAlgoInit(cublaslt_handle_, computeType, CUDA_R_32I, CUDA_R_8I, + CUDA_R_8I, CUDA_R_32I, CUDA_R_32I, algoId, &algo); + cublasLtMatmulAlgoConfigSetAttribute(&algo, + CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, + &(customOption), sizeof(customOption)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, + &(tile), sizeof(tile)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, + &(splitK_val), sizeof(splitK_val)); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(swizzle), sizeof(swizzle)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, + CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, + &(reductionScheme), sizeof(int)); +#if (CUDART_VERSION >= 11000) + int stages; + if (use_ORDER_COL32_2R_4R4_) { + stages = 15; + } else { + stages = 13; + } + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, + &(stages), sizeof(stages)); +#endif + } + + cublasLtMatmul(cublaslt_handle_, matmulDesc, &alphaI, ATransform, + AtransformDesc, kernel, BtransformDesc, &betaI, res, + CtransformDesc, res, CtransformDesc, + (findAlgo == 1 ? (&algo) : NULL), NULL, 0, stream_); + + cublasLtMatmulDescDestroy(matmulDesc); + cublasLtMatrixLayoutDestroy(AtransformDesc); + cublasLtMatrixLayoutDestroy(BtransformDesc); + cublasLtMatrixLayoutDestroy(CtransformDesc); + sync_check_cuda_error(); + mu_->unlock(); +} + +// Atransform: mxk CUDA_R_8I +// kernel: nxk CUDA_R_8I +// res: mxn CUDA_R_32I +// alpha: CUDA_R_32I should be 1 +// beta: CUDA_R_32I should be 0 +// computeType: CUBLAS_COMPUTE_32I +void cublasINT8MMWrapper::Gemm_(int *res, int batchCount, int m, int n, int k, + int64_t stridea, int64_t strideb, + int64_t stridec, const int8_t *ATransform, + const int8_t *kernel) { + mu_->lock(); + cublasOperation_t opTranspose = CUBLAS_OP_T; +#if (CUDART_VERSION >= 11000) + cublasComputeType_t computeType = CUBLAS_COMPUTE_32I; +#else + cudaDataType_t computeType = CUDA_R_32I; +#endif + cublasLtMatmulDesc_t matmulDesc; + cublasLtMatrixLayout_t AtransformDesc = NULL; + cublasLtMatrixLayout_t BtransformDesc = NULL; + cublasLtMatrixLayout_t CtransformDesc = NULL; + + // create matmulDesc +#if (CUDART_VERSION >= 11000) + cublasLtMatmulDescCreate(&matmulDesc, computeType, CUDA_R_32I); +#else + cublasLtMatmulDescCreate(&matmulDesc, computeType); +#endif + cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSA, + &opTranspose, sizeof(cublasOperation_t)); + + cublasLtMatrixLayoutCreate(&AtransformDesc, CUDA_R_8I, k, n, k); + + cublasLtMatrixLayoutCreate(&BtransformDesc, CUDA_R_8I, k, m, k); + + cublasLtMatrixLayoutCreate(&CtransformDesc, CUDA_R_32I, n, m, n); + + if (batchCount > 1) { + cublasLtMatrixLayoutSetAttribute(AtransformDesc, + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &batchCount, sizeof(batchCount)); + cublasLtMatrixLayoutSetAttribute( + AtransformDesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridea, + sizeof(stridea)); + cublasLtMatrixLayoutSetAttribute(BtransformDesc, + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &batchCount, sizeof(batchCount)); + cublasLtMatrixLayoutSetAttribute( + BtransformDesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideb, + sizeof(strideb)); + cublasLtMatrixLayoutSetAttribute(CtransformDesc, + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &batchCount, sizeof(batchCount)); + cublasLtMatrixLayoutSetAttribute( + CtransformDesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridec, + sizeof(stridec)); + } + + int alphaI = 1; + int betaI = 0; + + // get algo + cublasLtMatmulAlgo_t algo; + int findAlgo = 0; + if (cublas_algo_map_->isExist(batchCount, m, n, k, INT8_DATATYPE)) { + // printf("find algo %s\n", markStr.c_str()); + findAlgo = 1; + + cublasLtMatmulAlgo_info tmp_info = + cublas_algo_map_->getAlgo(batchCount, m, n, k, INT8_DATATYPE); + + cublasLtMatmulAlgoInit(cublaslt_handle_, computeType, CUDA_R_32I, CUDA_R_8I, + CUDA_R_8I, CUDA_R_32I, CUDA_R_32I, tmp_info.algoId, + &algo); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &(tmp_info.customOption), + sizeof(tmp_info.customOption)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, + &(tmp_info.tile), + sizeof(tmp_info.tile)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, + &(tmp_info.splitK_val), + sizeof(tmp_info.splitK_val)); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(tmp_info.swizzle), + sizeof(tmp_info.swizzle)); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, + &(tmp_info.reductionScheme), sizeof(int)); +#if (CUDART_VERSION >= 11000) + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, + &(tmp_info.stages), + sizeof(tmp_info.stages)); +#endif + } else { + findAlgo = 1; + int algoId; + algoId = 21; + int swizzle = 0; + int customOption = 0; + int tile = 20; + int splitK_val = 0; + int reductionScheme = 0; + cublasLtMatmulAlgoInit(cublaslt_handle_, computeType, CUDA_R_32I, CUDA_R_8I, + CUDA_R_8I, CUDA_R_32I, CUDA_R_32I, algoId, &algo); + cublasLtMatmulAlgoConfigSetAttribute(&algo, + CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, + &(customOption), sizeof(customOption)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, + &(tile), sizeof(tile)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, + &(splitK_val), sizeof(splitK_val)); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(swizzle), sizeof(swizzle)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, + CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, + &(reductionScheme), sizeof(int)); +#if (CUDART_VERSION >= 11000) + int stages; + stages = 17; + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, + &(stages), sizeof(stages)); +#endif + } + + cublasLtMatmul(cublaslt_handle_, matmulDesc, &alphaI, kernel, AtransformDesc, + ATransform, BtransformDesc, &betaI, res, CtransformDesc, res, + CtransformDesc, (findAlgo == 1 ? (&algo) : NULL), NULL, 0, + stream_); + + cublasLtMatmulDescDestroy(matmulDesc); + cublasLtMatrixLayoutDestroy(AtransformDesc); + cublasLtMatrixLayoutDestroy(BtransformDesc); + cublasLtMatrixLayoutDestroy(CtransformDesc); + sync_check_cuda_error(); + mu_->unlock(); +} + +// for int8 IO cublasLtMM with algo +// ATransform should be m*k CUBLASLT_ORDER_COL32 +// kernel should be n*k CUBLASLT_ORDER_COL4_4R2_8C +// res is m*n CUBLASLT_ORDER_COL32 +void cublasINT8MMWrapper::Gemm(int8_t *res, int batchCount, int m, int n, int k, + int64_t stridea, int64_t strideb, + int64_t stridec, const float alpha, + const int8_t *ATransform, const int8_t *kernel) { + mu_->lock(); + cublasOperation_t opTranspose = CUBLAS_OP_T; + // int8 gemm does not support CUBLAS_POINTER_MODE_DEVICE + // cublasLtPointerMode_t pointerMode = + // CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO; + cudaDataType_t scaleType = CUDA_R_32F; +#if (CUDART_VERSION >= 11000) + cublasComputeType_t computeType = CUBLAS_COMPUTE_32I; +#else + cudaDataType_t computeType = CUDA_R_32I; +#endif + cublasLtMatmulDesc_t matmulDesc; + cublasLtMatrixLayout_t AtransformDesc = NULL; + cublasLtMatrixLayout_t BtransformDesc = NULL; + cublasLtMatrixLayout_t CtransformDesc = NULL; + cublasLtOrder_t order_COL32 = CUBLASLT_ORDER_COL32; + + cublasLtOrder_t order_matrixB; +#if (CUDART_VERSION >= 11000) + if (use_ORDER_COL32_2R_4R4_) { + order_matrixB = CUBLASLT_ORDER_COL32_2R_4R4; + } else { + order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C; + } +#else + order_matrixB = CUBLASLT_ORDER_COL4_4R2_8C; +#endif + + int ldaTransform = 32 * m; + + int ldbTransform; + if (use_ORDER_COL32_2R_4R4_) { + ldbTransform = 32 * ((n + 32 - 1) / 32) * 32; + } else { + ldbTransform = 32 * ((n + 8 - 1) / 8) * 8; + } + + int ldcTransform = 32 * m; + + // create matmulDesc +#if (CUDART_VERSION >= 11000) + cublasLtMatmulDescCreate(&matmulDesc, computeType, scaleType); +#else + cublasLtMatmulDescCreate(&matmulDesc, computeType); +#endif + cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, + &opTranspose, sizeof(cublasOperation_t)); + cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_SCALE_TYPE, + &scaleType, sizeof(scaleType)); + // cublasLtMatmulDescSetAttribute(matmulDesc, + // CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointerMode, + // sizeof(cublasLtPointerMode_t)); + cublasLtMatrixLayoutCreate(&AtransformDesc, CUDA_R_8I, m, k, ldaTransform); + cublasLtMatrixLayoutSetAttribute(AtransformDesc, CUBLASLT_MATRIX_LAYOUT_ORDER, + &order_COL32, sizeof(order_COL32)); + cublasLtMatrixLayoutCreate(&BtransformDesc, CUDA_R_8I, n, k, ldbTransform); + cublasLtMatrixLayoutSetAttribute(BtransformDesc, CUBLASLT_MATRIX_LAYOUT_ORDER, + &order_matrixB, sizeof(order_matrixB)); + cublasLtMatrixLayoutCreate(&CtransformDesc, CUDA_R_8I, m, n, ldcTransform); + cublasLtMatrixLayoutSetAttribute(CtransformDesc, CUBLASLT_MATRIX_LAYOUT_ORDER, + &order_COL32, sizeof(order_COL32)); + if (batchCount > 1) { + cublasLtMatrixLayoutSetAttribute(AtransformDesc, + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &batchCount, sizeof(batchCount)); + cublasLtMatrixLayoutSetAttribute( + AtransformDesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridea, + sizeof(stridea)); + cublasLtMatrixLayoutSetAttribute(BtransformDesc, + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &batchCount, sizeof(batchCount)); + cublasLtMatrixLayoutSetAttribute( + BtransformDesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideb, + sizeof(strideb)); + cublasLtMatrixLayoutSetAttribute(CtransformDesc, + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &batchCount, sizeof(batchCount)); + cublasLtMatrixLayoutSetAttribute( + CtransformDesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridec, + sizeof(stridec)); + } + + // get algo + cublasLtMatmulAlgo_t algo; + int findAlgo = 0; + if (cublas_algo_map_->isExist(batchCount, m, n, k, INT8_DATATYPE)) { + findAlgo = 1; + + cublasLtMatmulAlgo_info tmp_info = + cublas_algo_map_->getAlgo(batchCount, m, n, k, INT8_DATATYPE); + + cublasLtMatmulAlgoInit(cublaslt_handle_, computeType, CUDA_R_32F, CUDA_R_8I, + CUDA_R_8I, CUDA_R_8I, CUDA_R_8I, tmp_info.algoId, + &algo); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &(tmp_info.customOption), + sizeof(tmp_info.customOption)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, + &(tmp_info.tile), + sizeof(tmp_info.tile)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, + &(tmp_info.splitK_val), + sizeof(tmp_info.splitK_val)); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(tmp_info.swizzle), + sizeof(tmp_info.swizzle)); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, + &(tmp_info.reductionScheme), sizeof(int)); +#if (CUDART_VERSION >= 11000) + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, + &(tmp_info.stages), + sizeof(tmp_info.stages)); +#endif + } else { + findAlgo = 1; + int algoId; + if (use_ORDER_COL32_2R_4R4_) { + algoId = 7; + } else { + algoId = 6; + } + int swizzle = 0; + int customOption = 0; + int tile = 20; + int splitK_val = 0; + int reductionScheme = 0; + cublasLtMatmulAlgoInit(cublaslt_handle_, computeType, CUDA_R_32F, CUDA_R_8I, + CUDA_R_8I, CUDA_R_8I, CUDA_R_8I, algoId, &algo); + cublasLtMatmulAlgoConfigSetAttribute(&algo, + CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, + &(customOption), sizeof(customOption)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, + &(tile), sizeof(tile)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, + &(splitK_val), sizeof(splitK_val)); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(swizzle), sizeof(swizzle)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, + CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, + &(reductionScheme), sizeof(int)); +#if (CUDART_VERSION >= 11000) + int stages; + if (use_ORDER_COL32_2R_4R4_) { + stages = 15; + } else { + stages = 13; + } + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, + &(stages), sizeof(stages)); +#endif + } + + float beta = 0.0f; + cublasLtMatmul(cublaslt_handle_, matmulDesc, &alpha, kernel, AtransformDesc, + ATransform, BtransformDesc, &beta, res, CtransformDesc, res, + CtransformDesc, (findAlgo == 1 ? (&algo) : NULL), NULL, 0, + stream_); + + cublasLtMatmulDescDestroy(matmulDesc); + cublasLtMatrixLayoutDestroy(AtransformDesc); + cublasLtMatrixLayoutDestroy(BtransformDesc); + cublasLtMatrixLayoutDestroy(CtransformDesc); + sync_check_cuda_error(); + mu_->unlock(); +} + +// Atransform: mxk CUDA_R_8I +// kernel: nxk CUDA_R_8I +// res: mxn CUDA_R_8I +// alpha: CUDA_R_32F +// beta: CUDA_R_32F +// computeType: CUBLAS_COMPUTE_32I +void cublasINT8MMWrapper::Gemm_(int8_t *res, int batchCount, int m, int n, + int k, int64_t stridea, int64_t strideb, + int64_t stridec, const float alpha, + const int8_t *ATransform, + const int8_t *kernel) { + mu_->lock(); + cublasOperation_t opTranspose = CUBLAS_OP_T; + // int8 gemm does not support CUBLAS_POINTER_MODE_DEVICE + // cublasLtPointerMode_t pointerMode = + // CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO; + cudaDataType_t scaleType = CUDA_R_32F; +#if (CUDART_VERSION >= 11000) + cublasComputeType_t computeType = CUBLAS_COMPUTE_32I; +#else + cudaDataType_t computeType = CUDA_R_32I; +#endif + cublasLtMatmulDesc_t matmulDesc; + cublasLtMatrixLayout_t AtransformDesc = NULL; + cublasLtMatrixLayout_t BtransformDesc = NULL; + cublasLtMatrixLayout_t CtransformDesc = NULL; + + // create matmulDesc +#if (CUDART_VERSION >= 11000) + cublasLtMatmulDescCreate(&matmulDesc, computeType, scaleType); +#else + cublasLtMatmulDescCreate(&matmulDesc, computeType); +#endif + cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSA, + &opTranspose, sizeof(cublasOperation_t)); + cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_SCALE_TYPE, + &scaleType, sizeof(scaleType)); + // cublasLtMatmulDescSetAttribute(matmulDesc, + // CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointerMode, + // sizeof(cublasLtPointerMode_t)); + cublasLtMatrixLayoutCreate(&AtransformDesc, CUDA_R_8I, k, n, k); + + cublasLtMatrixLayoutCreate(&BtransformDesc, CUDA_R_8I, k, m, k); + + cublasLtMatrixLayoutCreate(&CtransformDesc, CUDA_R_8I, n, m, n); + + if (batchCount > 1) { + cublasLtMatrixLayoutSetAttribute(AtransformDesc, + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &batchCount, sizeof(batchCount)); + cublasLtMatrixLayoutSetAttribute( + AtransformDesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridea, + sizeof(stridea)); + cublasLtMatrixLayoutSetAttribute(BtransformDesc, + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &batchCount, sizeof(batchCount)); + cublasLtMatrixLayoutSetAttribute( + BtransformDesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideb, + sizeof(strideb)); + cublasLtMatrixLayoutSetAttribute(CtransformDesc, + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &batchCount, sizeof(batchCount)); + cublasLtMatrixLayoutSetAttribute( + CtransformDesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridec, + sizeof(stridec)); + } + + // get algo + cublasLtMatmulAlgo_t algo; + int findAlgo = 0; + if (cublas_algo_map_->isExist(batchCount, n, m, k, INT8_DATATYPE)) { + findAlgo = 1; + cublasLtMatmulAlgo_info tmp_info = + cublas_algo_map_->getAlgo(batchCount, n, m, k, INT8_DATATYPE); + + cublasLtMatmulAlgoInit(cublaslt_handle_, computeType, CUDA_R_32F, CUDA_R_8I, + CUDA_R_8I, CUDA_R_8I, CUDA_R_8I, tmp_info.algoId, + &algo); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &(tmp_info.customOption), + sizeof(tmp_info.customOption)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, + &(tmp_info.tile), + sizeof(tmp_info.tile)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, + &(tmp_info.splitK_val), + sizeof(tmp_info.splitK_val)); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(tmp_info.swizzle), + sizeof(tmp_info.swizzle)); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, + &(tmp_info.reductionScheme), sizeof(int)); +#if (CUDART_VERSION >= 11000) + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, + &(tmp_info.stages), + sizeof(tmp_info.stages)); +#endif + } else { + findAlgo = 1; + int algoId; + algoId = 21; + int swizzle = 0; + int customOption = 0; + int tile = 20; + int splitK_val = 0; + int reductionScheme = 0; + cublasLtMatmulAlgoInit(cublaslt_handle_, computeType, CUDA_R_32F, CUDA_R_8I, + CUDA_R_8I, CUDA_R_8I, CUDA_R_8I, algoId, &algo); + cublasLtMatmulAlgoConfigSetAttribute(&algo, + CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, + &(customOption), sizeof(customOption)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, + &(tile), sizeof(tile)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, + &(splitK_val), sizeof(splitK_val)); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(swizzle), sizeof(swizzle)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, + CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, + &(reductionScheme), sizeof(int)); +#if (CUDART_VERSION >= 11000) + int stages; + stages = 17; + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, + &(stages), sizeof(stages)); +#endif + } + + float beta = 0.0f; + cublasLtMatmul(cublaslt_handle_, matmulDesc, &alpha, kernel, AtransformDesc, + ATransform, BtransformDesc, &beta, res, CtransformDesc, res, + CtransformDesc, (findAlgo == 1 ? (&algo) : NULL), NULL, 0, + stream_); + + cublasLtMatmulDescDestroy(matmulDesc); + cublasLtMatrixLayoutDestroy(AtransformDesc); + cublasLtMatrixLayoutDestroy(BtransformDesc); + cublasLtMatrixLayoutDestroy(CtransformDesc); + sync_check_cuda_error(); + mu_->unlock(); +} + +// Atransform: mxk CUDA_R_8I +// kernel: nxk CUDA_R_8I +// res: mxn CUDA_R_32F +// alpha: CUDA_R_32F +// beta: CUDA_R_32F +// computeType: CUBLAS_COMPUTE_32F +void cublasINT8MMWrapper::Gemm_f(float *res, int batchCount, int m, int n, + int k, int64_t stridea, int64_t strideb, + int64_t stridec, const float alpha, + const int8_t *ATransform, + const int8_t *kernel) { + mu_->lock(); + cublasOperation_t opTranspose = CUBLAS_OP_T; + // int8 gemm does not support CUBLAS_POINTER_MODE_DEVICE + // cublasLtPointerMode_t pointerMode = + // CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO; + cudaDataType_t scaleType = CUDA_R_32F; +#if (CUDART_VERSION >= 11000) + cublasComputeType_t computeType = CUBLAS_COMPUTE_32F; +#else + cudaDataType_t computeType = CUDA_R_32F; +#endif + cublasLtMatmulDesc_t matmulDesc; + cublasLtMatrixLayout_t AtransformDesc = NULL; + cublasLtMatrixLayout_t BtransformDesc = NULL; + cublasLtMatrixLayout_t CtransformDesc = NULL; + // cublasLtOrder_t order_COL32 = CUBLASLT_ORDER_COL32; + + // create matmulDesc +#if (CUDART_VERSION >= 11000) + cublasLtMatmulDescCreate(&matmulDesc, computeType, scaleType); +#else + cublasLtMatmulDescCreate(&matmulDesc, computeType); +#endif + cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSA, + &opTranspose, sizeof(cublasOperation_t)); + + cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_SCALE_TYPE, + &scaleType, sizeof(scaleType)); + // cublasLtMatmulDescSetAttribute(matmulDesc, + // CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointerMode, + // sizeof(cublasLtPointerMode_t)); + cublasLtMatrixLayoutCreate(&AtransformDesc, CUDA_R_8I, k, n, k); + cublasLtMatrixLayoutCreate(&BtransformDesc, CUDA_R_8I, k, m, k); + cublasLtMatrixLayoutCreate(&CtransformDesc, CUDA_R_32F, n, m, n); + + if (batchCount > 1) { + cublasLtMatrixLayoutSetAttribute(AtransformDesc, + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &batchCount, sizeof(batchCount)); + cublasLtMatrixLayoutSetAttribute( + AtransformDesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridea, + sizeof(stridea)); + cublasLtMatrixLayoutSetAttribute(BtransformDesc, + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &batchCount, sizeof(batchCount)); + cublasLtMatrixLayoutSetAttribute( + BtransformDesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideb, + sizeof(strideb)); + cublasLtMatrixLayoutSetAttribute(CtransformDesc, + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &batchCount, sizeof(batchCount)); + cublasLtMatrixLayoutSetAttribute( + CtransformDesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridec, + sizeof(stridec)); + } + + // get algo + cublasLtMatmulAlgo_t algo; + int findAlgo = 0; + if (cublas_algo_map_->isExist(batchCount, n, m, k, INT8_DATATYPE)) { + findAlgo = 1; + cublasLtMatmulAlgo_info tmp_info = + cublas_algo_map_->getAlgo(batchCount, n, m, k, INT8_DATATYPE); + + cublasLtMatmulAlgoInit(cublaslt_handle_, computeType, CUDA_R_32F, CUDA_R_8I, + CUDA_R_8I, CUDA_R_32F, CUDA_R_32F, tmp_info.algoId, + &algo); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &(tmp_info.customOption), + sizeof(tmp_info.customOption)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, + &(tmp_info.tile), + sizeof(tmp_info.tile)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, + &(tmp_info.splitK_val), + sizeof(tmp_info.splitK_val)); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(tmp_info.swizzle), + sizeof(tmp_info.swizzle)); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, + &(tmp_info.reductionScheme), sizeof(int)); +#if (CUDART_VERSION >= 11000) + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, + &(tmp_info.stages), + sizeof(tmp_info.stages)); +#endif + } else { + findAlgo = 1; + int algoId; + algoId = 21; + int swizzle = 0; + int customOption = 0; + int tile = 20; + int splitK_val = 0; + int reductionScheme = 0; + cublasLtMatmulAlgoInit(cublaslt_handle_, computeType, CUDA_R_32F, CUDA_R_8I, + CUDA_R_8I, CUDA_R_32F, CUDA_R_32F, algoId, &algo); + cublasLtMatmulAlgoConfigSetAttribute(&algo, + CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, + &(customOption), sizeof(customOption)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_TILE_ID, + &(tile), sizeof(tile)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, + &(splitK_val), sizeof(splitK_val)); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(swizzle), sizeof(swizzle)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, + CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, + &(reductionScheme), sizeof(int)); +#if (CUDART_VERSION >= 11000) + int stages; + stages = 17; + cublasLtMatmulAlgoConfigSetAttribute(&algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, + &(stages), sizeof(stages)); +#endif + } + + float beta = 0.0f; + findAlgo = 0; + cublasLtMatmul(cublaslt_handle_, matmulDesc, &alpha, kernel, AtransformDesc, + ATransform, BtransformDesc, &beta, res, CtransformDesc, res, + CtransformDesc, (findAlgo == 1 ? (&algo) : NULL), NULL, 0, + stream_); + + cublasLtMatmulDescDestroy(matmulDesc); + cublasLtMatrixLayoutDestroy(AtransformDesc); + cublasLtMatrixLayoutDestroy(BtransformDesc); + cublasLtMatrixLayoutDestroy(CtransformDesc); + sync_check_cuda_error(); + mu_->unlock(); +} + +bool cublasINT8MMWrapper::getUseOrderCol322R4R4() { + return use_ORDER_COL32_2R_4R4_; +} diff --git a/csrc/int8gemm/cublas/cublasINT8MMWrapper.h b/csrc/int8gemm/cublas/cublasINT8MMWrapper.h new file mode 100644 index 000000000000..ab3de04692af --- /dev/null +++ b/csrc/int8gemm/cublas/cublasINT8MMWrapper.h @@ -0,0 +1,69 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cublasAlgoMap.h" +#include "cublasMMWrapper.h" +#include "cuda_utils.h" +#include +#include +#include +#include +#include +#include + +#pragma once + +class cublasINT8MMWrapper : public cublasMMWrapper { +private: + bool use_ORDER_COL32_2R_4R4_; + +public: + cublasINT8MMWrapper(cublasLtHandle_t cublaslt_handle_, cudaStream_t stream, + cublasAlgoMap *map, std::mutex *mu, + bool use_ORDER_COL32_2R_4R4); + + cublasINT8MMWrapper(cublasHandle_t cublas_handle, + cublasLtHandle_t cublaslt_handle, cudaStream_t stream, + cublasAlgoMap *map, std::mutex *mu, + bool use_ORDER_COL32_2R_4R4); + + ~cublasINT8MMWrapper(); + + cublasINT8MMWrapper(const cublasINT8MMWrapper &wrapper); + + void Gemm(int *res, int batchCount, int m, int n, int k, int64_t stridea, + int64_t strideb, int64_t stridec, const int8_t *ATransform, + const int8_t *kernel); + + void Gemm_(int *res, int batchCount, int m, int n, int k, int64_t stridea, + int64_t strideb, int64_t stridec, const int8_t *ATransform, + const int8_t *kernel); + + void Gemm(int8_t *res, int batchCount, int m, int n, int k, int64_t stridea, + int64_t strideb, int64_t stridec, const float alpha, + const int8_t *ATransform, const int8_t *kernel); + + void Gemm_(int8_t *res, int batchCount, int m, int n, int k, int64_t stridea, + int64_t strideb, int64_t stridec, const float alpha, + const int8_t *ATransform, const int8_t *kernel); + + // w8a8sfp32ofp32 + void Gemm_f(float *res, int batchCount, int m, int n, int k, int64_t stridea, + int64_t strideb, int64_t stridec, const float alpha, + const int8_t *ATransform, const int8_t *kernel); + + bool getUseOrderCol322R4R4(); +}; \ No newline at end of file diff --git a/csrc/int8gemm/cublas/cublasMMWrapper.cc b/csrc/int8gemm/cublas/cublasMMWrapper.cc new file mode 100644 index 000000000000..184304af74b7 --- /dev/null +++ b/csrc/int8gemm/cublas/cublasMMWrapper.cc @@ -0,0 +1,851 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cublasMMWrapper.h" +#include "cuda_utils.h" +#include + +#ifndef CUDART_VERSION +#error CUDART_VERSION Undefined! +#endif + +cublasMMWrapper::cublasMMWrapper(cublasHandle_t cublas_handle, + cublasLtHandle_t cublaslt_handle, + cudaStream_t stream, + cublasAlgoMap *cublas_algo_map, std::mutex *mu, + IAllocator *allocator) + : cublas_handle_(cublas_handle), cublaslt_handle_(cublaslt_handle), + stream_(stream), cublas_algo_map_(cublas_algo_map), mu_(mu), + allocator_(allocator) { + // FT_LOG_DEBUG(__PRETTY_FUNCTION__); + if (allocator_ != nullptr) { + cublas_workspace_ = + allocator_->reMalloc(cublas_workspace_, CUBLAS_WORKSPACE_SIZE, false); + } +} + +#ifdef SPARSITY_ENABLED +cublasMMWrapper::cublasMMWrapper(cublasHandle_t cublas_handle, + cublasLtHandle_t cublaslt_handle, + cusparseLtHandle_t cusparselt_handle, + cudaStream_t stream, + cublasAlgoMap *cublas_algo_map, std::mutex *mu, + IAllocator *allocator) + : cublas_handle_(cublas_handle), cublaslt_handle_(cublaslt_handle), + cusparselt_handle_(cusparselt_handle), stream_(stream), + cublas_algo_map_(cublas_algo_map), mu_(mu), allocator_(allocator) { + // FT_LOG_DEBUG(__PRETTY_FUNCTION__); + if (allocator_ != nullptr) { + cublas_workspace_ = + allocator_->reMalloc(cublas_workspace_, CUBLAS_WORKSPACE_SIZE, false); + } +} +#endif + +cublasMMWrapper::~cublasMMWrapper() { + // FT_LOG_DEBUG(__PRETTY_FUNCTION__); + mu_ = nullptr; + if (allocator_ != nullptr) { + allocator_->free((void **)(&cublas_workspace_)); + allocator_ = nullptr; + } +} + +cublasMMWrapper::cublasMMWrapper(const cublasMMWrapper &wrapper) + : cublas_handle_(wrapper.cublas_handle_), + cublaslt_handle_(wrapper.cublaslt_handle_), +#ifdef SPARSITY_ENABLED + cusparselt_handle_(wrapper.cusparselt_handle_), +#endif + stream_(wrapper.stream_), cublas_algo_map_(wrapper.cublas_algo_map_), + mu_(wrapper.mu_), allocator_(wrapper.allocator_) { + // FT_LOG_DEBUG(__PRETTY_FUNCTION__); + if (allocator_ != nullptr) { + cublas_workspace_ = + allocator_->reMalloc(cublas_workspace_, CUBLAS_WORKSPACE_SIZE, false); + } +} + +void cublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, + const int m, const int n, const int k, + const void *alpha, const void *A, + cudaDataType_t Atype, int lda, const void *B, + cudaDataType_t Btype, int ldb, const void *beta, + void *C, cudaDataType_t Ctype, int ldc, + cudaDataType_t computeType, cublasGemmAlgo_t algo) { + // FT_LOG_DEBUG(__PRETTY_FUNCTION__); + mu_->lock(); + check_cuda_error(cublasGemmEx(cublas_handle_, transa, transb, m, n, k, alpha, + A, Atype, lda, B, Btype, ldb, beta, C, Ctype, + ldc, computeType, algo)); + sync_check_cuda_error(); + mu_->unlock(); +} + +void cublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, + const int m, const int n, const int k, const void *A, + const int lda, const void *B, const int ldb, void *C, + const int ldc) { + // FT_LOG_DEBUG(__PRETTY_FUNCTION__); + Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, 1.0f, 0.0f); +} + +void cublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, + const int m, const int n, const int k, const void *A, + const int lda, const void *B, const int ldb, void *C, + const int ldc, float f_alpha, float f_beta) { + // FT_LOG_DEBUG(__PRETTY_FUNCTION__); + half h_alpha = (half)(f_alpha); + half h_beta = (half)(f_beta); + + mu_->lock(); + // TODO: default cublas libs + int is_fp16_computeType = computeType_ == CUDA_R_16F ? 1 : 0; + bool using_cublasLt = (Atype_ == CUDA_R_16F) ? true : false; + int batch_count = 1; + // fp32 use cublas as default + // fp16 use cublasLt as default + const void *alpha = is_fp16_computeType ? reinterpret_cast(&h_alpha) + : reinterpret_cast(&f_alpha); + const void *beta = is_fp16_computeType ? reinterpret_cast(&h_beta) + : reinterpret_cast(&f_beta); + + int findAlgo = cublas_algo_map_->isExist(batch_count, m, n, k, + getCublasDataType(Atype_)); + + cublasLtMatmulAlgo_info info = cublas_algo_map_->getAlgo( + batch_count, m, n, k, getCublasDataType(Atype_)); + if (findAlgo) { + if (info.stages != -1) { + using_cublasLt = true; + } else { + using_cublasLt = false; + } + } + + if (using_cublasLt) { + cublasLtMatmulDesc_t operationDesc = NULL; + cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL; + cudaDataType_t scaleType; +#if (CUDART_VERSION >= 11000) + cublasComputeType_t computeType; +#else + cudaDataType_t computeType; +#endif + + if (is_fp16_computeType) { +#if (CUDART_VERSION >= 11000) + computeType = CUBLAS_COMPUTE_16F; +#else + computeType = CUDA_R_16F; +#endif + scaleType = CUDA_R_16F; + } else { +#if (CUDART_VERSION >= 11000) + computeType = CUBLAS_COMPUTE_32F; +#else + computeType = CUDA_R_32F; +#endif + scaleType = CUDA_R_32F; + } + + // -------------------------------------- + // Create descriptors for the original matrices + cublasLtMatrixLayoutCreate(&Adesc, Atype_, transa == CUBLAS_OP_N ? m : k, + transa == CUBLAS_OP_N ? k : m, lda); + cublasLtMatrixLayoutCreate(&Bdesc, Btype_, transb == CUBLAS_OP_N ? k : n, + transb == CUBLAS_OP_N ? n : k, ldb); + cublasLtMatrixLayoutCreate(&Cdesc, Ctype_, m, n, ldc); +#if (CUDART_VERSION >= 11000) + cublasLtMatmulDescCreate(&operationDesc, computeType, scaleType); +#else + cublasLtMatmulDescCreate(&operationDesc, computeType); +#endif + + cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, + &transa, sizeof(cublasOperation_t)); + cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, + &transb, sizeof(cublasOperation_t)); + + cublasLtMatmulAlgo_t algo; + void *workSpace = cublas_workspace_; + int workspaceSize = cublas_workspace_ == NULL ? 0 : CUBLAS_WORKSPACE_SIZE; + if (findAlgo) { + if (info.workspaceSize > workspaceSize) { + findAlgo = 0; + } else { + cublasLtMatmulAlgoInit(cublaslt_handle_, computeType, scaleType, Atype_, + Btype_, Ctype_, Ctype_, info.algoId, &algo); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &(info.customOption), + sizeof(info.customOption)); + cublasLtMatmulAlgoConfigSetAttribute(&algo, + CUBLASLT_ALGO_CONFIG_TILE_ID, + &(info.tile), sizeof(info.tile)); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &(info.splitK_val), + sizeof(info.splitK_val)); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(info.swizzle), + sizeof(info.swizzle)); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, + &(info.reductionScheme), sizeof(info.reductionScheme)); + +#if (CUDART_VERSION >= 11000) + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(info.stages), + sizeof(info.stages)); +#endif + +#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3) + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_INNER_SHAPE_ID, &(info.inner_shapeId), + sizeof(info.inner_shapeId)); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_CLUSTER_SHAPE_ID, + &(info.cluster_shapeId), sizeof(info.cluster_shapeId)); +#elif (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH < 3) + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_MMA_SHAPE_ID, &(info.mma_shapeId), + sizeof(info.mma_shapeId)); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_CGA_SHAPE_ID, &(info.cga_shapeId), + sizeof(info.cga_shapeId)); + cublasLtMatmulAlgoConfigSetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_SCHEDULING_MODE, &(info.sche_mode), + sizeof(info.sche_mode)); +#endif + } + } + + cublasLtMatmul(cublaslt_handle_, operationDesc, alpha, A, Adesc, B, Bdesc, + beta, C, Cdesc, C, Cdesc, (findAlgo == 1 ? (&algo) : NULL), + workSpace, workspaceSize, stream_); + + cublasLtMatmulDescDestroy(operationDesc); + cublasLtMatrixLayoutDestroy(Adesc); + cublasLtMatrixLayoutDestroy(Bdesc); + cublasLtMatrixLayoutDestroy(Cdesc); + sync_check_cuda_error(); + } else { + int cublasAlgo = info.algoId; + check_cuda_error(cublasGemmEx(cublas_handle_, transa, transb, m, n, k, + alpha, A, Atype_, lda, B, Btype_, ldb, beta, + C, Ctype_, ldc, computeType_, + static_cast(cublasAlgo))); + sync_check_cuda_error(); + } + mu_->unlock(); +} + +void cublasMMWrapper::setFP32GemmConfig() { + Atype_ = CUDA_R_32F; + Btype_ = CUDA_R_32F; + Ctype_ = CUDA_R_32F; + computeType_ = CUDA_R_32F; +} + +void cublasMMWrapper::setFP16GemmConfig() { + Atype_ = CUDA_R_16F; + Btype_ = CUDA_R_16F; + Ctype_ = CUDA_R_16F; + computeType_ = CUDA_R_32F; +} + +#ifdef ENABLE_BF16 +void cublasMMWrapper::setBF16GemmConfig() { + Atype_ = CUDA_R_16BF; + Btype_ = CUDA_R_16BF; + Ctype_ = CUDA_R_16BF; + computeType_ = CUDA_R_32F; +} +#endif + +void cublasMMWrapper::setGemmConfig(cudaDataType_t aType, cudaDataType_t bType, + cudaDataType_t cType, + cudaDataType_t computeType) { + Atype_ = aType; + Btype_ = bType; + Ctype_ = cType; + computeType_ = computeType; +} + +CublasDataType cublasMMWrapper::getCublasDataType(cudaDataType_t data_type) { + if (data_type == CUDA_R_16F) { + return HALF_DATATYPE; + } else if (data_type == CUDA_R_32F) { + return FLOAT_DATATYPE; + } +#ifdef ENABLE_BF16 + else if (data_type == CUDA_R_16BF) { + return BFLOAT16_DATATYPE; + } +#endif + return FLOAT_DATATYPE; +} + +#if (CUDART_VERSION >= 11000) +// input, weight, output are row-major +// only works for cublas 11.x +void cublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, + const int m, const int n, const int k, const void *A, + const int lda, const void *B, const int ldb, + const void *bias, void *C, const int ldc) { + // FT_LOG_DEBUG(__PRETTY_FUNCTION__); + cudaDataType_t Atype, Btype, Ctype; + cublasComputeType_t computeType; + cudaDataType_t scaleType; + float alpha_float = 1.0f; + float beta_float = 0.0f; + half alpha_half = half(1.0f); + half beta_half = half(0.0f); + void *alpha, *beta; + + // int is_fp16_computeType = computeType_ == CUDA_R_16F ? 1 : 0; + if (Atype_ == CUDA_R_32F) { + computeType = CUBLAS_COMPUTE_32F_FAST_TF32; + Atype = CUDA_R_32F; + Btype = CUDA_R_32F; + Ctype = CUDA_R_32F; + scaleType = CUDA_R_32F; + alpha = &alpha_float; + beta = &beta_float; + } else if (Atype_ == CUDA_R_16BF) { + computeType = CUBLAS_COMPUTE_32F_FAST_TF32; + Atype = CUDA_R_16BF; + Btype = CUDA_R_16BF; + Ctype = CUDA_R_16BF; + scaleType = CUDA_R_32F; + alpha = &alpha_float; + beta = &beta_float; + } else { + computeType = CUBLAS_COMPUTE_16F; + Atype = CUDA_R_16F; + Btype = CUDA_R_16F; + Ctype = CUDA_R_16F; + scaleType = CUDA_R_16F; + alpha = &alpha_half; + beta = &beta_half; + } + + cublasLtMatmulDesc_t operationDesc = NULL; + cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL; + cublasLtEpilogue_t epi = CUBLASLT_EPILOGUE_BIAS; + cublasLtMatrixLayoutCreate(&Adesc, Atype, (transa == CUBLAS_OP_N) ? m : k, + (transa == CUBLAS_OP_N) ? k : m, lda); + cublasLtMatrixLayoutCreate(&Bdesc, Btype, (transb == CUBLAS_OP_N) ? k : n, + (transb == CUBLAS_OP_N) ? n : k, ldb); + cublasLtMatrixLayoutCreate(&Cdesc, Ctype, m, n, ldc); + + cublasLtMatmulDescCreate(&operationDesc, computeType, scaleType); + cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, + &transa, sizeof(cublasOperation_t)); + cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, + &transb, sizeof(cublasOperation_t)); + cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, + &epi, sizeof(cublasLtEpilogue_t)); + cublasLtMatmulDescSetAttribute(operationDesc, + CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, + sizeof(const void *)); + check_cuda_error(cublasLtMatmul(cublaslt_handle_, operationDesc, alpha, A, + Adesc, B, Bdesc, beta, C, Cdesc, C, Cdesc, + NULL, NULL, 0, stream_)); + cublasLtMatrixLayoutDestroy(Adesc); + cublasLtMatrixLayoutDestroy(Bdesc); + cublasLtMatrixLayoutDestroy(Cdesc); + cublasLtMatmulDescDestroy(operationDesc); +} +#endif +void cublasMMWrapper::setStream(cudaStream_t stream) { stream_ = stream; } + +void cublasMMWrapper::stridedBatchedGemm( + cublasOperation_t transa, cublasOperation_t transb, const int m, + const int n, const int k, const void *A, const int lda, + const int64_t strideA, const void *B, const int ldb, const int64_t strideB, + void *C, const int ldc, const int64_t strideC, const int batch_count, + const float f_alpha, const float f_beta) { + half h_alpha = (half)f_alpha; + half h_beta = (half)f_beta; + + mu_->lock(); + int is_fp16_computeType = computeType_ == CUDA_R_16F ? 1 : 0; + const void *alpha = is_fp16_computeType + ? reinterpret_cast(&h_alpha) + : reinterpret_cast(&f_alpha); + const void *beta = is_fp16_computeType + ? reinterpret_cast(&h_beta) + : reinterpret_cast(&f_beta); + cublasLtMatmulAlgo_info info = cublas_algo_map_->getAlgo( + batch_count, m, n, k, getCublasDataType(Atype_)); + + check_cuda_error(cublasGemmStridedBatchedEx( + cublas_handle_, transa, transb, m, n, k, alpha, A, Atype_, lda, strideA, + B, Btype_, ldb, strideB, beta, C, Ctype_, ldc, strideC, batch_count, + computeType_, static_cast(info.algoId))); + + mu_->unlock(); +} + +void cublasMMWrapper::stridedBatchedGemm( + cublasOperation_t transa, cublasOperation_t transb, const int m, + const int n, const int k, const float f_alpha, const void *A, + cudaDataType_t AType, const int lda, const int64_t strideA, const void *B, + cudaDataType_t BType, const int ldb, const int64_t strideB, + const float f_beta, void *C, cudaDataType_t CType, const int ldc, + const int64_t strideC, const int batch_count, cudaDataType_t computeType) { + half h_alpha = (half)f_alpha; + half h_beta = (half)f_beta; + + mu_->lock(); + int is_fp16_computeType = computeType == CUDA_R_16F ? 1 : 0; + const void *alpha = is_fp16_computeType + ? reinterpret_cast(&h_alpha) + : reinterpret_cast(&f_alpha); + const void *beta = is_fp16_computeType + ? reinterpret_cast(&h_beta) + : reinterpret_cast(&f_beta); + cublasLtMatmulAlgo_info info = cublas_algo_map_->getAlgo( + batch_count, m, n, k, getCublasDataType(Atype_)); + + check_cuda_error(cublasGemmStridedBatchedEx( + cublas_handle_, transa, transb, m, n, k, alpha, A, AType, lda, strideA, B, + BType, ldb, strideB, beta, C, CType, ldc, strideC, batch_count, + computeType, static_cast(info.algoId))); + + mu_->unlock(); +} + +void cublasMMWrapper::batchedGemm(cublasOperation_t transa, + cublasOperation_t transb, const int m, + const int n, const int k, + const void *const *A, const int lda, + const void *const *B, const int ldb, + void *const *C, const int ldc, + const int batch_count) { + float f_alpha = static_cast(1.0f); + float f_beta = static_cast(0.0f); + + half h_alpha = (half)1.0f; + half h_beta = (half)0.0f; + + mu_->lock(); + int is_fp16_computeType = computeType_ == CUDA_R_16F ? 1 : 0; + const void *alpha = is_fp16_computeType ? reinterpret_cast(&h_alpha) + : reinterpret_cast(&f_alpha); + const void *beta = is_fp16_computeType ? reinterpret_cast(&h_beta) + : reinterpret_cast(&f_beta); + cublasLtMatmulAlgo_info info = cublas_algo_map_->getAlgo( + batch_count, m, n, k, getCublasDataType(Atype_)); + + check_cuda_error(cublasGemmBatchedEx( + cublas_handle_, transa, transb, m, n, k, alpha, A, Atype_, lda, B, Btype_, + ldb, beta, C, Ctype_, ldc, batch_count, computeType_, + static_cast(info.algoId))); + mu_->unlock(); +} + +bool cublasMMWrapper::isFuseBatchGemm(const int batch_count, const int m, + const int k, const int n) { + CublasDataType data_type = getCublasDataType(Atype_); + + if (cublas_algo_map_->isExist(batch_count, m, k, n, data_type) == false || + cublas_algo_map_->isExist(1, m, k, n, data_type) == false) { + return false; + } else { + return cublas_algo_map_->getAlgo(batch_count, m, k, n, data_type) + .exec_time < + 3 * cublas_algo_map_->getAlgo(1, m, k, n, data_type).exec_time; + } +} + +#ifdef SPARSITY_ENABLED +void cublasMMWrapper::SpGemm(cublasOperation_t transa, cublasOperation_t transb, + const int m, const int n, const int k, + const void *A, const void *B, void *C) { + if (Atype_ != CUDA_R_16F || Btype_ != CUDA_R_16F || Ctype_ != CUDA_R_16F) { + throw std::runtime_error( + "\n[FT][ERROR] sparse GEMM only supports FP16 data type now."); + } + static bool not_printed_fp32_accumulation_warning = true; + if (computeType_ != CUDA_R_16F && not_printed_fp32_accumulation_warning) { + printf("[FT][WARNING] cublasMMWrapper sets to FP32 compute type, " + "but sparse gemm will use FP16 compute type since cusparselt " + "supports FP16 accumulation only.\n"); + not_printed_fp32_accumulation_warning = false; + } + cusparseOrder_t order = CUSPARSE_ORDER_COL; + cusparseOperation_t opA = (transa == CUBLAS_OP_N) + ? CUSPARSE_OPERATION_NON_TRANSPOSE + : CUSPARSE_OPERATION_TRANSPOSE; + cusparseOperation_t opB = (transb == CUBLAS_OP_N) + ? CUSPARSE_OPERATION_NON_TRANSPOSE + : CUSPARSE_OPERATION_TRANSPOSE; + cusparseComputeType compute_type = CUSPARSE_COMPUTE_16F; + cusparseLtMatmulDescriptor_t matmul; + cusparseLtMatmulAlgSelection_t alg_sel; + cusparseLtMatmulPlan_t plan; + + bool is_rowmajor = (order == CUSPARSE_ORDER_ROW); + bool isA_transposed = (opA != CUSPARSE_OPERATION_NON_TRANSPOSE); + bool isB_transposed = (opB != CUSPARSE_OPERATION_NON_TRANSPOSE); + auto num_A_rows = (isA_transposed) ? k : m; + auto num_A_cols = (isA_transposed) ? m : k; + auto num_B_rows = (isB_transposed) ? n : k; + auto num_B_cols = (isB_transposed) ? k : n; + auto num_C_rows = m; + auto num_C_cols = n; + unsigned alignment = 16; + auto lda = (is_rowmajor) ? num_A_cols : num_A_rows; + auto ldb = (is_rowmajor) ? num_B_cols : num_B_rows; + auto ldc = (is_rowmajor) ? num_C_cols : num_C_rows; + float _alpha(1.0f); + float _beta(0.0f); + + char mark[256]; + sprintf(mark, "%d_%d_%d_%d", 1, m, n, k); + if (sp_mat_A_desc_map_.find(mark) != sp_mat_A_desc_map_.end()) { + CHECK_CUSPARSE(cusparseLtMatmulDescriptorInit( + &cusparselt_handle_, &matmul, opA, opB, &sp_mat_A_desc_map_[mark], + &sp_mat_B_desc_map_[mark], &sp_mat_C_desc_map_[mark], + &sp_mat_C_desc_map_[mark], compute_type)) + } else { + // initializing MatDesc takes a lot of time + cusparseLtMatDescriptor_t matA, matB, matC; + sp_mat_A_desc_map_[mark] = matA; + sp_mat_B_desc_map_[mark] = matB; + sp_mat_C_desc_map_[mark] = matC; + CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit( + &cusparselt_handle_, &sp_mat_A_desc_map_[mark], num_A_rows, num_A_cols, + lda, alignment, Atype_, order, CUSPARSELT_SPARSITY_50_PERCENT)) + CHECK_CUSPARSE(cusparseLtDenseDescriptorInit( + &cusparselt_handle_, &sp_mat_B_desc_map_[mark], num_B_rows, num_B_cols, + ldb, alignment, Btype_, order)) + CHECK_CUSPARSE(cusparseLtDenseDescriptorInit( + &cusparselt_handle_, &sp_mat_C_desc_map_[mark], num_C_rows, num_C_cols, + ldc, alignment, Ctype_, order)) + CHECK_CUSPARSE(cusparseLtMatmulDescriptorInit( + &cusparselt_handle_, &matmul, opA, opB, &sp_mat_A_desc_map_[mark], + &sp_mat_B_desc_map_[mark], &sp_mat_C_desc_map_[mark], + &sp_mat_C_desc_map_[mark], compute_type)) + } + mu_->lock(); + CHECK_CUSPARSE(cusparseLtMatmulAlgSelectionInit( + &cusparselt_handle_, &alg_sel, &matmul, CUSPARSELT_MATMUL_ALG_DEFAULT)) + int alg = cublas_algo_map_->getSpAlgo(1, num_A_rows, num_B_cols, num_A_cols); + CHECK_CUSPARSE(cusparseLtMatmulAlgSetAttribute( + &cusparselt_handle_, &alg_sel, CUSPARSELT_MATMUL_ALG_CONFIG_ID, &alg, + sizeof(alg))) + size_t workspace_size; + CHECK_CUSPARSE(cusparseLtMatmulGetWorkspace(&cusparselt_handle_, &alg_sel, + &workspace_size)) + CHECK_CUSPARSE(cusparseLtMatmulPlanInit(&cusparselt_handle_, &plan, &matmul, + &alg_sel, workspace_size)) + + void *d_workspace = nullptr; + int num_streams = 1; + cudaStream_t streams[1] = {stream_}; + CHECK_CUSPARSE(cusparseLtMatmul(&cusparselt_handle_, &plan, &_alpha, A, B, + &_beta, C, C, d_workspace, streams, + num_streams)) + CHECK_CUSPARSE(cusparseLtMatmulPlanDestroy(&plan)) + sync_check_cuda_error(); + mu_->unlock(); +} + +size_t cublasMMWrapper::getSparseMatrixSize(int m, int k) { + // Get a compressed matrix size of shape (m, k) used in cusparselt. + auto Atype_ = CUDA_R_16F; + cusparseOrder_t order = CUSPARSE_ORDER_COL; + unsigned alignment = 16; + int num_A_rows = m; + int num_A_cols = k; + int lda = num_A_rows; + + cusparseLtMatDescriptor_t matA; + CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit( + &cusparselt_handle_, &matA, num_A_rows, num_A_cols, lda, alignment, + Atype_, order, CUSPARSELT_SPARSITY_50_PERCENT)); + size_t compressed_size = 0; + CHECK_CUSPARSE(cusparseLtSpMMACompressedSize2(&cusparselt_handle_, &matA, + &compressed_size)); + return compressed_size; +} + +void cublasMMWrapper::compressMatrix(const void *input, void *output, + const int m, const int k) { + cusparseOrder_t order = CUSPARSE_ORDER_COL; + cusparseOperation_t opA = CUSPARSE_OPERATION_NON_TRANSPOSE; + cusparseLtMatDescriptor_t matA; + unsigned alignment = 16; + CHECK_CUSPARSE(cusparseLtStructuredDescriptorInit( + &cusparselt_handle_, &matA, m, k, m, alignment, CUDA_R_16F, order, + CUSPARSELT_SPARSITY_50_PERCENT)) + CHECK_CUSPARSE(cusparseLtSpMMACompress2(&cusparselt_handle_, &matA, true, opA, + input, output, stream_)) + sync_check_cuda_error(); +} + +bool cublasMMWrapper::isUseSparse(const int batch_count, const int m, + const int n, const int k) { + return cublas_algo_map_->isUseSparse(batch_count, m, n, k); +} +#endif + +std::pair cublasMMWrapper::findBestAlgo( + cublasLtHandle_t lightHandle, cublasLtMatmulDesc_t computeDesc, + const void *alpha, const void *A, cublasLtMatrixLayout_t Adesc, + const void *B, cublasLtMatrixLayout_t Bdesc, const void *beta, + const void *C, cublasLtMatrixLayout_t Cdesc, void *D, + cublasLtMatrixLayout_t Ddesc, cudaStream_t stream) { +#if (CUBLAS_VERSION) < 11601 + FT_CHECK_WITH_INFO(false, "CUBLAS version too low."); + return {false, cublasLtMatmulAlgo_t{}}; +#else + size_t returnSize; + int32_t pointer_mode; + cublasLtMatmulDescGetAttribute(computeDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, + &pointer_mode, sizeof(pointer_mode), + &returnSize); + + std::vector heuristics(200); + cublasLtMatmulPreference_t preference; + check_cuda_error(cublasLtMatmulPreferenceCreate(&preference)); + check_cuda_error(cublasLtMatmulPreferenceInit(preference)); + uint64_t workspace_size = CUBLAS_WORKSPACE_SIZE; + check_cuda_error(cublasLtMatmulPreferenceSetAttribute( + preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspace_size, + sizeof(workspace_size))); +#if (CUBLAS_VERSION) <= 12000 + uint32_t pointer_mode_mask = 0; + check_cuda_error(cublasLtMatmulPreferenceSetAttribute( + preference, CUBLASLT_MATMUL_PREF_EPILOGUE_MASK, &pointer_mode_mask, + sizeof(pointer_mode_mask))); +#endif + + int return_count = 0; + auto ret = cublasLtMatmulAlgoGetHeuristic( + lightHandle, computeDesc, Adesc, Bdesc, Cdesc, Ddesc, preference, + heuristics.size(), heuristics.data(), &return_count); + heuristics.resize(return_count); + + std::map> algo_results; + for (const auto &heuristic : heuristics) { + cublasLtMatmulAlgo_t algo = heuristic.algo; + int32_t algo_id; + cublasLtMatmulAlgoConfigGetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_ID, &algo_id, sizeof(algo_id), &returnSize); + + cudaEvent_t start_event, stop_event; + cudaEventCreate(&start_event); + cudaEventCreate(&stop_event); + + float my_alpha = 1.0f; + float my_beta = 0.0f; + + for (int i = 0; i < 11; i++) { + float duration_ms; + cudaEventRecord(start_event, stream); + check_cuda_error(cublasLtMatmul( + lightHandle, computeDesc, alpha, A, Adesc, B, Bdesc, beta, C, Cdesc, + D, Ddesc, &algo, cublas_workspace_, CUBLAS_WORKSPACE_SIZE, stream)); + cudaEventRecord(stop_event, stream); + cudaEventSynchronize(stop_event); + cudaEventElapsedTime(&duration_ms, start_event, stop_event); + + algo_results[algo_id].push_back(duration_ms); + } + std::sort(algo_results[algo_id].begin(), algo_results[algo_id].end()); + } + + cublasLtMatmulHeuristicResult_t result; + float best_time = INFINITY; + for (const auto &heuristic : heuristics) { + cublasLtMatmulAlgo_t algo = heuristic.algo; + int32_t algo_id; + cublasLtMatmulAlgoConfigGetAttribute( + &algo, CUBLASLT_ALGO_CONFIG_ID, &algo_id, sizeof(algo_id), &returnSize); + const auto &results = algo_results[algo_id]; + + if (results.size() > 0 && results[5] < best_time) { + best_time = results[5]; + result = heuristic; + } + } + + return {best_time != INFINITY, result.algo}; +#endif +} + +cublasMMWrapper::MatrixLayout +cublasMMWrapper::createMatrixLayout(cublasLtMatrixLayout_t Mdesc) { + size_t returnSize; + MatrixLayout m_layout; + + cublasLtMatrixLayoutGetAttribute(Mdesc, CUBLASLT_MATRIX_LAYOUT_TYPE, + &std::get<0>(m_layout), + sizeof(std::get<0>(m_layout)), &returnSize); + cublasLtMatrixLayoutGetAttribute(Mdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, + &std::get<1>(m_layout), + sizeof(std::get<1>(m_layout)), &returnSize); + cublasLtMatrixLayoutGetAttribute(Mdesc, CUBLASLT_MATRIX_LAYOUT_ROWS, + &std::get<2>(m_layout), + sizeof(std::get<2>(m_layout)), &returnSize); + cublasLtMatrixLayoutGetAttribute(Mdesc, CUBLASLT_MATRIX_LAYOUT_COLS, + &std::get<3>(m_layout), + sizeof(std::get<3>(m_layout)), &returnSize); + + return m_layout; +} + +cublasStatus_t cublasMMWrapper::cublasLtMatmulWrapper( + cublasLtHandle_t lightHandle, cublasLtMatmulDesc_t computeDesc, + const void *alpha, const void *A, cublasLtMatrixLayout_t Adesc, + const void *B, cublasLtMatrixLayout_t Bdesc, const void *beta, + const void *C, cublasLtMatrixLayout_t Cdesc, void *D, + cublasLtMatrixLayout_t Ddesc, const cublasLtMatmulAlgo_t *algo, + void *workspace, size_t workspaceSizeInBytes, cudaStream_t stream) { + cache_idx_t cache_idx{computeDesc, + {createMatrixLayout(Adesc), createMatrixLayout(Bdesc), + createMatrixLayout(Cdesc), createMatrixLayout(Ddesc)}}; + + cublasLtMatmulAlgo_t algo_value; + bool found_algo = false; + if (algo == nullptr) { + if (algo_cache.find(cache_idx) == algo_cache.end()) { + auto result = findBestAlgo(lightHandle, computeDesc, alpha, A, Adesc, B, + Bdesc, beta, C, Cdesc, D, Ddesc, stream); + if (result.first) { + algo_cache[cache_idx] = result.second; + algo_value = result.second; + found_algo = true; + } + } else { + algo_value = algo_cache[cache_idx]; + found_algo = true; + } + } + + return cublasLtMatmul(lightHandle, computeDesc, alpha, A, Adesc, B, Bdesc, + beta, C, Cdesc, D, Ddesc, + found_algo ? &algo_value : algo, workspace, + workspaceSizeInBytes, stream); +} + +void cublasMMWrapper::_Int8Gemm(const int m, const int n, const int k, + const int8_t *A, const int lda, const int8_t *B, + const int ldb, void *C, const int ldc, + const void *alpha, const int mode, + const bool per_column_scaling) { +/* mode: + * - 0: int8 * int8 -> int32 -> int8 + * - 1: int8 * int8 -> int32 -> int32 + */ +#if (CUBLAS_VERSION) < 11601 + FT_CHECK_WITH_INFO(false, "CUBLAS version too low."); +#else + + mu_->lock(); + const auto op_a = CUBLAS_OP_T; + const auto op_b = CUBLAS_OP_N; + const auto dataType = CUDA_R_8I; + const auto resultType = mode == 0 ? CUDA_R_8I : CUDA_R_32I; + const auto computeType = CUBLAS_COMPUTE_32I; + const auto scaleType = mode == 0 ? CUDA_R_32F : CUDA_R_32I; + const int batch_count = 1; + const void *beta; + + int findAlgo = cublas_algo_map_->isExist(batch_count, m, n, k, + getCublasDataType(dataType)); + + cublasLtMatmulAlgo_info info = cublas_algo_map_->getAlgo( + batch_count, m, n, k, getCublasDataType(dataType)); + + cublasLtMatmulDesc_t operationDesc = NULL; + cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL; + + // -------------------------------------- + // Create descriptors for the original matrices + check_cuda_error(cublasLtMatrixLayoutCreate(&Adesc, dataType, k, m, lda)); + check_cuda_error(cublasLtMatrixLayoutCreate(&Bdesc, dataType, k, n, ldb)); + check_cuda_error(cublasLtMatrixLayoutCreate(&Cdesc, resultType, m, n, ldc)); + + check_cuda_error( + cublasLtMatmulDescCreate(&operationDesc, computeType, scaleType)); + + auto pointer_mode = CUBLASLT_POINTER_MODE_HOST; + if (mode == 0) { + pointer_mode = per_column_scaling + ? CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST + : CUBLASLT_POINTER_MODE_DEVICE; + } + check_cuda_error( + cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, + &op_a, sizeof(cublasOperation_t))); + check_cuda_error( + cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, + &op_b, sizeof(cublasOperation_t))); + check_cuda_error( + cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSC, + &op_b, sizeof(cublasOperation_t))); + check_cuda_error(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointer_mode, + sizeof(pointer_mode))); + + const int32_t int_one = 1; + const int32_t int_zero = 0; + const float float_zero = 0; + if (mode == 0) { + beta = per_column_scaling ? &float_zero : NULL; + } else { + alpha = &int_one; + beta = &int_zero; + } + + cublasLtMatmulAlgo_t algo; + void *workSpace = cublas_workspace_; + int workspaceSize = cublas_workspace_ == NULL ? 0 : CUBLAS_WORKSPACE_SIZE; + + sync_check_cuda_error(); + auto ret = cublasLtMatmulWrapper(cublaslt_handle_, operationDesc, alpha, A, + Adesc, B, Bdesc, beta, C, Cdesc, C, Cdesc, + NULL, workSpace, workspaceSize, stream_); + check_cuda_error(ret); + sync_check_cuda_error(); + + cublasLtMatmulDescDestroy(operationDesc); + cublasLtMatrixLayoutDestroy(Adesc); + cublasLtMatrixLayoutDestroy(Bdesc); + cublasLtMatrixLayoutDestroy(Cdesc); + sync_check_cuda_error(); + mu_->unlock(); +#endif +} + +void cublasMMWrapper::Int8Gemm(const int m, const int n, const int k, + const int8_t *A, const int lda, const int8_t *B, + const int ldb, int8_t *C, const int ldc, + const float *alpha, + const bool per_column_scaling) { + return _Int8Gemm(m, n, k, A, lda, B, ldb, C, ldc, alpha, 0, + per_column_scaling); +} + +void cublasMMWrapper::Int8Gemm(const int m, const int n, const int k, + const int8_t *A, const int lda, const int8_t *B, + const int ldb, int32_t *C, const int ldc) { + return _Int8Gemm(m, n, k, A, lda, B, ldb, C, ldc, (float *)nullptr, 1, false); +} diff --git a/csrc/int8gemm/cublas/cublasMMWrapper.h b/csrc/int8gemm/cublas/cublasMMWrapper.h new file mode 100644 index 000000000000..69f229246ea9 --- /dev/null +++ b/csrc/int8gemm/cublas/cublasMMWrapper.h @@ -0,0 +1,177 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "allocator.h" +#include "cublasAlgoMap.h" +#include "cuda_utils.h" +#include +#include +#include +#include +#include +#include + +#pragma once + +class cublasMMWrapper { +protected: + cublasHandle_t cublas_handle_; + cublasLtHandle_t cublaslt_handle_; +#ifdef SPARSITY_ENABLED + cusparseLtHandle_t cusparselt_handle_; + std::map sp_mat_A_desc_map_; + std::map sp_mat_B_desc_map_; + std::map sp_mat_C_desc_map_; +#endif + + cudaDataType_t Atype_; + cudaDataType_t Btype_; + cudaDataType_t Ctype_; + cudaDataType_t computeType_; + + cudaStream_t stream_; + cublasAlgoMap *cublas_algo_map_; + std::mutex *mu_; + + IAllocator *allocator_ = nullptr; + void *cublas_workspace_ = nullptr; + + friend class cublasINT8MMWrapper; + + void _Int8Gemm(const int m, const int n, const int k, const int8_t *A, + const int lda, const int8_t *B, const int ldb, void *C, + const int ldc, const void *alpha, const int mode, + const bool per_column_scaling); + +public: + cublasMMWrapper(cublasHandle_t cublas_handle_, + cublasLtHandle_t cublaslt_handle_, cudaStream_t stream, + cublasAlgoMap *map, std::mutex *mu, IAllocator *allocator); + +#ifdef SPARSITY_ENABLED + cublasMMWrapper(cublasHandle_t cublas_handle_, + cublasLtHandle_t cublaslt_handle_, + cusparseLtHandle_t cusparselt_handle, cudaStream_t stream, + cublasAlgoMap *map, std::mutex *mu, IAllocator *allocator); +#endif + + ~cublasMMWrapper(); + + cublasMMWrapper(const cublasMMWrapper &wrapper); + + virtual void cublasVersionCheck() { return; }; + cublasStatus_t cublasLtMatmulWrapper( + cublasLtHandle_t lightHandle, cublasLtMatmulDesc_t computeDesc, + const void *alpha, const void *A, cublasLtMatrixLayout_t Adesc, + const void *B, cublasLtMatrixLayout_t Bdesc, const void *beta, + const void *C, cublasLtMatrixLayout_t Cdesc, void *D, + cublasLtMatrixLayout_t Ddesc, const cublasLtMatmulAlgo_t *algo, + void *workspace, size_t workspaceSizeInBytes, cudaStream_t stream); + + std::pair + findBestAlgo(cublasLtHandle_t lightHandle, cublasLtMatmulDesc_t computeDesc, + const void *alpha, const void *A, cublasLtMatrixLayout_t Adesc, + const void *B, cublasLtMatrixLayout_t Bdesc, const void *beta, + const void *C, cublasLtMatrixLayout_t Cdesc, void *D, + cublasLtMatrixLayout_t Ddesc, cudaStream_t stream); + + using MatrixLayout = + std::tuple; + using cache_idx_t = + std::tuple>; + std::map algo_cache; + + MatrixLayout createMatrixLayout(cublasLtMatrixLayout_t Mdesc); + + void Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, + const int n, const int k, const void *alpha, const void *A, + cudaDataType_t Atype, int lda, const void *B, cudaDataType_t Btype, + int ldb, const void *beta, void *C, cudaDataType_t Ctype, int ldc, + cudaDataType_t computeType, cublasGemmAlgo_t algo); + + void Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, + const int n, const int k, const void *A, const int lda, + const void *B, const int ldb, void *C, const int ldc); + + void Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, + const int n, const int k, const void *A, const int lda, + const void *B, const int ldb, void *C, const int ldc, float f_alpha, + float f_beta); + + void Int8Gemm(const int m, const int n, const int k, const int8_t *A, + const int lda, const int8_t *B, const int ldb, int8_t *C, + const int ldc, const float *alpha, + const bool per_column_scaling = false); + + void Int8Gemm(const int m, const int n, const int k, const int8_t *A, + const int lda, const int8_t *B, const int ldb, int32_t *C, + const int ldc); + + void setFP32GemmConfig(); + void setFP16GemmConfig(); +#ifdef ENABLE_BF16 + void setBF16GemmConfig(); +#endif + void setStream(cudaStream_t stream); + + void setGemmConfig(cudaDataType_t aType, cudaDataType_t bType, + cudaDataType_t cType, cudaDataType_t computeType); + + CublasDataType getCublasDataType(cudaDataType_t data_type); + +#if (CUDART_VERSION >= 11000) + void Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, + const int n, const int k, const void *A, const int lda, + const void *B, const int ldb, const void *bias, void *C, + const int ldc); +#endif + + void stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, + const int m, const int n, const int k, const void *A, + const int lda, const int64_t strideA, const void *B, + const int ldb, const int64_t strideB, void *C, + const int ldc, const int64_t strideC, + const int batchCount, const float f_alpha = 1.0f, + const float f_beta = 0.0f); + + void stridedBatchedGemm( + cublasOperation_t transa, cublasOperation_t transb, const int m, + const int n, const int k, const float f_alpha, const void *A, + cudaDataType_t AType, const int lda, const int64_t strideA, const void *B, + cudaDataType_t BType, const int ldb, const int64_t strideB, + const float f_beta, void *C, cudaDataType_t CType, const int ldc, + const int64_t strideC, const int batch_count, cudaDataType_t computeType); + + void batchedGemm(cublasOperation_t transa, cublasOperation_t transb, + const int m, const int n, const int k, const void *const *A, + const int lda, const void *const *B, const int ldb, + void *const *C, const int ldc, const int batch_count); + + bool isFuseBatchGemm(const int batch_count, const int m, const int k, + const int n); + +#ifdef SPARSITY_ENABLED + void SpGemm(cublasOperation_t transa, cublasOperation_t transb, const int m, + const int n, const int k, const void *A, const void *B, void *C); + + size_t getSparseMatrixSize(int m, int k); + void compressMatrix(const void *input, void *output, const int m, + const int k); + + bool isUseSparse(const int batch_count, const int m, const int n, + const int k); +#endif +}; diff --git a/csrc/int8gemm/cublas/cuda_utils.cc b/csrc/int8gemm/cublas/cuda_utils.cc new file mode 100644 index 000000000000..dc0bc509fd81 --- /dev/null +++ b/csrc/int8gemm/cublas/cuda_utils.cc @@ -0,0 +1,326 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cuda_utils.h" +// #include "cuda_fp8_utils.h" + +/* **************************** debug tools ********************************* */ + +template +void print_to_file(const T *result, const int size, const char *file, + cudaStream_t stream, std::ios::openmode open_mode) { + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); + printf("[INFO] file: %s with size %d.\n", file, size); + std::ofstream outFile(file, open_mode); + if (outFile) { + T *tmp = new T[size]; + check_cuda_error(cudaMemcpyAsync(tmp, result, sizeof(T) * size, + cudaMemcpyDeviceToHost, stream)); + for (int i = 0; i < size; ++i) { + float val = (float)(tmp[i]); + outFile << val << std::endl; + } + delete[] tmp; + } else { + throw std::runtime_error(std::string("[FT][ERROR] Cannot open file: ") + + file + "\n"); + } + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +} + +template void print_to_file(const float *result, const int size, + const char *file, cudaStream_t stream, + std::ios::openmode open_mode); +template void print_to_file(const half *result, const int size, + const char *file, cudaStream_t stream, + std::ios::openmode open_mode); + +template +void print_abs_mean(const T *buf, uint size, cudaStream_t stream, + std::string name) { + if (buf == nullptr) { + // FT_LOG_WARNING("It is an nullptr, skip!"); + return; + } + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); + T *h_tmp = new T[size]; + cudaMemcpyAsync(h_tmp, buf, sizeof(T) * size, cudaMemcpyDeviceToHost, stream); + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); + double sum = 0.0f; + uint64_t zero_count = 0; + float max_val = -1e10; + bool find_inf = false; + for (uint i = 0; i < size; i++) { + if (std::isinf((float)(h_tmp[i]))) { + find_inf = true; + continue; + } + sum += abs((double)h_tmp[i]); + if ((float)h_tmp[i] == 0.0f) { + zero_count++; + } + max_val = max_val > abs(float(h_tmp[i])) ? max_val : abs(float(h_tmp[i])); + } + printf("[INFO][FT] %20s size: %u, abs mean: %f, abs sum: %f, abs max: %f, " + "find inf: %s", + name.c_str(), size, sum / size, sum, max_val, + find_inf ? "true" : "false"); + std::cout << std::endl; + delete[] h_tmp; + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +} + +template void print_abs_mean(const float *buf, uint size, cudaStream_t stream, + std::string name); +template void print_abs_mean(const half *buf, uint size, cudaStream_t stream, + std::string name); +template void print_abs_mean(const int *buf, uint size, cudaStream_t stream, + std::string name); +template void print_abs_mean(const uint *buf, uint size, cudaStream_t stream, + std::string name); +template void print_abs_mean(const int8_t *buf, uint size, cudaStream_t stream, + std::string name); + +template void print_to_screen(const T *result, const int size) { + if (result == nullptr) { + // FT_LOG_WARNING("It is an nullptr, skip! \n"); + return; + } + T *tmp = reinterpret_cast(malloc(sizeof(T) * size)); + check_cuda_error( + cudaMemcpy(tmp, result, sizeof(T) * size, cudaMemcpyDeviceToHost)); + for (int i = 0; i < size; ++i) { + printf("%d, %f\n", i, static_cast(tmp[i])); + } + free(tmp); +} + +template void print_to_screen(const float *result, const int size); +template void print_to_screen(const half *result, const int size); +template void print_to_screen(const int *result, const int size); +template void print_to_screen(const uint *result, const int size); +template void print_to_screen(const bool *result, const int size); + + +template +void printMatrix(T *ptr, int m, int k, int stride, bool is_device_ptr) { + T *tmp; + if (is_device_ptr) { + // k < stride ; stride = col-dimension. + tmp = reinterpret_cast(malloc(m * stride * sizeof(T))); + check_cuda_error( + cudaMemcpy(tmp, ptr, sizeof(T) * m * stride, cudaMemcpyDeviceToHost)); + cudaDeviceSynchronize(); + } else { + tmp = ptr; + } + + for (int ii = -1; ii < m; ++ii) { + if (ii >= 0) { + printf("%02d ", ii); + } else { + printf(" "); + } + + for (int jj = 0; jj < k; jj += 1) { + if (ii >= 0) { + printf("%7.3f ", (float)tmp[ii * stride + jj]); + } else { + printf("%7d ", jj); + } + } + printf("\n"); + } + if (is_device_ptr) { + free(tmp); + } +} + +template void printMatrix(float *ptr, int m, int k, int stride, + bool is_device_ptr); +template void printMatrix(half *ptr, int m, int k, int stride, + bool is_device_ptr); + +void printMatrix(unsigned long long *ptr, int m, int k, int stride, + bool is_device_ptr) { + typedef unsigned long long T; + T *tmp; + if (is_device_ptr) { + // k < stride ; stride = col-dimension. + tmp = reinterpret_cast(malloc(m * stride * sizeof(T))); + check_cuda_error( + cudaMemcpy(tmp, ptr, sizeof(T) * m * stride, cudaMemcpyDeviceToHost)); + cudaDeviceSynchronize(); + } else { + tmp = ptr; + } + + for (int ii = -1; ii < m; ++ii) { + if (ii >= 0) { + printf("%02d ", ii); + } else { + printf(" "); + } + + for (int jj = 0; jj < k; jj += 1) { + if (ii >= 0) { + printf("%4llu ", tmp[ii * stride + jj]); + } else { + printf("%4d ", jj); + } + } + printf("\n"); + } + if (is_device_ptr) { + free(tmp); + } +} + +void printMatrix(int *ptr, int m, int k, int stride, bool is_device_ptr) { + typedef int T; + T *tmp; + if (is_device_ptr) { + // k < stride ; stride = col-dimension. + tmp = reinterpret_cast(malloc(m * stride * sizeof(T))); + check_cuda_error( + cudaMemcpy(tmp, ptr, sizeof(T) * m * stride, cudaMemcpyDeviceToHost)); + cudaDeviceSynchronize(); + } else { + tmp = ptr; + } + + for (int ii = -1; ii < m; ++ii) { + if (ii >= 0) { + printf("%02d ", ii); + } else { + printf(" "); + } + + for (int jj = 0; jj < k; jj += 1) { + if (ii >= 0) { + printf("%4d ", tmp[ii * stride + jj]); + } else { + printf("%4d ", jj); + } + } + printf("\n"); + } + if (is_device_ptr) { + free(tmp); + } +} + +void printMatrix(size_t *ptr, int m, int k, int stride, bool is_device_ptr) { + typedef size_t T; + T *tmp; + if (is_device_ptr) { + // k < stride ; stride = col-dimension. + tmp = reinterpret_cast(malloc(m * stride * sizeof(T))); + check_cuda_error( + cudaMemcpy(tmp, ptr, sizeof(T) * m * stride, cudaMemcpyDeviceToHost)); + cudaDeviceSynchronize(); + } else { + tmp = ptr; + } + + for (int ii = -1; ii < m; ++ii) { + if (ii >= 0) { + printf("%02d ", ii); + } else { + printf(" "); + } + + for (int jj = 0; jj < k; jj += 1) { + if (ii >= 0) { + printf("%4ld ", tmp[ii * stride + jj]); + } else { + printf("%4d ", jj); + } + } + printf("\n"); + } + if (is_device_ptr) { + free(tmp); + } +} + +template void check_max_val(const T *result, const int size) { + T *tmp = new T[size]; + cudaMemcpy(tmp, result, sizeof(T) * size, cudaMemcpyDeviceToHost); + float max_val = -100000; + for (int i = 0; i < size; i++) { + float val = static_cast(tmp[i]); + if (val > max_val) { + max_val = val; + } + } + delete tmp; + printf("[INFO][CUDA] addr %p max val: %f \n", result, max_val); +} + +template void check_max_val(const float *result, const int size); +template void check_max_val(const half *result, const int size); + +template void check_abs_mean_val(const T *result, const int size) { + T *tmp = new T[size]; + cudaMemcpy(tmp, result, sizeof(T) * size, cudaMemcpyDeviceToHost); + float sum = 0.0f; + for (int i = 0; i < size; i++) { + sum += abs(static_cast(tmp[i])); + } + delete tmp; + printf("[INFO][CUDA] addr %p abs mean val: %f \n", result, sum / size); +} + +template void check_abs_mean_val(const float *result, const int size); +template void check_abs_mean_val(const half *result, const int size); + +/* ***************************** common utils ****************************** */ + +cudaError_t getSetDevice(int i_device, int *o_device) { + int current_dev_id = 0; + cudaError_t err = cudaSuccess; + + if (o_device != NULL) { + err = cudaGetDevice(¤t_dev_id); + if (err != cudaSuccess) { + return err; + } + if (current_dev_id == i_device) { + *o_device = i_device; + } else { + err = cudaSetDevice(i_device); + if (err != cudaSuccess) { + return err; + } + *o_device = current_dev_id; + } + } else { + err = cudaSetDevice(i_device); + if (err != cudaSuccess) { + return err; + } + } + + return cudaSuccess; +} + +/* ************************** end of common utils ************************** */ diff --git a/csrc/int8gemm/cublas/cuda_utils.h b/csrc/int8gemm/cublas/cuda_utils.h new file mode 100644 index 000000000000..e85d38bf53d4 --- /dev/null +++ b/csrc/int8gemm/cublas/cuda_utils.h @@ -0,0 +1,436 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#ifdef SPARSITY_ENABLED +#include +#endif + +#define MAX_CONFIG_NUM 20 +#define COL32_ 32 +// workspace for cublas gemm : 32MB +#define CUBLAS_WORKSPACE_SIZE 33554432 + +typedef struct __align__(4) { + half x, y, z, w; +} +half4; + +/* **************************** type definition ***************************** */ + +enum CublasDataType { + FLOAT_DATATYPE = 0, + HALF_DATATYPE = 1, + BFLOAT16_DATATYPE = 2, + INT8_DATATYPE = 3, + FP8_DATATYPE = 4 +}; + +enum FtCudaDataType { FP32 = 0, FP16 = 1, BF16 = 2, INT8 = 3, FP8 = 4 }; + +enum class OperationType { FP32, FP16, BF16, INT8, FP8 }; + +/* **************************** debug tools ********************************* */ +static const char *_cudaGetErrorEnum(cudaError_t error) { + return cudaGetErrorString(error); +} + +static const char *_cudaGetErrorEnum(cublasStatus_t error) { + switch (error) { + case CUBLAS_STATUS_SUCCESS: + return "CUBLAS_STATUS_SUCCESS"; + + case CUBLAS_STATUS_NOT_INITIALIZED: + return "CUBLAS_STATUS_NOT_INITIALIZED"; + + case CUBLAS_STATUS_ALLOC_FAILED: + return "CUBLAS_STATUS_ALLOC_FAILED"; + + case CUBLAS_STATUS_INVALID_VALUE: + return "CUBLAS_STATUS_INVALID_VALUE"; + + case CUBLAS_STATUS_ARCH_MISMATCH: + return "CUBLAS_STATUS_ARCH_MISMATCH"; + + case CUBLAS_STATUS_MAPPING_ERROR: + return "CUBLAS_STATUS_MAPPING_ERROR"; + + case CUBLAS_STATUS_EXECUTION_FAILED: + return "CUBLAS_STATUS_EXECUTION_FAILED"; + + case CUBLAS_STATUS_INTERNAL_ERROR: + return "CUBLAS_STATUS_INTERNAL_ERROR"; + + case CUBLAS_STATUS_NOT_SUPPORTED: + return "CUBLAS_STATUS_NOT_SUPPORTED"; + + case CUBLAS_STATUS_LICENSE_ERROR: + return "CUBLAS_STATUS_LICENSE_ERROR"; + } + return ""; +} + +template +void check(T result, char const *const func, const char *const file, + int const line) { + if (result) { + throw std::runtime_error(std::string("[FT][ERROR] CUDA runtime error: ") + + (_cudaGetErrorEnum(result)) + " " + file + ":" + + std::to_string(line) + " \n"); + } +} + +#define check_cuda_error(val) check((val), #val, __FILE__, __LINE__) +#define check_cuda_error_2(val, file, line) check((val), #val, file, line) + +inline void syncAndCheck(const char *const file, int const line) { + // When FT_DEBUG_LEVEL=DEBUG, must check error + static char *level_name = std::getenv("FT_DEBUG_LEVEL"); + if (level_name != nullptr) { + static std::string level = std::string(level_name); + if (level == "DEBUG") { + cudaDeviceSynchronize(); + cudaError_t result = cudaGetLastError(); + if (result) { + throw std::runtime_error( + std::string("[FT][ERROR] CUDA runtime error: ") + + (_cudaGetErrorEnum(result)) + " " + file + ":" + + std::to_string(line) + " \n"); + } + // FT_LOG_DEBUG(fmtstr("run syncAndCheck at %s:%d", file, line)); + } + } + +#ifndef NDEBUG + cudaDeviceSynchronize(); + cudaError_t result = cudaGetLastError(); + if (result) { + throw std::runtime_error(std::string("[FT][ERROR] CUDA runtime error: ") + + (_cudaGetErrorEnum(result)) + " " + file + ":" + + std::to_string(line) + " \n"); + } +#endif +} + +#define sync_check_cuda_error() syncAndCheck(__FILE__, __LINE__) + +#define checkCUDNN(expression) \ + { \ + cudnnStatus_t status = (expression); \ + if (status != CUDNN_STATUS_SUCCESS) { \ + std::cerr << "Error on file " << __FILE__ << " line " << __LINE__ \ + << ": " << cudnnGetErrorString(status) << std::endl; \ + std::exit(EXIT_FAILURE); \ + } \ + } + +template +void print_to_file(const T *result, const int size, const char *file, + cudaStream_t stream = 0, + std::ios::openmode open_mode = std::ios::out); + +template +void print_abs_mean(const T *buf, uint size, cudaStream_t stream, + std::string name = ""); + +template void print_to_screen(const T *result, const int size); + +template +void printMatrix(T *ptr, int m, int k, int stride, bool is_device_ptr); + +void printMatrix(unsigned long long *ptr, int m, int k, int stride, + bool is_device_ptr); +void printMatrix(int *ptr, int m, int k, int stride, bool is_device_ptr); +void printMatrix(size_t *ptr, int m, int k, int stride, bool is_device_ptr); + +template void check_max_val(const T *result, const int size); + +template void check_abs_mean_val(const T *result, const int size); + +#define PRINT_FUNC_NAME_() \ + do { \ + std::cout << "[FT][CALL] " << __FUNCTION__ << " " << std::endl; \ + } while (0) + +[[noreturn]] inline void throwRuntimeError(const char *const file, + int const line, + std::string const &info = "") { + throw std::runtime_error(std::string("[FT][ERROR] ") + info + + " Assertion fail: " + file + ":" + + std::to_string(line) + " \n"); +} + +inline void myAssert(bool result, const char *const file, int const line, + std::string const &info = "") { + if (!result) { + throwRuntimeError(file, line, info); + } +} + +#define FT_CHECK(val) myAssert(val, __FILE__, __LINE__) +#define FT_CHECK_WITH_INFO(val, info) \ + do { \ + bool is_valid_val = (val); \ + if (!is_valid_val) { \ + fastertransformer::myAssert(is_valid_val, __FILE__, __LINE__, (info)); \ + } \ + } while (0) + +#define FT_THROW(info) throwRuntimeError(__FILE__, __LINE__, info) + +#ifdef SPARSITY_ENABLED +#define CHECK_CUSPARSE(func) \ + { \ + cusparseStatus_t status = (func); \ + if (status != CUSPARSE_STATUS_SUCCESS) { \ + throw std::runtime_error( \ + std::string("[FT][ERROR] CUSPARSE API failed at line ") + \ + std::to_string(__LINE__) + " in file " + __FILE__ + ": " + \ + cusparseGetErrorString(status) + " " + std::to_string(status)); \ + } \ + } +#endif + +/*************Time Handling**************/ +class CudaTimer { +private: + cudaEvent_t event_start_; + cudaEvent_t event_stop_; + cudaStream_t stream_; + +public: + explicit CudaTimer(cudaStream_t stream = 0) { stream_ = stream; } + void start() { + check_cuda_error(cudaEventCreate(&event_start_)); + check_cuda_error(cudaEventCreate(&event_stop_)); + check_cuda_error(cudaEventRecord(event_start_, stream_)); + } + float stop() { + float time; + check_cuda_error(cudaEventRecord(event_stop_, stream_)); + check_cuda_error(cudaEventSynchronize(event_stop_)); + check_cuda_error(cudaEventElapsedTime(&time, event_start_, event_stop_)); + check_cuda_error(cudaEventDestroy(event_start_)); + check_cuda_error(cudaEventDestroy(event_stop_)); + return time; + } + ~CudaTimer() {} +}; + +static double diffTime(timeval start, timeval end) { + return (end.tv_sec - start.tv_sec) * 1000 + + (end.tv_usec - start.tv_usec) * 0.001; +} + +/* ***************************** common utils ****************************** */ + +inline void print_mem_usage(std::string time = "after allocation") { + size_t free_bytes, total_bytes; + check_cuda_error(cudaMemGetInfo(&free_bytes, &total_bytes)); + float free = static_cast(free_bytes) / 1024.0 / 1024.0 / 1024.0; + float total = static_cast(total_bytes) / 1024.0 / 1024.0 / 1024.0; + float used = total - free; + printf("%-20s: free: %5.2f GB, total: %5.2f GB, used: %5.2f GB\n", + time.c_str(), free, total, used); +} + +inline int getSMVersion() { + int device{-1}; + check_cuda_error(cudaGetDevice(&device)); + int sm_major = 0; + int sm_minor = 0; + check_cuda_error(cudaDeviceGetAttribute( + &sm_major, cudaDevAttrComputeCapabilityMajor, device)); + check_cuda_error(cudaDeviceGetAttribute( + &sm_minor, cudaDevAttrComputeCapabilityMinor, device)); + return sm_major * 10 + sm_minor; +} + +inline int getMaxSharedMemoryPerBlock() { + int device{-1}; + check_cuda_error(cudaGetDevice(&device)); + int max_shared_memory_size = 0; + check_cuda_error(cudaDeviceGetAttribute( + &max_shared_memory_size, cudaDevAttrMaxSharedMemoryPerBlock, device)); + return max_shared_memory_size; +} + +inline std::string getDeviceName() { + int device{-1}; + check_cuda_error(cudaGetDevice(&device)); + cudaDeviceProp props; + check_cuda_error(cudaGetDeviceProperties(&props, device)); + return std::string(props.name); +} + +inline int div_up(int a, int n) { return (a + n - 1) / n; } + +cudaError_t getSetDevice(int i_device, int *o_device = NULL); + +inline int getDevice() { + int current_dev_id = 0; + check_cuda_error(cudaGetDevice(¤t_dev_id)); + return current_dev_id; +} + +inline int getDeviceCount() { + int count = 0; + check_cuda_error(cudaGetDeviceCount(&count)); + return count; +} + +template CublasDataType getCublasDataType() { + if (std::is_same::value) { + return HALF_DATATYPE; + } + // #ifdef ENABLE_BF16 + // else if (std::is_same::value) { + // return BFLOAT16_DATATYPE; + // } + // #endif + else if (std::is_same::value) { + return FLOAT_DATATYPE; + } else { + FT_CHECK(false); + return FLOAT_DATATYPE; + } +} + +template cudaDataType_t getCudaDataType() { + if (std::is_same::value) { + return CUDA_R_16F; + } + // #ifdef ENABLE_BF16 + // else if (std::is_same::value) { + // return CUDA_R_16BF; + // } + // #endif + else if (std::is_same::value) { + return CUDA_R_32F; + } else { + FT_CHECK(false); + return CUDA_R_32F; + } +} + +template struct getTypeFromCudaDataType { + using Type = float; +}; + +template <> struct getTypeFromCudaDataType { + using Type = half; +}; + + +// clang-format off +template struct packed_type; +template <> struct packed_type { using type = float; }; // we don't need to pack float by default +template <> struct packed_type { using type = half2; }; + + +template struct num_elems; +template <> struct num_elems { static constexpr int value = 1; }; +template <> struct num_elems { static constexpr int value = 2; }; +template <> struct num_elems { static constexpr int value = 4; }; +template <> struct num_elems { static constexpr int value = 1; }; +template <> struct num_elems { static constexpr int value = 2; }; + + +template struct packed_as; +template struct packed_as { using type = T; }; +template<> struct packed_as { using type = half2; }; +template<> struct packed_as { using type = float2; }; +template<> struct packed_as { using type = int16_t; }; +template<> struct packed_as { using type = int2; }; +template<> struct packed_as { using type = half; }; + + +inline __device__ float2 operator*(float2 a, float2 b) { return make_float2(a.x * b.x, a.y * b.y); } +inline __device__ float2 operator*(float2 a, float b) { return make_float2(a.x * b, a.y * b); } +// clang-format on + +template +void compareTwoTensor(const T1 *pred, const T2 *ref, const int size, + const int print_size = 0, + const std::string filename = "") { + T1 *h_pred = new T1[size]; + T2 *h_ref = new T2[size]; + check_cuda_error( + cudaMemcpy(h_pred, pred, size * sizeof(T1), cudaMemcpyDeviceToHost)); + check_cuda_error( + cudaMemcpy(h_ref, ref, size * sizeof(T2), cudaMemcpyDeviceToHost)); + + FILE *fd = nullptr; + if (filename != "") { + fd = fopen(filename.c_str(), "w"); + fprintf(fd, "| %10s | %10s | %10s | %10s | \n", "pred", "ref", "abs_diff", + "rel_diff(%)"); + } + + if (print_size > 0) { + // FT_LOG_INFO(" id | pred | ref |abs diff | rel diff (%) |"); + } + float mean_abs_diff = 0.0f; + float mean_rel_diff = 0.0f; + int count = 0; + for (int i = 0; i < size; i++) { + if (i < print_size) { + // FT_LOG_INFO("%4d | % 6.4f | % 6.4f | % 6.4f | % 7.4f |", + // i, + // (float)h_pred[i], + // (float)h_ref[i], + // abs((float)h_pred[i] - (float)h_ref[i]), + // abs((float)h_pred[i] - (float)h_ref[i]) / + // (abs((float)h_ref[i]) + 1e-6f) * 100.f); + } + if ((float)h_pred[i] == 0) { + continue; + } + count += 1; + mean_abs_diff += abs((float)h_pred[i] - (float)h_ref[i]); + mean_rel_diff += abs((float)h_pred[i] - (float)h_ref[i]) / + (abs((float)h_ref[i]) + 1e-6f) * 100.f; + + if (fd != nullptr) { + fprintf(fd, "| %10.5f | %10.5f | %10.5f | %11.5f |\n", (float)h_pred[i], + (float)h_ref[i], abs((float)h_pred[i] - (float)h_ref[i]), + abs((float)h_pred[i] - (float)h_ref[i]) / + (abs((float)h_ref[i]) + 1e-6f) * 100.f); + } + } + mean_abs_diff = mean_abs_diff / (float)count; + mean_rel_diff = mean_rel_diff / (float)count; + // FT_LOG_INFO("mean_abs_diff: % 6.4f, mean_rel_diff: % 6.4f (%%)", + // mean_abs_diff, mean_rel_diff); + + if (fd != nullptr) { + fprintf(fd, "mean_abs_diff: % 6.4f, mean_rel_diff: % 6.4f (%%)", + mean_abs_diff, mean_rel_diff); + fclose(fd); + } + delete[] h_pred; + delete[] h_ref; +} + +/* ************************** end of common utils ************************** */ diff --git a/csrc/int8gemm/cublas/int8_utils.cuh b/csrc/int8gemm/cublas/int8_utils.cuh new file mode 100644 index 000000000000..55b9c4f24de2 --- /dev/null +++ b/csrc/int8gemm/cublas/int8_utils.cuh @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include + +static inline __device__ int8_t float_to_int8_rn(float x) +{ + uint32_t dst; + asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x)); + return reinterpret_cast(dst); +} + +static inline __device__ uint32_t float4_to_char4(float x, + float y, + float z, + float w) { + uint32_t dst; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 720 + uint32_t a; asm volatile("cvt.rni.sat.s32.f32 %0, %1;\n" : "=r"(a) : "f"(x)); + uint32_t b; asm volatile("cvt.rni.sat.s32.f32 %0, %1;\n" : "=r"(b) : "f"(y)); + uint32_t c; asm volatile("cvt.rni.sat.s32.f32 %0, %1;\n" : "=r"(c) : "f"(z)); + uint32_t d; asm volatile("cvt.rni.sat.s32.f32 %0, %1;\n" : "=r"(d) : "f"(w)); + + asm volatile("cvt.pack.sat.s8.s32.b32 %0, %1, %2, 0;\n" : "=r"(dst) : "r"(d), "r"(c)); + asm volatile("cvt.pack.sat.s8.s32.b32 %0, %1, %2, %0;\n" : "+r"(dst) : "r"(b), "r"(a)); +#else + char4 tmp; + tmp.x = x; + tmp.y = y; + tmp.z = z; + tmp.w = w; + dst = reinterpret_cast(tmp); +#endif + return dst; +} \ No newline at end of file diff --git a/csrc/int8gemm/cublas/transform_layout.cu b/csrc/int8gemm/cublas/transform_layout.cu new file mode 100644 index 000000000000..7bc6cfa6b95e --- /dev/null +++ b/csrc/int8gemm/cublas/transform_layout.cu @@ -0,0 +1,127 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// #include "src/fastertransformer/kernels/layout_transformer_int8_kernels.h" +#include "transform_layout.h" +#include + +// transform row-major to COL32 +// input matrix is (m, n) row-major +// output matrix is (m, n) COL32 +// n should be a multiple of 32 +// grid((n+31)/32, (m+31)/32) +// block(8, 32) +__global__ void rowMajorToCOL32_kernel(char4 *dst, const char4 *src, const int m, const int n) { + + int n_id = (blockIdx.x * blockDim.x + threadIdx.x) << 2; + int m_id = blockIdx.y * blockDim.y + threadIdx.y; + + bool check = ((m_id < m) && (n_id < n)); + if (check) { + + // COL32_col = n_id >> 5 ; COL32_row = (m_id << 5) + (n_id & 31); + // COL32_idx = (COL32_col << 5) * m + COL32_row = (n_id & 0xffffffe0)*m + + // (m_id << 5) + (n_id & 31) + dst[((n_id & 0xffffffe0) * m + (m_id << 5) + (n_id & 31)) >> 2] = + __ldg(src + ((m_id * n + n_id) >> 2)); + } +} + +__global__ void col32ToRowMajor_kernel(char4 *dst, const char4 *src, + const int m, const int n) { + int n_id = (blockIdx.x * blockDim.x + threadIdx.x) << 2; + int m_id = blockIdx.y * blockDim.y + threadIdx.y; + + bool check = ((m_id < m) && (n_id < n)); + if (check) { + int idx = m_id * n + n_id; + dst[(((idx >> 5) % m) * n + (((idx >> 5) / m) << 5) + (idx & 31)) >> 2] = + __ldg(src + (idx >> 2)); + } +} + +void invokeRowMajorToCOL32(int8_t *dst, const int8_t *src, const int m, + const int n, cudaStream_t stream) { + assert(n % 32 == 0); + rowMajorToCOL32_kernel<<>>((char4 *)dst, (const char4 *)src, m, n); +} + +void invokeCOL32ToRowMajor(int8_t *dst, const int8_t *src, const int m, + const int n, cudaStream_t stream) { + assert(n % 32 == 0); + col32ToRowMajor_kernel<<>>((char4 *)dst, (const char4 *)src, m, n); +} + +__global__ void rowMajorToAmpere_kernel(char4 *dst, const char4 *src, + const int m, const int n) { + int n_id = (blockIdx.x * blockDim.x + threadIdx.x) << 2; + int m_id = blockIdx.y * blockDim.y + threadIdx.y; + + bool check = ((m_id < m) && (n_id < n)); + if (check) { + int new_col = n_id >> 5; + int row_in_tile = m_id & 31; + int col_in_tile = n_id & 31; + int new_row = // CUBLASLT_ORDER_COL32_2R_4R4 + (((m_id >> 5) << 10) + + //(((row%8)/2*4+row/8)*2+row%2)*32+col + (((((((row_in_tile & 7) >> 1) << 2) + (row_in_tile >> 3)) << 1) + + (row_in_tile & 1)) + << 5) + + col_in_tile); + int idx = m_id * n + n_id; + dst[(new_col * (m << 5) + new_row) >> 2] = __ldg(src + (idx >> 2)); + } +} + +void invokeRowMajorToAmpere(int8_t *dst, const int8_t *src, const int m, + const int n, cudaStream_t stream) { + assert(n % 32 == 0); + rowMajorToAmpere_kernel<<>>((char4 *)dst, (const char4 *)src, m, n); +} + +__global__ void rowMajorToTuring_kernel(char4 *dst, const char4 *src, + const int m, const int n) { + int n_id = (blockIdx.x * blockDim.x + threadIdx.x) << 2; + int m_id = blockIdx.y * blockDim.y + threadIdx.y; + + bool check = ((m_id < m) && (n_id < n)); + if (check) { + int new_col = n_id >> 5; + int new_row = // CUBLASLT_ORDER_COL4_4R2_8C + ////m_id/8 is the number of tile of (8 rows 32 columns) -- + /// column-major /m_id%2 is even row, otherwise odd row + ////n_id%COL32_/8 is the number tile of (8 rows 8 columns) + (((((m_id >> 3) << 3) + ((m_id & 1) << 2) + ((n_id & 31) >> 3)) << 5) + + ////n_id%8 >= 4 is the right half of (8 rows 8 columns) tile + ////(m_id%8/2) is (the row id of alternating 4 rows) - 1 + (((((n_id & 7) >= 4) ? 4 : 0) + ((m_id & 7) >> 1)) << 2) + + ////n_id%4 is the id of 4 cols + (n_id & 3)); + int idx = m_id * n + n_id; + dst[(new_col * (m << 5) + new_row) >> 2] = __ldg(src + (idx >> 2)); + } +} + +void invokeRowMajorToTuring(int8_t *dst, const int8_t *src, const int m, + const int n, cudaStream_t stream) { + assert(n % 32 == 0); + rowMajorToTuring_kernel<<>>((char4 *)dst, (const char4 *)src, m, n); +} diff --git a/csrc/int8gemm/cublas/transform_layout.h b/csrc/int8gemm/cublas/transform_layout.h new file mode 100644 index 000000000000..a695923b08f8 --- /dev/null +++ b/csrc/int8gemm/cublas/transform_layout.h @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "int8_utils.cuh" +#include +#include +#include + +void invokeRowMajorToCOL32(int8_t *dst, const int8_t *src, const int m, + const int n, cudaStream_t stream); +void invokeCOL32ToRowMajor(int8_t *dst, const int8_t *src, const int m, + const int n, cudaStream_t stream); +void invokeRowMajorToAmpere(int8_t *dst, const int8_t *src, const int m, + const int n, cudaStream_t stream); +void invokeRowMajorToTuring(int8_t *dst, const int8_t *src, const int m, + const int n, cudaStream_t stream); \ No newline at end of file diff --git a/csrc/layernorm.cpp b/csrc/layernorm.cpp index 749ca5f92154..c9f917e1aab8 100644 --- a/csrc/layernorm.cpp +++ b/csrc/layernorm.cpp @@ -1,14 +1,28 @@ #include -void rms_norm( - torch::Tensor& out, - torch::Tensor& input, - torch::Tensor& weight, - float epsilon); +void rms_norm(torch::Tensor &out, torch::Tensor &input, torch::Tensor &weight, + float epsilon); + +void invoke_rms_norm_quant(torch::Tensor &out, // [num_tokens, hidden_size] + torch::Tensor &input, // [num_tokens, hidden_size] + torch::Tensor &gamma, // [hidden_size] + float epsilon); + +void invoke_dequant_add_residual_rms_norm_quant( + torch::Tensor &out, // [num_tokens, hidden_size] + torch::Tensor &input, // [num_tokens, hidden_size] + torch::Tensor &residual, // [num_tokens, hidden_size] + torch::Tensor &gamma, // [hidden_size] + float epsilon, float scale); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def( - "rms_norm", - &rms_norm, - "Apply Root Mean Square (RMS) Normalization to the input tensor."); + m.def("rms_norm", &rms_norm, + "Apply Root Mean Square (RMS) Normalization to the input tensor."); + m.def("invoke_rms_norm_quant", &invoke_rms_norm_quant, + "Apply Root Mean Square (RMS) Normalization to the input tensor and " + "quant output."); + m.def("invoke_dequant_add_residual_rms_norm_quant", + &invoke_dequant_add_residual_rms_norm_quant, + "Add the dequanted result and residual, then use RMS norm and quant " + "output."); } diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index f932b9e2d615..1b5bf933edab 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -1,25 +1,25 @@ -#include #include +#include #include "dispatch_utils.h" +#include "quant_utils.cuh" #include "reduction_utils.cuh" namespace vllm { // TODO(woosuk): Further optimize this kernel. -template -__global__ void rms_norm_kernel( - scalar_t* __restrict__ out, // [num_tokens, hidden_size] - const scalar_t* __restrict__ input, // [num_tokens, hidden_size] - const scalar_t* __restrict__ weight, // [hidden_size] - const float epsilon, - const int num_tokens, - const int hidden_size) { +template +__global__ void +rms_norm_kernel(scalar_t *__restrict__ out, // [num_tokens, hidden_size] + const scalar_t *__restrict__ input, // [num_tokens, hidden_size] + const scalar_t *__restrict__ weight, // [hidden_size] + const float epsilon, const int num_tokens, + const int hidden_size) { __shared__ float s_variance; float variance = 0.0f; for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - const float x = (float) input[blockIdx.x * hidden_size + idx]; + const float x = (float)input[blockIdx.x * hidden_size + idx]; variance += x * x; } variance = blockReduceSum(variance); @@ -29,34 +29,129 @@ __global__ void rms_norm_kernel( __syncthreads(); for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - float x = (float) input[blockIdx.x * hidden_size + idx]; - out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx]; + float x = (float)input[blockIdx.x * hidden_size + idx]; + out[blockIdx.x * hidden_size + idx] = + ((scalar_t)(x * s_variance)) * weight[idx]; + } +} + +template +__global__ void rms_norm_quant_kernel(const T *__restrict__ input, + const T *__restrict__ gamma, + int8_t *__restrict__ output, + const float layernorm_eps, int m, int n) { + // layernorm module in the T5 style No bias and no subtraction of mean. + const int tid = threadIdx.x; + + __shared__ float s_variance; + float variance = 0.0f; + + float local_var_sum = 0.0f; + for (int i = tid; i < n; i += blockDim.x) { + // float diff = (float)(ldg(&input[blockIdx.x * n + i])); + float diff = (float)(input[blockIdx.x * n + i]); + local_var_sum += diff * diff; + } + variance = blockReduceSum(local_var_sum); + + if (threadIdx.x == 0) { + s_variance = rsqrtf(variance / (float)n + layernorm_eps); + } + __syncthreads(); + + for (int i = tid; i < n; i += blockDim.x) { + output[blockIdx.x * n + i] = float_to_int8_rn( + (((float)input[blockIdx.x * n + i]) * s_variance) * (float)(gamma[i])); + } +} + +template +__global__ void dequant_add_residual_rms_norm_quant_kernel( + const int32_t *__restrict__ input, T *__restrict__ residual, + int8_t *__restrict__ output, const T *__restrict__ gamma, + const float layernorm_eps, const float scale, int m, int n) { + // layernorm module in the T5 style No bias and no subtraction of mean. + const int tid = threadIdx.x; + + __shared__ float s_variance; + float variance = 0.0f; + + float local_var_sum = 0.0f; + for (int i = tid; i < n; i += blockDim.x) { + float diff = ((((float)input[blockIdx.x * n + i]) * scale) + + (float)residual[blockIdx.x * n + i]); + residual[blockIdx.x * n + i] = (T)diff; + local_var_sum += diff * diff; + } + variance = blockReduceSum(local_var_sum); + + if (threadIdx.x == 0) { + s_variance = rsqrtf(variance / (float)n + layernorm_eps); + } + __syncthreads(); + + for (int i = tid; i < n; i += blockDim.x) { + output[blockIdx.x * n + i] = + float_to_int8_rn((((float)(residual[blockIdx.x * n + i])) * s_variance) * (float)(gamma[i])); } } } // namespace vllm -void rms_norm( - torch::Tensor& out, // [num_tokens, hidden_size] - torch::Tensor& input, // [num_tokens, hidden_size] - torch::Tensor& weight, // [hidden_size] - float epsilon) { +void rms_norm(torch::Tensor &out, // [num_tokens, hidden_size] + torch::Tensor &input, // [num_tokens, hidden_size] + torch::Tensor &weight, // [hidden_size] + float epsilon) { int num_tokens = input.size(0); int hidden_size = input.size(1); dim3 grid(num_tokens); dim3 block(std::min(hidden_size, 1024)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] { + vllm::rms_norm_kernel<<>>( + out.data_ptr(), input.data_ptr(), + weight.data_ptr(), epsilon, num_tokens, hidden_size); + }); +} + +void invoke_rms_norm_quant(torch::Tensor &out, // [num_tokens, hidden_size] + torch::Tensor &input, // [num_tokens, hidden_size] + torch::Tensor &gamma, // [hidden_size] + float epsilon) { + int m = input.size(0); + int n = input.size(1); + dim3 grid(m); + dim3 block(min(n, 1024)); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "rms_norm_quant_kernel", [&] { + vllm::rms_norm_quant_kernel<<>>( + input.data_ptr(), gamma.data_ptr(), + out.data_ptr(), epsilon, m, n); + }); +} + +void invoke_dequant_add_residual_rms_norm_quant( + torch::Tensor &out, // [num_tokens, hidden_size] + torch::Tensor &input, // [num_tokens, hidden_size] + torch::Tensor &residual, // [num_tokens, hidden_size] + torch::Tensor &gamma, // [hidden_size] + float epsilon, float scale) { + int m = input.size(0); + int n = input.size(1); + dim3 grid(m); + dim3 block(min(n, 1024)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( - input.scalar_type(), - "rms_norm_kernel", - [&] { - vllm::rms_norm_kernel<<>>( - out.data_ptr(), - input.data_ptr(), - weight.data_ptr(), - epsilon, - num_tokens, - hidden_size); - }); + residual.scalar_type(), "dequant_add_residual_rms_norm_quant_kernel", + [&] { + vllm::dequant_add_residual_rms_norm_quant_kernel + <<>>( + input.data_ptr(), residual.data_ptr(), + out.data_ptr(), gamma.data_ptr(), epsilon, + scale, m, n); + }); } diff --git a/csrc/pos_encoding.cpp b/csrc/pos_encoding.cpp index eee0cf0d0fa0..60930025c988 100644 --- a/csrc/pos_encoding.cpp +++ b/csrc/pos_encoding.cpp @@ -1,16 +1,21 @@ #include -void rotary_embedding( - torch::Tensor& positions, - torch::Tensor& query, - torch::Tensor& key, - int head_size, - torch::Tensor& cos_sin_cache, - bool is_neox); +void rotary_embedding(torch::Tensor &positions, torch::Tensor &query, + torch::Tensor &key, int head_size, + torch::Tensor &cos_sin_cache, bool is_neox); +void invoke_dequant_rotary_embedding( + torch::Tensor &positions, // [num_tokens] + torch::Tensor &query, // [num_tokens, num_heads * head_size] + torch::Tensor &query_out, // [num_tokens, num_heads * head_size] + torch::Tensor &key, // [num_tokens, num_kv_heads * head_size] + torch::Tensor &key_out, // [num_tokens, num_kv_heads * head_size] + int head_size, + torch::Tensor &cos_sin_cache, // [max_position, rot_dim] + const float query_scale, const float key_scale, bool is_neox); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def( - "rotary_embedding", - &rotary_embedding, - "Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); + m.def("rotary_embedding", &rotary_embedding, + "Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); + m.def("invoke_dequant_rotary_embedding", &invoke_dequant_rotary_embedding, + "Dequant the input and apply rotary embedding."); } diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu index b4351ee0d794..4a36f1de4747 100644 --- a/csrc/pos_encoding_kernels.cu +++ b/csrc/pos_encoding_kernels.cu @@ -35,6 +35,38 @@ inline __device__ void apply_rotary_embedding( arr[y_index] = y * cos + x * sin; } +template +inline __device__ void apply_dequant_rotary_embedding( + const int32_t* __restrict__ arr, + scalar_t* __restrict__ arr_out, + const scalar_t* __restrict__ cos_ptr, + const scalar_t* __restrict__ sin_ptr, + int rot_offset, + int embed_dim, + const float scale) +{ + int x_index, y_index; + scalar_t cos, sin; + if (IS_NEOX) { + // GPT-NeoX style rotary embedding. + x_index = rot_offset; + y_index = embed_dim + rot_offset; + cos = __ldg(cos_ptr + x_index); + sin = __ldg(sin_ptr + x_index); + } else { + // GPT-J style rotary embedding. + x_index = 2 * rot_offset; + y_index = 2 * rot_offset + 1; + cos = __ldg(cos_ptr + x_index / 2); + sin = __ldg(sin_ptr + x_index / 2); + } + + const scalar_t x = (scalar_t)((float)arr[x_index] * scale); + const scalar_t y = (scalar_t)((float)arr[y_index] * scale); + arr_out[x_index] = x * cos - y * sin; + arr_out[y_index] = y * cos + x * sin; +} + template __global__ void rotary_embedding_kernel( const int64_t* __restrict__ positions, // [num_tokens] @@ -75,6 +107,54 @@ __global__ void rotary_embedding_kernel( } } +template +__global__ void dequant_rotary_embedding_kernel( + const int64_t* __restrict__ positions, // [num_tokens] + const int32_t* __restrict__ query, // [num_tokens, num_heads, head_size] + scalar_t* __restrict__ query_out, // [num_tokens, num_heads, head_size] + const int32_t* __restrict__ key, // [num_tokens, num_kv_heads, head_size] + scalar_t* __restrict__ key_out, // [num_tokens, num_kv_heads, head_size] + const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] + const int rot_dim, + const int query_stride, + const int query_out_stride, + const int key_stride, + const int key_out_stride, + const int num_heads, + const int num_kv_heads, + const int head_size, + const float query_scale, + const float key_scale) { + // Each thread block is responsible for one token. + const int token_idx = blockIdx.x; + int64_t pos = positions[token_idx]; + const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; + + const int embed_dim = rot_dim / 2; + const scalar_t* cos_ptr = cache_ptr; + const scalar_t* sin_ptr = cache_ptr + embed_dim; + + const int nq = num_heads * embed_dim; + for (int i = threadIdx.x; i < nq; i += blockDim.x) { + const int head_idx = i / embed_dim; + const int token_head = token_idx * query_stride + head_idx * head_size; + const int token_out_head = token_idx * query_out_stride + head_idx * head_size; + const int rot_offset = i % embed_dim; + apply_dequant_rotary_embedding(query + token_head, query_out + token_out_head, cos_ptr, + sin_ptr, rot_offset, embed_dim, query_scale); + } + + const int nk = num_kv_heads * embed_dim; + for (int i = threadIdx.x; i < nk; i += blockDim.x) { + const int head_idx = i / embed_dim; + const int token_head = token_idx * key_stride + head_idx * head_size; + const int token_out_head = token_idx * key_out_stride + head_idx * head_size; + const int rot_offset = i % embed_dim; + apply_dequant_rotary_embedding(key + token_head, key_out + token_out_head, cos_ptr, + sin_ptr, rot_offset, embed_dim, key_scale); + } +} + } // namespace vllm void rotary_embedding( @@ -125,3 +205,70 @@ void rotary_embedding( } }); } + + +void invoke_dequant_rotary_embedding( + torch::Tensor& positions, // [num_tokens] + torch::Tensor& query, // [num_tokens, num_heads * head_size] + torch::Tensor& query_out, // [num_tokens, num_heads * head_size] + torch::Tensor& key, // [num_tokens, num_kv_heads * head_size] + torch::Tensor& key_out, // [num_tokens, num_kv_heads * head_size] + int head_size, + torch::Tensor& cos_sin_cache, // [max_position, rot_dim] + const float query_scale, + const float key_scale, + bool is_neox) { + int num_tokens = query.size(0); + int rot_dim = cos_sin_cache.size(1); + int num_heads = query.size(1) / head_size; + int num_kv_heads = key.size(1) / head_size; + int query_stride = query.stride(0); + int key_stride = key.stride(0); + int query_out_stride = query_out.stride(0); + int key_out_stride = key_out.stride(0); + dim3 grid(num_tokens); + dim3 block(std::min(num_heads * rot_dim / 2, 512)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES( + query_out.scalar_type(), + "dequant_rotary_embedding_kernel", + [&] { + if (is_neox) { + vllm::dequant_rotary_embedding_kernel<<>>( + positions.data_ptr(), + query.data_ptr(), + query_out.data_ptr(), + key.data_ptr(), + key_out.data_ptr(), + cos_sin_cache.data_ptr(), + rot_dim, + query_stride, + query_out_stride, + key_stride, + key_out_stride, + num_heads, + num_kv_heads, + head_size, + query_scale, + key_scale); + } else { + vllm::dequant_rotary_embedding_kernel<<>>( + positions.data_ptr(), + query.data_ptr(), + query_out.data_ptr(), + key.data_ptr(), + key_out.data_ptr(), + cos_sin_cache.data_ptr(), + rot_dim, + query_stride, + query_out_stride, + key_stride, + key_out_stride, + num_heads, + num_kv_heads, + head_size, + query_scale, + key_scale); + } + }); +} diff --git a/csrc/quant_utils.cuh b/csrc/quant_utils.cuh new file mode 100644 index 000000000000..d26e754bb40e --- /dev/null +++ b/csrc/quant_utils.cuh @@ -0,0 +1,268 @@ +// Adated from FasterTransformer, https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp +#pragma once + +#include +#include +#include +#include +#include "attention/attention_dtypes.h" +#include "attention/dtype_float32.cuh" +using namespace vllm; + +// this function is for function matching, delete it after writing customized dispatch functions +inline __device__ int8_t quant(double a, const float scale, const float zp) +{ + int8_t int8; + int8 = round(max(-128.f, min(127.f, (a - zp) / scale))); + return int8; +} + +inline __device__ int8_t quant(float a, const float scale, const float zp) +{ + int8_t int8; + int8 = round(max(-128.f, min(127.f, (a - zp) / scale))); + return int8; +} + +inline __device__ short quant(float2 a, const float scale, const float zp) +{ + union { + int8_t int8[2]; + short int16; + }; + + int8[0] = round(max(-128.f, min(127.f, (a.x - zp) / scale))); + int8[1] = round(max(-128.f, min(127.f, (a.y - zp) / scale))); + return int16; +} + +inline __device__ int32_t quant(float4 a, const float scale, const float zp) +{ + union { + int8_t int8[4]; + int32_t int32; + }; + + int8[0] = round(max(-128.f, min(127.f, (a.x - zp) / scale))); + int8[1] = round(max(-128.f, min(127.f, (a.y - zp) / scale))); + int8[2] = round(max(-128.f, min(127.f, (a.z - zp) / scale))); + int8[3] = round(max(-128.f, min(127.f, (a.w - zp) / scale))); + return int32; +} + +// float16 to int8 +inline __device__ int8_t quant(uint16_t a, const float scale, const float zp) +{ + int8_t int8; + float b = half_to_float(a); + int8 = round(max(-128.f, min(127.f, (b - zp) / scale))); + return int8; +} + +// float16x2 to int8x2 +inline __device__ int16_t quant(uint32_t a, const float scale, const float zp) +{ + union { + int8_t int8[2]; + short int16; + }; + float2 b = half2_to_float2(a); + + int8[0] = round(max(-128.f, min(127.f, (b.x - zp) / scale))); + int8[1] = round(max(-128.f, min(127.f, (b.y - zp) / scale))); + return int16; +} + +// float16x4 to int8x4 +inline __device__ int32_t quant(uint2 a, const float scale, const float zp) +{ + union { + int16_t int16[2]; + int32_t int32; + }; + + int16[0] = quant(a.x, scale, zp); + int16[1] = quant(a.y, scale, zp); + return int32; +} + +// float16x8 to int8x8 +inline __device__ int64_t quant(uint4 a, const float scale, const float zp) +{ + union { + int16_t int16[4]; + int64_t int64; + }; + + int16[0] = quant(a.x, scale, zp); + int16[1] = quant(a.y, scale, zp); + int16[2] = quant(a.z, scale, zp); + int16[3] = quant(a.w, scale, zp); + return int64; +} + +// int8 to float32, then `vec_conversion` to target format +inline __device__ float dequant(int8_t a, const float scale, const float zp) +{ + float b = a * scale + zp; + return b; +} + +// int8x2 to float32x2 +inline __device__ float2 dequant(int16_t a, const float scale, const float zp) +{ + union { + int8_t int8[2]; + int16_t int16; + }; + int16 = a; + + float2 b; + b.x = int8[0] * scale + zp; + b.y = int8[1] * scale + zp; + return b; +} + +// int8x4 to float32x4 +inline __device__ Float4_ dequant(int32_t a, const float scale, const float zp) +{ + union { + int8_t int8[4]; + int32_t int32; + }; + int32 = a; + + Float4_ b; + b.x.x = (int8[0] * scale) + zp; + b.x.y = (int8[1] * scale) + zp; + b.y.x = (int8[2] * scale) + zp; + b.y.y = (int8[3] * scale) + zp; + return b; +} + +inline __device__ Float8_ dequant(int64_t a, const float scale, const float zp) +{ + union { + int16_t int16[4]; + int64_t int64; + }; + int64 = a; + + Float8_ b; + b.x = dequant(int16[0], scale, zp); + b.y = dequant(int16[1], scale, zp); + b.z = dequant(int16[2], scale, zp); + b.w = dequant(int16[3], scale, zp); + return b; +} + +template +__inline__ __device__ Tout vec_conversion(const Tin& x) +{ + return x; +} + +template<> +__inline__ __device__ uint32_t vec_conversion(const float2& a) +{ + union { + half2 float16; + uint32_t uint32; + }; + + float16 = __float22half2_rn(a); + return uint32; +} + +template<> +__inline__ __device__ uint2 vec_conversion(const Float4_& a) +{ + uint2 b; + float2 val; + val.x = a.x.x; + val.y = a.x.y; + b.x = vec_conversion(val); + + val.x = a.y.x; + val.y = a.y.y; + b.y = vec_conversion(val); + + return b; +} + +template<> +__inline__ __device__ float4 vec_conversion(const Float4_& a) +{ + float4 b; + b.x = a.x.x; + b.y = a.x.y; + b.z = a.y.x; + b.w = a.y.y; + return b; +} + +template<> +__inline__ __device__ uint4 vec_conversion(const Float8_& a) +{ + uint4 b; + b.x = vec_conversion(a.x); + b.y = vec_conversion(a.y); + b.z = vec_conversion(a.z); + b.w = vec_conversion(a.w); + return b; +} + +template<> +__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(const float2 &a) { + return __float22bfloat162_rn(a); +} + +template<> +__inline__ __device__ bf16_4_t vec_conversion(const Float4_ &a) { + bf16_4_t b; + b.x = vec_conversion<__nv_bfloat162, float2>(a.x); + b.y = vec_conversion<__nv_bfloat162, float2>(a.y); + return b; +} + +template<> +__inline__ __device__ bf16_8_t vec_conversion(const Float8_ &a) { + bf16_8_t b; + b.x = vec_conversion<__nv_bfloat162, float2>(a.x); + b.y = vec_conversion<__nv_bfloat162, float2>(a.y); + b.z = vec_conversion<__nv_bfloat162, float2>(a.z); + b.w = vec_conversion<__nv_bfloat162, float2>(a.w); + return b; +} + +static inline __device__ int8_t float_to_int8_rn(float x) +{ + uint32_t dst; + asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x)); + return reinterpret_cast(dst); +} + +template +inline __device__ T ldg(const T* val) { + return __ldg(val); +} + +#if ENABLE_BF16 +template<> +inline __device__ __nv_bfloat162 ldg(const __nv_bfloat162* val) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return val[0]; +#else + return __ldg(val); +#endif +} + +template<> +inline __device__ __nv_bfloat16 ldg(const __nv_bfloat16* val) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return val[0]; +#else + return __ldg(val); +#endif +} +#endif // ENABLE_BF16 diff --git a/examples/offline_inference_quant.py b/examples/offline_inference_quant.py new file mode 100644 index 000000000000..819eb2a4e72b --- /dev/null +++ b/examples/offline_inference_quant.py @@ -0,0 +1,111 @@ +import argparse +import os +from typing import List, Tuple, Dict + +import numpy as np +import pandas as pd +from vllm import LLM, SamplingParams, RequestOutput +from benchmarks.mmlu_template import MMLUTemplate + + +def sample_requests( + # dataset_path: str, + # num_requests: int, + # tokenizer: PreTrainedTokenizerBase, + dev_data_path: str, + test_data_path: str, + subjects: List[str], + # dataset_template: str = "mmlu", + is_analyse: bool = False, +) -> List[Tuple[str, int, int]]: + # Load the dataset. + nums_questions = [] + dataset = [] + labels = [] + template_class = MMLUTemplate + for subject in subjects: + test_dataset = pd.read_csv(os.path.join(test_data_path, subject + "_test.csv"), header=None) + nums_questions.append(len(test_dataset)) + template = template_class(subject, os.path.join(dev_data_path, subject + "_dev.csv"), is_analyse) + for idx in range(len(test_dataset)): + prompt = template.getTemplate(test_dataset, idx) + dataset.append(prompt) + labels.append(test_dataset.iloc[idx, -1]) + return dataset, labels, nums_questions + + +def main(args: argparse.Namespace): + subjects = ["abstract_algebra"] + llm = LLM( + model=args.model, + tokenizer=args.tokenizer, + tensor_parallel_size=args.tensor_parallel_size, + seed=args.seed, + trust_remote_code=args.trust_remote_code, + kv_cache_dtype=args.kv_cache_dtype, + kv_quant_params_path=args.kv_quant_params_path, + quantization=args.quantization + ) + requests, labels, _ = sample_requests( + args.dev_data_path, + args.test_data_path, + subjects, + args.is_analyse, + ) + prompt, label = requests[0], labels[0] + print(f"the correct answer is\n{label}") + sampling_params = SamplingParams( + n=args.n, + temperature=0.0 if args.use_beam_search else 1.0, + top_p=1.0, + use_beam_search=args.use_beam_search, + ignore_eos=True, + max_tokens=args.output_len, + ) + outputs = llm.generate(prompt, sampling_params) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="evaluation for quantization.") + + parser.add_argument("--model", type=str, default="facebook/opt-125m") + parser.add_argument("--tokenizer", type=str, default=None) + parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1) + parser.add_argument("--n", + type=int, + default=1, + help="Number of generated sequences per prompt.") + parser.add_argument("--use-beam-search", action="store_true") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument('--trust-remote-code', + action='store_true', + help='trust remote code from huggingface') + parser.add_argument("--dev-data-path", + type=str, + default=None, + help="path to few-shot dataset") + parser.add_argument("--test-data-path", + type=str, + default=None, + help="path to test dataset") + parser.add_argument("--is-analyse", + action="store_true") + parser.add_argument("--output-len", + type=int, + default=200, + help="nums of max token for evaluation outputs") + parser.add_argument("--kv-cache-dtype", + type=str, + default="float16") + parser.add_argument("--kv-quant-params-path", + type=str, + default=None) + parser.add_argument("--quantization", + type=str, + default="smoothquant") + args = parser.parse_args() + main(args) diff --git a/setup.py b/setup.py index 047ee8d0e894..a8513eb582fc 100644 --- a/setup.py +++ b/setup.py @@ -91,6 +91,24 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version: ext_modules = [] +# int8gemm +i8cugemm_extension = CUDAExtension( + name='vllm.i8cugemm', + sources=[ + 'csrc/int8gemm/cublas/bindings.cpp', + 'csrc/int8gemm/cublas/cublasAlgoMap.cc', + 'csrc/int8gemm/cublas/cublasINT8MMWrapper.cc', + 'csrc/int8gemm/cublas/cublasMMWrapper.cc', + 'csrc/int8gemm/cublas/cuda_utils.cc', + 'csrc/int8gemm/cublas/transform_layout.cu' + ], + extra_compile_args={ + "cxx": CXX_FLAGS, + "nvcc": NVCC_FLAGS, + }, +) +ext_modules.append(i8cugemm_extension) + # Cache operations. cache_extension = CUDAExtension( name="vllm.cache_ops", @@ -102,6 +120,17 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version: ) ext_modules.append(cache_extension) +# Fuse kernels. +fused_extension = CUDAExtension( + name="vllm.fused_kernels", + sources=["csrc/fused.cpp", "csrc/fused_kernels.cu"], + extra_compile_args={ + "cxx": CXX_FLAGS, + "nvcc": NVCC_FLAGS, + }, +) +ext_modules.append(fused_extension) + # Attention kernels. attention_extension = CUDAExtension( name="vllm.attention_ops", diff --git a/tests/kernels/test_activation.py b/tests/kernels/test_activation.py index 8aa35d2b2340..cd548cb000bb 100644 --- a/tests/kernels/test_activation.py +++ b/tests/kernels/test_activation.py @@ -9,6 +9,9 @@ NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing D = [512, 4096, 5120, 13824] # Arbitrary values for testing SEEDS = [0] +SCALE_UP = [0.09, 1.2, 1.9] +SCALE_GATE = [2.17, 1.2, 1.9] +SCALE_OUT = [1.2, 1.9, 0.17] def ref_silu_and_mul(x: torch.Tensor) -> torch.Tensor: @@ -29,13 +32,51 @@ def test_silu_and_mul( ) -> None: torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) - x = torch.randn(num_tokens, 2 * d, dtype=dtype, device='cuda') - out = torch.empty(num_tokens, d, dtype=dtype, device='cuda') + x = torch.randn(num_tokens, 2 * d, dtype=dtype, device="cuda") + out = torch.empty(num_tokens, d, dtype=dtype, device="cuda") activation_ops.silu_and_mul(out, x) ref_out = ref_silu_and_mul(x) assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5) +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("d", D) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("scale_gate", SCALE_GATE) +@pytest.mark.parametrize("scale_up", SCALE_UP) +@pytest.mark.parametrize("scale_out", SCALE_OUT) +@torch.inference_mode() +def test_dequant_silu_and_mul_quant( + num_tokens: int, + d: int, + dtype: torch.dtype, + seed: int, + scale_gate: float, + scale_up: float, + scale_out: float, +) -> None: + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + # x = torch.randn(num_tokens, 2 * d, dtype=dtype, device='cuda') + x = torch.randint( + -1000, 1000, (num_tokens, 2 * d), dtype=torch.int32, device="cuda" + ) + x_ = torch.empty_like(x, dtype=dtype) + x_[:, :d] = x[:, :d] * scale_gate + x_[:, d:] = x[:, d:] * scale_up + out1 = torch.empty(num_tokens, d, dtype=dtype, device="cuda") + activation_ops.silu_and_mul(out1, x_) + out1 = (out1 / scale_out).round().clamp(-128, 127).to(torch.int8) + # ref_out = ref_silu_and_mul(x) + + out2 = torch.empty(num_tokens, d, dtype=torch.int8, device="cuda") + activation_ops.invoke_dequant_silu_and_mul_quant( + out2, x, scale_gate, scale_up, scale_out + ) + assert torch.allclose(out1, out2, atol=2) + + @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("d", D) @pytest.mark.parametrize("dtype", DTYPES) @@ -49,8 +90,8 @@ def test_gelu_new( ) -> None: torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) - x = torch.randn(num_tokens, d, dtype=dtype, device='cuda') - out = torch.empty(num_tokens, d, dtype=dtype, device='cuda') + x = torch.randn(num_tokens, d, dtype=dtype, device="cuda") + out = torch.empty(num_tokens, d, dtype=dtype, device="cuda") activation_ops.gelu_new(out, x) ref_out = get_activation("gelu_new")(x) assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5) @@ -68,8 +109,8 @@ def test_gelu_fast( ) -> None: torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) - x = torch.randn(num_tokens, d, dtype=dtype, device='cuda') - out = torch.empty(num_tokens, d, dtype=dtype, device='cuda') + x = torch.randn(num_tokens, d, dtype=dtype, device="cuda") + out = torch.empty(num_tokens, d, dtype=dtype, device="cuda") activation_ops.gelu_fast(out, x) ref_out = get_activation("gelu_fast")(x) assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5) diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 18985669d159..141efdf3c8e8 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -11,13 +11,24 @@ MAX_SEQ_LEN = 8192 NUM_BLOCKS = 128 # Arbitrary values for testing -DTYPES = [torch.half, torch.bfloat16, torch.float] +DTYPES = [ + torch.half, + # torch.bfloat16, + torch.float, + ] NUM_GEN_SEQS = [7] # Arbitrary values for testing NUM_PREFILL_SEQS = [1, 3, 7] # Arbitrary values for testing NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing HEAD_SIZES = [64, 80, 96, 112, 128, 256] -BLOCK_SIZES = [8, 16, 32] -USE_ALIBI = [False, True] +BLOCK_SIZES = [ + 8, + 16, + # 32, + ] +USE_ALIBI = [ + False, + True, + ] SEEDS = [0] @@ -284,3 +295,218 @@ def test_multi_query_kv_attention( dtype, ) assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) + + +# def test_single_query_cached_kv_attention_quantized() -> None: +# torch.random.manual_seed(TEST_SEED) +# torch.cuda.manual_seed(TEST_SEED) +# for dtype in [ +# torch.half, +# torch.bfloat16, +# torch.float, +# ]: +# for block_size in [8, +# 16, +# ]: +# for head_size in [64, +# 80, +# 96, +# 112, +# 128, +# 256, +# ]: +# print(f'Testing single_query_cached_kv_attention with ' +# f'dtype={dtype}, block_size={block_size}, ' +# f'head_size={head_size}') +# run_single_query_cached_kv_attention_quantized( +# num_tokens=37, +# num_heads=3, +# head_size=head_size, +# block_size=block_size, +# num_blocks=1024, +# dtype=dtype, +# ) + + +def ref_single_query_cached_kv_attention_quantized( + output: torch.Tensor, + query: torch.Tensor, + num_queries_per_kv: int, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + scale: float, + alibi_slopes: Optional[torch.Tensor], + k_scale: float, + k_zp: float, + v_scale: float, + v_zp: float, +) -> None: + num_query_heads = query.shape[1] + num_kv_heads = value_cache.shape[1] + head_size = value_cache.shape[2] + block_size = value_cache.shape[3] + num_seqs = query.shape[0] + + block_tables = block_tables.cpu().tolist() + context_lens = context_lens.cpu().tolist() + for i in range(num_seqs): + q = query[i].unsqueeze(0) + block_table = block_tables[i] + context_len = int(context_lens[i]) + + keys = [] + values = [] + for j in range(context_len): + block_number = int(block_table[j // block_size]) + block_offset = j % block_size + + k = key_cache[block_number, :, :, block_offset, :] + k = k.reshape(num_kv_heads, head_size) + k = k.to(torch.float32) + k = k * k_scale + k_zp + k = k.to(q.dtype) + keys.append(k) + + v = value_cache[block_number, :, :, block_offset] + v = v.to(torch.float32) + v = v * v_scale + v_zp + v = v.to(q.dtype) + values.append(v) + keys = torch.stack(keys, dim=0) + values = torch.stack(values, dim=0) + + if num_queries_per_kv > 1: + # Handle MQA and GQA + keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1) + values = torch.repeat_interleave(values, num_queries_per_kv, dim=1) + + alibi_bias = None + if alibi_slopes is not None: + # Create the ALiBi bias used in the paged attention kernel. + position_ids = torch.arange(context_len, device="cuda").int() + alibi_bias = (position_ids - context_len + 1).float() + alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view( + 1, 1, -1) + + out = ref_masked_attention(q, keys, values, scale, alibi_bias) + out = out.view(num_query_heads, head_size) + output[i].copy_(out, non_blocking=True) + + +@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("use_alibi", USE_ALIBI) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@torch.inference_mode() +def test_single_query_cached_kv_attention_quantized( + # kv_cache_factory, + num_seqs: int, + num_heads: Tuple[int, int], + head_size: int, + use_alibi: bool, + block_size: int, + dtype: torch.dtype, + seed: int, + k_scale: float = 1e-2, + k_zp: float = 0.0, + v_scale: float = 1e-2, + v_zp: float = 0.0, +) -> None: + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + scale = float(1.0 / (head_size**0.5)) + num_query_heads, num_kv_heads = num_heads + query = torch.empty(num_seqs, + num_query_heads, + head_size, + dtype=dtype, + device="cuda") + query.uniform_(-scale, scale) + + assert num_query_heads % num_kv_heads == 0 + num_queries_per_kv = num_query_heads // num_kv_heads + head_mapping = torch.repeat_interleave( + torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"), + num_queries_per_kv) + alibi_slopes = None + if use_alibi: + alibi_slopes = torch.randn(num_query_heads, + dtype=torch.float, + device="cuda") + + context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] + max_context_len = max(context_lens) + context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda") + + # Create the block tables. + max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size + block_tables = [] + for _ in range(num_seqs): + block_table = [ + random.randint(0, NUM_BLOCKS - 1) + for _ in range(max_num_blocks_per_seq) + ] + block_tables.append(block_table) + block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda") + + # Create the KV caches. + + # key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1, + # num_kv_heads, head_size, dtype, + # seed) + # key_cache, value_cache = key_caches[0], value_caches[0] + + x = 16 // torch.tensor([], dtype=torch.int8).element_size() ## use int8 dtype + key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, block_size, x) + key_cache = torch.randint(-10, 10, size=key_cache_shape, dtype=torch.int8, device='cuda') + value_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size, block_size) + value_cache = torch.randint(-10, 10, size=value_cache_shape, + dtype=torch.int8, ## change to int8 + device='cuda') + # Call the paged attention kernel. + output = torch.empty_like(query) + attention_ops.single_query_cached_kv_quantized_attention( + output, + query, + key_cache, + value_cache, + head_mapping, + scale, + block_tables, + context_lens, + block_size, + max_context_len, + alibi_slopes, # ALiBi slopes. + k_scale, + k_zp, + v_scale, + v_zp, + ) + + ref_output = torch.empty_like(query) + ref_single_query_cached_kv_attention_quantized( + ref_output, + query, + num_queries_per_kv, + key_cache, + value_cache, + block_tables, + context_lens, + scale, + alibi_slopes, + k_scale, + k_zp, + v_scale, + v_zp, + ) + # NOTE(woosuk): Due to the difference in the data types the two + # implementations use for attention softmax logits and accumulation, + # there is a small difference in the final outputs. + # We should use a relaxed tolerance for the test. + assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index cca037df235d..baa90dc675b9 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -5,12 +5,20 @@ from vllm import cache_ops -DTYPES = [torch.half, torch.bfloat16, torch.float] +DTYPES = [ + # torch.half, + # torch.bfloat16, + torch.float +] NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing NUM_LAYERS = [5] # Arbitrary values for testing NUM_HEADS = [8] # Arbitrary values for testing HEAD_SIZES = [64, 80, 96, 112, 128, 256] -BLOCK_SIZES = [8, 16, 32] +BLOCK_SIZES = [ + 8, + 16, + 32, +] NUM_BLOCKS = [1024] # Arbitrary values for testing NUM_MAPPINGS = [32, 256] # Arbitrary values for testing SEEDS = [0] @@ -144,3 +152,74 @@ def test_reshape_and_cache( assert torch.allclose(key_cache, cloned_key_cache) assert torch.allclose(value_cache, cloned_value_cache) + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) +@pytest.mark.parametrize("dtype", DTYPES) +# @pytest.mark.parametrize("seed", SEEDS) +@torch.inference_mode() +def test_reshape_and_cache_quantized( + num_tokens: int, + num_heads: int, + head_size: int, + block_size: int, + num_blocks: int, + dtype: torch.dtype, + k_scale: float = 3.0, + k_zp: float = 0.0, + v_scale: float = 3.0, + v_zp: float = 0.0, +) -> None: + num_slots = block_size * num_blocks + slot_mapping = random.sample(range(num_slots), num_tokens) + slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda') + + qkv = torch.randn(num_tokens, + 3, + num_heads, + head_size, + dtype=dtype, + device='cuda') + _, key, value = qkv.unbind(dim=1) + + x = 16 // torch.tensor([], dtype=torch.int8).element_size() + key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) + key_cache = torch.randint(-10, 10, size=key_cache_shape, dtype=torch.int8, device='cuda') ## change to int8 + cloned_key_cache = key_cache.clone() + + value_cache_shape = (num_blocks, num_heads, head_size, block_size) + value_cache = torch.randint(-10, 10, size=value_cache_shape, + dtype=torch.int8, ## change to int8 + device='cuda') + cloned_value_cache = value_cache.clone() + + cache_ops.reshape_and_cache_quantized(key, value, key_cache, value_cache, + slot_mapping, k_scale, k_zp, v_scale, v_zp) + lower_bound, upper_bound = torch.tensor([-128.0], dtype=dtype, device='cuda'), torch.tensor([127.0], dtype=dtype, device='cuda') + ## quantize and store here + ## quantize and store here + quantized_key = key.reshape(num_tokens, num_heads, head_size // x, x) + quantized_key = quantized_key.to(torch.float32) + quantized_key = torch.maximum(lower_bound, torch.minimum(upper_bound, (quantized_key - k_zp) / k_scale)) + quantized_key = torch.round(quantized_key) + quantized_key = quantized_key.to(torch.int8) ## change to int8 + + quantized_value = value.to(torch.float32) + quantized_value = torch.maximum(lower_bound, torch.minimum(upper_bound, (quantized_value - v_zp) / v_scale)) + quantized_value = torch.round(quantized_value) + quantized_value = quantized_value.to(torch.int8) + + for i in range(num_tokens): + block_idx = torch.div(slot_mapping[i], + block_size, + rounding_mode='floor') + block_offset = slot_mapping[i] % block_size + cloned_key_cache[block_idx, :, :, block_offset, :] = quantized_key[i] + cloned_value_cache[block_idx, :, :, block_offset] = quantized_value[i] + + assert torch.allclose(key_cache, cloned_key_cache) + assert torch.allclose(value_cache, cloned_value_cache) diff --git a/tests/kernels/test_layernorm.py b/tests/kernels/test_layernorm.py index a63ef5cc76ff..1083d88e368b 100644 --- a/tests/kernels/test_layernorm.py +++ b/tests/kernels/test_layernorm.py @@ -2,16 +2,16 @@ import torch import torch.nn as nn -from vllm import layernorm_ops +from vllm import layernorm_ops, fused_kernels DTYPES = [torch.half, torch.bfloat16, torch.float] HIDDEN_SIZES = [67, 768, 2048, 5120, 8192] # Arbitrary values for testing NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing SEEDS = [0] +SCALE = [0.1, 0.5, 0.8, 1.2, 2.1] class RefRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): super().__init__() weight = torch.empty(hidden_size) @@ -23,8 +23,7 @@ def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + - self.variance_epsilon) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) @@ -56,3 +55,155 @@ def test_rms_norm( ) ref_out = ref(x) assert torch.allclose(out, ref_out, atol=1e-2, rtol=1e-5) + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@torch.inference_mode() +def test_rms_norm_quant( + num_tokens: int, + hidden_size: int, + dtype: torch.dtype, + seed: int, +) -> None: + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + scale = float(hidden_size**-0.5) + x = torch.empty(num_tokens, hidden_size, dtype=dtype, device="cuda") + x.uniform_(-scale, scale) + ref = RefRMSNorm(hidden_size).to(dtype).cuda() + + out1 = torch.empty_like(x) + layernorm_ops.rms_norm( + out1, + x, + ref.weight.data, + ref.variance_epsilon, + ) + out1 = out1.clamp(-128, 127).round().to(torch.int8) + out2 = torch.empty_like(x, dtype=torch.int8) + layernorm_ops.invoke_rms_norm_quant(out2, x, ref.weight.data, ref.variance_epsilon) + assert torch.allclose(out1, out2, atol=1.0) + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("scale", SCALE) +@torch.inference_mode() +def test_dequant_add_residual_rms_norm_quant( + num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int, scale: float +) -> None: + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + s = float(hidden_size**-0.5) + residual = torch.empty(num_tokens, hidden_size, dtype=dtype, device="cuda") + # x = torch.randint(torch.iinfo(torch.int32).min, torch.iinfo(torch.int32).max, (num_tokens, hidden_size), dtype=torch.int32, device="cuda") + x = torch.randint( + -1000, 1000, (num_tokens, hidden_size), dtype=torch.int32, device="cuda" + ) + residual.uniform_(-s, s) + ref = RefRMSNorm(hidden_size).to(dtype).cuda() + x_ = (x * scale + residual).to(dtype) + + out1 = torch.empty_like(x_) + layernorm_ops.rms_norm( + out1, + x_, + ref.weight.data, + ref.variance_epsilon, + ) + out1 = out1.round().clamp(-128, 127).to(torch.int8) + out2 = torch.empty_like(x, dtype=torch.int8) + layernorm_ops.invoke_dequant_add_residual_rms_norm_quant( + out2, x, residual, ref.weight.data, ref.variance_epsilon, scale + ) + + assert torch.allclose(out1, out2, atol=1.0) + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("scale", SCALE) +@torch.inference_mode() +def test_dequant_add_residual( + num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int, scale: float +) -> None: + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + s = float(hidden_size**-0.5) + residual = torch.empty(num_tokens, hidden_size, dtype=dtype, device="cuda") + x = torch.randint( + torch.iinfo(torch.int32).min, + torch.iinfo(torch.int32).max, + (num_tokens, hidden_size), + dtype=torch.int32, + device="cuda", + ) + residual.uniform_(-s, s) + out1 = (x * scale + residual).to(dtype) + + out2 = torch.empty_like(x, dtype=dtype) + fused_kernels.invoke_dequant_add_residual(out2, x, residual, scale) + + assert torch.allclose(out1, out2, atol=0.001), f"diff: {torch.max(out1 - out2)}" + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("scale", SCALE) +@torch.inference_mode() +def test_dequant( + num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int, scale: float +) -> None: + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + s = float(hidden_size**-0.5) + # residual = torch.empty(num_tokens, hidden_size, dtype=dtype, device="cuda") + x = torch.randint( + torch.iinfo(torch.int32).min, + torch.iinfo(torch.int32).max, + (num_tokens, hidden_size), + dtype=torch.int32, + device="cuda", + ) + # residual.uniform_(-s, s) + out1 = (x * scale).to(dtype) + + out2 = torch.empty_like(x, dtype=dtype) + fused_kernels.invoke_dequant(out2, x, scale) + assert torch.allclose(out1, out2, atol=0.001) + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("scale", SCALE) +@torch.inference_mode() +def test_quant( + num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int, scale: float +) -> None: + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + s = float(hidden_size**-0.5) + # residual = torch.empty(num_tokens, hidden_size, dtype=dtype, device="cuda") + x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 + # residual.uniform_(-s, s) + out1 = (x / scale).round().clamp(-128, 127).to(torch.int8) + + out2 = torch.empty_like(x, dtype=torch.int8) + fused_kernels.invoke_quant(out2, x, scale) + assert torch.allclose(out1, out2, atol=1) diff --git a/tests/kernels/test_pos_encoding.py b/tests/kernels/test_pos_encoding.py index 0d255900d4c1..8c99f3bc77d4 100644 --- a/tests/kernels/test_pos_encoding.py +++ b/tests/kernels/test_pos_encoding.py @@ -10,15 +10,17 @@ IS_NEOX_STYLE = [True, False] DTYPES = [torch.half, torch.bfloat16, torch.float] HEAD_SIZES = [64, 80, 96, 112, 128, 256] -ROTARY_DIMS = [None, 32] # None means rotary dim == head size +ROTARY_DIMS = [None] # None means rotary dim == head size NUM_HEADS = [7, 12, 40, 52] # Arbitrary values for testing NUM_TOKENS = [11, 83, 2048] # Arbitrary values for testing SEEDS = [0] +QUERY_SCALE = [0.0002, 0.0008] +KEY_SCALE = [0.0002, 0.0008] def rotate_neox(x: torch.Tensor) -> torch.Tensor: - x1 = x[..., :x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2:] + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) @@ -58,7 +60,7 @@ def __init__( self.max_position_embeddings = max_position_embeddings # Create cos and sin embeddings. - inv_freq = 1.0 / (base**(torch.arange(0, dim, 2) / dim)) + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2) / dim)) t = torch.arange(max_position_embeddings).float() freqs = torch.einsum("i,j->ij", t, inv_freq.float()) if is_neox_style: @@ -76,18 +78,19 @@ def forward( query: torch.Tensor, # [num_tokens, num_heads, head_size] key: torch.Tensor, # [num_tokens, num_heads, head_size] ) -> Tuple[torch.Tensor, torch.Tensor]: - query_rot = query[..., :self.rotary_dim] - query_pass = query[..., self.rotary_dim:] - key_rot = key[..., :self.rotary_dim] - key_pass = key[..., self.rotary_dim:] + query_rot = query[..., : self.rotary_dim] + query_pass = query[..., self.rotary_dim :] + key_rot = key[..., : self.rotary_dim] + key_pass = key[..., self.rotary_dim :] query_rot = query_rot.transpose(0, 1) key_rot = key_rot.transpose(0, 1) cos = F.embedding(positions, self.cos_cached) sin = F.embedding(positions, self.sin_cached) - query_rot, key_rot = apply_rope(query_rot, key_rot, cos, sin, - self.is_neox_style) + query_rot, key_rot = apply_rope( + query_rot, key_rot, cos, sin, self.is_neox_style + ) query_rot = query_rot.transpose(0, 1).contiguous() key_rot = key_rot.transpose(0, 1).contiguous() @@ -122,25 +125,20 @@ def test_rotary_embedding( torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) - positions = torch.randint(0, max_position, (num_tokens, ), device="cuda") - query = torch.randn(num_tokens, - num_heads * head_size, - dtype=dtype, - device="cuda") - key = torch.randn(num_tokens, - num_heads * head_size, - dtype=dtype, - device="cuda") + positions = torch.randint(0, max_position, (num_tokens,), device="cuda") + query = torch.randn(num_tokens, num_heads * head_size, dtype=dtype, device="cuda") + key = torch.randn(num_tokens, num_heads * head_size, dtype=dtype, device="cuda") # Create the rotary embedding. - inv_freq = 1.0 / (base**( - torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim)) + inv_freq = 1.0 / ( + base ** (torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim) + ) t = torch.arange(max_position).float() freqs = torch.einsum("i,j -> ij", t, inv_freq) cos = freqs.cos() sin = freqs.sin() cos_sin_cache = torch.cat((cos, sin), dim=-1) - cos_sin_cache = cos_sin_cache.to(dtype=dtype, device='cuda') + cos_sin_cache = cos_sin_cache.to(dtype=dtype, device="cuda") # Run the kernel. The kernel is in-place, so we need to clone the inputs. out_query = query.clone() @@ -168,7 +166,95 @@ def test_rotary_embedding( ) ref_query = ref_query.view(num_tokens, num_heads * head_size) ref_key = ref_key.view(num_tokens, num_heads * head_size) - # Compare the results. assert torch.allclose(out_query, ref_query, atol=1e-5, rtol=1e-5) assert torch.allclose(out_key, ref_key, atol=1e-5, rtol=1e-5) + + +@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("query_scale", QUERY_SCALE) +@pytest.mark.parametrize("key_scale", KEY_SCALE) +@torch.inference_mode() +def test_dequant_rotary_embedding( + is_neox_style: bool, + num_tokens: int, + num_heads: int, + head_size: int, + rotary_dim: Optional[int], + dtype: torch.dtype, + seed: int, + query_scale: float, + key_scale: float, + max_position: int = 8192, + base: int = 10000, +) -> None: + if rotary_dim is None: + rotary_dim = head_size + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + positions = torch.randint(0, max_position, (num_tokens,), device="cuda") + query = torch.randint( + -1000, + 1000, + (num_tokens, num_heads * head_size), + dtype=torch.int32, + device="cuda", + ) + key = torch.randint( + -1000, + 1000, + (num_tokens, num_heads * head_size), + dtype=torch.int32, + device="cuda", + ) + query_ = (query * query_scale).to(dtype) + key_ = (key * key_scale).to(dtype) + + # Create the rotary embedding. + inv_freq = 1.0 / ( + base ** (torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim) + ) + t = torch.arange(max_position).float() + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cos_sin_cache = torch.cat((cos, sin), dim=-1) + cos_sin_cache = cos_sin_cache.to(dtype=dtype, device="cuda") + + ref_rotary_embedding = RefRotaryEmbedding( + dim=rotary_dim, + is_neox_style=is_neox_style, + max_position_embeddings=max_position, + base=base, + ).to(dtype=dtype, device="cuda") + ref_query, ref_key = ref_rotary_embedding( + positions, + query_.view(num_tokens, num_heads, head_size), + key_.view(num_tokens, num_heads, head_size), + ) + ref_query = ref_query.view(num_tokens, num_heads * head_size) + ref_key = ref_key.view(num_tokens, num_heads * head_size) + out2_query = torch.empty_like(query_) + out2_key = torch.empty_like(key_) + + pos_encoding_ops.invoke_dequant_rotary_embedding( + positions, + query, + out2_query, + key, + out2_key, + head_size, + cos_sin_cache, + query_scale, + key_scale, + is_neox_style, + ) + assert torch.allclose(ref_key, out2_key, atol=1e-4) + assert torch.allclose(ref_query, out2_query, atol=1e-4) diff --git a/vllm/config.py b/vllm/config.py index dd92fbccd899..3e2d50425c2f 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -60,6 +60,8 @@ def __init__( revision: Optional[str], max_model_len: Optional[int] = None, quantization: Optional[str] = None, + kv_cache_dtype: str = None, ## for kv cache quantization, only for int8 right now + kv_quant_params_path: str = None, ## path for kv scales and zero points ) -> None: self.model = model self.tokenizer = tokenizer @@ -74,6 +76,10 @@ def __init__( self.hf_config = get_config(model, trust_remote_code, revision) self.dtype = _get_and_verify_dtype(self.hf_config, dtype) self._verify_load_format() + ## for kv cache quantization + self.kv_cache_dtype = _STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype] if kv_cache_dtype else self.dtype + self.quant_kv_cache = not self.kv_cache_dtype == self.dtype + self.kv_quant_params_path = kv_quant_params_path self._verify_tokenizer_mode() self._verify_quantization() self.max_model_len = None @@ -106,7 +112,7 @@ def _verify_tokenizer_mode(self) -> None: self.tokenizer_mode = tokenizer_mode def _verify_quantization(self) -> None: - supported_quantization = ["awq"] + supported_quantization = ["awq", "smoothquant"] if self.quantization is None: return quantization = self.quantization.lower() @@ -296,6 +302,7 @@ def __init__(self, max_num_batched_tokens: int, max_num_seqs: int, _STR_DTYPE_TO_TORCH_DTYPE = { + "int8": torch.int8, "half": torch.float16, "float16": torch.float16, "float": torch.float32, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index a03155a4929d..ad3b3883107b 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -30,6 +30,8 @@ class EngineArgs: disable_log_stats: bool = False revision: Optional[str] = None quantization: Optional[str] = None + kv_cache_dtype: str = "float16" + kv_quant_params_path: str = None def __post_init__(self): if self.tokenizer is None: @@ -103,6 +105,18 @@ def add_cli_args( default=None, help='model context length. If unspecified, ' 'will be automatically derived from the model.') + # kv cache quantization + parser.add_argument( + '--kv-cache-dtype', + type=str, + default=EngineArgs.kv_cache_dtype, + help='data type for kv cache') + parser.add_argument( + '--kv-quant-params-path', + type=str, + default=EngineArgs.kv_quant_params_path, + help="path to kv scales and zero points" + ) # Parallel arguments parser.add_argument('--worker-use-ray', action='store_true', @@ -154,7 +168,7 @@ def add_cli_args( parser.add_argument('--quantization', '-q', type=str, - choices=['awq', None], + choices=['awq', "smoothquant", None], default=None, help='Method used to quantize the weights') return parser @@ -174,7 +188,8 @@ def create_engine_configs( self.tokenizer_mode, self.trust_remote_code, self.download_dir, self.load_format, self.dtype, self.seed, self.revision, - self.max_model_len, self.quantization) + self.max_model_len, self.quantization, + self.kv_cache_dtype, self.kv_quant_params_path) cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.swap_space) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 743454301838..4214f835a2dc 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -81,7 +81,9 @@ def __init__( f"load_format={model_config.load_format}, " f"tensor_parallel_size={parallel_config.tensor_parallel_size}, " f"quantization={model_config.quantization}, " - f"seed={model_config.seed})") + f"seed={model_config.seed})" + f"kv_cache_type={model_config.kv_cache_dtype}" + f"use kv cache quantization: {model_config.quant_kv_cache}") # TODO(woosuk): Print more configs in debug mode. self.model_config = model_config diff --git a/vllm/kv_quant/calib_dataloader.py b/vllm/kv_quant/calib_dataloader.py new file mode 100644 index 000000000000..bd0a86823577 --- /dev/null +++ b/vllm/kv_quant/calib_dataloader.py @@ -0,0 +1,311 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import torch + + +def set_seed(seed): + np.random.seed(seed) + torch.random.manual_seed(seed) + + +def get_wikitext2(tokenizer, nsamples, seed, seqlen, path=None): + """Load Wikitext-2 train and test datasets and tokenize. + + Args: + tokenizer: Tokenizer to encode text. + nsamples: Number of samples to take from train set. + seed: Random seed for sampling. + seqlen: Maximum sequence length. + + Returns: + train_loader: List of sampled and tokenized training examples. + test_enc: Full tokenized Wikitext-2 test set. + """ + from datasets import load_dataset + traindata = load_dataset(path if path else 'wikitext', 'wikitext-2-raw-v1', split='train') + testdata = load_dataset(path if path else 'wikitext', 'wikitext-2-raw-v1', split='test') + + trainenc = tokenizer('\n\n'.join(traindata['text']), return_tensors='pt') + testenc = tokenizer('\n\n'.join(testdata['text']), return_tensors='pt') + + import random + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + return trainloader, testenc + + +def get_ptb(tokenizer, nsamples, seed, seqlen): + """Load PTB train and validation datasets and tokenize. + + Args: + tokenizer: Tokenizer to encode text. + nsamples: Number of samples to take from train set. + seed: Random seed for sampling. + seqlen: Maximum sequence length. + + Returns: + train_loader: List of sampled and tokenized training examples. + test_enc: Full tokenized PTB validation set. + """ + from datasets import load_dataset + traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train') + valdata = load_dataset('ptb_text_only', + 'penn_treebank', + split='validation') + + trainenc = tokenizer('\n\n'.join(traindata['sentence']), + return_tensors='pt') + testenc = tokenizer('\n\n'.join(valdata['sentence']), return_tensors='pt') + + import random + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + return trainloader, testenc + + +def get_c4(tokenizer, nsamples, seed, seqlen, path=None): + """Load C4 train and validation datasets and tokenize. + + Args: + tokenizer: Tokenizer to encode text. + nsamples: Number of samples to take from train set. + seed: Random seed for sampling. + seqlen: Maximum sequence length. + + Returns: + train_loader: List of sampled and tokenized training examples. + test_enc: Full tokenized PTB validation set. + """ + from datasets import load_dataset + traindata = load_dataset( + path if path else 'allenai/c4', + 'allenai--c4', + data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, + split='train', + use_auth_token=False) + valdata = load_dataset( + path if path else 'allenai/c4', + 'allenai--c4', + data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, + split='validation', + use_auth_token=False) + + import random + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + while True: + i = random.randint(0, len(traindata) - 1) + trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') + if trainenc.input_ids.shape[1] >= seqlen: + break + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + + import random + random.seed(0) + valenc = [] + for _ in range(256): + while True: + i = random.randint(0, len(valdata) - 1) + tmp = tokenizer(valdata[i]['text'], return_tensors='pt') + if tmp.input_ids.shape[1] >= seqlen: + break + i = random.randint(0, tmp.input_ids.shape[1] - seqlen) + j = i + seqlen + valenc.append(tmp.input_ids[:, i:j]) + valenc = torch.hstack(valenc) + + class TokenizerWrapper: + + def __init__(self, input_ids): + self.input_ids = input_ids + + valenc = TokenizerWrapper(valenc) + + return trainloader, valenc + + +def get_ptb_new(tokenizer, nsamples, seed, seqlen): + """Load PTB New train and validation datasets and tokenize. + + Args: + tokenizer: Tokenizer to encode text. + nsamples: Number of samples to take from train set. + seed: Random seed for sampling. + seqlen: Maximum sequence length. + + Returns: + train_loader: List of sampled and tokenized training examples. + test_enc: Full tokenized PTB validation set. + """ + from datasets import load_dataset + traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train') + testdata = load_dataset('ptb_text_only', 'penn_treebank', split='test') + + trainenc = tokenizer(' '.join(traindata['sentence']), return_tensors='pt') + testenc = tokenizer(' '.join(testdata['sentence']), return_tensors='pt') + + import random + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + return trainloader, testenc + + +def get_c4_new(tokenizer, nsamples, seed, seqlen): + """Load C4 New train and validation datasets and tokenize. + + Args: + tokenizer: Tokenizer to encode text. + nsamples: Number of samples to take from train set. + seed: Random seed for sampling. + seqlen: Maximum sequence length. + + Returns: + train_loader: List of sampled and tokenized training examples. + test_enc: Full tokenized PTB validation set. + """ + from datasets import load_dataset + traindata = load_dataset( + 'allenai/c4', + 'allenai--c4', + data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, + split='train') + valdata = load_dataset( + 'allenai/c4', + 'allenai--c4', + data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, + split='validation') + + import random + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + while True: + i = random.randint(0, len(traindata) - 1) + trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') + if trainenc.input_ids.shape[1] >= seqlen: + break + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + + valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt') + valenc = valenc.input_ids[:, :(256 * seqlen)] + + class TokenizerWrapper: + + def __init__(self, input_ids): + self.input_ids = input_ids + + valenc = TokenizerWrapper(valenc) + + return trainloader, valenc + + +def get_pileval(tokenizer, nsamples, seed, seqlen=512): + """Load pileval train dataset and tokenize. + + Args: + tokenizer: Tokenizer to encode text. + nsamples: Number of samples to take from train set. + seed: Random seed for sampling. + seqlen: Maximum sequence length. + + Returns: + train_loader: List of sampled and tokenized training examples. + test_enc: Full tokenized PTB validation set. + """ + from datasets import load_dataset + from datasets.builder import DatasetGenerationError + try: + dataset = load_dataset( + 'json', + data_files='https://the-eye.eu/public/AI/pile/val.jsonl.zst', + split='train') + except DatasetGenerationError: + raise InterruptedError('There have been some issues when generating ' + 'the dataset, you could try to download it ' + 'locally first, and replace the `data_files`' + 'with local addresses or use other datasets ' + '(c4, wiki, ptb).') + dataset = dataset.shuffle(seed=seed) + samples = [] + n_run = 0 + for data in dataset: + line = data['text'] + line = line.strip() + line_encoded = tokenizer.encode(line) + if len(line_encoded) > 512: + continue + sample = torch.tensor([line_encoded]) + if sample.numel() == 0: + continue + samples.append(sample) + n_run += 1 + if n_run == nsamples: + break + # now concatenate all samples and split according to block size + cat_samples = torch.cat(samples, dim=1) + n_split = cat_samples.shape[1] // seqlen + print(f' * Split into {n_split} blocks') + return [ + cat_samples[:, i * seqlen:(i + 1) * seqlen] for i in range(n_split) + ], None + + +def get_calib_loaders(name, tokenizer, nsamples=128, seed=0, seqlen=2048, path=None): + """Get calibration data loaders for a dataset. + + Args: + name: Dataset name ('wikitext2', 'ptb', 'c4', etc). + tokenizer: Tokenizer to encode text. + nsamples: Number of samples to take from train set. + seed: Random seed for sampling. + seqlen: Maximum sequence length. + + Returns: + train_loader: List of sampled and tokenized training examples. + test_data: Full tokenized validation set. + """ + if 'wikitext2' in name: + return get_wikitext2(tokenizer, nsamples, seed, seqlen, path) + if 'ptb' in name: + if 'new' in name: + return get_ptb_new(tokenizer, nsamples, seed, seqlen, path) + return get_ptb(tokenizer, nsamples, seed, seqlen, path) + if 'c4' in name: + if 'new' in name: + return get_c4_new(tokenizer, nsamples, seed, seqlen, path) + return get_c4(tokenizer, nsamples, seed, seqlen, path) + + if 'pileval' in name: + return get_pileval(tokenizer, nsamples, seed, seqlen, path) diff --git a/vllm/kv_quant/calibrate.py b/vllm/kv_quant/calibrate.py new file mode 100644 index 000000000000..7097e29e9d98 --- /dev/null +++ b/vllm/kv_quant/calibrate.py @@ -0,0 +1,117 @@ +# coding=utf-8 +# Adapted from +# https://github.com/InternLM/lmdeploy/blob/main/lmdeploy/lite/apis/calibrate.py + +# Copyright (c) OpenMMLab. All rights reserved. + +from pathlib import Path + +import fire +import torch +from accelerate import (infer_auto_device_map, init_empty_weights, + load_checkpoint_in_model) +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer + +from vllm.kv_quant.calibration import CalibrationContext +from vllm.kv_quant.utils import collect_target_modules +from vllm.kv_quant.calib_dataloader import get_calib_loaders + +LAYER_TYPE_MAP = { + 'InternLMForCausalLM': 'InternLMDecoderLayer', + 'QWenLMHeadModel': 'QWenBlock', + 'BaiChuanForCausalLM': 'DecoderLayer', + 'LlamaForCausalLM': 'LlamaDecoderLayer', +} +NORM_TYPE_MAP = { + 'InternLMForCausalLM': 'InternLMRMSNorm', + 'QWenLMHeadModel': 'RMSNorm', + 'BaiChuanForCausalLM': 'RMSNorm', + 'LlamaForCausalLM': 'LlamaRMSNorm', +} + + +def calibrate(model: str, + calib_dataset: str = 'c4', + calib_samples: int = 128, + calib_seqlen: int = 2048, + work_dir: str = './work_dir', + device: str = 'cuda', + dataset_path: str = None) -> None: + """The main function for loading the model and performing calibration on a + given dataset. + + Args: + model (str): The model to be loaded. + calib_dataset (str, optional): The calibration dataset name. + Defaults to 'c4'. + calib_samples (int, optional): The number of samples for calibration. + Defaults to 128. + calib_seqlen (int, optional): The sequence length for calibration. + Defaults to 2048. + work_dir (str): The working directory for outputs. + Defaults to './work_dir'. + device (str, optional): The device to be used for calculation. + Defaults to 'cuda'. + """ + + assert calib_dataset in ['c4', 'ptb', 'wikitext2', 'pileval'], \ + 'Support only `c4`, `ptb`, `wikitext2` or `pileval`.' + + # Load tokenizer and configuration + tokenizer = AutoTokenizer.from_pretrained(model, + use_fast=False, + trust_remote_code=True) + hf_config = AutoConfig.from_pretrained(model, trust_remote_code=True) + checkpoint = hf_config._name_or_path + + with init_empty_weights(): + # Load model + model = AutoModelForCausalLM.from_pretrained(model, + torch_dtype=torch.float16, + trust_remote_code=True) + model.config.use_cache = False + + layer_type = LAYER_TYPE_MAP[type(model).__name__] + norm_type = NORM_TYPE_MAP[type(model).__name__] + + decoder_layers = collect_target_modules(model, layer_type) + + # Infer device map + device_map = infer_auto_device_map(model, + no_split_module_classes=[layer_type]) + for name in device_map.keys(): + if name in decoder_layers or 'lm_head' in name: + device_map[name] = 'cpu' + else: + device_map[name] = 0 + load_checkpoint_in_model(model, checkpoint, device_map) + + print('Loading calibrate dataset ...') + calib_loader, _ = get_calib_loaders(calib_dataset, + tokenizer, + nsamples=calib_samples, + seqlen=calib_seqlen, + path=dataset_path) + + # Initialize calibration context + calib_ctx = CalibrationContext(model, + tokenizer, + layer_type=layer_type, + norm_type=norm_type, + device=device) + + with calib_ctx: + all_data = torch.cat([ + data if isinstance(data, torch.Tensor) else data[0] + for data in calib_loader + ]).to(device) + calib_ctx.calibrate(all_data) + + # Create work directory if not exists + work_dir = Path(work_dir) + work_dir.mkdir(parents=True, exist_ok=True) + calib_ctx.export(work_dir) + + +if __name__ == '__main__': + fire.Fire(calibrate) diff --git a/vllm/kv_quant/calibration.py b/vllm/kv_quant/calibration.py new file mode 100644 index 000000000000..d38e9e486456 --- /dev/null +++ b/vllm/kv_quant/calibration.py @@ -0,0 +1,307 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from functools import partial +from typing import Union + +import torch +from torch import nn +from transformers import PreTrainedTokenizer +from vllm.kv_quant.utils import (bimap_name_mod, collect_target_modules, + concat_decoder_layer_outputs, + split_decoder_layer_inputs) +from vllm.kv_quant.observer import ActivationObserver, KVCacheObserver + + +class CalibrationContext(): + """Calibration context manager for model quantization. + + Parameters: + - model: The target model to be calibrated and quantized + - tokenizer: The tokenizer used in the model training + - layer_type: Layer type to be targeted for calibration + - norm_type: Normalization type used for calibration + - device: Device on which model is to be calibrated ('cpu' or 'cuda') + """ + + inp_obs_group = 'inputs' + out_obs_group = 'outputs' + key_obs_group = 'keys' + value_obs_group = 'values' + + def __init__(self, + model: nn.Module, + tokenizer: PreTrainedTokenizer, + layer_type: Union[str, type], + norm_type: Union[str, type], + device: str = 'cuda') -> None: + """Initiate calibration context. + + Args: + model (nn.Module): Model to be calibrated. + tokenizer (PreTrainedTokenizer): Tokenizer of the given model. + layer_type (Union[str, type]): Type of the layers to be observed. + norm_type (Union[str, type]): Norm type used in the model. + device (str, optional): Device where the model should run. + Defaults to 'cuda'. + """ + + self.layer_type = layer_type + self.norm_type = norm_type + + num_kv_heads, num_attn_heads = self._guess_num_heads(model) + self.num_kv_heads = num_kv_heads + self.head_dim = model.config.hidden_size // num_attn_heads + self.model = model + del self.model.lm_head + + self.tokenizer = tokenizer + + # Collect modules to observe + self.name2layer = collect_target_modules(self.model, layer_type) + self.name2fc = {} + for l_name, layer in self.name2layer.items(): + name2fc = collect_target_modules(layer, nn.Linear, prefix=l_name) + self.name2fc.update(name2fc) + self.name2norm = collect_target_modules(self.model, norm_type) + + maps = bimap_name_mod([self.name2layer, self.name2fc, self.name2norm]) + self.name2mod, self.mod2name = maps + + # Initialize observers + self._init_input_observers(self.name2fc) + self._init_output_observers(self.name2norm) + self._init_output_observers(self.name2fc) + self._init_kv_observers(self.name2layer) + + self.device = device + + def _guess_num_heads(self, model): + + if hasattr(model.config, 'num_key_value_heads'): + num_kv_heads = model.config.num_key_value_heads + else: + num_kv_heads = model.config.num_attention_heads + + num_attn_heads = model.config.num_attention_heads + + return num_kv_heads, num_attn_heads + + def _init_input_observers(self, name2mod): + """Initialize input observers for given modules.""" + for name, mod in name2mod.items(): + obs = ActivationObserver(mod.weight.size(-1)) + obs.global_available(name, group=self.inp_obs_group) + + def _init_output_observers(self, name2mod): + """Initialize output observers for given modules.""" + for name, mod in name2mod.items(): + obs = ActivationObserver(mod.weight.size(0)) + obs.global_available(name, group=self.out_obs_group) + + def _init_kv_observers(self, name2mod): + """Initialize KV observers for given modules.""" + for name in name2mod.keys(): + k_obs = KVCacheObserver(self.num_kv_heads, self.head_dim) + v_obs = KVCacheObserver(self.num_kv_heads, self.head_dim) + k_obs.global_available(name, group=self.key_obs_group) + v_obs.global_available(name, group=self.value_obs_group) + + def _insert_input_observers(self): + """Insert input observers into the target modules. + + This function registers a forward pre-hook on each target module to + observe the inputs. + """ + + def _input_hook(mod: nn.Module, inp: torch.Tensor): + m_name = self.mod2name[mod] + obs = ActivationObserver.find(m_name, group=self.inp_obs_group) + obs.observe(inp[0]) + + group = ActivationObserver.find_group(self.inp_obs_group) + for name in group.keys(): + mod = self.name2mod[name] + hook_fn = mod.register_forward_pre_hook(_input_hook) + self._hooks.append(hook_fn) + + def _insert_output_observers(self): + """Insert output observers into the target modules. + + This function registers a forward hook on each target module to observe + the outputs. + """ + + def _output_hook(mod: nn.Module, inp: torch.Tensor, out: torch.Tensor): + m_name = self.mod2name[mod] + obs = ActivationObserver.find(m_name, group=self.out_obs_group) + obs.observe(out) + + group = ActivationObserver.find_group(self.out_obs_group) + for name in group.keys(): + mod = self.name2mod[name] + hook_fn = mod.register_forward_hook(_output_hook) + self._hooks.append(hook_fn) + + def _wrap_decoder_layers(self): + """Method to wrap the decoder layers' forward functions for observing + their key/value cache during batched forward passes.""" + + def _forward(mod, *args, **kwargs): + + mod.to(self.device) + batch_args, batch_kwargs = split_decoder_layer_inputs( + *args, **kwargs) + batch_outputs = [] + samples = len(batch_args) + + m_name = self.mod2name[mod] + k_obs = KVCacheObserver.find(m_name, group=self.key_obs_group) + v_obs = KVCacheObserver.find(m_name, group=self.value_obs_group) + + for i in range(len(batch_args)): + + if k_obs and v_obs: + batch_kwargs[i]['use_cache'] = True + out = self._ori_forwards[mod](*batch_args[i], + **batch_kwargs[i]) + out = list(out) + key, value = out.pop(-1) + k_obs.observe(key) + v_obs.observe(value) + + del key, value + torch.cuda.empty_cache() + batch_outputs.append(tuple(out)) + else: + batch_outputs.append(self._ori_forwards[mod]( + *batch_args[i], **batch_kwargs[i])) + + outputs = concat_decoder_layer_outputs(batch_outputs) + + del batch_outputs, batch_args, batch_kwargs, args + mod.to('cpu') + torch.cuda.empty_cache() + max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024 + print(f'{m_name}, samples: {samples}, ' + f'max gpu memory: {max_memory:.2f} GB') + return outputs + + for layer in self.name2layer.values(): + self._ori_forwards[layer] = layer.forward + layer.forward = partial(_forward, layer) + + def collect_inputs_stats(self): + """Collect statistics (min, max, absmax values) of the observed inputs. + + Returns a dictionary with these collected stats. + """ + inputs_stats = { + 'max': {}, + 'min': {}, + 'mean': {}, + 'absmax': {}, + 'absmean': {} + } + obs_group = ActivationObserver.find_group(self.inp_obs_group) + for name, obs in obs_group.items(): + inputs_stats['max'][name] = obs.max_val + inputs_stats['min'][name] = obs.min_val + inputs_stats['mean'][name] = obs.mean_val + inputs_stats['absmax'][name] = obs.absmax_val + inputs_stats['absmean'][name] = obs.absmean_val + return inputs_stats + + def collect_outputs_stats(self): + """Collect statistics (min, max, absmax values) of the observed + outputs. + + Returns a dictionary with these collected stats. + """ + outputs_stats = { + 'max': {}, + 'min': {}, + 'mean': {}, + 'absmax': {}, + 'absmean': {} + } + obs_group = ActivationObserver.find_group(self.out_obs_group) + for name, obs in obs_group.items(): + outputs_stats['max'][name] = obs.max_val + outputs_stats['min'][name] = obs.min_val + outputs_stats['mean'][name] = obs.mean_val + outputs_stats['absmax'][name] = obs.absmax_val + outputs_stats['absmean'][name] = obs.absmean_val + return outputs_stats + + def collect_kv_stats(self): + """Collect statistics (min, max, absmax values) of the observed keys + and values. + + Returns a tuple of two dictionaries with these collected stats. + """ + key_stats = {'max': {}, 'min': {}, 'absmax': {}} + obs_group = KVCacheObserver.find_group(self.key_obs_group) + for name, obs in obs_group.items(): + key_stats['max'][name] = obs.max_val + key_stats['min'][name] = obs.min_val + key_stats['absmax'][name] = obs.absmax_val + + value_stats = {'max': {}, 'min': {}, 'absmax': {}} + obs_group = KVCacheObserver.find_group(self.value_obs_group) + for name, obs in obs_group.items(): + value_stats['max'][name] = obs.max_val + value_stats['min'][name] = obs.min_val + value_stats['absmax'][name] = obs.absmax_val + return key_stats, value_stats + + def export(self, out_dir): + """Export the calibration statistics (inputs, outputs, keys and values) + to specified directory. + + Args: + out_dir (Union[str, Path]): The directory path where the stats + will be saved. + """ + + inp_stats = self.collect_inputs_stats() + torch.save(inp_stats, out_dir / 'inputs_stats.pth') + + out_stats = self.collect_outputs_stats() + torch.save(out_stats, out_dir / 'outputs_stats.pth') + + key_stats, value_stats = self.collect_kv_stats() + torch.save(key_stats, out_dir / 'key_stats.pth') + torch.save(value_stats, out_dir / 'value_stats.pth') + + def calibrate(self, data): + """Forward pass through the model in inference mode with given data.""" + + if type(self.model).__name__ == 'QWenLMHeadModel': + model = self.model.transformer + else: + model = self.model.model + with torch.inference_mode(): + _ = model(data.to(self.device)) + + def __enter__(self): + """Prepares the Calibration object for a 'with' statement by + registering hooks and wrapping layer forward methods.""" + + self._hooks = list() + + self._ori_forwards = {} + for layer in self.name2layer.values(): + self._ori_forwards[layer] = layer.forward + + self._insert_input_observers() + self._insert_output_observers() + self._wrap_decoder_layers() + + def __exit__(self, exc_type, exc_value, traceback): + """Clean up after a 'with' statement by removing registered hooks, + restoring original forward methods, and if no exception occurred, + collecting all gathered statistics and saving them.""" + for h in self._hooks: + h.remove() + + for layer in self.name2layer.values(): + layer.forward = self._ori_forwards[layer] diff --git a/vllm/kv_quant/export_kv_params.py b/vllm/kv_quant/export_kv_params.py new file mode 100644 index 000000000000..e0cf47d9b751 --- /dev/null +++ b/vllm/kv_quant/export_kv_params.py @@ -0,0 +1,123 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from pathlib import Path +from typing import Union + +import numpy as np +import torch +import fire + + +def _export_sym(key_stats: dict, + value_stats: dict, + bits: int, + out_dir: Union[str, Path], + tp: int = 1) -> None: + """Export symmetric quantization parameters to specified directory.""" + keys_absmax = key_stats['absmax'] + values_absmax = value_stats['absmax'] + for layer_idx, name in enumerate(keys_absmax.keys()): + k_absmax = keys_absmax[name] + v_absmax = values_absmax[name] + + heads, dims = k_absmax.shape + assert heads % tp == 0 + + mp_k_absmax = torch.chunk(k_absmax, tp) + mp_v_absmax = torch.chunk(v_absmax, tp) + for i in range(tp): + # quant: q = f / scale + # dequant: f = q * scale + k_s = mp_k_absmax[i].max() / (2**(bits - 1) - 1) + v_s = mp_v_absmax[i].max() / (2**(bits - 1) - 1) + + kv_qparams = np.array([k_s, v_s], dtype=np.float32) + out_path = out_dir / f'layers.{layer_idx}.past_kv_scale.{i}.weight' # noqa: E501 + kv_qparams.tofile(out_path) + print(f'Layer {layer_idx} MP {i} qparam: {k_s} \t{v_s}') + + +def _export_asym(key_stats: dict, + value_stats: dict, + bits: int, + out_dir: Union[str, Path], + tp: int = 1) -> None: + """Export asymmetric quantization parameters to specified directory.""" + keys_min = key_stats['min'] + values_min = value_stats['min'] + + keys_max = key_stats['max'] + values_max = value_stats['max'] + for layer_idx, name in enumerate(keys_min.keys()): + k_max = keys_max[name] + v_max = values_max[name] + + k_min = keys_min[name] + v_min = values_min[name] + + heads, dims = k_min.shape + assert heads % tp == 0 + + tp_k_min = torch.chunk(k_min, tp) + tp_v_min = torch.chunk(v_min, tp) + + tp_k_max = torch.chunk(k_max, tp) + tp_v_max = torch.chunk(v_max, tp) + for i in range(tp): + # zp = (min+max) / 2 + # scale = (max-min) / 255 + # quant: q = (f-zp) / scale + # dequant: f = q * scale + zp + k_min = tp_k_min[i].min() + v_min = tp_v_min[i].min() + + k_max = tp_k_max[i].max() + v_max = tp_v_max[i].max() + + k_scale = (k_max - k_min) / (2**bits - 1) + v_scale = (v_max - v_min) / (2**bits - 1) + + k_zp = (k_max + k_min) / 2 + v_zp = (v_max + v_min) / 2 + + kv_qparams = np.array([k_scale, k_zp, v_scale, v_zp], + dtype=np.float32) + out_path = out_dir / f'layers.{layer_idx}.past_kv_scale.{i}.weight' + kv_qparams.tofile(out_path) + print(f'Layer {layer_idx} MP {i} qparam: ' + f'\t{k_scale} \t{k_zp} \t{v_scale} \t{v_zp}') + + +def main(work_dir: str, + kv_params_dir: str, + kv_bits: int = 8, + kv_sym: bool = False, + num_tp: int = 1) -> None: + """Main function to export key and value stats. + + Args: + work_dir (Union[str, Path]): Directory path where the stats are saved. + turbomind_dir (Union[str, Path]): Directory path where to + save the results. + kv_bits (int, optional): Number of bits for quantization. + Defaults to 8. + kv_sym (bool, optional): Whether to use symmetric quantizaiton. + Defaults to False. + num_tp (int, optional): Number of tensor parallelism. Defaults to 1. + """ + + work_dir = Path(work_dir) + + tm_dir = Path(kv_params_dir) + assert tm_dir.exists(), 'The specified TurboMind directory does not exist.' + + key_stats = torch.load(work_dir / 'key_stats.pth') + value_stats = torch.load(work_dir / 'value_stats.pth') + + if kv_sym: + _export_sym(key_stats, value_stats, kv_bits, tm_dir, num_tp) + else: + _export_asym(key_stats, value_stats, kv_bits, tm_dir, num_tp) + + +if __name__ == '__main__': + fire.Fire(main) diff --git a/vllm/kv_quant/observer.py b/vllm/kv_quant/observer.py new file mode 100644 index 000000000000..f36a63c0e0df --- /dev/null +++ b/vllm/kv_quant/observer.py @@ -0,0 +1,192 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Union +import torch +from torch import nn + + +class GlobalAvailMixin: + """Mixin class to make instances globally available.""" + + _instances: Dict[str, Dict[Union[str, nn.Module], 'GlobalAvailMixin']] = { + 'default': {} + } + + def global_available(self, + key: Union[str, nn.Module] = 'default', + group: str = 'default') -> None: + """Make the instance globally available. + + Args: + key (Union[str, nn.Module], optional): Key to save the instance. + Defaults to 'default'. + group (str, optional): Group to save the instance. + Defaults to 'default'. + """ + self._save_instance(self, key, group) + + @classmethod + def _save_instance(cls, + instance: 'GlobalAvailMixin', + key: Union[str, nn.Module] = 'default', + group: str = 'default') -> None: + """Save the instance. + + Args: + instance (GlobalAvailMixin): Instance to save. + key (Union[str, nn.Module], optional): Key to save the instance. + Defaults to 'default'. + group (str, optional): Group to save the instance. + Defaults to 'default'. + """ + if group not in cls._instances: + assert isinstance(group, str) + cls._instances[group] = {} + + cls._instances[group][key] = instance + + @classmethod + def find(cls, + key: Union[str, nn.Module] = 'default', + group: str = 'default') -> Union[None, 'GlobalAvailMixin']: + """Find an instance by its key and group. + + Args: + key (Union[str, nn.Module], optional): Key of the instance. + Defaults to 'default'. + group (str, optional): Group of the instance. + Defaults to 'default'. + + Returns: + Union[None, GlobalAvailMixin]: The found instance, or None if + it does not exist. + """ + return cls._instances.get(group, {}).get(key) + + @classmethod + def find_group( + cls, + group: str) -> Dict[Union[str, nn.Module], 'GlobalAvailMixin']: + """Find all instances in a group. + + Args: + group (str): Group of the instances. + + Returns: + Dict[Union[str, nn.Module], GlobalAvailMixin]: All instances in + the group. + """ + return cls._instances.get(group, {}) + + @classmethod + def instances( + cls) -> Dict[str, Dict[Union[str, nn.Module], 'GlobalAvailMixin']]: + """Get all instances.""" + return cls._instances + + +class KVCacheObserver(GlobalAvailMixin): + """A class to observe and record the max, min, and absolute max value of + given tensor.""" + + def __init__(self, num_head: int, head_dim: int) -> None: + """Constructor for KVCacheObserver. + + Args: + num_head : Number of heads + head_dim : Dimension of each head + """ + self.num_head = num_head + self.head_dim = head_dim + self.max_val = torch.full((num_head, head_dim), + -torch.inf, + dtype=torch.float16) + self.min_val = torch.full((num_head, head_dim), + torch.inf, + dtype=torch.float16) + self.absmax_val = torch.full((num_head, head_dim), + 0, + dtype=torch.float16) + + @torch.no_grad() + def observe(self, x: torch.Tensor) -> None: + """Function to observe the input tensor and update the max, min, and + absolute max values. + + Args: + x : Input tensor + """ + assert len(x.shape) == 4 + + if x.size(2) == self.num_head and x.size(3) == self.head_dim: + # layout: (bs, seqlen, heads, dims) + x = x + elif x.size(1) == self.num_head and x.size(3) == self.head_dim: + # layout: (bs, heads, seqlen, dims) + x = x.transpose(1, 2) + else: + raise RuntimeError + + cur_max = x.flatten(0, 1).max(0)[0].cpu() + cur_min = x.flatten(0, 1).min(0)[0].cpu() + cur_absmax = x.flatten(0, 1).abs().max(0)[0].cpu() + + self.max_val = torch.maximum(self.max_val, cur_max) + self.min_val = torch.minimum(self.min_val, cur_min) + self.absmax_val = torch.maximum(self.absmax_val, cur_absmax) + + +class ActivationObserver(GlobalAvailMixin): + """A class to observe and record the max, min, mean, absolute max, and + absolute mean value of a given tensor. + + Also keeps track of the number of batches observed. + """ + + def __init__(self, dim: int) -> None: + """Constructor for ActivationObserver. + + Args: + dim : Dimension of the tensor + """ + self.dim = dim + self.max_val = torch.full((dim, ), -torch.inf, dtype=torch.float16) + self.min_val = torch.full((dim, ), torch.inf, dtype=torch.float16) + self.absmax_val = torch.full((dim, ), 0, dtype=torch.float16) + self.absmean_val = torch.full((dim, ), 0, dtype=torch.float16) + self.mean_val = torch.full((dim, ), 0, dtype=torch.float16) + self.num_batches_tracked = 0 + + @torch.no_grad() + def observe(self, x: torch.Tensor) -> None: + """Function to observe the input tensor and update the max, min, mean, + absolute max, absolute mean values and number of batches tracked. + + Args: + x : Input tensor + """ + assert len(x.shape) == 3 + assert x.size(2) == self.dim + cur_val = x.flatten(0, 1) + cur_max = cur_val.max(0)[0].cpu() + cur_min = cur_val.min(0)[0].cpu() + cur_mean = cur_val.mean(0).cpu() + + cur_abs = cur_val.abs() + cur_absmax = cur_abs.max(0)[0].cpu() + cur_absmean = cur_abs.mean(0).cpu() + + self.max_val = torch.maximum(self.max_val, cur_max) + self.min_val = torch.minimum(self.min_val, cur_min) + self.absmax_val = torch.maximum(self.absmax_val, cur_absmax) + + # Update mean and absmean value with accumulated sum divided + # by total number of batches + self.mean_val = ( + (self.mean_val * self.num_batches_tracked + cur_mean) / + (self.num_batches_tracked + 1)) + self.absmean_val = ( + (self.absmean_val * self.num_batches_tracked + cur_absmean) / + (self.num_batches_tracked + 1)) + + # Increment the count of batches tracked + self.num_batches_tracked += 1 diff --git a/vllm/kv_quant/utils.py b/vllm/kv_quant/utils.py new file mode 100644 index 000000000000..309c48e3c213 --- /dev/null +++ b/vllm/kv_quant/utils.py @@ -0,0 +1,164 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Any, Dict, List, Tuple, Union +import torch +from torch import nn + + +def split_decoder_layer_inputs( + *args: Union[torch.Tensor, Any], **kwargs: Union[torch.Tensor, Any] +) -> Tuple[List[List[Any]], List[Dict[str, Any]]]: + """This function splits batched decoder layer inputs into individual + elements. + + Args: + *args (Union[torch.Tensor, Any]): Positional arguments which could + be a mix of tensors and other types. + **kwargs (Union[torch.Tensor, Any]): Keyword arguments which could + be a mix of tensors and other types. + + Returns: + Tuple[List[List[Any]], List[Dict[str, Any]]]: A tuple containing two + lists, one for positional arguments, one for keyword arguments. + Each list contains individual elements from the batch. + """ + + if not isinstance(args[0], torch.Tensor): + raise ValueError('The first argument must be a Tensor') + + bs = args[0].size(0) + + batch_args = [] + batch_kwargs = [] + for i in range(bs): + new_args = [] + # Iterate over each argument. If it's a torch.Tensor and its first + # dimension equals the batch size, then get the value corresponding + # to the current index, else directly add the whole value. + for val in args: + if isinstance(val, torch.Tensor) and val.size(0) == bs: + new_args.append(val[i:i + 1]) + else: + new_args.append(val) + + new_kwargs = {} + # Execute the same operation for the keyword arguments. + for name, val in kwargs.items(): + if isinstance(val, torch.Tensor) and val.size(0) == bs: + new_kwargs[name] = val[i:i + 1] + else: + new_kwargs[name] = val + + batch_args.append(new_args) + batch_kwargs.append(new_kwargs) + + return batch_args, batch_kwargs + + +def concat_decoder_layer_outputs( + batch_outputs: List[Tuple[Any]]) -> Tuple[Any]: + """This function concatenates individual decoder layer outputs into a + batched output. + + Args: + batch_outputs (List[Tuple[Any]]): A list of tuples, where each tuple + represents the output from an individual element in the batch. + + Returns: + Tuple[Any]: A tuple representing the batched output. + """ + + num_returns = len(batch_outputs[0]) + + def is_past_key_value(data: Any) -> bool: + """Check whether data is a past key-value pair. + + Args: + data (Any): The data to check. + + Returns: + bool: True if data is a past key-value pair, False otherwise. + """ + flag = isinstance(data, tuple) + flag = flag and len(data) == 2 + flag = flag and isinstance(data[0], torch.Tensor) + flag = flag and isinstance(data[1], torch.Tensor) + return flag + + new_outputs = [] + + # Iterate over all types of return values. + for i in range(num_returns): + # Check if the current element is a past key-value pair. + flag = is_past_key_value(batch_outputs[0][i]) + if flag: + # Concatenate the keys and values separately. + key = torch.cat([out[i][0] for out in batch_outputs]) + value = torch.cat([out[i][1] for out in batch_outputs]) + out_i = (key, value) + else: + # If it's not a past key-value pair, concatenate directly. + out_i = torch.cat([out[i] for out in batch_outputs]) + new_outputs.append(out_i) + + return tuple(new_outputs) + + +def collect_target_modules(model: nn.Module, + # target: Union[str, type], + target: str, + skip_names: List[str] = [], + prefix: str = '') -> Dict[str, nn.Module]: + """Collects the specific target modules from the model. + + Args: + model : The PyTorch module from which to collect the target modules. + target : The specific target to be collected. It can be a class of a + module or the name of a module. + skip_names : List of names of modules to be skipped during collection. + prefix : A string to be added as a prefix to the module names. + + Returns: + A dictionary mapping from module names to module instances. + """ + + # if isinstance(target, LazyAttr): + # target = target.build() + + if not isinstance(target, (type, str)): + raise TypeError('Target must be a string (name of the module) ' + 'or a type (class of the module)') + + def _is_target(n, m): + if isinstance(target, str): + return target == type(m).__name__ and n not in skip_names + return isinstance(m, target) and n not in skip_names + + name2mod = {} + for name, mod in model.named_modules(): + m_name = f'{prefix}.{name}' if prefix else name + if _is_target(name, mod): + name2mod[m_name] = mod + return name2mod + + +def bimap_name_mod( + name2mod_mappings: List[Dict[str, nn.Module]] +) -> Tuple[Dict[str, nn.Module], Dict[nn.Module, str]]: + """Generates bidirectional maps from module names to module instances and + vice versa. + + Args: + name2mod_mappings : List of dictionaries each mapping from module + names to module instances. + + Returns: + Two dictionaries providing bidirectional mappings between module + names and module instances. + """ + + name2mod = {} + mod2name = {} + for mapping in name2mod_mappings: + mod2name.update({v: k for k, v in mapping.items()}) + name2mod.update(mapping) + return name2mod, mod2name diff --git a/vllm/model_executor/__init__.py b/vllm/model_executor/__init__.py index 36fc30f9c1e3..e1c687aa5aef 100644 --- a/vllm/model_executor/__init__.py +++ b/vllm/model_executor/__init__.py @@ -5,5 +5,5 @@ __all__ = [ "InputMetadata", "get_model", - "set_random_seed", + "set_random_seed" ] diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 9222fe27218c..1d3b87bca27a 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -21,7 +21,41 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device) activation_ops.silu_and_mul(out, x) return out + +class DequantSiluAndMulQuant(nn.Module): + """An activation function for SwiGLU. + + The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[1] // 2. + + Shapes: + x: (num_tokens, 2 * d) + return: (num_tokens, d) + """ + def __init__(self, scale_in: float = 1.0, scale_out: float = 1.0) -> None: + super().__init__() + self.register_buffer('a', torch.tensor(scale_in, dtype=torch.float32, requires_grad=False)) + self.register_buffer('inscale', torch.tensor(scale_out, dtype=torch.float32, requires_grad=False)) + + def _apply(self, fn): + super()._apply(fn) + self.a = self.a.cpu() + self.inscale = self.inscale.cpu() + return self + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + self.a = self.a.to(*args, **kwargs) + self.a = self.a.to(torch.float32) + self.inscale = self.inscale.to(*args, **kwargs) + self.inscale = self.inscale.to(torch.float32) + return self + def forward(self, x: torch.Tensor) -> torch.Tensor: + num_tokens = x.shape[0] + d = x.shape[1] // 2 + out = torch.empty(num_tokens, d, dtype=torch.int8, device=x.device) + activation_ops.invoke_dequant_silu_and_mul_quant(out, x, self.a.item(), self.a.item(), self.inscale.item()) + return out class NewGELU(nn.Module): diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 5e9360a3c20e..c82df29e1af4 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -10,6 +10,7 @@ from vllm import attention_ops from vllm import cache_ops from vllm import pos_encoding_ops +from vllm import fused_kernels from vllm.model_executor.input_metadata import InputMetadata _SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256] @@ -56,7 +57,9 @@ def __init__(self, num_heads: int, head_size: int, scale: float, - num_kv_heads: Optional[int] = None) -> None: + num_kv_heads: Optional[int] = None, + quant_kv_cache: bool = False, + kv_quant_params: List[float] = None) -> None: super().__init__() self.num_heads = num_heads self.head_size = head_size @@ -65,6 +68,8 @@ def __init__(self, assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads + self.quant_kv_cache = quant_kv_cache + self.kv_quant_params = kv_quant_params self.head_mapping = torch.repeat_interleave( torch.arange(self.num_kv_heads, dtype=torch.int32, device="cuda"), self.num_queries_per_kv) @@ -144,19 +149,35 @@ def single_query_cached_kv_attention( input_metadata: metadata for paged attention. """ block_size = value_cache.shape[3] - attention_ops.single_query_cached_kv_attention( - output, - query, - key_cache, - value_cache, - self.head_mapping, - self.scale, - input_metadata.block_tables, - input_metadata.context_lens, - block_size, - input_metadata.max_context_len, - None, # alibi_slopes - ) + if self.quant_kv_cache: + attention_ops.single_query_cached_kv_quantized_attention( + output, + query, + key_cache, + value_cache, + self.head_mapping, + self.scale, + input_metadata.block_tables, + input_metadata.context_lens, + block_size, + input_metadata.max_context_len, + None, # alibi_slopes + *self.kv_quant_params, + ) + else: + attention_ops.single_query_cached_kv_attention( + output, + query, + key_cache, + value_cache, + self.head_mapping, + self.scale, + input_metadata.block_tables, + input_metadata.context_lens, + block_size, + input_metadata.max_context_len, + None, # alibi_slopes + ) def forward( self, @@ -221,13 +242,23 @@ def forward( if (num_valid_tokens > 0 and key_cache is not None and value_cache is not None): # The stride is 3 because the key and value are sliced from qkv. - cache_ops.reshape_and_cache( - key[:num_valid_tokens], - value[:num_valid_tokens], - key_cache, - value_cache, - input_metadata.slot_mapping, - ) + if self.quant_kv_cache: + cache_ops.reshape_and_cache_quantized( + key[:num_valid_tokens], + value[:num_valid_tokens], + key_cache, + value_cache, + input_metadata.slot_mapping, + *self.kv_quant_params, + ) + else: + cache_ops.reshape_and_cache( + key[:num_valid_tokens], + value[:num_valid_tokens], + key_cache, + value_cache, + input_metadata.slot_mapping, + ) if input_metadata.num_generation_tokens > 0: # Decoding run. @@ -259,8 +290,10 @@ def __init__( base: int = 10000, num_kv_heads: Optional[int] = None, is_neox_style: bool = True, + quant_kv_cache: bool = False, + kv_quant_params: torch.Tensor = None, ) -> None: - super().__init__(num_heads, head_size, scale, num_kv_heads) + super().__init__(num_heads, head_size, scale, num_kv_heads, quant_kv_cache, kv_quant_params) self.is_neox_style = is_neox_style # Create the cos and sin cache. @@ -330,6 +363,122 @@ def forward( ) +class DequantPagedAttentionWithRoPEQuant(PagedAttention): + """PagedAttention with rotary embedding.""" + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + rotary_dim: int, + max_position: int = 8192, + base: int = 10000, + num_kv_heads: Optional[int] = None, + is_neox_style: bool = True, + quant_kv_cache: bool = False, + kv_quant_params: torch.Tensor = None, + dequant_scale: float = 1.0, + quant_scale: float = 1.0 + ) -> None: + super().__init__(num_heads, head_size, scale, num_kv_heads, quant_kv_cache, kv_quant_params) + self.is_neox_style = is_neox_style + + # Create the cos and sin cache. + inv_freq = 1.0 / (base**(torch.arange( + 0, rotary_dim, 2, dtype=torch.float, device="cuda") / rotary_dim)) + t = torch.arange(max_position, dtype=torch.float, device="cuda") + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + + # FIXME(woosuk): This assumes that we configure the default dtype when + # initializing the model. + # TODO(woosuk): Make it more robust. + torch_dtype = torch.get_default_dtype() + cache = cache.to(torch_dtype) + # Embedding size: [max_position, rotary_dim] + self.register_buffer("cos_sin_cache", cache, persistent=False) + self.register_buffer('a', torch.tensor(dequant_scale, dtype=torch.float32, requires_grad=False)) + self.register_buffer('inscale', torch.tensor(quant_scale, dtype=torch.float32, requires_grad=False)) + + def _apply(self, fn): + super()._apply(fn) + self.a = self.a.cpu() + self.inscale = self.inscale.cpu() + return self + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + self.a = self.a.to(*args, **kwargs) + self.a = self.a.to(torch.float32) + self.inscale = self.inscale.to(*args, **kwargs) + self.inscale = self.inscale.to(torch.float32) + return self + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + input_metadata: InputMetadata, + cache_event: Optional[torch.cuda.Event], + ) -> torch.Tensor: + """ PagedAttention forward pass with rotary embedding. + + Args: + positions: shape = [num_tokens] + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + key_cache: shape = [num_blocks, num_kv_heads, head_size/x, + block_size, x] + value_cache: shape = [num_blocks, num_kv_heads, head_size, + block_size] + input_metadata: metadata for paged attention. + cache_event: event to wait for the cache operations to finish. + + Returns: + shape = [num_tokens, num_heads * head_size] + """ + + # Apply rotary embedding to the query and key before passing them + # to the attention op. + query_dequant = torch.empty_like(query, dtype=self.cos_sin_cache.dtype) + key_dequant = torch.empty_like(key, dtype=self.cos_sin_cache.dtype) + value_dequant = torch.empty_like(value, dtype=self.cos_sin_cache.dtype) + + fused_kernels.invoke_dequant(value_dequant, value, self.a.item()) + pos_encoding_ops.invoke_dequant_rotary_embedding( + positions, + query, + query_dequant, + key, + key_dequant, + self.head_size, + self.cos_sin_cache, + self.a.item(), + self.a.item(), + self.is_neox_style, + ) + out = super().forward( + query_dequant, + key_dequant, + value_dequant, + key_cache, + value_cache, + input_metadata, + cache_event, + ) + quant_out = torch.empty_like(out, dtype=torch.int8) + fused_kernels.invoke_quant(quant_out, out, self.inscale.item()) + return quant_out + + class PagedAttentionWithALiBi(PagedAttention): """PagedAttention with ALiBi attention bias.""" diff --git a/vllm/model_executor/layers/fusion.py b/vllm/model_executor/layers/fusion.py new file mode 100644 index 000000000000..11832e3d8738 --- /dev/null +++ b/vllm/model_executor/layers/fusion.py @@ -0,0 +1,28 @@ +import torch +import torch.nn as nn + +from vllm import fused_kernels + + +class DequantAddResidual(nn.Module): + def __init__(self, scale: float = 1.0) -> None: + super().__init__() + self.register_buffer( + "a", torch.tensor(scale, dtype=torch.float32, requires_grad=False) + ) + + def _apply(self, fn): + super()._apply(fn) + self.a = self.a.cpu() + return self + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + self.a = self.a.to(*args, **kwargs) + self.a = self.a.to(torch.float32) + return self + + def forward(self, residual: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(residual) + fused_kernels.invoke_dequant_add_residual(out, x, residual, self.a.item()) + return out \ No newline at end of file diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 731bc7cbf53f..abe8a9844ec9 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -30,3 +30,61 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: self.variance_epsilon, ) return out + +class I8RMSNorm(nn.Module): + """Root mean square normalization. + + Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight. + Refer to https://arxiv.org/abs/1910.07467 + """ + + def __init__( + self, + hidden_size: int, + eps: float = 1e-6, + ) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x, dtype=torch.int8) + layernorm_ops.invoke_rms_norm_quant(out, x, self.weight.data, self.variance_epsilon) + return out + + +class DequantAddResidualI8RMSNormQuant(nn.Module): + """Root mean square normalization. + + Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight. + Refer to https://arxiv.org/abs/1910.07467 + """ + + def __init__( + self, + hidden_size: int, + scale: float = 1.0, + eps: float = 1e-6, + ) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.register_buffer( + "a", torch.tensor(scale, dtype=torch.float32, requires_grad=False) + ) + self.variance_epsilon = eps + + def _apply(self, fn): + super()._apply(fn) + self.a = self.a.cpu() + return self + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + self.a = self.a.to(*args, **kwargs) + self.a = self.a.to(torch.float32) + return self + + def forward(self, residual: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x, dtype=torch.int8) + layernorm_ops.invoke_dequant_add_residual_rms_norm_quant(out, x, residual, self.weight.data, self.variance_epsilon, self.a.item()) + return residual, out diff --git a/vllm/model_executor/layers/quantized_linear/__init__.py b/vllm/model_executor/layers/quantized_linear/__init__.py index bcb9a54e7a2c..e9781f1f73c6 100644 --- a/vllm/model_executor/layers/quantized_linear/__init__.py +++ b/vllm/model_executor/layers/quantized_linear/__init__.py @@ -1,10 +1,14 @@ from vllm.model_executor.layers.quantized_linear.awq import ( AWQColumnParallelLinear, AWQRowParallelLinear) +from vllm.model_executor.layers.quantized_linear.smoothquant import ( + SQColumnParallelLinear, SQRowParallelLinear) from vllm.model_executor.parallel_utils.tensor_parallel import ( ColumnParallelLinear, RowParallelLinear) + _QUANTIZED_LINEAR_REGISTRY = { "awq": (AWQColumnParallelLinear, AWQRowParallelLinear), + "smoothquant": (SQColumnParallelLinear, SQRowParallelLinear) } diff --git a/vllm/model_executor/layers/quantized_linear/smoothquant.py b/vllm/model_executor/layers/quantized_linear/smoothquant.py new file mode 100644 index 000000000000..8ee3e5238d16 --- /dev/null +++ b/vllm/model_executor/layers/quantized_linear/smoothquant.py @@ -0,0 +1,56 @@ +from typing import Optional + +import torch +from torch.nn.parameter import Parameter +from vllm.model_executor.parallel_utils.tensor_parallel.layers import ( + ColumnParallelLinear, RowParallelLinear) +from vllm.i8cugemm import I8CUGEMM +i8cugemm = I8CUGEMM() + +class SQColumnParallelLinear(ColumnParallelLinear): + + def create_weights(self, dtype: torch.dtype) -> None: + assert self.input_size % self.quant_config.weight_bits == 0 + self.register_buffer('weight', + torch.randint(-127, + 127, + (self.output_size_per_partition, + self.input_size), + dtype=torch.int8, + requires_grad=False)) + + def apply_weights( + self, + x: torch.Tensor, + bias: Optional[torch.Tensor], + ) -> torch.Tensor: + + + x_shape = x.shape + x = x.view(-1, x_shape[-1]) + y = torch.empty((x.shape[0], self.output_size_per_partition), dtype=torch.int32, device=x.device) + i8cugemm.linear_a8_w8_o32_(x, self.weight, y) + y = y.view(*x_shape[:-1], -1) + return y + +class SQRowParallelLinear(RowParallelLinear): + + def create_weights(self, dtype: torch.dtype) -> None: + assert (self.input_size_per_partition % + self.quant_config.weight_bits == 0) + self.register_buffer('weight', + torch.randint(-127, + 127, + (self.output_size, + self.input_size_per_partition), + dtype=torch.int8, + requires_grad=False)) + + + def apply_weights(self, x: torch.Tensor) -> torch.Tensor: + x_shape = x.shape + x = x.view(-1, x_shape[-1]) + y = torch.empty((x.shape[0], self.output_size), dtype=torch.int32, device=x.device) + i8cugemm.linear_a8_w8_o32_(x, self.weight, y) + y = y.view(*x_shape[:-1], -1) + return y \ No newline at end of file diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index 526b4f8b5c87..51b64c7bedaa 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -2,11 +2,12 @@ import contextlib from typing import Type +import numpy as np import torch import torch.nn as nn from transformers import PretrainedConfig -from vllm.config import ModelConfig +from vllm.config import ModelConfig, ParallelConfig from vllm.model_executor.models import * # pylint: disable=wildcard-import from vllm.model_executor.weight_utils import (get_quant_config, initialize_dummy_weights) @@ -25,6 +26,8 @@ "InternLMForCausalLM": InternLMForCausalLM, "LlamaForCausalLM": LlamaForCausalLM, "LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-* + # "LlamaForCausalLM": LlamaQForCausalLM, + # "LLaMAForCausalLM": LlamaQForCausalLM, # For decapoda-research/llama-* "MPTForCausalLM": MPTForCausalLM, "OPTForCausalLM": OPTForCausalLM, "QWenLMHeadModel": QWenLMHeadModel, @@ -56,7 +59,9 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: f"Supported architectures: {list(_MODEL_REGISTRY.keys())}") -def get_model(model_config: ModelConfig) -> nn.Module: +def get_model(model_config: ModelConfig, + parallel_config: ParallelConfig, + rank: int) -> nn.Module: model_class = _get_model_architecture(model_config.hf_config) # Get the quantization config. @@ -86,10 +91,17 @@ def get_model(model_config: ModelConfig) -> nn.Module: with _set_default_torch_dtype(model_config.dtype): # Create a model instance. # The weights will be initialized as empty tensors. + num_layers = model_config.get_num_layers(parallel_config) + kv_quant_params_list = [] + if model_config.quant_kv_cache: + for i in range(num_layers): + path = model_config.kv_quant_params_path + f"/layers.{i}.past_kv_scale.{rank}.weight" + kv_quant_params = list(np.fromfile(path, dtype=np.float32)) + kv_quant_params_list.append(kv_quant_params) if model_class in _MODEL_CLASSES_SUPPORT_QUANTIZATION: - model = model_class(model_config.hf_config, quant_config) + model = model_class(model_config.hf_config, quant_config, model_config.quant_kv_cache, kv_quant_params_list) else: - model = model_class(model_config.hf_config) + model = model_class(model_config.hf_config, None, model_config.quant_kv_cache, kv_quant_params_list) if model_config.load_format == "dummy": model = model.cuda() # NOTE(woosuk): For accurate performance evaluation, we assign diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index f20e5d8e6f20..4481e85fee5d 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -9,6 +9,7 @@ from vllm.model_executor.models.gpt_neox import GPTNeoXForCausalLM from vllm.model_executor.models.internlm import InternLMForCausalLM from vllm.model_executor.models.llama import LlamaForCausalLM +# from vllm.model_executor.models.llamaq import LlamaQForCausalLM from vllm.model_executor.models.mpt import MPTForCausalLM from vllm.model_executor.models.opt import OPTForCausalLM from vllm.model_executor.models.qwen import QWenLMHeadModel @@ -25,6 +26,7 @@ "GPTNeoXForCausalLM", "InternLMForCausalLM", "LlamaForCausalLM", + # "LlamaQForCausalLM", "MPTForCausalLM", "OPTForCausalLM", "QWenLMHeadModel", diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 0b7f4181a150..995974e8ac9a 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -32,11 +32,12 @@ from transformers import LlamaConfig from vllm.model_executor.input_metadata import InputMetadata -from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.attention import PagedAttentionWithRoPE +from vllm.model_executor.layers.activation import SiluAndMul, DequantSiluAndMulQuant +from vllm.model_executor.layers.layernorm import RMSNorm, I8RMSNorm, DequantAddResidualI8RMSNormQuant +from vllm.model_executor.layers.attention import PagedAttentionWithRoPE, DequantPagedAttentionWithRoPEQuant from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.quantized_linear import ParallelLinear +from vllm.model_executor.layers.fusion import DequantAddResidual from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.parallel_utils.tensor_parallel import ( @@ -60,25 +61,33 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() + self.use_int8 = quant_config is not None and quant_config.get_name() == "smoothquant" + + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.gate_up_proj = ParallelLinear.column(hidden_size, - 2 * intermediate_size, - bias=False, - gather_output=False, - perform_initialization=False, - quant_config=quant_config) + 2 * intermediate_size, + bias=False, + gather_output=False, + perform_initialization=False, + quant_config=quant_config) self.down_proj = ParallelLinear.row(intermediate_size, hidden_size, bias=False, input_is_parallel=True, perform_initialization=False, - quant_config=quant_config) - if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") - self.act_fn = SiluAndMul() + quant_config=quant_config) + # kernel fusion for int8 inference + if self.use_int8: + self.act_fn = DequantSiluAndMulQuant() + else: + self.act_fn = SiluAndMul() def forward(self, x): gate_up, _ = self.gate_up_proj(x) + # FIXME: currently gate up share same scale, plan to use seperate scales x = self.act_fn(gate_up) x, _ = self.down_proj(x) return x @@ -93,6 +102,8 @@ def __init__( num_kv_heads: int, rope_theta: float = 10000, quant_config: Optional[QuantizationConfig] = None, + quant_kv_cache: bool = False, + kv_quant_params: List[float] = None ) -> None: super().__init__() self.hidden_size = hidden_size @@ -108,6 +119,7 @@ def __init__( self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta + self.use_int8 = quant_config is not None and quant_config.get_name() == "smoothquant" self.qkv_proj = ParallelLinear.column( hidden_size, @@ -126,12 +138,26 @@ def __init__( perform_initialization=False, quant_config=quant_config, ) - self.attn = PagedAttentionWithRoPE(self.num_heads, - self.head_dim, - self.scaling, - base=self.rope_theta, - rotary_dim=self.head_dim, - num_kv_heads=self.num_kv_heads) + + # kernel fusion for int8 inference + if self.use_int8: + self.attn = DequantPagedAttentionWithRoPEQuant(self.num_heads, + self.head_dim, + self.scaling, + base=self.rope_theta, + rotary_dim=self.head_dim, + num_kv_heads=self.num_kv_heads, + quant_kv_cache=quant_kv_cache, + kv_quant_params=kv_quant_params) + else: + self.attn = PagedAttentionWithRoPE(self.num_heads, + self.head_dim, + self.scaling, + base=self.rope_theta, + rotary_dim=self.head_dim, + num_kv_heads=self.num_kv_heads, + quant_kv_cache=quant_kv_cache, + kv_quant_params=kv_quant_params) def forward( self, @@ -142,11 +168,14 @@ def forward( cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) + # qkv = qkv.half() + # FIXME: currently qkv share same scale, plan to use seperate scales q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) k_cache, v_cache = kv_cache attn_output = self.attn(positions, q, k, v, k_cache, v_cache, input_metadata, cache_event) output, _ = self.o_proj(attn_output) + # output = output.half() return output @@ -156,9 +185,12 @@ def __init__( self, config: LlamaConfig, quant_config: Optional[QuantizationConfig] = None, + quant_kv_cache: bool = False, + kv_quant_params: List[float] = None ) -> None: super().__init__() self.hidden_size = config.hidden_size + self.use_int8 = quant_config is not None and quant_config.get_name() == "smoothquant" # Requires transformers > 4.32.0 rope_theta = getattr(config, "rope_theta", 10000) self.self_attn = LlamaAttention( @@ -167,6 +199,8 @@ def __init__( num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, quant_config=quant_config, + quant_kv_cache=quant_kv_cache, + kv_quant_params=kv_quant_params ) self.mlp = LlamaMLP( hidden_size=self.hidden_size, @@ -174,10 +208,18 @@ def __init__( hidden_act=config.hidden_act, quant_config=quant_config, ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + if self.use_int8: + self.input_layernorm = I8RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + # kernel fusion, post_attention_layernorm are fused into DequantAddResidualI8RMSNormQuant + self.dequant_add_residual_layernorm_quant = DequantAddResidualI8RMSNormQuant(config.hidden_size, + eps=config.rms_norm_eps) + self.dequant_add_residual = DequantAddResidual() + else: + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) def forward( self, @@ -197,13 +239,19 @@ def forward( input_metadata=input_metadata, cache_event=cache_event, ) - hidden_states = residual + hidden_states - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states + if self.use_int8: + residual, hidden_states = self.dequant_add_residual_layernorm_quant(residual, hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = self.dequant_add_residual(residual, hidden_states) + else: + hidden_states = residual + hidden_states + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states @@ -213,6 +261,8 @@ def __init__( self, config: LlamaConfig, quant_config: Optional[QuantizationConfig] = None, + quant_kv_cache: bool = False, + kv_quant_params_list: List[List[float]] = None ) -> None: super().__init__() self.config = config @@ -222,9 +272,10 @@ def __init__( vocab_size = ((config.vocab_size + 63) // 64) * 64 self.embed_tokens = VocabParallelEmbedding( vocab_size, config.hidden_size, perform_initialization=False) + self.layers = nn.ModuleList([ - LlamaDecoderLayer(config, quant_config) - for _ in range(config.num_hidden_layers) + LlamaDecoderLayer(config, quant_config, quant_kv_cache, kv_quant_params_list[i] if quant_kv_cache else None) + for i in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -260,11 +311,13 @@ def __init__( self, config: LlamaConfig, quant_config: Optional[QuantizationConfig] = None, + quant_kv_cache: bool = False, + kv_quant_params_list: List[List[float]] = None ) -> None: super().__init__() self.config = config self.quant_config = quant_config - self.model = LlamaModel(config, quant_config) + self.model = LlamaModel(config, quant_config, quant_kv_cache, kv_quant_params_list) vocab_size = ((config.vocab_size + 63) // 64) * 64 # NOTE: The LM head is not quantized. self.lm_head = ParallelLinear.column(config.hidden_size, @@ -291,12 +344,28 @@ def forward( _column_parallel_layers = [] _row_parallel_layers = ["o_proj", "down_proj"] + _int8_scale_params = { + "self_attn.q_proj.a": "self_attn.attn.a", + "self_attn.k_proj.a": "self_attn.attn.a", + "self_attn.v_proj.a": "self_attn.attn.a", + "self_attn.o_proj.inscale": "self_attn.attn.inscale", + "self_attn.o_proj.a": "dequant_add_residual_layernorm_quant.a", + "post_attention_layernorm.weight": "dequant_add_residual_layernorm_quant.weight", + "mlp.gate_proj.a": "mlp.act_fn.a", + "mlp.up_proj.a": "mlp.act_fn.a", + "mlp.down_proj.inscale": "mlp.act_fn.inscale", + "mlp.down_proj.a": "dequant_add_residual.a" + } def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, load_format: str = "auto", revision: Optional[str] = None): + int8_fusion = False + if self.quant_config is not None and self.quant_config.get_name() == "smoothquant": + int8_fusion = True + if self.quant_config is None: weight_suffixes = ["weight"] else: @@ -330,6 +399,9 @@ def load_weights(self, model_name_or_path, cache_dir, load_format, revision): if "rotary_emb.inv_freq" in name: continue + # bias is useless for llama + if "bias" in name: + continue is_packed = False is_transposed = False @@ -340,6 +412,17 @@ def load_weights(self, loaded_weight = convert_pyslice_to_tensor(loaded_weight) loaded_weight = loaded_weight.T + if int8_fusion: + is_fusion_weight = False + for weight_name in self._int8_scale_params.keys(): + if weight_name not in name: + continue + param = state_dict[name.replace(weight_name, self._int8_scale_params[weight_name])] + param.copy_(loaded_weight) + is_fusion_weight = True + if is_fusion_weight: + continue + is_attention_weight = False for weight_name, shard_size, offset in attention_weight_specs: if weight_name not in name: @@ -397,4 +480,4 @@ def load_weights(self, load_tensor_parallel_weights(param, loaded_weight, name, column_parallel_weights, row_parallel_weights, - tensor_model_parallel_rank) + tensor_model_parallel_rank) \ No newline at end of file diff --git a/vllm/model_executor/quantization_utils/__init__.py b/vllm/model_executor/quantization_utils/__init__.py index df67758f7110..5fb6547ffeb0 100644 --- a/vllm/model_executor/quantization_utils/__init__.py +++ b/vllm/model_executor/quantization_utils/__init__.py @@ -1,10 +1,12 @@ from typing import Type from vllm.model_executor.quantization_utils.awq import AWQConfig +from vllm.model_executor.quantization_utils.smoothquant import SmoothQuantConfig from vllm.model_executor.quantization_utils.base import QuantizationConfig _QUANTIZATION_REGISTRY = { "awq": AWQConfig, + "smoothquant": SmoothQuantConfig, } diff --git a/vllm/model_executor/quantization_utils/smoothquant.py b/vllm/model_executor/quantization_utils/smoothquant.py new file mode 100644 index 000000000000..1b4f64a94573 --- /dev/null +++ b/vllm/model_executor/quantization_utils/smoothquant.py @@ -0,0 +1,73 @@ +from typing import Any, Dict, List + +import torch + +from vllm.model_executor.quantization_utils.base import QuantizationConfig + + +class SmoothQuantConfig(QuantizationConfig): + """Config class for SmoothQuant + + Reference: https://github.com/mit-han-lab/smoothquant + """ + + def __init__( + self, + weight_bits: int = 8, + quant_type: str = "tensor" + ) -> None: + self.weight_bits = weight_bits + self.quant_type = quant_type + + if self.weight_bits != 8: + raise ValueError( + "Currently, only w8a8 quantization is supported for " + f"SmoothQuant, but got {self.weight_bits} bits.") + if self.quant_type != "tensor": + raise ValueError( + "Currently, only tensor wise quantization is supported for " + f"SmoothQuant, but got {self.quant_type} type quantization.") + + def __repr__(self) -> str: + return (f"SmoothQuantConfig(weight_bits={self.weight_bits}, " + f"quant_type={self.quant_type})") + + @classmethod + def get_name(cls) -> str: + return "smoothquant" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half, torch.float] + + @classmethod + def get_min_capability(cls) -> int: + # The smoothquant kernel only supports Ampere or newer GPUs. + return 80 + + @classmethod + def get_config_filenames(cls) -> List[str]: + """List of filenames to search for in the model directory.""" + return [ + "quant_config.json", + "quantize_config.json", + ] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "SmoothQuantConfig": + weight_bits = cls.get_from_keys(config, ["w_bit", "bits"]) + quant_type = cls.get_from_keys(config, ["quant_type", "q_type"]) + return cls(weight_bits, quant_type) + + @classmethod + def get_packed_tensor_names(cls) -> List[str]: + return [] + + @classmethod + def get_transposed_tensor_names(cls) -> List[str]: + return [] + + @classmethod + def get_tp_tensor_names(cls) -> List[str]: + return ["weight"] + diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 3d5a723d9d42..2f3fd3237042 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -34,7 +34,8 @@ def __init__( self.head_size = model_config.get_head_size() self.num_layers = model_config.get_num_layers(parallel_config) self.num_heads = model_config.get_num_heads(parallel_config) - self.dtype = model_config.dtype + ## for kv cache quantization + self.dtype = model_config.kv_cache_dtype self.block_size = cache_config.block_size self.num_gpu_blocks = cache_config.num_gpu_blocks @@ -152,7 +153,7 @@ def get_cache_block_size( key_cache_block = block_size * num_heads * head_size value_cache_block = key_cache_block total = num_layers * (key_cache_block + value_cache_block) - dtype_size = _get_dtype_size(model_config.dtype) + dtype_size = _get_dtype_size(model_config.kv_cache_dtype) return dtype_size * total diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 2d2021d9fe95..321f352ab0ad 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -64,7 +64,7 @@ def init_model(self): # Initialize the model. set_random_seed(self.model_config.seed) - self.model = get_model(self.model_config) + self.model = get_model(self.model_config, self.parallel_config, self.rank) @torch.inference_mode() def profile_num_available_blocks(