From a4c78043a3b9496ec2ca4d995e6b2f6fbe8680d9 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 2 Oct 2023 05:00:38 +0900 Subject: [PATCH 01/10] Add vllm kernels --- CMakeLists.txt | 1 + cmake/modules/CUDA.cmake | 2 + cmake/modules/contrib/vllm.cmake | 24 + .../contrib/vllm/attention_generic.cuh | 64 +++ src/runtime/contrib/vllm/attention_kernels.cu | 476 ++++++++++++++++++ src/runtime/contrib/vllm/attention_utils.cuh | 55 ++ src/runtime/contrib/vllm/cache_alloc.cc | 55 ++ src/runtime/contrib/vllm/cache_kernels.cu | 91 ++++ src/runtime/contrib/vllm/dtype_float16.cuh | 444 ++++++++++++++++ src/runtime/contrib/vllm/dtype_float32.cuh | 268 ++++++++++ tests/python/relax/test_contrib_vllm.py | 161 ++++++ 11 files changed, 1641 insertions(+) create mode 100644 cmake/modules/contrib/vllm.cmake create mode 100644 src/runtime/contrib/vllm/attention_generic.cuh create mode 100644 src/runtime/contrib/vllm/attention_kernels.cu create mode 100644 src/runtime/contrib/vllm/attention_utils.cuh create mode 100644 src/runtime/contrib/vllm/cache_alloc.cc create mode 100644 src/runtime/contrib/vllm/cache_kernels.cu create mode 100644 src/runtime/contrib/vllm/dtype_float16.cuh create mode 100644 src/runtime/contrib/vllm/dtype_float32.cuh create mode 100644 tests/python/relax/test_contrib_vllm.py diff --git a/CMakeLists.txt b/CMakeLists.txt index ff1d51ce121f..5faad6d619b8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -571,6 +571,7 @@ include(cmake/modules/contrib/VitisAI.cmake) include(cmake/modules/contrib/Verilator.cmake) include(cmake/modules/contrib/UMA.cmake) include(cmake/modules/contrib/MSC.cmake) +include(cmake/modules/contrib/vllm.cmake) include(cmake/modules/Git.cmake) include(cmake/modules/LibInfo.cmake) include(cmake/modules/RustExt.cmake) diff --git a/cmake/modules/CUDA.cmake b/cmake/modules/CUDA.cmake index ce561c66a6a2..55ba93f6cb28 100644 --- a/cmake/modules/CUDA.cmake +++ b/cmake/modules/CUDA.cmake @@ -38,6 +38,8 @@ if(USE_CUDA) list(APPEND TVM_RUNTIME_LINKER_LIBS ${CUDA_CUDA_LIBRARY}) list(APPEND TVM_RUNTIME_LINKER_LIBS ${CUDA_NVRTC_LIBRARY}) + set(CMAKE_CUDA_ARCHITECTURES "86;80") + if(USE_CUDNN) message(STATUS "Build with cuDNN support") include_directories(SYSTEM ${CUDA_CUDNN_INCLUDE_DIRS}) diff --git a/cmake/modules/contrib/vllm.cmake b/cmake/modules/contrib/vllm.cmake new file mode 100644 index 000000000000..b27184079926 --- /dev/null +++ b/cmake/modules/contrib/vllm.cmake @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +if(USE_VLLM) + message(STATUS "Build with vllm paged attention kernel.") + include_directories(src/runtime/contrib/vllm) + set(CMAKE_CUDA_ARCHITECTURES 80) # without this, cmake tries to compile with compute_52 + tvm_file_glob(GLOB VLLM_CONTRIB_SRC src/runtime/contrib/vllm/*.cu src/runtime/contrib/vllm/*.cc) + list(APPEND RUNTIME_SRCS ${VLLM_CONTRIB_SRC}) +endif(USE_VLLM) diff --git a/src/runtime/contrib/vllm/attention_generic.cuh b/src/runtime/contrib/vllm/attention_generic.cuh new file mode 100644 index 000000000000..31fb401cbe2c --- /dev/null +++ b/src/runtime/contrib/vllm/attention_generic.cuh @@ -0,0 +1,64 @@ +/* + * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +namespace vllm { + +// A vector type to store Q, K, V elements. +template +struct Vec {}; + +// A vector type to store FP32 accumulators. +template +struct FloatVec {}; + +// Template vector operations. +template +inline __device__ Acc mul(A a, B b); + +template +inline __device__ float sum(T v); + +template +inline __device__ float dot(T a, T b) { + return sum(mul(a, b)); +} + +template +inline __device__ float dot(T a, T b) { + return sum(mul(a, b)); +} + +template +inline __device__ void zero(T& dst) { + constexpr int WORDS = sizeof(T) / 4; + union { + T raw; + uint32_t words[WORDS]; + } tmp; + +#pragma unroll + for (int ii = 0; ii < WORDS; ++ii) { + tmp.words[ii] = 0u; + } + dst = tmp.raw; +} + +} // namespace vllm diff --git a/src/runtime/contrib/vllm/attention_kernels.cu b/src/runtime/contrib/vllm/attention_kernels.cu new file mode 100644 index 000000000000..b38acd6c102b --- /dev/null +++ b/src/runtime/contrib/vllm/attention_kernels.cu @@ -0,0 +1,476 @@ +/* + * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "dtype_float16.cuh" +#include "attention_utils.cuh" + +#include + +#include +#include +#include + +#define WARP_SIZE 32 +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) + +namespace vllm { + +// Utility function for attention softmax. +template +inline __device__ float block_sum(float* red_smem, float sum) { + // Decompose the thread index into warp / lane. + int warp = threadIdx.x / WARP_SIZE; + int lane = threadIdx.x % WARP_SIZE; + + // Compute the sum per warp. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + } + + // Warp leaders store the data to shared memory. + if (lane == 0) { + red_smem[warp] = sum; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // The warps compute the final sums. + if (lane < NUM_WARPS) { + sum = red_smem[lane]; + } + + // Parallel reduction inside the warp. +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + } + + // Broadcast to other threads. + return __shfl_sync(uint32_t(-1), sum, 0); +} + +// Grid: (num_heads, num_seqs). +template< + typename scalar_t, + int HEAD_SIZE, + int BLOCK_SIZE, + int NUM_THREADS> +__global__ void single_query_cached_kv_attention_kernel( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] + const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] + const int* __restrict__ head_mapping, // [num_heads] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, + const int kv_block_stride, + const int kv_head_stride) { + constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); + constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS + assert(NUM_THREADS % THREAD_GROUP_SIZE == 0); + constexpr int NUM_TOKENS_PER_THREAD_GROUP = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE; + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + const int thread_idx = threadIdx.x; + const int warp_idx = thread_idx / WARP_SIZE; + const int lane = thread_idx % WARP_SIZE; + + const int head_idx = blockIdx.x; + const int num_heads = gridDim.x; + const int kv_head_idx = head_mapping[head_idx]; + const int seq_idx = blockIdx.y; + const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; + + // A vector type to store a part of a key or a query. + // The vector size is configured in such a way that the threads in a thread group + // fetch or compute 16 bytes at a time. + // For example, if the size of a thread group is 4 and the data type is half, + // then the vector size is 16 / (4 * sizeof(half)) == 2. + constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1); + using K_vec = typename Vec::Type; + using Q_vec = typename Vec::Type; + + constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE; + constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE; + + const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE; + const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE; + + // Load the query to registers. + // Each thread in a thread group has a different part of the query. + // For example, if the the thread group size is 4, then the first thread in the group + // has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ... + // th vectors of the query, and so on. + // NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous. + const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; + __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD]; +#pragma unroll + for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) { + const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; + q_vecs[thread_group_offset][i] = *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); + } + __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs + + // Memory planning. + extern __shared__ char shared_mem[]; + // NOTE(woosuk): We use FP32 for the softmax logits for better accuracy. + float* logits = reinterpret_cast(shared_mem); + // Workspace for reduction. + __shared__ float red_smem[2 * NUM_WARPS]; + + // x == THREAD_GROUP_SIZE * VEC_SIZE + // Each thread group fetches x elements from the key at a time. + constexpr int x = 16 / sizeof(scalar_t); + float qk_max = -FLT_MAX; + + const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; + const int context_len = context_lens[seq_idx]; + const int num_blocks = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE; + + // Iterate over the key blocks. + // Each warp fetches a block of keys for each iteration. + // Each thread group in a warp fetches a key from the block, and computes + // dot product with the query. + for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) { + const int physical_block_number = block_table[block_idx]; + + // Load a key to registers. + // Each thread in a thread group has a different part of the key. + // For example, if the the thread group size is 4, then the first thread in the group + // has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th + // vectors of the key, and so on. + for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { + const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + K_vec k_vecs[NUM_VECS_PER_THREAD]; + +#pragma unroll + for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { + const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride + + physical_block_offset * x; + const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; + const int offset1 = (vec_idx * VEC_SIZE) / x; + const int offset2 = (vec_idx * VEC_SIZE) % x; + k_vecs[j] = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); + } + + // Compute dot product. + // This includes a reduction across the threads in the same thread group. + float qk = scale * Qk_dot::dot(q_vecs[thread_group_offset], k_vecs); + // Add the ALiBi bias if slopes are given. + qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0; + + if (thread_group_offset == 0) { + // Store the partial reductions to shared memory. + // NOTE(woosuk): It is required to zero out the masked logits. + const bool mask = token_idx >= context_len; + logits[token_idx] = mask ? 0.f : qk; + // Update the max value. + qk_max = mask ? qk_max : fmaxf(qk_max, qk); + } + } + } + + // Perform reduction across the threads in the same warp to get the + // max qk value for each "warp" (not across the thread block yet). + // The 0-th thread of each thread group already has its max qk value. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + } + if (lane == 0) { + red_smem[warp_idx] = qk_max; + } + __syncthreads(); + + // TODO(woosuk): Refactor this part. + // Get the max qk value for the sequence. + qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + } + // Broadcast the max qk value to all threads. + qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); + + // Get the sum of the exp values. + float exp_sum = 0.f; + for (int i = thread_idx; i < context_len; i += NUM_THREADS) { + float val = __expf(logits[i] - qk_max); + logits[i] = val; + exp_sum += val; + } + exp_sum = block_sum(&red_smem[NUM_WARPS], exp_sum); + + // Compute softmax. + const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); + for (int i = thread_idx; i < context_len; i += NUM_THREADS) { + logits[i] *= inv_sum; + } + __syncthreads(); + + // Each thread will fetch 16 bytes from the value cache at a time. + constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE); + using V_vec = typename Vec::Type; + using L_vec = typename Vec::Type; + using Float_L_vec = typename FloatVec::Type; + + constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; + constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; + constexpr int NUM_ROWS_PER_THREAD = (HEAD_SIZE + NUM_ROWS_PER_ITER - 1) / NUM_ROWS_PER_ITER; + + // NOTE(woosuk): We use FP32 for the accumulator for better accuracy. + float accs[NUM_ROWS_PER_THREAD]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + accs[i] = 0.f; + } + + scalar_t zero_value; + zero(zero_value); + for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) { + const int physical_block_number = block_table[block_idx]; + const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + L_vec logits_vec; + from_float(logits_vec, *reinterpret_cast(logits + token_idx)); + + const scalar_t* v_ptr = v_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE) { + const int offset = row_idx * BLOCK_SIZE + physical_block_offset; + V_vec v_vec = *reinterpret_cast(v_ptr + offset); + if (block_idx == num_blocks - 1) { + // NOTE(woosuk): When v_vec contains the tokens that are out of the context, + // we should explicitly zero out the values since they may contain NaNs. + // See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 + scalar_t* v_vec_ptr = reinterpret_cast(&v_vec); +#pragma unroll + for (int j = 0; j <= V_VEC_SIZE; j++) { + v_vec_ptr[j] = token_idx + j < context_len ? v_vec_ptr[j] : zero_value; + } + } + accs[i] += dot(logits_vec, v_vec); + } + } + } + + // Perform reduction within each warp. +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + float acc = accs[i]; +#pragma unroll + for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { + acc += __shfl_xor_sync(uint32_t(-1), acc, mask); + } + accs[i] = acc; + } + + // NOTE(woosuk): A barrier is required because the shared memory space for logits + // is reused for the output. + __syncthreads(); + + // Perform reduction across warps. + float* out_smem = reinterpret_cast(shared_mem); +#pragma unroll + for (int i = NUM_WARPS; i > 1; i /= 2) { + int mid = i / 2; + // Upper warps write to shared memory. + if (warp_idx >= mid && warp_idx < i) { + float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + dst[row_idx] = accs[i]; + } + } + } + __syncthreads(); + + // Lower warps update the output. + if (warp_idx < mid) { + const float* src = &out_smem[warp_idx * HEAD_SIZE]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + accs[i] += src[row_idx]; + } + } + } + __syncthreads(); + } + + // Write the final output. + if (warp_idx == 0) { + scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + from_float(*(out_ptr + row_idx), accs[i]); + } + } + } +} + +} // namespace vllm + +namespace tvm { +namespace runtime { + +#define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \ + cudaFuncSetAttribute( \ + vllm::single_query_cached_kv_attention_kernel, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \ + vllm::single_query_cached_kv_attention_kernel \ + <<>>( \ + out_ptr, \ + query_ptr, \ + key_cache_ptr, \ + value_cache_ptr, \ + head_mapping_ptr, \ + scale, \ + block_tables_ptr, \ + context_lens_ptr, \ + max_num_blocks_per_seq, \ + alibi_slopes_ptr, \ + q_stride, \ + kv_block_stride, \ + kv_head_stride); + + +template< + typename T, + int BLOCK_SIZE, + int NUM_THREADS = 128> +void single_query_cached_kv_attention_launcher( + DLTensor* out, + const DLTensor* query, + const DLTensor* key_cache, + const DLTensor* value_cache, + const DLTensor* head_mapping, + float scale, + const DLTensor* block_tables, + const DLTensor* context_lens, + int max_context_len) { + int num_seqs = query->shape[0]; + int num_heads = query->shape[1]; + int head_size = query->shape[2]; + int max_num_blocks_per_seq = block_tables->shape[1]; + int q_stride = query->shape[1] * query->shape[2]; + + int kv_head_stride = key_cache->shape[2] * key_cache->shape[3] * key_cache->shape[4]; + int kv_block_stride = kv_head_stride * key_cache->shape[1]; + + // int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); + // assert(head_size % thread_group_size == 0); + const float* alibi_slopes_ptr = nullptr; + + T* out_ptr = static_cast(out->data); + T* query_ptr = static_cast(query->data); + T* key_cache_ptr = static_cast(key_cache->data); + T* value_cache_ptr = static_cast(value_cache->data); + int* head_mapping_ptr = static_cast(head_mapping->data); + int* block_tables_ptr = static_cast(block_tables->data); + int* context_lens_ptr = static_cast(context_lens->data); + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + int padded_max_context_len = ((max_context_len + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE; + int logits_size = padded_max_context_len * sizeof(float); + int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); + int shared_mem_size = std::max(logits_size, outputs_size); + + dim3 grid(num_heads, num_seqs); + dim3 block(NUM_THREADS); + switch (head_size) { + case 64: + LAUNCH_ATTENTION_KERNEL(T, 64, BLOCK_SIZE, NUM_THREADS); + break; + case 80: + LAUNCH_ATTENTION_KERNEL(T, 80, BLOCK_SIZE, NUM_THREADS); + break; + case 96: + LAUNCH_ATTENTION_KERNEL(T, 96, BLOCK_SIZE, NUM_THREADS); + break; + case 112: + LAUNCH_ATTENTION_KERNEL(T, 112, BLOCK_SIZE, NUM_THREADS); + break; + case 128: + LAUNCH_ATTENTION_KERNEL(T, 128, BLOCK_SIZE, NUM_THREADS); + break; + case 256: + LAUNCH_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS); + break; + default: + // TORCH_CHECK(false, "Unsupported head size: ", head_size); + break; + } +} + +#define CALL_KERNEL_LAUNCHER(BLOCK_SIZE) \ + single_query_cached_kv_attention_launcher( \ + out, \ + query, \ + key_cache, \ + value_cache, \ + head_mapping, \ + scale, \ + block_tables, \ + context_lens, \ + max_context_len); + +TVM_REGISTER_GLOBAL("tvm.contrib.vllm.single_query_cached_kv_attention") + .set_body_typed([](const DLTensor* query, + const DLTensor* key_cache, + const DLTensor* value_cache, + const DLTensor* head_mapping, + const DLTensor* block_tables, + const DLTensor* context_lens, + int block_size, + const DLTensor* max_context_len_tensor, // TODO: pass integer + DLTensor* out) { + float scale = 1.0 / sqrt(query->shape[2]); + int max_context_len = ((int*)max_context_len_tensor->data)[0]; + + if (block_size == 8) { + CALL_KERNEL_LAUNCHER(8); + } else if (block_size == 16) { + CALL_KERNEL_LAUNCHER(16); + } else if (block_size == 32) { + CALL_KERNEL_LAUNCHER(32); + } else { + // TORCH_CHECK(false, "Unsupported block size: ", block_size); + } + }); +} // namespace runtime +} // namespace tvm + +#undef WARP_SIZE +#undef MAX +#undef MIN diff --git a/src/runtime/contrib/vllm/attention_utils.cuh b/src/runtime/contrib/vllm/attention_utils.cuh new file mode 100644 index 000000000000..8855eb279324 --- /dev/null +++ b/src/runtime/contrib/vllm/attention_utils.cuh @@ -0,0 +1,55 @@ +/* + * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "dtype_float16.cuh" + +#include +#include + +namespace vllm { + +// Q*K^T operation. +template +inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { + using A_vec = typename FloatVec::Type; + // Compute the parallel products for Q*K^T (treat vector lanes separately). + A_vec qk_vec = mul(q[0], k[0]); +#pragma unroll + for (int ii = 1; ii < N; ++ii) { + qk_vec = fma(q[ii], k[ii], qk_vec); + } + + // Finalize the reduction across lanes. + float qk = sum(qk_vec); +#pragma unroll + for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) { + qk += __shfl_xor_sync(uint32_t(-1), qk, mask); + } + return qk; +} + +template +struct Qk_dot { + template + static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) { + return qk_dot_(q, k); + } +}; + +} // namespace vllm diff --git a/src/runtime/contrib/vllm/cache_alloc.cc b/src/runtime/contrib/vllm/cache_alloc.cc new file mode 100644 index 000000000000..aea50aa47a5c --- /dev/null +++ b/src/runtime/contrib/vllm/cache_alloc.cc @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include + +namespace tvm { +namespace runtime { +namespace vllm { + +Array AllocateKVCache(int head_size, int num_layers, int num_heads, int block_size, + int num_blocks) { + Array cache; + int element_size = 2; + int vec_size = 16 / element_size; + + int device_id; + cudaGetDevice(&device_id); + + DLDevice dev{DLDeviceType::kDLCUDA, device_id}; + + for (int i = 0; i < num_layers; ++i) { + NDArray key_blocks = + NDArray::Empty({num_blocks, num_heads, head_size / vec_size, block_size, vec_size}, + runtime::DataType::Float(16), dev); + NDArray value_blocks = NDArray::Empty({num_blocks, num_heads, head_size, block_size}, + runtime::DataType::Float(16), dev); + cache.push_back(key_blocks); + cache.push_back(value_blocks); + } + + return cache; +} + +TVM_REGISTER_GLOBAL("tvm.contrib.vllm.allocate_kv_cache").set_body_typed(AllocateKVCache); + +} // namespace vllm +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/contrib/vllm/cache_kernels.cu b/src/runtime/contrib/vllm/cache_kernels.cu new file mode 100644 index 000000000000..be1625362044 --- /dev/null +++ b/src/runtime/contrib/vllm/cache_kernels.cu @@ -0,0 +1,91 @@ +#include +#include +#include +#include + +#include +#include +#include + +namespace vllm { + +template +__global__ void reshape_and_cache_kernel( + const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] + const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] + scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size] + const int* __restrict__ slot_mapping, // [num_tokens] + const int key_stride, + const int value_stride, + const int num_heads, + const int head_size, + const int block_size, + const int x) { + const int token_idx = blockIdx.x; + const int slot_idx = slot_mapping[token_idx]; + const int block_idx = slot_idx / block_size; + const int block_offset = slot_idx % block_size; + + const int n = num_heads * head_size; + for (int i = threadIdx.x; i < n; i += blockDim.x) { + const int src_key_idx = token_idx * key_stride + i; + const int src_value_idx = token_idx * value_stride + i; + + const int head_idx = i / head_size; + const int head_offset = i % head_size; + const int x_idx = head_offset / x; + const int x_offset = head_offset % x; + + const int tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x + + head_idx * (head_size / x) * block_size * x + + x_idx * block_size * x + + block_offset * x + + x_offset; + const int tgt_value_idx = block_idx * num_heads * head_size * block_size + + head_idx * head_size * block_size + + head_offset * block_size + + block_offset; + key_cache[tgt_key_idx] = __ldg(&key[src_key_idx]); + value_cache[tgt_value_idx] = __ldg(&value[src_value_idx]); + } +} + +} // namespace vllm + +namespace tvm { +namespace runtime { + +TVM_REGISTER_GLOBAL("tvm.contrib.vllm.reshape_and_cache") + .set_body_typed([](NDArray key, NDArray value, NDArray key_cache, + NDArray value_cache, NDArray slot_mapping) { + int num_tokens = key->shape[0]; + int num_heads = key->shape[1]; + int head_size = key->shape[2]; + int block_size = key_cache->shape[3]; + int vec_size = key_cache->shape[4]; + + int key_stride = key->shape[1] * key->shape[2]; + int value_stride = value->shape[1] * value->shape[2]; + + dim3 grid(num_tokens); + dim3 block(std::min(num_heads * head_size, 512)); + + using scalar_t = uint16_t; + vllm::reshape_and_cache_kernel<<>>( + static_cast(key->data), + static_cast(value->data), + static_cast(key_cache->data), + static_cast(value_cache->data), + static_cast(slot_mapping->data), + key_stride, + value_stride, + num_heads, + head_size, + block_size, + vec_size); + + return Array({key_cache, value_cache}); + }); +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/contrib/vllm/dtype_float16.cuh b/src/runtime/contrib/vllm/dtype_float16.cuh new file mode 100644 index 000000000000..e67921128d52 --- /dev/null +++ b/src/runtime/contrib/vllm/dtype_float16.cuh @@ -0,0 +1,444 @@ +/* + * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "attention_generic.cuh" +#include "dtype_float32.cuh" + +#include + +namespace vllm { + +// FP16 vector types for Q, K, V. +template<> +struct Vec { + using Type = uint16_t; +}; +template<> +struct Vec { + using Type = uint32_t; +}; +template<> +struct Vec { + using Type = uint2; +}; +template<> +struct Vec { + using Type = uint4; +}; + +// FP32 accumulator vector types corresponding to Vec. +template<> +struct FloatVec { + using Type = float; +}; +template<> +struct FloatVec { + using Type = float2; +}; +template<> +struct FloatVec { + using Type = Float4_; +}; +template<> +struct FloatVec { + using Type = Float8_; +}; + +// Utility functions for type conversions. +inline __device__ uint32_t h0_h0(uint16_t a) { + uint32_t b; + asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a)); + return b; +} + +inline __device__ float half_to_float(uint16_t h) { + float f; + asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h)); + return f; +} + +inline __device__ float2 half2_to_float2(uint32_t v) { + uint16_t lo, hi; + asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v)); + return make_float2(half_to_float(lo), half_to_float(hi)); +} + +inline __device__ uint16_t float_to_half(float f) { + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f)); + return tmp.u16[0]; +} + +inline __device__ uint32_t float2_to_half2(float2 f) { + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x)); +#else + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); +#endif + return tmp.u32; +} + +// Vector addition. +inline __device__ uint16_t add(uint16_t a, uint16_t b) { + uint16_t c; + asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); + return c; +} + +inline __device__ uint32_t add(uint32_t a, uint32_t b) { + uint32_t c; + asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); + return c; +} + +inline __device__ uint2 add(uint2 a, uint2 b) { + uint2 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + return c; +} + +inline __device__ uint4 add(uint4 a, uint4 b) { + uint4 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + c.z = add(a.z, b.z); + c.w = add(a.w, b.w); + return c; +} + +inline __device__ float2 add(uint32_t a, float2 fb) { + float2 fa = half2_to_float2(a); + return add(fa, fb); +} + +inline __device__ Float4_ add(uint2 a, Float4_ fb) { + Float4_ fc; + fc.x = add(a.x, fb.x); + fc.y = add(a.y, fb.y); + return fc; +} + +inline __device__ Float8_ add(uint4 a, Float8_ fb) { + Float8_ fc; + fc.x = add(a.x, fb.x); + fc.y = add(a.y, fb.y); + fc.z = add(a.z, fb.z); + fc.w = add(a.w, fb.w); + return fc; +} + +// Vector multiplication. +template<> +inline __device__ uint16_t mul(uint16_t a, uint16_t b) { + uint16_t c; + asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); + return c; +} + +template<> +inline __device__ uint32_t mul(uint32_t a, uint32_t b) { + uint32_t c; + asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); + return c; +} + +template<> +inline __device__ uint32_t mul(uint16_t a, uint32_t b) { + return mul(h0_h0(a), b); +} + +template<> +inline __device__ uint2 mul(uint2 a, uint2 b) { + uint2 c; + c.x = mul(a.x, b.x); + c.y = mul(a.y, b.y); + return c; +} + +template<> +inline __device__ uint2 mul(uint16_t a, uint2 b) { + uint32_t s = h0_h0(a); + uint2 c; + c.x = mul(s, b.x); + c.y = mul(s, b.y); + return c; +} + +template<> +inline __device__ uint4 mul(uint4 a, uint4 b) { + uint4 c; + c.x = mul(a.x, b.x); + c.y = mul(a.y, b.y); + c.z = mul(a.z, b.z); + c.w = mul(a.w, b.w); + return c; +} + +template<> +inline __device__ uint4 mul(uint16_t a, uint4 b) { + uint32_t s = h0_h0(a); + uint4 c; + c.x = mul(s, b.x); + c.y = mul(s, b.y); + c.z = mul(s, b.z); + c.w = mul(s, b.w); + return c; +} + +template<> +inline __device__ float mul(uint16_t a, uint16_t b) { + float fa = half_to_float(a); + float fb = half_to_float(b); + return fa * fb; +} + +template<> +inline __device__ float2 mul(uint32_t a, uint32_t b) { + float2 fa = half2_to_float2(a); + float2 fb = half2_to_float2(b); + return mul(fa, fb); +} + +template<> +inline __device__ float2 mul(uint16_t a, uint32_t b) { + return mul(h0_h0(a), b); +} + +template<> +inline __device__ Float4_ mul(uint2 a, uint2 b) { + Float4_ fc; + fc.x = mul(a.x, b.x); + fc.y = mul(a.y, b.y); + return fc; +} + +template<> +inline __device__ Float4_ mul(uint16_t a, uint2 b) { + uint32_t s = h0_h0(a); + Float4_ fc; + fc.x = mul(s, b.x); + fc.y = mul(s, b.y); + return fc; +} + +template<> +inline __device__ Float8_ mul(uint4 a, uint4 b) { + Float8_ fc; + fc.x = mul(a.x, b.x); + fc.y = mul(a.y, b.y); + fc.z = mul(a.z, b.z); + fc.w = mul(a.w, b.w); + return fc; +} + +template<> +inline __device__ Float8_ mul(uint16_t a, uint4 b) { + uint32_t s = h0_h0(a); + Float8_ fc; + fc.x = mul(s, b.x); + fc.y = mul(s, b.y); + fc.z = mul(s, b.z); + fc.w = mul(s, b.w); + return fc; +} + +// Vector fused multiply-add. +inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) { + uint32_t d; + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c)); + return d; +} + +inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) { + return fma(h0_h0(a), b, c); +} + +inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) { + uint2 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + return d; +} + +inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) { + uint32_t s = h0_h0(a); + uint2 d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + return d; +} + +inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) { + uint4 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + d.z = fma(a.z, b.z, c.z); + d.w = fma(a.w, b.w, c.w); + return d; +} + +inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) { + uint32_t s = h0_h0(a); + uint4 d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + d.z = fma(s, b.z, c.z); + d.w = fma(s, b.w, c.w); + return d; +} + +inline __device__ float fma(uint16_t a, uint16_t b, float fc) { + float fa = half_to_float(a); + float fb = half_to_float(b); + return fa * fb + fc; +} + +inline __device__ float2 fma(uint32_t a, uint32_t b, float2 fc) { + float2 fa = half2_to_float2(a); + float2 fb = half2_to_float2(b); + return fma(fa, fb, fc); +} + +inline __device__ float2 fma(uint16_t a, uint32_t b, float2 fc) { + return fma(h0_h0(a), b, fc); +} + +inline __device__ Float4_ fma(uint2 a, uint2 b, Float4_ fc) { + Float4_ fd; + fd.x = fma(a.x, b.x, fc.x); + fd.y = fma(a.y, b.y, fc.y); + return fd; +} + +inline __device__ Float4_ fma(uint16_t a, uint2 b, Float4_ fc) { + uint32_t s = h0_h0(a); + Float4_ fd; + fd.x = fma(s, b.x, fc.x); + fd.y = fma(s, b.y, fc.y); + return fd; +} + +inline __device__ Float8_ fma(uint4 a, uint4 b, Float8_ fc) { + Float8_ fd; + fd.x = fma(a.x, b.x, fc.x); + fd.y = fma(a.y, b.y, fc.y); + fd.z = fma(a.z, b.z, fc.z); + fd.w = fma(a.w, b.w, fc.w); + return fd; +} + +inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) { + uint32_t s = h0_h0(a); + Float8_ fd; + fd.x = fma(s, b.x, fc.x); + fd.y = fma(s, b.y, fc.y); + fd.z = fma(s, b.z, fc.z); + fd.w = fma(s, b.w, fc.w); + return fd; +} + +// Vector sum. +template<> +inline __device__ float sum(uint16_t v) { + return half_to_float(v); +} + +template<> +inline __device__ float sum(uint32_t v) { + float2 tmp = half2_to_float2(v); + return tmp.x + tmp.y; +} + +template<> +inline __device__ float sum(uint2 v) { + uint32_t c = add(v.x, v.y); + return sum(c); +} + +template<> +inline __device__ float sum(uint4 v) { + uint32_t c = add(v.x, v.y); + c = add(c, v.z); + c = add(c, v.w); + return sum(c); +} + +// From float32 to float16. +inline __device__ void from_float(uint16_t& dst, float src) { + dst = float_to_half(src); +} + +inline __device__ void from_float(uint32_t& dst, float2 src) { + dst = float2_to_half2(src); +} + +inline __device__ void from_float(uint2& dst, Float4_ src) { + dst.x = float2_to_half2(src.x); + dst.y = float2_to_half2(src.y); +} + +inline __device__ void from_float(uint4& dst, Float8_ src) { + dst.x = float2_to_half2(src.x); + dst.y = float2_to_half2(src.y); + dst.z = float2_to_half2(src.z); + dst.w = float2_to_half2(src.w); +} + +// From float16 to float32. +inline __device__ float to_float(uint16_t u) { + return half_to_float(u); +} + +inline __device__ float2 to_float(uint32_t u) { + return half2_to_float2(u); +} + +inline __device__ Float4_ to_float(uint2 u) { + Float4_ tmp; + tmp.x = half2_to_float2(u.x); + tmp.y = half2_to_float2(u.y); + return tmp; +} + +inline __device__ Float8_ to_float(uint4 u) { + Float8_ tmp; + tmp.x = half2_to_float2(u.x); + tmp.y = half2_to_float2(u.y); + tmp.z = half2_to_float2(u.z); + tmp.w = half2_to_float2(u.w); + return tmp; +} + +// Zero-out a variable. +inline __device__ void zero(uint16_t& dst) { + dst = uint16_t(0); +} + +} // namespace vllm diff --git a/src/runtime/contrib/vllm/dtype_float32.cuh b/src/runtime/contrib/vllm/dtype_float32.cuh new file mode 100644 index 000000000000..960cf48e2643 --- /dev/null +++ b/src/runtime/contrib/vllm/dtype_float32.cuh @@ -0,0 +1,268 @@ +/* + * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "attention_generic.cuh" + +#include + +namespace vllm { + +// Define custom FP32 vector data types. +struct Float4_ { + float2 x; + float2 y; +}; + +struct Float8_ { + float2 x; + float2 y; + float2 z; + float2 w; +}; + +// FP32 vector types for Q, K, V. +template<> +struct Vec { + using Type = float; +}; +template<> +struct Vec { + using Type = float2; +}; +template<> +struct Vec { + using Type = float4; +}; + +// FP32 accumulator vector types corresponding to Vec. +template<> +struct FloatVec { + using Type = float; +}; +template<> +struct FloatVec { + using Type = float2; +}; +template<> +struct FloatVec { + using Type = float4; +}; + +// Vector addition. +inline __device__ float add(float a, float b) { + return a + b; +} + +inline __device__ float2 add(float2 a, float2 b) { + float2 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + return c; +} + +inline __device__ float4 add(float4 a, float4 b) { + float4 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + c.z = add(a.z, b.z); + c.w = add(a.w, b.w); + return c; +} + +// Vector multiplication. +template<> +inline __device__ float mul(float a, float b) { + return a * b; +} + +template<> +inline __device__ float2 mul(float2 a, float2 b) { + float2 c; + c.x = a.x * b.x; + c.y = a.y * b.y; + return c; +} + +template<> +inline __device__ float2 mul(float a, float2 b) { + float2 c; + c.x = a * b.x; + c.y = a * b.y; + return c; +} + +template<> +inline __device__ float4 mul(float4 a, float4 b) { + float4 c; + c.x = a.x * b.x; + c.y = a.y * b.y; + c.z = a.z * b.z; + c.w = a.w * b.w; + return c; +} + +template<> +inline __device__ float4 mul(float a, float4 b) { + float4 c; + c.x = a * b.x; + c.y = a * b.y; + c.z = a * b.z; + c.w = a * b.w; + return c; +} + +// Vector fused multiply-add. +inline __device__ float fma(float a, float b, float c) { + return a * b + c; +} + +inline __device__ float2 fma(float2 a, float2 b, float2 c) { + float2 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + return d; +} + +inline __device__ float2 fma(float a, float2 b, float2 c) { + float2 d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + return d; +} + +inline __device__ float4 fma(float4 a, float4 b, float4 c) { + float4 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + d.z = fma(a.z, b.z, c.z); + d.w = fma(a.w, b.w, c.w); + return d; +} + +inline __device__ float4 fma(float a, float4 b, float4 c) { + float4 d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + d.z = fma(a, b.z, c.z); + d.w = fma(a, b.w, c.w); + return d; +} + +inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c) { + Float4_ d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + return d; +} + +inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) { + Float8_ d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + d.z = fma(a, b.z, c.z); + d.w = fma(a, b.w, c.w); + return d; +} + +// Vector sum. +template<> +inline __device__ float sum(float v) { + return v; +} + +template<> +inline __device__ float sum(float2 v) { + return v.x + v.y; +} + +template<> +inline __device__ float sum(float4 v) { + return v.x + v.y + v.z + v.w; +} + +template<> +inline __device__ float sum(Float4_ v) { + return v.x.x + v.x.y + v.y.x + v.y.y; +} + +template<> +inline __device__ float sum(Float8_ v) { + return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y; +} + +// Vector dot product. +inline __device__ float dot(float a, float b) { + return a * b; +} + +inline __device__ float dot(float2 a, float2 b) { + float2 c = mul(a, b); + return c.x + c.y; +} + +inline __device__ float dot(Float4_ a, Float4_ b) { + float2 acc = mul(a.x, b.x); + acc = fma(a.y, b.y, acc); + return acc.x + acc.y; +} + +inline __device__ float dot(Float8_ a, Float8_ b) { + float2 acc = mul(a.x, b.x); + acc = fma(a.y, b.y, acc); + acc = fma(a.z, b.z, acc); + acc = fma(a.w, b.w, acc); + return acc.x + acc.y; +} + +// From float to float. +inline __device__ void from_float(float& dst, float src) { + dst = src; +} + +inline __device__ void from_float(float2& dst, float2 src) { + dst = src; +} + +inline __device__ void from_float(float4& dst, float4 src) { + dst = src; +} + +// From float to float. +inline __device__ float to_float(float u) { + return u; +} + +inline __device__ float2 to_float(float2 u) { + return u; +} + +inline __device__ float4 to_float(float4 u) { + return u; +} + +inline __device__ Float4_ to_float(Float4_ u) { + return u; +} + +inline __device__ Float8_ to_float(Float8_ u) { + return u; +} + +} // namespace vllm diff --git a/tests/python/relax/test_contrib_vllm.py b/tests/python/relax/test_contrib_vllm.py new file mode 100644 index 000000000000..bde07178392b --- /dev/null +++ b/tests/python/relax/test_contrib_vllm.py @@ -0,0 +1,161 @@ +import numpy as np + +import torch + +import tvm +from tvm import relax +from tvm.script import ir as I +from tvm.script import relax as R +from tvm.script import tir as T + + +def to_torch(arr): + return torch.from_numpy(arr).to("cuda") + + +def build_and_run(mod, inputs_np, target, legalize=True): + if legalize: + mod = relax.transform.LegalizeOps()(mod) + + with tvm.target.Target("cuda"): + mod = tvm.tir.transform.DefaultGPUSchedule()(mod) + + with tvm.transform.PassContext(): + ex = relax.build(mod, target) + + dev = tvm.device(target, 0) + vm = relax.VirtualMachine(ex, dev) + f = vm["main"] + inputs = [tvm.nd.array(inp, dev) for inp in inputs_np] + + out = f(*inputs) + + if isinstance(out, tvm.ir.container.Array): + return [arr.numpy() for arr in out] + + return out.numpy() + + +def test_attention(): + @I.ir_module + class Module: + @R.function + def main( + query: R.Tensor(("num_seqs", 12, 64), dtype="float16"), + key_cache: R.Tensor(("num_blocks", 12, 8, 16, 8), dtype="float16"), + value_cache: R.Tensor(("num_blocks", 12, 64, 16), dtype="float16"), + head_mapping: R.Tensor((12,), dtype="int32"), + block_tables: R.Tensor(("num_seqs", "max_num_blocks_per_seq"), dtype="int32"), + context_lens: R.Tensor(("num_seqs",), dtype="int32"), + ) -> R.Tensor(("num_seqs", 12, 64), dtype="float16"): + with R.dataflow(): + max_len = R.max(context_lens) + out = R.call_dps_packed( + "tvm.contrib.vllm.single_query_cached_kv_attention", + [ + query, + key_cache, + value_cache, + head_mapping, + block_tables, + context_lens, + 16, + max_len, + ], + out_sinfo=query.struct_info, + ) + R.output(out) + return out + + query = np.load("vllm_attention_inputs/query.npy") + key_cache = np.load("vllm_attention_inputs/key_cache.npy") + value_cache = np.load("vllm_attention_inputs/value_cache.npy") + block_tables = np.load("vllm_attention_inputs/block_tables.npy") + head_mapping = np.load("vllm_attention_inputs/head_mapping.npy") + context_lens = np.load("vllm_attention_inputs/context_lens.npy") + + out = build_and_run( + Module, + [query, key_cache, value_cache, head_mapping, block_tables, context_lens], + "cuda", + legalize=True, + ) + + ref = to_torch(np.zeros_like(query)) + + from vllm import attention_ops + + attention_ops.single_query_cached_kv_attention( + ref, + to_torch(query), + to_torch(key_cache), + to_torch(value_cache), + to_torch(head_mapping), + query.shape[-1] ** -0.5, # scale + to_torch(block_tables), + to_torch(context_lens), + value_cache.shape[-1], # block_size, + np.max(context_lens), + None, + ) + + assert np.max(np.abs(ref.cpu().numpy() - out)) == 0.0 + + +def test_cache(): + @I.ir_module + class Module: + @R.function + def main( + key: R.Tensor(("num_tokens", 12, 64), dtype="float16"), + value: R.Tensor(("num_tokens", 12, 64), dtype="float16"), + key_cache: R.Tensor(("num_blocks", 12, 8, 16, 8), dtype="float16"), + value_cache: R.Tensor(("num_blocks", 12, 64, 16), dtype="float16"), + slot_mapping: R.Tensor(("num_tokens",), dtype="int32"), + ) -> R.Tuple( + [ + R.Tensor(("num_blocks", 12, 8, 16, 8), dtype="float16"), + R.Tensor(("num_blocks", 12, 64, 16), dtype="float16"), + ] + ): + with R.dataflow(): + kv = R.call_pure_packed( + "tvm.contrib.vllm.reshape_and_cache", + key, value, key_cache, value_cache, slot_mapping, + sinfo_args=[key_cache.struct_info, value_cache.struct_info] + ) + out = (kv[0], kv[1]) + R.output(out) + return out + + key = np.load("vllm_cache_inputs/key_to_cache.npy") + value = np.load("vllm_cache_inputs/value_to_cache.npy") + + key_cache = np.load("vllm_cache_inputs/key_cache_before.npy") + value_cache = np.load("vllm_cache_inputs/value_cache_before.npy") + slot_mapping = np.load("vllm_cache_inputs/slot_mapping.npy") + + out_key_cache, out_value_cache = build_and_run( + Module, + [key, value, key_cache, value_cache, slot_mapping], + "cuda", + ) + + from vllm import cache_ops + + key_cache = to_torch(np.load("vllm_cache_inputs/key_cache_before.npy")) + value_cache = to_torch(np.load("vllm_cache_inputs/value_cache_before.npy")) + + cache_ops.reshape_and_cache( + to_torch(key), + to_torch(value), + key_cache, + value_cache, + to_torch(slot_mapping), + ) + + assert np.max(np.abs(out_key_cache - key_cache.cpu().numpy())) == 0 + assert np.max(np.abs(out_value_cache - value_cache.cpu().numpy())) == 0 + + +test_cache() From c301728020eb19ab663f45429e4e3f46c3d5c9cf Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 26 Oct 2023 11:21:19 +0000 Subject: [PATCH 02/10] add license --- licenses/LICENSE.vllm.txt | 201 ++++++++++++++++++++++ src/runtime/contrib/vllm/cache_kernels.cu | 19 ++ tests/python/relax/test_contrib_vllm.py | 18 +- 3 files changed, 237 insertions(+), 1 deletion(-) create mode 100644 licenses/LICENSE.vllm.txt diff --git a/licenses/LICENSE.vllm.txt b/licenses/LICENSE.vllm.txt new file mode 100644 index 000000000000..261eeb9e9f8b --- /dev/null +++ b/licenses/LICENSE.vllm.txt @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/src/runtime/contrib/vllm/cache_kernels.cu b/src/runtime/contrib/vllm/cache_kernels.cu index be1625362044..2eab6d4b07d7 100644 --- a/src/runtime/contrib/vllm/cache_kernels.cu +++ b/src/runtime/contrib/vllm/cache_kernels.cu @@ -1,3 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + #include #include #include diff --git a/tests/python/relax/test_contrib_vllm.py b/tests/python/relax/test_contrib_vllm.py index bde07178392b..59007403f4c5 100644 --- a/tests/python/relax/test_contrib_vllm.py +++ b/tests/python/relax/test_contrib_vllm.py @@ -1,3 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + import numpy as np import torch @@ -6,7 +23,6 @@ from tvm import relax from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T def to_torch(arr): From 0e54b1012c9e081d97abd6b6edf9d359f6b03914 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 26 Oct 2023 18:02:26 +0000 Subject: [PATCH 03/10] clean --- .../contrib/vllm/attention_generic.cuh | 64 ---- src/runtime/contrib/vllm/attention_kernels.cu | 35 ++- src/runtime/contrib/vllm/attention_utils.cuh | 55 ---- src/runtime/contrib/vllm/cache_kernels.cu | 2 +- src/runtime/contrib/vllm/dtype_float16.cuh | 284 +++++++++++++++++- src/runtime/contrib/vllm/dtype_float32.cuh | 268 ----------------- 6 files changed, 314 insertions(+), 394 deletions(-) delete mode 100644 src/runtime/contrib/vllm/attention_generic.cuh delete mode 100644 src/runtime/contrib/vllm/attention_utils.cuh delete mode 100644 src/runtime/contrib/vllm/dtype_float32.cuh diff --git a/src/runtime/contrib/vllm/attention_generic.cuh b/src/runtime/contrib/vllm/attention_generic.cuh deleted file mode 100644 index 31fb401cbe2c..000000000000 --- a/src/runtime/contrib/vllm/attention_generic.cuh +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h - * Copyright (c) 2023, The vLLM team. - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include - -namespace vllm { - -// A vector type to store Q, K, V elements. -template -struct Vec {}; - -// A vector type to store FP32 accumulators. -template -struct FloatVec {}; - -// Template vector operations. -template -inline __device__ Acc mul(A a, B b); - -template -inline __device__ float sum(T v); - -template -inline __device__ float dot(T a, T b) { - return sum(mul(a, b)); -} - -template -inline __device__ float dot(T a, T b) { - return sum(mul(a, b)); -} - -template -inline __device__ void zero(T& dst) { - constexpr int WORDS = sizeof(T) / 4; - union { - T raw; - uint32_t words[WORDS]; - } tmp; - -#pragma unroll - for (int ii = 0; ii < WORDS; ++ii) { - tmp.words[ii] = 0u; - } - dst = tmp.raw; -} - -} // namespace vllm diff --git a/src/runtime/contrib/vllm/attention_kernels.cu b/src/runtime/contrib/vllm/attention_kernels.cu index b38acd6c102b..657dfa0cb2bf 100644 --- a/src/runtime/contrib/vllm/attention_kernels.cu +++ b/src/runtime/contrib/vllm/attention_kernels.cu @@ -16,9 +16,10 @@ * limitations under the License. */ #include "dtype_float16.cuh" -#include "attention_utils.cuh" #include +#include +#include #include #include @@ -30,6 +31,34 @@ namespace vllm { +// Q*K^T operation. +template +inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { + using A_vec = typename FloatVec::Type; + // Compute the parallel products for Q*K^T (treat vector lanes separately). + A_vec qk_vec = mul(q[0], k[0]); +#pragma unroll + for (int ii = 1; ii < N; ++ii) { + qk_vec = fma(q[ii], k[ii], qk_vec); + } + + // Finalize the reduction across lanes. + float qk = sum(qk_vec); +#pragma unroll + for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) { + qk += __shfl_xor_sync(uint32_t(-1), qk, mask); + } + return qk; +} + +template +struct Qk_dot { + template + static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) { + return qk_dot_(q, k); + } +}; + // Utility function for attention softmax. template inline __device__ float block_sum(float* red_smem, float sum) { @@ -453,7 +482,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.vllm.single_query_cached_kv_attention") const DLTensor* block_tables, const DLTensor* context_lens, int block_size, - const DLTensor* max_context_len_tensor, // TODO: pass integer + const DLTensor* max_context_len_tensor, // TODO(masahi): pass integer DLTensor* out) { float scale = 1.0 / sqrt(query->shape[2]); int max_context_len = ((int*)max_context_len_tensor->data)[0]; @@ -465,7 +494,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.vllm.single_query_cached_kv_attention") } else if (block_size == 32) { CALL_KERNEL_LAUNCHER(32); } else { - // TORCH_CHECK(false, "Unsupported block size: ", block_size); + LOG(FATAL) << "Unsupported block size: " << block_size; } }); } // namespace runtime diff --git a/src/runtime/contrib/vllm/attention_utils.cuh b/src/runtime/contrib/vllm/attention_utils.cuh deleted file mode 100644 index 8855eb279324..000000000000 --- a/src/runtime/contrib/vllm/attention_utils.cuh +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp - * Copyright (c) 2023, The vLLM team. - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "dtype_float16.cuh" - -#include -#include - -namespace vllm { - -// Q*K^T operation. -template -inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { - using A_vec = typename FloatVec::Type; - // Compute the parallel products for Q*K^T (treat vector lanes separately). - A_vec qk_vec = mul(q[0], k[0]); -#pragma unroll - for (int ii = 1; ii < N; ++ii) { - qk_vec = fma(q[ii], k[ii], qk_vec); - } - - // Finalize the reduction across lanes. - float qk = sum(qk_vec); -#pragma unroll - for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) { - qk += __shfl_xor_sync(uint32_t(-1), qk, mask); - } - return qk; -} - -template -struct Qk_dot { - template - static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) { - return qk_dot_(q, k); - } -}; - -} // namespace vllm diff --git a/src/runtime/contrib/vllm/cache_kernels.cu b/src/runtime/contrib/vllm/cache_kernels.cu index 2eab6d4b07d7..e6fa5879d253 100644 --- a/src/runtime/contrib/vllm/cache_kernels.cu +++ b/src/runtime/contrib/vllm/cache_kernels.cu @@ -104,7 +104,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.vllm.reshape_and_cache") block_size, vec_size); - return Array({key_cache, value_cache}); + return Array{key_cache, value_cache}; }); } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/vllm/dtype_float16.cuh b/src/runtime/contrib/vllm/dtype_float16.cuh index e67921128d52..2c02aa346a2a 100644 --- a/src/runtime/contrib/vllm/dtype_float16.cuh +++ b/src/runtime/contrib/vllm/dtype_float16.cuh @@ -18,13 +18,291 @@ */ #pragma once -#include "attention_generic.cuh" -#include "dtype_float32.cuh" - #include namespace vllm { +// A vector type to store Q, K, V elements. +template +struct Vec {}; + +// A vector type to store FP32 accumulators. +template +struct FloatVec {}; + +// Template vector operations. +template +inline __device__ Acc mul(A a, B b); + +template +inline __device__ float sum(T v); + +template +inline __device__ float dot(T a, T b) { + return sum(mul(a, b)); +} + +template +inline __device__ float dot(T a, T b) { + return sum(mul(a, b)); +} + +template +inline __device__ void zero(T& dst) { + constexpr int WORDS = sizeof(T) / 4; + union { + T raw; + uint32_t words[WORDS]; + } tmp; + +#pragma unroll + for (int ii = 0; ii < WORDS; ++ii) { + tmp.words[ii] = 0u; + } + dst = tmp.raw; +} + +// Define custom FP32 vector data types. +struct Float4_ { + float2 x; + float2 y; +}; + +struct Float8_ { + float2 x; + float2 y; + float2 z; + float2 w; +}; + +// FP32 vector types for Q, K, V. +template<> +struct Vec { + using Type = float; +}; +template<> +struct Vec { + using Type = float2; +}; +template<> +struct Vec { + using Type = float4; +}; + +// FP32 accumulator vector types corresponding to Vec. +template<> +struct FloatVec { + using Type = float; +}; +template<> +struct FloatVec { + using Type = float2; +}; +template<> +struct FloatVec { + using Type = float4; +}; + +// Vector addition. +inline __device__ float add(float a, float b) { + return a + b; +} + +inline __device__ float2 add(float2 a, float2 b) { + float2 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + return c; +} + +inline __device__ float4 add(float4 a, float4 b) { + float4 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + c.z = add(a.z, b.z); + c.w = add(a.w, b.w); + return c; +} + +// Vector multiplication. +template<> +inline __device__ float mul(float a, float b) { + return a * b; +} + +template<> +inline __device__ float2 mul(float2 a, float2 b) { + float2 c; + c.x = a.x * b.x; + c.y = a.y * b.y; + return c; +} + +template<> +inline __device__ float2 mul(float a, float2 b) { + float2 c; + c.x = a * b.x; + c.y = a * b.y; + return c; +} + +template<> +inline __device__ float4 mul(float4 a, float4 b) { + float4 c; + c.x = a.x * b.x; + c.y = a.y * b.y; + c.z = a.z * b.z; + c.w = a.w * b.w; + return c; +} + +template<> +inline __device__ float4 mul(float a, float4 b) { + float4 c; + c.x = a * b.x; + c.y = a * b.y; + c.z = a * b.z; + c.w = a * b.w; + return c; +} + +// Vector fused multiply-add. +inline __device__ float fma(float a, float b, float c) { + return a * b + c; +} + +inline __device__ float2 fma(float2 a, float2 b, float2 c) { + float2 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + return d; +} + +inline __device__ float2 fma(float a, float2 b, float2 c) { + float2 d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + return d; +} + +inline __device__ float4 fma(float4 a, float4 b, float4 c) { + float4 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + d.z = fma(a.z, b.z, c.z); + d.w = fma(a.w, b.w, c.w); + return d; +} + +inline __device__ float4 fma(float a, float4 b, float4 c) { + float4 d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + d.z = fma(a, b.z, c.z); + d.w = fma(a, b.w, c.w); + return d; +} + +inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c) { + Float4_ d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + return d; +} + +inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) { + Float8_ d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + d.z = fma(a, b.z, c.z); + d.w = fma(a, b.w, c.w); + return d; +} + +// Vector sum. +template<> +inline __device__ float sum(float v) { + return v; +} + +template<> +inline __device__ float sum(float2 v) { + return v.x + v.y; +} + +template<> +inline __device__ float sum(float4 v) { + return v.x + v.y + v.z + v.w; +} + +template<> +inline __device__ float sum(Float4_ v) { + return v.x.x + v.x.y + v.y.x + v.y.y; +} + +template<> +inline __device__ float sum(Float8_ v) { + return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y; +} + +// Vector dot product. +inline __device__ float dot(float a, float b) { + return a * b; +} + +inline __device__ float dot(float2 a, float2 b) { + float2 c = mul(a, b); + return c.x + c.y; +} + +inline __device__ float dot(Float4_ a, Float4_ b) { + float2 acc = mul(a.x, b.x); + acc = fma(a.y, b.y, acc); + return acc.x + acc.y; +} + +inline __device__ float dot(Float8_ a, Float8_ b) { + float2 acc = mul(a.x, b.x); + acc = fma(a.y, b.y, acc); + acc = fma(a.z, b.z, acc); + acc = fma(a.w, b.w, acc); + return acc.x + acc.y; +} + +// From float to float. +inline __device__ void from_float(float& dst, float src) { + dst = src; +} + +inline __device__ void from_float(float2& dst, float2 src) { + dst = src; +} + +inline __device__ void from_float(float4& dst, float4 src) { + dst = src; +} + +// From float to float. +inline __device__ float to_float(float u) { + return u; +} + +inline __device__ float2 to_float(float2 u) { + return u; +} + +inline __device__ float4 to_float(float4 u) { + return u; +} + +inline __device__ Float4_ to_float(Float4_ u) { + return u; +} + +inline __device__ Float8_ to_float(Float8_ u) { + return u; +} + // FP16 vector types for Q, K, V. template<> struct Vec { diff --git a/src/runtime/contrib/vllm/dtype_float32.cuh b/src/runtime/contrib/vllm/dtype_float32.cuh deleted file mode 100644 index 960cf48e2643..000000000000 --- a/src/runtime/contrib/vllm/dtype_float32.cuh +++ /dev/null @@ -1,268 +0,0 @@ -/* - * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp - * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h - * Copyright (c) 2023, The vLLM team. - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "attention_generic.cuh" - -#include - -namespace vllm { - -// Define custom FP32 vector data types. -struct Float4_ { - float2 x; - float2 y; -}; - -struct Float8_ { - float2 x; - float2 y; - float2 z; - float2 w; -}; - -// FP32 vector types for Q, K, V. -template<> -struct Vec { - using Type = float; -}; -template<> -struct Vec { - using Type = float2; -}; -template<> -struct Vec { - using Type = float4; -}; - -// FP32 accumulator vector types corresponding to Vec. -template<> -struct FloatVec { - using Type = float; -}; -template<> -struct FloatVec { - using Type = float2; -}; -template<> -struct FloatVec { - using Type = float4; -}; - -// Vector addition. -inline __device__ float add(float a, float b) { - return a + b; -} - -inline __device__ float2 add(float2 a, float2 b) { - float2 c; - c.x = add(a.x, b.x); - c.y = add(a.y, b.y); - return c; -} - -inline __device__ float4 add(float4 a, float4 b) { - float4 c; - c.x = add(a.x, b.x); - c.y = add(a.y, b.y); - c.z = add(a.z, b.z); - c.w = add(a.w, b.w); - return c; -} - -// Vector multiplication. -template<> -inline __device__ float mul(float a, float b) { - return a * b; -} - -template<> -inline __device__ float2 mul(float2 a, float2 b) { - float2 c; - c.x = a.x * b.x; - c.y = a.y * b.y; - return c; -} - -template<> -inline __device__ float2 mul(float a, float2 b) { - float2 c; - c.x = a * b.x; - c.y = a * b.y; - return c; -} - -template<> -inline __device__ float4 mul(float4 a, float4 b) { - float4 c; - c.x = a.x * b.x; - c.y = a.y * b.y; - c.z = a.z * b.z; - c.w = a.w * b.w; - return c; -} - -template<> -inline __device__ float4 mul(float a, float4 b) { - float4 c; - c.x = a * b.x; - c.y = a * b.y; - c.z = a * b.z; - c.w = a * b.w; - return c; -} - -// Vector fused multiply-add. -inline __device__ float fma(float a, float b, float c) { - return a * b + c; -} - -inline __device__ float2 fma(float2 a, float2 b, float2 c) { - float2 d; - d.x = fma(a.x, b.x, c.x); - d.y = fma(a.y, b.y, c.y); - return d; -} - -inline __device__ float2 fma(float a, float2 b, float2 c) { - float2 d; - d.x = fma(a, b.x, c.x); - d.y = fma(a, b.y, c.y); - return d; -} - -inline __device__ float4 fma(float4 a, float4 b, float4 c) { - float4 d; - d.x = fma(a.x, b.x, c.x); - d.y = fma(a.y, b.y, c.y); - d.z = fma(a.z, b.z, c.z); - d.w = fma(a.w, b.w, c.w); - return d; -} - -inline __device__ float4 fma(float a, float4 b, float4 c) { - float4 d; - d.x = fma(a, b.x, c.x); - d.y = fma(a, b.y, c.y); - d.z = fma(a, b.z, c.z); - d.w = fma(a, b.w, c.w); - return d; -} - -inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c) { - Float4_ d; - d.x = fma(a, b.x, c.x); - d.y = fma(a, b.y, c.y); - return d; -} - -inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) { - Float8_ d; - d.x = fma(a, b.x, c.x); - d.y = fma(a, b.y, c.y); - d.z = fma(a, b.z, c.z); - d.w = fma(a, b.w, c.w); - return d; -} - -// Vector sum. -template<> -inline __device__ float sum(float v) { - return v; -} - -template<> -inline __device__ float sum(float2 v) { - return v.x + v.y; -} - -template<> -inline __device__ float sum(float4 v) { - return v.x + v.y + v.z + v.w; -} - -template<> -inline __device__ float sum(Float4_ v) { - return v.x.x + v.x.y + v.y.x + v.y.y; -} - -template<> -inline __device__ float sum(Float8_ v) { - return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y; -} - -// Vector dot product. -inline __device__ float dot(float a, float b) { - return a * b; -} - -inline __device__ float dot(float2 a, float2 b) { - float2 c = mul(a, b); - return c.x + c.y; -} - -inline __device__ float dot(Float4_ a, Float4_ b) { - float2 acc = mul(a.x, b.x); - acc = fma(a.y, b.y, acc); - return acc.x + acc.y; -} - -inline __device__ float dot(Float8_ a, Float8_ b) { - float2 acc = mul(a.x, b.x); - acc = fma(a.y, b.y, acc); - acc = fma(a.z, b.z, acc); - acc = fma(a.w, b.w, acc); - return acc.x + acc.y; -} - -// From float to float. -inline __device__ void from_float(float& dst, float src) { - dst = src; -} - -inline __device__ void from_float(float2& dst, float2 src) { - dst = src; -} - -inline __device__ void from_float(float4& dst, float4 src) { - dst = src; -} - -// From float to float. -inline __device__ float to_float(float u) { - return u; -} - -inline __device__ float2 to_float(float2 u) { - return u; -} - -inline __device__ float4 to_float(float4 u) { - return u; -} - -inline __device__ Float4_ to_float(Float4_ u) { - return u; -} - -inline __device__ Float8_ to_float(Float8_ u) { - return u; -} - -} // namespace vllm From 4817f3eb45c96f688afc6fe1efd33e26e91d73a7 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 27 Oct 2023 03:37:59 +0900 Subject: [PATCH 04/10] fix cmake --- cmake/modules/CUDA.cmake | 3 +-- cmake/modules/contrib/vllm.cmake | 3 ++- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cmake/modules/CUDA.cmake b/cmake/modules/CUDA.cmake index 55ba93f6cb28..d7c7f1bf530c 100644 --- a/cmake/modules/CUDA.cmake +++ b/cmake/modules/CUDA.cmake @@ -38,8 +38,6 @@ if(USE_CUDA) list(APPEND TVM_RUNTIME_LINKER_LIBS ${CUDA_CUDA_LIBRARY}) list(APPEND TVM_RUNTIME_LINKER_LIBS ${CUDA_NVRTC_LIBRARY}) - set(CMAKE_CUDA_ARCHITECTURES "86;80") - if(USE_CUDNN) message(STATUS "Build with cuDNN support") include_directories(SYSTEM ${CUDA_CUDNN_INCLUDE_DIRS}) @@ -66,6 +64,7 @@ if(USE_CUDA) message(STATUS "Build with Thrust support") cmake_minimum_required(VERSION 3.13) # to compile CUDA code enable_language(CUDA) + set(CMAKE_CUDA_ARCHITECTURES "80;75") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda") tvm_file_glob(GLOB CONTRIB_THRUST_SRC src/runtime/contrib/thrust/*.cu) list(APPEND RUNTIME_SRCS ${CONTRIB_THRUST_SRC}) diff --git a/cmake/modules/contrib/vllm.cmake b/cmake/modules/contrib/vllm.cmake index b27184079926..551ea6c8db18 100644 --- a/cmake/modules/contrib/vllm.cmake +++ b/cmake/modules/contrib/vllm.cmake @@ -18,7 +18,8 @@ if(USE_VLLM) message(STATUS "Build with vllm paged attention kernel.") include_directories(src/runtime/contrib/vllm) - set(CMAKE_CUDA_ARCHITECTURES 80) # without this, cmake tries to compile with compute_52 + enable_language(CUDA) + set(CMAKE_CUDA_ARCHITECTURES "80;75") tvm_file_glob(GLOB VLLM_CONTRIB_SRC src/runtime/contrib/vllm/*.cu src/runtime/contrib/vllm/*.cc) list(APPEND RUNTIME_SRCS ${VLLM_CONTRIB_SRC}) endif(USE_VLLM) From 99f14e59fcd5cfe474d1d1c8498f9672c92228e5 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 26 Oct 2023 20:01:18 +0000 Subject: [PATCH 05/10] update test --- tests/python/relax/test_contrib_vllm.py | 642 +++++++++++++++++++++--- 1 file changed, 580 insertions(+), 62 deletions(-) diff --git a/tests/python/relax/test_contrib_vllm.py b/tests/python/relax/test_contrib_vllm.py index 59007403f4c5..7ca60bbaca3a 100644 --- a/tests/python/relax/test_contrib_vllm.py +++ b/tests/python/relax/test_contrib_vllm.py @@ -14,19 +14,24 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - import numpy as np +import pytest -import torch - +import tvm.testing import tvm from tvm import relax from tvm.script import ir as I from tvm.script import relax as R -def to_torch(arr): - return torch.from_numpy(arr).to("cuda") +has_vllm = tvm.get_global_func("tvm.contrib.vllm.single_query_cached_kv_attention", True) + +vllm_enabled = pytest.mark.skipif( + not has_vllm, + reason="VLLM not enabled.", +) + +pytestmark = [vllm_enabled] def build_and_run(mod, inputs_np, target, legalize=True): @@ -55,17 +60,25 @@ def build_and_run(mod, inputs_np, target, legalize=True): def test_attention(): @I.ir_module class Module: + I.module_global_infos( + { + "vdevice": [ + I.vdevice("llvm"), + ] + } + ) + @R.function def main( - query: R.Tensor(("num_seqs", 12, 64), dtype="float16"), - key_cache: R.Tensor(("num_blocks", 12, 8, 16, 8), dtype="float16"), - value_cache: R.Tensor(("num_blocks", 12, 64, 16), dtype="float16"), - head_mapping: R.Tensor((12,), dtype="int32"), + query: R.Tensor(("num_seqs", 1, 64), dtype="float16"), + key_cache: R.Tensor(("num_blocks", 1, 8, 16, 8), dtype="float16"), + value_cache: R.Tensor(("num_blocks", 1, 64, 16), dtype="float16"), + head_mapping: R.Tensor((1,), dtype="int32"), block_tables: R.Tensor(("num_seqs", "max_num_blocks_per_seq"), dtype="int32"), context_lens: R.Tensor(("num_seqs",), dtype="int32"), - ) -> R.Tensor(("num_seqs", 12, 64), dtype="float16"): + ) -> R.Tensor(("num_seqs", 1, 64), dtype="float16"): with R.dataflow(): - max_len = R.max(context_lens) + max_len = R.to_vdevice(R.max(context_lens), "llvm:0") out = R.call_dps_packed( "tvm.contrib.vllm.single_query_cached_kv_attention", [ @@ -83,12 +96,21 @@ def main( R.output(out) return out - query = np.load("vllm_attention_inputs/query.npy") - key_cache = np.load("vllm_attention_inputs/key_cache.npy") - value_cache = np.load("vllm_attention_inputs/value_cache.npy") - block_tables = np.load("vllm_attention_inputs/block_tables.npy") - head_mapping = np.load("vllm_attention_inputs/head_mapping.npy") - context_lens = np.load("vllm_attention_inputs/context_lens.npy") + np.random.seed(0) + num_heads = 1 + head_dim = 64 + vec_size = 8 + block_size = 16 + num_seqs = 2 + num_blocks = 1 + query = np.random.randn(num_seqs, num_heads, head_dim).astype("float16") + key_cache = np.random.randn( + num_blocks, num_heads, head_dim // vec_size, block_size, vec_size + ).astype("float16") + value_cache = np.random.randn(num_blocks, num_heads, head_dim, block_size).astype("float16") + block_tables = np.array([[0], [0]]).astype("int32") + head_mapping = np.array([0]).astype("int32") + context_lens = np.array([3, 5]).astype("int32") out = build_and_run( Module, @@ -97,25 +119,172 @@ def main( legalize=True, ) - ref = to_torch(np.zeros_like(query)) - - from vllm import attention_ops - - attention_ops.single_query_cached_kv_attention( - ref, - to_torch(query), - to_torch(key_cache), - to_torch(value_cache), - to_torch(head_mapping), - query.shape[-1] ** -0.5, # scale - to_torch(block_tables), - to_torch(context_lens), - value_cache.shape[-1], # block_size, - np.max(context_lens), - None, - ) - - assert np.max(np.abs(ref.cpu().numpy() - out)) == 0.0 + ref = np.array( + [ + [ + [ + 0.28271484375, + 0.197021484375, + -0.278564453125, + 0.444580078125, + -0.47802734375, + -0.7548828125, + -0.84228515625, + -0.80322265625, + 0.478759765625, + 0.195068359375, + -0.59521484375, + 0.779296875, + 0.35888671875, + -0.158935546875, + -0.6103515625, + 0.188720703125, + 0.410400390625, + 0.28662109375, + 0.40283203125, + -1.23046875, + -0.01043701171875, + -0.0794677734375, + -0.0350341796875, + 0.12005615234375, + 0.63671875, + 0.368896484375, + -0.58642578125, + -0.34228515625, + -0.552734375, + 0.947265625, + -0.079833984375, + 0.85302734375, + 0.1947021484375, + 0.16748046875, + -0.083984375, + -0.75244140625, + -0.568359375, + -1.45703125, + -1.021484375, + -0.2235107421875, + -0.98828125, + -0.87109375, + -0.43359375, + -0.3271484375, + 0.0557861328125, + -0.269287109375, + -1.009765625, + 0.1387939453125, + -0.0831298828125, + 0.27978515625, + -0.9736328125, + 0.7802734375, + -0.1329345703125, + -0.5927734375, + -1.6923828125, + 1.1904296875, + -1.3759765625, + -1.080078125, + -0.53173828125, + 0.28466796875, + -2.02734375, + -0.377685546875, + -0.81201171875, + -0.7412109375, + ] + ], + [ + [ + 0.482177734375, + 0.114501953125, + -0.265869140625, + -1.154296875, + 0.28857421875, + 0.71240234375, + -1.1767578125, + 0.187744140625, + -0.23486328125, + 0.07135009765625, + -0.34521484375, + 0.444091796875, + -0.09130859375, + 0.900390625, + -0.043701171875, + 0.61279296875, + 0.1201171875, + 0.443603515625, + -0.4150390625, + -0.9560546875, + -0.1917724609375, + 0.0863037109375, + 0.267578125, + 0.04931640625, + -0.32666015625, + 0.5859375, + -0.57421875, + 0.29541015625, + -0.26220703125, + 1.177734375, + 0.11309814453125, + 0.81201171875, + 0.346435546875, + 0.53271484375, + -0.0009765625, + -0.35205078125, + -0.1298828125, + -1.2431640625, + -0.2196044921875, + 0.31640625, + -0.40869140625, + 0.25244140625, + -0.9853515625, + 0.284912109375, + 0.399169921875, + -1.1435546875, + 0.305419921875, + 0.300048828125, + -0.84521484375, + -0.5166015625, + -0.787109375, + 0.1011962890625, + -1.0302734375, + -1.35546875, + -0.0556640625, + 1.0791015625, + -0.047607421875, + -0.498046875, + -0.055999755859375, + -0.35009765625, + -1.4296875, + 0.350341796875, + -1.16796875, + -0.576171875, + ] + ], + ] + ).astype("float16") + + # from vllm import attention_ops + # import torch + # + # def to_torch(arr): + # return torch.from_numpy(arr).to("cuda") + # + # ref = to_torch(np.zeros_like(query)) + # attention_ops.single_query_cached_kv_attention( + # ref, + # to_torch(query), + # to_torch(key_cache), + # to_torch(value_cache), + # to_torch(head_mapping), + # query.shape[-1] ** -0.5, # scale + # to_torch(block_tables), + # to_torch(context_lens), + # value_cache.shape[-1], # block_size, + # np.max(context_lens), + # None, + # ) + # ref = ref.cpu().numpy() + + # print(ref.tolist()) + + assert np.max(np.abs(ref - out)) == 0.0 def test_cache(): @@ -123,33 +292,50 @@ def test_cache(): class Module: @R.function def main( - key: R.Tensor(("num_tokens", 12, 64), dtype="float16"), - value: R.Tensor(("num_tokens", 12, 64), dtype="float16"), - key_cache: R.Tensor(("num_blocks", 12, 8, 16, 8), dtype="float16"), - value_cache: R.Tensor(("num_blocks", 12, 64, 16), dtype="float16"), + key: R.Tensor(("num_tokens", 1, 8), dtype="float16"), + value: R.Tensor(("num_tokens", 1, 8), dtype="float16"), + key_cache: R.Tensor(("num_blocks", 1, 1, 16, 8), dtype="float16"), + value_cache: R.Tensor(("num_blocks", 1, 8, 16), dtype="float16"), slot_mapping: R.Tensor(("num_tokens",), dtype="int32"), ) -> R.Tuple( [ - R.Tensor(("num_blocks", 12, 8, 16, 8), dtype="float16"), - R.Tensor(("num_blocks", 12, 64, 16), dtype="float16"), + R.Tensor(("num_blocks", 1, 8, 16, 8), dtype="float16"), + R.Tensor(("num_blocks", 1, 8, 16), dtype="float16"), ] ): with R.dataflow(): kv = R.call_pure_packed( "tvm.contrib.vllm.reshape_and_cache", - key, value, key_cache, value_cache, slot_mapping, - sinfo_args=[key_cache.struct_info, value_cache.struct_info] + key, + value, + key_cache, + value_cache, + slot_mapping, + sinfo_args=[key_cache.struct_info, value_cache.struct_info], ) out = (kv[0], kv[1]) R.output(out) return out - key = np.load("vllm_cache_inputs/key_to_cache.npy") - value = np.load("vllm_cache_inputs/value_to_cache.npy") + np.random.seed(0) + num_heads = 1 + head_dim = 8 + vec_size = 8 + block_size = 16 + num_tokens = 8 + num_blocks = 1 + key = np.random.randn(num_tokens, num_heads, head_dim).astype("float16") + value = np.random.randn(num_tokens, num_heads, head_dim).astype("float16") + key_cache_before = np.random.randn( + num_blocks, num_heads, head_dim // vec_size, block_size, vec_size + ).astype("float16") + value_cache_before = np.random.randn(num_blocks, num_heads, head_dim, block_size).astype( + "float16" + ) + slot_mapping = np.arange(num_tokens).astype("int32") - key_cache = np.load("vllm_cache_inputs/key_cache_before.npy") - value_cache = np.load("vllm_cache_inputs/value_cache_before.npy") - slot_mapping = np.load("vllm_cache_inputs/slot_mapping.npy") + key_cache = key_cache_before.copy() + value_cache = value_cache_before.copy() out_key_cache, out_value_cache = build_and_run( Module, @@ -157,21 +343,353 @@ def main( "cuda", ) - from vllm import cache_ops + ref_key_cache = np.array( + [ + [ + [ + [ + [ + 1.763671875, + 0.400146484375, + 0.978515625, + 2.240234375, + 1.8671875, + -0.97705078125, + 0.9501953125, + -0.1513671875, + ], + [ + -0.10321044921875, + 0.41064453125, + 0.14404296875, + 1.4541015625, + 0.76123046875, + 0.1217041015625, + 0.44384765625, + 0.333740234375, + ], + [ + 1.494140625, + -0.2052001953125, + 0.31298828125, + -0.85400390625, + -2.552734375, + 0.65380859375, + 0.8642578125, + -0.7421875, + ], + [ + 2.26953125, + -1.4541015625, + 0.045745849609375, + -0.1871337890625, + 1.533203125, + 1.4697265625, + 0.1549072265625, + 0.378173828125, + ], + [ + -0.8876953125, + -1.98046875, + -0.347900390625, + 0.1563720703125, + 1.23046875, + 1.2021484375, + -0.38720703125, + -0.30224609375, + ], + [ + -1.048828125, + -1.419921875, + -1.7060546875, + 1.951171875, + -0.509765625, + -0.43798828125, + -1.2529296875, + 0.77734375, + ], + [ + -1.6142578125, + -0.2127685546875, + -0.8955078125, + 0.386962890625, + -0.5107421875, + -1.1806640625, + -0.0281829833984375, + 0.42822265625, + ], + [ + 0.0665283203125, + 0.302490234375, + -0.63427734375, + -0.36279296875, + -0.67236328125, + -0.359619140625, + -0.81298828125, + -1.7265625, + ], + [ + -0.039276123046875, + -1.16796875, + 0.5234375, + -0.1715087890625, + 0.77197265625, + 0.82373046875, + 2.1640625, + 1.3369140625, + ], + [ + -0.369140625, + -0.2393798828125, + 1.099609375, + 0.6552734375, + 0.64013671875, + -1.6171875, + -0.024322509765625, + -0.73779296875, + ], + [ + 0.280029296875, + -0.09814453125, + 0.91015625, + 0.317138671875, + 0.7861328125, + -0.46630859375, + -0.9443359375, + -0.41015625, + ], + [ + -0.0170135498046875, + 0.379150390625, + 2.259765625, + -0.042266845703125, + -0.9560546875, + -0.345947265625, + -0.463623046875, + 0.4814453125, + ], + [ + -1.541015625, + 0.063232421875, + 0.156494140625, + 0.232177734375, + -0.59716796875, + -0.2379150390625, + -1.423828125, + -0.493408203125, + ], + [ + -0.54296875, + 0.416015625, + -1.15625, + 0.78125, + 1.494140625, + -2.0703125, + 0.42626953125, + 0.6767578125, + ], + [ + -0.63720703125, + -0.397216796875, + -0.1329345703125, + -0.2978515625, + -0.30908203125, + -1.67578125, + 1.15234375, + 1.080078125, + ], + [ + -0.8134765625, + -1.466796875, + 0.52099609375, + -0.57568359375, + 0.1419677734375, + -0.3193359375, + 0.69140625, + 0.69482421875, + ], + ] + ] + ] + ] + ).astype("float16") + + ref_value_cache = np.array( + [ + [ + [ + [ + 0.1773681640625, + 1.1396484375, + -1.1650390625, + -1.0703125, + 0.010498046875, + -1.1728515625, + -0.861328125, + 0.37646484375, + -1.9365234375, + 0.188720703125, + 0.52392578125, + 0.08843994140625, + -0.310791015625, + 0.097412109375, + 0.39892578125, + -2.7734375, + ], + [ + -0.40185546875, + -1.234375, + 0.90087890625, + 1.0546875, + 1.7861328125, + 1.943359375, + 1.91015625, + -1.099609375, + -0.11053466796875, + 1.0205078125, + -0.69189453125, + 1.5361328125, + 0.286376953125, + 0.60888671875, + -1.044921875, + 1.2109375, + ], + [ + -1.6298828125, + 0.40234375, + 0.465576171875, + -0.403076171875, + 0.126953125, + -0.41357421875, + -0.26806640625, + 0.29833984375, + 0.09771728515625, + 0.5830078125, + -0.3994140625, + 0.3701171875, + -1.306640625, + 1.658203125, + -0.1181640625, + -0.68017578125, + ], + [ + 0.462890625, + -0.6845703125, + -1.5361328125, + 1.22265625, + 0.402099609375, + -0.74755859375, + 0.80224609375, + 1.326171875, + -1.126953125, + -0.73046875, + -0.384765625, + 0.0943603515625, + -0.04217529296875, + -0.286865234375, + -0.061614990234375, + -0.1072998046875, + ], + [ + -0.9072265625, + -0.87060546875, + 1.48828125, + 0.208251953125, + 1.8828125, + 1.9228515625, + 0.947265625, + -0.6943359375, + -0.70458984375, + 0.943359375, + 0.7470703125, + -1.1884765625, + 0.7734375, + -1.18359375, + -2.658203125, + 0.6064453125, + ], + [ + 0.05194091796875, + -0.57861328125, + 1.8955078125, + 0.9765625, + -1.34765625, + 1.48046875, + -0.155029296875, + -0.149658203125, + -0.44091796875, + -0.2802734375, + -0.36474609375, + 0.15673828125, + 0.57861328125, + 0.349609375, + -0.76416015625, + -1.4375, + ], + [ + 0.72900390625, + -0.3115234375, + 1.1787109375, + 0.3564453125, + -1.2705078125, + 1.8671875, + 0.6142578125, + -0.43505859375, + 0.6982421875, + 0.0037708282470703125, + 0.931640625, + 0.33984375, + -0.01568603515625, + 0.160888671875, + -0.190673828125, + -0.394775390625, + ], + [ + 0.1290283203125, + 0.05615234375, + -0.179931640625, + 0.70654296875, + 0.96923828125, + 0.90625, + 0.92236328125, + 1.849609375, + 0.6435546875, + -1.5703125, + -0.2069091796875, + 0.88037109375, + -1.6982421875, + 0.38720703125, + -2.255859375, + -1.0224609375, + ], + ] + ] + ] + ).astype("float16") + + # from vllm import cache_ops + # import torch - key_cache = to_torch(np.load("vllm_cache_inputs/key_cache_before.npy")) - value_cache = to_torch(np.load("vllm_cache_inputs/value_cache_before.npy")) + # def to_torch(arr): + # return torch.from_numpy(arr).to("cuda") - cache_ops.reshape_and_cache( - to_torch(key), - to_torch(value), - key_cache, - value_cache, - to_torch(slot_mapping), - ) + # ref_key_cache = to_torch(key_cache_before.copy()) + # ref_value_cache = to_torch(value_cache_before.copy()) + + # cache_ops.reshape_and_cache( + # to_torch(key), + # to_torch(value), + # ref_key_cache, + # ref_value_cache, + # to_torch(slot_mapping), + # ) + + # ref_key_cache = ref_key_cache.cpu().numpy() + # ref_value_cache = ref_value_cache.cpu().numpy() - assert np.max(np.abs(out_key_cache - key_cache.cpu().numpy())) == 0 - assert np.max(np.abs(out_value_cache - value_cache.cpu().numpy())) == 0 + assert np.max(np.abs(out_key_cache - ref_key_cache)) == 0 + assert np.max(np.abs(out_value_cache - ref_value_cache)) == 0 -test_cache() +if __name__ == "__main__": + tvm.testing.main() From c3c6bdfa5fc8cd76827a6d58d7bcaa03d0ccbb7d Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 27 Oct 2023 01:53:10 +0000 Subject: [PATCH 06/10] cuh -> h --- src/runtime/contrib/vllm/attention_kernels.cu | 2 +- src/runtime/contrib/vllm/{dtype_float16.cuh => dtype_float16.h} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename src/runtime/contrib/vllm/{dtype_float16.cuh => dtype_float16.h} (100%) diff --git a/src/runtime/contrib/vllm/attention_kernels.cu b/src/runtime/contrib/vllm/attention_kernels.cu index 657dfa0cb2bf..10a07fde6ac2 100644 --- a/src/runtime/contrib/vllm/attention_kernels.cu +++ b/src/runtime/contrib/vllm/attention_kernels.cu @@ -15,7 +15,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "dtype_float16.cuh" +#include "dtype_float16.h" #include #include diff --git a/src/runtime/contrib/vllm/dtype_float16.cuh b/src/runtime/contrib/vllm/dtype_float16.h similarity index 100% rename from src/runtime/contrib/vllm/dtype_float16.cuh rename to src/runtime/contrib/vllm/dtype_float16.h From 851c156b4149e98b25699ebbb55ecc5b2625328b Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 27 Oct 2023 10:38:24 +0000 Subject: [PATCH 07/10] reduce lint errors --- src/runtime/contrib/vllm/attention_kernels.cu | 70 +++++++++---------- src/runtime/contrib/vllm/cache_kernels.cu | 19 +++-- src/runtime/contrib/vllm/dtype_float16.h | 2 +- 3 files changed, 45 insertions(+), 46 deletions(-) diff --git a/src/runtime/contrib/vllm/attention_kernels.cu b/src/runtime/contrib/vllm/attention_kernels.cu index 10a07fde6ac2..4ed1a64096ae 100644 --- a/src/runtime/contrib/vllm/attention_kernels.cu +++ b/src/runtime/contrib/vllm/attention_kernels.cu @@ -17,14 +17,14 @@ */ #include "dtype_float16.h" -#include -#include -#include - #include #include #include +#include +#include +#include + #define WARP_SIZE 32 #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) @@ -111,12 +111,12 @@ __global__ void single_query_cached_kv_attention_kernel( const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ context_lens, // [num_seqs] const int max_num_blocks_per_seq, - const float* __restrict__ alibi_slopes, // [num_heads] + const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, const int kv_block_stride, const int kv_head_stride) { constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); - constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS + constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS assert(NUM_THREADS % THREAD_GROUP_SIZE == 0); constexpr int NUM_TOKENS_PER_THREAD_GROUP = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE; constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; @@ -158,7 +158,7 @@ __global__ void single_query_cached_kv_attention_kernel( const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; q_vecs[thread_group_offset][i] = *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); } - __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs + __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs // Memory planning. extern __shared__ char shared_mem[]; @@ -368,7 +368,7 @@ __global__ void single_query_cached_kv_attention_kernel( } } -} // namespace vllm +} // namespace vllm namespace tvm { namespace runtime { @@ -400,14 +400,14 @@ template< int NUM_THREADS = 128> void single_query_cached_kv_attention_launcher( DLTensor* out, - const DLTensor* query, - const DLTensor* key_cache, - const DLTensor* value_cache, - const DLTensor* head_mapping, - float scale, - const DLTensor* block_tables, - const DLTensor* context_lens, - int max_context_len) { + const DLTensor* query, + const DLTensor* key_cache, + const DLTensor* value_cache, + const DLTensor* head_mapping, + float scale, + const DLTensor* block_tables, + const DLTensor* context_lens, + int max_context_len) { int num_seqs = query->shape[0]; int num_heads = query->shape[1]; int head_size = query->shape[2]; @@ -476,26 +476,26 @@ void single_query_cached_kv_attention_launcher( TVM_REGISTER_GLOBAL("tvm.contrib.vllm.single_query_cached_kv_attention") .set_body_typed([](const DLTensor* query, - const DLTensor* key_cache, - const DLTensor* value_cache, - const DLTensor* head_mapping, - const DLTensor* block_tables, - const DLTensor* context_lens, - int block_size, - const DLTensor* max_context_len_tensor, // TODO(masahi): pass integer - DLTensor* out) { + const DLTensor* key_cache, + const DLTensor* value_cache, + const DLTensor* head_mapping, + const DLTensor* block_tables, + const DLTensor* context_lens, + int block_size, + const DLTensor* max_context_len_tensor, // TODO(masahi): pass integer + DLTensor* out) { float scale = 1.0 / sqrt(query->shape[2]); - int max_context_len = ((int*)max_context_len_tensor->data)[0]; - - if (block_size == 8) { - CALL_KERNEL_LAUNCHER(8); - } else if (block_size == 16) { - CALL_KERNEL_LAUNCHER(16); - } else if (block_size == 32) { - CALL_KERNEL_LAUNCHER(32); - } else { - LOG(FATAL) << "Unsupported block size: " << block_size; - } + int max_context_len = static_cast(max_context_len_tensor->data)[0]; + + if (block_size == 8) { + CALL_KERNEL_LAUNCHER(8); + } else if (block_size == 16) { + CALL_KERNEL_LAUNCHER(16); + } else if (block_size == 32) { + CALL_KERNEL_LAUNCHER(32); + } else { + LOG(FATAL) << "Unsupported block size: " << block_size; + } }); } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/vllm/cache_kernels.cu b/src/runtime/contrib/vllm/cache_kernels.cu index e6fa5879d253..ef99a3647b82 100644 --- a/src/runtime/contrib/vllm/cache_kernels.cu +++ b/src/runtime/contrib/vllm/cache_kernels.cu @@ -16,16 +16,15 @@ * specific language governing permissions and limitations * under the License. */ +#include +#include +#include #include #include #include #include -#include -#include -#include - namespace vllm { template @@ -77,7 +76,7 @@ namespace runtime { TVM_REGISTER_GLOBAL("tvm.contrib.vllm.reshape_and_cache") .set_body_typed([](NDArray key, NDArray value, NDArray key_cache, - NDArray value_cache, NDArray slot_mapping) { + NDArray value_cache, NDArray slot_mapping) { int num_tokens = key->shape[0]; int num_heads = key->shape[1]; int head_size = key->shape[2]; @@ -92,11 +91,11 @@ TVM_REGISTER_GLOBAL("tvm.contrib.vllm.reshape_and_cache") using scalar_t = uint16_t; vllm::reshape_and_cache_kernel<<>>( - static_cast(key->data), - static_cast(value->data), - static_cast(key_cache->data), - static_cast(value_cache->data), - static_cast(slot_mapping->data), + static_cast(key->data), + static_cast(value->data), + static_cast(key_cache->data), + static_cast(value_cache->data), + static_cast(slot_mapping->data), key_stride, value_stride, num_heads, diff --git a/src/runtime/contrib/vllm/dtype_float16.h b/src/runtime/contrib/vllm/dtype_float16.h index 2c02aa346a2a..c978c68231e8 100644 --- a/src/runtime/contrib/vllm/dtype_float16.h +++ b/src/runtime/contrib/vllm/dtype_float16.h @@ -719,4 +719,4 @@ inline __device__ void zero(uint16_t& dst) { dst = uint16_t(0); } -} // namespace vllm +} // namespace vllm From b6546f1715ebad2c01e129dd1d52f57196576aae Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 31 Oct 2023 14:55:50 +0900 Subject: [PATCH 08/10] more lint fix --- src/runtime/contrib/vllm/attention_kernels.cu | 12 +- src/runtime/contrib/vllm/cache_kernels.cu | 12 +- src/runtime/contrib/vllm/dtype_float16.h | 176 +++++++----------- 3 files changed, 84 insertions(+), 116 deletions(-) diff --git a/src/runtime/contrib/vllm/attention_kernels.cu b/src/runtime/contrib/vllm/attention_kernels.cu index 4ed1a64096ae..debb68543619 100644 --- a/src/runtime/contrib/vllm/attention_kernels.cu +++ b/src/runtime/contrib/vllm/attention_kernels.cu @@ -15,8 +15,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "dtype_float16.h" - #include #include #include @@ -25,6 +23,8 @@ #include #include +#include "dtype_float16.h" + #define WARP_SIZE 32 #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) @@ -116,7 +116,8 @@ __global__ void single_query_cached_kv_attention_kernel( const int kv_block_stride, const int kv_head_stride) { constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); - constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS + // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS + constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; assert(NUM_THREADS % THREAD_GROUP_SIZE == 0); constexpr int NUM_TOKENS_PER_THREAD_GROUP = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE; constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; @@ -158,7 +159,7 @@ __global__ void single_query_cached_kv_attention_kernel( const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; q_vecs[thread_group_offset][i] = *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); } - __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs + __syncthreads(); // Memory planning. extern __shared__ char shared_mem[]; @@ -206,7 +207,8 @@ __global__ void single_query_cached_kv_attention_kernel( // Compute dot product. // This includes a reduction across the threads in the same thread group. - float qk = scale * Qk_dot::dot(q_vecs[thread_group_offset], k_vecs); + float qk = scale * Qk_dot::dot(q_vecs[thread_group_offset], + k_vecs); // Add the ALiBi bias if slopes are given. qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0; diff --git a/src/runtime/contrib/vllm/cache_kernels.cu b/src/runtime/contrib/vllm/cache_kernels.cu index ef99a3647b82..29ab9bfa2e48 100644 --- a/src/runtime/contrib/vllm/cache_kernels.cu +++ b/src/runtime/contrib/vllm/cache_kernels.cu @@ -29,11 +29,11 @@ namespace vllm { template __global__ void reshape_and_cache_kernel( - const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] - const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] - scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size] - const int* __restrict__ slot_mapping, // [num_tokens] + const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] + const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] + scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size] + const int* __restrict__ slot_mapping, // [num_tokens] const int key_stride, const int value_stride, const int num_heads, @@ -69,7 +69,7 @@ __global__ void reshape_and_cache_kernel( } } -} // namespace vllm +} // namespace vllm namespace tvm { namespace runtime { diff --git a/src/runtime/contrib/vllm/dtype_float16.h b/src/runtime/contrib/vllm/dtype_float16.h index c978c68231e8..e16c10468b76 100644 --- a/src/runtime/contrib/vllm/dtype_float16.h +++ b/src/runtime/contrib/vllm/dtype_float16.h @@ -1,6 +1,8 @@ /* - * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp - * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * and + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h * Copyright (c) 2023, The vLLM team. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * @@ -23,31 +25,31 @@ namespace vllm { // A vector type to store Q, K, V elements. -template +template struct Vec {}; // A vector type to store FP32 accumulators. -template +template struct FloatVec {}; // Template vector operations. -template +template inline __device__ Acc mul(A a, B b); -template +template inline __device__ float sum(T v); -template +template inline __device__ float dot(T a, T b) { return sum(mul(a, b)); } -template +template inline __device__ float dot(T a, T b) { return sum(mul(a, b)); } -template +template inline __device__ void zero(T& dst) { constexpr int WORDS = sizeof(T) / 4; union { @@ -76,37 +78,35 @@ struct Float8_ { }; // FP32 vector types for Q, K, V. -template<> +template <> struct Vec { using Type = float; }; -template<> +template <> struct Vec { using Type = float2; }; -template<> +template <> struct Vec { using Type = float4; }; // FP32 accumulator vector types corresponding to Vec. -template<> +template <> struct FloatVec { using Type = float; }; -template<> +template <> struct FloatVec { using Type = float2; }; -template<> +template <> struct FloatVec { using Type = float4; }; // Vector addition. -inline __device__ float add(float a, float b) { - return a + b; -} +inline __device__ float add(float a, float b) { return a + b; } inline __device__ float2 add(float2 a, float2 b) { float2 c; @@ -125,12 +125,12 @@ inline __device__ float4 add(float4 a, float4 b) { } // Vector multiplication. -template<> +template <> inline __device__ float mul(float a, float b) { return a * b; } -template<> +template <> inline __device__ float2 mul(float2 a, float2 b) { float2 c; c.x = a.x * b.x; @@ -138,7 +138,7 @@ inline __device__ float2 mul(float2 a, float2 b) { return c; } -template<> +template <> inline __device__ float2 mul(float a, float2 b) { float2 c; c.x = a * b.x; @@ -146,7 +146,7 @@ inline __device__ float2 mul(float a, float2 b) { return c; } -template<> +template <> inline __device__ float4 mul(float4 a, float4 b) { float4 c; c.x = a.x * b.x; @@ -156,7 +156,7 @@ inline __device__ float4 mul(float4 a, float4 b) { return c; } -template<> +template <> inline __device__ float4 mul(float a, float4 b) { float4 c; c.x = a * b.x; @@ -167,9 +167,7 @@ inline __device__ float4 mul(float a, float4 b) { } // Vector fused multiply-add. -inline __device__ float fma(float a, float b, float c) { - return a * b + c; -} +inline __device__ float fma(float a, float b, float c) { return a * b + c; } inline __device__ float2 fma(float2 a, float2 b, float2 c) { float2 d; @@ -220,35 +218,33 @@ inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) { } // Vector sum. -template<> +template <> inline __device__ float sum(float v) { return v; } -template<> +template <> inline __device__ float sum(float2 v) { return v.x + v.y; } -template<> +template <> inline __device__ float sum(float4 v) { return v.x + v.y + v.z + v.w; } -template<> +template <> inline __device__ float sum(Float4_ v) { return v.x.x + v.x.y + v.y.x + v.y.y; } -template<> +template <> inline __device__ float sum(Float8_ v) { return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y; } // Vector dot product. -inline __device__ float dot(float a, float b) { - return a * b; -} +inline __device__ float dot(float a, float b) { return a * b; } inline __device__ float dot(float2 a, float2 b) { float2 c = mul(a, b); @@ -270,71 +266,55 @@ inline __device__ float dot(Float8_ a, Float8_ b) { } // From float to float. -inline __device__ void from_float(float& dst, float src) { - dst = src; -} +inline __device__ void from_float(float& dst, float src) { dst = src; } -inline __device__ void from_float(float2& dst, float2 src) { - dst = src; -} +inline __device__ void from_float(float2& dst, float2 src) { dst = src; } -inline __device__ void from_float(float4& dst, float4 src) { - dst = src; -} +inline __device__ void from_float(float4& dst, float4 src) { dst = src; } // From float to float. -inline __device__ float to_float(float u) { - return u; -} +inline __device__ float to_float(float u) { return u; } -inline __device__ float2 to_float(float2 u) { - return u; -} +inline __device__ float2 to_float(float2 u) { return u; } -inline __device__ float4 to_float(float4 u) { - return u; -} +inline __device__ float4 to_float(float4 u) { return u; } -inline __device__ Float4_ to_float(Float4_ u) { - return u; -} +inline __device__ Float4_ to_float(Float4_ u) { return u; } -inline __device__ Float8_ to_float(Float8_ u) { - return u; -} +inline __device__ Float8_ to_float(Float8_ u) { return u; } // FP16 vector types for Q, K, V. -template<> +template <> struct Vec { using Type = uint16_t; }; -template<> +template <> struct Vec { using Type = uint32_t; }; -template<> +template <> struct Vec { using Type = uint2; }; -template<> +template <> struct Vec { using Type = uint4; }; // FP32 accumulator vector types corresponding to Vec. -template<> +template <> struct FloatVec { using Type = float; }; -template<> +template <> struct FloatVec { using Type = float2; }; -template<> +template <> struct FloatVec { using Type = Float4_; }; -template<> +template <> struct FloatVec { using Type = Float8_; }; @@ -433,26 +413,26 @@ inline __device__ Float8_ add(uint4 a, Float8_ fb) { } // Vector multiplication. -template<> +template <> inline __device__ uint16_t mul(uint16_t a, uint16_t b) { uint16_t c; asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); return c; } -template<> +template <> inline __device__ uint32_t mul(uint32_t a, uint32_t b) { uint32_t c; asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); return c; } -template<> +template <> inline __device__ uint32_t mul(uint16_t a, uint32_t b) { return mul(h0_h0(a), b); } -template<> +template <> inline __device__ uint2 mul(uint2 a, uint2 b) { uint2 c; c.x = mul(a.x, b.x); @@ -460,7 +440,7 @@ inline __device__ uint2 mul(uint2 a, uint2 b) { return c; } -template<> +template <> inline __device__ uint2 mul(uint16_t a, uint2 b) { uint32_t s = h0_h0(a); uint2 c; @@ -469,7 +449,7 @@ inline __device__ uint2 mul(uint16_t a, uint2 b) { return c; } -template<> +template <> inline __device__ uint4 mul(uint4 a, uint4 b) { uint4 c; c.x = mul(a.x, b.x); @@ -479,7 +459,7 @@ inline __device__ uint4 mul(uint4 a, uint4 b) { return c; } -template<> +template <> inline __device__ uint4 mul(uint16_t a, uint4 b) { uint32_t s = h0_h0(a); uint4 c; @@ -490,26 +470,26 @@ inline __device__ uint4 mul(uint16_t a, uint4 b) { return c; } -template<> +template <> inline __device__ float mul(uint16_t a, uint16_t b) { float fa = half_to_float(a); float fb = half_to_float(b); return fa * fb; } -template<> +template <> inline __device__ float2 mul(uint32_t a, uint32_t b) { float2 fa = half2_to_float2(a); float2 fb = half2_to_float2(b); return mul(fa, fb); } -template<> +template <> inline __device__ float2 mul(uint16_t a, uint32_t b) { return mul(h0_h0(a), b); } -template<> +template <> inline __device__ Float4_ mul(uint2 a, uint2 b) { Float4_ fc; fc.x = mul(a.x, b.x); @@ -517,7 +497,7 @@ inline __device__ Float4_ mul(uint2 a, uint2 b) { return fc; } -template<> +template <> inline __device__ Float4_ mul(uint16_t a, uint2 b) { uint32_t s = h0_h0(a); Float4_ fc; @@ -526,7 +506,7 @@ inline __device__ Float4_ mul(uint16_t a, uint2 b) { return fc; } -template<> +template <> inline __device__ Float8_ mul(uint4 a, uint4 b) { Float8_ fc; fc.x = mul(a.x, b.x); @@ -536,7 +516,7 @@ inline __device__ Float8_ mul(uint4 a, uint4 b) { return fc; } -template<> +template <> inline __device__ Float8_ mul(uint16_t a, uint4 b) { uint32_t s = h0_h0(a); Float8_ fc; @@ -554,9 +534,7 @@ inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) { return d; } -inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) { - return fma(h0_h0(a), b, c); -} +inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) { return fma(h0_h0(a), b, c); } inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) { uint2 d; @@ -604,9 +582,7 @@ inline __device__ float2 fma(uint32_t a, uint32_t b, float2 fc) { return fma(fa, fb, fc); } -inline __device__ float2 fma(uint16_t a, uint32_t b, float2 fc) { - return fma(h0_h0(a), b, fc); -} +inline __device__ float2 fma(uint16_t a, uint32_t b, float2 fc) { return fma(h0_h0(a), b, fc); } inline __device__ Float4_ fma(uint2 a, uint2 b, Float4_ fc) { Float4_ fd; @@ -643,24 +619,24 @@ inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) { } // Vector sum. -template<> +template <> inline __device__ float sum(uint16_t v) { return half_to_float(v); } -template<> +template <> inline __device__ float sum(uint32_t v) { float2 tmp = half2_to_float2(v); return tmp.x + tmp.y; } -template<> +template <> inline __device__ float sum(uint2 v) { uint32_t c = add(v.x, v.y); return sum(c); } -template<> +template <> inline __device__ float sum(uint4 v) { uint32_t c = add(v.x, v.y); c = add(c, v.z); @@ -669,13 +645,9 @@ inline __device__ float sum(uint4 v) { } // From float32 to float16. -inline __device__ void from_float(uint16_t& dst, float src) { - dst = float_to_half(src); -} +inline __device__ void from_float(uint16_t& dst, float src) { dst = float_to_half(src); } -inline __device__ void from_float(uint32_t& dst, float2 src) { - dst = float2_to_half2(src); -} +inline __device__ void from_float(uint32_t& dst, float2 src) { dst = float2_to_half2(src); } inline __device__ void from_float(uint2& dst, Float4_ src) { dst.x = float2_to_half2(src.x); @@ -690,13 +662,9 @@ inline __device__ void from_float(uint4& dst, Float8_ src) { } // From float16 to float32. -inline __device__ float to_float(uint16_t u) { - return half_to_float(u); -} +inline __device__ float to_float(uint16_t u) { return half_to_float(u); } -inline __device__ float2 to_float(uint32_t u) { - return half2_to_float2(u); -} +inline __device__ float2 to_float(uint32_t u) { return half2_to_float2(u); } inline __device__ Float4_ to_float(uint2 u) { Float4_ tmp; @@ -715,8 +683,6 @@ inline __device__ Float8_ to_float(uint4 u) { } // Zero-out a variable. -inline __device__ void zero(uint16_t& dst) { - dst = uint16_t(0); -} +inline __device__ void zero(uint16_t& dst) { dst = uint16_t(0); } } // namespace vllm From 724d5e92a3a73d5dbb71cfa2b83df7ddf2f3e57b Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 31 Oct 2023 06:10:20 +0000 Subject: [PATCH 09/10] Avoid hard-coded CMAKE_CUDA_ARCHITECTURES --- cmake/config.cmake | 4 ++++ cmake/modules/CUDA.cmake | 7 ++++++- cmake/modules/contrib/vllm.cmake | 7 ++++++- 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/cmake/config.cmake b/cmake/config.cmake index bf0a49b1aa18..0ef8952ea4fd 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -442,3 +442,7 @@ set(USE_UMA OFF) # Set custom Alloc Alignment for device allocated memory ndarray points to set(USE_KALLOC_ALIGNMENT 64) + +# List of architectures to generate CUDA device code for, only used for +# compiling external kernels from Thrust and vLLM. +set(CMAKE_CUDA_ARCHITECTURES "80;75") diff --git a/cmake/modules/CUDA.cmake b/cmake/modules/CUDA.cmake index d7c7f1bf530c..f2370648e1a6 100644 --- a/cmake/modules/CUDA.cmake +++ b/cmake/modules/CUDA.cmake @@ -64,7 +64,12 @@ if(USE_CUDA) message(STATUS "Build with Thrust support") cmake_minimum_required(VERSION 3.13) # to compile CUDA code enable_language(CUDA) - set(CMAKE_CUDA_ARCHITECTURES "80;75") + + if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) + message(WARNING "CMAKE_CUDA_ARCHITECTURES not set, compiling Thrust for sm80 and sm75.") + set(CMAKE_CUDA_ARCHITECTURES "80;75") + endif() + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda") tvm_file_glob(GLOB CONTRIB_THRUST_SRC src/runtime/contrib/thrust/*.cu) list(APPEND RUNTIME_SRCS ${CONTRIB_THRUST_SRC}) diff --git a/cmake/modules/contrib/vllm.cmake b/cmake/modules/contrib/vllm.cmake index 551ea6c8db18..ae9474cb8b39 100644 --- a/cmake/modules/contrib/vllm.cmake +++ b/cmake/modules/contrib/vllm.cmake @@ -19,7 +19,12 @@ if(USE_VLLM) message(STATUS "Build with vllm paged attention kernel.") include_directories(src/runtime/contrib/vllm) enable_language(CUDA) - set(CMAKE_CUDA_ARCHITECTURES "80;75") + + if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) + message(WARNING "CMAKE_CUDA_ARCHITECTURES not set, compiling vLLM kernels for sm80 and sm75.") + set(CMAKE_CUDA_ARCHITECTURES "80;75") + endif() + tvm_file_glob(GLOB VLLM_CONTRIB_SRC src/runtime/contrib/vllm/*.cu src/runtime/contrib/vllm/*.cc) list(APPEND RUNTIME_SRCS ${VLLM_CONTRIB_SRC}) endif(USE_VLLM) From 971483d0d369d80ac2d82a291edaf70af7832e30 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 31 Oct 2023 11:12:25 +0000 Subject: [PATCH 10/10] more lint fix --- src/runtime/contrib/vllm/attention_kernels.cu | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/runtime/contrib/vllm/attention_kernels.cu b/src/runtime/contrib/vllm/attention_kernels.cu index debb68543619..8176aea7c186 100644 --- a/src/runtime/contrib/vllm/attention_kernels.cu +++ b/src/runtime/contrib/vllm/attention_kernels.cu @@ -15,12 +15,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + +#include + #include #include #include #include -#include #include #include "dtype_float16.h" @@ -208,7 +210,7 @@ __global__ void single_query_cached_kv_attention_kernel( // Compute dot product. // This includes a reduction across the threads in the same thread group. float qk = scale * Qk_dot::dot(q_vecs[thread_group_offset], - k_vecs); + k_vecs); // Add the ALiBi bias if slopes are given. qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;