From ae336d5e261b9312994bf263379144ca5ec193b4 Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Thu, 29 Jun 2023 01:51:25 +0800 Subject: [PATCH 01/54] init library --- .gitignore | 51 ++ CMakeLists.txt | 47 + include/llm_fc.hpp | 64 ++ include/llm_mm.hpp | 35 + include/llm_types.hpp | 36 + src/CMakeLists.txt | 13 + src/bf16.hpp | 249 ++++++ src/fc_interface.cpp | 316 +++++++ src/mm_interface.cpp | 122 +++ src/mm_kernel_amx.hpp | 1631 ++++++++++++++++++++++++++++++++++ src/tensor2d.hpp | 178 ++++ src/tensor2d_helper.hpp | 200 +++++ src/utility_amx.hpp | 107 +++ src/utility_avx512.hpp | 94 ++ tests/CMakeLists.txt | 22 + tests/src/test_common.cpp | 61 ++ tests/src/test_common.h | 218 +++++ tests/src/test_fc_kernel.cpp | 217 +++++ tests/src/test_mm_kernel.cpp | 137 +++ 19 files changed, 3798 insertions(+) create mode 100644 .gitignore create mode 100644 CMakeLists.txt create mode 100644 include/llm_fc.hpp create mode 100644 include/llm_mm.hpp create mode 100644 include/llm_types.hpp create mode 100644 src/CMakeLists.txt create mode 100644 src/bf16.hpp create mode 100644 src/fc_interface.cpp create mode 100644 src/mm_interface.cpp create mode 100644 src/mm_kernel_amx.hpp create mode 100644 src/tensor2d.hpp create mode 100644 src/tensor2d_helper.hpp create mode 100644 src/utility_amx.hpp create mode 100644 src/utility_avx512.hpp create mode 100644 tests/CMakeLists.txt create mode 100644 tests/src/test_common.cpp create mode 100644 tests/src/test_common.h create mode 100644 tests/src/test_fc_kernel.cpp create mode 100644 tests/src/test_mm_kernel.cpp diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..d1f16b8 --- /dev/null +++ b/.gitignore @@ -0,0 +1,51 @@ +# Compiled Object files +**/.DS_Store +*.slo +*.lo +*.o +*.obj + +# Precompiled Headers +*.gch +*.pch + +# Compiled Dynamic libraries +*.so +*.dylib +*.dll + +# Fortran module files +*.mod +*.smod + +# Compiled Static libraries +*.lai +*.la +*.a +*.lib + +# Executables +*.exe +*.out +*.app + +**/cmake-build-debug +**/CMakeCache.txt +**/cmake_install.cmake +**/install_manifest.txt +**/CMakeFiles/ +**/CTestTestfile.cmake +**/Makefile +**/*.cbp +**/CMakeScripts +**/compile_commands.json + + +## Local + +build/**/* +out/* +lib/* +bin/* +test/test_runner +.vs \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..4cd1f2e --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,47 @@ +cmake_minimum_required(VERSION 3.18) + +project(root) + +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) +option(BUILD_TESTS "Build with tests" ON) + +message(INFO "--------------------------------") +message(STATUS "Build with tests: ${BUILD_TESTS}") +message(INFO "--------------------------------") + +set(CMAKE_CXX_STANDARD 17) +if(MSVC) + # TODO + message(FATAL_ERROR "Not support yet. Use intel compiler 2023.0+.") + # Force to always compile with W4 + if(CMAKE_CXX_FLAGS MATCHES "/W[0-4]") + string(REGEX REPLACE "/W[0-4]" "/W4" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") + else() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /W4") + endif() +elseif(CMAKE_COMPILER_IS_GNUCC OR CMAKE_COMPILER_IS_GNUCXX) + # TODO + message(FATAL_ERROR "Not support yet. Use intel compiler 2023.0+.") + if(CMAKE_CXX_COMPILER_VERSION VERSION_LESS "12.0") + message(FATAL_ERROR "Insufficient gcc compiler version, current ${CMAKE_CXX_COMPILER_VERSION}, minimum 12.0.") + endif() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=sapphirerapids") +elseif(CMAKE_CXX_COMPILER_ID STREQUAL "IntelLLVM") + if(CMAKE_CXX_COMPILER_VERSION VERSION_LESS "2023.0") + message(FATAL_ERROR "Insufficient intel compiler version, current ${CMAKE_CXX_COMPILER_VERSION}, minimum 2023.0.") + endif() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=sapphirerapids") +endif() + +include_directories(${PROJECT_SOURCE_DIR}/include) +include_directories(${PROJECT_SOURCE_DIR}/src/) + +if (NOT CMAKE_RUNTIME_OUTPUT_DIRECTORY) + set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) + set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) + set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) +endif() +add_subdirectory(src) +if (BUILD_TESTS) + add_subdirectory(tests) +endif() diff --git a/include/llm_fc.hpp b/include/llm_fc.hpp new file mode 100644 index 0000000..d503d8a --- /dev/null +++ b/include/llm_fc.hpp @@ -0,0 +1,64 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "llm_types.hpp" + +namespace llmdnn { + +typedef enum { + NONE = 0, + DEQUANT = 1 << 0, + BIAS = 1 << 1, + GELU = 1 << 2, + QUANT = 1 << 3, + + BIAS_GELU = BIAS | GELU, + DEQUANT_BIAS_GELU = DEQUANT | BIAS_GELU, + DEQUANT_BIAS_GELU_QUANT = DEQUANT_BIAS_GELU | QUANT, + DEQUANT_BIAS_QUANT = DEQUANT | BIAS | QUANT, + DEQUANT_GELU_QUANT = DEQUANT | GELU | QUANT, + DEQUANT_QUANT = DEQUANT | QUANT, + + DEQUANT_GELU = DEQUANT | GELU, + DEQUANT_BIAS = DEQUANT | BIAS +} postops_types; + +struct fc_create_param { + data_type_t dt_a; + data_type_t dt_b; + data_type_t dt_c; + bool b_is_trans; + postops_types postops_type; +}; + +struct fc_kernel; + +/// Generates a mm kernel based on param +/// +/// @param mm Output kernel +/// @param param kernel parameters, supported: +/// fc: (s8,s8,s8),dq,[bias],[gelu],q +/// fc: (s8,s8,bf16),dq,[bias],[gelu] +/// fc: (s8,s8,f32),dq,[bias],[gelu] +/// fc: (bf16,bf16,bf16),[bias],[gelu] +/// fc: (bf16,bf16,f32),[bias],[gelu] +/// fc: (bf16,s8,f32),dq,[bias],[gelu] +/// fc: (bf16,s8,bf16),dq,[bias],[gelu] +/// +bool fc_kernel_create(fc_kernel** mm, const fc_create_param* param); +void fc_kernel_destroy(const fc_kernel* mm); +void fc_kernel_execute(const fc_kernel* mm, + void* ptr_a, void* ptr_b, void* ptr_c, size_t lda, size_t ldb, size_t ldc, + size_t M, size_t N, size_t K, size_t n_start, size_t n_end, + float* dq=nullptr, float* q=nullptr, float* bias=nullptr); + +/// weight compression +/// compute weight min/max once, set q, dq for each fc_kernel instance +void fc_kernel_bf16w8_get_q_dq(size_t K, size_t N, size_t stride, void* ptr, float* q, float* dq); +/// set q, dq for each fc_kernel instance, must call before first fc_kernel_execute +void fc_kernel_bf16w8_set_q_dq(const fc_kernel* mm, float q, float dq); + +} diff --git a/include/llm_mm.hpp b/include/llm_mm.hpp new file mode 100644 index 0000000..bb187f7 --- /dev/null +++ b/include/llm_mm.hpp @@ -0,0 +1,35 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "llm_types.hpp" + +namespace llmdnn { + +struct mm_create_param { + data_type_t dt_a; + data_type_t dt_b; + bool b_is_gemv; // true if matrix b is vector. Shape: a[M,K], b[K,1], c[M,1] + bool b_is_trans; +}; + +struct mm_kernel; + +/// Generates a mm kernel based on param +/// +/// @param mm Output kernel +/// @param param kernel parameters, supported: +/// matmul: (u8/s8,s8,f32) +/// gemv: (s8,s8,f32) +/// matmul: (bf16,bf16,f32) +/// gemv: (bf16,bf16,f32) +/// +bool mm_kernel_create(mm_kernel** mm, const mm_create_param* param); +void mm_kernel_destroy(const mm_kernel* mm); + +void mm_kernel_execute(const mm_kernel* mm, void* ptr_a, void* ptr_b, void* ptr_c, size_t lda, size_t ldb, size_t ldc, + size_t M, size_t N, size_t K); + +} diff --git a/include/llm_types.hpp b/include/llm_types.hpp new file mode 100644 index 0000000..3ace6d6 --- /dev/null +++ b/include/llm_types.hpp @@ -0,0 +1,36 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +namespace llmdnn { + +// from oneDNN +/// Data type specification +typedef enum { + /// Undefined data type, used for empty memory descriptors. + dnnl_data_type_undef = 0, + /// 16-bit/half-precision floating point. + dnnl_f16 = 1, + /// non-standard 16-bit (bfloat16 w/ 7 bit mantissa) floating point. + dnnl_bf16 = 2, + /// 32-bit/single-precision floating point. + dnnl_f32 = 3, + /// 32-bit signed integer. + dnnl_s32 = 4, + /// 8-bit signed integer. + dnnl_s8 = 5, + /// 8-bit unsigned integer. + dnnl_u8 = 6, + /// 64-bit/double-precision floating point. + dnnl_f64 = 7, + + /// Parameter to allow internal only data_types without undefined behavior. + /// This parameter is chosen to be valid for so long as sizeof(int) >= 2. + dnnl_data_type_max = 0x7fff, +} data_type_t; + +} \ No newline at end of file diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt new file mode 100644 index 0000000..5d118ab --- /dev/null +++ b/src/CMakeLists.txt @@ -0,0 +1,13 @@ +cmake_minimum_required(VERSION 3.18) +project(llmdnn) + +file(GLOB_RECURSE SOURCE_FILES ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp) + +add_library(llmdnn STATIC ${SOURCE_FILES}) +add_compile_definitions(DNNL_CPU_THREADING_RUNTIME=DNNL_RUNTIME_TBB) +target_compile_definitions(llmdnn PRIVATE LLMDNN_EXPORT) +target_include_directories(llmdnn INTERFACE ${CMAKE_CURRENT_SOURCE_DIR}) +set_target_properties(llmdnn PROPERTIES + POSITION_INDEPENDENT_CODE ON + ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}) +install(TARGETS llmdnn DESTINATION ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}) diff --git a/src/bf16.hpp b/src/bf16.hpp new file mode 100644 index 0000000..35f42cc --- /dev/null +++ b/src/bf16.hpp @@ -0,0 +1,249 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include +#include +#include + +#define ROUND_MODE_TO_NEAREST_EVEN + +#define OPENVINO_API + +namespace ov { +class OPENVINO_API bfloat16 { +public: + constexpr bfloat16() : m_value{0} {} + bfloat16(float value) : m_value { +#if defined ROUND_MODE_TO_NEAREST + round_to_nearest(value) +#elif defined ROUND_MODE_TO_NEAREST_EVEN + round_to_nearest_even(value) +#elif defined ROUND_MODE_TRUNCATE + truncate(value) +#else +# error "ROUNDING_MODE must be one of ROUND_MODE_TO_NEAREST, ROUND_MODE_TO_NEAREST_EVEN, or ROUND_MODE_TRUNCATE" +#endif + } + {} + + template + explicit bfloat16(I value) : m_value{bfloat16{static_cast(value)}.m_value} {} + + std::string to_string() const; + size_t size() const; + template + bool operator==(const T& other) const; + template + bool operator!=(const T& other) const { + return !(*this == other); + } + template + bool operator<(const T& other) const; + template + bool operator<=(const T& other) const; + template + bool operator>(const T& other) const; + template + bool operator>=(const T& other) const; + template + bfloat16 operator+(const T& other) const; + template + bfloat16 operator+=(const T& other); + template + bfloat16 operator-(const T& other) const; + template + bfloat16 operator-=(const T& other); + template + bfloat16 operator*(const T& other) const; + template + bfloat16 operator*=(const T& other); + template + bfloat16 operator/(const T& other) const; + template + bfloat16 operator/=(const T& other); + operator float() const { + uint32_t tmp = 0; + uint32_t* ptmp = &tmp; + *ptmp = (static_cast(m_value) << 16); + const float* f = reinterpret_cast(ptmp); + return *f; + } + + static std::vector to_float_vector(const std::vector&); + static std::vector from_float_vector(const std::vector&); + static constexpr bfloat16 from_bits(uint16_t bits) { + return bfloat16(bits, true); + } + uint16_t to_bits() const; + friend std::ostream& operator<<(std::ostream& out, const bfloat16& obj) { + out << static_cast(obj); + return out; + } + +#define cu32(x) (F32(x).i) + + static uint16_t round_to_nearest_even(float x) { + return static_cast((cu32(x) + ((cu32(x) & 0x00010000) >> 1)) >> 16); + } + + static uint16_t round_to_nearest(float x) { + return static_cast((cu32(x) + 0x8000) >> 16); + } + + static uint16_t truncate(float x) { + return static_cast((cu32(x)) >> 16); + } + +private: + constexpr bfloat16(uint16_t x, bool) : m_value{x} {} + union F32 { + F32(float val) : f{val} {} + F32(uint32_t val) : i{val} {} + float f; + uint32_t i; + }; + + uint16_t m_value; +}; + +#if defined(_MSC_VER) +# pragma warning(push) +# pragma warning(disable : 4756) +#endif +template +bool bfloat16::operator==(const T& other) const { +#if defined(__GNUC__) +# pragma GCC diagnostic push +# pragma GCC diagnostic ignored "-Wfloat-equal" +#endif + return (static_cast(*this) == static_cast(other)); +#if defined(__GNUC__) +# pragma GCC diagnostic pop +#endif +} + +template +bool bfloat16::operator<(const T& other) const { + return (static_cast(*this) < static_cast(other)); +} + +template +bool bfloat16::operator<=(const T& other) const { + return (static_cast(*this) <= static_cast(other)); +} + +template +bool bfloat16::operator>(const T& other) const { + return (static_cast(*this) > static_cast(other)); +} + +template +bool bfloat16::operator>=(const T& other) const { + return (static_cast(*this) >= static_cast(other)); +} + +template +bfloat16 bfloat16::operator+(const T& other) const { + return {static_cast(*this) + static_cast(other)}; +} + +template +bfloat16 bfloat16::operator+=(const T& other) { + return *this = *this + other; +} + +template +bfloat16 bfloat16::operator-(const T& other) const { + return {static_cast(*this) - static_cast(other)}; +} + +template +bfloat16 bfloat16::operator-=(const T& other) { + return *this = *this - other; +} + +template +bfloat16 bfloat16::operator*(const T& other) const { + return {static_cast(*this) * static_cast(other)}; +} + +template +bfloat16 bfloat16::operator*=(const T& other) { + return *this = *this * other; +} + +template +bfloat16 bfloat16::operator/(const T& other) const { + return {static_cast(*this) / static_cast(other)}; +} + +template +bfloat16 bfloat16::operator/=(const T& other) { + return *this = *this / other; +} +#if defined(_MSC_VER) +# pragma warning(pop) +#endif +} // namespace ov + +namespace std { +template <> +class numeric_limits { +public: + static constexpr bool is_specialized = true; + static constexpr ov::bfloat16 min() noexcept { + return ov::bfloat16::from_bits(0x007F); + } + static constexpr ov::bfloat16 max() noexcept { + return ov::bfloat16::from_bits(0x7F7F); + } + static constexpr ov::bfloat16 lowest() noexcept { + return ov::bfloat16::from_bits(0xFF7F); + } + static constexpr int digits = 7; + static constexpr int digits10 = 2; + static constexpr bool is_signed = true; + static constexpr bool is_integer = false; + static constexpr bool is_exact = false; + static constexpr int radix = 2; + static constexpr ov::bfloat16 epsilon() noexcept { + return ov::bfloat16::from_bits(0x3C00); + } + static constexpr ov::bfloat16 round_error() noexcept { + return ov::bfloat16::from_bits(0x3F00); + } + static constexpr int min_exponent = -125; + static constexpr int min_exponent10 = -37; + static constexpr int max_exponent = 128; + static constexpr int max_exponent10 = 38; + static constexpr bool has_infinity = true; + static constexpr bool has_quiet_NaN = true; + static constexpr bool has_signaling_NaN = true; + static constexpr float_denorm_style has_denorm = denorm_absent; + static constexpr bool has_denorm_loss = false; + static constexpr ov::bfloat16 infinity() noexcept { + return ov::bfloat16::from_bits(0x7F80); + } + static constexpr ov::bfloat16 quiet_NaN() noexcept { + return ov::bfloat16::from_bits(0x7FC0); + } + static constexpr ov::bfloat16 signaling_NaN() noexcept { + return ov::bfloat16::from_bits(0x7FC0); + } + static constexpr ov::bfloat16 denorm_min() noexcept { + return ov::bfloat16::from_bits(0); + } + static constexpr bool is_iec559 = false; + static constexpr bool is_bounded = false; + static constexpr bool is_modulo = false; + static constexpr bool traps = false; + static constexpr bool tinyness_before = false; + static constexpr float_round_style round_style = round_to_nearest; +}; +} // namespace std diff --git a/src/fc_interface.cpp b/src/fc_interface.cpp new file mode 100644 index 0000000..6a34a4e --- /dev/null +++ b/src/fc_interface.cpp @@ -0,0 +1,316 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "llm_fc.hpp" +#include "mm_kernel_amx.hpp" +#include "utility_avx512.hpp" + +namespace llmdnn { + +using ov::bfloat16; +struct fc_kernel { + std::shared_ptr> bf16xbf16; + std::shared_ptr> bf16xi8; + std::shared_ptr> i8xi8; + std::shared_ptr> u8xi8; + + data_type_t dt_a; + data_type_t dt_b; + data_type_t dt_c; + postops_types postops_type; + bool b_is_transpose; +}; + +using supported_key = std::tuple; +using supported_value = std::pair; +static std::map supported_postops = { + { { dnnl_s8, dnnl_s8, dnnl_s8 }, { DEQUANT | QUANT, BIAS | GELU } }, + { { dnnl_s8, dnnl_s8, dnnl_bf16 }, { DEQUANT, BIAS | GELU } }, + { { dnnl_s8, dnnl_s8, dnnl_f32 }, { DEQUANT, BIAS | GELU } }, + { { dnnl_bf16, dnnl_bf16, dnnl_bf16 }, { 0, BIAS | GELU } }, + { { dnnl_bf16, dnnl_bf16, dnnl_f32 }, { 0, BIAS | GELU } }, + { { dnnl_bf16, dnnl_s8, dnnl_f32 }, { DEQUANT, BIAS | GELU } }, + { { dnnl_bf16, dnnl_s8, dnnl_bf16 }, { DEQUANT, BIAS | GELU } }, +}; + +static bool check_valid_postops(size_t value, data_type_t dt_a, data_type_t dt_b, data_type_t dt_c) { + auto it = supported_postops.find(std::make_tuple(dt_a, dt_b, dt_c)); + if (it == supported_postops.end()) { + return false; + } + + size_t must_have; + size_t opt_have; + must_have = (*it).second.first; + opt_have = (*it).second.second; + + if ((value & must_have) != must_have) + return false; + // value must in must_have and opt_have + if ((value & ~(must_have | opt_have)) != 0) + return false; + + return true; +} + +// interface +bool fc_kernel_create(fc_kernel** mm, const fc_create_param* param) { + fc_kernel* m = nullptr; + if (param == nullptr || mm == nullptr) { + std::cout << "fc_kernel_create: invalid input parameter.\n"; + goto ERR; + } + + if (!check_valid_postops(static_cast(param->postops_type), param->dt_a, param->dt_b, param->dt_c)) { + std::cout << "fc_kernel_create: unsupported data type, a: " << param->dt_a <<", b: " << param->dt_b << ", c: " << param->dt_c << + ", postops type: " << param->postops_type << ".\n"; + goto ERR; + } + + m = new fc_kernel; + if (param->dt_a == dnnl_s8 && param->dt_b == dnnl_s8) { + m->i8xi8 = std::make_shared>(true, param->b_is_trans); + } else if (param->dt_a == dnnl_u8 && param->dt_b == dnnl_s8) { + m->u8xi8 = std::make_shared>(true, param->b_is_trans); + } else if (param->dt_a == dnnl_bf16 && param->dt_b == dnnl_bf16) { + m->bf16xbf16 = std::make_shared>(true, param->b_is_trans); + } else if (param->dt_a == dnnl_bf16 && param->dt_b == dnnl_s8) { + m->bf16xi8 = std::make_shared>(true, param->b_is_trans); + } else { + std::cout << "fc_kernel_create: unsupport input type, a: " << param->dt_a << ", b: " << param->dt_b << ".\n"; + goto ERR; + } + + m->dt_a = param->dt_a; + m->dt_b = param->dt_b; + m->dt_c = param->dt_c; + m->b_is_transpose = param->b_is_trans; + m->postops_type = param->postops_type; + + *mm = m; + return true; +ERR: + delete m; + return false; +} + +void fc_kernel_destroy(const fc_kernel* mm) { + if (mm) { + delete mm; + } +} + +void fc_kernel_execute(const fc_kernel* mm, void* ptr_a, void* ptr_b, void* ptr_c, size_t lda, size_t ldb, size_t ldc, + size_t M, size_t N, size_t K, size_t n_start, size_t n_end, float* dq, float* q, float* bias) { + size_t b_d0 = K, b_d1 = N; + if (mm->b_is_transpose) { + b_d0 = N; + b_d1 = K; + } + if (mm->i8xi8) { + tensor2D a(M, K, reinterpret_cast(ptr_a), lda); + tensor2D b(b_d0, b_d1, reinterpret_cast(ptr_b), ldb); + + if (mm->dt_c == dnnl_s8) { + tensor2D c(M, N, reinterpret_cast(ptr_c), ldc); + if (!(mm->postops_type & BIAS)) { + if (mm->postops_type & GELU) { + amx_kernel::PP::BiasGeluStore ppkernel(c); + ppkernel.set_deq_scale(dq); + ppkernel.set_q_scale(q); + (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); + } else { + amx_kernel::PP::BiasGeluStore ppkernel(c); + ppkernel.set_deq_scale(dq); + ppkernel.set_q_scale(q); + (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); + } + } else { + if (mm->postops_type & GELU) { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + ppkernel.set_deq_scale(dq); + ppkernel.set_q_scale(q); + (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); + } else { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + ppkernel.set_deq_scale(dq); + ppkernel.set_q_scale(q); + (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); + } + } + } else if (mm->dt_c == dnnl_bf16) { + tensor2D c(M, N, reinterpret_cast(ptr_c), ldc); + if (!bias) { + if (mm->postops_type & GELU) { + amx_kernel::PP::BiasGeluStore ppkernel(c); + ppkernel.set_deq_scale(dq); + (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); + } else { + amx_kernel::PP::BiasGeluStore ppkernel(c); + ppkernel.set_deq_scale(dq); + (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); + } + } else { + if (mm->postops_type & GELU) { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + ppkernel.set_deq_scale(dq); + (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); + } else { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + ppkernel.set_deq_scale(dq); + (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); + } + } + } else if (mm->dt_c == dnnl_f32) { + tensor2D c(M, N, reinterpret_cast(ptr_c), ldc); + if (!bias) { + if (mm->postops_type & GELU) { + amx_kernel::PP::BiasGeluStore ppkernel(c); + ppkernel.set_deq_scale(dq); + (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); + } else { + amx_kernel::PP::BiasGeluStore ppkernel(c); + ppkernel.set_deq_scale(dq); + (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); + } + } else { + if (mm->postops_type & GELU) { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + ppkernel.set_deq_scale(dq); + (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); + } else { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + ppkernel.set_deq_scale(dq); + (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); + } + } + } + } else if (mm->u8xi8) { + tensor2D a(M, K, reinterpret_cast(ptr_a), lda); + tensor2D b(b_d0, b_d1, reinterpret_cast(ptr_b), ldb); + tensor2D c(M, N, reinterpret_cast(ptr_c), ldc); + amx_kernel::PP::BiasGeluStore pp(c); + (*mm->u8xi8)(a, b, n_start, n_end, pp); + } else if (mm->bf16xbf16) { + tensor2D a(M, K, reinterpret_cast(ptr_a), lda); + tensor2D b(b_d0, b_d1, reinterpret_cast(ptr_b), ldb); + + if (mm->dt_c == dnnl_bf16) { + tensor2D c(M, N, reinterpret_cast(ptr_c), ldc); + if (!(mm->postops_type & BIAS)) { + if (mm->postops_type & GELU) { + amx_kernel::PP::BiasGeluStore ppkernel(c); + (*mm->bf16xbf16)(a, b, n_start, n_end, ppkernel); + } else { + amx_kernel::PP::BiasGeluStore ppkernel(c); + (*mm->bf16xbf16)(a, b, n_start, n_end, ppkernel); + } + } else { + if (mm->postops_type & GELU) { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + (*mm->bf16xbf16)(a, b, n_start, n_end, ppkernel); + } else { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + (*mm->bf16xbf16)(a, b, n_start, n_end, ppkernel); + } + } + } else if (mm->dt_c == dnnl_f32) { + tensor2D c(M, N, reinterpret_cast(ptr_c), ldc); + if (!(mm->postops_type & BIAS)) { + if (mm->postops_type & GELU) { + amx_kernel::PP::BiasGeluStore ppkernel(c); + (*mm->bf16xbf16)(a, b, n_start, n_end, ppkernel); + } else { + amx_kernel::PP::BiasGeluStore ppkernel(c); + (*mm->bf16xbf16)(a, b, n_start, n_end, ppkernel); + } + } else { + if (mm->postops_type & GELU) { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + (*mm->bf16xbf16)(a, b, n_start, n_end, ppkernel); + } else { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + (*mm->bf16xbf16)(a, b, n_start, n_end, ppkernel); + } + } + } + } else { + tensor2D a(M, K, reinterpret_cast(ptr_a), lda); + tensor2D b(N, K, reinterpret_cast(ptr_b), ldb); + + if (mm->dt_c == dnnl_bf16) { + tensor2D c(M, N, reinterpret_cast(ptr_c), ldc); + if (!(mm->postops_type & BIAS)) { + if (mm->postops_type & GELU) { + amx_kernel::PP::BiasGeluStore ppkernel(c); + (*mm->bf16xi8)(a, b, n_start, n_end, ppkernel); + } else { + amx_kernel::PP::BiasGeluStore ppkernel(c); + (*mm->bf16xi8)(a, b, n_start, n_end, ppkernel); + } + } else { + if (mm->postops_type & GELU) { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + (*mm->bf16xi8)(a, b, n_start, n_end, ppkernel); + } else { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + (*mm->bf16xi8)(a, b, n_start, n_end, ppkernel); + } + } + } else if (mm->dt_c == dnnl_f32) { + tensor2D c(M, N, reinterpret_cast(ptr_c), ldc); + if (!(mm->postops_type & BIAS)) { + if (mm->postops_type & GELU) { + amx_kernel::PP::BiasGeluStore ppkernel(c); + (*mm->bf16xi8)(a, b, n_start, n_end, ppkernel); + } else { + amx_kernel::PP::BiasGeluStore ppkernel(c); + (*mm->bf16xi8)(a, b, n_start, n_end, ppkernel); + } + } else { + if (mm->postops_type & GELU) { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + (*mm->bf16xi8)(a, b, n_start, n_end, ppkernel); + } else { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + (*mm->bf16xi8)(a, b, n_start, n_end, ppkernel); + } + } + } + } +} + +void fc_kernel_bf16w8_get_q_dq(size_t K, size_t N, size_t stride, void* ptr, float* q, float* dq) { + float min, max; + tensor2D B(K, N, reinterpret_cast(ptr), stride); + amx_kernel::functional::get_min_max(B, min, max); + max = std::max(std::abs(max), std::abs(min)); + *q = 127 / max; + *dq = max / 127; +} + +/// set q, dq for each fc_kernel instance +void fc_kernel_bf16w8_set_q_dq(const fc_kernel* mm, float q, float dq) { + if (!mm || !mm->bf16xi8) { + std::cout << "fc_kernel_bf16w8_set_q_dq: created kernel is not int8 weight.\n"; + return; + } + mm->bf16xi8->quant_scale_B = q; + mm->bf16xi8->dequant_scale_B = dq; +} + + +} \ No newline at end of file diff --git a/src/mm_interface.cpp b/src/mm_interface.cpp new file mode 100644 index 0000000..76c48aa --- /dev/null +++ b/src/mm_interface.cpp @@ -0,0 +1,122 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "llm_mm.hpp" +#include "mm_kernel_amx.hpp" +#include "utility_avx512.hpp" + +namespace llmdnn { + +using ov::bfloat16; +struct mm_kernel { + std::shared_ptr> bf16xbf16; + std::shared_ptr> i8xi8; + std::shared_ptr> u8xi8; + + std::shared_ptr> i8xi8_gemv; + std::shared_ptr> bf16xbf16_gemv; + + data_type_t dt_a; + data_type_t dt_b; + bool b_is_transpose; +}; + +// interface +bool mm_kernel_create(mm_kernel** mm, const mm_create_param* param) { + mm_kernel* m = nullptr; + if (param == nullptr || mm == nullptr) { + std::cout << "mm_kernel_create: invalid input parameter.\n"; + goto ERR; + } + + m = new mm_kernel; + if (param->b_is_gemv) { + if (param->dt_a == dnnl_s8 && param->dt_b == dnnl_s8) { + m->i8xi8_gemv = std::make_shared>(); + } else if (param->dt_a == dnnl_bf16 && param->dt_b == dnnl_bf16) { + m->bf16xbf16_gemv = std::make_shared>(); + } else { + std::cout << "mm_kernel_create: unsupport gemv input type, a: " << param->dt_a << ", b: " << param->dt_b << ".\n"; + goto ERR; + } + } else { + if (param->dt_a == dnnl_s8 && param->dt_b == dnnl_s8) { + m->i8xi8 = std::make_shared>(false, param->b_is_trans); + } else if (param->dt_a == dnnl_u8 && param->dt_b == dnnl_s8) { + m->u8xi8 = std::make_shared>(false, param->b_is_trans); + } else if (param->dt_a == dnnl_bf16 && param->dt_b == dnnl_bf16) { + m->bf16xbf16 = std::make_shared>(false, param->b_is_trans); + } else { + std::cout << "mm_kernel_create: unsupport input type, a: " << param->dt_a << ", b: " << param->dt_b << ".\n"; + goto ERR; + } + } + m->dt_a = param->dt_a; + m->dt_b = param->dt_b; + m->b_is_transpose = param->b_is_trans; + + *mm = m; + return true; +ERR: + delete m; + return false; +} + +void mm_kernel_destroy(const mm_kernel* mm) { + if (mm) { + delete mm; + } +} + +void mm_kernel_execute(const mm_kernel* mm, void* ptr_a, void* ptr_b, void* ptr_c, size_t lda, size_t ldb, size_t ldc, + size_t M, size_t N, size_t K) { + size_t b_d0 = K, b_d1 = N; + if (mm->b_is_transpose) { + b_d0 = N; + b_d1 = K; + } + if (mm->i8xi8_gemv) { + tensor2D a(M, K, reinterpret_cast(ptr_a), lda); + (*mm->i8xi8_gemv)(a, reinterpret_cast(ptr_b), reinterpret_cast(ptr_c)); + cvt_i32_f32(reinterpret_cast(ptr_c), reinterpret_cast(ptr_c), M); + } else if (mm->i8xi8) { + tensor2D a(M, K, reinterpret_cast(ptr_a), lda); + tensor2D b(b_d0, b_d1, reinterpret_cast(ptr_b), ldb); + tensor2D c(M, N, reinterpret_cast(ptr_c), ldc); + amx_kernel::PP::BiasGeluStore pp(c); + (*mm->i8xi8)(a, b, 0, N, pp); + } else if (mm->u8xi8) { + tensor2D a(M, K, reinterpret_cast(ptr_a), lda); + tensor2D b(b_d0, b_d1, reinterpret_cast(ptr_b), ldb); + tensor2D c(M, N, reinterpret_cast(ptr_c), ldc); + amx_kernel::PP::BiasGeluStore pp(c); + (*mm->u8xi8)(a, b, 0, N, pp); + } else if (mm->bf16xbf16_gemv) { + tensor2D a(M, K, reinterpret_cast(ptr_a), lda); + (*mm->bf16xbf16_gemv)(a, reinterpret_cast(ptr_b), reinterpret_cast(ptr_c)); + } else if (mm->bf16xbf16) { + tensor2D a(M, K, reinterpret_cast(ptr_a), lda); + tensor2D b(b_d0, b_d1, reinterpret_cast(ptr_b), ldb); + tensor2D c(M, N, reinterpret_cast(ptr_c), ldc); + amx_kernel::PP::BiasGeluStore pp(c); + (*mm->bf16xbf16)(a, b, 0, N, pp); + } else { + std::cout << "mm_kernel_execute: no valid kernel created, call create first.\n"; + } +} + + +} \ No newline at end of file diff --git a/src/mm_kernel_amx.hpp b/src/mm_kernel_amx.hpp new file mode 100644 index 0000000..8f41915 --- /dev/null +++ b/src/mm_kernel_amx.hpp @@ -0,0 +1,1631 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "utility_amx.hpp" +#include "tensor2d.hpp" + +#ifdef _WIN32 +#include +#else +#include +#endif + +#include "bf16.hpp" +#ifdef ENABLE_NUMA +#include "numa.h" +#endif + +namespace amx_kernel { + +namespace functional { + + inline void transpose_m512i_16x16(__m512i &r0, __m512i &r1, __m512i &r2, __m512i &r3, + __m512i &r4, __m512i &r5, __m512i &r6, __m512i &r7, + __m512i &r8, __m512i &r9, __m512i &ra, __m512i &rb, + __m512i &rc, __m512i &rd, __m512i &re, __m512i &rf) { + __m512i t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, ta, tb, tc, td, te, tf; + + t0 = _mm512_unpacklo_epi32(r0,r1); // 0 16 1 17 4 20 5 21 8 24 9 25 12 28 13 29 + t1 = _mm512_unpackhi_epi32(r0,r1); // 2 18 3 19 6 22 7 23 10 26 11 27 14 30 15 31 + t2 = _mm512_unpacklo_epi32(r2,r3); // 32 48 33 49 ... + t3 = _mm512_unpackhi_epi32(r2,r3); // 34 50 35 51 ... + t4 = _mm512_unpacklo_epi32(r4,r5); // 64 80 65 81 ... + t5 = _mm512_unpackhi_epi32(r4,r5); // 66 82 67 83 ... + t6 = _mm512_unpacklo_epi32(r6,r7); // 96 112 97 113 ... + t7 = _mm512_unpackhi_epi32(r6,r7); // 98 114 99 115 ... + t8 = _mm512_unpacklo_epi32(r8,r9); // 128 ... + t9 = _mm512_unpackhi_epi32(r8,r9); // 130 ... + ta = _mm512_unpacklo_epi32(ra,rb); // 160 ... + tb = _mm512_unpackhi_epi32(ra,rb); // 162 ... + tc = _mm512_unpacklo_epi32(rc,rd); // 196 ... + td = _mm512_unpackhi_epi32(rc,rd); // 198 ... + te = _mm512_unpacklo_epi32(re,rf); // 228 ... + tf = _mm512_unpackhi_epi32(re,rf); // 230 ... + + r0 = _mm512_unpacklo_epi64(t0,t2); // 0 16 32 48 ... + r1 = _mm512_unpackhi_epi64(t0,t2); // 1 17 33 49 ... + r2 = _mm512_unpacklo_epi64(t1,t3); // 2 18 34 49 ... + r3 = _mm512_unpackhi_epi64(t1,t3); // 3 19 35 51 ... + r4 = _mm512_unpacklo_epi64(t4,t6); // 64 80 96 112 ... + r5 = _mm512_unpackhi_epi64(t4,t6); // 65 81 97 114 ... + r6 = _mm512_unpacklo_epi64(t5,t7); // 66 82 98 113 ... + r7 = _mm512_unpackhi_epi64(t5,t7); // 67 83 99 115 ... + r8 = _mm512_unpacklo_epi64(t8,ta); // 128 144 160 176 ... + r9 = _mm512_unpackhi_epi64(t8,ta); // 129 145 161 178 ... + ra = _mm512_unpacklo_epi64(t9,tb); // 130 146 162 177 ... + rb = _mm512_unpackhi_epi64(t9,tb); // 131 147 163 179 ... + rc = _mm512_unpacklo_epi64(tc,te); // 192 208 228 240 ... + rd = _mm512_unpackhi_epi64(tc,te); // 193 209 229 241 ... + re = _mm512_unpacklo_epi64(td,tf); // 194 210 230 242 ... + rf = _mm512_unpackhi_epi64(td,tf); // 195 211 231 243 ... + + t0 = _mm512_shuffle_i32x4(r0, r4, 0x88); // 0 16 32 48 8 24 40 56 64 80 96 112 ... + t1 = _mm512_shuffle_i32x4(r1, r5, 0x88); // 1 17 33 49 ... + t2 = _mm512_shuffle_i32x4(r2, r6, 0x88); // 2 18 34 50 ... + t3 = _mm512_shuffle_i32x4(r3, r7, 0x88); // 3 19 35 51 ... + t4 = _mm512_shuffle_i32x4(r0, r4, 0xdd); // 4 20 36 52 ... + t5 = _mm512_shuffle_i32x4(r1, r5, 0xdd); // 5 21 37 53 ... + t6 = _mm512_shuffle_i32x4(r2, r6, 0xdd); // 6 22 38 54 ... + t7 = _mm512_shuffle_i32x4(r3, r7, 0xdd); // 7 23 39 55 ... + t8 = _mm512_shuffle_i32x4(r8, rc, 0x88); // 128 144 160 176 ... + t9 = _mm512_shuffle_i32x4(r9, rd, 0x88); // 129 145 161 177 ... + ta = _mm512_shuffle_i32x4(ra, re, 0x88); // 130 146 162 178 ... + tb = _mm512_shuffle_i32x4(rb, rf, 0x88); // 131 147 163 179 ... + tc = _mm512_shuffle_i32x4(r8, rc, 0xdd); // 132 148 164 180 ... + td = _mm512_shuffle_i32x4(r9, rd, 0xdd); // 133 149 165 181 ... + te = _mm512_shuffle_i32x4(ra, re, 0xdd); // 134 150 166 182 ... + tf = _mm512_shuffle_i32x4(rb, rf, 0xdd); // 135 151 167 183 ... + + r0 = _mm512_shuffle_i32x4(t0, t8, 0x88); // 0 16 32 48 64 80 96 112 ... 240 + r1 = _mm512_shuffle_i32x4(t1, t9, 0x88); // 1 17 33 49 66 81 97 113 ... 241 + r2 = _mm512_shuffle_i32x4(t2, ta, 0x88); // 2 18 34 50 67 82 98 114 ... 242 + r3 = _mm512_shuffle_i32x4(t3, tb, 0x88); // 3 19 35 51 68 83 99 115 ... 243 + r4 = _mm512_shuffle_i32x4(t4, tc, 0x88); // 4 ... + r5 = _mm512_shuffle_i32x4(t5, td, 0x88); // 5 ... + r6 = _mm512_shuffle_i32x4(t6, te, 0x88); // 6 ... + r7 = _mm512_shuffle_i32x4(t7, tf, 0x88); // 7 ... + r8 = _mm512_shuffle_i32x4(t0, t8, 0xdd); // 8 ... + r9 = _mm512_shuffle_i32x4(t1, t9, 0xdd); // 9 ... + ra = _mm512_shuffle_i32x4(t2, ta, 0xdd); // 10 ... + rb = _mm512_shuffle_i32x4(t3, tb, 0xdd); // 11 ... + rc = _mm512_shuffle_i32x4(t4, tc, 0xdd); // 12 ... + rd = _mm512_shuffle_i32x4(t5, td, 0xdd); // 13 ... + re = _mm512_shuffle_i32x4(t6, te, 0xdd); // 14 ... + rf = _mm512_shuffle_i32x4(t7, tf, 0xdd); // 15 31 47 63 79 96 111 127 ... 255 + } + + inline void transpose_epi32_16x16(void * _dst, const void * src, int stride) { + auto * dst = reinterpret_cast(_dst); + __m512i r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf; + auto * pA = reinterpret_cast(src); + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_loadu_epi32(pA + 2*stride); + r3 = _mm512_loadu_epi32(pA + 3*stride); + r4 = _mm512_loadu_epi32(pA + 4*stride); + r5 = _mm512_loadu_epi32(pA + 5*stride); + r6 = _mm512_loadu_epi32(pA + 6*stride); + r7 = _mm512_loadu_epi32(pA + 7*stride); + r8 = _mm512_loadu_epi32(pA + 8*stride); + r9 = _mm512_loadu_epi32(pA + 9*stride); + ra = _mm512_loadu_epi32(pA + 10*stride); + rb = _mm512_loadu_epi32(pA + 11*stride); + rc = _mm512_loadu_epi32(pA + 12*stride); + rd = _mm512_loadu_epi32(pA + 13*stride); + re = _mm512_loadu_epi32(pA + 14*stride); + rf = _mm512_loadu_epi32(pA + 15*stride); + + transpose_m512i_16x16(r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf); + + _mm512_storeu_epi32(dst, r0); + _mm512_storeu_epi32(dst + 16, r1); + _mm512_storeu_epi32(dst + 2*16, r2); + _mm512_storeu_epi32(dst + 3*16, r3); + _mm512_storeu_epi32(dst + 4*16, r4); + _mm512_storeu_epi32(dst + 5*16, r5); + _mm512_storeu_epi32(dst + 6*16, r6); + _mm512_storeu_epi32(dst + 7*16, r7); + _mm512_storeu_epi32(dst + 8*16, r8); + _mm512_storeu_epi32(dst + 9*16, r9); + _mm512_storeu_epi32(dst + 10*16, ra); + _mm512_storeu_epi32(dst + 11*16, rb); + _mm512_storeu_epi32(dst + 12*16, rc); + _mm512_storeu_epi32(dst + 13*16, rd); + _mm512_storeu_epi32(dst + 14*16, re); + _mm512_storeu_epi32(dst + 15*16, rf); + } + + // 16xN, N<=16, non-valid part is filled with zeros + inline void transpose_epi32_16xN(void * _dst, const void * src, int stride, int valid_bytes) { + auto * dst = reinterpret_cast(_dst); + __m512i r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf; + auto * pA = reinterpret_cast(src); + uint64_t mask_value = 0xFFFFFFFFFFFFFFFFull >> (64 - valid_bytes); + __mmask64 mask = _cvtu64_mask64(mask_value); + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); + r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); + r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); + r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); + r6 = _mm512_maskz_loadu_epi8 (mask, pA + 6*stride); + r7 = _mm512_maskz_loadu_epi8 (mask, pA + 7*stride); + r8 = _mm512_maskz_loadu_epi8 (mask, pA + 8*stride); + r9 = _mm512_maskz_loadu_epi8 (mask, pA + 9*stride); + ra = _mm512_maskz_loadu_epi8 (mask, pA + 10*stride); + rb = _mm512_maskz_loadu_epi8 (mask, pA + 11*stride); + rc = _mm512_maskz_loadu_epi8 (mask, pA + 12*stride); + rd = _mm512_maskz_loadu_epi8 (mask, pA + 13*stride); + re = _mm512_maskz_loadu_epi8 (mask, pA + 14*stride); + rf = _mm512_maskz_loadu_epi8 (mask, pA + 15*stride); + transpose_m512i_16x16(r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf); + _mm512_storeu_epi32(dst, r0); + _mm512_storeu_epi32(dst + 16, r1); + _mm512_storeu_epi32(dst + 2*16, r2); + _mm512_storeu_epi32(dst + 3*16, r3); + _mm512_storeu_epi32(dst + 4*16, r4); + _mm512_storeu_epi32(dst + 5*16, r5); + _mm512_storeu_epi32(dst + 6*16, r6); + _mm512_storeu_epi32(dst + 7*16, r7); + _mm512_storeu_epi32(dst + 8*16, r8); + _mm512_storeu_epi32(dst + 9*16, r9); + _mm512_storeu_epi32(dst + 10*16, ra); + _mm512_storeu_epi32(dst + 11*16, rb); + _mm512_storeu_epi32(dst + 12*16, rc); + _mm512_storeu_epi32(dst + 13*16, rd); + _mm512_storeu_epi32(dst + 14*16, re); + _mm512_storeu_epi32(dst + 15*16, rf); + } + + // 16xN, N<=16, non-valid part is on the left, filled with zeros + inline void transpose_epi32_16xN_right_align(void * _dst, const void * src, int stride, int valid_bytes) { + auto * dst = reinterpret_cast(_dst); + __m512i r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf; + int invalid_bytes = 64 - valid_bytes; + auto * pA = reinterpret_cast(src) - invalid_bytes; + uint64_t mask_value = 0xFFFFFFFFFFFFFFFFull << invalid_bytes; + __mmask64 mask = _cvtu64_mask64(mask_value); + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); + r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); + r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); + r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); + r6 = _mm512_maskz_loadu_epi8 (mask, pA + 6*stride); + r7 = _mm512_maskz_loadu_epi8 (mask, pA + 7*stride); + r8 = _mm512_maskz_loadu_epi8 (mask, pA + 8*stride); + r9 = _mm512_maskz_loadu_epi8 (mask, pA + 9*stride); + ra = _mm512_maskz_loadu_epi8 (mask, pA + 10*stride); + rb = _mm512_maskz_loadu_epi8 (mask, pA + 11*stride); + rc = _mm512_maskz_loadu_epi8 (mask, pA + 12*stride); + rd = _mm512_maskz_loadu_epi8 (mask, pA + 13*stride); + re = _mm512_maskz_loadu_epi8 (mask, pA + 14*stride); + rf = _mm512_maskz_loadu_epi8 (mask, pA + 15*stride); + transpose_m512i_16x16(r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf); + _mm512_storeu_epi32(dst, r0); + _mm512_storeu_epi32(dst + 16, r1); + _mm512_storeu_epi32(dst + 2*16, r2); + _mm512_storeu_epi32(dst + 3*16, r3); + _mm512_storeu_epi32(dst + 4*16, r4); + _mm512_storeu_epi32(dst + 5*16, r5); + _mm512_storeu_epi32(dst + 6*16, r6); + _mm512_storeu_epi32(dst + 7*16, r7); + _mm512_storeu_epi32(dst + 8*16, r8); + _mm512_storeu_epi32(dst + 9*16, r9); + _mm512_storeu_epi32(dst + 10*16, ra); + _mm512_storeu_epi32(dst + 11*16, rb); + _mm512_storeu_epi32(dst + 12*16, rc); + _mm512_storeu_epi32(dst + 13*16, rd); + _mm512_storeu_epi32(dst + 14*16, re); + _mm512_storeu_epi32(dst + 15*16, rf); + } + + // gelu_erf_minimax_approx_compute_vector_fwd in oneDNN + // x*0.5*(1+erf(x/sqrt(2))) = x*0.5*(1 + x*Polynomial(x^2)) + inline __m512 gelu_erf_minmax_approx(__m512 & x) { + auto x2 = _mm512_mul_ps(x, x); // x^2 + + auto x_positive = _mm512_castsi512_ps(_mm512_and_epi32(_mm512_castps_si512(x), _mm512_set1_epi32(0x7FFFFFFF))); // clear sign mask + auto x_half = _mm512_mul_ps(x, _mm512_set1_ps(0.5f)); + + auto poly = _mm512_castsi512_ps(_mm512_set1_epi32(0x1f1c83fd)); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0xa3198977))); // poly * x^2 + xxx + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0x268a7927))); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0xa998c963))); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0x2c67ddb2))); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0xaf013b2c))); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0x315d4a4f))); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0xb3969b11))); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0x35a776e9))); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0xb79b0914))); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0x3970b255))); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0xbb1b7399))); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0x3ca3621f))); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0xbe082bc7))); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0x3f4c4228))); + + // 1.0f + erf(x * inv_sqrt2) = 1.0f + x * P(x^2) + poly = _mm512_fmadd_ps(poly, x, _mm512_set1_ps(1.0f)); + // x*0.5*(1 + x*Polynomial(x^2)) + poly = _mm512_mul_ps(poly, x_half); + + // combine: + // zone_id + // 1 -inf; -saturation_lbound : 0.0f + // 2 -saturation_lbound; -linear_ubound : x*0.5*(1 + x*Polynomial(x^2)) + // 3 -linear_ubound, linear_ubound : x*0.5 + // 4 linear_ubound : saturation_lbound : x*0.5*(1 + x*Polynomial(x^2)) + // 5 saturation_lbound: +inf : x + constexpr int neg_saturation_lbound = 0xc0a00000; + constexpr int linear_ubound = 0x33800000; + constexpr int saturation_lbound = 0x40a00000; + + auto mask_x_not_zone1 = _mm512_cmpnlt_ps_mask(x, _mm512_castsi512_ps(_mm512_set1_epi32(neg_saturation_lbound))); + x = _mm512_maskz_mov_ps(mask_x_not_zone1, x); + + auto mask_x_in_zone5 = _mm512_cmpnle_ps_mask(x_positive, _mm512_castsi512_ps(_mm512_set1_epi32(saturation_lbound))); + poly = _mm512_mask_mov_ps(poly, mask_x_in_zone5, x); + + auto mask_x_in_zone3 = _mm512_cmple_ps_mask(x_positive, _mm512_castsi512_ps(_mm512_set1_epi32(linear_ubound))); + poly = _mm512_mask_mov_ps(poly, mask_x_in_zone3, x_half); + return poly; + } + + inline void kpack_tile_B0B1(void * _dst0, void * _dst1, const int8_t * _src, int stride, int src_rows) { + #define FROM_B(i) ((1<<4)|(i)) + static const uint32_t idx[16] = { 0,4,FROM_B(0),FROM_B(4), + 1,5,FROM_B(1),FROM_B(5), + 2,6,FROM_B(2),FROM_B(6), + 3,7,FROM_B(3),FROM_B(7)}; + auto midx = _mm512_loadu_epi64(idx); + __mmask16 mask = _cvtu32_mask16(0xFFFFu); + const auto * src = reinterpret_cast(_src); + auto * dst0 = reinterpret_cast(_dst0); + auto * dst1 = reinterpret_cast(_dst1); + if (src_rows == 64) { + for (int row = 0; row < 16; row++) { + // each element (a? or b?) is 32-bits, two lanes in each ymm register + auto a256 = _mm256_loadu_epi8(src); src += stride; // [a0 a1 a2 a3 | a4 a5 a6 a7] 256-bits ymm0 B0: a0-a3 B1: a4:a7 + auto b256 = _mm256_loadu_epi8(src); src += stride; // [b0 b1 b2 b3 | b4 b5 b6 b7] 256-bits ymm1 B0: b0-b3 B1: b4:b7 + auto c256 = _mm256_loadu_epi8(src); src += stride; // [c0 c1 c2 c3 | c4 c5 c6 c7] 256-bits ymm2 B0: c0-c3 B1: c4:c7 + auto d256 = _mm256_loadu_epi8(src); src += stride; // [d0 d1 d2 d3 | d4 d5 d6 d7] 256-bits ymm3 B0: d0-d3 B1: d4:d7 + auto a = _mm512_castsi256_si512(a256); + auto b = _mm512_castsi256_si512(b256); + auto c = _mm512_castsi256_si512(c256); + auto d = _mm512_castsi256_si512(d256); + auto ac = _mm512_mask_permutex2var_epi32(a, mask, midx, c); // [a0 a4 c0 c4 | a1 a5 c1 c5 | a2 a6 c2 c6 | a3 a7 c3 c7] + auto bd = _mm512_mask_permutex2var_epi32(b, mask, midx, d); // [b0 b4 d0 d4 | b1 b5 d1 d5 | b2 b6 d2 d6 | b3 b7 d3 d7] + auto aib = _mm512_unpacklo_epi8(ac, bd); // [a0&b0 a4&b4 | a1&b1 a5&b5 | a2&b2 a6&b6 | a3&b3 a7&b7] + auto cid = _mm512_unpackhi_epi8(ac, bd); // [c0&d0 c4&d4 | c1&d1 c5&d5 | c2&d2 c6&d6 | c3&d3 c7&d7] + auto rowB0 = _mm512_unpacklo_epi16(aib, cid); // [a0&b0&c0&d0 | a1&b1&c1&d1 | a2&b2&c2&d2 | a3&b3&c3&d3] 512-bit (64bytes) line in B0 + auto rowB1 = _mm512_unpackhi_epi16(aib, cid); // [a4&b4&c4&d4 | a5&b5&c5&d5 | a6&b6&c6&d6 | a7&b7&c7&d7] 512-bit (64bytes) line in B1 + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + dst0 += 64; + dst1 += 64; + } + } else { + // less than 64 source lines, + int allzero_dst_rows = (64-src_rows)/4; + int allnonzero_dst_rows = src_rows/4; + // padding zeros at the top + auto rowB0 = _mm512_setzero_si512(); + auto rowB1 = _mm512_setzero_si512(); + for(int i = 0; i < allzero_dst_rows ; i++) { + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + dst0 += 64; + dst1 += 64; + } + // mixed row + int tails_nz = (src_rows & 3); + if (tails_nz) { + __mmask32 kmask1 = _cvtu32_mask32(0xFFFFFFFF); + auto a256 = _mm256_setzero_si256(); // must be zero + auto b256 = _mm256_setzero_si256(); // when tails_nz > 2 + auto c256 = _mm256_setzero_si256(); // when tails_nz > 1 + auto d256 = _mm256_setzero_si256(); // when tails_nz > 0(always load) + if (tails_nz > 2) { + b256 = _mm256_loadu_epi8 (src); src += stride; + } + if (tails_nz > 1) { + c256 = _mm256_loadu_epi8 (src); src += stride; + } + d256 = _mm256_loadu_epi8 (src); src += stride; + auto a = _mm512_castsi256_si512(a256); + auto b = _mm512_castsi256_si512(b256); + auto c = _mm512_castsi256_si512(c256); + auto d = _mm512_castsi256_si512(d256); + auto ac = _mm512_mask_permutex2var_epi32(a, mask, midx, c); // [a0 a4 c0 c4 | a1 a5 c1 c5 | a2 a6 c2 c6 | a3 a7 c3 c7] + auto bd = _mm512_mask_permutex2var_epi32(b, mask, midx, d); // [b0 b4 d0 d4 | b1 b5 d1 d5 | b2 b6 d2 d6 | b3 b7 d3 d7] + auto aib = _mm512_unpacklo_epi8(ac, bd); // [a0&b0 a4&b4 | a1&b1 a5&b5 | a2&b2 a6&b6 | a3&b3 a7&b7] + auto cid = _mm512_unpackhi_epi8(ac, bd); // [c0&d0 c4&d4 | c1&d1 c5&d5 | c2&d2 c6&d6 | c3&d3 c7&d7] + auto rowB0 = _mm512_unpacklo_epi16(aib, cid); // [a0&b0&c0&d0 | a1&b1&c1&d1 | a2&b2&c2&d2 | a3&b3&c3&d3] 512-bit (64bytes) line in B0 + auto rowB1 = _mm512_unpackhi_epi16(aib, cid); // [a4&b4&c4&d4 | a5&b5&c5&d5 | a6&b6&c6&d6 | a7&b7&c7&d7] 512-bit (64bytes) line in B1 + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + dst0 += 64; + dst1 += 64; + } + // all non zeros + for (int i = 0; i < allnonzero_dst_rows; i++) { + auto a256 = _mm256_loadu_epi8 (src); src += stride; + auto b256 = _mm256_loadu_epi8 (src); src += stride; + auto c256 = _mm256_loadu_epi8 (src); src += stride; + auto d256 = _mm256_loadu_epi8 (src); src += stride; + auto a = _mm512_castsi256_si512(a256); + auto b = _mm512_castsi256_si512(b256); + auto c = _mm512_castsi256_si512(c256); + auto d = _mm512_castsi256_si512(d256); + auto ac = _mm512_mask_permutex2var_epi32(a, mask, midx, c); // [a0 a4 c0 c4 | a1 a5 c1 c5 | a2 a6 c2 c6 | a3 a7 c3 c7] + auto bd = _mm512_mask_permutex2var_epi32(b, mask, midx, d); // [b0 b4 d0 d4 | b1 b5 d1 d5 | b2 b6 d2 d6 | b3 b7 d3 d7] + auto aib = _mm512_unpacklo_epi8(ac, bd); // [a0&b0 a4&b4 | a1&b1 a5&b5 | a2&b2 a6&b6 | a3&b3 a7&b7] + auto cid = _mm512_unpackhi_epi8(ac, bd); // [c0&d0 c4&d4 | c1&d1 c5&d5 | c2&d2 c6&d6 | c3&d3 c7&d7] + auto rowB0 = _mm512_unpacklo_epi16(aib, cid); // [a0&b0&c0&d0 | a1&b1&c1&d1 | a2&b2&c2&d2 | a3&b3&c3&d3] 512-bit (64bytes) line in B0 + auto rowB1 = _mm512_unpackhi_epi16(aib, cid); // [a4&b4&c4&d4 | a5&b5&c5&d5 | a6&b6&c6&d6 | a7&b7&c7&d7] 512-bit (64bytes) line in B1 + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + dst0 += 64; + dst1 += 64; + } + } + } + + inline void kpack_tile_B0B1(void * _dst0, void * _dst1, const ov::bfloat16 * _src, int stride, int src_rows) { + static const uint64_t idx[8] = {0,4,1,5,2,6,3,7}; + auto midx = _mm512_loadu_epi64(idx); + const auto * src = reinterpret_cast(_src); + auto * dst0 = reinterpret_cast(_dst0); + auto * dst1 = reinterpret_cast(_dst1); + __m512i a,b,rowB0, rowB1; + if (src_rows == 32) { + for (int row = 0; row < 16; row++) { + a = _mm512_loadu_epi16(src); // [a1 a2 a3 a4 | a5 a6 a7 a8] total 512-bits in 8 64bits unit + b = _mm512_loadu_epi16(src + stride); // [b1 b2 b3 b4 | b5 b6 b7 b8] total 512-bits + a = _mm512_permutexvar_epi64(midx, a); // [a1 a5 | a2 a6 | a3 a7 | a4 a8] + b = _mm512_permutexvar_epi64(midx, b); // [b1 b5 | b2 b6 | b3 b7 | b4 b8] + rowB0 = _mm512_unpacklo_epi16(a, b); // [ a1&b1 a2&b2 a3&b3 a4&b4] for each 128-bits lane, interleave word in low 64 bits + rowB1 = _mm512_unpackhi_epi16(a, b); // [ a5&b5 a6&b6 a7&b7 a8&b8] for each 128-bits lane, interleave word in high 64 bits + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + src += 2*stride; + dst0 += 64; + dst1 += 64; + } + } else { + int allzero_dst_rows = (32-src_rows)/2; + int allnonzero_dst_rows = src_rows/2; + + rowB0 = _mm512_setzero_si512(); + rowB1 = _mm512_setzero_si512(); + for(int i = 0; i < allzero_dst_rows ; i++) { + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + dst0 += 64; + dst1 += 64; + } + // mixed row + if (src_rows & 1) { + a = _mm512_setzero_si512(); + b = _mm512_loadu_epi16(src); src += stride; + a = _mm512_permutexvar_epi64(midx, a); + b = _mm512_permutexvar_epi64(midx, b); + auto rowB0 = _mm512_unpacklo_epi16(a, b); + auto rowB1 = _mm512_unpackhi_epi16(a, b); + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + dst0 += 64; + dst1 += 64; + } + // all non-zero rows + for (int i = 0; i < allnonzero_dst_rows; i++) { + a = _mm512_loadu_epi16(src); + b = _mm512_loadu_epi16(src + stride); + a = _mm512_permutexvar_epi64(midx, a); + b = _mm512_permutexvar_epi64(midx, b); + rowB0 = _mm512_unpacklo_epi16(a, b); + rowB1 = _mm512_unpackhi_epi16(a, b); + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + src += 2*stride; + dst0 += 64; + dst1 += 64; + } + } + } + + // prepare B matrix for C matrix 2x2 blocking (B matrix + // will be accessed in 1x2) + // given 2x2 blocking scheme, Kx32 blocks are always + // accessed sequentially: + // transpose/repack each 32xK ov::bfloat16 submatrix + // into Kx32 slices (each number is a 16x32 bf16-block): + // 0 2 4 6 ... ... + // 1 3 5 7 ... ... + + inline void get_min_max(tensor2D & matB, float& min, float& max) { + int K = matB.dims[0]; + int N = matB.dims[1]; + auto m_max = _mm512_set1_ps(-__FLT_MAX__); + auto m_min = _mm512_set1_ps(__FLT_MAX__); + for (int k = 0; k < K; k++) { + int n = 0; + for (; n < N / 16 * 16; n += 16) { + auto a = _mm512_cvtepi16_epi32(_mm256_loadu_epi16(&matB(k, n))); + a = _mm512_slli_epi32(a, 16); + m_max = _mm512_max_ps((__m512)a, m_max); + m_min = _mm512_min_ps((__m512)a, m_min); + } + if (n != N) { + __mmask16 msk = _cvtu32_mask16(0xFFFFu >> (16 - (N - n))); + auto a = _mm512_cvtepi16_epi32(_mm256_maskz_loadu_epi16(msk, &matB(k, n))); + a = _mm512_slli_epi32(a, 16); + m_max = _mm512_mask_max_ps(m_max, msk, (__m512)a, m_max); + m_min = _mm512_mask_min_ps(m_min, msk, (__m512)a, m_min); + } + } + max = _mm512_reduce_max_ps(m_max); + min = _mm512_reduce_min_ps(m_min); + } + + template + void i8_to_bf16_Kx32(int8_t *&src, ov::bfloat16 *dst) + { + for (int k = 0; k < K; k++) + { + auto a = _mm_load_si128((__m128i *)src); // 16 int8 + auto b = _mm_load_si128((__m128i *)(src + 16)); // 16 int8 + auto a_512 = _mm512_cvtepi8_epi32(a); // 16 int32 + auto b_512 = _mm512_cvtepi8_epi32(b); // 16 int32 + auto a_f = _mm512_cvtepi32_ps(a_512); // 16 ps + auto b_f = _mm512_cvtepi32_ps(b_512); // 16 ps + auto reg_out = _mm512_cvtne2ps_pbh(b_f, a_f); // 32 packed bf16 + _mm512_store_epi32(dst, (__m512i)reg_out); // + src += 32; // 32 int8_t dequantized into 32 bf16 + dst += 32; + } + } + + inline void bf16_to_i8_tensor(tensor2D& dst, tensor2D& src, float quant_scale) { + dst.resize(src.dims[0], src.dims[1]); + auto scale = _mm512_set1_ps(quant_scale); + for (int k = 0; k < src.dims[0]; k++) { + auto p_src = &src(k, 0); + auto p_dst = &dst(k, 0); + for (int n = 0; n < src.dims[1]; n += 16, p_src += 16, p_dst += 16) { + auto a = _mm512_cvtepi16_epi32(_mm256_loadu_epi16(p_src)); // load packed 16 x bf16 + a = _mm512_slli_epi32(a, 16); // bf16 zero-extend to f32 + auto a_f = _mm512_mul_ps((__m512)a, scale); // scale + a = _mm512_cvtps_epi32(a_f); // ps to dw + auto a_128 = _mm512_cvtsepi32_epi8(a); // saturate convert into int8 + _mm_store_si128((__m128i*)(p_dst), a_128); + } + } + } +}; + +// 2x2 tiles post process kernels + +// 4 tiles located at C matrix (m,n) of size (valid_m, valid_n) +// tC00/tC01 +// tC10/tC11 + +namespace PP { + template + struct is_f32i32 : std::false_type {}; + template<> + struct is_f32i32 : std::true_type {}; + template<> + struct is_f32i32 : std::true_type {}; + + enum Steps { + NONE = 0, + DEQUANT = 1<<0, + BIAS = 1<<1, + GELU = 1<<2, + QUANT = 1<<3, + + BIAS_GELU = BIAS | GELU, + DEQUANT_BIAS_GELU = DEQUANT | BIAS_GELU, + DEQUANT_BIAS_GELU_QUANT = DEQUANT_BIAS_GELU | QUANT, + DEQUANT_BIAS_QUANT = DEQUANT | BIAS | QUANT, + DEQUANT_GELU_QUANT = DEQUANT | GELU | QUANT, + DEQUANT_QUANT = DEQUANT | QUANT, + + DEQUANT_GELU = DEQUANT | GELU, + DEQUANT_BIAS = DEQUANT | BIAS + }; + + template + struct BiasGeluStore { + static_assert(std::is_same::value || std::is_same::value || std::is_same::value, + "BiasGeluStore only support output data types ov::bfloat16/int8_t/float"); + + BiasGeluStore(tensor2D & C, float * bias = nullptr) : C(C), bias(bias) {} + + tensor2D & C; + float * bias; + void set_bias(float * _bias) { + assert (steps & BIAS); + bias = _bias; + } + + float deq_scale_common = 1.0f; + float * deq_scale_per_oc = nullptr; + void set_deq_scale(float scale = 1.0f) { + assert (steps & DEQUANT); + deq_scale_common = scale; + deq_scale_per_oc = nullptr; + } + void set_deq_scale(float * scale_per_oc) { + assert (steps & DEQUANT); + deq_scale_common = 0; + deq_scale_per_oc = scale_per_oc; + } + + float q_scale_common = 0.0f; + float * q_scale_per_oc = nullptr; + void set_q_scale(float scale) { + assert (steps & QUANT); + q_scale_common = scale; + q_scale_per_oc = nullptr; + } + void set_q_scale(float * scale_per_oc) { + assert (steps & QUANT); + q_scale_common = 0; + q_scale_per_oc = scale_per_oc; + } + + // source buffC can be i32 or f32 + template::value, bool>::type = true> + void operator()(tensor2D & buffC, int m, int n, int valid_m, int valid_n) { + auto * psrc = &buffC(0,0); + int8_t * pdst = reinterpret_cast(&(C(m, n))); + int stride = C.stride; + + __m512 bias0, bias1; + if (steps & BIAS) { + bias0 = _mm512_loadu_ps(bias + n); + bias1 = _mm512_loadu_ps(bias + n + 16); + } + + __m512 m512_q_scale0; + __m512 m512_q_scale1; + __m512 m512_deq_scale0; + __m512 m512_deq_scale1; + if (steps & DEQUANT) { + if (deq_scale_per_oc) { + m512_deq_scale0 = _mm512_loadu_ps(deq_scale_per_oc + n); + m512_deq_scale1 = _mm512_loadu_ps(deq_scale_per_oc + n + 16); + } else { + m512_deq_scale0 = _mm512_set1_ps(deq_scale_common); + m512_deq_scale1 = _mm512_set1_ps(deq_scale_common); + } + } + if (steps & QUANT) { + if (q_scale_per_oc) { + m512_q_scale0 = _mm512_loadu_ps(q_scale_per_oc + n); + m512_q_scale1 = _mm512_loadu_ps(q_scale_per_oc + n + 16); + } else { + m512_q_scale0 = _mm512_set1_ps(q_scale_common); + m512_q_scale1 = _mm512_set1_ps(q_scale_common); + } + } + + __mmask32 kall; + __mmask16 k0, k1; + if (std::is_same::value) { + if (valid_n >= 16) { + k0 = _cvtu32_mask16(0xFFFF); + k1 = _cvtu32_mask16(0xFFFF >> (32-valid_n)); + } else { + k0 = _cvtu32_mask16(0xFFFF >> (16-valid_n)); + k1 = _cvtu32_mask16(0); + } + } else { + kall = _cvtu32_mask32(0xFFFFFFFF >> (32-valid_n)); + } + + for(int i = 0; i < valid_m; i ++) { + auto r0 = _mm512_loadu_ps(psrc); + auto r1 = _mm512_loadu_ps(psrc + 16); + if (std::is_same::value) { + r0 = _mm512_cvtepi32_ps(_mm512_castps_si512(r0)); // cvt i32=>f32 + r1 = _mm512_cvtepi32_ps(_mm512_castps_si512(r1)); // cvt i32=>f32 + } + if (steps & DEQUANT) { + r0 = _mm512_mul_ps(r0, m512_deq_scale0); // dequantize + r1 = _mm512_mul_ps(r1, m512_deq_scale1); // dequantize + } + if (steps & BIAS) { + r0 = _mm512_add_ps(r0, bias0); + r1 = _mm512_add_ps(r1, bias1); + } + if (steps & GELU) { + r0 = functional::gelu_erf_minmax_approx(r0); + r1 = functional::gelu_erf_minmax_approx(r1); + } + + // quantize & store + if (steps & QUANT) { + r0 = _mm512_mul_ps(r0, m512_q_scale0); + r1 = _mm512_mul_ps(r1, m512_q_scale1); + } + if (std::is_same::value) { + auto c = _mm512_cvtne2ps_pbh(r1, r0); // convert to bf16 + _mm512_mask_storeu_epi16(pdst, kall, c); // store bf16 + } + if (std::is_same::value) { + auto d0 = _mm512_cvtps_epi32(r0); // convert to dword(i32) + auto d1 = _mm512_cvtps_epi32(r1); // convert to dword(i32) + auto b0 = _mm512_cvtsepi32_epi8 (d0); // dword => int8 with Saturate8 + auto b1 = _mm512_cvtsepi32_epi8 (d1); // dword => int8 with Saturate8 + auto b0b1 = _mm256_inserti32x4(_mm256_castsi128_si256(b0), b1, 1); // combine two int8 xmm into a ymm + _mm256_mask_storeu_epi8(pdst, kall, b0b1); // masked store + } + if (std::is_same::value) { + _mm512_mask_storeu_ps(pdst, k0, r0); // store float + _mm512_mask_storeu_ps(pdst + 64, k1, r1); // store float + } + pdst += stride; + psrc += 32; + } + } + }; +} + +template +void prefetch_bytes(void *src) +{ + int8_t *p = reinterpret_cast(src); + for (int i = 0; i < bytes; i+=64) + _mm_prefetch(p + i + advance, sel); +} +template +void zero_tiles() { int dummy[sizeof...(tmm)] = {(_tile_zero(tmm), 0)...}; } + +// matmul (FC) +// +// constB constrols whether it's FC or not +// store precision for weight compression, only for BF16 AMX + +template +tensor2D getSubMatB(tensor2D & _matB, int n0, int n1, bool transposeB) { + int Bd0 = transposeB ? (n1-n0) : _matB.dims[0]; + int Bd1 = transposeB ? _matB.dims[1] : (n1-n0); + T * pbase = transposeB ? (&_matB(n0, 0)):(&_matB(0, n0)); + return tensor2D(Bd0, Bd1, pbase, _matB.stride); +} + +template +void loop2D_no_bM(int M, int N, F f) { + for(int n=0; n +void loop2D(int M, int N, int mc, F f) { + for(int m0=0; m0= bM) +template +void loop2D_opt_Mtail(int M, int N, int mc, F f) { + int tailM = (M % (mc*bM)) % bM; + assert(M > bM); + for(int m0=0; m0 +void repackB_1x2(const tensor2D &Bi, bool transpose, tensor2D& Bo, bool is_const) { + int K = Bi.dims[transpose?1:0]; + int N = Bi.dims[transpose?0:1]; + + // K_padded : round up to multiple of 32/64 + int kStep = 64 / sizeof(T); + int K_padded = (K + kStep - 1)/kStep * kStep; + int Ktails = K % kStep; + int Kbody = K - Ktails; + + // N_padded : round up to multiple of (2*16) + int N_unit = 2*16; + int N_padded = (N + N_unit - 1)/N_unit * N_unit; + + // Bo(ni, 0) is a vector flattened from a slice of shape [K_padded x N_unit] + Bo.resize(N_padded/N_unit, K_padded * N_unit, false, is_const); + + if (transpose) { + for(int n = 0; n < N; n += N_unit) { + // a K_padded x N_unit submatrix layouted in B0/B1... and put sequentially + auto * dst = reinterpret_cast(&Bo(n/N_unit, 0)); + auto * src0 = reinterpret_cast(&Bi(n, 0)); + int k; + for(k = 0; k < Kbody; k += kStep) { + // B0 (16x32) => transpose+repack as 32x16(16x16x2) or 64x16(16x16x4) + functional::transpose_epi32_16x16(dst, src0 + 0*16*Bi.stride + k*sizeof(T), Bi.stride); + dst += 1024; + functional::transpose_epi32_16x16(dst, src0 + 1*16*Bi.stride + k*sizeof(T), Bi.stride); + dst += 1024; + } + if (Ktails) { + // Ktails part is loaded into A tile right-aligned, so B tile must also load + // Ktails part to bottom-aligned, and fill upper padding with zero + functional::transpose_epi32_16xN_right_align(dst, src0 + 0*16*Bi.stride + k*sizeof(T), Bi.stride, (K-k)*sizeof(T)); + dst += 1024; + functional::transpose_epi32_16xN_right_align(dst, src0 + 1*16*Bi.stride + k*sizeof(T), Bi.stride, (K-k)*sizeof(T)); + dst += 1024; + } + } + } else { + // pack & layout sequentially + for(int n = 0; n < N; n += N_unit) { + auto * dst = reinterpret_cast(&Bo(n/N_unit, 0)); + for(int k = 0; k < K; k+=kStep) { + // bf16: B0 B1 32x(16+16) => repack as two 16x16x2 + // int8: B0 B1 64x(16+16) => repack as two 16x16x4 + int src_rows = std::min(K - k, kStep); + functional::kpack_tile_B0B1(dst, dst + (1024), &Bi(k, n), Bi.stride, src_rows); + dst += 2048; + } + } + } +} + +template +struct acc_type {}; +template<> +struct acc_type { typedef float type; }; +template<> +struct acc_type { typedef int32_t type; }; +template<> +struct acc_type { typedef int32_t type; }; + +template +using acc_type_t = typename acc_type::type; + +// matrix multiply with vector +// C_Nx1 = A_MxK * b_Kx1 + +template> +struct MatmulVector { + MatmulVector() {} + constexpr static bool is_bf16s8 = std::is_same::value && std::is_same::value; + constexpr static bool is_bf16bf16 = std::is_same::value && std::is_same::value; + constexpr static bool is_s8s8 = std::is_same::value && std::is_same::value; + constexpr static bool is_s8u8 = std::is_same::value && std::is_same::value; + constexpr static bool is_u8s8 = std::is_same::value && std::is_same::value; + constexpr static bool is_u8u8 = std::is_same::value && std::is_same::value; + constexpr static bool is_i8_mode = is_s8s8 || is_s8u8 || is_u8s8 || is_u8u8; + constexpr static int kStep = is_i8_mode ? 64 : 32; + +#define TILE_DP(dst, a, b) \ + if (is_bf16bf16) _tile_dpbf16ps(dst, a, b); \ + if (is_s8s8) _tile_dpbssd(dst, a, b); \ + if (is_s8u8) _tile_dpbsud(dst, a, b); \ + if (is_u8s8) _tile_dpbusd(dst, a, b); \ + if (is_u8u8) _tile_dpbuud(dst, a, b); + + alignas(64) int8_t KtailBuff[64]; + + template + void kernel(int M, int K, const void * pA, int strideA, const void * vB, void * vC) { + static_assert(tmmN >= 1 && tmmN <= 6, "tmmN must be within [1-6] range"); + const auto * pA0 = reinterpret_cast(pA); + int KLastOffBytes = (K - kStep) * sizeof(TA); + const auto * pB0 = reinterpret_cast(vB); + auto * pC0 = reinterpret_cast(vC); + + const auto * pBLast = pB0 + 64*(tmmN - 1); + int Ktail = K & (kStep - 1); + if (Ktail) { + if (bFallbackKtails) { + // if bContainMtails, the last submatrix needs to use special to prevent A matrix read overflow + // K tails is handled by: + // - zero-padding the last tile of vector B, at the top + // - right-align last tile load from matA + __mmask64 kmask = _cvtu64_mask64(0xFFFFFFFFFFFFFFFFull << (kStep - Ktail)*sizeof(TB)); + auto r = _mm512_maskz_loadu_epi8(kmask, pB0 + KLastOffBytes); + _mm512_storeu_epi8(KtailBuff, r); + } else { + // each row of A can be read overflow w/o worrying NaN numbers + // zero-padding the last tile of vector B as bottom is enough + __mmask64 kmask = _cvtu64_mask64(0xFFFFFFFFFFFFFFFFull >> (kStep - Ktail)*sizeof(TB)); + KLastOffBytes = (K - Ktail)*sizeof(TA); + auto r = _mm512_maskz_loadu_epi8(kmask, pB0 + KLastOffBytes); + _mm512_storeu_epi8(KtailBuff, r); + } + pBLast = KtailBuff; + } + + // load B tiles outside of loop + if (tmmN == 1) { + _tile_loadd(2, pB0, 4); + } + if (tmmN == 2) { + _tile_loadd(2, pB0, 4); + _tile_loadd(3, pBLast, 4); + } + if (tmmN == 3) { + _tile_loadd(2, pB0, 4); + _tile_loadd(3, pB0 + 64, 4); + _tile_loadd(4, pBLast, 4); + } + if (tmmN == 4) { + _tile_loadd(2, pB0, 4); + _tile_loadd(3, pB0 + 64, 4); + _tile_loadd(4, pB0 + 64*2, 4); + _tile_loadd(5, pBLast, 4); + } + if (tmmN == 5) { + _tile_loadd(2, pB0, 4); + _tile_loadd(3, pB0 + 64, 4); + _tile_loadd(4, pB0 + 64*2, 4); + _tile_loadd(5, pB0 + 64*3, 4); + _tile_loadd(6, pBLast, 4); + } + if (tmmN == 6) { + _tile_loadd(2, pB0, 4); + _tile_loadd(3, pB0 + 64, 4); + _tile_loadd(4, pB0 + 64*2, 4); + _tile_loadd(5, pB0 + 64*3, 4); + _tile_loadd(6, pB0 + 64*4, 4); + _tile_loadd(7, pBLast, 4); + } + //asm("int3"); + for(int m = 0; m < M; m+=16) { + zero_tiles<0>(); + if (tmmN == 1) { + _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); + } + if (tmmN == 2) { + _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); + _tile_loadd(1, pA0 + KLastOffBytes, strideA); TILE_DP(0, 1, 3); + } + if (tmmN == 3) { + _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); + _tile_loadd(1, pA0 + 64, strideA); TILE_DP(0, 1, 3); + _tile_loadd(1, pA0 + KLastOffBytes, strideA); TILE_DP(0, 1, 4); + } + if (tmmN == 4) { + _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); + _tile_loadd(1, pA0 + 64, strideA); TILE_DP(0, 1, 3); + _tile_loadd(1, pA0 + 128, strideA); TILE_DP(0, 1, 4); + _tile_loadd(1, pA0 + KLastOffBytes, strideA); TILE_DP(0, 1, 5); + } + if (tmmN == 6) { + _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); + _tile_loadd(1, pA0 + 64, strideA); TILE_DP(0, 1, 3); + _tile_loadd(1, pA0 + 128, strideA); TILE_DP(0, 1, 4); + _tile_loadd(1, pA0 + 192, strideA); TILE_DP(0, 1, 5); + _tile_loadd(1, pA0 + KLastOffBytes, strideA); TILE_DP(0, 1, 6); + } + if (tmmN == 7) { + _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); + _tile_loadd(1, pA0 + 64, strideA); TILE_DP(0, 1, 3); + _tile_loadd(1, pA0 + 128, strideA); TILE_DP(0, 1, 4); + _tile_loadd(1, pA0 + 192, strideA); TILE_DP(0, 1, 5); + _tile_loadd(1, pA0 + 256, strideA); TILE_DP(0, 1, 6); + _tile_loadd(1, pA0 + KLastOffBytes, strideA); TILE_DP(0, 1, 7); + } + _tile_stored(0, pC0, 4); pC0 += 16*4; // C is single column, always take 4 bytes + pA0 += 16 * strideA; + } + } + + void operator()(tensor2D & matA, const TB * vB, TC * vC) { + int M = matA.dims[0]; + int K = matA.dims[1]; + TA * pA = &matA[0]; + int strideA = matA.stride; + + // M tails is handled + assert(K >= kStep && K <= 6*kStep); + + int Ktail = K & (kStep - 1); + int Mtail = M & (16 - 1); + int Mbody = M - Mtail; + int numBtiles = (K + kStep - 1)/kStep; + + // if we have Ktails, then it will always be handled in Mtail, so we split + // Mtail out even if it's zero + if (Ktail) { + if (Mtail == 0) { + Mtail = 16; + Mbody -= 16; + } + } + + if (Mbody) { + tileconfig_t tfg(1, 0, { + {16, 4}, // C:0 M x 1 (4b) + {16, 64}, // A:1 M x 32/64 (64b) + {16, 4}, // B:2 32/64 x 1 (4b) + {16, 4}, // B:3 + {16, 4}, // B:4 + {16, 4}, // B:5 + {16, 4}, // B:6 + {16, 4}, // B:7 + }); + // Ktail fallback will always be done at Mtails loop + switch(numBtiles) { + case 1: kernel<1, false>(Mbody, K, pA, strideA, vB, vC); break; + case 2: kernel<2, false>(Mbody, K, pA, strideA, vB, vC); break; + case 3: kernel<3, false>(Mbody, K, pA, strideA, vB, vC); break; + case 4: kernel<4, false>(Mbody, K, pA, strideA, vB, vC); break; + case 5: kernel<5, false>(Mbody, K, pA, strideA, vB, vC); break; + case 6: kernel<6, false>(Mbody, K, pA, strideA, vB, vC); break; + default: + assert(false); // impossible since (K <= 6*kStep) + } + } + + if (Mtail) { + pA = &matA(Mbody, 0); + tileconfig_t tfg(1, 0, { + {Mtail, 4}, // C:0 M x 1 (4b) + {Mtail, 64}, // A:1 M x 32/64 (64b) + {16, 4}, // B:2 32/64 x 1 (4b) + {16, 4}, // B:3 + {16, 4}, // B:4 + {16, 4}, // B:5 + {16, 4}, // B:6 + {16, 4}, // B:7 + }); + if (Ktail) { + switch(numBtiles) { + case 1: kernel<1, true>(Mtail, K, pA, strideA, vB, vC + Mbody); break; + case 2: kernel<2, true>(Mtail, K, pA, strideA, vB, vC + Mbody); break; + case 3: kernel<3, true>(Mtail, K, pA, strideA, vB, vC + Mbody); break; + case 4: kernel<4, true>(Mtail, K, pA, strideA, vB, vC + Mbody); break; + case 5: kernel<5, true>(Mtail, K, pA, strideA, vB, vC + Mbody); break; + case 6: kernel<6, true>(Mtail, K, pA, strideA, vB, vC + Mbody); break; + default: + assert(false); // impossible since (K <= 6*kStep) + } + } else { + switch(numBtiles) { + case 1: kernel<1, false>(Mtail, K, pA, strideA, vB, vC + Mbody); break; + case 2: kernel<2, false>(Mtail, K, pA, strideA, vB, vC + Mbody); break; + case 3: kernel<3, false>(Mtail, K, pA, strideA, vB, vC + Mbody); break; + case 4: kernel<4, false>(Mtail, K, pA, strideA, vB, vC + Mbody); break; + case 5: kernel<5, false>(Mtail, K, pA, strideA, vB, vC + Mbody); break; + case 6: kernel<6, false>(Mtail, K, pA, strideA, vB, vC + Mbody); break; + default: + assert(false); // impossible since (K <= 6*kStep) + } + } + } + } +}; + +template::type> +struct Matmul { + // B matrix is orgnized as tensor2D of shape axb where a=round_up_div(N, 32), b=round_up(K,32/64)*32 + // so b is size of submatrix of Kx32 composed of two columns of B0/B1 tiles. + tensor2D internalB; + + bool constB; + bool transposeB; + + constexpr static bool is_bf16s8 = std::is_same::value && std::is_same::value; + constexpr static bool is_bf16bf16 = std::is_same::value && std::is_same::value; + constexpr static bool is_s8s8 = std::is_same::value && std::is_same::value; + constexpr static bool is_s8u8 = std::is_same::value && std::is_same::value; + constexpr static bool is_u8s8 = std::is_same::value && std::is_same::value; + constexpr static bool is_u8u8 = std::is_same::value && std::is_same::value; + constexpr static bool is_i8_mode = is_s8s8 || is_s8u8 || is_u8s8 || is_u8u8; + + // AMX bf16 & int8 has same M(=16) in A,C tile and same N(=16) in B tile + // but only different K(32 vs 64) in A,C & B tiles + constexpr static int kStep = is_i8_mode ? 64 : 32; + + // 2x2 C tiles buffer + // most usecase requires post-processing with AVX, thus buffC + // is used to transfer data to AVX register + tensor2D buffC; + + Matmul(bool constB = false, bool transposeB = false) : + constB(constB), transposeB(transposeB), buffC(32, 32) {} + + // ppkernel is a callable which captures the runtime args + // by itself, so no need to pass in any post-process related + // runtime args through this API + // + // n0/n1 allows us for calculating only partial results, so it + // can be used to run on multi-cores in parallel + // + // ppkernel will be invoked with true (m,n) with n0-offset added + // so ppkernel don't need to know which sub-matrix it's working on. + // + // for most ppkernels w/o runtime state, a single ppkernel can be + // shared among all threads. + // + // but for ppkernels doing reductions, it needs separate instance + // for each thread, also a post-merging process to combine the results. + // + // ppkernels are simple to write, further wrapping or structurelize only + // makes the design more complex, so we stop doing that. + + // I cannot find a way to call TDP intrinsic polymophically using overload or template. + // have to use old-macro-tricks, hopefully these compile-time checks can be optimized + // by compiler. +#define TILE_DP(dst, a, b) \ + if (is_bf16bf16) _tile_dpbf16ps(dst, a, b); \ + if (is_s8s8) _tile_dpbssd(dst, a, b); \ + if (is_s8u8) _tile_dpbsud(dst, a, b); \ + if (is_u8s8) _tile_dpbusd(dst, a, b); \ + if (is_u8u8) _tile_dpbuud(dst, a, b); + + template + void kernel_slimB(int M, int N, int K, + tensor2D & A, + void * B, + tensor2D & buffC, + PP ppkernel) { + auto * pB0 = reinterpret_cast(B); + auto * pC0 = &buffC[0]; + int8_t * pA0 = reinterpret_cast(&A[0]); + int strideA = A.stride; + int KlastOffBytes = (K - kStep)* sizeof(TA); + // load B tiles outside of loop + if (tmmN > 0) _tile_loadd(2, pB0, 64); pB0 += 1024*2; + if (tmmN > 1) _tile_loadd(3, pB0, 64); pB0 += 1024*2; + if (tmmN > 2) _tile_loadd(4, pB0, 64); pB0 += 1024*2; + if (tmmN > 3) _tile_loadd(5, pB0, 64); pB0 += 1024*2; + if (tmmN > 4) _tile_loadd(6, pB0, 64); pB0 += 1024*2; + if (tmmN > 5) _tile_loadd(7, pB0, 64); pB0 += 1024*2; + //asm("int3"); + for(int m0 = 0; m0 < M; m0+=16) { + int m = m0; + if (M - m0 < 16) { + // shift up to prevent M-tails + pA0 -= (16 - (M - m0))*A.stride; + m = M - 16; + } + zero_tiles<0>(); + if (tmmN == 1) { + _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); + } + if (tmmN == 2) { + _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); + _tile_loadd(1, pA0 + KlastOffBytes, strideA); TILE_DP(0, 1, 3); + } + if (tmmN == 3) { + _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); + _tile_loadd(1, pA0 + 64, strideA); TILE_DP(0, 1, 3); + _tile_loadd(1, pA0 + KlastOffBytes, strideA); TILE_DP(0, 1, 4); + } + if (tmmN == 4) { + _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); + _tile_loadd(1, pA0 + 64, strideA); TILE_DP(0, 1, 3); + _tile_loadd(1, pA0 + 128, strideA); TILE_DP(0, 1, 4); + _tile_loadd(1, pA0 + KlastOffBytes, strideA); TILE_DP(0, 1, 5); + } + if (tmmN == 5) { + _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); + _tile_loadd(1, pA0 + 64, strideA); TILE_DP(0, 1, 3); + _tile_loadd(1, pA0 + 128, strideA); TILE_DP(0, 1, 4); + _tile_loadd(1, pA0 + 192, strideA); TILE_DP(0, 1, 5); + _tile_loadd(1, pA0 + KlastOffBytes, strideA); TILE_DP(0, 1, 6); + } + if (tmmN == 6) { + _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); + _tile_loadd(1, pA0 + 64, strideA); TILE_DP(0, 1, 3); + _tile_loadd(1, pA0 + 128, strideA); TILE_DP(0, 1, 4); + _tile_loadd(1, pA0 + 192, strideA); TILE_DP(0, 1, 5); + _tile_loadd(1, pA0 + 256, strideA); TILE_DP(0, 1, 6); + _tile_loadd(1, pA0 + KlastOffBytes, strideA); TILE_DP(0, 1, 7); + } + _tile_stored(0, pC0, buffC.stride); + (ppkernel)(buffC, m, 0, 16, N); + pA0 += 16*A.stride; + } + } + + template + void operator()(tensor2D & matA, + tensor2D & _matB, + int n0, int n1, + PP ppkernel, + bool skip_repack = false) { + auto matB = getSubMatB(_matB, n0, n1, transposeB); + int M = matA.dims[0]; + int K = matA.dims[1]; + int N = matB.dims[transposeB ? 0 : 1]; + assert(K == matB.dims[transposeB ? 1 : 0]); + // Due to the fact that we load a full tile at tails of K dimension + // we may access memory address beyond the limit of A matrix + // to avoid read in nan values, we backoff to the left to ensure A tile + // contain valid numbers and no overflow access happens, but it requires K>=kStep; + assert(K >= kStep); + int Ktails = K % kStep; + int Kbody = K - Ktails; + int KbackoffBytes = (kStep - Ktails)*sizeof(TA); + + // for non-constB, internalB is updated every time + // for constB, internalB is updated once + if ((!constB && !skip_repack) || (internalB.capacity == 0)) { + repackB_1x2(matB, transposeB, internalB, constB); + } + + // special case when whole B matrix can fit in 6 tiles + // we can load B only once + if (M >= 16 && N <= 16 && K <= 6*kStep) { + // B is zero-padded + // C:0 + // A:1 + // B:2,3,4,5,6,7 + auto * pB0 = reinterpret_cast(&internalB[0]); + tileconfig_t tfg(1, 0, 8, 16, 64); + switch((K + kStep - 1)/kStep) { + case 1: kernel_slimB<1>(M, N, K, matA, pB0, buffC, ppkernel); break; + case 2: kernel_slimB<2>(M, N, K, matA, pB0, buffC, ppkernel); break; + case 3: kernel_slimB<3>(M, N, K, matA, pB0, buffC, ppkernel); break; + case 4: kernel_slimB<4>(M, N, K, matA, pB0, buffC, ppkernel); break; + case 5: kernel_slimB<5>(M, N, K, matA, pB0, buffC, ppkernel); break; + case 6: kernel_slimB<6>(M, N, K, matA, pB0, buffC, ppkernel); break; + default: + assert(false); // impossible since (K <= 6*kStep) + } + return; + } + + if (M <= 16) { + // register/cache blocking scheme is simplified when M <= 16 + // C_MxN: 0,1 + // A_MxK: 2, + // B_KxN: 3, 4 + tileconfig_t tfg(1, 0, {M,M,M,16,16}, 64); + auto * pB0 = reinterpret_cast(&internalB[0]); + auto * const pC0 = &buffC[0]; + int k; + const auto strideA = matA.stride; + loop2D_no_bM<32>(M, N, [&](int m, int n, int valid_m, int valid_n) { + zero_tiles<0, 1>(); + int8_t * pA0 = reinterpret_cast(&matA[0]); + for(k=0; k(pB0); + _tile_loadd(3, pB0, 64); pB0 += 1024; // tile B0 32x16(16x16x2)/64x16(16x16x4) is always 1KB + // prefetch_bytes<1024, _MM_HINT_T1, 4096*48>(pB0); + _tile_loadd(4, pB0, 64); pB0 += 1024; // tile B1 32x16(16x16x2)/64x16(16x16x4) is always 1KB + TILE_DP(0, 2, 3); // C0 += A*B0 + TILE_DP(1, 2, 4); // C1 += A*B1 + } + if (Ktails) { + _tile_loadd(2, pA0 - KbackoffBytes, strideA); + // prefetch_bytes<1024, _MM_HINT_T1, 4096*48>(pB0); + _tile_loadd(3, pB0, 64); pB0 += 1024; + // prefetch_bytes<1024, _MM_HINT_T1, 4096*48>(pB0); + _tile_loadd(4, pB0, 64); pB0 += 1024; + TILE_DP(0, 2, 3); // C0 += A*B0 + TILE_DP(1, 2, 4); // C1 += A*B1 + } + _tile_stored(0, pC0, buffC.stride); + _tile_stored(1, pC0 + 16, buffC.stride); + //int valid_n = std::min(N - n, 32); + (ppkernel)(buffC, 0, n + n0, M, valid_n); + }); + return; + } + + auto kernel_2x2 = [&](int m, int n, int valid_m, int valid_n) { + auto * pA0 = reinterpret_cast(&matA(m, 0)); + auto * pA1 = reinterpret_cast(&matA(m + 16, 0)); + auto strideA = matA.stride; + auto * pB = reinterpret_cast(&internalB(n>>5, 0)); + zero_tiles<0, 1, 2, 3>(); + // 2x2 + for (int k = 0; k < Kbody; k += kStep) { + _tile_loadd(4, pA0, strideA); pA0 += 64; + _tile_loadd(6, pB, 64); pB += 1024; + // prefetch_bytes<1024>(pB); + TILE_DP(0, 4, 6); + + _tile_loadd(5, pA1, strideA); pA1 += 64; + TILE_DP(2, 5, 6); + _tile_loadd(7, pB, 64); pB += 1024; + // prefetch_bytes<1024>(pB); + TILE_DP(1, 4, 7); + + TILE_DP(3, 5, 7); + } + if (Ktails) { + _tile_loadd(4, pA0 - KbackoffBytes, strideA); + _tile_loadd(6, pB, 64); pB += 1024; + // prefetch_bytes<1024>(pB); + TILE_DP(0, 4, 6); + + _tile_loadd(5, pA1 - KbackoffBytes, strideA); + TILE_DP(2, 5, 6); + _tile_loadd(7, pB, 64); pB += 1024; + // prefetch_bytes<1024>(pB); + TILE_DP(1, 4, 7); + + TILE_DP(3, 5, 7); + } + _tile_stored(0, &buffC(0,0), buffC.stride); + _tile_stored(1, &buffC(0,16), buffC.stride); + _tile_stored(2, &buffC(16,0), buffC.stride); + _tile_stored(3, &buffC(16,16), buffC.stride); + (ppkernel)(buffC, m, n + n0, valid_m, valid_n); + }; + + if (M <= 32 && M >16) { + // 2x2 tile, C:0/1/2/3 A:4/5 B:6/7 no blocking along M dimension + tileconfig_t tfg(1, 0, {16,16,M-16,M-16,16,M-16,16,16}, 64); + loop2D_no_bM<32>(M, N, kernel_2x2); + return; + } + + // generic input shapes with M > 32 + // determine cache blocking scheme + int elesz = sizeof(TA); + int L2 = 2048*1024; // 2MB + int slice_size = 32*rndup(K, 32)*elesz; + int mc = std::max(1, L2/slice_size - 1); + + // M > bM + tileconfig_t tfg(1, 0, 8, 16, 64); + loop2D_opt_Mtail<32, 32>(M, N, mc, kernel_2x2); + } +}; + +// specialization: +// TA is ov::bfloat16 and TB is int8_t, decompressed on the fly into ov::bfloat16 by simply convert +template<> +struct Matmul { + tensor2D internalBI8; + + // wei_buff is ping-pong buffer containing ov::bfloat16 weights decompressed on the fly. + tensor2D weiBuff; + + bool constB; + bool transposeB; + + constexpr static int kStep = 32; + + // 2x2 C tiles buffer + // most usecase requires post-processing with AVX, thus buffC + // is used to transfer data to AVX register + tensor2D buffC; + + Matmul(bool constB = false, bool transposeB = false) : + constB(constB), transposeB(transposeB), buffC(32, 32) {} + + float quant_scale_B; + float dequant_scale_B; + + template + void operator()(tensor2D & matA, + tensor2D & _matB, + int n0, int n1, + PP ppkernel) { + auto matB = getSubMatB(_matB, n0, n1, transposeB); + int M = matA.dims[0]; + int K = matA.dims[1]; + int N = matB.dims[transposeB ? 0 : 1]; + assert(K == matB.dims[transposeB ? 1 : 0]); + // Due to the fact that we load a full tile at tails of K dimension + // we may access memory address beyond the limit of A matrix + // to avoid read in nan values, we backoff to the left to ensure A tile + // contain valid numbers and no overflow access happens, but it requires K>=kStep; + assert(K >= kStep); + int Ktails = K % kStep; + int Kbody = K - Ktails; + int Kbackoff = (kStep - Ktails); + + // for non-constB, internalB is updated every time + // for constB, internalB is updated once + if (!constB || (internalBI8.capacity == 0)) { + // this dynamic quantization of weight matrix using minmax + // is time-consuming, should be used only for constB + if (!constB) { + std::cout << "\t WANING: dynamic quantization of weight matrix for non-constB is time-consuming " << std::endl; + } + // float min, max; + // functional::get_min_max(_matB, min, max); + // max = std::max(std::abs(max), std::abs(min)); + // quant_scale_B = 127 / max; + // dequant_scale_B = max / 127; + + tensor2D internalTmpB; + repackB_1x2(matB, transposeB, internalTmpB, constB); + functional::bf16_to_i8_tensor(internalBI8, internalTmpB, quant_scale_B); + } + + ppkernel.set_deq_scale(dequant_scale_B); + + if (M <= 16) { + // C:0/1 A:2 B:3/4 + // dequantize scale is moved into ppkernel + constexpr int prefetch_ahead = 64*1024; + tileconfig_t tfg(1, 0, {M,M,M,16,16}, 64); + auto * pBint = reinterpret_cast(&internalBI8[0]); + auto & B2buff = weiBuff; + B2buff.resize(32*2, 32); + auto * const pB = &B2buff[0]; + auto * pBsrc = pB + (32*32) * 0; + auto * pBdst = pB + (32*32) * 1; + functional::i8_to_bf16_Kx32<32>(pBint, pBsrc); + + auto * const pC0 = &buffC[0]; + const auto strideA = matA.stride; + loop2D_no_bM<32>(M, N, [&](int m, int n, int valid_m, int valid_n) { + // C:Mx32 = A:Mx32 x B:32x32 + zero_tiles<0, 1>(); + auto * pA0 = &matA[0]; + for(int k=0; k(pBint); + + functional::i8_to_bf16_Kx32<8>(pBint, pBdst); + _tile_loadd(3, pBsrc, 64); + functional::i8_to_bf16_Kx32<8>(pBint, pBdst + 8*32); + _tile_dpbf16ps(0, 2, 3); // C0 += A*B0 + + prefetch_bytes<512, _MM_HINT_T1, prefetch_ahead>(pBint); + functional::i8_to_bf16_Kx32<8>(pBint, pBdst + 16*32); + _tile_loadd(4, pBsrc + 16*32, 64); + functional::i8_to_bf16_Kx32<8>(pBint, pBdst + 24*32); + _tile_dpbf16ps(1, 2, 4); // C1 += A*B1 + std::swap(pBsrc, pBdst); + } + if (Ktails) { + _tile_loadd(2, pA0 - Kbackoff, strideA); // backoff to prevent access beyond the end of A + prefetch_bytes<512, _MM_HINT_T1, prefetch_ahead>(pBint); + + functional::i8_to_bf16_Kx32<8>(pBint, pBdst); + _tile_loadd(3, pBsrc, 64); + functional::i8_to_bf16_Kx32<8>(pBint, pBdst + 8*32); + _tile_dpbf16ps(0, 2, 3); // C0 += A*B0 + + prefetch_bytes<512, _MM_HINT_T1, prefetch_ahead>(pBint); + functional::i8_to_bf16_Kx32<8>(pBint, pBdst + 16*32); + _tile_loadd(4, pBsrc + 16*32, 64); + functional::i8_to_bf16_Kx32<8>(pBint, pBdst + 24*32); + _tile_dpbf16ps(1, 2, 4); // C1 += A*B1 + std::swap(pBsrc, pBdst); + } + //prefetch_bytes<2048, _MM_HINT_T1, prefetch_ahead>(pBint); + _tile_stored(0, pC0, buffC.stride); + _tile_stored(1, pC0 + 16, buffC.stride); + //prefetch_bytes<2048, _MM_HINT_T1, prefetch_ahead>(pBint + 2048); + //int valid_n = std::min(N - n, 32); + (ppkernel)(buffC, 0, n + n0, M, valid_n); + }); + return; + } + + // 4 tiles buffC is reused as decompressed bf16 weights + constexpr int prefetch_ahead = 16*1024; + ov::bfloat16 * pBa = reinterpret_cast(&buffC(0,0)); + ov::bfloat16 * pBb = pBa + (16*32)*2; + auto kernel_2x2 = [&](int m, int n, int valid_m, int valid_n) { + auto strideA = matA.stride; + auto * pA0 = &matA(m, 0); + auto * pA1 = &matA(m + 16, 0); + auto * pBint = reinterpret_cast(&internalBI8(n>>5, 0)); + functional::i8_to_bf16_Kx32<32>(pBint, pBb); + + zero_tiles<0, 1, 2, 3>(); + int k; + for (k = 0; k < Kbody; k += kStep) { + functional::i8_to_bf16_Kx32<16>(pBint, pBa); + + _tile_loadd(4, pA0 + k, strideA); + _tile_loadd(6, pBb, 64); + _tile_dpbf16ps(0, 4, 6); + + _tile_loadd(5, pA1 + k, strideA); + _tile_dpbf16ps(2, 5, 6); + + functional::i8_to_bf16_Kx32<16>(pBint, pBa + 16*32); + + _tile_loadd(7, pBb + 16*32, 64); + _tile_dpbf16ps(1, 4, 7); + _tile_dpbf16ps(3, 5, 7); + + std::swap(pBa, pBb); + } + if (Ktails) { + functional::i8_to_bf16_Kx32<16>(pBint, pBa); + + _tile_loadd(4, pA0 + k - Kbackoff, strideA); + _tile_loadd(6, pBb, 64); + _tile_dpbf16ps(0, 4, 6); + + _tile_loadd(5, pA1 + k - Kbackoff, strideA); + _tile_dpbf16ps(2, 5, 6); + + functional::i8_to_bf16_Kx32<16>(pBint, pBa + 16*32); + + _tile_loadd(7, pBb + 16*32, 64); + _tile_dpbf16ps(1, 4, 7); + _tile_dpbf16ps(3, 5, 7); + + std::swap(pBa, pBb); + } + _tile_stored(0, &buffC(0,0), buffC.stride); + _tile_stored(1, &buffC(0,16), buffC.stride); + _tile_stored(2, &buffC(16,0), buffC.stride); + _tile_stored(3, &buffC(16,16), buffC.stride); + (ppkernel)(buffC, m, n + n0, valid_m, valid_n); + }; + + if (M <= 32 && M > 16) { + // 2x2 C:0/1/2/3 A:4/5 B:6/7 + tileconfig_t tfg(1, 0, {16, 16, M-16, M-16, 16, M-16, 16, 16}, 64); + loop2D_no_bM<32>(M, N, kernel_2x2); + return; + } + + // determine blocking scheme + int elesz = sizeof(uint16_t); + int L2 = 2048*1024; // 2MB + int slice_size = 32*rndup(K, 32)*elesz; + int mc = std::max(1, L2/slice_size - 1); // if 1 32xK slice cannot fit L2, use 1 slice at least + + // main loop + tileconfig_t tfg(1, 0, 8, 16, 64); + loop2D_opt_Mtail<32, 32>(M, N, mc, kernel_2x2); + } +}; + +//https://stackoverflow.com/questions/29519222/how-to-transpose-a-16x16-matrix-using-simd-instructions +// vector multiply with matrix: +// mAvB: A(M, K) * B(K, 1) => C(M, 1) +// vAmB: A(1, K) * B(K, N) => C(1, N) +// +// in mAvB form, block of A (16x32) is transposed in register +// in unit of 2 packed bf16, and then vdpbf16ps was used +// to multiply with broadcasted B (2x1) and accumulate into C (16x1) +// +// B is pre-broadcasted in unit of 2 +// +struct GemAvB { + tensor2D Bpadded; + GemAvB() { + } + + void operator()(tensor2D & matA, + ov::bfloat16 * vecB, + float * vecC) { + int M = matA.dims[0]; + int K = matA.dims[1]; + + constexpr int kStep = 32; + + assert(K >= 32); + int Ktails = K % kStep; + int Kbody = K - Ktails; + int Kbackoff = (kStep - Ktails); + + if (K % 32) { + if (K > Bpadded.dims[1]) + Bpadded.resize(1, rndup(K, 32)); + auto newB = &Bpadded(0, 0); + memset(newB, 0, Bpadded.stride); + memcpy(newB, vecB, K * sizeof(ov::bfloat16)); + vecB = newB; + } + + for(int m = 0; m < M; m += 16) { + auto * pA = reinterpret_cast(&matA(m, 0)); + auto * pBi32 = reinterpret_cast(vecB); + __m512 regC0 = _mm512_setzero(); + __m512 regC1 = _mm512_setzero(); + __mmask16 kmask = _cvtu32_mask16(0xFFFF); + if (M-m < 16) { + kmask = _cvtu32_mask16(0xFFFF >> (16-(M-m))); + } + for(int k = 0; k < K; k += 32, pA += 64, pBi32 += 16) { + // handle Ab: 16x32 + // transposed in register as 16x16x2 + // r0: (a0,a1)(b0,b1).... + // r1: (a2,a3)(b2,b3).... + // ... + // rf: (a30,a31),(b30,b31).... + // + __m512i r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf; + auto stride = matA.stride; + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_loadu_epi32(pA + 2*stride); + r3 = _mm512_loadu_epi32(pA + 3*stride); + r4 = _mm512_loadu_epi32(pA + 4*stride); + r5 = _mm512_loadu_epi32(pA + 5*stride); + r6 = _mm512_loadu_epi32(pA + 6*stride); + r7 = _mm512_loadu_epi32(pA + 7*stride); + r8 = _mm512_loadu_epi32(pA + 8*stride); + r9 = _mm512_loadu_epi32(pA + 9*stride); + ra = _mm512_loadu_epi32(pA + 10*stride); + rb = _mm512_loadu_epi32(pA + 11*stride); + rc = _mm512_loadu_epi32(pA + 12*stride); + rd = _mm512_loadu_epi32(pA + 13*stride); + re = _mm512_loadu_epi32(pA + 14*stride); + rf = _mm512_loadu_epi32(pA + 15*stride); + + functional::transpose_m512i_16x16(r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf); + + // vdpbf16ps + regC0 = _mm512_dpbf16_ps(regC0, r0, _mm512_set1_epi32(pBi32[0])); + regC1 = _mm512_dpbf16_ps(regC1, r1, _mm512_set1_epi32(pBi32[1])); + regC0 = _mm512_dpbf16_ps(regC0, r2, _mm512_set1_epi32(pBi32[2])); + regC1 = _mm512_dpbf16_ps(regC1, r3, _mm512_set1_epi32(pBi32[3])); + regC0 = _mm512_dpbf16_ps(regC0, r4, _mm512_set1_epi32(pBi32[4])); + regC1 = _mm512_dpbf16_ps(regC1, r5, _mm512_set1_epi32(pBi32[5])); + regC0 = _mm512_dpbf16_ps(regC0, r6, _mm512_set1_epi32(pBi32[6])); + regC1 = _mm512_dpbf16_ps(regC1, r7, _mm512_set1_epi32(pBi32[7])); + regC0 = _mm512_dpbf16_ps(regC0, r8, _mm512_set1_epi32(pBi32[8])); + regC1 = _mm512_dpbf16_ps(regC1, r9, _mm512_set1_epi32(pBi32[9])); + regC0 = _mm512_dpbf16_ps(regC0, ra, _mm512_set1_epi32(pBi32[10])); + regC1 = _mm512_dpbf16_ps(regC1, rb, _mm512_set1_epi32(pBi32[11])); + regC0 = _mm512_dpbf16_ps(regC0, rc, _mm512_set1_epi32(pBi32[12])); + regC1 = _mm512_dpbf16_ps(regC1, rd, _mm512_set1_epi32(pBi32[13])); + regC0 = _mm512_dpbf16_ps(regC0, re, _mm512_set1_epi32(pBi32[14])); + regC1 = _mm512_dpbf16_ps(regC1, rf, _mm512_set1_epi32(pBi32[15])); + } + regC0 = _mm512_add_ps(regC0, regC1); + _mm512_mask_storeu_ps (vecC + m, kmask, regC0); + //auto regOut = _mm512_cvtne2ps_pbh(regC0, regC0); // only 16 ov::bfloat16 results in lower 256bits + //_mm256_storeu_si256(reinterpret_cast<__m256i_u *>(vecC + m), _mm512_extracti64x4_epi64(regOut, 0)); + } + } +}; + +} // namespace amx + +inline std::ostream & operator<<(std::ostream & os, const amx_kernel::PP::Steps & steps) { + os << "amx_kernel::PP::Steps::"; + if (steps == amx_kernel::PP::Steps::NONE) + os << "NONE"; + if (steps & amx_kernel::PP::Steps::DEQUANT) + os << "_DEQUANT"; + if (steps & amx_kernel::PP::Steps::BIAS) + os << "_BIAS"; + if (steps & amx_kernel::PP::Steps::GELU) + os << "_GELU"; + if (steps & amx_kernel::PP::Steps::QUANT) + os << "_QUANT"; + return os; +} diff --git a/src/tensor2d.hpp b/src/tensor2d.hpp new file mode 100644 index 0000000..8ca325f --- /dev/null +++ b/src/tensor2d.hpp @@ -0,0 +1,178 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include +#include +#ifdef ENABLE_NUMA +#include "numa.h" +#endif +#include "bf16.hpp" + +#define rndup(x, n) (((x + n - 1)/n)*n) + +template +struct tensor2D { + int dims[2] = {0}; + T* data = nullptr; + int64_t capacity = 0; + int stride = 0; + bool force_compact = false; + bool own; + int padded_dim1 = 0; + + tensor2D() = default; + tensor2D(const tensor2D&) = delete; + ~tensor2D() { + if (own && data) ::free(data); + } + + operator bool() { + return dims[0] * dims[1] > 0; + } + + tensor2D(int d0, int d1, bool _force_compact = false) { + capacity = 0; + resize(d0, d1, _force_compact); + } + + tensor2D(int d0, int d1, T * ext, int _stride) { + capacity = 1; + data = ext; + own = false; + dims[0] = d0; + dims[1] = d1; + stride = _stride; + padded_dim1 = stride / sizeof(T); + } + + tensor2D Tr() { + tensor2D ret(dims[1], dims[0]); + for(int c0=0; c0 < dims[0]; ++c0) { + for(int c1=0; c1 < dims[1]; ++c1) { + ret(c1, c0) = (*this)(c0, c1); + } + } + return ret; + } + tensor2D clone() { + tensor2D ret; + ret.resize(dims[0], dims[1], force_compact); + if (ret.stride == stride) { + memcpy(ret.data, data, dims[0] * stride); + }else{ + for(int i=0;i( + reinterpret_cast(numa_alloc_local(capacity)), + [need_capacity](void * p){ numa_free(p, need_capacity); }); + } else { +#else + { +#endif + if (data) ::free(data); + data = reinterpret_cast(aligned_alloc(64, capacity)); + } + if (is_const) + memset(data, 0, need_capacity); + if (reinterpret_cast(data) % 64) + std::cout << "WARNING: resize(), data is not cache-line aligned!" << std::endl; + } + // put a NaN at the end to test over-read + // https://en.wikipedia.org/wiki/Bfloat16_floating-point_format + // #define INF 0xff80 + // #define NAN1 (INF + 1) + // if (sizeof(T) == 2) { + // *reinterpret_cast(data.get() + dims[0] * padded_dim1) = NAN1; + // } + } + + T & operator[](int i) { + return data[i]; + } + + const T & operator[](int i) const { + return data[i]; + } + + //https://stackoverflow.com/questions/1936399/c-array-operator-with-multiple-arguments + T & operator()(int i0, int i1) { + return (*this)[i0 * padded_dim1 + i1]; + } + + const T & operator()(int i0, int i1) const { + return (*this)[i0 * padded_dim1 + i1]; + } + + + void operator=(const T & v) { + for(int k = 0; k < dims[0] * padded_dim1; k++) + (*this)[k] = v; + } + + tensor2D& operator=(const tensor2D & t2) = delete; + + // move semantics + tensor2D(tensor2D && t2) { + dims[0] = t2.dims[0]; + dims[1] = t2.dims[1]; + if (own && data) ::free(data); + data = t2.data; + own = t2.own; + capacity = t2.capacity; + stride = t2.stride; + padded_dim1 = t2.padded_dim1; + force_compact = t2.force_compact; + t2.capacity = 0; + t2.data = nullptr; + } + + tensor2D& operator=(tensor2D && t2) { + dims[0] = t2.dims[0]; + dims[1] = t2.dims[1]; + if (own && data) ::free(data); + own = t2.own; + data = t2.data; + capacity = t2.capacity; + stride = t2.stride; + padded_dim1 = t2.padded_dim1; + force_compact = t2.force_compact; + t2.capacity = 0; + t2.data = nullptr; + return *this; + } +}; diff --git a/src/tensor2d_helper.hpp b/src/tensor2d_helper.hpp new file mode 100644 index 0000000..417b2a6 --- /dev/null +++ b/src/tensor2d_helper.hpp @@ -0,0 +1,200 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include +#include +#include +#include "llm_types.hpp" +#include "tensor2d.hpp" +#include "bf16.hpp" +#ifdef _WIN32 +#include +#else +#include +#endif +#include + +// https://stackoverflow.com/questions/570669/checking-if-a-double-or-float-is-nan-in-c/57770634#57770634 +static inline uint32_t load_ieee754_rep(float a) { + uint32_t r; + static_assert(sizeof r == sizeof a, "Unexpected sizes."); + std::memcpy(&r, &a, sizeof a); // Generates movd instruction. + return r; +} +constexpr uint32_t inf_float_shl1 = UINT32_C(0xff000000); +// The shift left removes the sign bit. The exponent moves into the topmost bits, +// so that plain unsigned comparison is enough. +static inline bool isnan2(float a) { return load_ieee754_rep(a) << 1 > inf_float_shl1; } +static inline bool isinf2(float a) { return load_ieee754_rep(a) << 1 == inf_float_shl1; } +static inline bool isfinite2(float a) { return load_ieee754_rep(a) << 1 < inf_float_shl1; } + +template +void fill_rnd(tensor2D& t) { + auto * p = t.data; + int i = 0; + int total = t.dims[0] * t.padded_dim1; + // +1 -1 for integer types + // 0.5 -0.5 for float point + float scale = std::is_integral::value ? 2:1; + for(i = 0; i + 8 <= total; i+=8) { + // lower mantissa can help to avoid small errors in accuracy comparison + auto num = rand() & 0xFF; + p[i] = scale*((num & 1) - 0.5f); num>>=1; + p[i+1] = scale*((num & 1) - 0.5f); num>>=1; + p[i+2] = scale*((num & 1) - 0.5f); num>>=1; + p[i+3] = scale*((num & 1) - 0.5f); num>>=1; + p[i+4] = scale*((num & 1) - 0.5f); num>>=1; + p[i+5] = scale*((num & 1) - 0.5f); num>>=1; + p[i+6] = scale*((num & 1) - 0.5f); num>>=1; + p[i+7] = scale*((num & 1) - 0.5f); num>>=1; + } + for(; i +bool operator==(const tensor2D& lhs, const tensor2D& rhs) { + if (lhs.dims[0] != rhs.dims[0] || lhs.dims[1] != rhs.dims[1]) + return false; + for(int i0 = 0; i0 < lhs.dims[0]; i0++) + for(int i1 = 0; i1 < lhs.dims[1]; i1++) { + // with -ffast-math, std::isnan, std::isinf, x != x always return false + // so we need special logic to test nan here + if (std::is_same::value || + std::is_same::value) { + float f0 = lhs(i0,i1); + float f1 = rhs(i0,i1); + if (isnan2(f1) || isnan2(f0)) { + std::cout << " nan is found: f0=" << f0 << ", f1=" << f1 << std::endl; + return false; + } + if (std::abs(f0 - f1) <= 0.01) + continue; + } + + if (lhs(i0,i1) == rhs(i0,i1)) + continue; + std::cout << " operator== failed at (" << i0 << ", " << i1 << ") value " + << lhs(i0,i1) << "!=" << rhs(i0,i1) << std::endl; + return false; + } + return true; +} + +template +bool is_normal(const tensor2D& t) { + for (int i0 = 0; i0 < t.dims[0]; i0++) + for (int i1 = 0; i1 < t.dims[1]; i1++) { + float f0 = t(i0,i1); + if (isnan2(f0)) { + std::cout << " found nan at (" << i0 << "," << i1 << ")" << std::endl; + return false; + } + if (isinf2(f0)) { + std::cout << " found inf at (" << i0 << "," << i1 << ")" << std::endl; + return false; + } + } + return true; +} + +template +bool compare(const tensor2D& lhs, const tensor2D& rhs, float tolerance) { + float max_abs_diff = 0; + float max_rel_diff = 0; + if (lhs.dims[0] != rhs.dims[0] || lhs.dims[1] != rhs.dims[1]) + return false; + for (int i0 = 0; i0 < lhs.dims[0]; i0++) + for (int i1 = 0; i1 < lhs.dims[1]; i1++) { + float f0 = lhs(i0, i1); + float f1 = rhs(i0, i1); + auto diff = std::fabs(f0 - f1); + auto rel_diff = diff / std::fabs(f0); + max_abs_diff = std::max(max_abs_diff, diff); + if (std::fabs(lhs(i0,i1) > 0) && diff > 0) + max_rel_diff = std::max(max_rel_diff, rel_diff); + } + std::cout << "max_abs_diff=" << max_abs_diff << " max_rel_diff=" << max_rel_diff << "\n"; + return tolerance > max_abs_diff; +} + +template +std::ostream& operator<<(std::ostream& out, const tensor2D& obj) { + int i0; + auto showline = [&](int i) { + out << "[" << i << "," << 0 << "]: "; + int i1; + for(i1=0; i1 +inline void show(const T * data, int rows, int cols) { + std::ostream& out = std::cout; + out << "==============\n"; + for(int i0=0; i0 < rows; i0++) { + out << "[" << i0 << "," << 0 << "]: "; + for(int i1=0; i1 +inline void vshow(__m512i v) { + T values[512/8/sizeof(T)]; + _mm512_storeu_si512(values, v); + show(values, 1, 512/8/sizeof(T)); +} + +template +inline void vshow(__m512 v) { + T values[512/8/sizeof(T)]; + _mm512_storeu_ps(values, v); + show(values, 1, 512/8/sizeof(T)); +} + +template +inline void tshow() { + if (std::is_same::value) { + ov::bfloat16 data[16*32]; + _tile_stored(tile, data, 64); + show(data, 16, 32); + } + if (std::is_same::value) { + float data[16*16]; + _tile_stored(tile, data, 64); + show(data, 16, 16); + } + if (std::is_same::value) { + int8_t data[16*64]; + _tile_stored(tile, data, 64); + show(data, 16, 64); + } + if (std::is_same::value) { + uint8_t data[16*64]; + _tile_stored(tile, data, 64); + show(data, 16, 64); + } +} diff --git a/src/utility_amx.hpp b/src/utility_amx.hpp new file mode 100644 index 0000000..2c348d4 --- /dev/null +++ b/src/utility_amx.hpp @@ -0,0 +1,107 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include + +#ifdef _WIN32 +#include +#else +#include +#endif + +struct tileconfig_t { + uint8_t palette_id; + uint8_t startRow; + uint8_t reserved[14]; + uint16_t cols[16]; + uint8_t rows[16]; + tileconfig_t() = default; + + tileconfig_t(int palette, int _startRow, const std::initializer_list> &_rows_columnsBytes) { + palette_id = palette; + startRow = _startRow; + int i; + for(i = 0; i < 14; i++) { + reserved[i] = 0; + } + i = 0; + for (const auto& ele : _rows_columnsBytes) { + rows[i] = ele.first; + cols[i] = ele.second; + i++; + } + for(; i < 16; i++) { + cols[i] = 0; + rows[i] = 0; + } + load(); + } + + tileconfig_t(int palette, int _startRow, const std::initializer_list &_rows, int columnsBytes) { + palette_id = palette; + startRow = _startRow; + int i; + for(i = 0; i < 14; i++) { + reserved[i] = 0; + } + i = 0; + for (const auto ele : _rows) { + rows[i] = ele; + cols[i] = columnsBytes; + i++; + } + for(; i < 16; i++) { + cols[i] = 0; + rows[i] = 0; + } + load(); + } + tileconfig_t(int palette, int _startRow, int numTiles, int _rows, int columnsBytes) { + palette_id = palette; + startRow = _startRow; + int i; + for(i = 0; i < 14; i++) { + reserved[i] = 0; + } + for(i = 0; i < numTiles; i++) { + rows[i] = _rows; + cols[i] = columnsBytes; + } + for(; i < 16; i++) { + cols[i] = 0; + rows[i] = 0; + } + load(); + } + + ~tileconfig_t() { + _tile_release(); + } + void __attribute__((noinline)) load() { + //std::cout << "\ttile load config ... " << std::flush; + _tile_loadconfig(this); + //std::cout << *this << std::flush << std::endl; + } + void store() { + _tile_storeconfig(this); + } + friend std::ostream& operator<<(std::ostream& out, const tileconfig_t& cfg) { + out << " palette_id=" << static_cast(cfg.palette_id); + out << " startRow=" << static_cast(cfg.startRow); + out << " row x colsb=("; + for (int i = 0; i < 16;i++) { + if (cfg.rows[i] == 0 && cfg.cols[i] == 0) + continue; + if (i > 0) out << ","; + out << static_cast(cfg.rows[i]) << "x" << static_cast(cfg.cols[i]); + } + out << ")"; + return out; + } +} __attribute__ ((__packed__)); diff --git a/src/utility_avx512.hpp b/src/utility_avx512.hpp new file mode 100644 index 0000000..5509218 --- /dev/null +++ b/src/utility_avx512.hpp @@ -0,0 +1,94 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include "bf16.hpp" +#ifdef _WIN32 +#include +#else +#include +#include +#endif + +namespace llmdnn { + +/// Convert Packed BF16 Data to Packed float Data. +/// +/// \headerfile +/// +/// \param __A +/// A 256-bit vector of [16 x bfloat]. +/// \returns A 512-bit vector of [16 x float] come from convertion of __A +static __inline__ __m512 _mm512_cvtpbh_ps(__m256bh __A) { + return _mm512_castsi512_ps((__m512i)_mm512_slli_epi32( + (__m512i)_mm512_cvtepi16_epi32((__m256i)__A), 16)); +} + +// Store masks. The highest bit in each byte indicates the byte to store. +alignas(16) const unsigned char masks[16][16] = +{ + { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }, + { 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }, + { 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }, + { 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }, + { 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }, + { 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }, + { 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }, + { 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }, + { 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }, + { 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }, + { 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }, + { 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00 }, + { 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00 }, + { 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00 }, + { 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00 }, + { 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00 } +}; + +inline void store_n(__m128i mm, unsigned int n, void* storage) +{ + _mm_maskmoveu_si128(mm, reinterpret_cast< const __m128i& >(masks[n]), static_cast< char* >(storage)); +} + +inline void quant_i8(void* dst, void* src, size_t ele_num, float scale) { + size_t i = 0; + auto* a = reinterpret_cast(src); + int8_t* d = reinterpret_cast(dst); + auto s = _mm512_set1_ps(scale); + for (; i < ele_num / 16 * 16; i += 16) { + auto a0 = _mm256_loadu_epi16(a); + auto a0_f = _mm512_cvtpbh_ps((__m256bh)a0); + auto d_f = _mm512_mul_ps(a0_f, s); + auto d_i = _mm512_cvtps_epi32(d_f); + auto d_i8 = _mm512_cvtsepi32_epi8(d_i); + _mm_storeu_si128(reinterpret_cast<__m128i *>(d), d_i8); + a += 16; + d += 16; + } + if (i != ele_num) { + // https://stackoverflow.com/questions/40391708/convert-16-bit-mask-mmask16-to-m128i-control-byte-mask-on-knl-xeon-phi-72 + __mmask16 msk = _cvtu32_mask16(0xFFFFu >> (16 - (ele_num % 16))); + auto a0 = _mm256_maskz_loadu_epi16(msk, a); + auto a0_f = _mm512_cvtpbh_ps((__m256bh)a0); + auto d_f = _mm512_mul_ps(a0_f, s); + auto d_i = _mm512_cvtps_epi32(d_f); + auto d_i8 = _mm512_cvtsepi32_epi8(d_i); + store_n(d_i8, ele_num % 16, d); + } +} + +// NOTE: did not handle tail because there should be enough room +inline void cvt_i32_f32(float* dst, int32_t* src, size_t ele_num) { + for (int i = 0; i < (ele_num + 15) / 16 * 16; i += 16) { + auto a0 = _mm512_load_epi32(src); + auto a_f = _mm512_cvtepi32_ps(a0); + _mm512_storeu_ps(dst, a_f); + src += 16; + dst += 16; + } +} + +} \ No newline at end of file diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt new file mode 100644 index 0000000..09ce3a5 --- /dev/null +++ b/tests/CMakeLists.txt @@ -0,0 +1,22 @@ +cmake_minimum_required(VERSION 3.18) +project(llmdnn_tests) + +include(FetchContent) + +FetchContent_Declare( + googletest + GIT_REPOSITORY https://github.com/google/googletest.git + GIT_TAG release-1.11.0 + GIT_SHALLOW TRUE + GIT_PROGRESS TRUE) +FetchContent_MakeAvailable(googletest) + +set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) + +include_directories(${LLMDNN_HEADERS_DIR}) + +file(GLOB_RECURSE TEST_SOURCE_FILES ${CMAKE_CURRENT_SOURCE_DIR}/src/*.cpp) + +add_executable(llmdnn_tests ${TEST_SOURCE_FILES}) +target_link_libraries(llmdnn_tests llmdnn gtest_main stdc++) +install(TARGETS llmdnn_tests DESTINATION ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}) diff --git a/tests/src/test_common.cpp b/tests/src/test_common.cpp new file mode 100644 index 0000000..28760b4 --- /dev/null +++ b/tests/src/test_common.cpp @@ -0,0 +1,61 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include +#include +#include +#include +#include "test_common.h" + +#ifndef _GNU_SOURCE +#define _GNU_SOURCE /* See feature_test_macros(7) */ +#endif +#include +#include /* For SYS_xxx definitions */ + +#define XFEATURE_XTILECFG 17 +#define XFEATURE_XTILEDATA 18 +#define XFEATURE_MASK_XTILECFG (1 << XFEATURE_XTILECFG) +#define XFEATURE_MASK_XTILEDATA (1 << XFEATURE_XTILEDATA) +#define XFEATURE_MASK_XTILE (XFEATURE_MASK_XTILECFG | XFEATURE_MASK_XTILEDATA) +#define ARCH_GET_XCOMP_PERM 0x1022 +#define ARCH_REQ_XCOMP_PERM 0x1023 + +using namespace std; +using namespace llmdnn; + +std::string dtype_to_str(data_type_t type) { + switch (type) { + case dnnl_data_type_undef: return "undef"; + case dnnl_f16: return "f16"; + case dnnl_bf16: return "bf16"; + case dnnl_f32: return "f32"; + case dnnl_s32: return "s32"; + case dnnl_s8: return "s8"; + case dnnl_u8: return "u8"; + case dnnl_f64: return "f64"; + default: return "unkown"; + } +} + +bool initXTILE() { + unsigned long bitmask = 0; + long status = syscall(SYS_arch_prctl, ARCH_GET_XCOMP_PERM, &bitmask); + if (0 != status) return false; + if (bitmask & XFEATURE_MASK_XTILEDATA) return true; + + status = syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA); + if (0 != status) + return false; // XFEATURE_XTILEDATA setup is failed, TMUL usage is not allowed + status = syscall(SYS_arch_prctl, ARCH_GET_XCOMP_PERM, &bitmask); + + // XFEATURE_XTILEDATA setup is failed, can't use TMUL + if (0 != status || !(bitmask & XFEATURE_MASK_XTILEDATA)) return false; + + // XFEATURE_XTILEDATA set successfully, TMUL usage is allowed + return true; +} \ No newline at end of file diff --git a/tests/src/test_common.h b/tests/src/test_common.h new file mode 100644 index 0000000..e557de3 --- /dev/null +++ b/tests/src/test_common.h @@ -0,0 +1,218 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include +#include +#include +#include "llm_types.hpp" +#include "llm_fc.hpp" +#include "tensor2d.hpp" +#include "bf16.hpp" +#ifdef _WIN32 +#include +#else +#include +#endif +#include + +#define rndup(x, n) (((x + n - 1)/n)*n) + +std::string dtype_to_str(llmdnn::data_type_t type); + +using func_act = std::function; + +bool initXTILE(); + +template +void matmul(tensor2D & A, + tensor2D & B, + tensor2D & C, + float * dq = nullptr, + float * bias = nullptr, + func_act act = func_act(), + float * q = nullptr) { + int M = C.dims[0]; + int N = C.dims[1]; + int K = A.dims[1]; + assert(B.dims[0] == K); + assert(B.dims[1] == N); + for(int m = 0; m < M; m++) { + for(int n = 0; n < N; n++) { + float sum = C(m,n); + int k; + for (k = 0; (k + 32) <= K; k += 32) { + float psum0 = 0; + float psum1 = 0; + for(int p = 0; p < 32; p+=2) { + psum0 += static_cast(A(m,k+p)) * static_cast(B(k+p,n)); + psum1 += static_cast(A(m,k+p+1)) * static_cast(B(k+p+1,n)); + } + sum += (psum0 + psum1); + } + for(; k < K; k++) { + sum += static_cast(A(m,k)) * static_cast(B(k,n)); + } + if (bias) { + sum += bias[n]; + } + if (act) { + sum = act(sum); + } + //std::cout << m << "," << n << std::endl; + C(m,n) = sum; + } + } +} + +inline void matmul(tensor2D & A, + tensor2D & B, + tensor2D & C, + float * dq = nullptr, + float * bias = nullptr, + func_act act = func_act(), + float * q = nullptr) { + int M = C.dims[0]; + int N = C.dims[1]; + int K = A.dims[1]; + assert(B.dims[0] == K); + assert(B.dims[1] == N); + for(int m = 0; m < M; m++) { + for(int n = 0; n < N; n++) { + float sum = C(m,n); + for(int k = 0; k < K; k++) { + sum += static_cast(A(m,k)) * static_cast(B(k,n)); + } + if (bias) { + sum += bias[n]; + } + if (act) { + sum = act(sum); + } + C(m,n) = sum; + } + } +} + +template +void matmul(tensor2D & A, + tensor2D & B, + tensor2D & C, + float * dq = nullptr, + float * bias = nullptr, + func_act act = func_act(), + float * q = nullptr) { + int M = C.dims[0]; + int N = C.dims[1]; + int K = A.dims[1]; + assert(B.dims[0] == K); + assert(B.dims[1] == N); + for(int m = 0; m < M; m++) { + for(int n = 0; n < N; n++) { + float sum = C(m,n); + for(int k = 0; k < K; k++) { + sum += static_cast(A(m,k)) * static_cast(B(k,n)); + } + if (dq) { + sum *= dq[n]; + } + if (bias) { + sum += bias[n]; + } + if (act) { + sum = act(sum); + } + if (q) { + sum *= q[n]; + sum = std::min(static_cast(std::numeric_limits::max()), sum); + sum = std::max(static_cast(std::numeric_limits::min()), sum); + } + C(m,n) = sum; + } + } +} + +struct ANSIcolor { + const char * code; + ANSIcolor(const char * code = "0") : code(code){ + } + friend std::ostream& operator<<(std::ostream& out, const ANSIcolor& obj) { + out << "\033[" << obj.code << "m"; + return out; + } +}; + +struct pretty_size { + double sz; + std::string txt; + pretty_size(double sz, const char * unit = "") : sz(sz) { + std::stringstream ss; + ss << std::setprecision(5) << std::setw(5); + if (sz < 1024) + ss << sz; + else if (sz < 1024 * 1024) + ss << (sz / 1024) << " K"; + else if (sz < 1024 * 1024 * 1024) + ss << (sz / 1024/1024) << " M"; + else + ss << (sz / 1024 / 1024/1024) << " G"; + ss << unit; + txt = ss.str(); + } + friend std::ostream& operator<<(std::ostream& os, const pretty_size& ps) { + os << ps.txt; + return os; + } +}; + +inline int readenv(const char * name) { + int v = 0; + auto * p = std::getenv(name); + if (p) + v = std::atoi(p); + std::cout << ANSIcolor("32") << "ENV: " << name << " = " << v << std::endl << ANSIcolor(); + return v; +} + +template +struct TypeName { + static const char* get(){return typeid(T).name();} +}; + +// a specialization for each type of those you want to support +// and don't like the string returned by typeid +template <> +struct TypeName{ + static const char* get(){return "int32_t";} +}; +template <> +struct TypeName{ + static const char* get(){return "foat";} +}; +template <> +struct TypeName{ + static const char* get(){return "bfloat16";} +}; +template <> +struct TypeName{ + static const char* get(){return "int8_t";} +}; + +inline std::ostream & operator<<(std::ostream & os, const llmdnn::postops_types & steps) { + if (steps == llmdnn::NONE) + os << "NONE"; + if (steps & llmdnn::DEQUANT) + os << "_DEQUANT"; + if (steps & llmdnn::BIAS) + os << "_BIAS"; + if (steps & llmdnn::GELU) + os << "_GELU"; + if (steps & llmdnn::QUANT) + os << "_QUANT"; + return os; +} diff --git a/tests/src/test_fc_kernel.cpp b/tests/src/test_fc_kernel.cpp new file mode 100644 index 0000000..c4d7087 --- /dev/null +++ b/tests/src/test_fc_kernel.cpp @@ -0,0 +1,217 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include +#include +#include +#include +#include "gtest/gtest.h" +#include "llm_fc.hpp" +#include "tensor2d.hpp" +#include "tensor2d_helper.hpp" +#include "test_common.h" + +using namespace std; +using namespace llmdnn; +using ::testing::TestWithParam; +using ::testing::Values; +using ::testing::ValuesIn; + +using FCKernelTestShape = std::tuple; +using FCKernelTestDTPost = std::tuple; +using FCKernelTestParamSet = std::tuple< + FCKernelTestDTPost, // a, b, c data type, postops + bool, // b needs transpose + FCKernelTestShape // M, N, K + >; + +class FCKernelTest : public TestWithParam { +public: + static std::string getTestCaseName(const testing::TestParamInfo& obj) { + FCKernelTestDTPost types; + bool is_transpose; + postops_types postops_type; + data_type_t dt_a, dt_b, dt_c; + FCKernelTestShape shape; + int M, N, K; + std::tie(types, is_transpose, shape) = obj.param; + std::tie(M, N, K) = shape; + std::tie(dt_a, dt_b, dt_c, postops_type) = types; + + std::ostringstream result; + result << "A_" << dtype_to_str(dt_a) << "_B_" << dtype_to_str(dt_b) + << "_C_" << dtype_to_str(dt_c) << (is_transpose ? "_transpose" : "") + << "_postops_" << postops_type << "_M_" << M << "_N_" << N << "_K_" << K; + return result.str(); + } + +protected: + virtual void SetUp() { + initXTILE(); + + FCKernelTestShape shape; + FCKernelTestDTPost types; + std::tie(types, _is_transpose, shape) = GetParam(); + std::tie(_M, _N, _K) = shape; + std::tie(_dt_a, _dt_b, _dt_c, _postops_type) = types; + }; + + template + void do_test() { + fc_kernel* fc; + fc_create_param param = { + _dt_a, _dt_b, _dt_c, + _is_transpose, _postops_type + }; + ASSERT_TRUE(fc_kernel_create(&fc, ¶m)); + auto gemm = std::shared_ptr(fc, [](fc_kernel* p) { fc_kernel_destroy(p); }); + + tensor2D A(_M, _K, true); + tensor2D B(_K, _N, true); + tensor2D C(_M, _N, true); + tensor2D C_Ref(_M, _N, true); + tensor2D dq(1, _N); + tensor2D q(1, _N); + tensor2D bias(1, _N); + + fill_rnd(A); + fill_rnd(B); + dq = 2; + q = 2; + fill_rnd(bias); + bias = 1; + + tensor2D BT = B.Tr(); + TB* ptr_B; + size_t ldb; + if (_is_transpose) { + ptr_B = BT.data; + ldb = BT.stride; + } else { + ptr_B = B.data; + ldb = B.stride; + } + fc_kernel_execute(gemm.get(), A.data, ptr_B, C.data, A.stride, ldb, + C.stride, _M, _N, _K, 0, _N, dq.data, q.data, bias.data); + C_Ref = 0; + float* ptr_dq = nullptr; + float* ptr_q = nullptr; + float* ptr_bias = nullptr; + func_act act = func_act(); + if (_postops_type & DEQUANT) { + ptr_dq = dq.data; + } + if (_postops_type & QUANT) { + ptr_q = q.data; + } + if (_postops_type & BIAS) { + ptr_bias = bias.data; + } + if (_postops_type & GELU) { + act = [] (float x) { + return x * 0.5 * (1 + std::erf(x / std::sqrt(2))); + }; + } + + matmul(A, B, C_Ref, ptr_dq, ptr_bias, act, ptr_q); + float thresh = 0.0001f; + if (std::is_same::value || std::is_same::value) + thresh = 1.1f; + if (std::is_same::value) + thresh = 0.01f; + ASSERT_TRUE(compare(C, C_Ref, thresh)); + } + + int _M, _N, _K; + bool _is_transpose; + postops_types _postops_type; + data_type_t _dt_a, _dt_b, _dt_c; +}; + +TEST_P(FCKernelTest, Func) { + if (_dt_a == dnnl_s8 && _dt_b == dnnl_s8 && _dt_c == dnnl_s8) { + do_test(); + } else if (_dt_a == dnnl_s8 && _dt_b == dnnl_s8 && _dt_c == dnnl_bf16) { + do_test(); + } else if (_dt_a == dnnl_s8 && _dt_b == dnnl_s8 && _dt_c == dnnl_f32) { + do_test(); + } else if (_dt_a == dnnl_bf16 && _dt_b == dnnl_bf16 && _dt_c == dnnl_bf16) { + do_test(); + } else if (_dt_a == dnnl_bf16 && _dt_b == dnnl_bf16 && _dt_c == dnnl_f32) { + do_test(); + } else if (_dt_a == dnnl_bf16 && _dt_b == dnnl_s8 && _dt_c == dnnl_f32) { + do_test(); + } else if (_dt_a == dnnl_bf16 && _dt_b == dnnl_s8 && _dt_c == dnnl_bf16) { + do_test(); + } else { + ASSERT_TRUE(false); + } +} + +// supported: +// (s8,s8,s8),dq,[bias],[gelu],q +// (s8,s8,bf16),dq,[bias],[gelu] +// (s8,s8,f32),dq,[bias],[gelu] +// (bf16,bf16,bf16),[bias],[gelu] +// (bf16,bf16,f32),[bias],[gelu] +// (bf16,s8,f32),dq,[bias],[gelu] +// (bf16,s8,bf16),dq,[bias],[gelu] +const std::vector types = { + { dnnl_s8, dnnl_s8, dnnl_s8, DEQUANT_QUANT }, + { dnnl_s8, dnnl_s8, dnnl_s8, DEQUANT_BIAS_QUANT }, + { dnnl_s8, dnnl_s8, dnnl_s8, DEQUANT_GELU_QUANT }, + { dnnl_s8, dnnl_s8, dnnl_s8, DEQUANT_BIAS_GELU_QUANT }, + { dnnl_s8, dnnl_s8, dnnl_bf16, DEQUANT }, + { dnnl_s8, dnnl_s8, dnnl_bf16, DEQUANT_BIAS }, + { dnnl_s8, dnnl_s8, dnnl_bf16, DEQUANT_GELU }, + { dnnl_s8, dnnl_s8, dnnl_bf16, DEQUANT_BIAS_GELU }, + { dnnl_s8, dnnl_s8, dnnl_f32, DEQUANT }, + { dnnl_s8, dnnl_s8, dnnl_f32, DEQUANT_BIAS }, + { dnnl_s8, dnnl_s8, dnnl_f32, DEQUANT_GELU }, + { dnnl_s8, dnnl_s8, dnnl_f32, DEQUANT_BIAS_GELU }, + { dnnl_bf16, dnnl_bf16, dnnl_bf16, NONE }, + { dnnl_bf16, dnnl_bf16, dnnl_bf16, BIAS }, + { dnnl_bf16, dnnl_bf16, dnnl_bf16, GELU }, + { dnnl_bf16, dnnl_bf16, dnnl_bf16, BIAS_GELU }, + { dnnl_bf16, dnnl_bf16, dnnl_f32, NONE }, + { dnnl_bf16, dnnl_bf16, dnnl_f32, BIAS }, + { dnnl_bf16, dnnl_bf16, dnnl_f32, GELU }, + { dnnl_bf16, dnnl_bf16, dnnl_f32, BIAS_GELU }, + // TODO: support weight compression + // { dnnl_bf16, dnnl_s8, dnnl_f32, DEQUANT }, + // { dnnl_bf16, dnnl_s8, dnnl_f32, DEQUANT_BIAS }, + // { dnnl_bf16, dnnl_s8, dnnl_f32, DEQUANT_GELU }, + // { dnnl_bf16, dnnl_s8, dnnl_f32, DEQUANT_BIAS_GELU }, + // { dnnl_bf16, dnnl_s8, dnnl_bf16, DEQUANT }, + // { dnnl_bf16, dnnl_s8, dnnl_bf16, DEQUANT_BIAS }, + // { dnnl_bf16, dnnl_s8, dnnl_bf16, DEQUANT_GELU }, + // { dnnl_bf16, dnnl_s8, dnnl_bf16, DEQUANT_BIAS_GELU }, +}; + +// M, N, K +const std::vector shapes = { + // normal + {256, 48, 448}, + // k tail + {256, 48, 449}, + // M tail == unroll 8 + {256 + 8, 48, 449}, + // M tail == unroll 8 + 2 + {256 + 10, 48, 449}, + // N tail + {256, 40, 448}, + // all tail + {256 + 9, 47, 449}, + // gemv, K <= 64(32)*6 + {256, 1, 80}, +}; + +INSTANTIATE_TEST_SUITE_P(smoke_FCKernel, FCKernelTest, + ::testing::Combine(ValuesIn(types), + Values(true, false), + ValuesIn(shapes)), + FCKernelTest::getTestCaseName); diff --git a/tests/src/test_mm_kernel.cpp b/tests/src/test_mm_kernel.cpp new file mode 100644 index 0000000..1cecf01 --- /dev/null +++ b/tests/src/test_mm_kernel.cpp @@ -0,0 +1,137 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include +#include +#include +#include +#include "gtest/gtest.h" +#include "llm_mm.hpp" +#include "tensor2d.hpp" +#include "tensor2d_helper.hpp" +#include "test_common.h" + +using namespace std; +using namespace llmdnn; +using ::testing::TestWithParam; +using ::testing::Values; +using ::testing::ValuesIn; + +using MMKernelTestShape = std::tuple; +using MMKernelTestParamSet = std::tuple< + std::pair, // a, b data type + bool, // b needs transpose + MMKernelTestShape // M, N, K + >; + +class GemmKernelTest : public TestWithParam { +public: + static std::string getTestCaseName(const testing::TestParamInfo& obj) { + std::pair types; + bool is_transpose; + MMKernelTestShape shape; + int M, N, K; + std::tie(types, is_transpose, shape) = obj.param; + std::tie(M, N, K) = shape; + + std::ostringstream result; + result << "A_" << dtype_to_str(types.first) << "_B_" << dtype_to_str(types.second) << "_" + << (is_transpose ? "transpose_" : "") + << "M_" << M << "_N_" << N << "_K_" << K; + return result.str(); + } + +protected: + virtual void SetUp() { + initXTILE(); + + MMKernelTestShape shape; + std::tie(_types, _is_transpose, shape) = GetParam(); + std::tie(_M, _N, _K) = shape; + }; + + template + void test() { + if (_N == 1 && (_is_transpose || _types.first == dnnl_u8)) { + GTEST_SKIP() << "gemv does not support transpose or u8s8."; + } + mm_kernel* mm; + mm_create_param param = { + _types.first, _types.second, + _N == 1, _is_transpose + }; + ASSERT_TRUE(mm_kernel_create(&mm, ¶m)); + auto gemm = std::shared_ptr(mm, [](mm_kernel* p) { mm_kernel_destroy(p); }); + + tensor2D A(_M, _K, true); + tensor2D B(_K, _N, true); + tensor2D C(_M, _N, true); + tensor2D C_Ref(_M, _N, true); + + fill_rnd(A); + fill_rnd(B); + tensor2D BT = B.Tr(); + TB* ptr_B; + size_t ldb; + if (_is_transpose) { + ptr_B = BT.data; + ldb = BT.stride; + } else { + ptr_B = B.data; + ldb = B.stride; + } + mm_kernel_execute(gemm.get(), A.data, ptr_B, C.data, A.stride, ldb, + C.stride, _M, _N, _K); + C_Ref = 0; + matmul(A, B, C_Ref); + ASSERT_TRUE(C_Ref == C); + } + + int _M, _N, _K; + std::pair _types; + bool _is_transpose; +}; + +TEST_P(GemmKernelTest, Func) { + if (_types.first == dnnl_u8 && _types.second == dnnl_s8) { + test(); + } else if (_types.first == dnnl_s8 && _types.second == dnnl_s8) { + test(); + } else { + test(); + } +} + +const std::vector> types = { + { dnnl_u8, dnnl_s8 }, + { dnnl_s8, dnnl_s8 }, + { dnnl_bf16, dnnl_bf16 }, +}; + +// M, N, K +const std::vector shapes = { + // normal + {256, 48, 448}, + // k tail + {256, 48, 449}, + // M tail == unroll 8 + {256 + 8, 48, 449}, + // M tail == unroll 8 + 2 + {256 + 10, 48, 449}, + // N tail + {256, 40, 448}, + // all tail + {256 + 9, 47, 449}, + // gemv, K <= 64(32)*6, TODO: 160 will fail + {256, 1, 80}, +}; + +INSTANTIATE_TEST_SUITE_P(smoke_GemmKernel, GemmKernelTest, + ::testing::Combine(ValuesIn(types), + Values(true, false), + ValuesIn(shapes)), + GemmKernelTest::getTestCaseName); From 2f8cd883ff065fc20675fbbecaa019d43d700c2d Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Thu, 29 Jun 2023 01:51:25 +0800 Subject: [PATCH 02/54] strict N tail handling --- src/mm_kernel_amx.hpp | 872 ++++++++++++++++++++++++++- src/tensor2d.hpp | 4 +- tests/src/test_mm_kernel.cpp | 4 +- tests/src/test_utility_repack1x2.cpp | 149 +++++ 4 files changed, 998 insertions(+), 31 deletions(-) create mode 100644 tests/src/test_utility_repack1x2.cpp diff --git a/src/mm_kernel_amx.hpp b/src/mm_kernel_amx.hpp index 8f41915..611972b 100644 --- a/src/mm_kernel_amx.hpp +++ b/src/mm_kernel_amx.hpp @@ -180,6 +180,303 @@ namespace functional { _mm512_storeu_epi32(dst + 15*16, rf); } + inline void transpose_epi32_Mx16(void * _dst, const void * src, int stride, int valid_m) { + auto * dst = reinterpret_cast(_dst); + __m512i r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf; + auto * pA = reinterpret_cast(src); + switch (valid_m) { + case 15: + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_loadu_epi32(pA + 2*stride); + r3 = _mm512_loadu_epi32(pA + 3*stride); + r4 = _mm512_loadu_epi32(pA + 4*stride); + r5 = _mm512_loadu_epi32(pA + 5*stride); + r6 = _mm512_loadu_epi32(pA + 6*stride); + r7 = _mm512_loadu_epi32(pA + 7*stride); + r8 = _mm512_loadu_epi32(pA + 8*stride); + r9 = _mm512_loadu_epi32(pA + 9*stride); + ra = _mm512_loadu_epi32(pA + 10*stride); + rb = _mm512_loadu_epi32(pA + 11*stride); + rc = _mm512_loadu_epi32(pA + 12*stride); + rd = _mm512_loadu_epi32(pA + 13*stride); + re = _mm512_loadu_epi32(pA + 14*stride); + rf = _mm512_setzero(); + break; + case 14: + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_loadu_epi32(pA + 2*stride); + r3 = _mm512_loadu_epi32(pA + 3*stride); + r4 = _mm512_loadu_epi32(pA + 4*stride); + r5 = _mm512_loadu_epi32(pA + 5*stride); + r6 = _mm512_loadu_epi32(pA + 6*stride); + r7 = _mm512_loadu_epi32(pA + 7*stride); + r8 = _mm512_loadu_epi32(pA + 8*stride); + r9 = _mm512_loadu_epi32(pA + 9*stride); + ra = _mm512_loadu_epi32(pA + 10*stride); + rb = _mm512_loadu_epi32(pA + 11*stride); + rc = _mm512_loadu_epi32(pA + 12*stride); + rd = _mm512_loadu_epi32(pA + 13*stride); + re = _mm512_setzero(); + rf = _mm512_setzero(); + break; + case 13: + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_loadu_epi32(pA + 2*stride); + r3 = _mm512_loadu_epi32(pA + 3*stride); + r4 = _mm512_loadu_epi32(pA + 4*stride); + r5 = _mm512_loadu_epi32(pA + 5*stride); + r6 = _mm512_loadu_epi32(pA + 6*stride); + r7 = _mm512_loadu_epi32(pA + 7*stride); + r8 = _mm512_loadu_epi32(pA + 8*stride); + r9 = _mm512_loadu_epi32(pA + 9*stride); + ra = _mm512_loadu_epi32(pA + 10*stride); + rb = _mm512_loadu_epi32(pA + 11*stride); + rc = _mm512_loadu_epi32(pA + 12*stride); + rd = _mm512_setzero(); + re = _mm512_setzero(); + rf = _mm512_setzero(); + break; + case 12: + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_loadu_epi32(pA + 2*stride); + r3 = _mm512_loadu_epi32(pA + 3*stride); + r4 = _mm512_loadu_epi32(pA + 4*stride); + r5 = _mm512_loadu_epi32(pA + 5*stride); + r6 = _mm512_loadu_epi32(pA + 6*stride); + r7 = _mm512_loadu_epi32(pA + 7*stride); + r8 = _mm512_loadu_epi32(pA + 8*stride); + r9 = _mm512_loadu_epi32(pA + 9*stride); + ra = _mm512_loadu_epi32(pA + 10*stride); + rb = _mm512_loadu_epi32(pA + 11*stride); + rc = _mm512_setzero(); + rd = _mm512_setzero(); + re = _mm512_setzero(); + rf = _mm512_setzero(); + break; + case 11: + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_loadu_epi32(pA + 2*stride); + r3 = _mm512_loadu_epi32(pA + 3*stride); + r4 = _mm512_loadu_epi32(pA + 4*stride); + r5 = _mm512_loadu_epi32(pA + 5*stride); + r6 = _mm512_loadu_epi32(pA + 6*stride); + r7 = _mm512_loadu_epi32(pA + 7*stride); + r8 = _mm512_loadu_epi32(pA + 8*stride); + r9 = _mm512_loadu_epi32(pA + 9*stride); + ra = _mm512_loadu_epi32(pA + 10*stride); + rb = _mm512_setzero(); + rc = _mm512_setzero(); + rd = _mm512_setzero(); + re = _mm512_setzero(); + rf = _mm512_setzero(); + break; + case 10: + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_loadu_epi32(pA + 2*stride); + r3 = _mm512_loadu_epi32(pA + 3*stride); + r4 = _mm512_loadu_epi32(pA + 4*stride); + r5 = _mm512_loadu_epi32(pA + 5*stride); + r6 = _mm512_loadu_epi32(pA + 6*stride); + r7 = _mm512_loadu_epi32(pA + 7*stride); + r8 = _mm512_loadu_epi32(pA + 8*stride); + r9 = _mm512_loadu_epi32(pA + 9*stride); + ra = _mm512_setzero(); + rb = _mm512_setzero(); + rc = _mm512_setzero(); + rd = _mm512_setzero(); + re = _mm512_setzero(); + rf = _mm512_setzero(); + break; + case 9: + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_loadu_epi32(pA + 2*stride); + r3 = _mm512_loadu_epi32(pA + 3*stride); + r4 = _mm512_loadu_epi32(pA + 4*stride); + r5 = _mm512_loadu_epi32(pA + 5*stride); + r6 = _mm512_loadu_epi32(pA + 6*stride); + r7 = _mm512_loadu_epi32(pA + 7*stride); + r8 = _mm512_loadu_epi32(pA + 8*stride); + r9 = _mm512_setzero(); + ra = _mm512_setzero(); + rb = _mm512_setzero(); + rc = _mm512_setzero(); + rd = _mm512_setzero(); + re = _mm512_setzero(); + rf = _mm512_setzero(); + break; + case 8: + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_loadu_epi32(pA + 2*stride); + r3 = _mm512_loadu_epi32(pA + 3*stride); + r4 = _mm512_loadu_epi32(pA + 4*stride); + r5 = _mm512_loadu_epi32(pA + 5*stride); + r6 = _mm512_loadu_epi32(pA + 6*stride); + r7 = _mm512_loadu_epi32(pA + 7*stride); + r8 = _mm512_setzero(); + r9 = _mm512_setzero(); + ra = _mm512_setzero(); + rb = _mm512_setzero(); + rc = _mm512_setzero(); + rd = _mm512_setzero(); + re = _mm512_setzero(); + rf = _mm512_setzero(); + break; + case 7: + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_loadu_epi32(pA + 2*stride); + r3 = _mm512_loadu_epi32(pA + 3*stride); + r4 = _mm512_loadu_epi32(pA + 4*stride); + r5 = _mm512_loadu_epi32(pA + 5*stride); + r6 = _mm512_loadu_epi32(pA + 6*stride); + r7 = _mm512_setzero(); + r8 = _mm512_setzero(); + r9 = _mm512_setzero(); + ra = _mm512_setzero(); + rb = _mm512_setzero(); + rc = _mm512_setzero(); + rd = _mm512_setzero(); + re = _mm512_setzero(); + rf = _mm512_setzero(); + break; + case 6: + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_loadu_epi32(pA + 2*stride); + r3 = _mm512_loadu_epi32(pA + 3*stride); + r4 = _mm512_loadu_epi32(pA + 4*stride); + r5 = _mm512_loadu_epi32(pA + 5*stride); + r6 = _mm512_setzero(); + r7 = _mm512_setzero(); + r8 = _mm512_setzero(); + r9 = _mm512_setzero(); + ra = _mm512_setzero(); + rb = _mm512_setzero(); + rc = _mm512_setzero(); + rd = _mm512_setzero(); + re = _mm512_setzero(); + rf = _mm512_setzero(); + break; + case 5: + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_loadu_epi32(pA + 2*stride); + r3 = _mm512_loadu_epi32(pA + 3*stride); + r4 = _mm512_loadu_epi32(pA + 4*stride); + r5 = _mm512_setzero(); + r6 = _mm512_setzero(); + r7 = _mm512_setzero(); + r8 = _mm512_setzero(); + r9 = _mm512_setzero(); + ra = _mm512_setzero(); + rb = _mm512_setzero(); + rc = _mm512_setzero(); + rd = _mm512_setzero(); + re = _mm512_setzero(); + rf = _mm512_setzero(); + break; + case 4: + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_loadu_epi32(pA + 2*stride); + r3 = _mm512_loadu_epi32(pA + 3*stride); + r4 = _mm512_setzero(); + r5 = _mm512_setzero(); + r6 = _mm512_setzero(); + r7 = _mm512_setzero(); + r8 = _mm512_setzero(); + r9 = _mm512_setzero(); + ra = _mm512_setzero(); + rb = _mm512_setzero(); + rc = _mm512_setzero(); + rd = _mm512_setzero(); + re = _mm512_setzero(); + rf = _mm512_setzero(); + break; + case 3: + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_loadu_epi32(pA + 2*stride); + r3 = _mm512_setzero(); + r4 = _mm512_setzero(); + r5 = _mm512_setzero(); + r6 = _mm512_setzero(); + r7 = _mm512_setzero(); + r8 = _mm512_setzero(); + r9 = _mm512_setzero(); + ra = _mm512_setzero(); + rb = _mm512_setzero(); + rc = _mm512_setzero(); + rd = _mm512_setzero(); + re = _mm512_setzero(); + rf = _mm512_setzero(); + break; + case 2: + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_setzero(); + r3 = _mm512_setzero(); + r4 = _mm512_setzero(); + r5 = _mm512_setzero(); + r6 = _mm512_setzero(); + r7 = _mm512_setzero(); + r8 = _mm512_setzero(); + r9 = _mm512_setzero(); + ra = _mm512_setzero(); + rb = _mm512_setzero(); + rc = _mm512_setzero(); + rd = _mm512_setzero(); + re = _mm512_setzero(); + rf = _mm512_setzero(); + break; + case 1: + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_setzero(); + r2 = _mm512_setzero(); + r3 = _mm512_setzero(); + r4 = _mm512_setzero(); + r5 = _mm512_setzero(); + r6 = _mm512_setzero(); + r7 = _mm512_setzero(); + r8 = _mm512_setzero(); + r9 = _mm512_setzero(); + ra = _mm512_setzero(); + rb = _mm512_setzero(); + rc = _mm512_setzero(); + rd = _mm512_setzero(); + re = _mm512_setzero(); + rf = _mm512_setzero(); + break; + } + + transpose_m512i_16x16(r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf); + + _mm512_storeu_epi32(dst, r0); + _mm512_storeu_epi32(dst + 16, r1); + _mm512_storeu_epi32(dst + 2*16, r2); + _mm512_storeu_epi32(dst + 3*16, r3); + _mm512_storeu_epi32(dst + 4*16, r4); + _mm512_storeu_epi32(dst + 5*16, r5); + _mm512_storeu_epi32(dst + 6*16, r6); + _mm512_storeu_epi32(dst + 7*16, r7); + _mm512_storeu_epi32(dst + 8*16, r8); + _mm512_storeu_epi32(dst + 9*16, r9); + _mm512_storeu_epi32(dst + 10*16, ra); + _mm512_storeu_epi32(dst + 11*16, rb); + _mm512_storeu_epi32(dst + 12*16, rc); + _mm512_storeu_epi32(dst + 13*16, rd); + _mm512_storeu_epi32(dst + 14*16, re); + _mm512_storeu_epi32(dst + 15*16, rf); + } + // 16xN, N<=16, non-valid part is on the left, filled with zeros inline void transpose_epi32_16xN_right_align(void * _dst, const void * src, int stride, int valid_bytes) { auto * dst = reinterpret_cast(_dst); @@ -223,6 +520,304 @@ namespace functional { _mm512_storeu_epi32(dst + 15*16, rf); } + inline void transpose_epi32_MxN_right_align(void * _dst, const void * src, int stride, int valid_bytes, int valid_m) { + auto * dst = reinterpret_cast(_dst); + __m512i r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf; + int invalid_bytes = 64 - valid_bytes; + auto * pA = reinterpret_cast(src) - invalid_bytes; + uint64_t mask_value = 0xFFFFFFFFFFFFFFFFull << invalid_bytes; + __mmask64 mask = _cvtu64_mask64(mask_value); + switch (valid_m) { + case 15: + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); + r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); + r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); + r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); + r6 = _mm512_maskz_loadu_epi8 (mask, pA + 6*stride); + r7 = _mm512_maskz_loadu_epi8 (mask, pA + 7*stride); + r8 = _mm512_maskz_loadu_epi8 (mask, pA + 8*stride); + r9 = _mm512_maskz_loadu_epi8 (mask, pA + 9*stride); + ra = _mm512_maskz_loadu_epi8 (mask, pA + 10*stride); + rb = _mm512_maskz_loadu_epi8 (mask, pA + 11*stride); + rc = _mm512_maskz_loadu_epi8 (mask, pA + 12*stride); + rd = _mm512_maskz_loadu_epi8 (mask, pA + 13*stride); + re = _mm512_maskz_loadu_epi8 (mask, pA + 14*stride); + rf = _mm512_setzero(); + break; + case 14: + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); + r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); + r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); + r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); + r6 = _mm512_maskz_loadu_epi8 (mask, pA + 6*stride); + r7 = _mm512_maskz_loadu_epi8 (mask, pA + 7*stride); + r8 = _mm512_maskz_loadu_epi8 (mask, pA + 8*stride); + r9 = _mm512_maskz_loadu_epi8 (mask, pA + 9*stride); + ra = _mm512_maskz_loadu_epi8 (mask, pA + 10*stride); + rb = _mm512_maskz_loadu_epi8 (mask, pA + 11*stride); + rc = _mm512_maskz_loadu_epi8 (mask, pA + 12*stride); + rd = _mm512_maskz_loadu_epi8 (mask, pA + 13*stride); + re = _mm512_setzero(); + rf = _mm512_setzero(); + break; + case 13: + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); + r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); + r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); + r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); + r6 = _mm512_maskz_loadu_epi8 (mask, pA + 6*stride); + r7 = _mm512_maskz_loadu_epi8 (mask, pA + 7*stride); + r8 = _mm512_maskz_loadu_epi8 (mask, pA + 8*stride); + r9 = _mm512_maskz_loadu_epi8 (mask, pA + 9*stride); + ra = _mm512_maskz_loadu_epi8 (mask, pA + 10*stride); + rb = _mm512_maskz_loadu_epi8 (mask, pA + 11*stride); + rc = _mm512_maskz_loadu_epi8 (mask, pA + 12*stride); + rd = _mm512_setzero(); + re = _mm512_setzero(); + rf = _mm512_setzero(); + break; + case 12: + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); + r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); + r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); + r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); + r6 = _mm512_maskz_loadu_epi8 (mask, pA + 6*stride); + r7 = _mm512_maskz_loadu_epi8 (mask, pA + 7*stride); + r8 = _mm512_maskz_loadu_epi8 (mask, pA + 8*stride); + r9 = _mm512_maskz_loadu_epi8 (mask, pA + 9*stride); + ra = _mm512_maskz_loadu_epi8 (mask, pA + 10*stride); + rb = _mm512_maskz_loadu_epi8 (mask, pA + 11*stride); + rc = _mm512_setzero(); + rd = _mm512_setzero(); + re = _mm512_setzero(); + rf = _mm512_setzero(); + break; + case 11: + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); + r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); + r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); + r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); + r6 = _mm512_maskz_loadu_epi8 (mask, pA + 6*stride); + r7 = _mm512_maskz_loadu_epi8 (mask, pA + 7*stride); + r8 = _mm512_maskz_loadu_epi8 (mask, pA + 8*stride); + r9 = _mm512_maskz_loadu_epi8 (mask, pA + 9*stride); + ra = _mm512_maskz_loadu_epi8 (mask, pA + 10*stride); + rb = _mm512_setzero(); + rc = _mm512_setzero(); + rd = _mm512_setzero(); + re = _mm512_setzero(); + rf = _mm512_setzero(); + break; + case 10: + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); + r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); + r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); + r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); + r6 = _mm512_maskz_loadu_epi8 (mask, pA + 6*stride); + r7 = _mm512_maskz_loadu_epi8 (mask, pA + 7*stride); + r8 = _mm512_maskz_loadu_epi8 (mask, pA + 8*stride); + r9 = _mm512_maskz_loadu_epi8 (mask, pA + 9*stride); + ra = _mm512_setzero(); + rb = _mm512_setzero(); + rc = _mm512_setzero(); + rd = _mm512_setzero(); + re = _mm512_setzero(); + rf = _mm512_setzero(); + break; + case 9: + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); + r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); + r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); + r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); + r6 = _mm512_maskz_loadu_epi8 (mask, pA + 6*stride); + r7 = _mm512_maskz_loadu_epi8 (mask, pA + 7*stride); + r8 = _mm512_maskz_loadu_epi8 (mask, pA + 8*stride); + r9 = _mm512_setzero(); + ra = _mm512_setzero(); + rb = _mm512_setzero(); + rc = _mm512_setzero(); + rd = _mm512_setzero(); + re = _mm512_setzero(); + rf = _mm512_setzero(); + break; + case 8: + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); + r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); + r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); + r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); + r6 = _mm512_maskz_loadu_epi8 (mask, pA + 6*stride); + r7 = _mm512_maskz_loadu_epi8 (mask, pA + 7*stride); + r8 = _mm512_setzero(); + r9 = _mm512_setzero(); + ra = _mm512_setzero(); + rb = _mm512_setzero(); + rc = _mm512_setzero(); + rd = _mm512_setzero(); + re = _mm512_setzero(); + rf = _mm512_setzero(); + break; + case 7: + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); + r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); + r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); + r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); + r6 = _mm512_maskz_loadu_epi8 (mask, pA + 6*stride); + r7 = _mm512_setzero(); + r8 = _mm512_setzero(); + r9 = _mm512_setzero(); + ra = _mm512_setzero(); + rb = _mm512_setzero(); + rc = _mm512_setzero(); + rd = _mm512_setzero(); + re = _mm512_setzero(); + rf = _mm512_setzero(); + break; + case 6: + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); + r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); + r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); + r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); + r6 = _mm512_setzero(); + r7 = _mm512_setzero(); + r8 = _mm512_setzero(); + r9 = _mm512_setzero(); + ra = _mm512_setzero(); + rb = _mm512_setzero(); + rc = _mm512_setzero(); + rd = _mm512_setzero(); + re = _mm512_setzero(); + rf = _mm512_setzero(); + break; + case 5: + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); + r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); + r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); + r5 = _mm512_setzero(); + r6 = _mm512_setzero(); + r7 = _mm512_setzero(); + r8 = _mm512_setzero(); + r9 = _mm512_setzero(); + ra = _mm512_setzero(); + rb = _mm512_setzero(); + rc = _mm512_setzero(); + rd = _mm512_setzero(); + re = _mm512_setzero(); + rf = _mm512_setzero(); + break; + case 4: + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); + r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); + r4 = _mm512_setzero(); + r5 = _mm512_setzero(); + r6 = _mm512_setzero(); + r7 = _mm512_setzero(); + r8 = _mm512_setzero(); + r9 = _mm512_setzero(); + ra = _mm512_setzero(); + rb = _mm512_setzero(); + rc = _mm512_setzero(); + rd = _mm512_setzero(); + re = _mm512_setzero(); + rf = _mm512_setzero(); + break; + case 3: + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); + r3 = _mm512_setzero(); + r4 = _mm512_setzero(); + r5 = _mm512_setzero(); + r6 = _mm512_setzero(); + r7 = _mm512_setzero(); + r8 = _mm512_setzero(); + r9 = _mm512_setzero(); + ra = _mm512_setzero(); + rb = _mm512_setzero(); + rc = _mm512_setzero(); + rd = _mm512_setzero(); + re = _mm512_setzero(); + rf = _mm512_setzero(); + break; + case 2: + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_setzero(); + r3 = _mm512_setzero(); + r4 = _mm512_setzero(); + r5 = _mm512_setzero(); + r6 = _mm512_setzero(); + r7 = _mm512_setzero(); + r8 = _mm512_setzero(); + r9 = _mm512_setzero(); + ra = _mm512_setzero(); + rb = _mm512_setzero(); + rc = _mm512_setzero(); + rd = _mm512_setzero(); + re = _mm512_setzero(); + rf = _mm512_setzero(); + break; + case 1: + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_setzero(); + r2 = _mm512_setzero(); + r3 = _mm512_setzero(); + r4 = _mm512_setzero(); + r5 = _mm512_setzero(); + r6 = _mm512_setzero(); + r7 = _mm512_setzero(); + r8 = _mm512_setzero(); + r9 = _mm512_setzero(); + ra = _mm512_setzero(); + rb = _mm512_setzero(); + rc = _mm512_setzero(); + rd = _mm512_setzero(); + re = _mm512_setzero(); + rf = _mm512_setzero(); + break; + } + transpose_m512i_16x16(r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf); + _mm512_storeu_epi32(dst, r0); + _mm512_storeu_epi32(dst + 16, r1); + _mm512_storeu_epi32(dst + 2*16, r2); + _mm512_storeu_epi32(dst + 3*16, r3); + _mm512_storeu_epi32(dst + 4*16, r4); + _mm512_storeu_epi32(dst + 5*16, r5); + _mm512_storeu_epi32(dst + 6*16, r6); + _mm512_storeu_epi32(dst + 7*16, r7); + _mm512_storeu_epi32(dst + 8*16, r8); + _mm512_storeu_epi32(dst + 9*16, r9); + _mm512_storeu_epi32(dst + 10*16, ra); + _mm512_storeu_epi32(dst + 11*16, rb); + _mm512_storeu_epi32(dst + 12*16, rc); + _mm512_storeu_epi32(dst + 13*16, rd); + _mm512_storeu_epi32(dst + 14*16, re); + _mm512_storeu_epi32(dst + 15*16, rf); + } + // gelu_erf_minimax_approx_compute_vector_fwd in oneDNN // x*0.5*(1+erf(x/sqrt(2))) = x*0.5*(1 + x*Polynomial(x^2)) inline __m512 gelu_erf_minmax_approx(__m512 & x) { @@ -437,6 +1032,171 @@ namespace functional { } } + inline void kpack_tile_B0B1_ntail(void * _dst0, void * _dst1, const int8_t * _src, int stride, int src_rows, int valid_n) { + #define FROM_B(i) ((1<<4)|(i)) + static const uint32_t idx[16] = { 0,4,FROM_B(0),FROM_B(4), + 1,5,FROM_B(1),FROM_B(5), + 2,6,FROM_B(2),FROM_B(6), + 3,7,FROM_B(3),FROM_B(7)}; + auto midx = _mm512_loadu_epi64(idx); + __mmask16 mask = _cvtu32_mask16(0xFFFFu); + const auto * src = reinterpret_cast(_src); + auto * dst0 = reinterpret_cast(_dst0); + auto * dst1 = reinterpret_cast(_dst1); + __mmask32 mask_n = _cvtu32_mask32(0xFFFFFFFF >> (32 - valid_n)); + if (src_rows == 64) { + for (int row = 0; row < 16; row++) { + // each element (a? or b?) is 32-bits, two lanes in each ymm register + auto a256 = _mm256_maskz_loadu_epi8(mask_n, src); src += stride; // [a0 a1 a2 a3 | a4 a5 a6 a7] 256-bits ymm0 B0: a0-a3 B1: a4:a7 + auto b256 = _mm256_maskz_loadu_epi8(mask_n, src); src += stride; // [b0 b1 b2 b3 | b4 b5 b6 b7] 256-bits ymm1 B0: b0-b3 B1: b4:b7 + auto c256 = _mm256_maskz_loadu_epi8(mask_n, src); src += stride; // [c0 c1 c2 c3 | c4 c5 c6 c7] 256-bits ymm2 B0: c0-c3 B1: c4:c7 + auto d256 = _mm256_maskz_loadu_epi8(mask_n, src); src += stride; // [d0 d1 d2 d3 | d4 d5 d6 d7] 256-bits ymm3 B0: d0-d3 B1: d4:d7 + auto a = _mm512_castsi256_si512(a256); + auto b = _mm512_castsi256_si512(b256); + auto c = _mm512_castsi256_si512(c256); + auto d = _mm512_castsi256_si512(d256); + auto ac = _mm512_mask_permutex2var_epi32(a, mask, midx, c); // [a0 a4 c0 c4 | a1 a5 c1 c5 | a2 a6 c2 c6 | a3 a7 c3 c7] + auto bd = _mm512_mask_permutex2var_epi32(b, mask, midx, d); // [b0 b4 d0 d4 | b1 b5 d1 d5 | b2 b6 d2 d6 | b3 b7 d3 d7] + auto aib = _mm512_unpacklo_epi8(ac, bd); // [a0&b0 a4&b4 | a1&b1 a5&b5 | a2&b2 a6&b6 | a3&b3 a7&b7] + auto cid = _mm512_unpackhi_epi8(ac, bd); // [c0&d0 c4&d4 | c1&d1 c5&d5 | c2&d2 c6&d6 | c3&d3 c7&d7] + auto rowB0 = _mm512_unpacklo_epi16(aib, cid); // [a0&b0&c0&d0 | a1&b1&c1&d1 | a2&b2&c2&d2 | a3&b3&c3&d3] 512-bit (64bytes) line in B0 + auto rowB1 = _mm512_unpackhi_epi16(aib, cid); // [a4&b4&c4&d4 | a5&b5&c5&d5 | a6&b6&c6&d6 | a7&b7&c7&d7] 512-bit (64bytes) line in B1 + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + dst0 += 64; + dst1 += 64; + } + } else { + // less than 64 source lines, + int allzero_dst_rows = (64-src_rows)/4; + int allnonzero_dst_rows = src_rows/4; + // padding zeros at the top + auto rowB0 = _mm512_setzero_si512(); + auto rowB1 = _mm512_setzero_si512(); + for(int i = 0; i < allzero_dst_rows ; i++) { + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + dst0 += 64; + dst1 += 64; + } + // mixed row + int tails_nz = (src_rows & 3); + if (tails_nz) { + __mmask32 kmask1 = _cvtu32_mask32(0xFFFFFFFF); + auto a256 = _mm256_setzero_si256(); // must be zero + auto b256 = _mm256_setzero_si256(); // when tails_nz > 2 + auto c256 = _mm256_setzero_si256(); // when tails_nz > 1 + auto d256 = _mm256_setzero_si256(); // when tails_nz > 0(always load) + if (tails_nz > 2) { + b256 = _mm256_maskz_loadu_epi8(mask_n, src); src += stride; + } + if (tails_nz > 1) { + c256 = _mm256_maskz_loadu_epi8(mask_n, src); src += stride; + } + d256 = _mm256_maskz_loadu_epi8(mask_n, src); src += stride; + auto a = _mm512_castsi256_si512(a256); + auto b = _mm512_castsi256_si512(b256); + auto c = _mm512_castsi256_si512(c256); + auto d = _mm512_castsi256_si512(d256); + auto ac = _mm512_mask_permutex2var_epi32(a, mask, midx, c); // [a0 a4 c0 c4 | a1 a5 c1 c5 | a2 a6 c2 c6 | a3 a7 c3 c7] + auto bd = _mm512_mask_permutex2var_epi32(b, mask, midx, d); // [b0 b4 d0 d4 | b1 b5 d1 d5 | b2 b6 d2 d6 | b3 b7 d3 d7] + auto aib = _mm512_unpacklo_epi8(ac, bd); // [a0&b0 a4&b4 | a1&b1 a5&b5 | a2&b2 a6&b6 | a3&b3 a7&b7] + auto cid = _mm512_unpackhi_epi8(ac, bd); // [c0&d0 c4&d4 | c1&d1 c5&d5 | c2&d2 c6&d6 | c3&d3 c7&d7] + auto rowB0 = _mm512_unpacklo_epi16(aib, cid); // [a0&b0&c0&d0 | a1&b1&c1&d1 | a2&b2&c2&d2 | a3&b3&c3&d3] 512-bit (64bytes) line in B0 + auto rowB1 = _mm512_unpackhi_epi16(aib, cid); // [a4&b4&c4&d4 | a5&b5&c5&d5 | a6&b6&c6&d6 | a7&b7&c7&d7] 512-bit (64bytes) line in B1 + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + dst0 += 64; + dst1 += 64; + } + // all non zeros + for (int i = 0; i < allnonzero_dst_rows; i++) { + auto a256 = _mm256_maskz_loadu_epi8(mask_n, src); src += stride; + auto b256 = _mm256_maskz_loadu_epi8(mask_n, src); src += stride; + auto c256 = _mm256_maskz_loadu_epi8(mask_n, src); src += stride; + auto d256 = _mm256_maskz_loadu_epi8(mask_n, src); src += stride; + auto a = _mm512_castsi256_si512(a256); + auto b = _mm512_castsi256_si512(b256); + auto c = _mm512_castsi256_si512(c256); + auto d = _mm512_castsi256_si512(d256); + auto ac = _mm512_mask_permutex2var_epi32(a, mask, midx, c); // [a0 a4 c0 c4 | a1 a5 c1 c5 | a2 a6 c2 c6 | a3 a7 c3 c7] + auto bd = _mm512_mask_permutex2var_epi32(b, mask, midx, d); // [b0 b4 d0 d4 | b1 b5 d1 d5 | b2 b6 d2 d6 | b3 b7 d3 d7] + auto aib = _mm512_unpacklo_epi8(ac, bd); // [a0&b0 a4&b4 | a1&b1 a5&b5 | a2&b2 a6&b6 | a3&b3 a7&b7] + auto cid = _mm512_unpackhi_epi8(ac, bd); // [c0&d0 c4&d4 | c1&d1 c5&d5 | c2&d2 c6&d6 | c3&d3 c7&d7] + auto rowB0 = _mm512_unpacklo_epi16(aib, cid); // [a0&b0&c0&d0 | a1&b1&c1&d1 | a2&b2&c2&d2 | a3&b3&c3&d3] 512-bit (64bytes) line in B0 + auto rowB1 = _mm512_unpackhi_epi16(aib, cid); // [a4&b4&c4&d4 | a5&b5&c5&d5 | a6&b6&c6&d6 | a7&b7&c7&d7] 512-bit (64bytes) line in B1 + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + dst0 += 64; + dst1 += 64; + } + } + } + + inline void kpack_tile_B0B1_ntail(void * _dst0, void * _dst1, const ov::bfloat16 * _src, int stride, int src_rows, int valid_n) { + static const uint64_t idx[8] = {0,4,1,5,2,6,3,7}; + auto midx = _mm512_loadu_epi64(idx); + const auto * src = reinterpret_cast(_src); + auto * dst0 = reinterpret_cast(_dst0); + auto * dst1 = reinterpret_cast(_dst1); + __mmask32 mask = _cvtu32_mask32(0xFFFFFFFF >> (32 - valid_n)); + __m512i a,b,rowB0, rowB1; + if (src_rows == 32) { + for (int row = 0; row < 16; row++) { + a = _mm512_maskz_loadu_epi16(mask, src); // [a1 a2 a3 a4 | a5 a6 a7 a8] total 512-bits in 8 64bits unit + b = _mm512_maskz_loadu_epi16(mask, src + stride); // [b1 b2 b3 b4 | b5 b6 b7 b8] total 512-bits + a = _mm512_permutexvar_epi64(midx, a); // [a1 a5 | a2 a6 | a3 a7 | a4 a8] + b = _mm512_permutexvar_epi64(midx, b); // [b1 b5 | b2 b6 | b3 b7 | b4 b8] + rowB0 = _mm512_unpacklo_epi16(a, b); // [ a1&b1 a2&b2 a3&b3 a4&b4] for each 128-bits lane, interleave word in low 64 bits + rowB1 = _mm512_unpackhi_epi16(a, b); // [ a5&b5 a6&b6 a7&b7 a8&b8] for each 128-bits lane, interleave word in high 64 bits + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + src += 2*stride; + dst0 += 64; + dst1 += 64; + } + } else { + int allzero_dst_rows = (32-src_rows)/2; + int allnonzero_dst_rows = src_rows/2; + + rowB0 = _mm512_setzero_si512(); + rowB1 = _mm512_setzero_si512(); + for(int i = 0; i < allzero_dst_rows ; i++) { + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + dst0 += 64; + dst1 += 64; + } + // mixed row + if (src_rows & 1) { + a = _mm512_setzero_si512(); + b = _mm512_maskz_loadu_epi16(mask, src); src += stride; + a = _mm512_permutexvar_epi64(midx, a); + b = _mm512_permutexvar_epi64(midx, b); + auto rowB0 = _mm512_unpacklo_epi16(a, b); + auto rowB1 = _mm512_unpackhi_epi16(a, b); + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + dst0 += 64; + dst1 += 64; + } + // all non-zero rows + for (int i = 0; i < allnonzero_dst_rows; i++) { + a = _mm512_maskz_loadu_epi16(mask, src); + b = _mm512_maskz_loadu_epi16(mask, src + stride); + a = _mm512_permutexvar_epi64(midx, a); + b = _mm512_permutexvar_epi64(midx, b); + rowB0 = _mm512_unpacklo_epi16(a, b); + rowB1 = _mm512_unpackhi_epi16(a, b); + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + src += 2*stride; + dst0 += 64; + dst1 += 64; + } + } + } + // prepare B matrix for C matrix 2x2 blocking (B matrix // will be accessed in 1x2) // given 2x2 blocking scheme, Kx32 blocks are always @@ -756,55 +1516,113 @@ void loop2D_opt_Mtail(int M, int N, int mc, F f) { // template void repackB_1x2(const tensor2D &Bi, bool transpose, tensor2D& Bo, bool is_const) { - int K = Bi.dims[transpose?1:0]; - int N = Bi.dims[transpose?0:1]; + int K = Bi.dims[transpose ? 1 : 0]; + int N = Bi.dims[transpose ? 0 : 1]; // K_padded : round up to multiple of 32/64 int kStep = 64 / sizeof(T); - int K_padded = (K + kStep - 1)/kStep * kStep; + int K_padded = (K + kStep - 1) / kStep * kStep; int Ktails = K % kStep; int Kbody = K - Ktails; // N_padded : round up to multiple of (2*16) - int N_unit = 2*16; - int N_padded = (N + N_unit - 1)/N_unit * N_unit; + int N_unit = 2 * 16; + int N_padded = (N + N_unit - 1) / N_unit * N_unit; // Bo(ni, 0) is a vector flattened from a slice of shape [K_padded x N_unit] - Bo.resize(N_padded/N_unit, K_padded * N_unit, false, is_const); + Bo.resize(N_padded / N_unit, K_padded * N_unit, false, is_const); + int n = 0; + int n_tail = N % N_unit; if (transpose) { - for(int n = 0; n < N; n += N_unit) { + for(; n < N - n_tail; n += N_unit) { // a K_padded x N_unit submatrix layouted in B0/B1... and put sequentially - auto * dst = reinterpret_cast(&Bo(n/N_unit, 0)); - auto * src0 = reinterpret_cast(&Bi(n, 0)); + auto* dst = reinterpret_cast(&Bo(n / N_unit, 0)); + auto* src0 = reinterpret_cast(&Bi(n, 0)); int k; for(k = 0; k < Kbody; k += kStep) { // B0 (16x32) => transpose+repack as 32x16(16x16x2) or 64x16(16x16x4) - functional::transpose_epi32_16x16(dst, src0 + 0*16*Bi.stride + k*sizeof(T), Bi.stride); + functional::transpose_epi32_16x16(dst, src0 + 0 * 16 * Bi.stride + k * sizeof(T), Bi.stride); dst += 1024; - functional::transpose_epi32_16x16(dst, src0 + 1*16*Bi.stride + k*sizeof(T), Bi.stride); + functional::transpose_epi32_16x16(dst, src0 + 1 * 16 * Bi.stride + k * sizeof(T), Bi.stride); dst += 1024; } if (Ktails) { // Ktails part is loaded into A tile right-aligned, so B tile must also load // Ktails part to bottom-aligned, and fill upper padding with zero - functional::transpose_epi32_16xN_right_align(dst, src0 + 0*16*Bi.stride + k*sizeof(T), Bi.stride, (K-k)*sizeof(T)); + functional::transpose_epi32_16xN_right_align(dst, src0 + 0 * 16 * Bi.stride + k * sizeof(T), Bi.stride, (K - k)*sizeof(T)); dst += 1024; - functional::transpose_epi32_16xN_right_align(dst, src0 + 1*16*Bi.stride + k*sizeof(T), Bi.stride, (K-k)*sizeof(T)); + functional::transpose_epi32_16xN_right_align(dst, src0 + 1 * 16 * Bi.stride + k * sizeof(T), Bi.stride, (K - k)*sizeof(T)); dst += 1024; } } + // n_tail: [16, 32) + if (N - n >= 16) { + auto* dst = reinterpret_cast(&Bo(n / N_unit, 0)); + auto* src0 = reinterpret_cast(&Bi(n, 0)); + int k; + for(k = 0; k < Kbody; k += kStep) { + // B0 (16x32) => transpose+repack as 32x16(16x16x2) or 64x16(16x16x4) + functional::transpose_epi32_16x16(dst, src0 + 0 * 16 * Bi.stride + k * sizeof(T), Bi.stride); + dst += 1024 * 2; + } + if (Ktails) { + // Ktails part is loaded into A tile right-aligned, so B tile must also load + // Ktails part to bottom-aligned, and fill upper padding with zero + functional::transpose_epi32_16xN_right_align(dst, src0 + 0 * 16 * Bi.stride + k * sizeof(T), Bi.stride, (K - k) * sizeof(T)); + } + n += 16; + } + // n_tail: (0, 16) + if (N - n > 0) { + auto* dst = reinterpret_cast(&Bo(n / N_unit, 0)) + (n_tail > 16 ? 1024 : 0); + auto* src0 = reinterpret_cast(&Bi(n, 0)); + int k; + for(k = 0; k < Kbody; k += kStep) { + // B0 (16x32) => transpose+repack as 32x16(16x16x2) or 64x16(16x16x4) + functional::transpose_epi32_Mx16(dst, src0 + 0 * 16 * Bi.stride + k * sizeof(T), Bi.stride, N - n); + dst += 1024 * 2; + } + if (Ktails) { + // Ktails part is loaded into A tile right-aligned, so B tile must also load + // Ktails part to bottom-aligned, and fill upper padding with zero + functional::transpose_epi32_MxN_right_align(dst, src0 + 0 * 16 * Bi.stride + k * sizeof(T), Bi.stride, (K - k) * sizeof(T), N - n); + } + n = N; + } + // second B tile is untouched, need to set to zero + if (n_tail > 0 && n_tail <= 16) { + auto* dst = reinterpret_cast(&Bo(n / N_unit, 0)); + for (int k = 0; k < K_padded; k += kStep) { + memset(dst + 1024, 0, 1024); + dst += 1024 * 2; + } + } } else { // pack & layout sequentially - for(int n = 0; n < N; n += N_unit) { - auto * dst = reinterpret_cast(&Bo(n/N_unit, 0)); - for(int k = 0; k < K; k+=kStep) { + int n = 0; + int n_tail = N % N_unit; + for(; n < N - n_tail; n += N_unit) { + auto * dst = reinterpret_cast(&Bo(n / N_unit, 0)); + for(int k = 0; k < K; k += kStep) { + // bf16: B0 B1 32x(16+16) => repack as two 16x16x2 + // int8: B0 B1 64x(16+16) => repack as two 16x16x4 + int src_rows = std::min(K - k, kStep); + functional::kpack_tile_B0B1(dst, dst + 1024, &Bi(k, n), Bi.stride, src_rows); + dst += 2048; + } + } + // n_tail: (0, 32) + if (N - n > 0) { + auto * dst = reinterpret_cast(&Bo(n / N_unit, 0)); + for(int k = 0; k < K; k += kStep) { // bf16: B0 B1 32x(16+16) => repack as two 16x16x2 // int8: B0 B1 64x(16+16) => repack as two 16x16x4 int src_rows = std::min(K - k, kStep); - functional::kpack_tile_B0B1(dst, dst + (1024), &Bi(k, n), Bi.stride, src_rows); + functional::kpack_tile_B0B1_ntail(dst, dst + 1024, &Bi(k, n), Bi.stride, src_rows, N - n); dst += 2048; } + n += 16; } } } @@ -930,14 +1748,14 @@ struct MatmulVector { _tile_loadd(1, pA0 + 128, strideA); TILE_DP(0, 1, 4); _tile_loadd(1, pA0 + KLastOffBytes, strideA); TILE_DP(0, 1, 5); } - if (tmmN == 6) { + if (tmmN == 5) { _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); _tile_loadd(1, pA0 + 64, strideA); TILE_DP(0, 1, 3); _tile_loadd(1, pA0 + 128, strideA); TILE_DP(0, 1, 4); _tile_loadd(1, pA0 + 192, strideA); TILE_DP(0, 1, 5); _tile_loadd(1, pA0 + KLastOffBytes, strideA); TILE_DP(0, 1, 6); } - if (tmmN == 7) { + if (tmmN == 6) { _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); _tile_loadd(1, pA0 + 64, strideA); TILE_DP(0, 1, 3); _tile_loadd(1, pA0 + 128, strideA); TILE_DP(0, 1, 4); @@ -1095,7 +1913,7 @@ struct Matmul { if (is_u8u8) _tile_dpbuud(dst, a, b); template - void kernel_slimB(int M, int N, int K, + void kernel_slimB(int M, int N, int K, int n0, tensor2D & A, void * B, tensor2D & buffC, @@ -1155,7 +1973,7 @@ struct Matmul { _tile_loadd(1, pA0 + KlastOffBytes, strideA); TILE_DP(0, 1, 7); } _tile_stored(0, pC0, buffC.stride); - (ppkernel)(buffC, m, 0, 16, N); + (ppkernel)(buffC, m, n0, 16, N); pA0 += 16*A.stride; } } @@ -1196,12 +2014,12 @@ struct Matmul { auto * pB0 = reinterpret_cast(&internalB[0]); tileconfig_t tfg(1, 0, 8, 16, 64); switch((K + kStep - 1)/kStep) { - case 1: kernel_slimB<1>(M, N, K, matA, pB0, buffC, ppkernel); break; - case 2: kernel_slimB<2>(M, N, K, matA, pB0, buffC, ppkernel); break; - case 3: kernel_slimB<3>(M, N, K, matA, pB0, buffC, ppkernel); break; - case 4: kernel_slimB<4>(M, N, K, matA, pB0, buffC, ppkernel); break; - case 5: kernel_slimB<5>(M, N, K, matA, pB0, buffC, ppkernel); break; - case 6: kernel_slimB<6>(M, N, K, matA, pB0, buffC, ppkernel); break; + case 1: kernel_slimB<1>(M, N, K, n0, matA, pB0, buffC, ppkernel); break; + case 2: kernel_slimB<2>(M, N, K, n0, matA, pB0, buffC, ppkernel); break; + case 3: kernel_slimB<3>(M, N, K, n0, matA, pB0, buffC, ppkernel); break; + case 4: kernel_slimB<4>(M, N, K, n0, matA, pB0, buffC, ppkernel); break; + case 5: kernel_slimB<5>(M, N, K, n0, matA, pB0, buffC, ppkernel); break; + case 6: kernel_slimB<6>(M, N, K, n0, matA, pB0, buffC, ppkernel); break; default: assert(false); // impossible since (K <= 6*kStep) } diff --git a/src/tensor2d.hpp b/src/tensor2d.hpp index 8ca325f..e24c8c7 100644 --- a/src/tensor2d.hpp +++ b/src/tensor2d.hpp @@ -51,8 +51,8 @@ struct tensor2D { padded_dim1 = stride / sizeof(T); } - tensor2D Tr() { - tensor2D ret(dims[1], dims[0]); + tensor2D Tr(bool _force_compact = false) { + tensor2D ret(dims[1], dims[0], _force_compact); for(int c0=0; c0 < dims[0]; ++c0) { for(int c1=0; c1 < dims[1]; ++c1) { ret(c1, c0) = (*this)(c0, c1); diff --git a/tests/src/test_mm_kernel.cpp b/tests/src/test_mm_kernel.cpp index 1cecf01..91863c7 100644 --- a/tests/src/test_mm_kernel.cpp +++ b/tests/src/test_mm_kernel.cpp @@ -126,8 +126,8 @@ const std::vector shapes = { {256, 40, 448}, // all tail {256 + 9, 47, 449}, - // gemv, K <= 64(32)*6, TODO: 160 will fail - {256, 1, 80}, + // gemv, K <= 64(32)*6 + {256, 1, 160}, }; INSTANTIATE_TEST_SUITE_P(smoke_GemmKernel, GemmKernelTest, diff --git a/tests/src/test_utility_repack1x2.cpp b/tests/src/test_utility_repack1x2.cpp new file mode 100644 index 0000000..0e063d9 --- /dev/null +++ b/tests/src/test_utility_repack1x2.cpp @@ -0,0 +1,149 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include +#include +#include +#include +#include "gtest/gtest.h" +#include "llm_mm.hpp" +#include "tensor2d.hpp" +#include "tensor2d_helper.hpp" +#include "mm_kernel_amx.hpp" +#include "test_common.h" + +using namespace std; +using namespace llmdnn; +using ::testing::TestWithParam; +using ::testing::Values; +using ::testing::ValuesIn; + +using RepackTestParamSet = std::tuple< + data_type_t // data type + >; + +class RepackTest : public TestWithParam { +public: + static std::string getTestCaseName(const testing::TestParamInfo& obj) { + data_type_t types; + int K, N; + std::tie(types) = obj.param; + + std::ostringstream result; + result << dtype_to_str(types); + return result.str(); + } + +protected: + virtual void SetUp() { + std::tie(_types) = GetParam(); + }; + + template + static void gen_ref(tensor2D& in_t, tensor2D& out) { + int K = in_t.dims[1]; + int N = in_t.dims[0]; + int kStep = 64 / sizeof(T); + int K_padded = (K + kStep - 1) / kStep * kStep; + int Ktails = K % kStep; + int Kbody = K - Ktails; + + // N_padded : round up to multiple of (2*16) + int N_unit = 2 * 16; + int N_padded = (N + N_unit - 1) / N_unit * N_unit; + + // Bo(ni, 0) is a vector flattened from a slice of shape [K_padded x N_unit] + out.resize(N_padded / N_unit, K_padded * N_unit, true, true); + + tensor2D in_padded; + in_padded.resize(N_padded, K_padded, true, true); + for (int i = 0; i < N; i++) { + // full range + memcpy(&in_padded(i, 0), &in_t(i, 0), (K - Ktails) * sizeof(T)); + // k tail needs to right aligned + memcpy(&in_padded(i, K_padded - Ktails), &in_t(i, K - Ktails), Ktails * sizeof(T)); + } + for (int n = 0; n < N_padded; n += N_unit) { + for (int k = 0; k < K_padded; k += kStep) { + // bf16 as example: + // [N, K], 2*[16(N), 32(K)] => + // [K, N], 2*[16(K), 16(N)*2(K)] + for (int m = 0; m < 2; m++) { + auto* src = reinterpret_cast(&in_padded(n + m * 16, k)); + auto* dst = reinterpret_cast(&out(n / N_unit, k * N_unit)); + dst += m * (1024 / sizeof(int)); + for (int i = 0; i < 16; i++) { + for (int j = 0; j < 16; j++) { + dst[i * 16 + j] = src[j * in_padded.stride / sizeof(int) + i]; + } + } + } + } + } + } + + template + void test() { + auto testone = [] (int k, int n, std::string prefix) { + tensor2D A(k, n, true); + + fill_rnd(A); + tensor2D AT = A.Tr(true); + tensor2D A_out, AT_out, A_ref; + amx_kernel::repackB_1x2(A, false, A_out, false); + amx_kernel::repackB_1x2(AT, true, AT_out, false); + gen_ref(AT, A_ref); + ASSERT_TRUE(A_out == A_ref) << " " << prefix << " without transform K: " << k << " N: " << n; + ASSERT_TRUE(AT_out == A_ref) << " " << prefix << " with transform K: " << k << " N: " << n; + }; + // n tail: transpose case needs from 1 to 31, without transpose needs one + int k = 32; + int n; + for (n = 1; n < 32; n++) { + testone(k, n, "ntail"); + } + for (n = 32 + 1; n < 32 + 32; n++) { + testone(k, n, "ntail"); + } + // k tail: transpose case needs 1, without transpose needs from 1 to 31 + n = 32; + for (k = 1; k < 32; k++) { + testone(k, n, "ktail"); + } + for (k = 32 + 1; k < 32 + 32; k++) { + testone(k, n, "ktail"); + } + // k, n normal + testone(32, 32, "normal"); + testone(64, 128, "normal"); + // k, n tail + testone(64, 128 + 5, "ntail"); + testone(64 + 3, 128, "ktail"); + testone(64 + 3, 128 + 5, "alltail"); + testone(64, 128 + 16 + 5, "ntail"); + testone(64 + 16 + 3, 128, "ktail"); + testone(64 + 16 + 3, 128 + 16 + 5, "alltail"); + } + + data_type_t _types; +}; + +TEST_P(RepackTest, Func) { + if (_types == dnnl_s8) { + test(); + } else { + test(); + } +} + +const std::vector types = { + dnnl_s8, dnnl_bf16 +}; + +INSTANTIATE_TEST_SUITE_P(smoke_Repack, RepackTest, + ::testing::Combine(ValuesIn(types)), + RepackTest::getTestCaseName); From 2cacaab33d05b4d8ea86f122f8f22f550da4014d Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Thu, 29 Jun 2023 01:51:25 +0800 Subject: [PATCH 03/54] apply review comments --- CMakeLists.txt | 8 ++++++-- src/CMakeLists.txt | 6 +++++- tests/CMakeLists.txt | 6 +++++- 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 4cd1f2e..8185604 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,4 +1,8 @@ -cmake_minimum_required(VERSION 3.18) +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +cmake_minimum_required(VERSION 3.13) project(root) @@ -9,7 +13,7 @@ message(INFO "--------------------------------") message(STATUS "Build with tests: ${BUILD_TESTS}") message(INFO "--------------------------------") -set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD 11) if(MSVC) # TODO message(FATAL_ERROR "Not support yet. Use intel compiler 2023.0+.") diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 5d118ab..2ece614 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,4 +1,8 @@ -cmake_minimum_required(VERSION 3.18) +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +cmake_minimum_required(VERSION 3.13) project(llmdnn) file(GLOB_RECURSE SOURCE_FILES ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp) diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 09ce3a5..81eb524 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -1,4 +1,8 @@ -cmake_minimum_required(VERSION 3.18) +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +cmake_minimum_required(VERSION 3.13) project(llmdnn_tests) include(FetchContent) From a54403187c8eac781323a4a1810dd548b5a078fd Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Thu, 29 Jun 2023 01:51:25 +0800 Subject: [PATCH 04/54] add mha support --- CMakeLists.txt | 1 + src/mha_gpt.cpp | 358 +++++++++++++++++++ src/mha_gpt.hpp | 85 +++++ src/simple_parallel.hpp | 182 ++++++++++ src/softmax_kernel_avx512.hpp | 218 +++++++++++ src/transpose_kernel_avx512.hpp | 111 ++++++ src/utility.hpp | 56 +++ src/utility_avx512.hpp | 20 ++ tests/CMakeLists.txt | 8 +- tests/script/README.md | 19 + tests/script/ext/CMakeLists.txt | 29 ++ tests/script/ext/mha_gpt.cpp | 100 ++++++ tests/script/ext/module.cpp | 18 + tests/script/ext/module.hpp | 9 + tests/script/ext/setup.py | 35 ++ tests/script/requirements.txt | 2 + tests/script/test_mha_gpt.py | 151 ++++++++ tests/src/test_common.cpp | 21 +- tests/src/{test_common.h => test_common.hpp} | 0 tests/src/test_fc_kernel.cpp | 2 +- tests/src/test_mm_kernel.cpp | 2 +- tests/src/test_softmax_kernel_avx512.cpp | 129 +++++++ tests/src/test_transpose_kernel_avx512.cpp | 131 +++++++ tests/src/test_utility.cpp | 37 ++ tests/src/test_utility_repack1x2.cpp | 3 +- 25 files changed, 1721 insertions(+), 6 deletions(-) create mode 100644 src/mha_gpt.cpp create mode 100644 src/mha_gpt.hpp create mode 100644 src/simple_parallel.hpp create mode 100644 src/softmax_kernel_avx512.hpp create mode 100644 src/transpose_kernel_avx512.hpp create mode 100644 src/utility.hpp create mode 100644 tests/script/README.md create mode 100644 tests/script/ext/CMakeLists.txt create mode 100644 tests/script/ext/mha_gpt.cpp create mode 100644 tests/script/ext/module.cpp create mode 100644 tests/script/ext/module.hpp create mode 100644 tests/script/ext/setup.py create mode 100644 tests/script/requirements.txt create mode 100644 tests/script/test_mha_gpt.py rename tests/src/{test_common.h => test_common.hpp} (100%) create mode 100644 tests/src/test_softmax_kernel_avx512.cpp create mode 100644 tests/src/test_transpose_kernel_avx512.cpp create mode 100644 tests/src/test_utility.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 8185604..0061e15 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -8,6 +8,7 @@ project(root) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) option(BUILD_TESTS "Build with tests" ON) +option(BUILD_PYTHON_TESTS "Build with tests need python extension" OFF) message(INFO "--------------------------------") message(STATUS "Build with tests: ${BUILD_TESTS}") diff --git a/src/mha_gpt.cpp b/src/mha_gpt.cpp new file mode 100644 index 0000000..916c2f8 --- /dev/null +++ b/src/mha_gpt.cpp @@ -0,0 +1,358 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include + +#include "simple_parallel.hpp" +#include "utility.hpp" +#include "utility_avx512.hpp" +#include "mm_kernel_amx.hpp" +#include "softmax_kernel_avx512.hpp" +#include "transpose_kernel_avx512.hpp" +#include "mha_gpt.hpp" + +using namespace ov::cpu; + +namespace llmdnn { + +struct mha_gpt::Impl { + void create(const create_param& param); + void exec(const exec_param& param); + + create_param _create_param; + + void mha_bf16(const exec_param ¶m); + void mha_i8(const exec_param ¶m); + + size_t bufferMatMul0OutSize; + size_t bufferMatMul1OutSize; + + std::shared_ptr bufferMatMul0Out; + std::shared_ptr bufferMatMul1Out; + + std::vector>> gemAvB_BF16xBF16; + std::vector>> qKtrGemm_BF16xBF16; + std::vector>> qKVGemm_BF16xBF16; + + std::vector>> qKtrGemm_i8xi8; + std::vector>> qKVGemm_u8xi8; + std::vector>> gemAvB_i8xi8; +}; + +void mha_gpt::Impl::create(const create_param& param) { + _create_param = param; + + // q: [batch, num_heads, query_seq_len, head_size] + // k: [batch, num_heads, maxSeqLen(valid: key_seq_len), head_size] + // v: [batch, num_heads, maxSeqLen(valid: value_seq_len), head_size] + // attention_mask: [batch, 1, 1, maxSeqLen(valid: key_seq_len)] + // matmul1: [batch, num_heads, query_seq_len, head_size] + // attn_output: [batch, query_seq_len, num_heads * head_size] + size_t numThreads = getTotalThreads(); + if (_create_param.qkv_precision == dnnl_s8) { + qKtrGemm_i8xi8.resize(numThreads); + for (size_t i = 0; i < numThreads; i++) { + qKtrGemm_i8xi8[i] = std::make_shared>(false, true); + } + qKVGemm_u8xi8.resize(numThreads); + for (size_t i = 0; i < numThreads; i++) { + qKVGemm_u8xi8[i] = std::make_shared>(false, false); + } + gemAvB_i8xi8.resize(numThreads); + for (size_t i = 0; i < numThreads; i++) { + gemAvB_i8xi8[i] = std::make_shared>(); + } + } else { + gemAvB_BF16xBF16.resize(numThreads); + for (size_t i = 0; i < numThreads; i++) { + gemAvB_BF16xBF16[i] = std::make_shared>(); + } + qKtrGemm_BF16xBF16.resize(numThreads); + for (size_t i = 0; i < numThreads; i++) { + qKtrGemm_BF16xBF16[i] = std::make_shared>(false, true); + } + qKVGemm_BF16xBF16.resize(numThreads); + for (size_t i = 0; i < numThreads; i++) { + qKVGemm_BF16xBF16[i] = std::make_shared>(false, false); + } + } + + bufferMatMul0OutSize = _create_param.max_seq_len * rndup(_create_param.max_seq_len * sizeof(float), 64); + bufferMatMul1OutSize = _create_param.max_seq_len * _create_param.head_size_aligned * sizeof(float); + + bufferMatMul0Out = std::shared_ptr( + reinterpret_cast(aligned_alloc(64, numThreads * bufferMatMul0OutSize)), + [](void * p) { ::free(p); }); + bufferMatMul1Out = std::shared_ptr( + reinterpret_cast(aligned_alloc(64, numThreads * bufferMatMul1OutSize)), + [](void * p) { ::free(p); }); +} + +void mha_gpt::Impl::mha_bf16(const exec_param ¶m) { + uint8_t* pQIn0 = param.q; + auto& pKIn0 = param.k; + auto& attn_masks = param.attention_mask; + auto& pVIn0 = param.v; + uint8_t* pout = param.attn_output; + + auto outPrcSize = get_precision_size(_create_param.qkv_precision); + auto& gemAvB_ops = gemAvB_BF16xBF16; + auto& qKtrGemm_ops = qKtrGemm_BF16xBF16; + auto& qKVGemm_ops = qKVGemm_BF16xBF16; + bool is_vector = param.query_seq_len == 1; + size_t head_stride_in_q = _create_param.head_size_aligned * param.query_seq_len; + size_t batch_stride_in_q = head_stride_in_q * _create_param.num_heads; + size_t head_stride_in_attn = _create_param.head_size; + size_t batch_stride_in_attn = _create_param.head_size * _create_param.num_heads * param.query_seq_len; + + if (is_vector) { + parallel_for2d(param.batch, _create_param.num_heads, [&](size_t threadNum, size_t i0, size_t i1) { + auto pQIn0_aux = pQIn0 + (i0 * batch_stride_in_q + i1 * head_stride_in_q) * get_precision_size(_create_param.qkv_precision); + auto pKIn0_aux = pKIn0[i0] + i1 * param.head_stride_in_kv * get_precision_size(_create_param.qkv_precision); + auto pVIn0_aux = pVIn0[i0] + i1 * param.head_stride_in_kv * get_precision_size(_create_param.qkv_precision); + + auto pAddIn1_aux = attn_masks[i0]; + + auto bufferMatMul0Out_local = reinterpret_cast(bufferMatMul0Out.get() + threadNum * bufferMatMul0OutSize); + auto bufferMatMul1Out_local = reinterpret_cast(bufferMatMul1Out.get() + threadNum * bufferMatMul1OutSize); + + tensor2D matK(param.key_seq_len, _create_param.head_size, reinterpret_cast(pKIn0_aux), _create_param.head_size_aligned * sizeof(ov::bfloat16)); + // N: key_seq_len, K: head_size + // q[1, K] * transpose(k[N, K]) ==> + // k[N, K] * transpose(q[1, K]) ==> + // k[N, K] * q[K, 1] + (*gemAvB_ops[threadNum])(matK, reinterpret_cast(pQIn0_aux), reinterpret_cast(bufferMatMul0Out_local)); + + float* pMatMul0Out = reinterpret_cast(bufferMatMul0Out_local); + mul_add_f32(pMatMul0Out, pMatMul0Out, _create_param.normal_factor, pAddIn1_aux, param.key_seq_len); + softmax(reinterpret_cast(pMatMul0Out), pMatMul0Out, param.key_seq_len, nullptr, nullptr, nullptr); + auto pOut_aux = pout + (i0 * batch_stride_in_attn + i1 * head_stride_in_attn) * outPrcSize; + tensor2D matQK(param.query_seq_len, param.key_seq_len, reinterpret_cast(bufferMatMul0Out_local), rndup(param.key_seq_len * sizeof(ov::bfloat16), 64)); + tensor2D matV(param.key_seq_len, _create_param.head_size, reinterpret_cast(pVIn0_aux), _create_param.head_size_aligned * sizeof(ov::bfloat16)); + tensor2D matQKV(param.query_seq_len, _create_param.head_size, reinterpret_cast(bufferMatMul1Out_local), _create_param.head_size_aligned * sizeof(float)); + amx_kernel::PP::BiasGeluStore pp(matQKV); + (*qKVGemm_ops[threadNum])(matQK, matV, 0, _create_param.head_size, pp); + memcpy2d_stride(reinterpret_cast(pOut_aux), reinterpret_cast(bufferMatMul1Out_local), param.query_seq_len, + _create_param.head_size, _create_param.head_size_aligned, _create_param.num_heads * _create_param.head_size, nullptr); + }); + } else { + auto numThreads = getTotalThreads(); + int seq_cout_all = rndup(param.query_seq_len, 32) / 32; + int work_amount = param.batch * _create_param.num_heads * seq_cout_all; + parallel_for(numThreads, [&](int threadNum) { + int i0; + int i1; + int seq; + int start {0}, end {0}; + splitter(work_amount, static_cast(numThreads), threadNum, start, end); + + parallel_it_init(start, i0, param.batch, i1, _create_param.num_heads, seq, seq_cout_all); + uint8_t* prev_k = nullptr; + uint8_t* prev_v = nullptr; + for (int iwork = start; iwork < end; ++iwork) { + int seq_start = seq * 32; + int seq_end = std::min(static_cast(seq_start) + 32, param.query_seq_len); + int seq_cout = seq_end - seq_start; + // q: [batch, num_heads, query_seq_len, head_size] + // k: [batch, num_heads, key_seq_len, head_size] + // v: [batch, num_heads, value_seq_len, head_size] + auto pQIn0_aux = pQIn0 + (i0 * batch_stride_in_q + i1 * head_stride_in_q + seq_start * _create_param.head_size_aligned) * get_precision_size(_create_param.qkv_precision); + auto pKIn0_aux = pKIn0[i0] + i1 * param.head_stride_in_kv * get_precision_size(_create_param.qkv_precision); + auto pVIn0_aux = pVIn0[i0] + i1 * param.head_stride_in_kv * get_precision_size(_create_param.qkv_precision); + + auto pAddIn1_aux = attn_masks[i0]; + + auto bufferMatMul0Out_local = reinterpret_cast(bufferMatMul0Out.get() + threadNum * bufferMatMul0OutSize); + auto bufferMatMul1Out_local = reinterpret_cast(bufferMatMul1Out.get() + threadNum * bufferMatMul1OutSize); + + tensor2D matQ(seq_cout, _create_param.head_size, reinterpret_cast(pQIn0_aux), _create_param.head_size_aligned * sizeof(ov::bfloat16)); + tensor2D matK(param.key_seq_len, _create_param.head_size, reinterpret_cast(pKIn0_aux), _create_param.head_size_aligned * sizeof(ov::bfloat16)); + tensor2D matQK(seq_cout, param.key_seq_len, reinterpret_cast(bufferMatMul0Out_local), rndup(param.key_seq_len * sizeof(float), 64)); + amx_kernel::PP::BiasGeluStore pp(matQK); + (*qKtrGemm_ops[threadNum])(matQ, matK, 0, param.key_seq_len, pp, pKIn0_aux == prev_k); + prev_k = pKIn0_aux; + + auto pMatMul0Out = bufferMatMul0Out_local; + // loop along K dimension + size_t valid_softmax_items = seq_start + 1; + for (size_t m = 0; m < seq_cout; m++) { + float* src = reinterpret_cast(pMatMul0Out + m * rndup(param.key_seq_len * sizeof(float), 64)); + ov::bfloat16* dst = reinterpret_cast(pMatMul0Out + m * rndup(param.key_seq_len * sizeof(ov::bfloat16), 64)); + mul_add_f32(src, src, _create_param.normal_factor, pAddIn1_aux, valid_softmax_items); + softmax(dst, src, valid_softmax_items, nullptr, nullptr, nullptr); + // attn_scores = torch.where(causal_mask, attn_scores, mask_value) + if (param.key_seq_len > valid_softmax_items) { + auto *invalidPtr = dst + valid_softmax_items; + memset(invalidPtr, 0, (param.key_seq_len - valid_softmax_items) * get_precision_size(_create_param.qkv_precision)); + valid_softmax_items = std::min(valid_softmax_items + 1, param.key_seq_len); + } + } + auto pOut_aux = pout + (i0 * batch_stride_in_attn + i1 * head_stride_in_attn + + seq_start * head_stride_in_attn * _create_param.num_heads) * outPrcSize; + tensor2D matQKBF16(seq_cout, param.key_seq_len, reinterpret_cast(bufferMatMul0Out_local), rndup(param.key_seq_len * sizeof(ov::bfloat16), 64)); + tensor2D matV(param.key_seq_len, _create_param.head_size, reinterpret_cast(pVIn0_aux), _create_param.head_size_aligned * sizeof(ov::bfloat16)); + tensor2D matQKV(seq_cout, _create_param.head_size, reinterpret_cast(bufferMatMul1Out_local), _create_param.head_size_aligned * sizeof(float)); + amx_kernel::PP::BiasGeluStore pp2(matQKV); + (*qKVGemm_ops[threadNum])(matQKBF16, matV, 0, _create_param.head_size, pp2, prev_v == pVIn0_aux); + prev_v = pVIn0_aux; + memcpy2d_stride(reinterpret_cast(pOut_aux), reinterpret_cast(bufferMatMul1Out_local), seq_cout, + _create_param.head_size, _create_param.head_size_aligned, _create_param.num_heads * _create_param.head_size, nullptr); + parallel_it_step(i0, param.batch, i1, _create_param.num_heads, seq, seq_cout_all); + } + }); + } +} + +void mha_gpt::Impl::mha_i8(const exec_param ¶m) { + uint8_t* pQIn0 = param.q; + auto& pKIn0 = param.k; + auto& attn_masks = param.attention_mask; + auto& pVIn0 = param.v; + uint8_t* pout = param.attn_output; + + auto outPrcSize = get_precision_size(_create_param.dst_precision); + auto& gemAvB_ops = gemAvB_i8xi8; + auto& qKtrGemm_ops = qKtrGemm_i8xi8; + auto& qKVGemm_ops = qKVGemm_u8xi8; + bool is_vector = param.query_seq_len == 1; + // dequant param + auto mul_scales = _create_param.normal_factor * param.q_dequant * param.k_dequant; + // prepare for per channel + auto qkv_quant = param.qkv_quant; + std::vector qk_quant_vec(_create_param.head_size, param.qk_quant); + for (size_t i = 0; i < param.qkv_quant.size(); i++) { + qkv_quant[i] *= param.v_dequant / param.qk_quant; + } + size_t head_stride_in_q = _create_param.head_size_aligned * param.query_seq_len; + size_t batch_stride_in_q = head_stride_in_q * _create_param.num_heads; + size_t head_stride_in_attn = _create_param.head_size; + size_t batch_stride_in_attn = _create_param.head_size * _create_param.num_heads * param.query_seq_len; + + if (is_vector) { + parallel_for2d(param.batch, _create_param.num_heads, [&](size_t threadNum, size_t i0, size_t i1) { + auto pQIn0_aux = pQIn0 + (i0 * batch_stride_in_q + i1 * head_stride_in_q) * get_precision_size(_create_param.qkv_precision); + auto pKIn0_aux = pKIn0[i0] + i1 * param.head_stride_in_kv * get_precision_size(_create_param.qkv_precision); + auto pVIn0_aux = pVIn0[i0] + i1 * param.head_stride_in_kv * get_precision_size(_create_param.qkv_precision); + + auto pAddIn1_aux = attn_masks[i0]; + + auto bufferMatMul0Out_local = reinterpret_cast(bufferMatMul0Out.get() + threadNum * bufferMatMul0OutSize); + auto bufferMatMul1Out_local = reinterpret_cast(bufferMatMul1Out.get() + threadNum * bufferMatMul1OutSize); + + tensor2D matK(param.key_seq_len, _create_param.head_size, reinterpret_cast(pKIn0_aux), _create_param.head_size_aligned * sizeof(int8_t)); + // N: key_seq_len, K: head_size + // q[1, K] * transpose(k[N, K]) ==> + // k[N, K] * transpose(q[1, K]) ==> + // k[N, K] * q[K, 1] + (*gemAvB_ops[threadNum])(matK, reinterpret_cast(pQIn0_aux), reinterpret_cast(bufferMatMul0Out_local)); + cvt_i32_f32(reinterpret_cast(bufferMatMul0Out_local), reinterpret_cast(bufferMatMul0Out_local), param.key_seq_len); + + float* pMatMul0Out = reinterpret_cast(bufferMatMul0Out_local); + mul_add_f32(pMatMul0Out, pMatMul0Out, mul_scales, pAddIn1_aux, param.key_seq_len); + softmax(reinterpret_cast(pMatMul0Out), pMatMul0Out, param.key_seq_len, nullptr, nullptr, qk_quant_vec.data()); + auto pOut_aux = pout + (i0 * batch_stride_in_attn + i1 * head_stride_in_attn) * outPrcSize; + tensor2D matQK(param.query_seq_len, param.key_seq_len, reinterpret_cast(bufferMatMul0Out_local), rndup(param.key_seq_len * sizeof(uint8_t), 64)); + tensor2D matV(param.key_seq_len, _create_param.head_size, reinterpret_cast(pVIn0_aux), _create_param.head_size_aligned * sizeof(int8_t)); + tensor2D matQKV(param.query_seq_len, _create_param.head_size, reinterpret_cast(bufferMatMul1Out_local), _create_param.head_size_aligned * sizeof(float)); + amx_kernel::PP::BiasGeluStore pp(matQKV); + (*qKVGemm_ops[threadNum])(matQK, matV, 0, _create_param.head_size, pp); + memcpy2d_stride(reinterpret_cast(pOut_aux), reinterpret_cast(bufferMatMul1Out_local), param.query_seq_len, + _create_param.head_size, _create_param.head_size_aligned, _create_param.num_heads * _create_param.head_size, qkv_quant.data()); + }); + } else { + auto numThreads = getTotalThreads(); + int seq_cout_all = rndup(param.query_seq_len, 32) / 32; + int work_amount = param.batch * _create_param.num_heads * seq_cout_all; + parallel_for(numThreads, [&](int threadNum) { + int i0; + int i1; + int seq; + int start {0}, end {0}; + splitter(work_amount, static_cast(numThreads), threadNum, start, end); + + parallel_it_init(start, i0, param.batch, i1, _create_param.num_heads, seq, seq_cout_all); + uint8_t* prev_k = nullptr; + uint8_t* prev_v = nullptr; + for (int iwork = start; iwork < end; ++iwork) { + int seq_start = seq * 32; + int seq_end = std::min(static_cast(seq_start) + 32, param.query_seq_len); + int seq_cout = seq_end - seq_start; + // q: [batch, num_heads, query_seq_len, head_size] + // k: [batch, num_heads, key_seq_len, head_size] + // v: [batch, num_heads, value_seq_len, head_size] + auto pQIn0_aux = pQIn0 + (i0 * batch_stride_in_q + i1 * head_stride_in_q + seq_start * _create_param.head_size_aligned); + auto pKIn0_aux = pKIn0[i0] + i1 * param.head_stride_in_kv; + auto pVIn0_aux = pVIn0[i0] + i1 * param.head_stride_in_kv; + + auto pAddIn1_aux = attn_masks[i0]; + + auto bufferMatMul0Out_local = reinterpret_cast(bufferMatMul0Out.get() + threadNum * bufferMatMul0OutSize); + auto bufferMatMul1Out_local = reinterpret_cast(bufferMatMul1Out.get() + threadNum * bufferMatMul1OutSize); + + tensor2D matQ(seq_cout, _create_param.head_size, reinterpret_cast(pQIn0_aux), _create_param.head_size_aligned * sizeof(int8_t)); + tensor2D matK(param.key_seq_len, _create_param.head_size, reinterpret_cast(pKIn0_aux), _create_param.head_size_aligned * sizeof(int8_t)); + tensor2D matQK(seq_cout, param.key_seq_len, reinterpret_cast(bufferMatMul0Out_local), rndup(param.key_seq_len * sizeof(float), 64)); + amx_kernel::PP::BiasGeluStore pp(matQK); + (*qKtrGemm_ops[threadNum])(matQ, matK, 0, param.key_seq_len, pp, prev_k == pKIn0_aux); + prev_k = pKIn0_aux; + + auto pMatMul0Out = bufferMatMul0Out_local; + // loop along K dimension + size_t valid_softmax_items = seq_start + 1; + for (size_t m = 0; m < seq_cout; m++) { + float* src = reinterpret_cast(pMatMul0Out + m * rndup(param.key_seq_len * sizeof(float), 64)); + uint8_t* dst = reinterpret_cast(pMatMul0Out + m * rndup(param.key_seq_len * sizeof(uint8_t), 64)); + mul_add_f32(src, src, mul_scales, pAddIn1_aux, valid_softmax_items); + softmax(dst, src, valid_softmax_items, nullptr, nullptr, qk_quant_vec.data()); + // attn_scores = torch.where(causal_mask, attn_scores, mask_value) + if (param.key_seq_len > valid_softmax_items) { + auto *invalidPtr = dst + valid_softmax_items; + memset(invalidPtr, 0, (param.key_seq_len - valid_softmax_items) * get_precision_size(_create_param.qkv_precision)); + valid_softmax_items = std::min(valid_softmax_items + 1, param.key_seq_len); + } + } + auto pOut_aux = pout + (i0 * batch_stride_in_attn + i1 * head_stride_in_attn + + seq_start * head_stride_in_attn * _create_param.num_heads) * outPrcSize; + tensor2D matQKI8(seq_cout, param.key_seq_len, reinterpret_cast(bufferMatMul0Out_local), rndup(param.key_seq_len * sizeof(uint8_t), 64)); + tensor2D matV(param.key_seq_len, _create_param.head_size, reinterpret_cast(pVIn0_aux), _create_param.head_size_aligned * sizeof(int8_t)); + tensor2D matQKV(seq_cout, _create_param.head_size, reinterpret_cast(bufferMatMul1Out_local), _create_param.head_size_aligned * sizeof(float)); + amx_kernel::PP::BiasGeluStore pp2(matQKV); + (*qKVGemm_ops[threadNum])(matQKI8, matV, 0, _create_param.head_size, pp2, prev_v == pVIn0_aux); + prev_v = pVIn0_aux; + // matmul1: [batch, num_heads, query_seq_len, head_size] + // attn_output: [batch, query_seq_len, num_heads * head_size] + memcpy2d_stride(reinterpret_cast(pOut_aux), reinterpret_cast(bufferMatMul1Out_local), seq_cout, + _create_param.head_size, _create_param.head_size_aligned, _create_param.num_heads * _create_param.head_size, qkv_quant.data()); + parallel_it_step(i0, param.batch, i1, _create_param.num_heads, seq, seq_cout_all); + } + }); + } +} + +void mha_gpt::Impl::exec(const exec_param& param) { + if (_create_param.qkv_precision == dnnl_f32) { + assert(false); + } else if (_create_param.qkv_precision == dnnl_bf16) { + mha_bf16(param); + } else if (_create_param.qkv_precision == dnnl_s8) { + mha_i8(param); + } else { + assert(false && "doesn't support provided input precisions"); + } +} + +// interface +mha_gpt::mha_gpt(): _impl(std::make_shared()) { +} + +void mha_gpt::create(const create_param& param) { + _impl->create(param); +} + +void mha_gpt::exec(const exec_param& param) { + _impl->exec(param); +} + +} \ No newline at end of file diff --git a/src/mha_gpt.hpp b/src/mha_gpt.hpp new file mode 100644 index 0000000..0605f87 --- /dev/null +++ b/src/mha_gpt.hpp @@ -0,0 +1,85 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include "llm_types.hpp" + +namespace llmdnn { + +// pattern is: +// query:[batch, num_heads, query_seq_len, head_size] key:[batch, num_heads, key_seq_len, head_size] +// \ | +// \ Transpose0: [batch, num_heads, head_size, key_seq_len] +// \ / +// \ / +// \ / +// MatMul0: [batch, num_heads, query_seq_len, key_seq_len] +// | +// | norm_factor(const): [1] +// | / +// Multiply: [batch, num_heads, query_seq_len, key_seq_len] +// | +// | causal_mask: [1, 1, query_seq_len, key_seq_len] +// | / +// Select(only for 1x300): [batch, num_heads, query_seq_len, key_seq_len] +// | +// | attention_mask:[batch, 1, 1, key_seq_len] +// | / +// Add: [batch, num_heads, query_seq_len, key_seq_len] +// | +// SoftMax: [batch, num_heads, query_seq_len, key_seq_len] +// | +// \ value:[batch, num_heads, key_seq_len, head_size] +// \ / +// MatMul1: [batch, num_heads, query_seq_len, head_size] +// | +// Transpose1(only for 1x300): [batch, query_seq_len, num_heads * head_size] +class mha_gpt { +public: + struct create_param { + size_t num_heads; + size_t head_size; + size_t head_size_aligned; // better to aligned to 64 bytes for best performance, apply for qkv + size_t max_seq_len; // max seq length for computing the size of matmul tmp result + float normal_factor; + data_type_t qkv_precision; + data_type_t dst_precision; + }; + struct exec_param { + size_t batch; + size_t query_seq_len; + size_t key_seq_len; + uint8_t* q; // q buffer, compact, shape: [batch, num_heads, query_seq_len, head_size] + uint8_t** k; // k buffer, k[N] stands different batch which may be discreted + // k[0] shape: [batch, num_heads, key_seq_len, head_size] + uint8_t** v; // v buffer, v[N] stands different batch which may be discreted + // v[0] shape: [batch, num_heads, value_seq_len, head_size] + float** attention_mask; // attention mask, attention_mask[N] is the batch + // attention_mask[0] shape: [1, max_seq_len] + uint8_t* attn_output; // output, compact, shape: [batch, query_seq_len, num_heads * head_size] + size_t head_stride_in_kv; // kv stride for next head; kv may be preallocated a big buffer + float q_dequant; + float k_dequant; + float v_dequant; + float qk_quant; + std::vector qkv_quant; // per channel + // float* qk_normal_dq; // per channel, each item = normal_factor * q_dequant * k_dequant, used for softmax input + // float* qk_quant; // per channel, used for softmax output + // float* qkv_dq_q; // per channel, each item = 1 / qk_quant * v_dequant * qkv_quant, used for matmul2 output + }; + + mha_gpt(); + void create(const create_param& param); + void exec(const exec_param& param); + +private: + struct Impl; + std::shared_ptr _impl; +}; + +} diff --git a/src/simple_parallel.hpp b/src/simple_parallel.hpp new file mode 100644 index 0000000..b0548f4 --- /dev/null +++ b/src/simple_parallel.hpp @@ -0,0 +1,182 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include + +namespace ov { +namespace cpu { + +size_t getTotalThreads(); +void TrySimpleParallelFor(const std::ptrdiff_t total, const std::function& fn); + +// copy from openvino/core/parallel.hpp +template +inline T parallel_it_init(T start) { + return start; +} +template +inline T parallel_it_init(T start, Q& x, const R& X, Args&&... tuple) { + start = parallel_it_init(start, static_cast(tuple)...); + x = start % X; + return start / X; +} + +inline bool parallel_it_step() { + return true; +} +template +inline bool parallel_it_step(Q& x, const R& X, Args&&... tuple) { + if (parallel_it_step(static_cast(tuple)...)) { + if (++x - X == 0) { + x = 0; + return true; + } + } + return false; +} + +template +inline void splitter(const T& n, const Q& team, const Q& tid, T& n_start, T& n_end) { + if (team <= 1 || n == 0) { + n_start = 0; + n_end = n; + } else { + T n1 = (n + (T)team - 1) / (T)team; + T n2 = n1 - 1; + T T1 = n - n2 * (T)team; + n_end = (T)tid < T1 ? n1 : n2; + n_start = (T)tid <= T1 ? tid * n1 : T1 * n1 + ((T)tid - T1) * n2; + } + + n_end += n_start; +} + +namespace helpers { +template +struct NumOfLambdaArgs : public NumOfLambdaArgs {}; + +template +struct NumOfLambdaArgs { + constexpr static int value = sizeof...(Args); +}; + +template ::value> +typename std::enable_if::type call_with_args(const ACT& body, + size_t g_id, + size_t iwork, + T... arg) { + body(g_id, iwork, arg...); +} + +template ::value> +typename std::enable_if::type call_with_args(const ACT& body, + size_t g_id, + size_t iwork, + T... arg) { + body(g_id, arg...); +} + +template ::value> +typename std::enable_if::type call_with_args(const ACT& body, + size_t g_id, + size_t iwork, + T... arg) { + body(arg...); +} +} // namespace helpers + +template +void for_1d(const int& ithr, const int& nthr, const T0& D0, const F& func) { + T0 d0{0}, end{0}; + splitter(D0, nthr, ithr, d0, end); + for (; d0 < end; ++d0) + helpers::call_with_args(func, ithr, d0, d0); +} + +template +void parallel_for(const T0& D0, const F& func) { + auto work_amount = static_cast(D0); + int nthr = static_cast(getTotalThreads()); + if (static_cast(nthr) > work_amount) + nthr = static_cast(work_amount); + if (nthr == 1) { + for_1d(0, 1, D0, func); + } else { + TrySimpleParallelFor(static_cast(nthr), [&](size_t ithr) { + for_1d(static_cast(ithr), nthr, D0, func); + }); + } +} + +template +void for_2d(const int& ithr, const int& nthr, const T0& D0, const T1& D1, const F& func) { + const size_t work_amount = (size_t)D0 * D1; + if (work_amount == 0) + return; + size_t start{0}, end{0}; + splitter(work_amount, nthr, ithr, start, end); + + T0 d0{0}; + T1 d1{0}; + parallel_it_init(start, d0, D0, d1, D1); + for (size_t iwork = start; iwork < end; ++iwork) { + helpers::call_with_args(func, ithr, iwork, d0, d1); + parallel_it_step(d0, D0, d1, D1); + } +} + +template +void parallel_for2d(const T0& D0, const T1& D1, const F& func) { + auto work_amount = static_cast(D0 * D1); + int nthr = static_cast(getTotalThreads()); + if (static_cast(nthr) > work_amount) + nthr = static_cast(work_amount); + if (nthr == 1) { + for_2d(0, 1, D0, D1, func); + } else { + TrySimpleParallelFor(static_cast(nthr), [&](size_t ithr) { + for_2d(static_cast(ithr), nthr, D0, D1, func); + }); + } +} + +template +void for_3d(const int& ithr, const int& nthr, const T0& D0, const T1& D1, const T2& D2, const F& func) { + const size_t work_amount = (size_t)D0 * D1 * D2; + if (work_amount == 0) + return; + size_t start{0}, end{0}; + splitter(work_amount, nthr, ithr, start, end); + + T0 d0{0}; + T1 d1{0}; + T2 d2{0}; + parallel_it_init(start, d0, D0, d1, D1, d2, D2); + for (size_t iwork = start; iwork < end; ++iwork) { + helpers::call_with_args(func, ithr, iwork, d0, d1, d2); + parallel_it_step(d0, D0, d1, D1, d2, D2); + } +} + +template +void parallel_for3d(const T0& D0, const T1& D1, const T2& D2, const F& func) { + auto work_amount = static_cast(D0 * D1 * D2); + int nthr = static_cast(getTotalThreads()); + if (static_cast(nthr) > work_amount) + nthr = static_cast(work_amount); + if (nthr == 1) { + for_3d(0, 1, D0, D1, D2, func); + } else { + TrySimpleParallelFor(static_cast(nthr), [&](size_t ithr) { + for_3d(static_cast(ithr), nthr, D0, D1, D2, func); + }); + } +} + +}; // namespace cpu +}; // namespace ov \ No newline at end of file diff --git a/src/softmax_kernel_avx512.hpp b/src/softmax_kernel_avx512.hpp new file mode 100644 index 0000000..f9638d2 --- /dev/null +++ b/src/softmax_kernel_avx512.hpp @@ -0,0 +1,218 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#ifdef _WIN32 +#include +#else +#include +#include +#endif +#include "bf16.hpp" +#include "llm_types.hpp" +#include "utility_avx512.hpp" + +namespace llmdnn { + inline void exp_ps(__m512 & src) { + static __m512 exp_ln_flt_min_f = _mm512_castsi512_ps(_mm512_set1_epi32(0xc2aeac50)); // log(FLT_MIN) + static __m512 exp_ln_flt_max_f = _mm512_castsi512_ps(_mm512_set1_epi32(0x42b17218)); // log(FLT_MAX) + static __m512 exp_log2ef = _mm512_castsi512_ps(_mm512_set1_epi32(0x3fb8aa3b)); // log2(e) + static __m512 half = _mm512_castsi512_ps(_mm512_set1_epi32(0x3f000000)); // 0.5f + static __m512 ln2f = _mm512_castsi512_ps(_mm512_set1_epi32(0x3f317218)); // ln(2) + static __m512 one = _mm512_castsi512_ps(_mm512_set1_epi32(0x3f800000)); // 1.0f + static __m512i exponent_bias = _mm512_set1_epi32(0x0000007f); // 127 + static constexpr int n_mantissa_bits = 23; + static __m512 exp_pol1 = _mm512_castsi512_ps(_mm512_set1_epi32(0x3f7ffffb)); // p1 = 0.999999701f + static __m512 exp_pol2 = _mm512_castsi512_ps(_mm512_set1_epi32(0x3efffee3)); // p2 = 0.499991506f + static __m512 exp_pol3 = _mm512_castsi512_ps(_mm512_set1_epi32(0x3e2aad40)); // p3 = 0.166676521f + static __m512 exp_pol4 = _mm512_castsi512_ps(_mm512_set1_epi32(0x3d2b9d0d)); // p4 = 0.0418978221f + static __m512 exp_pol5 = _mm512_castsi512_ps(_mm512_set1_epi32(0x3c07cfce)); // p5 = 0.00828929059f + static __m512 two = _mm512_castsi512_ps(_mm512_set1_epi32(0x40000000)); // 2 + // exp(x) = + // = exp(n * ln(2) + r) // divide x by ln(2) and get quot and rem + // = 2^n * exp(r) // simplify the exp(n*ln(2)) expression + + // get mask of values lower than log(FLT_MIN) to zero them in the output + auto zero_mask = _mm512_cmp_ps_mask(src, exp_ln_flt_min_f, _CMP_LT_OS); + + // clip src + src = _mm512_min_ps(src, exp_ln_flt_max_f); + src = _mm512_max_ps(src, exp_ln_flt_min_f); + + // aux1 : r + auto aux1 = src; + + // calculate exp(x) + // fx = x * log2(e) + 0.5 + src = _mm512_mul_ps(src, exp_log2ef); + src = _mm512_add_ps(src, half); + + // tmp = floorf(fx) + src = _mm512_floor_ps(src); + + // aux1 = x - fx * ln2 + aux1 = _mm512_fnmadd_ps(src, ln2f, aux1); + + // We do not count 2^n here, because n can reach 128 and 2^128 is not + // representable by fp32, so to get around this problem, instead of computing + // 2^n * exp(r) will be counted 2*2^(n-1)*exp(r), because 2^127 + // and 2 are numbers representable in fp32. + + // compute 2^(n-1) + src = _mm512_sub_ps(src, one); + auto aux2_i = _mm512_cvtps_epi32(src); + aux2_i = _mm512_add_epi32(aux2_i, exponent_bias); + aux2_i = _mm512_slli_epi32 (aux2_i, n_mantissa_bits); + + // set zeroes at those points which were < log(FLT_MIN) + auto zero = _mm512_setzero_ps(); + auto aux2 = _mm512_mask_blend_ps(zero_mask, _mm512_castsi512_ps(aux2_i), zero); + + // compute polynomial + src = exp_pol5; + src = _mm512_fmadd_ps(src, aux1, exp_pol4); + src = _mm512_fmadd_ps(src, aux1, exp_pol3); + src = _mm512_fmadd_ps(src, aux1, exp_pol2); + src = _mm512_fmadd_ps(src, aux1, exp_pol1); + src = _mm512_fmadd_ps(src, aux1, one); + + // y = y * 2^n + src = _mm512_mul_ps(src, aux2); + src = _mm512_mul_ps(src, two); + } + + template + void softmax(D* dst, float* src, int N, float* s_max=nullptr, float* s_sum=nullptr, float* quant=nullptr) { + static_assert(std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value, + "softmax only support output data types ov::bfloat16/uint8_t/int8_t/float"); + + static __m512 one = _mm512_castsi512_ps(_mm512_set1_epi32(0x3f800000)); // 1.0f + auto tail = N % 16; + __mmask16 x_mask = _cvtu32_mask16(0xFFFFu >> (16 - tail)); + + // get max + auto x_max = _mm512_set1_ps(std::numeric_limits::lowest()); + int i; + for (i = 0; i < N - tail; i += 16) { + auto x = _mm512_loadu_ps(src + i); + x_max = _mm512_max_ps(x_max, x); + } + // tails + if (tail) { + auto x = _mm512_maskz_loadu_ps(x_mask, src + i); + x_max = _mm512_mask_max_ps(x_max, x_mask, x_max, x); + } + auto max = _mm512_reduce_max_ps(x_max); + if (s_max) *s_max = max; + x_max = _mm512_set1_ps(max); + + // softmax + auto sum_exp = _mm512_setzero_ps(); + for(i = 0; i < N - tail; i += 16) { + auto x = _mm512_loadu_ps(src + i); + x = _mm512_sub_ps(x, x_max); + exp_ps(x); // exp(x-x_max) + sum_exp = _mm512_add_ps(sum_exp, x); // sum(exp(x-x_max)) + _mm512_storeu_ps(src + i, x); // save exp(x-x_max) + } + + // handle tails + if (tail) { + auto x = _mm512_maskz_loadu_ps(x_mask, src + i); + x = _mm512_sub_ps(x, x_max); + exp_ps(x); + x = _mm512_mask_blend_ps(x_mask, _mm512_setzero_ps(), x); + sum_exp = _mm512_add_ps(sum_exp, x); + _mm512_mask_storeu_ps(src + i, x_mask, x); + } + + auto sum = _mm512_reduce_add_ps(sum_exp); + if (s_sum) *s_sum = sum; + sum_exp = _mm512_set1_ps(sum); + auto reciprocal_sum_exp = _mm512_div_ps(one, sum_exp); // 1/sum_exp + + // divide + if (std::is_same::value) { + for(i = 0; i < N - tail; i += 16) { + auto x = _mm512_loadu_ps(src + i); + x = _mm512_mul_ps(x, reciprocal_sum_exp); + _mm512_storeu_ps(dst + i, x); + } + // handle tails + if (tail) { + auto x = _mm512_maskz_loadu_ps(x_mask, src + i); + x = _mm512_mul_ps(x, reciprocal_sum_exp); + _mm512_mask_storeu_ps(dst + i, x_mask, x); + } + } + if (std::is_same::value) { + for(i = 0; i < N / 32 * 32; i += 32) { + auto x0 = _mm512_loadu_ps(src + i); + auto x1 = _mm512_loadu_ps(src + i + 16); + x0 = _mm512_mul_ps(x0, reciprocal_sum_exp); + x1 = _mm512_mul_ps(x1, reciprocal_sum_exp); + auto out = _mm512_cvtne2ps_pbh(x1, x0); + _mm512_storeu_epi32(dst + i, (__m512i)out); + } + if (i < N - tail) { + auto x = _mm512_loadu_ps(src + i); + x = _mm512_mul_ps(x, reciprocal_sum_exp); + auto out = _mm512_cvtne2ps_pbh(x, x); + _mm256_storeu_si256(reinterpret_cast<__m256i *>(dst + i), _mm512_extracti64x4_epi64(out, 0)); + i += 16; + } + // handle tails + if (tail) { + auto x = _mm512_maskz_loadu_ps(x_mask, src + i); + x = _mm512_mul_ps(x, reciprocal_sum_exp); + auto out = _mm512_cvtne2ps_pbh(x, x); + _mm256_mask_storeu_epi16(dst + i, x_mask, _mm512_extracti64x4_epi64(out, 0)); + } + } + if (std::is_same::value) { + for(i = 0; i < N - tail; i += 16) { + auto q = _mm512_loadu_ps(quant + i); + auto x = _mm512_loadu_ps(src + i); + x = _mm512_mul_ps(x, reciprocal_sum_exp); + x = _mm512_mul_ps(x, q); + auto x_i = _mm512_cvtps_epi32(x); + _mm512_mask_cvtsepi32_storeu_epi8(dst + i, 0xFFFF, x_i); + } + // handle tails + if (tail) { + auto x = _mm512_maskz_loadu_ps(x_mask, src + i); + auto q = _mm512_maskz_loadu_ps(x_mask, quant + i); + x = _mm512_mul_ps(x, reciprocal_sum_exp); + x = _mm512_mul_ps(x, q); + auto x_i = _mm512_cvtps_epi32(x); + _mm512_mask_cvtsepi32_storeu_epi8(dst + i, x_mask, x_i); + } + } + if (std::is_same::value) { + auto zero = _mm512_setzero_epi32(); + for(i = 0; i < N - tail; i += 16) { + auto q = _mm512_loadu_ps(quant + i); + auto x = _mm512_loadu_ps(src + i); + x = _mm512_mul_ps(x, reciprocal_sum_exp); + x = _mm512_mul_ps(x, q); + auto x_i = _mm512_cvtps_epi32(x); + x_i = _mm512_max_epi32(x_i, zero); + _mm512_mask_cvtusepi32_storeu_epi8(dst + i, 0xFFFF, x_i); + } + // handle tails + if (tail) { + auto x = _mm512_maskz_loadu_ps(x_mask, src + i); + auto q = _mm512_maskz_loadu_ps(x_mask, quant + i); + x = _mm512_mul_ps(x, reciprocal_sum_exp); + x = _mm512_mul_ps(x, q); + auto x_i = _mm512_cvtps_epi32(x); + x_i = _mm512_max_epi32(x_i, zero); + _mm512_mask_cvtusepi32_storeu_epi8(dst + i, x_mask, x_i); + } + } + } +} \ No newline at end of file diff --git a/src/transpose_kernel_avx512.hpp b/src/transpose_kernel_avx512.hpp new file mode 100644 index 0000000..3965d80 --- /dev/null +++ b/src/transpose_kernel_avx512.hpp @@ -0,0 +1,111 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#ifdef _WIN32 +#include +#else +#include +#include +#endif +#include "bf16.hpp" +#include "llm_types.hpp" +#include "utility_avx512.hpp" + +namespace llmdnn { + template + void memcpy2d_stride(D* dst, S* src, size_t height, size_t width, size_t src_stride, size_t dst_stride, float* quant=nullptr); + + template + void memcpy2d_stride(D* dst, float* src, size_t height, size_t width, size_t src_stride, size_t dst_stride, float* quant=nullptr) { + static_assert(std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value, + "softmax only support output data types ov::bfloat16/uint8_t/int8_t/float"); + + auto tail = width % 16; + __mmask16 x_mask = _cvtu32_mask16(0xFFFFu >> (16 - tail)); + + for (size_t j = 0; j < height; j++) { + int i; + if (std::is_same::value) { + for(i = 0; i < width - tail; i += 16) { + auto x = _mm512_loadu_ps(src + i); + _mm512_storeu_ps(reinterpret_cast(dst) + i, x); + } + // handle tails + if (tail) { + auto x = _mm512_maskz_loadu_ps(x_mask, src + i); + _mm512_mask_storeu_ps(reinterpret_cast(dst) + i, x_mask, x); + } + } + + if (std::is_same::value) { + for(i = 0; i < width / 32 * 32; i += 32) { + auto x0 = _mm512_loadu_ps(src + i); + auto x1 = _mm512_loadu_ps(src + i + 16); + auto out = _mm512_cvtne2ps_pbh(x1, x0); + _mm512_storeu_epi32(reinterpret_cast(dst) + i, (__m512i)out); + } + if (i < width - tail) { + auto x = _mm512_loadu_ps(src + i); + auto out = _mm512_cvtne2ps_pbh(x, x); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(reinterpret_cast(dst) + i), + _mm512_extracti64x4_epi64(out, 0)); + i += 16; + } + // handle tails + if (tail) { + auto x = _mm512_maskz_loadu_ps(x_mask, src + i); + auto out = _mm512_cvtne2ps_pbh(x, x); + _mm256_mask_storeu_epi16(reinterpret_cast<__m256i*>(reinterpret_cast(dst) + i), + x_mask, _mm512_extracti64x4_epi64(out, 0)); + } + } + + if (std::is_same::value) { + for(i = 0; i < width - tail; i += 16) { + auto x = _mm512_loadu_ps(src + i); + auto q = _mm512_loadu_ps(quant + i); + x = _mm512_mul_ps(x, q); + auto x_i = _mm512_cvtps_epi32(x); + _mm512_mask_cvtsepi32_storeu_epi8(reinterpret_cast(dst) + i, 0xFFFF, x_i); + } + // handle tails + if (tail) { + auto x = _mm512_maskz_loadu_ps(x_mask, src + i); + auto q = _mm512_maskz_loadu_ps(x_mask, quant + i); + x = _mm512_mul_ps(x, q); + auto x_i = _mm512_cvtps_epi32(x); + _mm512_mask_cvtsepi32_storeu_epi8(reinterpret_cast(dst) + i, x_mask, x_i); + } + } + + if (std::is_same::value) { + auto zero = _mm512_setzero_epi32(); + for(i = 0; i < width - tail; i += 16) { + auto x = _mm512_loadu_ps(src + i); + auto q = _mm512_loadu_ps(quant + i); + x = _mm512_mul_ps(x, q); + auto x_i = _mm512_cvtps_epi32(x); + x_i = _mm512_max_epi32(x_i, zero); + _mm512_mask_cvtusepi32_storeu_epi8(reinterpret_cast(dst) + i, 0xFFFF, x_i); + } + // handle tails + if (tail) { + auto x = _mm512_maskz_loadu_ps(x_mask, src + i); + auto q = _mm512_maskz_loadu_ps(x_mask, quant + i); + x = _mm512_mul_ps(x, q); + auto x_i = _mm512_cvtps_epi32(x); + x_i = _mm512_max_epi32(x_i, zero); + _mm512_mask_cvtusepi32_storeu_epi8(reinterpret_cast(dst) + i, x_mask, x_i); + } + } + + src = reinterpret_cast(reinterpret_cast(src) + src_stride); + dst = reinterpret_cast(reinterpret_cast(dst) + dst_stride); + } + } +} \ No newline at end of file diff --git a/src/utility.hpp b/src/utility.hpp new file mode 100644 index 0000000..5eee4fb --- /dev/null +++ b/src/utility.hpp @@ -0,0 +1,56 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include +#include +#include +#include "llm_types.hpp" + +namespace llmdnn { + +inline size_t get_precision_size(data_type_t type) { + switch(type) { + case dnnl_f16: + case dnnl_bf16: + return 2; + case dnnl_f32: + case dnnl_s32: + return 4; + case dnnl_s8: + case dnnl_u8: + return 1; + case dnnl_f64: + return 8; + default: + assert(false && "unknown data type"); + return 0; + } +} + +inline data_type_t get_dt_from_str(const std::string& name) { + static std::pair name2type[] = { + { "f16", dnnl_f16 }, + { "bf16", dnnl_bf16 }, + { "f32", dnnl_f32 }, + { "s32", dnnl_s32 }, + { "i32", dnnl_s32 }, + { "s8", dnnl_s8 }, + { "i8", dnnl_s8 }, + { "u8", dnnl_u8 }, + { "f64", dnnl_f64 }, + }; + for (size_t i = 0; i < sizeof(name2type) / sizeof(name2type[0]); i++) { + if (name == name2type[i].first) + return name2type[i].second; + } + + return dnnl_data_type_undef; +} + +} \ No newline at end of file diff --git a/src/utility_avx512.hpp b/src/utility_avx512.hpp index 5509218..8a5507d 100644 --- a/src/utility_avx512.hpp +++ b/src/utility_avx512.hpp @@ -91,4 +91,24 @@ inline void cvt_i32_f32(float* dst, int32_t* src, size_t ele_num) { } } +inline void mul_add_f32(float* dst, float* src, float mul, float* add, int ele_num) { + auto mul_f = _mm512_set1_ps(mul); + int i; + auto tail = ele_num % 16; + __mmask16 msk = _cvtu32_mask16(0xFFFFu >> (16 - tail)); + for (i = 0; i < ele_num - tail; i += 16) { + auto a_f = _mm512_loadu_ps(src); + auto add_f = _mm512_loadu_ps(add); + _mm512_storeu_ps(dst, _mm512_fmadd_ps(a_f, mul_f, add_f)); + src += 16; + dst += 16; + add += 16; + } + if (tail) { + auto a_f = _mm512_maskz_loadu_ps(msk, src); + auto add_f = _mm512_maskz_loadu_ps(msk, add); + _mm512_mask_storeu_ps(dst, msk, _mm512_fmadd_ps(a_f, mul_f, add_f)); + } +} + } \ No newline at end of file diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 81eb524..f4cf825 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -21,6 +21,12 @@ include_directories(${LLMDNN_HEADERS_DIR}) file(GLOB_RECURSE TEST_SOURCE_FILES ${CMAKE_CURRENT_SOURCE_DIR}/src/*.cpp) +find_package(OpenMP REQUIRED) + add_executable(llmdnn_tests ${TEST_SOURCE_FILES}) -target_link_libraries(llmdnn_tests llmdnn gtest_main stdc++) +target_link_libraries(llmdnn_tests llmdnn gtest_main stdc++ OpenMP::OpenMP_CXX) install(TARGETS llmdnn_tests DESTINATION ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}) + +if (BUILD_PYTHON_TESTS) + add_subdirectory(script/ext) +endif() \ No newline at end of file diff --git a/tests/script/README.md b/tests/script/README.md new file mode 100644 index 0000000..aaf82c5 --- /dev/null +++ b/tests/script/README.md @@ -0,0 +1,19 @@ +# Torch extension to help test + +## usage +prepare python enviroment +``` +python3 -m venv .env +source .env/bin/activate +pip3 install -r requirements.txt +``` + +compile extension +``` +cmake . -DBUILD_PYTHON_TESTS=ON +``` + +run test +``` +python test_mha_gpt.py +``` \ No newline at end of file diff --git a/tests/script/ext/CMakeLists.txt b/tests/script/ext/CMakeLists.txt new file mode 100644 index 0000000..d2eed61 --- /dev/null +++ b/tests/script/ext/CMakeLists.txt @@ -0,0 +1,29 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +cmake_minimum_required (VERSION 3.13) + +project(llmdnn_ext LANGUAGES CXX) + +# ref:https://stackoverflow.com/questions/68401650/how-can-i-make-a-pytorch-extension-with-cmake +execute_process(COMMAND python3 -c "import torch;print(torch.utils.cmake_prefix_path+'/Torch/',end='')" OUTPUT_VARIABLE TORCH_CMAKE_PATH) +set(CMAKE_PREFIX_PATH ${TORCH_CMAKE_PATH}) + +find_package(Python REQUIRED COMPONENTS Development) +find_package(Torch REQUIRED) +find_package(OpenMP REQUIRED) + +add_library(llmdnn_ext SHARED + mha_gpt.cpp + module.cpp + ../../src/test_common.cpp +) +set_target_properties(llmdnn_ext PROPERTIES + OUTPUT_NAME "llmdnn_ext" + POSITION_INDEPENDENT_CODE ON + LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../) +install(TARGETS llmdnn_ext DESTINATION ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}) +target_compile_features(llmdnn_ext PRIVATE cxx_std_14) +target_include_directories(llmdnn_ext PRIVATE ../../src) +target_link_libraries(llmdnn_ext PRIVATE ${TORCH_LIBRARIES} Python::Python llmdnn stdc++ OpenMP::OpenMP_CXX) diff --git a/tests/script/ext/mha_gpt.cpp b/tests/script/ext/mha_gpt.cpp new file mode 100644 index 0000000..776e1b7 --- /dev/null +++ b/tests/script/ext/mha_gpt.cpp @@ -0,0 +1,100 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include "alloca.h" +#include "module.hpp" +#include "utility.hpp" +#include "utility_amx.hpp" +#include "mha_gpt.hpp" +#include "test_common.hpp" + +void regclass_mha_gpt(pybind11::module m) { + py::class_> cls(m, "mha_gpt"); + cls.def(py::init<>()); + cls.def("create", [] (llmdnn::mha_gpt& self, + const size_t num_heads, + const size_t head_size, + const size_t head_size_aligned, + const float normal_factor, + const std::string qkv_precision_name, + const std::string dst_precision_name, + const size_t max_seq_len) { + llmdnn::mha_gpt::create_param param; + param.num_heads = num_heads; + param.head_size = head_size; + param.head_size_aligned = head_size_aligned; + param.normal_factor = normal_factor; + param.qkv_precision = llmdnn::get_dt_from_str(qkv_precision_name); + param.dst_precision = llmdnn::get_dt_from_str(dst_precision_name); + param.max_seq_len = max_seq_len; + if (param.qkv_precision == llmdnn::dnnl_data_type_undef) + throw pybind11::type_error("Incorrect qkv type " + qkv_precision_name); + if (param.dst_precision == llmdnn::dnnl_data_type_undef) + throw pybind11::type_error("Incorrect dst type " + dst_precision_name); + self.create(param); + }, + py::arg("num_heads"), + py::arg("head_size"), + py::arg("head_size_aligned"), + py::arg("normal_factor"), + py::arg("qkv_precision_name"), + py::arg("dst_precision_name"), + py::arg("max_seq_len"), + R"( + Create mha + + :param num_heads: heads number. + :type num_heads: int + )"); + cls.def("exec", [] (llmdnn::mha_gpt& self, torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor attn_mask) { + // q: [batch, num_heads, query_seq_len, head_size] + // k: [batch, num_heads, key_seq_len, head_size] + // v: [batch, num_heads, value_seq_len, head_size] + // attn_mask: [batch, MAX_POSITION_EMBEDDINGS] + // out: [batch, query_seq_len, num_heads * head_size] + AT_ASSERT(q.dim() == 4 && k.dim() == 4 && v.dim() == 4 && attn_mask.dim() == 2); + auto batch = q.size(0); + auto num_heads = q.size(1); + auto query_seq_len = q.size(2); + auto head_size = q.size(3); + auto key_seq_len = k.size(2); + auto attn_len = attn_mask.size(1); + AT_ASSERT(key_seq_len == v.size(2) && + batch == k.size(0) && batch == v.size(0) && batch == attn_mask.size(0) && + num_heads == k.size(1) && num_heads == v.size(1) && + head_size == k.size(3) && head_size == v.size(3)); + + auto out = q.new_empty({batch, query_seq_len, num_heads * head_size}); + llmdnn::mha_gpt::exec_param param; + param.batch = batch; + param.query_seq_len = query_seq_len; + param.key_seq_len = key_seq_len; + param.q = q.data_ptr(); + param.attn_output = out.data_ptr(); + param.head_stride_in_kv = key_seq_len * head_size; + param.k = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); + param.v = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); + param.attention_mask = reinterpret_cast(alloca(batch * sizeof(float*))); + for (size_t i = 0; i < batch; i++) { + param.k[i] = k.data_ptr() + i * num_heads * key_seq_len * head_size; + param.v[i] = v.data_ptr() + i * num_heads * key_seq_len * head_size; + param.attention_mask[i] = attn_mask.data_ptr() + i * attn_len; + } + + self.exec(param); + return out; + }, + py::arg("q"), + py::arg("k"), + py::arg("v"), + py::arg("attn_mask"), + R"( + exec mha + + :param num_heads: heads number. + :type num_heads: int + )"); +} \ No newline at end of file diff --git a/tests/script/ext/module.cpp b/tests/script/ext/module.cpp new file mode 100644 index 0000000..a84f216 --- /dev/null +++ b/tests/script/ext/module.cpp @@ -0,0 +1,18 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include "module.hpp" +#include "utility_amx.hpp" +#include "mha_gpt.hpp" +#include "test_common.hpp" + +PYBIND11_MODULE(libllmdnn_ext, m) { + static bool initAMX = initXTILE(); + if (!initAMX) { + std::cout << "init amx failed.\n"; + } + regclass_mha_gpt(m); +} \ No newline at end of file diff --git a/tests/script/ext/module.hpp b/tests/script/ext/module.hpp new file mode 100644 index 0000000..8f8f720 --- /dev/null +++ b/tests/script/ext/module.hpp @@ -0,0 +1,9 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +void regclass_mha_gpt(pybind11::module m); \ No newline at end of file diff --git a/tests/script/ext/setup.py b/tests/script/ext/setup.py new file mode 100644 index 0000000..f27924f --- /dev/null +++ b/tests/script/ext/setup.py @@ -0,0 +1,35 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +from setuptools import setup, Extension +from torch.utils import cpp_extension +import sys + +''' +using intel compiler: +source ~/intel/oneapi/setvars.sh +export CXX=icx +export CC=icx +''' +setup(name='llmdnn', + ext_modules=[ + cpp_extension.CppExtension( + 'llmdnn', + ['module.cpp', 'mha_gpt.cpp', '../../src/test_common.cpp'], + extra_compile_args=[ '-fopenmp', + '-march=native', + #'-g' + ], + #extra_link_args=['-lgomp'], + include_dirs=['../../src', + '../../../include', + '../../../src'], + library_dirs=[f'{sys.prefix}/lib', + '../../../../../../../../bin/intel64/Debug'], + #runtime_library_dirs=[ f'{sys.prefix}/lib', ], + libraries=['llmdnn', + 'stdc++']), + ], + cmdclass={'build_ext': cpp_extension.BuildExtension.with_options(use_ninja=False)} + ) \ No newline at end of file diff --git a/tests/script/requirements.txt b/tests/script/requirements.txt new file mode 100644 index 0000000..79b1ea8 --- /dev/null +++ b/tests/script/requirements.txt @@ -0,0 +1,2 @@ +numpy==1.24.2 +torch==2.0.0+cpu diff --git a/tests/script/test_mha_gpt.py b/tests/script/test_mha_gpt.py new file mode 100644 index 0000000..8ea1514 --- /dev/null +++ b/tests/script/test_mha_gpt.py @@ -0,0 +1,151 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +import math +import sys +import torch +import numpy as np +import libllmdnn_ext as ld +from torch import nn + +# copy from transformers/models/gpt_neox/modeling_gpt_neox.py +class GPTNeoXAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.num_attention_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.head_size = self.hidden_size // self.num_attention_heads + max_positions = config.max_position_embeddings + self.register_buffer( + "bias", + torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view( + 1, 1, max_positions, max_positions + ), + ) + self.norm_factor = torch.sqrt(torch.tensor(self.head_size, dtype=torch.float32)).to(torch.get_default_dtype()) + + def forward(self, query, key, value, attention_mask=None): + # Compute attention + attn_output, attn_weights = self._attn(query, key, value, attention_mask) + + # Reshape outputs + attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size) + + return attn_output + + @classmethod + def _merge_heads(cls, tensor, num_attention_heads, attn_head_size): + """ + Merges attn_head_size dim and num_attn_heads dim into hidden dim + """ + # tensor [bs, num_attention_heads, seq_len, attn_head_size] + tensor = tensor.permute(0, 2, 1, 3).contiguous() + # -> [bs, seq_len, num_attention_heads, attn_head_size] + tensor = tensor.view(tensor.size(0), tensor.size(1), num_attention_heads * attn_head_size) + # -> [bs, seq_len, hidden_size] + return tensor + + def _attn(self, query, key, value, attention_mask=None, head_mask=None): + # q, k, v: [bs, num_attention_heads, seq_len, attn_head_size] + # compute causal mask from causal mask buffer + batch_size, num_attention_heads, query_length, attn_head_size = query.size() + key_length = key.size(-2) + + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] + + query = query.view(batch_size * num_attention_heads, query_length, attn_head_size) + key = key.view(batch_size * num_attention_heads, key_length, attn_head_size) + attn_scores = torch.zeros( + batch_size * num_attention_heads, + query_length, + key_length, + dtype=query.dtype, + device=key.device, + ) + attn_scores = torch.baddbmm( + attn_scores, + query, + key.transpose(1, 2), + beta=1.0, + alpha=(torch.tensor(1.0, dtype=self.norm_factor.dtype, device=self.norm_factor.device) / self.norm_factor), + ) + attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length) + + mask_value = torch.finfo(attn_scores.dtype).min + # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + mask_value = torch.tensor(mask_value, dtype=attn_scores.dtype).to(attn_scores.device) + attn_scores = torch.where(causal_mask, attn_scores, mask_value) + + if attention_mask is not None: + # Apply the attention mask + attn_scores = attn_scores + attention_mask + + attn_weights = nn.functional.softmax(attn_scores, dim=-1) + attn_weights = attn_weights.to(value.dtype) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + return attn_output, attn_weights + +class GPTNeoXAttentionExt: + def __init__(self, num_attention_heads, hidden_size, max_position_embeddings): + self.mha = ld.mha_gpt() + num_heads = num_attention_heads + head_size = hidden_size // num_attention_heads + max_seq_len = max_position_embeddings + + head_size_aligned = head_size + normal_factor = math.sqrt(head_size) + qkv_precision_name = 'bf16' + dst_precision_name = 'bf16' + self.mha.create(num_heads, head_size, head_size_aligned, normal_factor, qkv_precision_name, + dst_precision_name, max_seq_len) + + def forward(self, query, key, value, attention_mask=None): + return self.mha.exec(query, key, value, attention_mask) + +HEAD_NUM = 12 #32 +SIZE_PER_HEAD = 80 +HIDDEN_SIZE = HEAD_NUM * SIZE_PER_HEAD +MAX_POSITION_EMBEDDINGS = 1024 #2048 +def test(inputs): + class FakeConfig: + def __init__(self): + self.num_attention_heads = HEAD_NUM + self.hidden_size = HIDDEN_SIZE + self.max_position_embeddings = MAX_POSITION_EMBEDDINGS + config = FakeConfig() + ref_net = GPTNeoXAttention(config) + ref_net = ref_net.to(dtype=torch.bfloat16) + net = GPTNeoXAttentionExt(HEAD_NUM, HIDDEN_SIZE, MAX_POSITION_EMBEDDINGS) + with torch.cpu.amp.autocast(): + for (input, i) in enumerate(inputs): + (q, k, v, attn_mask) = input + q = torch.from_numpy(q).to(torch.bfloat16) + k = torch.from_numpy(k).to(torch.bfloat16) + v = torch.from_numpy(v).to(torch.bfloat16) + ref_output = ref_net.forward(q, k, v, attn_mask) + output = net.forward(q, k, v, attn_mask) + if not torch.allclose(ref_output, output): + print(f"error at index {i}") + + return + +if __name__ == "__main__": + inputs = [ + # q, k, v, attn_mask + # q: [batch, num_heads, query_seq_len, head_size] + # k: [batch, num_heads, key_seq_len, head_size] + # v: [batch, num_heads, value_seq_len, head_size] + # attn: [1, MAX_POSITION_EMBEDDINGS] + (np.random.random(size=[1, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[1, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[1, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), + np.zeros([1, MAX_POSITION_EMBEDDINGS], dtype=np.float32)) + ] + test(inputs) \ No newline at end of file diff --git a/tests/src/test_common.cpp b/tests/src/test_common.cpp index 28760b4..595b243 100644 --- a/tests/src/test_common.cpp +++ b/tests/src/test_common.cpp @@ -9,7 +9,9 @@ #include #include #include -#include "test_common.h" +#include +#include "test_common.hpp" +#include "simple_parallel.hpp" #ifndef _GNU_SOURCE #define _GNU_SOURCE /* See feature_test_macros(7) */ @@ -58,4 +60,21 @@ bool initXTILE() { // XFEATURE_XTILEDATA set successfully, TMUL usage is allowed return true; +} + +namespace ov { +namespace cpu { + +size_t getTotalThreads() { + return omp_get_max_threads(); +} + +void TrySimpleParallelFor(const std::ptrdiff_t total, const std::function& fn) { + #pragma omp parallel for + for(std::ptrdiff_t i = 0; i < total; i++) { + fn(i); + } +} + +} } \ No newline at end of file diff --git a/tests/src/test_common.h b/tests/src/test_common.hpp similarity index 100% rename from tests/src/test_common.h rename to tests/src/test_common.hpp diff --git a/tests/src/test_fc_kernel.cpp b/tests/src/test_fc_kernel.cpp index c4d7087..04583f7 100644 --- a/tests/src/test_fc_kernel.cpp +++ b/tests/src/test_fc_kernel.cpp @@ -13,7 +13,7 @@ #include "llm_fc.hpp" #include "tensor2d.hpp" #include "tensor2d_helper.hpp" -#include "test_common.h" +#include "test_common.hpp" using namespace std; using namespace llmdnn; diff --git a/tests/src/test_mm_kernel.cpp b/tests/src/test_mm_kernel.cpp index 91863c7..9a081d3 100644 --- a/tests/src/test_mm_kernel.cpp +++ b/tests/src/test_mm_kernel.cpp @@ -13,7 +13,7 @@ #include "llm_mm.hpp" #include "tensor2d.hpp" #include "tensor2d_helper.hpp" -#include "test_common.h" +#include "test_common.hpp" using namespace std; using namespace llmdnn; diff --git a/tests/src/test_softmax_kernel_avx512.cpp b/tests/src/test_softmax_kernel_avx512.cpp new file mode 100644 index 0000000..fef8577 --- /dev/null +++ b/tests/src/test_softmax_kernel_avx512.cpp @@ -0,0 +1,129 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include +#include +#include +#include +#include "gtest/gtest.h" +#include "llm_mm.hpp" +#include "tensor2d.hpp" +#include "tensor2d_helper.hpp" +#include "softmax_kernel_avx512.hpp" +#include "test_common.hpp" + +using namespace std; +using namespace llmdnn; +using ::testing::TestWithParam; +using ::testing::Values; +using ::testing::ValuesIn; + +using SoftmaxTestParamSet = std::tuple< + data_type_t // data type + >; + +class SoftmaxTest : public TestWithParam { +public: + static std::string getTestCaseName(const testing::TestParamInfo& obj) { + data_type_t types; + std::tie(types) = obj.param; + + std::ostringstream result; + result << dtype_to_str(types); + return result.str(); + } + +protected: + virtual void SetUp() { + std::tie(_types) = GetParam(); + }; + + template + static void gen_ref(tensor2D& x, tensor2D& out, tensor2D& quant) { + tensor2D y = x.clone(); + float x_max = std::numeric_limits::lowest(); + for(int i = 0; i < x.dims[1]; i++) { + x_max = std::max(x_max, x[i]); + } + float sum = 0; + for(int i = 0; i < x.dims[1]; i++) { + y[i] = expf(x[i] - x_max); + sum += y[i]; + } + for(int i = 0; i < x.dims[1]; i++) { + y[i] = y[i] / sum; + } + out.resize(x.dims[0], x.dims[1], true); + if (std::is_same::value) { + memcpy(out.data, y.data, x.dims[0] * x.dims[1] * sizeof(float)); + } + if (std::is_same::value) { + for(int i = 0; i < x.dims[1]; i++) { + out[i] = y[i]; + } + } +#define CLIP(x, low, high) \ + (x < low ? low : (x > high ? high : x)) + if (std::is_same::value) { + for(int i = 0; i < x.dims[1]; i++) { + auto tmp = y[i] * quant[i]; + out[i] = static_cast(CLIP(tmp, -128, 127)); + } + } + if (std::is_same::value) { + for(int i = 0; i < x.dims[1]; i++) { + auto tmp = y[i] * quant[i]; + out[i] = static_cast(CLIP(tmp, 0, 255)); + } + } +#undef CLIP + } + + template + void test(float thresh) { + for (int n = 1; n < 129; n++) { + tensor2D A(1, n, true); + tensor2D quant(1, n, true); + tensor2D out(1, n, true), out_ref; + for (int i = 0; i < n; i++) { + A[i] = static_cast(i) - n / 2; + } + quant = 128.f; + gen_ref(A, out_ref, quant); + llmdnn::softmax(out.data, A.data, n, nullptr, nullptr, quant.data); + for (int i = 0; i < n; i++) { + float a = out[i]; + float b = out_ref[i]; + if (std::abs(a - b) > thresh) { + FAIL() << " N: " << n << " pos: " << i << " opt: " << a << " ref: " << b; + } + } + } + } + + data_type_t _types; +}; + +TEST_P(SoftmaxTest, Func) { + if (_types == dnnl_s8) { + test(1.1f); + } else if (_types == dnnl_u8) { + test(1.1f); + } else if (_types == dnnl_f32) { + test(0.00001f); + } else { + test(0.01f); + } +} + +const std::vector types = { + dnnl_s8, dnnl_bf16, dnnl_u8, dnnl_f32 +}; + +INSTANTIATE_TEST_SUITE_P(smoke_Softmax, SoftmaxTest, + ::testing::Combine(ValuesIn(types)), + SoftmaxTest::getTestCaseName); diff --git a/tests/src/test_transpose_kernel_avx512.cpp b/tests/src/test_transpose_kernel_avx512.cpp new file mode 100644 index 0000000..2dfe988 --- /dev/null +++ b/tests/src/test_transpose_kernel_avx512.cpp @@ -0,0 +1,131 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include +#include +#include +#include +#include "gtest/gtest.h" +#include "llm_mm.hpp" +#include "tensor2d.hpp" +#include "tensor2d_helper.hpp" +#include "transpose_kernel_avx512.hpp" +#include "test_common.hpp" + +using namespace std; +using namespace llmdnn; +using ::testing::TestWithParam; +using ::testing::Values; +using ::testing::ValuesIn; + +using TransposeTestParamSet = std::tuple< + data_type_t // data type + >; + +class TransposeTest : public TestWithParam { +public: + static std::string getTestCaseName(const testing::TestParamInfo& obj) { + data_type_t types; + std::tie(types) = obj.param; + + std::ostringstream result; + result << dtype_to_str(types); + return result.str(); + } + +protected: + virtual void SetUp() { + std::tie(_types) = GetParam(); + }; + + template + static void gen_ref(T* dst, float* src, size_t height, size_t width, size_t src_stride, size_t dst_stride, float* quant) { + for (int j = 0; j < height; j++) { + if (std::is_same::value) { + memcpy(dst, src, width * sizeof(float)); + } + if (std::is_same::value) { + for(int i = 0; i < width; i++) { + dst[i] = src[i]; + } + } + #define CLIP(x, low, high) \ + (x < low ? low : (x > high ? high : x)) + if (std::is_same::value) { + for(int i = 0; i < width; i++) { + auto tmp = src[i] * quant[i]; + dst[i] = static_cast(CLIP(tmp, -128, 127)); + } + } + if (std::is_same::value) { + for(int i = 0; i < width; i++) { + auto tmp = src[i] * quant[i]; + dst[i] = static_cast(CLIP(tmp, 0, 255)); + } + } + #undef CLIP + src = reinterpret_cast(reinterpret_cast(src) + src_stride); + dst = reinterpret_cast(reinterpret_cast(dst) + dst_stride); + } + } + + template + void test(float thresh) { + // [num_heads, query_seq_len, head_size] => [query_seq_len, num_heads * head_size] + int num_heads = 2, query_seq_len = 10, head_size; + for (int head_size = 1; head_size < 129; head_size++) { + tensor2D src(num_heads, head_size * query_seq_len, true); + tensor2D quant(1, head_size, true); + tensor2D dst(query_seq_len, num_heads * head_size, true); + tensor2D dst_ref(query_seq_len, num_heads * head_size, true); + for (int i = 0; i < num_heads * head_size * query_seq_len; i++) { + src[i] = i % 253 - 127; + } + quant = 1.28f; + auto* dst_p = dst.data; + auto* dst_p_ref = dst_ref.data; + for (int i = 0; i < num_heads; i++) { + auto* src_p = &src(i, 0); + llmdnn::memcpy2d_stride(dst_p, src_p, query_seq_len, head_size, + head_size * sizeof(float), num_heads * head_size * sizeof(T), quant.data); + gen_ref(dst_p_ref, src_p, query_seq_len, head_size, + head_size * sizeof(float), num_heads * head_size * sizeof(T), quant.data); + dst_p += head_size; + dst_p_ref += head_size; + } + for (int i = 0; i < num_heads * head_size * query_seq_len; i++) { + float a = dst[i]; + float b = dst_ref[i]; + if (a - b > thresh) { + FAIL() << " N: " << head_size << " pos: " << i << " opt: " << a << " ref: " << b; + } + } + } + } + + data_type_t _types; +}; + +TEST_P(TransposeTest, memcpy2d) { + if (_types == dnnl_s8) { + test(1.1f); + } else if (_types == dnnl_u8) { + test(1.1f); + } else if (_types == dnnl_f32) { + test(0.00001f); + } else { + test(0.01f); + } +} + +const std::vector types = { + dnnl_s8, dnnl_bf16, dnnl_u8, dnnl_f32 +}; + +INSTANTIATE_TEST_SUITE_P(smoke_Transpose, TransposeTest, + ::testing::Combine(ValuesIn(types)), + TransposeTest::getTestCaseName); diff --git a/tests/src/test_utility.cpp b/tests/src/test_utility.cpp new file mode 100644 index 0000000..e3b91e0 --- /dev/null +++ b/tests/src/test_utility.cpp @@ -0,0 +1,37 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include +#include +#include +#include "gtest/gtest.h" +#include "llm_mm.hpp" +#include "tensor2d.hpp" +#include "tensor2d_helper.hpp" +#include "utility_avx512.hpp" +#include "test_common.hpp" + +using namespace std; +using namespace llmdnn; +using ::testing::Values; +using ::testing::ValuesIn; + +TEST(smoke_Utility, muladd) { + float normal_factor = 1.2f; + for (size_t len = 1; len < 129; len++) { + std::vector x(len), x_out(len), bias(len), ref(len); + for (size_t i = 0; i < x.size(); i++) { + x[i] = -10.0f + i; + bias[i] = -100.0f + i; + ref[i] = x[i] * normal_factor + bias[i]; + } + mul_add_f32(x_out.data(), x.data(), normal_factor, bias.data(), len); + for (size_t i = 0; i < x.size(); i++) { + ASSERT_TRUE(std::abs(x_out[i] - ref[i]) < 0.0001f) << " length: " << len << " pos: " << i << " cur: " << x[i] << " ref: " << ref[i]; + } + } +} \ No newline at end of file diff --git a/tests/src/test_utility_repack1x2.cpp b/tests/src/test_utility_repack1x2.cpp index 0e063d9..4597b5b 100644 --- a/tests/src/test_utility_repack1x2.cpp +++ b/tests/src/test_utility_repack1x2.cpp @@ -14,7 +14,7 @@ #include "tensor2d.hpp" #include "tensor2d_helper.hpp" #include "mm_kernel_amx.hpp" -#include "test_common.h" +#include "test_common.hpp" using namespace std; using namespace llmdnn; @@ -30,7 +30,6 @@ class RepackTest : public TestWithParam { public: static std::string getTestCaseName(const testing::TestParamInfo& obj) { data_type_t types; - int K, N; std::tie(types) = obj.param; std::ostringstream result; From 3f18a486d857f0895d76da4ad00bfad43c18d783 Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Thu, 29 Jun 2023 01:51:25 +0800 Subject: [PATCH 05/54] move to gcc11; add mha_gpt bf16 support --- CMakeLists.txt | 4 +- src/mha_gpt.cpp | 16 +- src/mm_kernel_amx.hpp | 536 +++++++++++++++++----------------- tests/CMakeLists.txt | 4 - tests/script/README.md | 2 +- tests/script/build.sh | 7 + tests/script/ext/mha_gpt.cpp | 16 +- tests/script/ext/module.cpp | 2 +- tests/script/ext/setup.py | 14 +- tests/script/requirements.txt | 3 +- tests/script/test_mha_gpt.py | 34 ++- 11 files changed, 332 insertions(+), 306 deletions(-) create mode 100755 tests/script/build.sh diff --git a/CMakeLists.txt b/CMakeLists.txt index 0061e15..84387af 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -25,9 +25,7 @@ if(MSVC) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /W4") endif() elseif(CMAKE_COMPILER_IS_GNUCC OR CMAKE_COMPILER_IS_GNUCXX) - # TODO - message(FATAL_ERROR "Not support yet. Use intel compiler 2023.0+.") - if(CMAKE_CXX_COMPILER_VERSION VERSION_LESS "12.0") + if(CMAKE_CXX_COMPILER_VERSION VERSION_LESS "11.0") message(FATAL_ERROR "Insufficient gcc compiler version, current ${CMAKE_CXX_COMPILER_VERSION}, minimum 12.0.") endif() set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=sapphirerapids") diff --git a/src/mha_gpt.cpp b/src/mha_gpt.cpp index 916c2f8..34cc870 100644 --- a/src/mha_gpt.cpp +++ b/src/mha_gpt.cpp @@ -106,6 +106,7 @@ void mha_gpt::Impl::mha_bf16(const exec_param ¶m) { size_t batch_stride_in_q = head_stride_in_q * _create_param.num_heads; size_t head_stride_in_attn = _create_param.head_size; size_t batch_stride_in_attn = _create_param.head_size * _create_param.num_heads * param.query_seq_len; + size_t causal_mask_offset_start = param.key_seq_len - param.query_seq_len; if (is_vector) { parallel_for2d(param.batch, _create_param.num_heads, [&](size_t threadNum, size_t i0, size_t i1) { @@ -135,7 +136,7 @@ void mha_gpt::Impl::mha_bf16(const exec_param ¶m) { amx_kernel::PP::BiasGeluStore pp(matQKV); (*qKVGemm_ops[threadNum])(matQK, matV, 0, _create_param.head_size, pp); memcpy2d_stride(reinterpret_cast(pOut_aux), reinterpret_cast(bufferMatMul1Out_local), param.query_seq_len, - _create_param.head_size, _create_param.head_size_aligned, _create_param.num_heads * _create_param.head_size, nullptr); + _create_param.head_size, _create_param.head_size_aligned * sizeof(float), _create_param.num_heads * _create_param.head_size * sizeof(ov::bfloat16), nullptr); }); } else { auto numThreads = getTotalThreads(); @@ -147,6 +148,7 @@ void mha_gpt::Impl::mha_bf16(const exec_param ¶m) { int seq; int start {0}, end {0}; splitter(work_amount, static_cast(numThreads), threadNum, start, end); + if (start >= work_amount) return; parallel_it_init(start, i0, param.batch, i1, _create_param.num_heads, seq, seq_cout_all); uint8_t* prev_k = nullptr; @@ -176,7 +178,7 @@ void mha_gpt::Impl::mha_bf16(const exec_param ¶m) { auto pMatMul0Out = bufferMatMul0Out_local; // loop along K dimension - size_t valid_softmax_items = seq_start + 1; + size_t valid_softmax_items = causal_mask_offset_start + seq_start + 1; for (size_t m = 0; m < seq_cout; m++) { float* src = reinterpret_cast(pMatMul0Out + m * rndup(param.key_seq_len * sizeof(float), 64)); ov::bfloat16* dst = reinterpret_cast(pMatMul0Out + m * rndup(param.key_seq_len * sizeof(ov::bfloat16), 64)); @@ -198,7 +200,7 @@ void mha_gpt::Impl::mha_bf16(const exec_param ¶m) { (*qKVGemm_ops[threadNum])(matQKBF16, matV, 0, _create_param.head_size, pp2, prev_v == pVIn0_aux); prev_v = pVIn0_aux; memcpy2d_stride(reinterpret_cast(pOut_aux), reinterpret_cast(bufferMatMul1Out_local), seq_cout, - _create_param.head_size, _create_param.head_size_aligned, _create_param.num_heads * _create_param.head_size, nullptr); + _create_param.head_size, _create_param.head_size_aligned * sizeof(float), _create_param.num_heads * _create_param.head_size * sizeof(ov::bfloat16), nullptr); parallel_it_step(i0, param.batch, i1, _create_param.num_heads, seq, seq_cout_all); } }); @@ -229,6 +231,7 @@ void mha_gpt::Impl::mha_i8(const exec_param ¶m) { size_t batch_stride_in_q = head_stride_in_q * _create_param.num_heads; size_t head_stride_in_attn = _create_param.head_size; size_t batch_stride_in_attn = _create_param.head_size * _create_param.num_heads * param.query_seq_len; + size_t causal_mask_offset_start = param.key_seq_len - param.query_seq_len; if (is_vector) { parallel_for2d(param.batch, _create_param.num_heads, [&](size_t threadNum, size_t i0, size_t i1) { @@ -259,7 +262,7 @@ void mha_gpt::Impl::mha_i8(const exec_param ¶m) { amx_kernel::PP::BiasGeluStore pp(matQKV); (*qKVGemm_ops[threadNum])(matQK, matV, 0, _create_param.head_size, pp); memcpy2d_stride(reinterpret_cast(pOut_aux), reinterpret_cast(bufferMatMul1Out_local), param.query_seq_len, - _create_param.head_size, _create_param.head_size_aligned, _create_param.num_heads * _create_param.head_size, qkv_quant.data()); + _create_param.head_size, _create_param.head_size_aligned * sizeof(float), _create_param.num_heads * _create_param.head_size, qkv_quant.data()); }); } else { auto numThreads = getTotalThreads(); @@ -271,6 +274,7 @@ void mha_gpt::Impl::mha_i8(const exec_param ¶m) { int seq; int start {0}, end {0}; splitter(work_amount, static_cast(numThreads), threadNum, start, end); + if (start >= work_amount) return; parallel_it_init(start, i0, param.batch, i1, _create_param.num_heads, seq, seq_cout_all); uint8_t* prev_k = nullptr; @@ -300,7 +304,7 @@ void mha_gpt::Impl::mha_i8(const exec_param ¶m) { auto pMatMul0Out = bufferMatMul0Out_local; // loop along K dimension - size_t valid_softmax_items = seq_start + 1; + size_t valid_softmax_items = causal_mask_offset_start + seq_start + 1; for (size_t m = 0; m < seq_cout; m++) { float* src = reinterpret_cast(pMatMul0Out + m * rndup(param.key_seq_len * sizeof(float), 64)); uint8_t* dst = reinterpret_cast(pMatMul0Out + m * rndup(param.key_seq_len * sizeof(uint8_t), 64)); @@ -324,7 +328,7 @@ void mha_gpt::Impl::mha_i8(const exec_param ¶m) { // matmul1: [batch, num_heads, query_seq_len, head_size] // attn_output: [batch, query_seq_len, num_heads * head_size] memcpy2d_stride(reinterpret_cast(pOut_aux), reinterpret_cast(bufferMatMul1Out_local), seq_cout, - _create_param.head_size, _create_param.head_size_aligned, _create_param.num_heads * _create_param.head_size, qkv_quant.data()); + _create_param.head_size, _create_param.head_size_aligned * sizeof(float), _create_param.num_heads * _create_param.head_size, qkv_quant.data()); parallel_it_step(i0, param.batch, i1, _create_param.num_heads, seq, seq_cout_all); } }); diff --git a/src/mm_kernel_amx.hpp b/src/mm_kernel_amx.hpp index 611972b..b047999 100644 --- a/src/mm_kernel_amx.hpp +++ b/src/mm_kernel_amx.hpp @@ -201,7 +201,7 @@ namespace functional { rc = _mm512_loadu_epi32(pA + 12*stride); rd = _mm512_loadu_epi32(pA + 13*stride); re = _mm512_loadu_epi32(pA + 14*stride); - rf = _mm512_setzero(); + rf = _mm512_setzero_epi32(); break; case 14: r0 = _mm512_loadu_epi32(pA); @@ -218,8 +218,8 @@ namespace functional { rb = _mm512_loadu_epi32(pA + 11*stride); rc = _mm512_loadu_epi32(pA + 12*stride); rd = _mm512_loadu_epi32(pA + 13*stride); - re = _mm512_setzero(); - rf = _mm512_setzero(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); break; case 13: r0 = _mm512_loadu_epi32(pA); @@ -235,9 +235,9 @@ namespace functional { ra = _mm512_loadu_epi32(pA + 10*stride); rb = _mm512_loadu_epi32(pA + 11*stride); rc = _mm512_loadu_epi32(pA + 12*stride); - rd = _mm512_setzero(); - re = _mm512_setzero(); - rf = _mm512_setzero(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); break; case 12: r0 = _mm512_loadu_epi32(pA); @@ -252,10 +252,10 @@ namespace functional { r9 = _mm512_loadu_epi32(pA + 9*stride); ra = _mm512_loadu_epi32(pA + 10*stride); rb = _mm512_loadu_epi32(pA + 11*stride); - rc = _mm512_setzero(); - rd = _mm512_setzero(); - re = _mm512_setzero(); - rf = _mm512_setzero(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); break; case 11: r0 = _mm512_loadu_epi32(pA); @@ -269,11 +269,11 @@ namespace functional { r8 = _mm512_loadu_epi32(pA + 8*stride); r9 = _mm512_loadu_epi32(pA + 9*stride); ra = _mm512_loadu_epi32(pA + 10*stride); - rb = _mm512_setzero(); - rc = _mm512_setzero(); - rd = _mm512_setzero(); - re = _mm512_setzero(); - rf = _mm512_setzero(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); break; case 10: r0 = _mm512_loadu_epi32(pA); @@ -286,12 +286,12 @@ namespace functional { r7 = _mm512_loadu_epi32(pA + 7*stride); r8 = _mm512_loadu_epi32(pA + 8*stride); r9 = _mm512_loadu_epi32(pA + 9*stride); - ra = _mm512_setzero(); - rb = _mm512_setzero(); - rc = _mm512_setzero(); - rd = _mm512_setzero(); - re = _mm512_setzero(); - rf = _mm512_setzero(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); break; case 9: r0 = _mm512_loadu_epi32(pA); @@ -303,13 +303,13 @@ namespace functional { r6 = _mm512_loadu_epi32(pA + 6*stride); r7 = _mm512_loadu_epi32(pA + 7*stride); r8 = _mm512_loadu_epi32(pA + 8*stride); - r9 = _mm512_setzero(); - ra = _mm512_setzero(); - rb = _mm512_setzero(); - rc = _mm512_setzero(); - rd = _mm512_setzero(); - re = _mm512_setzero(); - rf = _mm512_setzero(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); break; case 8: r0 = _mm512_loadu_epi32(pA); @@ -320,14 +320,14 @@ namespace functional { r5 = _mm512_loadu_epi32(pA + 5*stride); r6 = _mm512_loadu_epi32(pA + 6*stride); r7 = _mm512_loadu_epi32(pA + 7*stride); - r8 = _mm512_setzero(); - r9 = _mm512_setzero(); - ra = _mm512_setzero(); - rb = _mm512_setzero(); - rc = _mm512_setzero(); - rd = _mm512_setzero(); - re = _mm512_setzero(); - rf = _mm512_setzero(); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); break; case 7: r0 = _mm512_loadu_epi32(pA); @@ -337,15 +337,15 @@ namespace functional { r4 = _mm512_loadu_epi32(pA + 4*stride); r5 = _mm512_loadu_epi32(pA + 5*stride); r6 = _mm512_loadu_epi32(pA + 6*stride); - r7 = _mm512_setzero(); - r8 = _mm512_setzero(); - r9 = _mm512_setzero(); - ra = _mm512_setzero(); - rb = _mm512_setzero(); - rc = _mm512_setzero(); - rd = _mm512_setzero(); - re = _mm512_setzero(); - rf = _mm512_setzero(); + r7 = _mm512_setzero_epi32(); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); break; case 6: r0 = _mm512_loadu_epi32(pA); @@ -354,16 +354,16 @@ namespace functional { r3 = _mm512_loadu_epi32(pA + 3*stride); r4 = _mm512_loadu_epi32(pA + 4*stride); r5 = _mm512_loadu_epi32(pA + 5*stride); - r6 = _mm512_setzero(); - r7 = _mm512_setzero(); - r8 = _mm512_setzero(); - r9 = _mm512_setzero(); - ra = _mm512_setzero(); - rb = _mm512_setzero(); - rc = _mm512_setzero(); - rd = _mm512_setzero(); - re = _mm512_setzero(); - rf = _mm512_setzero(); + r6 = _mm512_setzero_epi32(); + r7 = _mm512_setzero_epi32(); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); break; case 5: r0 = _mm512_loadu_epi32(pA); @@ -371,89 +371,89 @@ namespace functional { r2 = _mm512_loadu_epi32(pA + 2*stride); r3 = _mm512_loadu_epi32(pA + 3*stride); r4 = _mm512_loadu_epi32(pA + 4*stride); - r5 = _mm512_setzero(); - r6 = _mm512_setzero(); - r7 = _mm512_setzero(); - r8 = _mm512_setzero(); - r9 = _mm512_setzero(); - ra = _mm512_setzero(); - rb = _mm512_setzero(); - rc = _mm512_setzero(); - rd = _mm512_setzero(); - re = _mm512_setzero(); - rf = _mm512_setzero(); + r5 = _mm512_setzero_epi32(); + r6 = _mm512_setzero_epi32(); + r7 = _mm512_setzero_epi32(); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); break; case 4: r0 = _mm512_loadu_epi32(pA); r1 = _mm512_loadu_epi32(pA + stride); r2 = _mm512_loadu_epi32(pA + 2*stride); r3 = _mm512_loadu_epi32(pA + 3*stride); - r4 = _mm512_setzero(); - r5 = _mm512_setzero(); - r6 = _mm512_setzero(); - r7 = _mm512_setzero(); - r8 = _mm512_setzero(); - r9 = _mm512_setzero(); - ra = _mm512_setzero(); - rb = _mm512_setzero(); - rc = _mm512_setzero(); - rd = _mm512_setzero(); - re = _mm512_setzero(); - rf = _mm512_setzero(); + r4 = _mm512_setzero_epi32(); + r5 = _mm512_setzero_epi32(); + r6 = _mm512_setzero_epi32(); + r7 = _mm512_setzero_epi32(); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); break; case 3: r0 = _mm512_loadu_epi32(pA); r1 = _mm512_loadu_epi32(pA + stride); r2 = _mm512_loadu_epi32(pA + 2*stride); - r3 = _mm512_setzero(); - r4 = _mm512_setzero(); - r5 = _mm512_setzero(); - r6 = _mm512_setzero(); - r7 = _mm512_setzero(); - r8 = _mm512_setzero(); - r9 = _mm512_setzero(); - ra = _mm512_setzero(); - rb = _mm512_setzero(); - rc = _mm512_setzero(); - rd = _mm512_setzero(); - re = _mm512_setzero(); - rf = _mm512_setzero(); + r3 = _mm512_setzero_epi32(); + r4 = _mm512_setzero_epi32(); + r5 = _mm512_setzero_epi32(); + r6 = _mm512_setzero_epi32(); + r7 = _mm512_setzero_epi32(); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); break; case 2: r0 = _mm512_loadu_epi32(pA); r1 = _mm512_loadu_epi32(pA + stride); - r2 = _mm512_setzero(); - r3 = _mm512_setzero(); - r4 = _mm512_setzero(); - r5 = _mm512_setzero(); - r6 = _mm512_setzero(); - r7 = _mm512_setzero(); - r8 = _mm512_setzero(); - r9 = _mm512_setzero(); - ra = _mm512_setzero(); - rb = _mm512_setzero(); - rc = _mm512_setzero(); - rd = _mm512_setzero(); - re = _mm512_setzero(); - rf = _mm512_setzero(); + r2 = _mm512_setzero_epi32(); + r3 = _mm512_setzero_epi32(); + r4 = _mm512_setzero_epi32(); + r5 = _mm512_setzero_epi32(); + r6 = _mm512_setzero_epi32(); + r7 = _mm512_setzero_epi32(); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); break; case 1: r0 = _mm512_loadu_epi32(pA); - r1 = _mm512_setzero(); - r2 = _mm512_setzero(); - r3 = _mm512_setzero(); - r4 = _mm512_setzero(); - r5 = _mm512_setzero(); - r6 = _mm512_setzero(); - r7 = _mm512_setzero(); - r8 = _mm512_setzero(); - r9 = _mm512_setzero(); - ra = _mm512_setzero(); - rb = _mm512_setzero(); - rc = _mm512_setzero(); - rd = _mm512_setzero(); - re = _mm512_setzero(); - rf = _mm512_setzero(); + r1 = _mm512_setzero_epi32(); + r2 = _mm512_setzero_epi32(); + r3 = _mm512_setzero_epi32(); + r4 = _mm512_setzero_epi32(); + r5 = _mm512_setzero_epi32(); + r6 = _mm512_setzero_epi32(); + r7 = _mm512_setzero_epi32(); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); break; } @@ -544,7 +544,7 @@ namespace functional { rc = _mm512_maskz_loadu_epi8 (mask, pA + 12*stride); rd = _mm512_maskz_loadu_epi8 (mask, pA + 13*stride); re = _mm512_maskz_loadu_epi8 (mask, pA + 14*stride); - rf = _mm512_setzero(); + rf = _mm512_setzero_epi32(); break; case 14: r0 = _mm512_maskz_loadu_epi8 (mask, pA); @@ -561,8 +561,8 @@ namespace functional { rb = _mm512_maskz_loadu_epi8 (mask, pA + 11*stride); rc = _mm512_maskz_loadu_epi8 (mask, pA + 12*stride); rd = _mm512_maskz_loadu_epi8 (mask, pA + 13*stride); - re = _mm512_setzero(); - rf = _mm512_setzero(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); break; case 13: r0 = _mm512_maskz_loadu_epi8 (mask, pA); @@ -578,9 +578,9 @@ namespace functional { ra = _mm512_maskz_loadu_epi8 (mask, pA + 10*stride); rb = _mm512_maskz_loadu_epi8 (mask, pA + 11*stride); rc = _mm512_maskz_loadu_epi8 (mask, pA + 12*stride); - rd = _mm512_setzero(); - re = _mm512_setzero(); - rf = _mm512_setzero(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); break; case 12: r0 = _mm512_maskz_loadu_epi8 (mask, pA); @@ -595,10 +595,10 @@ namespace functional { r9 = _mm512_maskz_loadu_epi8 (mask, pA + 9*stride); ra = _mm512_maskz_loadu_epi8 (mask, pA + 10*stride); rb = _mm512_maskz_loadu_epi8 (mask, pA + 11*stride); - rc = _mm512_setzero(); - rd = _mm512_setzero(); - re = _mm512_setzero(); - rf = _mm512_setzero(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); break; case 11: r0 = _mm512_maskz_loadu_epi8 (mask, pA); @@ -612,11 +612,11 @@ namespace functional { r8 = _mm512_maskz_loadu_epi8 (mask, pA + 8*stride); r9 = _mm512_maskz_loadu_epi8 (mask, pA + 9*stride); ra = _mm512_maskz_loadu_epi8 (mask, pA + 10*stride); - rb = _mm512_setzero(); - rc = _mm512_setzero(); - rd = _mm512_setzero(); - re = _mm512_setzero(); - rf = _mm512_setzero(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); break; case 10: r0 = _mm512_maskz_loadu_epi8 (mask, pA); @@ -629,12 +629,12 @@ namespace functional { r7 = _mm512_maskz_loadu_epi8 (mask, pA + 7*stride); r8 = _mm512_maskz_loadu_epi8 (mask, pA + 8*stride); r9 = _mm512_maskz_loadu_epi8 (mask, pA + 9*stride); - ra = _mm512_setzero(); - rb = _mm512_setzero(); - rc = _mm512_setzero(); - rd = _mm512_setzero(); - re = _mm512_setzero(); - rf = _mm512_setzero(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); break; case 9: r0 = _mm512_maskz_loadu_epi8 (mask, pA); @@ -646,13 +646,13 @@ namespace functional { r6 = _mm512_maskz_loadu_epi8 (mask, pA + 6*stride); r7 = _mm512_maskz_loadu_epi8 (mask, pA + 7*stride); r8 = _mm512_maskz_loadu_epi8 (mask, pA + 8*stride); - r9 = _mm512_setzero(); - ra = _mm512_setzero(); - rb = _mm512_setzero(); - rc = _mm512_setzero(); - rd = _mm512_setzero(); - re = _mm512_setzero(); - rf = _mm512_setzero(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); break; case 8: r0 = _mm512_maskz_loadu_epi8 (mask, pA); @@ -663,14 +663,14 @@ namespace functional { r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); r6 = _mm512_maskz_loadu_epi8 (mask, pA + 6*stride); r7 = _mm512_maskz_loadu_epi8 (mask, pA + 7*stride); - r8 = _mm512_setzero(); - r9 = _mm512_setzero(); - ra = _mm512_setzero(); - rb = _mm512_setzero(); - rc = _mm512_setzero(); - rd = _mm512_setzero(); - re = _mm512_setzero(); - rf = _mm512_setzero(); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); break; case 7: r0 = _mm512_maskz_loadu_epi8 (mask, pA); @@ -680,15 +680,15 @@ namespace functional { r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); r6 = _mm512_maskz_loadu_epi8 (mask, pA + 6*stride); - r7 = _mm512_setzero(); - r8 = _mm512_setzero(); - r9 = _mm512_setzero(); - ra = _mm512_setzero(); - rb = _mm512_setzero(); - rc = _mm512_setzero(); - rd = _mm512_setzero(); - re = _mm512_setzero(); - rf = _mm512_setzero(); + r7 = _mm512_setzero_epi32(); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); break; case 6: r0 = _mm512_maskz_loadu_epi8 (mask, pA); @@ -697,16 +697,16 @@ namespace functional { r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); - r6 = _mm512_setzero(); - r7 = _mm512_setzero(); - r8 = _mm512_setzero(); - r9 = _mm512_setzero(); - ra = _mm512_setzero(); - rb = _mm512_setzero(); - rc = _mm512_setzero(); - rd = _mm512_setzero(); - re = _mm512_setzero(); - rf = _mm512_setzero(); + r6 = _mm512_setzero_epi32(); + r7 = _mm512_setzero_epi32(); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); break; case 5: r0 = _mm512_maskz_loadu_epi8 (mask, pA); @@ -714,89 +714,89 @@ namespace functional { r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); - r5 = _mm512_setzero(); - r6 = _mm512_setzero(); - r7 = _mm512_setzero(); - r8 = _mm512_setzero(); - r9 = _mm512_setzero(); - ra = _mm512_setzero(); - rb = _mm512_setzero(); - rc = _mm512_setzero(); - rd = _mm512_setzero(); - re = _mm512_setzero(); - rf = _mm512_setzero(); + r5 = _mm512_setzero_epi32(); + r6 = _mm512_setzero_epi32(); + r7 = _mm512_setzero_epi32(); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); break; case 4: r0 = _mm512_maskz_loadu_epi8 (mask, pA); r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); - r4 = _mm512_setzero(); - r5 = _mm512_setzero(); - r6 = _mm512_setzero(); - r7 = _mm512_setzero(); - r8 = _mm512_setzero(); - r9 = _mm512_setzero(); - ra = _mm512_setzero(); - rb = _mm512_setzero(); - rc = _mm512_setzero(); - rd = _mm512_setzero(); - re = _mm512_setzero(); - rf = _mm512_setzero(); + r4 = _mm512_setzero_epi32(); + r5 = _mm512_setzero_epi32(); + r6 = _mm512_setzero_epi32(); + r7 = _mm512_setzero_epi32(); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); break; case 3: r0 = _mm512_maskz_loadu_epi8 (mask, pA); r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); - r3 = _mm512_setzero(); - r4 = _mm512_setzero(); - r5 = _mm512_setzero(); - r6 = _mm512_setzero(); - r7 = _mm512_setzero(); - r8 = _mm512_setzero(); - r9 = _mm512_setzero(); - ra = _mm512_setzero(); - rb = _mm512_setzero(); - rc = _mm512_setzero(); - rd = _mm512_setzero(); - re = _mm512_setzero(); - rf = _mm512_setzero(); + r3 = _mm512_setzero_epi32(); + r4 = _mm512_setzero_epi32(); + r5 = _mm512_setzero_epi32(); + r6 = _mm512_setzero_epi32(); + r7 = _mm512_setzero_epi32(); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); break; case 2: r0 = _mm512_maskz_loadu_epi8 (mask, pA); r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); - r2 = _mm512_setzero(); - r3 = _mm512_setzero(); - r4 = _mm512_setzero(); - r5 = _mm512_setzero(); - r6 = _mm512_setzero(); - r7 = _mm512_setzero(); - r8 = _mm512_setzero(); - r9 = _mm512_setzero(); - ra = _mm512_setzero(); - rb = _mm512_setzero(); - rc = _mm512_setzero(); - rd = _mm512_setzero(); - re = _mm512_setzero(); - rf = _mm512_setzero(); + r2 = _mm512_setzero_epi32(); + r3 = _mm512_setzero_epi32(); + r4 = _mm512_setzero_epi32(); + r5 = _mm512_setzero_epi32(); + r6 = _mm512_setzero_epi32(); + r7 = _mm512_setzero_epi32(); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); break; case 1: r0 = _mm512_maskz_loadu_epi8 (mask, pA); - r1 = _mm512_setzero(); - r2 = _mm512_setzero(); - r3 = _mm512_setzero(); - r4 = _mm512_setzero(); - r5 = _mm512_setzero(); - r6 = _mm512_setzero(); - r7 = _mm512_setzero(); - r8 = _mm512_setzero(); - r9 = _mm512_setzero(); - ra = _mm512_setzero(); - rb = _mm512_setzero(); - rc = _mm512_setzero(); - rd = _mm512_setzero(); - re = _mm512_setzero(); - rf = _mm512_setzero(); + r1 = _mm512_setzero_epi32(); + r2 = _mm512_setzero_epi32(); + r3 = _mm512_setzero_epi32(); + r4 = _mm512_setzero_epi32(); + r5 = _mm512_setzero_epi32(); + r6 = _mm512_setzero_epi32(); + r7 = _mm512_setzero_epi32(); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); break; } transpose_m512i_16x16(r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf); @@ -1416,7 +1416,7 @@ namespace PP { } if (std::is_same::value) { auto c = _mm512_cvtne2ps_pbh(r1, r0); // convert to bf16 - _mm512_mask_storeu_epi16(pdst, kall, c); // store bf16 + _mm512_mask_storeu_epi16(pdst, kall, (__m512i)c); // store bf16 } if (std::is_same::value) { auto d0 = _mm512_cvtps_epi32(r0); // convert to dword(i32) @@ -1444,8 +1444,6 @@ void prefetch_bytes(void *src) for (int i = 0; i < bytes; i+=64) _mm_prefetch(p + i + advance, sel); } -template -void zero_tiles() { int dummy[sizeof...(tmm)] = {(_tile_zero(tmm), 0)...}; } // matmul (FC) // @@ -1729,7 +1727,7 @@ struct MatmulVector { } //asm("int3"); for(int m = 0; m < M; m+=16) { - zero_tiles<0>(); + _tile_zero(0); if (tmmN == 1) { _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); } @@ -1938,7 +1936,7 @@ struct Matmul { pA0 -= (16 - (M - m0))*A.stride; m = M - 16; } - zero_tiles<0>(); + _tile_zero(0); if (tmmN == 1) { _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); } @@ -2037,7 +2035,8 @@ struct Matmul { int k; const auto strideA = matA.stride; loop2D_no_bM<32>(M, N, [&](int m, int n, int valid_m, int valid_n) { - zero_tiles<0, 1>(); + _tile_zero(0); + _tile_zero(1); int8_t * pA0 = reinterpret_cast(&matA[0]); for(k=0; k(&matA(m + 16, 0)); auto strideA = matA.stride; auto * pB = reinterpret_cast(&internalB(n>>5, 0)); - zero_tiles<0, 1, 2, 3>(); + _tile_zero(0); + _tile_zero(1); + _tile_zero(2); + _tile_zero(3); // 2x2 for (int k = 0; k < Kbody; k += kStep) { _tile_loadd(4, pA0, strideA); pA0 += 64; @@ -2209,7 +2211,8 @@ struct Matmul { const auto strideA = matA.stride; loop2D_no_bM<32>(M, N, [&](int m, int n, int valid_m, int valid_n) { // C:Mx32 = A:Mx32 x B:32x32 - zero_tiles<0, 1>(); + _tile_zero(0); + _tile_zero(1); auto * pA0 = &matA[0]; for(int k=0; k { auto * pBint = reinterpret_cast(&internalBI8(n>>5, 0)); functional::i8_to_bf16_Kx32<32>(pBint, pBb); - zero_tiles<0, 1, 2, 3>(); + _tile_zero(0); + _tile_zero(1); + _tile_zero(2); + _tile_zero(3); int k; for (k = 0; k < Kbody; k += kStep) { functional::i8_to_bf16_Kx32<16>(pBint, pBa); @@ -2406,22 +2412,22 @@ struct GemAvB { functional::transpose_m512i_16x16(r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf); // vdpbf16ps - regC0 = _mm512_dpbf16_ps(regC0, r0, _mm512_set1_epi32(pBi32[0])); - regC1 = _mm512_dpbf16_ps(regC1, r1, _mm512_set1_epi32(pBi32[1])); - regC0 = _mm512_dpbf16_ps(regC0, r2, _mm512_set1_epi32(pBi32[2])); - regC1 = _mm512_dpbf16_ps(regC1, r3, _mm512_set1_epi32(pBi32[3])); - regC0 = _mm512_dpbf16_ps(regC0, r4, _mm512_set1_epi32(pBi32[4])); - regC1 = _mm512_dpbf16_ps(regC1, r5, _mm512_set1_epi32(pBi32[5])); - regC0 = _mm512_dpbf16_ps(regC0, r6, _mm512_set1_epi32(pBi32[6])); - regC1 = _mm512_dpbf16_ps(regC1, r7, _mm512_set1_epi32(pBi32[7])); - regC0 = _mm512_dpbf16_ps(regC0, r8, _mm512_set1_epi32(pBi32[8])); - regC1 = _mm512_dpbf16_ps(regC1, r9, _mm512_set1_epi32(pBi32[9])); - regC0 = _mm512_dpbf16_ps(regC0, ra, _mm512_set1_epi32(pBi32[10])); - regC1 = _mm512_dpbf16_ps(regC1, rb, _mm512_set1_epi32(pBi32[11])); - regC0 = _mm512_dpbf16_ps(regC0, rc, _mm512_set1_epi32(pBi32[12])); - regC1 = _mm512_dpbf16_ps(regC1, rd, _mm512_set1_epi32(pBi32[13])); - regC0 = _mm512_dpbf16_ps(regC0, re, _mm512_set1_epi32(pBi32[14])); - regC1 = _mm512_dpbf16_ps(regC1, rf, _mm512_set1_epi32(pBi32[15])); + regC0 = _mm512_dpbf16_ps(regC0, (__m512bh)(r0), (__m512bh)(_mm512_set1_epi32(pBi32[0]))); + regC1 = _mm512_dpbf16_ps(regC1, (__m512bh)(r1), (__m512bh)(_mm512_set1_epi32(pBi32[1]))); + regC0 = _mm512_dpbf16_ps(regC0, (__m512bh)(r2), (__m512bh)(_mm512_set1_epi32(pBi32[2]))); + regC1 = _mm512_dpbf16_ps(regC1, (__m512bh)(r3), (__m512bh)(_mm512_set1_epi32(pBi32[3]))); + regC0 = _mm512_dpbf16_ps(regC0, (__m512bh)(r4), (__m512bh)(_mm512_set1_epi32(pBi32[4]))); + regC1 = _mm512_dpbf16_ps(regC1, (__m512bh)(r5), (__m512bh)(_mm512_set1_epi32(pBi32[5]))); + regC0 = _mm512_dpbf16_ps(regC0, (__m512bh)(r6), (__m512bh)(_mm512_set1_epi32(pBi32[6]))); + regC1 = _mm512_dpbf16_ps(regC1, (__m512bh)(r7), (__m512bh)(_mm512_set1_epi32(pBi32[7]))); + regC0 = _mm512_dpbf16_ps(regC0, (__m512bh)(r8), (__m512bh)(_mm512_set1_epi32(pBi32[8]))); + regC1 = _mm512_dpbf16_ps(regC1, (__m512bh)(r9), (__m512bh)(_mm512_set1_epi32(pBi32[9]))); + regC0 = _mm512_dpbf16_ps(regC0, (__m512bh)(ra), (__m512bh)(_mm512_set1_epi32(pBi32[10]))); + regC1 = _mm512_dpbf16_ps(regC1, (__m512bh)(rb), (__m512bh)(_mm512_set1_epi32(pBi32[11]))); + regC0 = _mm512_dpbf16_ps(regC0, (__m512bh)(rc), (__m512bh)(_mm512_set1_epi32(pBi32[12]))); + regC1 = _mm512_dpbf16_ps(regC1, (__m512bh)(rd), (__m512bh)(_mm512_set1_epi32(pBi32[13]))); + regC0 = _mm512_dpbf16_ps(regC0, (__m512bh)(re), (__m512bh)(_mm512_set1_epi32(pBi32[14]))); + regC1 = _mm512_dpbf16_ps(regC1, (__m512bh)(rf), (__m512bh)(_mm512_set1_epi32(pBi32[15]))); } regC0 = _mm512_add_ps(regC0, regC1); _mm512_mask_storeu_ps (vecC + m, kmask, regC0); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index f4cf825..09b7295 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -26,7 +26,3 @@ find_package(OpenMP REQUIRED) add_executable(llmdnn_tests ${TEST_SOURCE_FILES}) target_link_libraries(llmdnn_tests llmdnn gtest_main stdc++ OpenMP::OpenMP_CXX) install(TARGETS llmdnn_tests DESTINATION ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}) - -if (BUILD_PYTHON_TESTS) - add_subdirectory(script/ext) -endif() \ No newline at end of file diff --git a/tests/script/README.md b/tests/script/README.md index aaf82c5..547a508 100644 --- a/tests/script/README.md +++ b/tests/script/README.md @@ -10,7 +10,7 @@ pip3 install -r requirements.txt compile extension ``` -cmake . -DBUILD_PYTHON_TESTS=ON +./build.sh ``` run test diff --git a/tests/script/build.sh b/tests/script/build.sh new file mode 100755 index 0000000..b466a7c --- /dev/null +++ b/tests/script/build.sh @@ -0,0 +1,7 @@ +#!/bin/bash + +pip uninstall -y llmdnn +cd ../../../../../../../build/ && make -j 20 llmdnnlib +cd - +cd ext +python setup.py clean --all install diff --git a/tests/script/ext/mha_gpt.cpp b/tests/script/ext/mha_gpt.cpp index 776e1b7..03531f6 100644 --- a/tests/script/ext/mha_gpt.cpp +++ b/tests/script/ext/mha_gpt.cpp @@ -62,8 +62,8 @@ void regclass_mha_gpt(pybind11::module m) { auto head_size = q.size(3); auto key_seq_len = k.size(2); auto attn_len = attn_mask.size(1); - AT_ASSERT(key_seq_len == v.size(2) && - batch == k.size(0) && batch == v.size(0) && batch == attn_mask.size(0) && + AT_ASSERT(key_seq_len == v.size(2) && key_seq_len == attn_len && + batch == k.size(0) && batch == v.size(0) && 1 == attn_mask.size(0) && num_heads == k.size(1) && num_heads == v.size(1) && head_size == k.size(3) && head_size == v.size(3)); @@ -72,16 +72,16 @@ void regclass_mha_gpt(pybind11::module m) { param.batch = batch; param.query_seq_len = query_seq_len; param.key_seq_len = key_seq_len; - param.q = q.data_ptr(); - param.attn_output = out.data_ptr(); + param.q = reinterpret_cast(q.data_ptr()); + param.attn_output = reinterpret_cast(out.data_ptr()); param.head_stride_in_kv = key_seq_len * head_size; param.k = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); param.v = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); param.attention_mask = reinterpret_cast(alloca(batch * sizeof(float*))); - for (size_t i = 0; i < batch; i++) { - param.k[i] = k.data_ptr() + i * num_heads * key_seq_len * head_size; - param.v[i] = v.data_ptr() + i * num_heads * key_seq_len * head_size; - param.attention_mask[i] = attn_mask.data_ptr() + i * attn_len; + for (int i = 0; i < batch; i++) { + param.k[i] = reinterpret_cast(k.data_ptr()) + i * num_heads * key_seq_len * head_size * sizeof(ov::bfloat16); + param.v[i] = reinterpret_cast(v.data_ptr()) + i * num_heads * key_seq_len * head_size * sizeof(ov::bfloat16); + param.attention_mask[i] = attn_mask.data_ptr(); } self.exec(param); diff --git a/tests/script/ext/module.cpp b/tests/script/ext/module.cpp index a84f216..19419cd 100644 --- a/tests/script/ext/module.cpp +++ b/tests/script/ext/module.cpp @@ -9,7 +9,7 @@ #include "mha_gpt.hpp" #include "test_common.hpp" -PYBIND11_MODULE(libllmdnn_ext, m) { +PYBIND11_MODULE(llmdnn, m) { static bool initAMX = initXTILE(); if (!initAMX) { std::cout << "init amx failed.\n"; diff --git a/tests/script/ext/setup.py b/tests/script/ext/setup.py index f27924f..aa58078 100644 --- a/tests/script/ext/setup.py +++ b/tests/script/ext/setup.py @@ -12,21 +12,25 @@ export CXX=icx export CC=icx ''' +debug = True +extra_args = ['-fopenmp', + '-march=native'] +llmdnn_lib_dir = '../../../../../../../../bin/intel64/Release' +if debug: + llmdnn_lib_dir = '../../../../../../../../bin/intel64/Debug' + extra_args += ['-g', '-O0'] setup(name='llmdnn', ext_modules=[ cpp_extension.CppExtension( 'llmdnn', ['module.cpp', 'mha_gpt.cpp', '../../src/test_common.cpp'], - extra_compile_args=[ '-fopenmp', - '-march=native', - #'-g' - ], + extra_compile_args=extra_args, #extra_link_args=['-lgomp'], include_dirs=['../../src', '../../../include', '../../../src'], library_dirs=[f'{sys.prefix}/lib', - '../../../../../../../../bin/intel64/Debug'], + llmdnn_lib_dir], #runtime_library_dirs=[ f'{sys.prefix}/lib', ], libraries=['llmdnn', 'stdc++']), diff --git a/tests/script/requirements.txt b/tests/script/requirements.txt index 79b1ea8..15ff051 100644 --- a/tests/script/requirements.txt +++ b/tests/script/requirements.txt @@ -1,2 +1,3 @@ +-f https://download.pytorch.org/whl/torch_stable.html numpy==1.24.2 -torch==2.0.0+cpu +torch==2.0.1+cpu diff --git a/tests/script/test_mha_gpt.py b/tests/script/test_mha_gpt.py index 8ea1514..bb2edf9 100644 --- a/tests/script/test_mha_gpt.py +++ b/tests/script/test_mha_gpt.py @@ -6,7 +6,7 @@ import sys import torch import numpy as np -import libllmdnn_ext as ld +import llmdnn as ld from torch import nn # copy from transformers/models/gpt_neox/modeling_gpt_neox.py @@ -100,7 +100,7 @@ def __init__(self, num_attention_heads, hidden_size, max_position_embeddings): max_seq_len = max_position_embeddings head_size_aligned = head_size - normal_factor = math.sqrt(head_size) + normal_factor = 1.0 / math.sqrt(head_size) qkv_precision_name = 'bf16' dst_precision_name = 'bf16' self.mha.create(num_heads, head_size, head_size_aligned, normal_factor, qkv_precision_name, @@ -109,7 +109,7 @@ def __init__(self, num_attention_heads, hidden_size, max_position_embeddings): def forward(self, query, key, value, attention_mask=None): return self.mha.exec(query, key, value, attention_mask) -HEAD_NUM = 12 #32 +HEAD_NUM = 32 SIZE_PER_HEAD = 80 HIDDEN_SIZE = HEAD_NUM * SIZE_PER_HEAD MAX_POSITION_EMBEDDINGS = 1024 #2048 @@ -124,16 +124,18 @@ def __init__(self): ref_net = ref_net.to(dtype=torch.bfloat16) net = GPTNeoXAttentionExt(HEAD_NUM, HIDDEN_SIZE, MAX_POSITION_EMBEDDINGS) with torch.cpu.amp.autocast(): - for (input, i) in enumerate(inputs): - (q, k, v, attn_mask) = input + for (i, input) in enumerate(inputs): + q, k, v, attn_mask = input q = torch.from_numpy(q).to(torch.bfloat16) k = torch.from_numpy(k).to(torch.bfloat16) v = torch.from_numpy(v).to(torch.bfloat16) + attn_mask = torch.from_numpy(attn_mask) ref_output = ref_net.forward(q, k, v, attn_mask) output = net.forward(q, k, v, attn_mask) - if not torch.allclose(ref_output, output): - print(f"error at index {i}") - + if not torch.allclose(ref_output, output, rtol=0.001, atol=0.01): + print(f"error at index {i} ref:\n{ref_output} \ncur:\n {output} ") + + print('done.') return if __name__ == "__main__": @@ -143,9 +145,17 @@ def __init__(self): # k: [batch, num_heads, key_seq_len, head_size] # v: [batch, num_heads, value_seq_len, head_size] # attn: [1, MAX_POSITION_EMBEDDINGS] - (np.random.random(size=[1, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), - np.random.random(size=[1, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), - np.random.random(size=[1, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), - np.zeros([1, MAX_POSITION_EMBEDDINGS], dtype=np.float32)) + (np.random.random(size=[2, HEAD_NUM, 2, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 32, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 32, SIZE_PER_HEAD]).astype(np.float32), + np.zeros([1, 32], dtype=np.float32)), + (np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), + np.zeros([1, 200], dtype=np.float32)), + (np.random.random(size=[2, HEAD_NUM, 1, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), + np.zeros([1, 200], dtype=np.float32)), ] test(inputs) \ No newline at end of file From d2ea5981a441fe52b51b1dce9831f4ef6cb56cc6 Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Thu, 29 Jun 2023 01:51:25 +0800 Subject: [PATCH 06/54] restructure directory layout --- src/mha_gpt.hpp => include/llm_mha_gpt.hpp | 9 +- src/{ => common}/bf16.hpp | 0 src/{ => common}/simple_parallel.hpp | 0 src/{ => common}/tensor2d.hpp | 0 src/{ => common}/tensor2d_helper.hpp | 0 src/{ => common}/utility.hpp | 0 src/{fc_interface.cpp => fc_kernel_amx.cpp} | 14 +- src/fc_kernel_amx.hpp | 19 + src/fc_kernel_api.cpp | 53 + src/{mha_gpt.cpp => mha_gpt_amx.cpp} | 66 +- src/mha_gpt_amx.hpp | 17 + src/mha_gpt_api.cpp | 24 + src/{mm_interface.cpp => mm_kernel_amx.cpp} | 11 +- src/mm_kernel_amx.hpp | 2465 +------------------ src/mm_kernel_api.cpp | 28 + src/mm_kernel_common_amx.hpp | 2455 ++++++++++++++++++ src/softmax_kernel_avx512.hpp | 12 +- src/transpose_kernel_avx512.hpp | 8 +- src/utility_amx.hpp | 6 +- src/utility_avx512.hpp | 8 +- tests/script/README.md | 6 +- tests/script/ext/mha_gpt.cpp | 4 +- tests/script/ext/module.cpp | 2 +- tests/script/ext/setup.py | 25 +- tests/script/requirements.txt | 2 + tests/script/test_mha_gpt.py | 43 +- tests/src/test_common.cpp | 2 +- tests/src/test_common.hpp | 4 +- tests/src/test_fc_kernel.cpp | 4 +- tests/src/test_mm_kernel.cpp | 4 +- tests/src/test_softmax_kernel_avx512.cpp | 6 +- tests/src/test_transpose_kernel_avx512.cpp | 6 +- tests/src/test_utility.cpp | 6 +- tests/src/test_utility_repack1x2.cpp | 6 +- 34 files changed, 2747 insertions(+), 2568 deletions(-) rename src/mha_gpt.hpp => include/llm_mha_gpt.hpp (95%) rename src/{ => common}/bf16.hpp (100%) rename src/{ => common}/simple_parallel.hpp (100%) rename src/{ => common}/tensor2d.hpp (100%) rename src/{ => common}/tensor2d_helper.hpp (100%) rename src/{ => common}/utility.hpp (100%) rename src/{fc_interface.cpp => fc_kernel_amx.cpp} (96%) create mode 100644 src/fc_kernel_amx.hpp create mode 100644 src/fc_kernel_api.cpp rename src/{mha_gpt.cpp => mha_gpt_amx.cpp} (89%) create mode 100644 src/mha_gpt_amx.hpp create mode 100644 src/mha_gpt_api.cpp rename src/{mm_interface.cpp => mm_kernel_amx.cpp} (91%) create mode 100644 src/mm_kernel_api.cpp create mode 100644 src/mm_kernel_common_amx.hpp diff --git a/src/mha_gpt.hpp b/include/llm_mha_gpt.hpp similarity index 95% rename from src/mha_gpt.hpp rename to include/llm_mha_gpt.hpp index 0605f87..0822120 100644 --- a/src/mha_gpt.hpp +++ b/include/llm_mha_gpt.hpp @@ -77,9 +77,12 @@ class mha_gpt { void create(const create_param& param); void exec(const exec_param& param); -private: - struct Impl; - std::shared_ptr _impl; + struct impl { + virtual void create(const create_param& param) = 0; + virtual void exec(const exec_param& param) = 0; + }; +protected: + std::shared_ptr _impl; }; } diff --git a/src/bf16.hpp b/src/common/bf16.hpp similarity index 100% rename from src/bf16.hpp rename to src/common/bf16.hpp diff --git a/src/simple_parallel.hpp b/src/common/simple_parallel.hpp similarity index 100% rename from src/simple_parallel.hpp rename to src/common/simple_parallel.hpp diff --git a/src/tensor2d.hpp b/src/common/tensor2d.hpp similarity index 100% rename from src/tensor2d.hpp rename to src/common/tensor2d.hpp diff --git a/src/tensor2d_helper.hpp b/src/common/tensor2d_helper.hpp similarity index 100% rename from src/tensor2d_helper.hpp rename to src/common/tensor2d_helper.hpp diff --git a/src/utility.hpp b/src/common/utility.hpp similarity index 100% rename from src/utility.hpp rename to src/common/utility.hpp diff --git a/src/fc_interface.cpp b/src/fc_kernel_amx.cpp similarity index 96% rename from src/fc_interface.cpp rename to src/fc_kernel_amx.cpp index 6a34a4e..0843fc0 100644 --- a/src/fc_interface.cpp +++ b/src/fc_kernel_amx.cpp @@ -15,8 +15,9 @@ #include #include "llm_fc.hpp" -#include "mm_kernel_amx.hpp" +#include "mm_kernel_common_amx.hpp" #include "utility_avx512.hpp" +#include "fc_kernel_amx.hpp" namespace llmdnn { @@ -67,7 +68,7 @@ static bool check_valid_postops(size_t value, data_type_t dt_a, data_type_t dt_b } // interface -bool fc_kernel_create(fc_kernel** mm, const fc_create_param* param) { +bool fc_kernel_create_amx(fc_kernel** mm, const fc_create_param* param) { fc_kernel* m = nullptr; if (param == nullptr || mm == nullptr) { std::cout << "fc_kernel_create: invalid input parameter.\n"; @@ -107,13 +108,13 @@ bool fc_kernel_create(fc_kernel** mm, const fc_create_param* param) { return false; } -void fc_kernel_destroy(const fc_kernel* mm) { +void fc_kernel_destroy_amx(const fc_kernel* mm) { if (mm) { delete mm; } } -void fc_kernel_execute(const fc_kernel* mm, void* ptr_a, void* ptr_b, void* ptr_c, size_t lda, size_t ldb, size_t ldc, +void fc_kernel_execute_amx(const fc_kernel* mm, void* ptr_a, void* ptr_b, void* ptr_c, size_t lda, size_t ldb, size_t ldc, size_t M, size_t N, size_t K, size_t n_start, size_t n_end, float* dq, float* q, float* bias) { size_t b_d0 = K, b_d1 = N; if (mm->b_is_transpose) { @@ -293,7 +294,7 @@ void fc_kernel_execute(const fc_kernel* mm, void* ptr_a, void* ptr_b, void* ptr_ } } -void fc_kernel_bf16w8_get_q_dq(size_t K, size_t N, size_t stride, void* ptr, float* q, float* dq) { +void fc_kernel_bf16w8_get_q_dq_amx(size_t K, size_t N, size_t stride, void* ptr, float* q, float* dq) { float min, max; tensor2D B(K, N, reinterpret_cast(ptr), stride); amx_kernel::functional::get_min_max(B, min, max); @@ -303,7 +304,7 @@ void fc_kernel_bf16w8_get_q_dq(size_t K, size_t N, size_t stride, void* ptr, flo } /// set q, dq for each fc_kernel instance -void fc_kernel_bf16w8_set_q_dq(const fc_kernel* mm, float q, float dq) { +void fc_kernel_bf16w8_set_q_dq_amx(const fc_kernel* mm, float q, float dq) { if (!mm || !mm->bf16xi8) { std::cout << "fc_kernel_bf16w8_set_q_dq: created kernel is not int8 weight.\n"; return; @@ -312,5 +313,4 @@ void fc_kernel_bf16w8_set_q_dq(const fc_kernel* mm, float q, float dq) { mm->bf16xi8->dequant_scale_B = dq; } - } \ No newline at end of file diff --git a/src/fc_kernel_amx.hpp b/src/fc_kernel_amx.hpp new file mode 100644 index 0000000..5c5467f --- /dev/null +++ b/src/fc_kernel_amx.hpp @@ -0,0 +1,19 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "llm_fc.hpp" + +namespace llmdnn { + +bool fc_kernel_create_amx(fc_kernel** mm, const fc_create_param* param); + +void fc_kernel_destroy_amx(const fc_kernel* mm); + +void fc_kernel_execute_amx(const fc_kernel* mm, void* ptr_a, void* ptr_b, void* ptr_c, size_t lda, size_t ldb, size_t ldc, + size_t M, size_t N, size_t K, size_t n_start, size_t n_end, float* dq, float* q, float* bias); + +void fc_kernel_bf16w8_get_q_dq_amx(size_t K, size_t N, size_t stride, void* ptr, float* q, float* dq); +void fc_kernel_bf16w8_set_q_dq_amx(const fc_kernel* mm, float q, float dq); + +} \ No newline at end of file diff --git a/src/fc_kernel_api.cpp b/src/fc_kernel_api.cpp new file mode 100644 index 0000000..12fe913 --- /dev/null +++ b/src/fc_kernel_api.cpp @@ -0,0 +1,53 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "llm_fc.hpp" +#include "fc_kernel_amx.hpp" +#include "mm_kernel_common_amx.hpp" +#include "utility_avx512.hpp" + +namespace llmdnn { + +static decltype(&fc_kernel_create) fc_kernel_create_ptr = fc_kernel_create_amx; +static decltype(&fc_kernel_destroy) fc_kernel_destroy_ptr = fc_kernel_destroy_amx; +static decltype(&fc_kernel_execute) fc_kernel_execute_ptr = fc_kernel_execute_amx; +static decltype(&fc_kernel_bf16w8_get_q_dq) fc_kernel_bf16w8_get_q_dq_ptr = fc_kernel_bf16w8_get_q_dq_amx; +static decltype(&fc_kernel_bf16w8_set_q_dq) fc_kernel_bf16w8_set_q_dq_ptr = fc_kernel_bf16w8_set_q_dq_amx; + +// interface +bool fc_kernel_create(fc_kernel** mm, const fc_create_param* param) { + return fc_kernel_create_ptr(mm, param); +} + +void fc_kernel_destroy(const fc_kernel* mm) { + fc_kernel_destroy_ptr(mm); +} + +void fc_kernel_execute(const fc_kernel* mm, void* ptr_a, void* ptr_b, void* ptr_c, size_t lda, size_t ldb, size_t ldc, + size_t M, size_t N, size_t K, size_t n_start, size_t n_end, float* dq, float* q, float* bias) { + fc_kernel_execute_ptr(mm, ptr_a, ptr_b, ptr_c, lda, ldb, ldc, M, N, K, n_start, n_end, dq, q, bias); +} + +void fc_kernel_bf16w8_get_q_dq(size_t K, size_t N, size_t stride, void* ptr, float* q, float* dq) { + fc_kernel_bf16w8_get_q_dq_ptr(K, N, stride, ptr, q, dq); +} + +/// set q, dq for each fc_kernel instance +void fc_kernel_bf16w8_set_q_dq(const fc_kernel* mm, float q, float dq) { + fc_kernel_bf16w8_set_q_dq_ptr(mm, q, dq); +} + +} \ No newline at end of file diff --git a/src/mha_gpt.cpp b/src/mha_gpt_amx.cpp similarity index 89% rename from src/mha_gpt.cpp rename to src/mha_gpt_amx.cpp index 34cc870..dd84591 100644 --- a/src/mha_gpt.cpp +++ b/src/mha_gpt_amx.cpp @@ -5,26 +5,26 @@ #include #include -#include "simple_parallel.hpp" -#include "utility.hpp" +#include "common/simple_parallel.hpp" +#include "common/utility.hpp" #include "utility_avx512.hpp" -#include "mm_kernel_amx.hpp" +#include "mm_kernel_common_amx.hpp" #include "softmax_kernel_avx512.hpp" #include "transpose_kernel_avx512.hpp" -#include "mha_gpt.hpp" +#include "llm_mha_gpt.hpp" using namespace ov::cpu; namespace llmdnn { -struct mha_gpt::Impl { - void create(const create_param& param); - void exec(const exec_param& param); +struct mha_gpt_impl_amx : public mha_gpt::impl { + void create(const mha_gpt::create_param& param) override; + void exec(const mha_gpt::exec_param& param) override; - create_param _create_param; + mha_gpt::create_param _create_param; - void mha_bf16(const exec_param ¶m); - void mha_i8(const exec_param ¶m); + void mha_bf16(const mha_gpt::exec_param ¶m); + void mha_i8(const mha_gpt::exec_param ¶m); size_t bufferMatMul0OutSize; size_t bufferMatMul1OutSize; @@ -41,7 +41,7 @@ struct mha_gpt::Impl { std::vector>> gemAvB_i8xi8; }; -void mha_gpt::Impl::create(const create_param& param) { +void mha_gpt_impl_amx::create(const mha_gpt::create_param& param) { _create_param = param; // q: [batch, num_heads, query_seq_len, head_size] @@ -90,7 +90,7 @@ void mha_gpt::Impl::create(const create_param& param) { [](void * p) { ::free(p); }); } -void mha_gpt::Impl::mha_bf16(const exec_param ¶m) { +void mha_gpt_impl_amx::mha_bf16(const mha_gpt::exec_param ¶m) { uint8_t* pQIn0 = param.q; auto& pKIn0 = param.k; auto& attn_masks = param.attention_mask; @@ -127,15 +127,15 @@ void mha_gpt::Impl::mha_bf16(const exec_param ¶m) { (*gemAvB_ops[threadNum])(matK, reinterpret_cast(pQIn0_aux), reinterpret_cast(bufferMatMul0Out_local)); float* pMatMul0Out = reinterpret_cast(bufferMatMul0Out_local); - mul_add_f32(pMatMul0Out, pMatMul0Out, _create_param.normal_factor, pAddIn1_aux, param.key_seq_len); - softmax(reinterpret_cast(pMatMul0Out), pMatMul0Out, param.key_seq_len, nullptr, nullptr, nullptr); + mul_add_f32_avx512(pMatMul0Out, pMatMul0Out, _create_param.normal_factor, pAddIn1_aux, param.key_seq_len); + softmax_avx512(reinterpret_cast(pMatMul0Out), pMatMul0Out, param.key_seq_len, nullptr, nullptr, nullptr); auto pOut_aux = pout + (i0 * batch_stride_in_attn + i1 * head_stride_in_attn) * outPrcSize; tensor2D matQK(param.query_seq_len, param.key_seq_len, reinterpret_cast(bufferMatMul0Out_local), rndup(param.key_seq_len * sizeof(ov::bfloat16), 64)); tensor2D matV(param.key_seq_len, _create_param.head_size, reinterpret_cast(pVIn0_aux), _create_param.head_size_aligned * sizeof(ov::bfloat16)); tensor2D matQKV(param.query_seq_len, _create_param.head_size, reinterpret_cast(bufferMatMul1Out_local), _create_param.head_size_aligned * sizeof(float)); amx_kernel::PP::BiasGeluStore pp(matQKV); (*qKVGemm_ops[threadNum])(matQK, matV, 0, _create_param.head_size, pp); - memcpy2d_stride(reinterpret_cast(pOut_aux), reinterpret_cast(bufferMatMul1Out_local), param.query_seq_len, + memcpy2d_stride_avx512(reinterpret_cast(pOut_aux), reinterpret_cast(bufferMatMul1Out_local), param.query_seq_len, _create_param.head_size, _create_param.head_size_aligned * sizeof(float), _create_param.num_heads * _create_param.head_size * sizeof(ov::bfloat16), nullptr); }); } else { @@ -182,8 +182,8 @@ void mha_gpt::Impl::mha_bf16(const exec_param ¶m) { for (size_t m = 0; m < seq_cout; m++) { float* src = reinterpret_cast(pMatMul0Out + m * rndup(param.key_seq_len * sizeof(float), 64)); ov::bfloat16* dst = reinterpret_cast(pMatMul0Out + m * rndup(param.key_seq_len * sizeof(ov::bfloat16), 64)); - mul_add_f32(src, src, _create_param.normal_factor, pAddIn1_aux, valid_softmax_items); - softmax(dst, src, valid_softmax_items, nullptr, nullptr, nullptr); + mul_add_f32_avx512(src, src, _create_param.normal_factor, pAddIn1_aux, valid_softmax_items); + softmax_avx512(dst, src, valid_softmax_items, nullptr, nullptr, nullptr); // attn_scores = torch.where(causal_mask, attn_scores, mask_value) if (param.key_seq_len > valid_softmax_items) { auto *invalidPtr = dst + valid_softmax_items; @@ -199,7 +199,7 @@ void mha_gpt::Impl::mha_bf16(const exec_param ¶m) { amx_kernel::PP::BiasGeluStore pp2(matQKV); (*qKVGemm_ops[threadNum])(matQKBF16, matV, 0, _create_param.head_size, pp2, prev_v == pVIn0_aux); prev_v = pVIn0_aux; - memcpy2d_stride(reinterpret_cast(pOut_aux), reinterpret_cast(bufferMatMul1Out_local), seq_cout, + memcpy2d_stride_avx512(reinterpret_cast(pOut_aux), reinterpret_cast(bufferMatMul1Out_local), seq_cout, _create_param.head_size, _create_param.head_size_aligned * sizeof(float), _create_param.num_heads * _create_param.head_size * sizeof(ov::bfloat16), nullptr); parallel_it_step(i0, param.batch, i1, _create_param.num_heads, seq, seq_cout_all); } @@ -207,7 +207,7 @@ void mha_gpt::Impl::mha_bf16(const exec_param ¶m) { } } -void mha_gpt::Impl::mha_i8(const exec_param ¶m) { +void mha_gpt_impl_amx::mha_i8(const mha_gpt::exec_param ¶m) { uint8_t* pQIn0 = param.q; auto& pKIn0 = param.k; auto& attn_masks = param.attention_mask; @@ -250,18 +250,18 @@ void mha_gpt::Impl::mha_i8(const exec_param ¶m) { // k[N, K] * transpose(q[1, K]) ==> // k[N, K] * q[K, 1] (*gemAvB_ops[threadNum])(matK, reinterpret_cast(pQIn0_aux), reinterpret_cast(bufferMatMul0Out_local)); - cvt_i32_f32(reinterpret_cast(bufferMatMul0Out_local), reinterpret_cast(bufferMatMul0Out_local), param.key_seq_len); + cvt_i32_f32_avx512(reinterpret_cast(bufferMatMul0Out_local), reinterpret_cast(bufferMatMul0Out_local), param.key_seq_len); float* pMatMul0Out = reinterpret_cast(bufferMatMul0Out_local); - mul_add_f32(pMatMul0Out, pMatMul0Out, mul_scales, pAddIn1_aux, param.key_seq_len); - softmax(reinterpret_cast(pMatMul0Out), pMatMul0Out, param.key_seq_len, nullptr, nullptr, qk_quant_vec.data()); + mul_add_f32_avx512(pMatMul0Out, pMatMul0Out, mul_scales, pAddIn1_aux, param.key_seq_len); + softmax_avx512(reinterpret_cast(pMatMul0Out), pMatMul0Out, param.key_seq_len, nullptr, nullptr, qk_quant_vec.data()); auto pOut_aux = pout + (i0 * batch_stride_in_attn + i1 * head_stride_in_attn) * outPrcSize; tensor2D matQK(param.query_seq_len, param.key_seq_len, reinterpret_cast(bufferMatMul0Out_local), rndup(param.key_seq_len * sizeof(uint8_t), 64)); tensor2D matV(param.key_seq_len, _create_param.head_size, reinterpret_cast(pVIn0_aux), _create_param.head_size_aligned * sizeof(int8_t)); tensor2D matQKV(param.query_seq_len, _create_param.head_size, reinterpret_cast(bufferMatMul1Out_local), _create_param.head_size_aligned * sizeof(float)); amx_kernel::PP::BiasGeluStore pp(matQKV); (*qKVGemm_ops[threadNum])(matQK, matV, 0, _create_param.head_size, pp); - memcpy2d_stride(reinterpret_cast(pOut_aux), reinterpret_cast(bufferMatMul1Out_local), param.query_seq_len, + memcpy2d_stride_avx512(reinterpret_cast(pOut_aux), reinterpret_cast(bufferMatMul1Out_local), param.query_seq_len, _create_param.head_size, _create_param.head_size_aligned * sizeof(float), _create_param.num_heads * _create_param.head_size, qkv_quant.data()); }); } else { @@ -308,8 +308,8 @@ void mha_gpt::Impl::mha_i8(const exec_param ¶m) { for (size_t m = 0; m < seq_cout; m++) { float* src = reinterpret_cast(pMatMul0Out + m * rndup(param.key_seq_len * sizeof(float), 64)); uint8_t* dst = reinterpret_cast(pMatMul0Out + m * rndup(param.key_seq_len * sizeof(uint8_t), 64)); - mul_add_f32(src, src, mul_scales, pAddIn1_aux, valid_softmax_items); - softmax(dst, src, valid_softmax_items, nullptr, nullptr, qk_quant_vec.data()); + mul_add_f32_avx512(src, src, mul_scales, pAddIn1_aux, valid_softmax_items); + softmax_avx512(dst, src, valid_softmax_items, nullptr, nullptr, qk_quant_vec.data()); // attn_scores = torch.where(causal_mask, attn_scores, mask_value) if (param.key_seq_len > valid_softmax_items) { auto *invalidPtr = dst + valid_softmax_items; @@ -327,7 +327,7 @@ void mha_gpt::Impl::mha_i8(const exec_param ¶m) { prev_v = pVIn0_aux; // matmul1: [batch, num_heads, query_seq_len, head_size] // attn_output: [batch, query_seq_len, num_heads * head_size] - memcpy2d_stride(reinterpret_cast(pOut_aux), reinterpret_cast(bufferMatMul1Out_local), seq_cout, + memcpy2d_stride_avx512(reinterpret_cast(pOut_aux), reinterpret_cast(bufferMatMul1Out_local), seq_cout, _create_param.head_size, _create_param.head_size_aligned * sizeof(float), _create_param.num_heads * _create_param.head_size, qkv_quant.data()); parallel_it_step(i0, param.batch, i1, _create_param.num_heads, seq, seq_cout_all); } @@ -335,7 +335,7 @@ void mha_gpt::Impl::mha_i8(const exec_param ¶m) { } } -void mha_gpt::Impl::exec(const exec_param& param) { +void mha_gpt_impl_amx::exec(const mha_gpt::exec_param& param) { if (_create_param.qkv_precision == dnnl_f32) { assert(false); } else if (_create_param.qkv_precision == dnnl_bf16) { @@ -347,16 +347,8 @@ void mha_gpt::Impl::exec(const exec_param& param) { } } -// interface -mha_gpt::mha_gpt(): _impl(std::make_shared()) { -} - -void mha_gpt::create(const create_param& param) { - _impl->create(param); -} - -void mha_gpt::exec(const exec_param& param) { - _impl->exec(param); +std::shared_ptr new_impl_amx() { + return std::make_shared(); } } \ No newline at end of file diff --git a/src/mha_gpt_amx.hpp b/src/mha_gpt_amx.hpp new file mode 100644 index 0000000..9409af9 --- /dev/null +++ b/src/mha_gpt_amx.hpp @@ -0,0 +1,17 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include "llm_types.hpp" +#include "llm_mha_gpt.hpp" + +namespace llmdnn { + +std::shared_ptr new_impl_amx(); + +} diff --git a/src/mha_gpt_api.cpp b/src/mha_gpt_api.cpp new file mode 100644 index 0000000..a979a5b --- /dev/null +++ b/src/mha_gpt_api.cpp @@ -0,0 +1,24 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include + +#include "mha_gpt_amx.hpp" + +namespace llmdnn { + +// interface +mha_gpt::mha_gpt(): _impl(new_impl_amx()) { +} + +void mha_gpt::create(const create_param& param) { + _impl->create(param); +} + +void mha_gpt::exec(const exec_param& param) { + _impl->exec(param); +} + +} \ No newline at end of file diff --git a/src/mm_interface.cpp b/src/mm_kernel_amx.cpp similarity index 91% rename from src/mm_interface.cpp rename to src/mm_kernel_amx.cpp index 76c48aa..3c0a810 100644 --- a/src/mm_interface.cpp +++ b/src/mm_kernel_amx.cpp @@ -15,8 +15,9 @@ #include #include "llm_mm.hpp" -#include "mm_kernel_amx.hpp" +#include "mm_kernel_common_amx.hpp" #include "utility_avx512.hpp" +#include "mm_kernel_amx.hpp" namespace llmdnn { @@ -35,7 +36,7 @@ struct mm_kernel { }; // interface -bool mm_kernel_create(mm_kernel** mm, const mm_create_param* param) { +bool mm_kernel_create_amx(mm_kernel** mm, const mm_create_param* param) { mm_kernel* m = nullptr; if (param == nullptr || mm == nullptr) { std::cout << "mm_kernel_create: invalid input parameter.\n"; @@ -75,13 +76,13 @@ bool mm_kernel_create(mm_kernel** mm, const mm_create_param* param) { return false; } -void mm_kernel_destroy(const mm_kernel* mm) { +void mm_kernel_destroy_amx(const mm_kernel* mm) { if (mm) { delete mm; } } -void mm_kernel_execute(const mm_kernel* mm, void* ptr_a, void* ptr_b, void* ptr_c, size_t lda, size_t ldb, size_t ldc, +void mm_kernel_execute_amx(const mm_kernel* mm, void* ptr_a, void* ptr_b, void* ptr_c, size_t lda, size_t ldb, size_t ldc, size_t M, size_t N, size_t K) { size_t b_d0 = K, b_d1 = N; if (mm->b_is_transpose) { @@ -91,7 +92,7 @@ void mm_kernel_execute(const mm_kernel* mm, void* ptr_a, void* ptr_b, void* ptr_ if (mm->i8xi8_gemv) { tensor2D a(M, K, reinterpret_cast(ptr_a), lda); (*mm->i8xi8_gemv)(a, reinterpret_cast(ptr_b), reinterpret_cast(ptr_c)); - cvt_i32_f32(reinterpret_cast(ptr_c), reinterpret_cast(ptr_c), M); + cvt_i32_f32_avx512(reinterpret_cast(ptr_c), reinterpret_cast(ptr_c), M); } else if (mm->i8xi8) { tensor2D a(M, K, reinterpret_cast(ptr_a), lda); tensor2D b(b_d0, b_d1, reinterpret_cast(ptr_b), ldb); diff --git a/src/mm_kernel_amx.hpp b/src/mm_kernel_amx.hpp index b047999..1d03818 100644 --- a/src/mm_kernel_amx.hpp +++ b/src/mm_kernel_amx.hpp @@ -1,2455 +1,28 @@ // Copyright (C) 2018-2023 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include -#pragma once +#include "llm_mm.hpp" -#include "utility_amx.hpp" -#include "tensor2d.hpp" +namespace llmdnn { -#ifdef _WIN32 -#include -#else -#include -#endif +bool mm_kernel_create_amx(mm_kernel** mm, const mm_create_param* param); -#include "bf16.hpp" -#ifdef ENABLE_NUMA -#include "numa.h" -#endif +void mm_kernel_destroy_amx(const mm_kernel* mm); -namespace amx_kernel { +void mm_kernel_execute_amx(const mm_kernel* mm, void* ptr_a, void* ptr_b, void* ptr_c, size_t lda, size_t ldb, size_t ldc, + size_t M, size_t N, size_t K); -namespace functional { - - inline void transpose_m512i_16x16(__m512i &r0, __m512i &r1, __m512i &r2, __m512i &r3, - __m512i &r4, __m512i &r5, __m512i &r6, __m512i &r7, - __m512i &r8, __m512i &r9, __m512i &ra, __m512i &rb, - __m512i &rc, __m512i &rd, __m512i &re, __m512i &rf) { - __m512i t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, ta, tb, tc, td, te, tf; - - t0 = _mm512_unpacklo_epi32(r0,r1); // 0 16 1 17 4 20 5 21 8 24 9 25 12 28 13 29 - t1 = _mm512_unpackhi_epi32(r0,r1); // 2 18 3 19 6 22 7 23 10 26 11 27 14 30 15 31 - t2 = _mm512_unpacklo_epi32(r2,r3); // 32 48 33 49 ... - t3 = _mm512_unpackhi_epi32(r2,r3); // 34 50 35 51 ... - t4 = _mm512_unpacklo_epi32(r4,r5); // 64 80 65 81 ... - t5 = _mm512_unpackhi_epi32(r4,r5); // 66 82 67 83 ... - t6 = _mm512_unpacklo_epi32(r6,r7); // 96 112 97 113 ... - t7 = _mm512_unpackhi_epi32(r6,r7); // 98 114 99 115 ... - t8 = _mm512_unpacklo_epi32(r8,r9); // 128 ... - t9 = _mm512_unpackhi_epi32(r8,r9); // 130 ... - ta = _mm512_unpacklo_epi32(ra,rb); // 160 ... - tb = _mm512_unpackhi_epi32(ra,rb); // 162 ... - tc = _mm512_unpacklo_epi32(rc,rd); // 196 ... - td = _mm512_unpackhi_epi32(rc,rd); // 198 ... - te = _mm512_unpacklo_epi32(re,rf); // 228 ... - tf = _mm512_unpackhi_epi32(re,rf); // 230 ... - - r0 = _mm512_unpacklo_epi64(t0,t2); // 0 16 32 48 ... - r1 = _mm512_unpackhi_epi64(t0,t2); // 1 17 33 49 ... - r2 = _mm512_unpacklo_epi64(t1,t3); // 2 18 34 49 ... - r3 = _mm512_unpackhi_epi64(t1,t3); // 3 19 35 51 ... - r4 = _mm512_unpacklo_epi64(t4,t6); // 64 80 96 112 ... - r5 = _mm512_unpackhi_epi64(t4,t6); // 65 81 97 114 ... - r6 = _mm512_unpacklo_epi64(t5,t7); // 66 82 98 113 ... - r7 = _mm512_unpackhi_epi64(t5,t7); // 67 83 99 115 ... - r8 = _mm512_unpacklo_epi64(t8,ta); // 128 144 160 176 ... - r9 = _mm512_unpackhi_epi64(t8,ta); // 129 145 161 178 ... - ra = _mm512_unpacklo_epi64(t9,tb); // 130 146 162 177 ... - rb = _mm512_unpackhi_epi64(t9,tb); // 131 147 163 179 ... - rc = _mm512_unpacklo_epi64(tc,te); // 192 208 228 240 ... - rd = _mm512_unpackhi_epi64(tc,te); // 193 209 229 241 ... - re = _mm512_unpacklo_epi64(td,tf); // 194 210 230 242 ... - rf = _mm512_unpackhi_epi64(td,tf); // 195 211 231 243 ... - - t0 = _mm512_shuffle_i32x4(r0, r4, 0x88); // 0 16 32 48 8 24 40 56 64 80 96 112 ... - t1 = _mm512_shuffle_i32x4(r1, r5, 0x88); // 1 17 33 49 ... - t2 = _mm512_shuffle_i32x4(r2, r6, 0x88); // 2 18 34 50 ... - t3 = _mm512_shuffle_i32x4(r3, r7, 0x88); // 3 19 35 51 ... - t4 = _mm512_shuffle_i32x4(r0, r4, 0xdd); // 4 20 36 52 ... - t5 = _mm512_shuffle_i32x4(r1, r5, 0xdd); // 5 21 37 53 ... - t6 = _mm512_shuffle_i32x4(r2, r6, 0xdd); // 6 22 38 54 ... - t7 = _mm512_shuffle_i32x4(r3, r7, 0xdd); // 7 23 39 55 ... - t8 = _mm512_shuffle_i32x4(r8, rc, 0x88); // 128 144 160 176 ... - t9 = _mm512_shuffle_i32x4(r9, rd, 0x88); // 129 145 161 177 ... - ta = _mm512_shuffle_i32x4(ra, re, 0x88); // 130 146 162 178 ... - tb = _mm512_shuffle_i32x4(rb, rf, 0x88); // 131 147 163 179 ... - tc = _mm512_shuffle_i32x4(r8, rc, 0xdd); // 132 148 164 180 ... - td = _mm512_shuffle_i32x4(r9, rd, 0xdd); // 133 149 165 181 ... - te = _mm512_shuffle_i32x4(ra, re, 0xdd); // 134 150 166 182 ... - tf = _mm512_shuffle_i32x4(rb, rf, 0xdd); // 135 151 167 183 ... - - r0 = _mm512_shuffle_i32x4(t0, t8, 0x88); // 0 16 32 48 64 80 96 112 ... 240 - r1 = _mm512_shuffle_i32x4(t1, t9, 0x88); // 1 17 33 49 66 81 97 113 ... 241 - r2 = _mm512_shuffle_i32x4(t2, ta, 0x88); // 2 18 34 50 67 82 98 114 ... 242 - r3 = _mm512_shuffle_i32x4(t3, tb, 0x88); // 3 19 35 51 68 83 99 115 ... 243 - r4 = _mm512_shuffle_i32x4(t4, tc, 0x88); // 4 ... - r5 = _mm512_shuffle_i32x4(t5, td, 0x88); // 5 ... - r6 = _mm512_shuffle_i32x4(t6, te, 0x88); // 6 ... - r7 = _mm512_shuffle_i32x4(t7, tf, 0x88); // 7 ... - r8 = _mm512_shuffle_i32x4(t0, t8, 0xdd); // 8 ... - r9 = _mm512_shuffle_i32x4(t1, t9, 0xdd); // 9 ... - ra = _mm512_shuffle_i32x4(t2, ta, 0xdd); // 10 ... - rb = _mm512_shuffle_i32x4(t3, tb, 0xdd); // 11 ... - rc = _mm512_shuffle_i32x4(t4, tc, 0xdd); // 12 ... - rd = _mm512_shuffle_i32x4(t5, td, 0xdd); // 13 ... - re = _mm512_shuffle_i32x4(t6, te, 0xdd); // 14 ... - rf = _mm512_shuffle_i32x4(t7, tf, 0xdd); // 15 31 47 63 79 96 111 127 ... 255 - } - - inline void transpose_epi32_16x16(void * _dst, const void * src, int stride) { - auto * dst = reinterpret_cast(_dst); - __m512i r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf; - auto * pA = reinterpret_cast(src); - r0 = _mm512_loadu_epi32(pA); - r1 = _mm512_loadu_epi32(pA + stride); - r2 = _mm512_loadu_epi32(pA + 2*stride); - r3 = _mm512_loadu_epi32(pA + 3*stride); - r4 = _mm512_loadu_epi32(pA + 4*stride); - r5 = _mm512_loadu_epi32(pA + 5*stride); - r6 = _mm512_loadu_epi32(pA + 6*stride); - r7 = _mm512_loadu_epi32(pA + 7*stride); - r8 = _mm512_loadu_epi32(pA + 8*stride); - r9 = _mm512_loadu_epi32(pA + 9*stride); - ra = _mm512_loadu_epi32(pA + 10*stride); - rb = _mm512_loadu_epi32(pA + 11*stride); - rc = _mm512_loadu_epi32(pA + 12*stride); - rd = _mm512_loadu_epi32(pA + 13*stride); - re = _mm512_loadu_epi32(pA + 14*stride); - rf = _mm512_loadu_epi32(pA + 15*stride); - - transpose_m512i_16x16(r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf); - - _mm512_storeu_epi32(dst, r0); - _mm512_storeu_epi32(dst + 16, r1); - _mm512_storeu_epi32(dst + 2*16, r2); - _mm512_storeu_epi32(dst + 3*16, r3); - _mm512_storeu_epi32(dst + 4*16, r4); - _mm512_storeu_epi32(dst + 5*16, r5); - _mm512_storeu_epi32(dst + 6*16, r6); - _mm512_storeu_epi32(dst + 7*16, r7); - _mm512_storeu_epi32(dst + 8*16, r8); - _mm512_storeu_epi32(dst + 9*16, r9); - _mm512_storeu_epi32(dst + 10*16, ra); - _mm512_storeu_epi32(dst + 11*16, rb); - _mm512_storeu_epi32(dst + 12*16, rc); - _mm512_storeu_epi32(dst + 13*16, rd); - _mm512_storeu_epi32(dst + 14*16, re); - _mm512_storeu_epi32(dst + 15*16, rf); - } - - // 16xN, N<=16, non-valid part is filled with zeros - inline void transpose_epi32_16xN(void * _dst, const void * src, int stride, int valid_bytes) { - auto * dst = reinterpret_cast(_dst); - __m512i r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf; - auto * pA = reinterpret_cast(src); - uint64_t mask_value = 0xFFFFFFFFFFFFFFFFull >> (64 - valid_bytes); - __mmask64 mask = _cvtu64_mask64(mask_value); - r0 = _mm512_maskz_loadu_epi8 (mask, pA); - r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); - r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); - r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); - r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); - r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); - r6 = _mm512_maskz_loadu_epi8 (mask, pA + 6*stride); - r7 = _mm512_maskz_loadu_epi8 (mask, pA + 7*stride); - r8 = _mm512_maskz_loadu_epi8 (mask, pA + 8*stride); - r9 = _mm512_maskz_loadu_epi8 (mask, pA + 9*stride); - ra = _mm512_maskz_loadu_epi8 (mask, pA + 10*stride); - rb = _mm512_maskz_loadu_epi8 (mask, pA + 11*stride); - rc = _mm512_maskz_loadu_epi8 (mask, pA + 12*stride); - rd = _mm512_maskz_loadu_epi8 (mask, pA + 13*stride); - re = _mm512_maskz_loadu_epi8 (mask, pA + 14*stride); - rf = _mm512_maskz_loadu_epi8 (mask, pA + 15*stride); - transpose_m512i_16x16(r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf); - _mm512_storeu_epi32(dst, r0); - _mm512_storeu_epi32(dst + 16, r1); - _mm512_storeu_epi32(dst + 2*16, r2); - _mm512_storeu_epi32(dst + 3*16, r3); - _mm512_storeu_epi32(dst + 4*16, r4); - _mm512_storeu_epi32(dst + 5*16, r5); - _mm512_storeu_epi32(dst + 6*16, r6); - _mm512_storeu_epi32(dst + 7*16, r7); - _mm512_storeu_epi32(dst + 8*16, r8); - _mm512_storeu_epi32(dst + 9*16, r9); - _mm512_storeu_epi32(dst + 10*16, ra); - _mm512_storeu_epi32(dst + 11*16, rb); - _mm512_storeu_epi32(dst + 12*16, rc); - _mm512_storeu_epi32(dst + 13*16, rd); - _mm512_storeu_epi32(dst + 14*16, re); - _mm512_storeu_epi32(dst + 15*16, rf); - } - - inline void transpose_epi32_Mx16(void * _dst, const void * src, int stride, int valid_m) { - auto * dst = reinterpret_cast(_dst); - __m512i r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf; - auto * pA = reinterpret_cast(src); - switch (valid_m) { - case 15: - r0 = _mm512_loadu_epi32(pA); - r1 = _mm512_loadu_epi32(pA + stride); - r2 = _mm512_loadu_epi32(pA + 2*stride); - r3 = _mm512_loadu_epi32(pA + 3*stride); - r4 = _mm512_loadu_epi32(pA + 4*stride); - r5 = _mm512_loadu_epi32(pA + 5*stride); - r6 = _mm512_loadu_epi32(pA + 6*stride); - r7 = _mm512_loadu_epi32(pA + 7*stride); - r8 = _mm512_loadu_epi32(pA + 8*stride); - r9 = _mm512_loadu_epi32(pA + 9*stride); - ra = _mm512_loadu_epi32(pA + 10*stride); - rb = _mm512_loadu_epi32(pA + 11*stride); - rc = _mm512_loadu_epi32(pA + 12*stride); - rd = _mm512_loadu_epi32(pA + 13*stride); - re = _mm512_loadu_epi32(pA + 14*stride); - rf = _mm512_setzero_epi32(); - break; - case 14: - r0 = _mm512_loadu_epi32(pA); - r1 = _mm512_loadu_epi32(pA + stride); - r2 = _mm512_loadu_epi32(pA + 2*stride); - r3 = _mm512_loadu_epi32(pA + 3*stride); - r4 = _mm512_loadu_epi32(pA + 4*stride); - r5 = _mm512_loadu_epi32(pA + 5*stride); - r6 = _mm512_loadu_epi32(pA + 6*stride); - r7 = _mm512_loadu_epi32(pA + 7*stride); - r8 = _mm512_loadu_epi32(pA + 8*stride); - r9 = _mm512_loadu_epi32(pA + 9*stride); - ra = _mm512_loadu_epi32(pA + 10*stride); - rb = _mm512_loadu_epi32(pA + 11*stride); - rc = _mm512_loadu_epi32(pA + 12*stride); - rd = _mm512_loadu_epi32(pA + 13*stride); - re = _mm512_setzero_epi32(); - rf = _mm512_setzero_epi32(); - break; - case 13: - r0 = _mm512_loadu_epi32(pA); - r1 = _mm512_loadu_epi32(pA + stride); - r2 = _mm512_loadu_epi32(pA + 2*stride); - r3 = _mm512_loadu_epi32(pA + 3*stride); - r4 = _mm512_loadu_epi32(pA + 4*stride); - r5 = _mm512_loadu_epi32(pA + 5*stride); - r6 = _mm512_loadu_epi32(pA + 6*stride); - r7 = _mm512_loadu_epi32(pA + 7*stride); - r8 = _mm512_loadu_epi32(pA + 8*stride); - r9 = _mm512_loadu_epi32(pA + 9*stride); - ra = _mm512_loadu_epi32(pA + 10*stride); - rb = _mm512_loadu_epi32(pA + 11*stride); - rc = _mm512_loadu_epi32(pA + 12*stride); - rd = _mm512_setzero_epi32(); - re = _mm512_setzero_epi32(); - rf = _mm512_setzero_epi32(); - break; - case 12: - r0 = _mm512_loadu_epi32(pA); - r1 = _mm512_loadu_epi32(pA + stride); - r2 = _mm512_loadu_epi32(pA + 2*stride); - r3 = _mm512_loadu_epi32(pA + 3*stride); - r4 = _mm512_loadu_epi32(pA + 4*stride); - r5 = _mm512_loadu_epi32(pA + 5*stride); - r6 = _mm512_loadu_epi32(pA + 6*stride); - r7 = _mm512_loadu_epi32(pA + 7*stride); - r8 = _mm512_loadu_epi32(pA + 8*stride); - r9 = _mm512_loadu_epi32(pA + 9*stride); - ra = _mm512_loadu_epi32(pA + 10*stride); - rb = _mm512_loadu_epi32(pA + 11*stride); - rc = _mm512_setzero_epi32(); - rd = _mm512_setzero_epi32(); - re = _mm512_setzero_epi32(); - rf = _mm512_setzero_epi32(); - break; - case 11: - r0 = _mm512_loadu_epi32(pA); - r1 = _mm512_loadu_epi32(pA + stride); - r2 = _mm512_loadu_epi32(pA + 2*stride); - r3 = _mm512_loadu_epi32(pA + 3*stride); - r4 = _mm512_loadu_epi32(pA + 4*stride); - r5 = _mm512_loadu_epi32(pA + 5*stride); - r6 = _mm512_loadu_epi32(pA + 6*stride); - r7 = _mm512_loadu_epi32(pA + 7*stride); - r8 = _mm512_loadu_epi32(pA + 8*stride); - r9 = _mm512_loadu_epi32(pA + 9*stride); - ra = _mm512_loadu_epi32(pA + 10*stride); - rb = _mm512_setzero_epi32(); - rc = _mm512_setzero_epi32(); - rd = _mm512_setzero_epi32(); - re = _mm512_setzero_epi32(); - rf = _mm512_setzero_epi32(); - break; - case 10: - r0 = _mm512_loadu_epi32(pA); - r1 = _mm512_loadu_epi32(pA + stride); - r2 = _mm512_loadu_epi32(pA + 2*stride); - r3 = _mm512_loadu_epi32(pA + 3*stride); - r4 = _mm512_loadu_epi32(pA + 4*stride); - r5 = _mm512_loadu_epi32(pA + 5*stride); - r6 = _mm512_loadu_epi32(pA + 6*stride); - r7 = _mm512_loadu_epi32(pA + 7*stride); - r8 = _mm512_loadu_epi32(pA + 8*stride); - r9 = _mm512_loadu_epi32(pA + 9*stride); - ra = _mm512_setzero_epi32(); - rb = _mm512_setzero_epi32(); - rc = _mm512_setzero_epi32(); - rd = _mm512_setzero_epi32(); - re = _mm512_setzero_epi32(); - rf = _mm512_setzero_epi32(); - break; - case 9: - r0 = _mm512_loadu_epi32(pA); - r1 = _mm512_loadu_epi32(pA + stride); - r2 = _mm512_loadu_epi32(pA + 2*stride); - r3 = _mm512_loadu_epi32(pA + 3*stride); - r4 = _mm512_loadu_epi32(pA + 4*stride); - r5 = _mm512_loadu_epi32(pA + 5*stride); - r6 = _mm512_loadu_epi32(pA + 6*stride); - r7 = _mm512_loadu_epi32(pA + 7*stride); - r8 = _mm512_loadu_epi32(pA + 8*stride); - r9 = _mm512_setzero_epi32(); - ra = _mm512_setzero_epi32(); - rb = _mm512_setzero_epi32(); - rc = _mm512_setzero_epi32(); - rd = _mm512_setzero_epi32(); - re = _mm512_setzero_epi32(); - rf = _mm512_setzero_epi32(); - break; - case 8: - r0 = _mm512_loadu_epi32(pA); - r1 = _mm512_loadu_epi32(pA + stride); - r2 = _mm512_loadu_epi32(pA + 2*stride); - r3 = _mm512_loadu_epi32(pA + 3*stride); - r4 = _mm512_loadu_epi32(pA + 4*stride); - r5 = _mm512_loadu_epi32(pA + 5*stride); - r6 = _mm512_loadu_epi32(pA + 6*stride); - r7 = _mm512_loadu_epi32(pA + 7*stride); - r8 = _mm512_setzero_epi32(); - r9 = _mm512_setzero_epi32(); - ra = _mm512_setzero_epi32(); - rb = _mm512_setzero_epi32(); - rc = _mm512_setzero_epi32(); - rd = _mm512_setzero_epi32(); - re = _mm512_setzero_epi32(); - rf = _mm512_setzero_epi32(); - break; - case 7: - r0 = _mm512_loadu_epi32(pA); - r1 = _mm512_loadu_epi32(pA + stride); - r2 = _mm512_loadu_epi32(pA + 2*stride); - r3 = _mm512_loadu_epi32(pA + 3*stride); - r4 = _mm512_loadu_epi32(pA + 4*stride); - r5 = _mm512_loadu_epi32(pA + 5*stride); - r6 = _mm512_loadu_epi32(pA + 6*stride); - r7 = _mm512_setzero_epi32(); - r8 = _mm512_setzero_epi32(); - r9 = _mm512_setzero_epi32(); - ra = _mm512_setzero_epi32(); - rb = _mm512_setzero_epi32(); - rc = _mm512_setzero_epi32(); - rd = _mm512_setzero_epi32(); - re = _mm512_setzero_epi32(); - rf = _mm512_setzero_epi32(); - break; - case 6: - r0 = _mm512_loadu_epi32(pA); - r1 = _mm512_loadu_epi32(pA + stride); - r2 = _mm512_loadu_epi32(pA + 2*stride); - r3 = _mm512_loadu_epi32(pA + 3*stride); - r4 = _mm512_loadu_epi32(pA + 4*stride); - r5 = _mm512_loadu_epi32(pA + 5*stride); - r6 = _mm512_setzero_epi32(); - r7 = _mm512_setzero_epi32(); - r8 = _mm512_setzero_epi32(); - r9 = _mm512_setzero_epi32(); - ra = _mm512_setzero_epi32(); - rb = _mm512_setzero_epi32(); - rc = _mm512_setzero_epi32(); - rd = _mm512_setzero_epi32(); - re = _mm512_setzero_epi32(); - rf = _mm512_setzero_epi32(); - break; - case 5: - r0 = _mm512_loadu_epi32(pA); - r1 = _mm512_loadu_epi32(pA + stride); - r2 = _mm512_loadu_epi32(pA + 2*stride); - r3 = _mm512_loadu_epi32(pA + 3*stride); - r4 = _mm512_loadu_epi32(pA + 4*stride); - r5 = _mm512_setzero_epi32(); - r6 = _mm512_setzero_epi32(); - r7 = _mm512_setzero_epi32(); - r8 = _mm512_setzero_epi32(); - r9 = _mm512_setzero_epi32(); - ra = _mm512_setzero_epi32(); - rb = _mm512_setzero_epi32(); - rc = _mm512_setzero_epi32(); - rd = _mm512_setzero_epi32(); - re = _mm512_setzero_epi32(); - rf = _mm512_setzero_epi32(); - break; - case 4: - r0 = _mm512_loadu_epi32(pA); - r1 = _mm512_loadu_epi32(pA + stride); - r2 = _mm512_loadu_epi32(pA + 2*stride); - r3 = _mm512_loadu_epi32(pA + 3*stride); - r4 = _mm512_setzero_epi32(); - r5 = _mm512_setzero_epi32(); - r6 = _mm512_setzero_epi32(); - r7 = _mm512_setzero_epi32(); - r8 = _mm512_setzero_epi32(); - r9 = _mm512_setzero_epi32(); - ra = _mm512_setzero_epi32(); - rb = _mm512_setzero_epi32(); - rc = _mm512_setzero_epi32(); - rd = _mm512_setzero_epi32(); - re = _mm512_setzero_epi32(); - rf = _mm512_setzero_epi32(); - break; - case 3: - r0 = _mm512_loadu_epi32(pA); - r1 = _mm512_loadu_epi32(pA + stride); - r2 = _mm512_loadu_epi32(pA + 2*stride); - r3 = _mm512_setzero_epi32(); - r4 = _mm512_setzero_epi32(); - r5 = _mm512_setzero_epi32(); - r6 = _mm512_setzero_epi32(); - r7 = _mm512_setzero_epi32(); - r8 = _mm512_setzero_epi32(); - r9 = _mm512_setzero_epi32(); - ra = _mm512_setzero_epi32(); - rb = _mm512_setzero_epi32(); - rc = _mm512_setzero_epi32(); - rd = _mm512_setzero_epi32(); - re = _mm512_setzero_epi32(); - rf = _mm512_setzero_epi32(); - break; - case 2: - r0 = _mm512_loadu_epi32(pA); - r1 = _mm512_loadu_epi32(pA + stride); - r2 = _mm512_setzero_epi32(); - r3 = _mm512_setzero_epi32(); - r4 = _mm512_setzero_epi32(); - r5 = _mm512_setzero_epi32(); - r6 = _mm512_setzero_epi32(); - r7 = _mm512_setzero_epi32(); - r8 = _mm512_setzero_epi32(); - r9 = _mm512_setzero_epi32(); - ra = _mm512_setzero_epi32(); - rb = _mm512_setzero_epi32(); - rc = _mm512_setzero_epi32(); - rd = _mm512_setzero_epi32(); - re = _mm512_setzero_epi32(); - rf = _mm512_setzero_epi32(); - break; - case 1: - r0 = _mm512_loadu_epi32(pA); - r1 = _mm512_setzero_epi32(); - r2 = _mm512_setzero_epi32(); - r3 = _mm512_setzero_epi32(); - r4 = _mm512_setzero_epi32(); - r5 = _mm512_setzero_epi32(); - r6 = _mm512_setzero_epi32(); - r7 = _mm512_setzero_epi32(); - r8 = _mm512_setzero_epi32(); - r9 = _mm512_setzero_epi32(); - ra = _mm512_setzero_epi32(); - rb = _mm512_setzero_epi32(); - rc = _mm512_setzero_epi32(); - rd = _mm512_setzero_epi32(); - re = _mm512_setzero_epi32(); - rf = _mm512_setzero_epi32(); - break; - } - - transpose_m512i_16x16(r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf); - - _mm512_storeu_epi32(dst, r0); - _mm512_storeu_epi32(dst + 16, r1); - _mm512_storeu_epi32(dst + 2*16, r2); - _mm512_storeu_epi32(dst + 3*16, r3); - _mm512_storeu_epi32(dst + 4*16, r4); - _mm512_storeu_epi32(dst + 5*16, r5); - _mm512_storeu_epi32(dst + 6*16, r6); - _mm512_storeu_epi32(dst + 7*16, r7); - _mm512_storeu_epi32(dst + 8*16, r8); - _mm512_storeu_epi32(dst + 9*16, r9); - _mm512_storeu_epi32(dst + 10*16, ra); - _mm512_storeu_epi32(dst + 11*16, rb); - _mm512_storeu_epi32(dst + 12*16, rc); - _mm512_storeu_epi32(dst + 13*16, rd); - _mm512_storeu_epi32(dst + 14*16, re); - _mm512_storeu_epi32(dst + 15*16, rf); - } - - // 16xN, N<=16, non-valid part is on the left, filled with zeros - inline void transpose_epi32_16xN_right_align(void * _dst, const void * src, int stride, int valid_bytes) { - auto * dst = reinterpret_cast(_dst); - __m512i r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf; - int invalid_bytes = 64 - valid_bytes; - auto * pA = reinterpret_cast(src) - invalid_bytes; - uint64_t mask_value = 0xFFFFFFFFFFFFFFFFull << invalid_bytes; - __mmask64 mask = _cvtu64_mask64(mask_value); - r0 = _mm512_maskz_loadu_epi8 (mask, pA); - r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); - r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); - r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); - r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); - r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); - r6 = _mm512_maskz_loadu_epi8 (mask, pA + 6*stride); - r7 = _mm512_maskz_loadu_epi8 (mask, pA + 7*stride); - r8 = _mm512_maskz_loadu_epi8 (mask, pA + 8*stride); - r9 = _mm512_maskz_loadu_epi8 (mask, pA + 9*stride); - ra = _mm512_maskz_loadu_epi8 (mask, pA + 10*stride); - rb = _mm512_maskz_loadu_epi8 (mask, pA + 11*stride); - rc = _mm512_maskz_loadu_epi8 (mask, pA + 12*stride); - rd = _mm512_maskz_loadu_epi8 (mask, pA + 13*stride); - re = _mm512_maskz_loadu_epi8 (mask, pA + 14*stride); - rf = _mm512_maskz_loadu_epi8 (mask, pA + 15*stride); - transpose_m512i_16x16(r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf); - _mm512_storeu_epi32(dst, r0); - _mm512_storeu_epi32(dst + 16, r1); - _mm512_storeu_epi32(dst + 2*16, r2); - _mm512_storeu_epi32(dst + 3*16, r3); - _mm512_storeu_epi32(dst + 4*16, r4); - _mm512_storeu_epi32(dst + 5*16, r5); - _mm512_storeu_epi32(dst + 6*16, r6); - _mm512_storeu_epi32(dst + 7*16, r7); - _mm512_storeu_epi32(dst + 8*16, r8); - _mm512_storeu_epi32(dst + 9*16, r9); - _mm512_storeu_epi32(dst + 10*16, ra); - _mm512_storeu_epi32(dst + 11*16, rb); - _mm512_storeu_epi32(dst + 12*16, rc); - _mm512_storeu_epi32(dst + 13*16, rd); - _mm512_storeu_epi32(dst + 14*16, re); - _mm512_storeu_epi32(dst + 15*16, rf); - } - - inline void transpose_epi32_MxN_right_align(void * _dst, const void * src, int stride, int valid_bytes, int valid_m) { - auto * dst = reinterpret_cast(_dst); - __m512i r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf; - int invalid_bytes = 64 - valid_bytes; - auto * pA = reinterpret_cast(src) - invalid_bytes; - uint64_t mask_value = 0xFFFFFFFFFFFFFFFFull << invalid_bytes; - __mmask64 mask = _cvtu64_mask64(mask_value); - switch (valid_m) { - case 15: - r0 = _mm512_maskz_loadu_epi8 (mask, pA); - r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); - r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); - r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); - r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); - r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); - r6 = _mm512_maskz_loadu_epi8 (mask, pA + 6*stride); - r7 = _mm512_maskz_loadu_epi8 (mask, pA + 7*stride); - r8 = _mm512_maskz_loadu_epi8 (mask, pA + 8*stride); - r9 = _mm512_maskz_loadu_epi8 (mask, pA + 9*stride); - ra = _mm512_maskz_loadu_epi8 (mask, pA + 10*stride); - rb = _mm512_maskz_loadu_epi8 (mask, pA + 11*stride); - rc = _mm512_maskz_loadu_epi8 (mask, pA + 12*stride); - rd = _mm512_maskz_loadu_epi8 (mask, pA + 13*stride); - re = _mm512_maskz_loadu_epi8 (mask, pA + 14*stride); - rf = _mm512_setzero_epi32(); - break; - case 14: - r0 = _mm512_maskz_loadu_epi8 (mask, pA); - r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); - r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); - r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); - r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); - r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); - r6 = _mm512_maskz_loadu_epi8 (mask, pA + 6*stride); - r7 = _mm512_maskz_loadu_epi8 (mask, pA + 7*stride); - r8 = _mm512_maskz_loadu_epi8 (mask, pA + 8*stride); - r9 = _mm512_maskz_loadu_epi8 (mask, pA + 9*stride); - ra = _mm512_maskz_loadu_epi8 (mask, pA + 10*stride); - rb = _mm512_maskz_loadu_epi8 (mask, pA + 11*stride); - rc = _mm512_maskz_loadu_epi8 (mask, pA + 12*stride); - rd = _mm512_maskz_loadu_epi8 (mask, pA + 13*stride); - re = _mm512_setzero_epi32(); - rf = _mm512_setzero_epi32(); - break; - case 13: - r0 = _mm512_maskz_loadu_epi8 (mask, pA); - r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); - r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); - r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); - r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); - r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); - r6 = _mm512_maskz_loadu_epi8 (mask, pA + 6*stride); - r7 = _mm512_maskz_loadu_epi8 (mask, pA + 7*stride); - r8 = _mm512_maskz_loadu_epi8 (mask, pA + 8*stride); - r9 = _mm512_maskz_loadu_epi8 (mask, pA + 9*stride); - ra = _mm512_maskz_loadu_epi8 (mask, pA + 10*stride); - rb = _mm512_maskz_loadu_epi8 (mask, pA + 11*stride); - rc = _mm512_maskz_loadu_epi8 (mask, pA + 12*stride); - rd = _mm512_setzero_epi32(); - re = _mm512_setzero_epi32(); - rf = _mm512_setzero_epi32(); - break; - case 12: - r0 = _mm512_maskz_loadu_epi8 (mask, pA); - r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); - r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); - r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); - r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); - r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); - r6 = _mm512_maskz_loadu_epi8 (mask, pA + 6*stride); - r7 = _mm512_maskz_loadu_epi8 (mask, pA + 7*stride); - r8 = _mm512_maskz_loadu_epi8 (mask, pA + 8*stride); - r9 = _mm512_maskz_loadu_epi8 (mask, pA + 9*stride); - ra = _mm512_maskz_loadu_epi8 (mask, pA + 10*stride); - rb = _mm512_maskz_loadu_epi8 (mask, pA + 11*stride); - rc = _mm512_setzero_epi32(); - rd = _mm512_setzero_epi32(); - re = _mm512_setzero_epi32(); - rf = _mm512_setzero_epi32(); - break; - case 11: - r0 = _mm512_maskz_loadu_epi8 (mask, pA); - r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); - r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); - r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); - r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); - r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); - r6 = _mm512_maskz_loadu_epi8 (mask, pA + 6*stride); - r7 = _mm512_maskz_loadu_epi8 (mask, pA + 7*stride); - r8 = _mm512_maskz_loadu_epi8 (mask, pA + 8*stride); - r9 = _mm512_maskz_loadu_epi8 (mask, pA + 9*stride); - ra = _mm512_maskz_loadu_epi8 (mask, pA + 10*stride); - rb = _mm512_setzero_epi32(); - rc = _mm512_setzero_epi32(); - rd = _mm512_setzero_epi32(); - re = _mm512_setzero_epi32(); - rf = _mm512_setzero_epi32(); - break; - case 10: - r0 = _mm512_maskz_loadu_epi8 (mask, pA); - r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); - r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); - r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); - r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); - r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); - r6 = _mm512_maskz_loadu_epi8 (mask, pA + 6*stride); - r7 = _mm512_maskz_loadu_epi8 (mask, pA + 7*stride); - r8 = _mm512_maskz_loadu_epi8 (mask, pA + 8*stride); - r9 = _mm512_maskz_loadu_epi8 (mask, pA + 9*stride); - ra = _mm512_setzero_epi32(); - rb = _mm512_setzero_epi32(); - rc = _mm512_setzero_epi32(); - rd = _mm512_setzero_epi32(); - re = _mm512_setzero_epi32(); - rf = _mm512_setzero_epi32(); - break; - case 9: - r0 = _mm512_maskz_loadu_epi8 (mask, pA); - r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); - r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); - r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); - r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); - r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); - r6 = _mm512_maskz_loadu_epi8 (mask, pA + 6*stride); - r7 = _mm512_maskz_loadu_epi8 (mask, pA + 7*stride); - r8 = _mm512_maskz_loadu_epi8 (mask, pA + 8*stride); - r9 = _mm512_setzero_epi32(); - ra = _mm512_setzero_epi32(); - rb = _mm512_setzero_epi32(); - rc = _mm512_setzero_epi32(); - rd = _mm512_setzero_epi32(); - re = _mm512_setzero_epi32(); - rf = _mm512_setzero_epi32(); - break; - case 8: - r0 = _mm512_maskz_loadu_epi8 (mask, pA); - r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); - r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); - r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); - r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); - r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); - r6 = _mm512_maskz_loadu_epi8 (mask, pA + 6*stride); - r7 = _mm512_maskz_loadu_epi8 (mask, pA + 7*stride); - r8 = _mm512_setzero_epi32(); - r9 = _mm512_setzero_epi32(); - ra = _mm512_setzero_epi32(); - rb = _mm512_setzero_epi32(); - rc = _mm512_setzero_epi32(); - rd = _mm512_setzero_epi32(); - re = _mm512_setzero_epi32(); - rf = _mm512_setzero_epi32(); - break; - case 7: - r0 = _mm512_maskz_loadu_epi8 (mask, pA); - r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); - r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); - r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); - r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); - r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); - r6 = _mm512_maskz_loadu_epi8 (mask, pA + 6*stride); - r7 = _mm512_setzero_epi32(); - r8 = _mm512_setzero_epi32(); - r9 = _mm512_setzero_epi32(); - ra = _mm512_setzero_epi32(); - rb = _mm512_setzero_epi32(); - rc = _mm512_setzero_epi32(); - rd = _mm512_setzero_epi32(); - re = _mm512_setzero_epi32(); - rf = _mm512_setzero_epi32(); - break; - case 6: - r0 = _mm512_maskz_loadu_epi8 (mask, pA); - r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); - r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); - r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); - r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); - r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); - r6 = _mm512_setzero_epi32(); - r7 = _mm512_setzero_epi32(); - r8 = _mm512_setzero_epi32(); - r9 = _mm512_setzero_epi32(); - ra = _mm512_setzero_epi32(); - rb = _mm512_setzero_epi32(); - rc = _mm512_setzero_epi32(); - rd = _mm512_setzero_epi32(); - re = _mm512_setzero_epi32(); - rf = _mm512_setzero_epi32(); - break; - case 5: - r0 = _mm512_maskz_loadu_epi8 (mask, pA); - r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); - r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); - r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); - r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); - r5 = _mm512_setzero_epi32(); - r6 = _mm512_setzero_epi32(); - r7 = _mm512_setzero_epi32(); - r8 = _mm512_setzero_epi32(); - r9 = _mm512_setzero_epi32(); - ra = _mm512_setzero_epi32(); - rb = _mm512_setzero_epi32(); - rc = _mm512_setzero_epi32(); - rd = _mm512_setzero_epi32(); - re = _mm512_setzero_epi32(); - rf = _mm512_setzero_epi32(); - break; - case 4: - r0 = _mm512_maskz_loadu_epi8 (mask, pA); - r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); - r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); - r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); - r4 = _mm512_setzero_epi32(); - r5 = _mm512_setzero_epi32(); - r6 = _mm512_setzero_epi32(); - r7 = _mm512_setzero_epi32(); - r8 = _mm512_setzero_epi32(); - r9 = _mm512_setzero_epi32(); - ra = _mm512_setzero_epi32(); - rb = _mm512_setzero_epi32(); - rc = _mm512_setzero_epi32(); - rd = _mm512_setzero_epi32(); - re = _mm512_setzero_epi32(); - rf = _mm512_setzero_epi32(); - break; - case 3: - r0 = _mm512_maskz_loadu_epi8 (mask, pA); - r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); - r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); - r3 = _mm512_setzero_epi32(); - r4 = _mm512_setzero_epi32(); - r5 = _mm512_setzero_epi32(); - r6 = _mm512_setzero_epi32(); - r7 = _mm512_setzero_epi32(); - r8 = _mm512_setzero_epi32(); - r9 = _mm512_setzero_epi32(); - ra = _mm512_setzero_epi32(); - rb = _mm512_setzero_epi32(); - rc = _mm512_setzero_epi32(); - rd = _mm512_setzero_epi32(); - re = _mm512_setzero_epi32(); - rf = _mm512_setzero_epi32(); - break; - case 2: - r0 = _mm512_maskz_loadu_epi8 (mask, pA); - r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); - r2 = _mm512_setzero_epi32(); - r3 = _mm512_setzero_epi32(); - r4 = _mm512_setzero_epi32(); - r5 = _mm512_setzero_epi32(); - r6 = _mm512_setzero_epi32(); - r7 = _mm512_setzero_epi32(); - r8 = _mm512_setzero_epi32(); - r9 = _mm512_setzero_epi32(); - ra = _mm512_setzero_epi32(); - rb = _mm512_setzero_epi32(); - rc = _mm512_setzero_epi32(); - rd = _mm512_setzero_epi32(); - re = _mm512_setzero_epi32(); - rf = _mm512_setzero_epi32(); - break; - case 1: - r0 = _mm512_maskz_loadu_epi8 (mask, pA); - r1 = _mm512_setzero_epi32(); - r2 = _mm512_setzero_epi32(); - r3 = _mm512_setzero_epi32(); - r4 = _mm512_setzero_epi32(); - r5 = _mm512_setzero_epi32(); - r6 = _mm512_setzero_epi32(); - r7 = _mm512_setzero_epi32(); - r8 = _mm512_setzero_epi32(); - r9 = _mm512_setzero_epi32(); - ra = _mm512_setzero_epi32(); - rb = _mm512_setzero_epi32(); - rc = _mm512_setzero_epi32(); - rd = _mm512_setzero_epi32(); - re = _mm512_setzero_epi32(); - rf = _mm512_setzero_epi32(); - break; - } - transpose_m512i_16x16(r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf); - _mm512_storeu_epi32(dst, r0); - _mm512_storeu_epi32(dst + 16, r1); - _mm512_storeu_epi32(dst + 2*16, r2); - _mm512_storeu_epi32(dst + 3*16, r3); - _mm512_storeu_epi32(dst + 4*16, r4); - _mm512_storeu_epi32(dst + 5*16, r5); - _mm512_storeu_epi32(dst + 6*16, r6); - _mm512_storeu_epi32(dst + 7*16, r7); - _mm512_storeu_epi32(dst + 8*16, r8); - _mm512_storeu_epi32(dst + 9*16, r9); - _mm512_storeu_epi32(dst + 10*16, ra); - _mm512_storeu_epi32(dst + 11*16, rb); - _mm512_storeu_epi32(dst + 12*16, rc); - _mm512_storeu_epi32(dst + 13*16, rd); - _mm512_storeu_epi32(dst + 14*16, re); - _mm512_storeu_epi32(dst + 15*16, rf); - } - - // gelu_erf_minimax_approx_compute_vector_fwd in oneDNN - // x*0.5*(1+erf(x/sqrt(2))) = x*0.5*(1 + x*Polynomial(x^2)) - inline __m512 gelu_erf_minmax_approx(__m512 & x) { - auto x2 = _mm512_mul_ps(x, x); // x^2 - - auto x_positive = _mm512_castsi512_ps(_mm512_and_epi32(_mm512_castps_si512(x), _mm512_set1_epi32(0x7FFFFFFF))); // clear sign mask - auto x_half = _mm512_mul_ps(x, _mm512_set1_ps(0.5f)); - - auto poly = _mm512_castsi512_ps(_mm512_set1_epi32(0x1f1c83fd)); - poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0xa3198977))); // poly * x^2 + xxx - poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0x268a7927))); - poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0xa998c963))); - poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0x2c67ddb2))); - poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0xaf013b2c))); - poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0x315d4a4f))); - poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0xb3969b11))); - poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0x35a776e9))); - poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0xb79b0914))); - poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0x3970b255))); - poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0xbb1b7399))); - poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0x3ca3621f))); - poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0xbe082bc7))); - poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0x3f4c4228))); - - // 1.0f + erf(x * inv_sqrt2) = 1.0f + x * P(x^2) - poly = _mm512_fmadd_ps(poly, x, _mm512_set1_ps(1.0f)); - // x*0.5*(1 + x*Polynomial(x^2)) - poly = _mm512_mul_ps(poly, x_half); - - // combine: - // zone_id - // 1 -inf; -saturation_lbound : 0.0f - // 2 -saturation_lbound; -linear_ubound : x*0.5*(1 + x*Polynomial(x^2)) - // 3 -linear_ubound, linear_ubound : x*0.5 - // 4 linear_ubound : saturation_lbound : x*0.5*(1 + x*Polynomial(x^2)) - // 5 saturation_lbound: +inf : x - constexpr int neg_saturation_lbound = 0xc0a00000; - constexpr int linear_ubound = 0x33800000; - constexpr int saturation_lbound = 0x40a00000; - - auto mask_x_not_zone1 = _mm512_cmpnlt_ps_mask(x, _mm512_castsi512_ps(_mm512_set1_epi32(neg_saturation_lbound))); - x = _mm512_maskz_mov_ps(mask_x_not_zone1, x); - - auto mask_x_in_zone5 = _mm512_cmpnle_ps_mask(x_positive, _mm512_castsi512_ps(_mm512_set1_epi32(saturation_lbound))); - poly = _mm512_mask_mov_ps(poly, mask_x_in_zone5, x); - - auto mask_x_in_zone3 = _mm512_cmple_ps_mask(x_positive, _mm512_castsi512_ps(_mm512_set1_epi32(linear_ubound))); - poly = _mm512_mask_mov_ps(poly, mask_x_in_zone3, x_half); - return poly; - } - - inline void kpack_tile_B0B1(void * _dst0, void * _dst1, const int8_t * _src, int stride, int src_rows) { - #define FROM_B(i) ((1<<4)|(i)) - static const uint32_t idx[16] = { 0,4,FROM_B(0),FROM_B(4), - 1,5,FROM_B(1),FROM_B(5), - 2,6,FROM_B(2),FROM_B(6), - 3,7,FROM_B(3),FROM_B(7)}; - auto midx = _mm512_loadu_epi64(idx); - __mmask16 mask = _cvtu32_mask16(0xFFFFu); - const auto * src = reinterpret_cast(_src); - auto * dst0 = reinterpret_cast(_dst0); - auto * dst1 = reinterpret_cast(_dst1); - if (src_rows == 64) { - for (int row = 0; row < 16; row++) { - // each element (a? or b?) is 32-bits, two lanes in each ymm register - auto a256 = _mm256_loadu_epi8(src); src += stride; // [a0 a1 a2 a3 | a4 a5 a6 a7] 256-bits ymm0 B0: a0-a3 B1: a4:a7 - auto b256 = _mm256_loadu_epi8(src); src += stride; // [b0 b1 b2 b3 | b4 b5 b6 b7] 256-bits ymm1 B0: b0-b3 B1: b4:b7 - auto c256 = _mm256_loadu_epi8(src); src += stride; // [c0 c1 c2 c3 | c4 c5 c6 c7] 256-bits ymm2 B0: c0-c3 B1: c4:c7 - auto d256 = _mm256_loadu_epi8(src); src += stride; // [d0 d1 d2 d3 | d4 d5 d6 d7] 256-bits ymm3 B0: d0-d3 B1: d4:d7 - auto a = _mm512_castsi256_si512(a256); - auto b = _mm512_castsi256_si512(b256); - auto c = _mm512_castsi256_si512(c256); - auto d = _mm512_castsi256_si512(d256); - auto ac = _mm512_mask_permutex2var_epi32(a, mask, midx, c); // [a0 a4 c0 c4 | a1 a5 c1 c5 | a2 a6 c2 c6 | a3 a7 c3 c7] - auto bd = _mm512_mask_permutex2var_epi32(b, mask, midx, d); // [b0 b4 d0 d4 | b1 b5 d1 d5 | b2 b6 d2 d6 | b3 b7 d3 d7] - auto aib = _mm512_unpacklo_epi8(ac, bd); // [a0&b0 a4&b4 | a1&b1 a5&b5 | a2&b2 a6&b6 | a3&b3 a7&b7] - auto cid = _mm512_unpackhi_epi8(ac, bd); // [c0&d0 c4&d4 | c1&d1 c5&d5 | c2&d2 c6&d6 | c3&d3 c7&d7] - auto rowB0 = _mm512_unpacklo_epi16(aib, cid); // [a0&b0&c0&d0 | a1&b1&c1&d1 | a2&b2&c2&d2 | a3&b3&c3&d3] 512-bit (64bytes) line in B0 - auto rowB1 = _mm512_unpackhi_epi16(aib, cid); // [a4&b4&c4&d4 | a5&b5&c5&d5 | a6&b6&c6&d6 | a7&b7&c7&d7] 512-bit (64bytes) line in B1 - _mm512_storeu_epi16(dst0, rowB0); - _mm512_storeu_epi16(dst1, rowB1); - dst0 += 64; - dst1 += 64; - } - } else { - // less than 64 source lines, - int allzero_dst_rows = (64-src_rows)/4; - int allnonzero_dst_rows = src_rows/4; - // padding zeros at the top - auto rowB0 = _mm512_setzero_si512(); - auto rowB1 = _mm512_setzero_si512(); - for(int i = 0; i < allzero_dst_rows ; i++) { - _mm512_storeu_epi16(dst0, rowB0); - _mm512_storeu_epi16(dst1, rowB1); - dst0 += 64; - dst1 += 64; - } - // mixed row - int tails_nz = (src_rows & 3); - if (tails_nz) { - __mmask32 kmask1 = _cvtu32_mask32(0xFFFFFFFF); - auto a256 = _mm256_setzero_si256(); // must be zero - auto b256 = _mm256_setzero_si256(); // when tails_nz > 2 - auto c256 = _mm256_setzero_si256(); // when tails_nz > 1 - auto d256 = _mm256_setzero_si256(); // when tails_nz > 0(always load) - if (tails_nz > 2) { - b256 = _mm256_loadu_epi8 (src); src += stride; - } - if (tails_nz > 1) { - c256 = _mm256_loadu_epi8 (src); src += stride; - } - d256 = _mm256_loadu_epi8 (src); src += stride; - auto a = _mm512_castsi256_si512(a256); - auto b = _mm512_castsi256_si512(b256); - auto c = _mm512_castsi256_si512(c256); - auto d = _mm512_castsi256_si512(d256); - auto ac = _mm512_mask_permutex2var_epi32(a, mask, midx, c); // [a0 a4 c0 c4 | a1 a5 c1 c5 | a2 a6 c2 c6 | a3 a7 c3 c7] - auto bd = _mm512_mask_permutex2var_epi32(b, mask, midx, d); // [b0 b4 d0 d4 | b1 b5 d1 d5 | b2 b6 d2 d6 | b3 b7 d3 d7] - auto aib = _mm512_unpacklo_epi8(ac, bd); // [a0&b0 a4&b4 | a1&b1 a5&b5 | a2&b2 a6&b6 | a3&b3 a7&b7] - auto cid = _mm512_unpackhi_epi8(ac, bd); // [c0&d0 c4&d4 | c1&d1 c5&d5 | c2&d2 c6&d6 | c3&d3 c7&d7] - auto rowB0 = _mm512_unpacklo_epi16(aib, cid); // [a0&b0&c0&d0 | a1&b1&c1&d1 | a2&b2&c2&d2 | a3&b3&c3&d3] 512-bit (64bytes) line in B0 - auto rowB1 = _mm512_unpackhi_epi16(aib, cid); // [a4&b4&c4&d4 | a5&b5&c5&d5 | a6&b6&c6&d6 | a7&b7&c7&d7] 512-bit (64bytes) line in B1 - _mm512_storeu_epi16(dst0, rowB0); - _mm512_storeu_epi16(dst1, rowB1); - dst0 += 64; - dst1 += 64; - } - // all non zeros - for (int i = 0; i < allnonzero_dst_rows; i++) { - auto a256 = _mm256_loadu_epi8 (src); src += stride; - auto b256 = _mm256_loadu_epi8 (src); src += stride; - auto c256 = _mm256_loadu_epi8 (src); src += stride; - auto d256 = _mm256_loadu_epi8 (src); src += stride; - auto a = _mm512_castsi256_si512(a256); - auto b = _mm512_castsi256_si512(b256); - auto c = _mm512_castsi256_si512(c256); - auto d = _mm512_castsi256_si512(d256); - auto ac = _mm512_mask_permutex2var_epi32(a, mask, midx, c); // [a0 a4 c0 c4 | a1 a5 c1 c5 | a2 a6 c2 c6 | a3 a7 c3 c7] - auto bd = _mm512_mask_permutex2var_epi32(b, mask, midx, d); // [b0 b4 d0 d4 | b1 b5 d1 d5 | b2 b6 d2 d6 | b3 b7 d3 d7] - auto aib = _mm512_unpacklo_epi8(ac, bd); // [a0&b0 a4&b4 | a1&b1 a5&b5 | a2&b2 a6&b6 | a3&b3 a7&b7] - auto cid = _mm512_unpackhi_epi8(ac, bd); // [c0&d0 c4&d4 | c1&d1 c5&d5 | c2&d2 c6&d6 | c3&d3 c7&d7] - auto rowB0 = _mm512_unpacklo_epi16(aib, cid); // [a0&b0&c0&d0 | a1&b1&c1&d1 | a2&b2&c2&d2 | a3&b3&c3&d3] 512-bit (64bytes) line in B0 - auto rowB1 = _mm512_unpackhi_epi16(aib, cid); // [a4&b4&c4&d4 | a5&b5&c5&d5 | a6&b6&c6&d6 | a7&b7&c7&d7] 512-bit (64bytes) line in B1 - _mm512_storeu_epi16(dst0, rowB0); - _mm512_storeu_epi16(dst1, rowB1); - dst0 += 64; - dst1 += 64; - } - } - } - - inline void kpack_tile_B0B1(void * _dst0, void * _dst1, const ov::bfloat16 * _src, int stride, int src_rows) { - static const uint64_t idx[8] = {0,4,1,5,2,6,3,7}; - auto midx = _mm512_loadu_epi64(idx); - const auto * src = reinterpret_cast(_src); - auto * dst0 = reinterpret_cast(_dst0); - auto * dst1 = reinterpret_cast(_dst1); - __m512i a,b,rowB0, rowB1; - if (src_rows == 32) { - for (int row = 0; row < 16; row++) { - a = _mm512_loadu_epi16(src); // [a1 a2 a3 a4 | a5 a6 a7 a8] total 512-bits in 8 64bits unit - b = _mm512_loadu_epi16(src + stride); // [b1 b2 b3 b4 | b5 b6 b7 b8] total 512-bits - a = _mm512_permutexvar_epi64(midx, a); // [a1 a5 | a2 a6 | a3 a7 | a4 a8] - b = _mm512_permutexvar_epi64(midx, b); // [b1 b5 | b2 b6 | b3 b7 | b4 b8] - rowB0 = _mm512_unpacklo_epi16(a, b); // [ a1&b1 a2&b2 a3&b3 a4&b4] for each 128-bits lane, interleave word in low 64 bits - rowB1 = _mm512_unpackhi_epi16(a, b); // [ a5&b5 a6&b6 a7&b7 a8&b8] for each 128-bits lane, interleave word in high 64 bits - _mm512_storeu_epi16(dst0, rowB0); - _mm512_storeu_epi16(dst1, rowB1); - src += 2*stride; - dst0 += 64; - dst1 += 64; - } - } else { - int allzero_dst_rows = (32-src_rows)/2; - int allnonzero_dst_rows = src_rows/2; - - rowB0 = _mm512_setzero_si512(); - rowB1 = _mm512_setzero_si512(); - for(int i = 0; i < allzero_dst_rows ; i++) { - _mm512_storeu_epi16(dst0, rowB0); - _mm512_storeu_epi16(dst1, rowB1); - dst0 += 64; - dst1 += 64; - } - // mixed row - if (src_rows & 1) { - a = _mm512_setzero_si512(); - b = _mm512_loadu_epi16(src); src += stride; - a = _mm512_permutexvar_epi64(midx, a); - b = _mm512_permutexvar_epi64(midx, b); - auto rowB0 = _mm512_unpacklo_epi16(a, b); - auto rowB1 = _mm512_unpackhi_epi16(a, b); - _mm512_storeu_epi16(dst0, rowB0); - _mm512_storeu_epi16(dst1, rowB1); - dst0 += 64; - dst1 += 64; - } - // all non-zero rows - for (int i = 0; i < allnonzero_dst_rows; i++) { - a = _mm512_loadu_epi16(src); - b = _mm512_loadu_epi16(src + stride); - a = _mm512_permutexvar_epi64(midx, a); - b = _mm512_permutexvar_epi64(midx, b); - rowB0 = _mm512_unpacklo_epi16(a, b); - rowB1 = _mm512_unpackhi_epi16(a, b); - _mm512_storeu_epi16(dst0, rowB0); - _mm512_storeu_epi16(dst1, rowB1); - src += 2*stride; - dst0 += 64; - dst1 += 64; - } - } - } - - inline void kpack_tile_B0B1_ntail(void * _dst0, void * _dst1, const int8_t * _src, int stride, int src_rows, int valid_n) { - #define FROM_B(i) ((1<<4)|(i)) - static const uint32_t idx[16] = { 0,4,FROM_B(0),FROM_B(4), - 1,5,FROM_B(1),FROM_B(5), - 2,6,FROM_B(2),FROM_B(6), - 3,7,FROM_B(3),FROM_B(7)}; - auto midx = _mm512_loadu_epi64(idx); - __mmask16 mask = _cvtu32_mask16(0xFFFFu); - const auto * src = reinterpret_cast(_src); - auto * dst0 = reinterpret_cast(_dst0); - auto * dst1 = reinterpret_cast(_dst1); - __mmask32 mask_n = _cvtu32_mask32(0xFFFFFFFF >> (32 - valid_n)); - if (src_rows == 64) { - for (int row = 0; row < 16; row++) { - // each element (a? or b?) is 32-bits, two lanes in each ymm register - auto a256 = _mm256_maskz_loadu_epi8(mask_n, src); src += stride; // [a0 a1 a2 a3 | a4 a5 a6 a7] 256-bits ymm0 B0: a0-a3 B1: a4:a7 - auto b256 = _mm256_maskz_loadu_epi8(mask_n, src); src += stride; // [b0 b1 b2 b3 | b4 b5 b6 b7] 256-bits ymm1 B0: b0-b3 B1: b4:b7 - auto c256 = _mm256_maskz_loadu_epi8(mask_n, src); src += stride; // [c0 c1 c2 c3 | c4 c5 c6 c7] 256-bits ymm2 B0: c0-c3 B1: c4:c7 - auto d256 = _mm256_maskz_loadu_epi8(mask_n, src); src += stride; // [d0 d1 d2 d3 | d4 d5 d6 d7] 256-bits ymm3 B0: d0-d3 B1: d4:d7 - auto a = _mm512_castsi256_si512(a256); - auto b = _mm512_castsi256_si512(b256); - auto c = _mm512_castsi256_si512(c256); - auto d = _mm512_castsi256_si512(d256); - auto ac = _mm512_mask_permutex2var_epi32(a, mask, midx, c); // [a0 a4 c0 c4 | a1 a5 c1 c5 | a2 a6 c2 c6 | a3 a7 c3 c7] - auto bd = _mm512_mask_permutex2var_epi32(b, mask, midx, d); // [b0 b4 d0 d4 | b1 b5 d1 d5 | b2 b6 d2 d6 | b3 b7 d3 d7] - auto aib = _mm512_unpacklo_epi8(ac, bd); // [a0&b0 a4&b4 | a1&b1 a5&b5 | a2&b2 a6&b6 | a3&b3 a7&b7] - auto cid = _mm512_unpackhi_epi8(ac, bd); // [c0&d0 c4&d4 | c1&d1 c5&d5 | c2&d2 c6&d6 | c3&d3 c7&d7] - auto rowB0 = _mm512_unpacklo_epi16(aib, cid); // [a0&b0&c0&d0 | a1&b1&c1&d1 | a2&b2&c2&d2 | a3&b3&c3&d3] 512-bit (64bytes) line in B0 - auto rowB1 = _mm512_unpackhi_epi16(aib, cid); // [a4&b4&c4&d4 | a5&b5&c5&d5 | a6&b6&c6&d6 | a7&b7&c7&d7] 512-bit (64bytes) line in B1 - _mm512_storeu_epi16(dst0, rowB0); - _mm512_storeu_epi16(dst1, rowB1); - dst0 += 64; - dst1 += 64; - } - } else { - // less than 64 source lines, - int allzero_dst_rows = (64-src_rows)/4; - int allnonzero_dst_rows = src_rows/4; - // padding zeros at the top - auto rowB0 = _mm512_setzero_si512(); - auto rowB1 = _mm512_setzero_si512(); - for(int i = 0; i < allzero_dst_rows ; i++) { - _mm512_storeu_epi16(dst0, rowB0); - _mm512_storeu_epi16(dst1, rowB1); - dst0 += 64; - dst1 += 64; - } - // mixed row - int tails_nz = (src_rows & 3); - if (tails_nz) { - __mmask32 kmask1 = _cvtu32_mask32(0xFFFFFFFF); - auto a256 = _mm256_setzero_si256(); // must be zero - auto b256 = _mm256_setzero_si256(); // when tails_nz > 2 - auto c256 = _mm256_setzero_si256(); // when tails_nz > 1 - auto d256 = _mm256_setzero_si256(); // when tails_nz > 0(always load) - if (tails_nz > 2) { - b256 = _mm256_maskz_loadu_epi8(mask_n, src); src += stride; - } - if (tails_nz > 1) { - c256 = _mm256_maskz_loadu_epi8(mask_n, src); src += stride; - } - d256 = _mm256_maskz_loadu_epi8(mask_n, src); src += stride; - auto a = _mm512_castsi256_si512(a256); - auto b = _mm512_castsi256_si512(b256); - auto c = _mm512_castsi256_si512(c256); - auto d = _mm512_castsi256_si512(d256); - auto ac = _mm512_mask_permutex2var_epi32(a, mask, midx, c); // [a0 a4 c0 c4 | a1 a5 c1 c5 | a2 a6 c2 c6 | a3 a7 c3 c7] - auto bd = _mm512_mask_permutex2var_epi32(b, mask, midx, d); // [b0 b4 d0 d4 | b1 b5 d1 d5 | b2 b6 d2 d6 | b3 b7 d3 d7] - auto aib = _mm512_unpacklo_epi8(ac, bd); // [a0&b0 a4&b4 | a1&b1 a5&b5 | a2&b2 a6&b6 | a3&b3 a7&b7] - auto cid = _mm512_unpackhi_epi8(ac, bd); // [c0&d0 c4&d4 | c1&d1 c5&d5 | c2&d2 c6&d6 | c3&d3 c7&d7] - auto rowB0 = _mm512_unpacklo_epi16(aib, cid); // [a0&b0&c0&d0 | a1&b1&c1&d1 | a2&b2&c2&d2 | a3&b3&c3&d3] 512-bit (64bytes) line in B0 - auto rowB1 = _mm512_unpackhi_epi16(aib, cid); // [a4&b4&c4&d4 | a5&b5&c5&d5 | a6&b6&c6&d6 | a7&b7&c7&d7] 512-bit (64bytes) line in B1 - _mm512_storeu_epi16(dst0, rowB0); - _mm512_storeu_epi16(dst1, rowB1); - dst0 += 64; - dst1 += 64; - } - // all non zeros - for (int i = 0; i < allnonzero_dst_rows; i++) { - auto a256 = _mm256_maskz_loadu_epi8(mask_n, src); src += stride; - auto b256 = _mm256_maskz_loadu_epi8(mask_n, src); src += stride; - auto c256 = _mm256_maskz_loadu_epi8(mask_n, src); src += stride; - auto d256 = _mm256_maskz_loadu_epi8(mask_n, src); src += stride; - auto a = _mm512_castsi256_si512(a256); - auto b = _mm512_castsi256_si512(b256); - auto c = _mm512_castsi256_si512(c256); - auto d = _mm512_castsi256_si512(d256); - auto ac = _mm512_mask_permutex2var_epi32(a, mask, midx, c); // [a0 a4 c0 c4 | a1 a5 c1 c5 | a2 a6 c2 c6 | a3 a7 c3 c7] - auto bd = _mm512_mask_permutex2var_epi32(b, mask, midx, d); // [b0 b4 d0 d4 | b1 b5 d1 d5 | b2 b6 d2 d6 | b3 b7 d3 d7] - auto aib = _mm512_unpacklo_epi8(ac, bd); // [a0&b0 a4&b4 | a1&b1 a5&b5 | a2&b2 a6&b6 | a3&b3 a7&b7] - auto cid = _mm512_unpackhi_epi8(ac, bd); // [c0&d0 c4&d4 | c1&d1 c5&d5 | c2&d2 c6&d6 | c3&d3 c7&d7] - auto rowB0 = _mm512_unpacklo_epi16(aib, cid); // [a0&b0&c0&d0 | a1&b1&c1&d1 | a2&b2&c2&d2 | a3&b3&c3&d3] 512-bit (64bytes) line in B0 - auto rowB1 = _mm512_unpackhi_epi16(aib, cid); // [a4&b4&c4&d4 | a5&b5&c5&d5 | a6&b6&c6&d6 | a7&b7&c7&d7] 512-bit (64bytes) line in B1 - _mm512_storeu_epi16(dst0, rowB0); - _mm512_storeu_epi16(dst1, rowB1); - dst0 += 64; - dst1 += 64; - } - } - } - - inline void kpack_tile_B0B1_ntail(void * _dst0, void * _dst1, const ov::bfloat16 * _src, int stride, int src_rows, int valid_n) { - static const uint64_t idx[8] = {0,4,1,5,2,6,3,7}; - auto midx = _mm512_loadu_epi64(idx); - const auto * src = reinterpret_cast(_src); - auto * dst0 = reinterpret_cast(_dst0); - auto * dst1 = reinterpret_cast(_dst1); - __mmask32 mask = _cvtu32_mask32(0xFFFFFFFF >> (32 - valid_n)); - __m512i a,b,rowB0, rowB1; - if (src_rows == 32) { - for (int row = 0; row < 16; row++) { - a = _mm512_maskz_loadu_epi16(mask, src); // [a1 a2 a3 a4 | a5 a6 a7 a8] total 512-bits in 8 64bits unit - b = _mm512_maskz_loadu_epi16(mask, src + stride); // [b1 b2 b3 b4 | b5 b6 b7 b8] total 512-bits - a = _mm512_permutexvar_epi64(midx, a); // [a1 a5 | a2 a6 | a3 a7 | a4 a8] - b = _mm512_permutexvar_epi64(midx, b); // [b1 b5 | b2 b6 | b3 b7 | b4 b8] - rowB0 = _mm512_unpacklo_epi16(a, b); // [ a1&b1 a2&b2 a3&b3 a4&b4] for each 128-bits lane, interleave word in low 64 bits - rowB1 = _mm512_unpackhi_epi16(a, b); // [ a5&b5 a6&b6 a7&b7 a8&b8] for each 128-bits lane, interleave word in high 64 bits - _mm512_storeu_epi16(dst0, rowB0); - _mm512_storeu_epi16(dst1, rowB1); - src += 2*stride; - dst0 += 64; - dst1 += 64; - } - } else { - int allzero_dst_rows = (32-src_rows)/2; - int allnonzero_dst_rows = src_rows/2; - - rowB0 = _mm512_setzero_si512(); - rowB1 = _mm512_setzero_si512(); - for(int i = 0; i < allzero_dst_rows ; i++) { - _mm512_storeu_epi16(dst0, rowB0); - _mm512_storeu_epi16(dst1, rowB1); - dst0 += 64; - dst1 += 64; - } - // mixed row - if (src_rows & 1) { - a = _mm512_setzero_si512(); - b = _mm512_maskz_loadu_epi16(mask, src); src += stride; - a = _mm512_permutexvar_epi64(midx, a); - b = _mm512_permutexvar_epi64(midx, b); - auto rowB0 = _mm512_unpacklo_epi16(a, b); - auto rowB1 = _mm512_unpackhi_epi16(a, b); - _mm512_storeu_epi16(dst0, rowB0); - _mm512_storeu_epi16(dst1, rowB1); - dst0 += 64; - dst1 += 64; - } - // all non-zero rows - for (int i = 0; i < allnonzero_dst_rows; i++) { - a = _mm512_maskz_loadu_epi16(mask, src); - b = _mm512_maskz_loadu_epi16(mask, src + stride); - a = _mm512_permutexvar_epi64(midx, a); - b = _mm512_permutexvar_epi64(midx, b); - rowB0 = _mm512_unpacklo_epi16(a, b); - rowB1 = _mm512_unpackhi_epi16(a, b); - _mm512_storeu_epi16(dst0, rowB0); - _mm512_storeu_epi16(dst1, rowB1); - src += 2*stride; - dst0 += 64; - dst1 += 64; - } - } - } - - // prepare B matrix for C matrix 2x2 blocking (B matrix - // will be accessed in 1x2) - // given 2x2 blocking scheme, Kx32 blocks are always - // accessed sequentially: - // transpose/repack each 32xK ov::bfloat16 submatrix - // into Kx32 slices (each number is a 16x32 bf16-block): - // 0 2 4 6 ... ... - // 1 3 5 7 ... ... - - inline void get_min_max(tensor2D & matB, float& min, float& max) { - int K = matB.dims[0]; - int N = matB.dims[1]; - auto m_max = _mm512_set1_ps(-__FLT_MAX__); - auto m_min = _mm512_set1_ps(__FLT_MAX__); - for (int k = 0; k < K; k++) { - int n = 0; - for (; n < N / 16 * 16; n += 16) { - auto a = _mm512_cvtepi16_epi32(_mm256_loadu_epi16(&matB(k, n))); - a = _mm512_slli_epi32(a, 16); - m_max = _mm512_max_ps((__m512)a, m_max); - m_min = _mm512_min_ps((__m512)a, m_min); - } - if (n != N) { - __mmask16 msk = _cvtu32_mask16(0xFFFFu >> (16 - (N - n))); - auto a = _mm512_cvtepi16_epi32(_mm256_maskz_loadu_epi16(msk, &matB(k, n))); - a = _mm512_slli_epi32(a, 16); - m_max = _mm512_mask_max_ps(m_max, msk, (__m512)a, m_max); - m_min = _mm512_mask_min_ps(m_min, msk, (__m512)a, m_min); - } - } - max = _mm512_reduce_max_ps(m_max); - min = _mm512_reduce_min_ps(m_min); - } - - template - void i8_to_bf16_Kx32(int8_t *&src, ov::bfloat16 *dst) - { - for (int k = 0; k < K; k++) - { - auto a = _mm_load_si128((__m128i *)src); // 16 int8 - auto b = _mm_load_si128((__m128i *)(src + 16)); // 16 int8 - auto a_512 = _mm512_cvtepi8_epi32(a); // 16 int32 - auto b_512 = _mm512_cvtepi8_epi32(b); // 16 int32 - auto a_f = _mm512_cvtepi32_ps(a_512); // 16 ps - auto b_f = _mm512_cvtepi32_ps(b_512); // 16 ps - auto reg_out = _mm512_cvtne2ps_pbh(b_f, a_f); // 32 packed bf16 - _mm512_store_epi32(dst, (__m512i)reg_out); // - src += 32; // 32 int8_t dequantized into 32 bf16 - dst += 32; - } - } - - inline void bf16_to_i8_tensor(tensor2D& dst, tensor2D& src, float quant_scale) { - dst.resize(src.dims[0], src.dims[1]); - auto scale = _mm512_set1_ps(quant_scale); - for (int k = 0; k < src.dims[0]; k++) { - auto p_src = &src(k, 0); - auto p_dst = &dst(k, 0); - for (int n = 0; n < src.dims[1]; n += 16, p_src += 16, p_dst += 16) { - auto a = _mm512_cvtepi16_epi32(_mm256_loadu_epi16(p_src)); // load packed 16 x bf16 - a = _mm512_slli_epi32(a, 16); // bf16 zero-extend to f32 - auto a_f = _mm512_mul_ps((__m512)a, scale); // scale - a = _mm512_cvtps_epi32(a_f); // ps to dw - auto a_128 = _mm512_cvtsepi32_epi8(a); // saturate convert into int8 - _mm_store_si128((__m128i*)(p_dst), a_128); - } - } - } -}; - -// 2x2 tiles post process kernels - -// 4 tiles located at C matrix (m,n) of size (valid_m, valid_n) -// tC00/tC01 -// tC10/tC11 - -namespace PP { - template - struct is_f32i32 : std::false_type {}; - template<> - struct is_f32i32 : std::true_type {}; - template<> - struct is_f32i32 : std::true_type {}; - - enum Steps { - NONE = 0, - DEQUANT = 1<<0, - BIAS = 1<<1, - GELU = 1<<2, - QUANT = 1<<3, - - BIAS_GELU = BIAS | GELU, - DEQUANT_BIAS_GELU = DEQUANT | BIAS_GELU, - DEQUANT_BIAS_GELU_QUANT = DEQUANT_BIAS_GELU | QUANT, - DEQUANT_BIAS_QUANT = DEQUANT | BIAS | QUANT, - DEQUANT_GELU_QUANT = DEQUANT | GELU | QUANT, - DEQUANT_QUANT = DEQUANT | QUANT, - - DEQUANT_GELU = DEQUANT | GELU, - DEQUANT_BIAS = DEQUANT | BIAS - }; - - template - struct BiasGeluStore { - static_assert(std::is_same::value || std::is_same::value || std::is_same::value, - "BiasGeluStore only support output data types ov::bfloat16/int8_t/float"); - - BiasGeluStore(tensor2D & C, float * bias = nullptr) : C(C), bias(bias) {} - - tensor2D & C; - float * bias; - void set_bias(float * _bias) { - assert (steps & BIAS); - bias = _bias; - } - - float deq_scale_common = 1.0f; - float * deq_scale_per_oc = nullptr; - void set_deq_scale(float scale = 1.0f) { - assert (steps & DEQUANT); - deq_scale_common = scale; - deq_scale_per_oc = nullptr; - } - void set_deq_scale(float * scale_per_oc) { - assert (steps & DEQUANT); - deq_scale_common = 0; - deq_scale_per_oc = scale_per_oc; - } - - float q_scale_common = 0.0f; - float * q_scale_per_oc = nullptr; - void set_q_scale(float scale) { - assert (steps & QUANT); - q_scale_common = scale; - q_scale_per_oc = nullptr; - } - void set_q_scale(float * scale_per_oc) { - assert (steps & QUANT); - q_scale_common = 0; - q_scale_per_oc = scale_per_oc; - } - - // source buffC can be i32 or f32 - template::value, bool>::type = true> - void operator()(tensor2D & buffC, int m, int n, int valid_m, int valid_n) { - auto * psrc = &buffC(0,0); - int8_t * pdst = reinterpret_cast(&(C(m, n))); - int stride = C.stride; - - __m512 bias0, bias1; - if (steps & BIAS) { - bias0 = _mm512_loadu_ps(bias + n); - bias1 = _mm512_loadu_ps(bias + n + 16); - } - - __m512 m512_q_scale0; - __m512 m512_q_scale1; - __m512 m512_deq_scale0; - __m512 m512_deq_scale1; - if (steps & DEQUANT) { - if (deq_scale_per_oc) { - m512_deq_scale0 = _mm512_loadu_ps(deq_scale_per_oc + n); - m512_deq_scale1 = _mm512_loadu_ps(deq_scale_per_oc + n + 16); - } else { - m512_deq_scale0 = _mm512_set1_ps(deq_scale_common); - m512_deq_scale1 = _mm512_set1_ps(deq_scale_common); - } - } - if (steps & QUANT) { - if (q_scale_per_oc) { - m512_q_scale0 = _mm512_loadu_ps(q_scale_per_oc + n); - m512_q_scale1 = _mm512_loadu_ps(q_scale_per_oc + n + 16); - } else { - m512_q_scale0 = _mm512_set1_ps(q_scale_common); - m512_q_scale1 = _mm512_set1_ps(q_scale_common); - } - } - - __mmask32 kall; - __mmask16 k0, k1; - if (std::is_same::value) { - if (valid_n >= 16) { - k0 = _cvtu32_mask16(0xFFFF); - k1 = _cvtu32_mask16(0xFFFF >> (32-valid_n)); - } else { - k0 = _cvtu32_mask16(0xFFFF >> (16-valid_n)); - k1 = _cvtu32_mask16(0); - } - } else { - kall = _cvtu32_mask32(0xFFFFFFFF >> (32-valid_n)); - } - - for(int i = 0; i < valid_m; i ++) { - auto r0 = _mm512_loadu_ps(psrc); - auto r1 = _mm512_loadu_ps(psrc + 16); - if (std::is_same::value) { - r0 = _mm512_cvtepi32_ps(_mm512_castps_si512(r0)); // cvt i32=>f32 - r1 = _mm512_cvtepi32_ps(_mm512_castps_si512(r1)); // cvt i32=>f32 - } - if (steps & DEQUANT) { - r0 = _mm512_mul_ps(r0, m512_deq_scale0); // dequantize - r1 = _mm512_mul_ps(r1, m512_deq_scale1); // dequantize - } - if (steps & BIAS) { - r0 = _mm512_add_ps(r0, bias0); - r1 = _mm512_add_ps(r1, bias1); - } - if (steps & GELU) { - r0 = functional::gelu_erf_minmax_approx(r0); - r1 = functional::gelu_erf_minmax_approx(r1); - } - - // quantize & store - if (steps & QUANT) { - r0 = _mm512_mul_ps(r0, m512_q_scale0); - r1 = _mm512_mul_ps(r1, m512_q_scale1); - } - if (std::is_same::value) { - auto c = _mm512_cvtne2ps_pbh(r1, r0); // convert to bf16 - _mm512_mask_storeu_epi16(pdst, kall, (__m512i)c); // store bf16 - } - if (std::is_same::value) { - auto d0 = _mm512_cvtps_epi32(r0); // convert to dword(i32) - auto d1 = _mm512_cvtps_epi32(r1); // convert to dword(i32) - auto b0 = _mm512_cvtsepi32_epi8 (d0); // dword => int8 with Saturate8 - auto b1 = _mm512_cvtsepi32_epi8 (d1); // dword => int8 with Saturate8 - auto b0b1 = _mm256_inserti32x4(_mm256_castsi128_si256(b0), b1, 1); // combine two int8 xmm into a ymm - _mm256_mask_storeu_epi8(pdst, kall, b0b1); // masked store - } - if (std::is_same::value) { - _mm512_mask_storeu_ps(pdst, k0, r0); // store float - _mm512_mask_storeu_ps(pdst + 64, k1, r1); // store float - } - pdst += stride; - psrc += 32; - } - } - }; -} - -template -void prefetch_bytes(void *src) -{ - int8_t *p = reinterpret_cast(src); - for (int i = 0; i < bytes; i+=64) - _mm_prefetch(p + i + advance, sel); -} - -// matmul (FC) -// -// constB constrols whether it's FC or not -// store precision for weight compression, only for BF16 AMX - -template -tensor2D getSubMatB(tensor2D & _matB, int n0, int n1, bool transposeB) { - int Bd0 = transposeB ? (n1-n0) : _matB.dims[0]; - int Bd1 = transposeB ? _matB.dims[1] : (n1-n0); - T * pbase = transposeB ? (&_matB(n0, 0)):(&_matB(0, n0)); - return tensor2D(Bd0, Bd1, pbase, _matB.stride); -} - -template -void loop2D_no_bM(int M, int N, F f) { - for(int n=0; n -void loop2D(int M, int N, int mc, F f) { - for(int m0=0; m0= bM) -template -void loop2D_opt_Mtail(int M, int N, int mc, F f) { - int tailM = (M % (mc*bM)) % bM; - assert(M > bM); - for(int m0=0; m0 -void repackB_1x2(const tensor2D &Bi, bool transpose, tensor2D& Bo, bool is_const) { - int K = Bi.dims[transpose ? 1 : 0]; - int N = Bi.dims[transpose ? 0 : 1]; - - // K_padded : round up to multiple of 32/64 - int kStep = 64 / sizeof(T); - int K_padded = (K + kStep - 1) / kStep * kStep; - int Ktails = K % kStep; - int Kbody = K - Ktails; - - // N_padded : round up to multiple of (2*16) - int N_unit = 2 * 16; - int N_padded = (N + N_unit - 1) / N_unit * N_unit; - - // Bo(ni, 0) is a vector flattened from a slice of shape [K_padded x N_unit] - Bo.resize(N_padded / N_unit, K_padded * N_unit, false, is_const); - - int n = 0; - int n_tail = N % N_unit; - if (transpose) { - for(; n < N - n_tail; n += N_unit) { - // a K_padded x N_unit submatrix layouted in B0/B1... and put sequentially - auto* dst = reinterpret_cast(&Bo(n / N_unit, 0)); - auto* src0 = reinterpret_cast(&Bi(n, 0)); - int k; - for(k = 0; k < Kbody; k += kStep) { - // B0 (16x32) => transpose+repack as 32x16(16x16x2) or 64x16(16x16x4) - functional::transpose_epi32_16x16(dst, src0 + 0 * 16 * Bi.stride + k * sizeof(T), Bi.stride); - dst += 1024; - functional::transpose_epi32_16x16(dst, src0 + 1 * 16 * Bi.stride + k * sizeof(T), Bi.stride); - dst += 1024; - } - if (Ktails) { - // Ktails part is loaded into A tile right-aligned, so B tile must also load - // Ktails part to bottom-aligned, and fill upper padding with zero - functional::transpose_epi32_16xN_right_align(dst, src0 + 0 * 16 * Bi.stride + k * sizeof(T), Bi.stride, (K - k)*sizeof(T)); - dst += 1024; - functional::transpose_epi32_16xN_right_align(dst, src0 + 1 * 16 * Bi.stride + k * sizeof(T), Bi.stride, (K - k)*sizeof(T)); - dst += 1024; - } - } - // n_tail: [16, 32) - if (N - n >= 16) { - auto* dst = reinterpret_cast(&Bo(n / N_unit, 0)); - auto* src0 = reinterpret_cast(&Bi(n, 0)); - int k; - for(k = 0; k < Kbody; k += kStep) { - // B0 (16x32) => transpose+repack as 32x16(16x16x2) or 64x16(16x16x4) - functional::transpose_epi32_16x16(dst, src0 + 0 * 16 * Bi.stride + k * sizeof(T), Bi.stride); - dst += 1024 * 2; - } - if (Ktails) { - // Ktails part is loaded into A tile right-aligned, so B tile must also load - // Ktails part to bottom-aligned, and fill upper padding with zero - functional::transpose_epi32_16xN_right_align(dst, src0 + 0 * 16 * Bi.stride + k * sizeof(T), Bi.stride, (K - k) * sizeof(T)); - } - n += 16; - } - // n_tail: (0, 16) - if (N - n > 0) { - auto* dst = reinterpret_cast(&Bo(n / N_unit, 0)) + (n_tail > 16 ? 1024 : 0); - auto* src0 = reinterpret_cast(&Bi(n, 0)); - int k; - for(k = 0; k < Kbody; k += kStep) { - // B0 (16x32) => transpose+repack as 32x16(16x16x2) or 64x16(16x16x4) - functional::transpose_epi32_Mx16(dst, src0 + 0 * 16 * Bi.stride + k * sizeof(T), Bi.stride, N - n); - dst += 1024 * 2; - } - if (Ktails) { - // Ktails part is loaded into A tile right-aligned, so B tile must also load - // Ktails part to bottom-aligned, and fill upper padding with zero - functional::transpose_epi32_MxN_right_align(dst, src0 + 0 * 16 * Bi.stride + k * sizeof(T), Bi.stride, (K - k) * sizeof(T), N - n); - } - n = N; - } - // second B tile is untouched, need to set to zero - if (n_tail > 0 && n_tail <= 16) { - auto* dst = reinterpret_cast(&Bo(n / N_unit, 0)); - for (int k = 0; k < K_padded; k += kStep) { - memset(dst + 1024, 0, 1024); - dst += 1024 * 2; - } - } - } else { - // pack & layout sequentially - int n = 0; - int n_tail = N % N_unit; - for(; n < N - n_tail; n += N_unit) { - auto * dst = reinterpret_cast(&Bo(n / N_unit, 0)); - for(int k = 0; k < K; k += kStep) { - // bf16: B0 B1 32x(16+16) => repack as two 16x16x2 - // int8: B0 B1 64x(16+16) => repack as two 16x16x4 - int src_rows = std::min(K - k, kStep); - functional::kpack_tile_B0B1(dst, dst + 1024, &Bi(k, n), Bi.stride, src_rows); - dst += 2048; - } - } - // n_tail: (0, 32) - if (N - n > 0) { - auto * dst = reinterpret_cast(&Bo(n / N_unit, 0)); - for(int k = 0; k < K; k += kStep) { - // bf16: B0 B1 32x(16+16) => repack as two 16x16x2 - // int8: B0 B1 64x(16+16) => repack as two 16x16x4 - int src_rows = std::min(K - k, kStep); - functional::kpack_tile_B0B1_ntail(dst, dst + 1024, &Bi(k, n), Bi.stride, src_rows, N - n); - dst += 2048; - } - n += 16; - } - } -} - -template -struct acc_type {}; -template<> -struct acc_type { typedef float type; }; -template<> -struct acc_type { typedef int32_t type; }; -template<> -struct acc_type { typedef int32_t type; }; - -template -using acc_type_t = typename acc_type::type; - -// matrix multiply with vector -// C_Nx1 = A_MxK * b_Kx1 - -template> -struct MatmulVector { - MatmulVector() {} - constexpr static bool is_bf16s8 = std::is_same::value && std::is_same::value; - constexpr static bool is_bf16bf16 = std::is_same::value && std::is_same::value; - constexpr static bool is_s8s8 = std::is_same::value && std::is_same::value; - constexpr static bool is_s8u8 = std::is_same::value && std::is_same::value; - constexpr static bool is_u8s8 = std::is_same::value && std::is_same::value; - constexpr static bool is_u8u8 = std::is_same::value && std::is_same::value; - constexpr static bool is_i8_mode = is_s8s8 || is_s8u8 || is_u8s8 || is_u8u8; - constexpr static int kStep = is_i8_mode ? 64 : 32; - -#define TILE_DP(dst, a, b) \ - if (is_bf16bf16) _tile_dpbf16ps(dst, a, b); \ - if (is_s8s8) _tile_dpbssd(dst, a, b); \ - if (is_s8u8) _tile_dpbsud(dst, a, b); \ - if (is_u8s8) _tile_dpbusd(dst, a, b); \ - if (is_u8u8) _tile_dpbuud(dst, a, b); - - alignas(64) int8_t KtailBuff[64]; - - template - void kernel(int M, int K, const void * pA, int strideA, const void * vB, void * vC) { - static_assert(tmmN >= 1 && tmmN <= 6, "tmmN must be within [1-6] range"); - const auto * pA0 = reinterpret_cast(pA); - int KLastOffBytes = (K - kStep) * sizeof(TA); - const auto * pB0 = reinterpret_cast(vB); - auto * pC0 = reinterpret_cast(vC); - - const auto * pBLast = pB0 + 64*(tmmN - 1); - int Ktail = K & (kStep - 1); - if (Ktail) { - if (bFallbackKtails) { - // if bContainMtails, the last submatrix needs to use special to prevent A matrix read overflow - // K tails is handled by: - // - zero-padding the last tile of vector B, at the top - // - right-align last tile load from matA - __mmask64 kmask = _cvtu64_mask64(0xFFFFFFFFFFFFFFFFull << (kStep - Ktail)*sizeof(TB)); - auto r = _mm512_maskz_loadu_epi8(kmask, pB0 + KLastOffBytes); - _mm512_storeu_epi8(KtailBuff, r); - } else { - // each row of A can be read overflow w/o worrying NaN numbers - // zero-padding the last tile of vector B as bottom is enough - __mmask64 kmask = _cvtu64_mask64(0xFFFFFFFFFFFFFFFFull >> (kStep - Ktail)*sizeof(TB)); - KLastOffBytes = (K - Ktail)*sizeof(TA); - auto r = _mm512_maskz_loadu_epi8(kmask, pB0 + KLastOffBytes); - _mm512_storeu_epi8(KtailBuff, r); - } - pBLast = KtailBuff; - } - - // load B tiles outside of loop - if (tmmN == 1) { - _tile_loadd(2, pB0, 4); - } - if (tmmN == 2) { - _tile_loadd(2, pB0, 4); - _tile_loadd(3, pBLast, 4); - } - if (tmmN == 3) { - _tile_loadd(2, pB0, 4); - _tile_loadd(3, pB0 + 64, 4); - _tile_loadd(4, pBLast, 4); - } - if (tmmN == 4) { - _tile_loadd(2, pB0, 4); - _tile_loadd(3, pB0 + 64, 4); - _tile_loadd(4, pB0 + 64*2, 4); - _tile_loadd(5, pBLast, 4); - } - if (tmmN == 5) { - _tile_loadd(2, pB0, 4); - _tile_loadd(3, pB0 + 64, 4); - _tile_loadd(4, pB0 + 64*2, 4); - _tile_loadd(5, pB0 + 64*3, 4); - _tile_loadd(6, pBLast, 4); - } - if (tmmN == 6) { - _tile_loadd(2, pB0, 4); - _tile_loadd(3, pB0 + 64, 4); - _tile_loadd(4, pB0 + 64*2, 4); - _tile_loadd(5, pB0 + 64*3, 4); - _tile_loadd(6, pB0 + 64*4, 4); - _tile_loadd(7, pBLast, 4); - } - //asm("int3"); - for(int m = 0; m < M; m+=16) { - _tile_zero(0); - if (tmmN == 1) { - _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); - } - if (tmmN == 2) { - _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); - _tile_loadd(1, pA0 + KLastOffBytes, strideA); TILE_DP(0, 1, 3); - } - if (tmmN == 3) { - _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); - _tile_loadd(1, pA0 + 64, strideA); TILE_DP(0, 1, 3); - _tile_loadd(1, pA0 + KLastOffBytes, strideA); TILE_DP(0, 1, 4); - } - if (tmmN == 4) { - _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); - _tile_loadd(1, pA0 + 64, strideA); TILE_DP(0, 1, 3); - _tile_loadd(1, pA0 + 128, strideA); TILE_DP(0, 1, 4); - _tile_loadd(1, pA0 + KLastOffBytes, strideA); TILE_DP(0, 1, 5); - } - if (tmmN == 5) { - _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); - _tile_loadd(1, pA0 + 64, strideA); TILE_DP(0, 1, 3); - _tile_loadd(1, pA0 + 128, strideA); TILE_DP(0, 1, 4); - _tile_loadd(1, pA0 + 192, strideA); TILE_DP(0, 1, 5); - _tile_loadd(1, pA0 + KLastOffBytes, strideA); TILE_DP(0, 1, 6); - } - if (tmmN == 6) { - _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); - _tile_loadd(1, pA0 + 64, strideA); TILE_DP(0, 1, 3); - _tile_loadd(1, pA0 + 128, strideA); TILE_DP(0, 1, 4); - _tile_loadd(1, pA0 + 192, strideA); TILE_DP(0, 1, 5); - _tile_loadd(1, pA0 + 256, strideA); TILE_DP(0, 1, 6); - _tile_loadd(1, pA0 + KLastOffBytes, strideA); TILE_DP(0, 1, 7); - } - _tile_stored(0, pC0, 4); pC0 += 16*4; // C is single column, always take 4 bytes - pA0 += 16 * strideA; - } - } - - void operator()(tensor2D & matA, const TB * vB, TC * vC) { - int M = matA.dims[0]; - int K = matA.dims[1]; - TA * pA = &matA[0]; - int strideA = matA.stride; - - // M tails is handled - assert(K >= kStep && K <= 6*kStep); - - int Ktail = K & (kStep - 1); - int Mtail = M & (16 - 1); - int Mbody = M - Mtail; - int numBtiles = (K + kStep - 1)/kStep; - - // if we have Ktails, then it will always be handled in Mtail, so we split - // Mtail out even if it's zero - if (Ktail) { - if (Mtail == 0) { - Mtail = 16; - Mbody -= 16; - } - } - - if (Mbody) { - tileconfig_t tfg(1, 0, { - {16, 4}, // C:0 M x 1 (4b) - {16, 64}, // A:1 M x 32/64 (64b) - {16, 4}, // B:2 32/64 x 1 (4b) - {16, 4}, // B:3 - {16, 4}, // B:4 - {16, 4}, // B:5 - {16, 4}, // B:6 - {16, 4}, // B:7 - }); - // Ktail fallback will always be done at Mtails loop - switch(numBtiles) { - case 1: kernel<1, false>(Mbody, K, pA, strideA, vB, vC); break; - case 2: kernel<2, false>(Mbody, K, pA, strideA, vB, vC); break; - case 3: kernel<3, false>(Mbody, K, pA, strideA, vB, vC); break; - case 4: kernel<4, false>(Mbody, K, pA, strideA, vB, vC); break; - case 5: kernel<5, false>(Mbody, K, pA, strideA, vB, vC); break; - case 6: kernel<6, false>(Mbody, K, pA, strideA, vB, vC); break; - default: - assert(false); // impossible since (K <= 6*kStep) - } - } - - if (Mtail) { - pA = &matA(Mbody, 0); - tileconfig_t tfg(1, 0, { - {Mtail, 4}, // C:0 M x 1 (4b) - {Mtail, 64}, // A:1 M x 32/64 (64b) - {16, 4}, // B:2 32/64 x 1 (4b) - {16, 4}, // B:3 - {16, 4}, // B:4 - {16, 4}, // B:5 - {16, 4}, // B:6 - {16, 4}, // B:7 - }); - if (Ktail) { - switch(numBtiles) { - case 1: kernel<1, true>(Mtail, K, pA, strideA, vB, vC + Mbody); break; - case 2: kernel<2, true>(Mtail, K, pA, strideA, vB, vC + Mbody); break; - case 3: kernel<3, true>(Mtail, K, pA, strideA, vB, vC + Mbody); break; - case 4: kernel<4, true>(Mtail, K, pA, strideA, vB, vC + Mbody); break; - case 5: kernel<5, true>(Mtail, K, pA, strideA, vB, vC + Mbody); break; - case 6: kernel<6, true>(Mtail, K, pA, strideA, vB, vC + Mbody); break; - default: - assert(false); // impossible since (K <= 6*kStep) - } - } else { - switch(numBtiles) { - case 1: kernel<1, false>(Mtail, K, pA, strideA, vB, vC + Mbody); break; - case 2: kernel<2, false>(Mtail, K, pA, strideA, vB, vC + Mbody); break; - case 3: kernel<3, false>(Mtail, K, pA, strideA, vB, vC + Mbody); break; - case 4: kernel<4, false>(Mtail, K, pA, strideA, vB, vC + Mbody); break; - case 5: kernel<5, false>(Mtail, K, pA, strideA, vB, vC + Mbody); break; - case 6: kernel<6, false>(Mtail, K, pA, strideA, vB, vC + Mbody); break; - default: - assert(false); // impossible since (K <= 6*kStep) - } - } - } - } -}; - -template::type> -struct Matmul { - // B matrix is orgnized as tensor2D of shape axb where a=round_up_div(N, 32), b=round_up(K,32/64)*32 - // so b is size of submatrix of Kx32 composed of two columns of B0/B1 tiles. - tensor2D internalB; - - bool constB; - bool transposeB; - - constexpr static bool is_bf16s8 = std::is_same::value && std::is_same::value; - constexpr static bool is_bf16bf16 = std::is_same::value && std::is_same::value; - constexpr static bool is_s8s8 = std::is_same::value && std::is_same::value; - constexpr static bool is_s8u8 = std::is_same::value && std::is_same::value; - constexpr static bool is_u8s8 = std::is_same::value && std::is_same::value; - constexpr static bool is_u8u8 = std::is_same::value && std::is_same::value; - constexpr static bool is_i8_mode = is_s8s8 || is_s8u8 || is_u8s8 || is_u8u8; - - // AMX bf16 & int8 has same M(=16) in A,C tile and same N(=16) in B tile - // but only different K(32 vs 64) in A,C & B tiles - constexpr static int kStep = is_i8_mode ? 64 : 32; - - // 2x2 C tiles buffer - // most usecase requires post-processing with AVX, thus buffC - // is used to transfer data to AVX register - tensor2D buffC; - - Matmul(bool constB = false, bool transposeB = false) : - constB(constB), transposeB(transposeB), buffC(32, 32) {} - - // ppkernel is a callable which captures the runtime args - // by itself, so no need to pass in any post-process related - // runtime args through this API - // - // n0/n1 allows us for calculating only partial results, so it - // can be used to run on multi-cores in parallel - // - // ppkernel will be invoked with true (m,n) with n0-offset added - // so ppkernel don't need to know which sub-matrix it's working on. - // - // for most ppkernels w/o runtime state, a single ppkernel can be - // shared among all threads. - // - // but for ppkernels doing reductions, it needs separate instance - // for each thread, also a post-merging process to combine the results. - // - // ppkernels are simple to write, further wrapping or structurelize only - // makes the design more complex, so we stop doing that. - - // I cannot find a way to call TDP intrinsic polymophically using overload or template. - // have to use old-macro-tricks, hopefully these compile-time checks can be optimized - // by compiler. -#define TILE_DP(dst, a, b) \ - if (is_bf16bf16) _tile_dpbf16ps(dst, a, b); \ - if (is_s8s8) _tile_dpbssd(dst, a, b); \ - if (is_s8u8) _tile_dpbsud(dst, a, b); \ - if (is_u8s8) _tile_dpbusd(dst, a, b); \ - if (is_u8u8) _tile_dpbuud(dst, a, b); - - template - void kernel_slimB(int M, int N, int K, int n0, - tensor2D & A, - void * B, - tensor2D & buffC, - PP ppkernel) { - auto * pB0 = reinterpret_cast(B); - auto * pC0 = &buffC[0]; - int8_t * pA0 = reinterpret_cast(&A[0]); - int strideA = A.stride; - int KlastOffBytes = (K - kStep)* sizeof(TA); - // load B tiles outside of loop - if (tmmN > 0) _tile_loadd(2, pB0, 64); pB0 += 1024*2; - if (tmmN > 1) _tile_loadd(3, pB0, 64); pB0 += 1024*2; - if (tmmN > 2) _tile_loadd(4, pB0, 64); pB0 += 1024*2; - if (tmmN > 3) _tile_loadd(5, pB0, 64); pB0 += 1024*2; - if (tmmN > 4) _tile_loadd(6, pB0, 64); pB0 += 1024*2; - if (tmmN > 5) _tile_loadd(7, pB0, 64); pB0 += 1024*2; - //asm("int3"); - for(int m0 = 0; m0 < M; m0+=16) { - int m = m0; - if (M - m0 < 16) { - // shift up to prevent M-tails - pA0 -= (16 - (M - m0))*A.stride; - m = M - 16; - } - _tile_zero(0); - if (tmmN == 1) { - _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); - } - if (tmmN == 2) { - _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); - _tile_loadd(1, pA0 + KlastOffBytes, strideA); TILE_DP(0, 1, 3); - } - if (tmmN == 3) { - _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); - _tile_loadd(1, pA0 + 64, strideA); TILE_DP(0, 1, 3); - _tile_loadd(1, pA0 + KlastOffBytes, strideA); TILE_DP(0, 1, 4); - } - if (tmmN == 4) { - _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); - _tile_loadd(1, pA0 + 64, strideA); TILE_DP(0, 1, 3); - _tile_loadd(1, pA0 + 128, strideA); TILE_DP(0, 1, 4); - _tile_loadd(1, pA0 + KlastOffBytes, strideA); TILE_DP(0, 1, 5); - } - if (tmmN == 5) { - _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); - _tile_loadd(1, pA0 + 64, strideA); TILE_DP(0, 1, 3); - _tile_loadd(1, pA0 + 128, strideA); TILE_DP(0, 1, 4); - _tile_loadd(1, pA0 + 192, strideA); TILE_DP(0, 1, 5); - _tile_loadd(1, pA0 + KlastOffBytes, strideA); TILE_DP(0, 1, 6); - } - if (tmmN == 6) { - _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); - _tile_loadd(1, pA0 + 64, strideA); TILE_DP(0, 1, 3); - _tile_loadd(1, pA0 + 128, strideA); TILE_DP(0, 1, 4); - _tile_loadd(1, pA0 + 192, strideA); TILE_DP(0, 1, 5); - _tile_loadd(1, pA0 + 256, strideA); TILE_DP(0, 1, 6); - _tile_loadd(1, pA0 + KlastOffBytes, strideA); TILE_DP(0, 1, 7); - } - _tile_stored(0, pC0, buffC.stride); - (ppkernel)(buffC, m, n0, 16, N); - pA0 += 16*A.stride; - } - } - - template - void operator()(tensor2D & matA, - tensor2D & _matB, - int n0, int n1, - PP ppkernel, - bool skip_repack = false) { - auto matB = getSubMatB(_matB, n0, n1, transposeB); - int M = matA.dims[0]; - int K = matA.dims[1]; - int N = matB.dims[transposeB ? 0 : 1]; - assert(K == matB.dims[transposeB ? 1 : 0]); - // Due to the fact that we load a full tile at tails of K dimension - // we may access memory address beyond the limit of A matrix - // to avoid read in nan values, we backoff to the left to ensure A tile - // contain valid numbers and no overflow access happens, but it requires K>=kStep; - assert(K >= kStep); - int Ktails = K % kStep; - int Kbody = K - Ktails; - int KbackoffBytes = (kStep - Ktails)*sizeof(TA); - - // for non-constB, internalB is updated every time - // for constB, internalB is updated once - if ((!constB && !skip_repack) || (internalB.capacity == 0)) { - repackB_1x2(matB, transposeB, internalB, constB); - } - - // special case when whole B matrix can fit in 6 tiles - // we can load B only once - if (M >= 16 && N <= 16 && K <= 6*kStep) { - // B is zero-padded - // C:0 - // A:1 - // B:2,3,4,5,6,7 - auto * pB0 = reinterpret_cast(&internalB[0]); - tileconfig_t tfg(1, 0, 8, 16, 64); - switch((K + kStep - 1)/kStep) { - case 1: kernel_slimB<1>(M, N, K, n0, matA, pB0, buffC, ppkernel); break; - case 2: kernel_slimB<2>(M, N, K, n0, matA, pB0, buffC, ppkernel); break; - case 3: kernel_slimB<3>(M, N, K, n0, matA, pB0, buffC, ppkernel); break; - case 4: kernel_slimB<4>(M, N, K, n0, matA, pB0, buffC, ppkernel); break; - case 5: kernel_slimB<5>(M, N, K, n0, matA, pB0, buffC, ppkernel); break; - case 6: kernel_slimB<6>(M, N, K, n0, matA, pB0, buffC, ppkernel); break; - default: - assert(false); // impossible since (K <= 6*kStep) - } - return; - } - - if (M <= 16) { - // register/cache blocking scheme is simplified when M <= 16 - // C_MxN: 0,1 - // A_MxK: 2, - // B_KxN: 3, 4 - tileconfig_t tfg(1, 0, {M,M,M,16,16}, 64); - auto * pB0 = reinterpret_cast(&internalB[0]); - auto * const pC0 = &buffC[0]; - int k; - const auto strideA = matA.stride; - loop2D_no_bM<32>(M, N, [&](int m, int n, int valid_m, int valid_n) { - _tile_zero(0); - _tile_zero(1); - int8_t * pA0 = reinterpret_cast(&matA[0]); - for(k=0; k(pB0); - _tile_loadd(3, pB0, 64); pB0 += 1024; // tile B0 32x16(16x16x2)/64x16(16x16x4) is always 1KB - // prefetch_bytes<1024, _MM_HINT_T1, 4096*48>(pB0); - _tile_loadd(4, pB0, 64); pB0 += 1024; // tile B1 32x16(16x16x2)/64x16(16x16x4) is always 1KB - TILE_DP(0, 2, 3); // C0 += A*B0 - TILE_DP(1, 2, 4); // C1 += A*B1 - } - if (Ktails) { - _tile_loadd(2, pA0 - KbackoffBytes, strideA); - // prefetch_bytes<1024, _MM_HINT_T1, 4096*48>(pB0); - _tile_loadd(3, pB0, 64); pB0 += 1024; - // prefetch_bytes<1024, _MM_HINT_T1, 4096*48>(pB0); - _tile_loadd(4, pB0, 64); pB0 += 1024; - TILE_DP(0, 2, 3); // C0 += A*B0 - TILE_DP(1, 2, 4); // C1 += A*B1 - } - _tile_stored(0, pC0, buffC.stride); - _tile_stored(1, pC0 + 16, buffC.stride); - //int valid_n = std::min(N - n, 32); - (ppkernel)(buffC, 0, n + n0, M, valid_n); - }); - return; - } - - auto kernel_2x2 = [&](int m, int n, int valid_m, int valid_n) { - auto * pA0 = reinterpret_cast(&matA(m, 0)); - auto * pA1 = reinterpret_cast(&matA(m + 16, 0)); - auto strideA = matA.stride; - auto * pB = reinterpret_cast(&internalB(n>>5, 0)); - _tile_zero(0); - _tile_zero(1); - _tile_zero(2); - _tile_zero(3); - // 2x2 - for (int k = 0; k < Kbody; k += kStep) { - _tile_loadd(4, pA0, strideA); pA0 += 64; - _tile_loadd(6, pB, 64); pB += 1024; - // prefetch_bytes<1024>(pB); - TILE_DP(0, 4, 6); - - _tile_loadd(5, pA1, strideA); pA1 += 64; - TILE_DP(2, 5, 6); - _tile_loadd(7, pB, 64); pB += 1024; - // prefetch_bytes<1024>(pB); - TILE_DP(1, 4, 7); - - TILE_DP(3, 5, 7); - } - if (Ktails) { - _tile_loadd(4, pA0 - KbackoffBytes, strideA); - _tile_loadd(6, pB, 64); pB += 1024; - // prefetch_bytes<1024>(pB); - TILE_DP(0, 4, 6); - - _tile_loadd(5, pA1 - KbackoffBytes, strideA); - TILE_DP(2, 5, 6); - _tile_loadd(7, pB, 64); pB += 1024; - // prefetch_bytes<1024>(pB); - TILE_DP(1, 4, 7); - - TILE_DP(3, 5, 7); - } - _tile_stored(0, &buffC(0,0), buffC.stride); - _tile_stored(1, &buffC(0,16), buffC.stride); - _tile_stored(2, &buffC(16,0), buffC.stride); - _tile_stored(3, &buffC(16,16), buffC.stride); - (ppkernel)(buffC, m, n + n0, valid_m, valid_n); - }; - - if (M <= 32 && M >16) { - // 2x2 tile, C:0/1/2/3 A:4/5 B:6/7 no blocking along M dimension - tileconfig_t tfg(1, 0, {16,16,M-16,M-16,16,M-16,16,16}, 64); - loop2D_no_bM<32>(M, N, kernel_2x2); - return; - } - - // generic input shapes with M > 32 - // determine cache blocking scheme - int elesz = sizeof(TA); - int L2 = 2048*1024; // 2MB - int slice_size = 32*rndup(K, 32)*elesz; - int mc = std::max(1, L2/slice_size - 1); - - // M > bM - tileconfig_t tfg(1, 0, 8, 16, 64); - loop2D_opt_Mtail<32, 32>(M, N, mc, kernel_2x2); - } -}; - -// specialization: -// TA is ov::bfloat16 and TB is int8_t, decompressed on the fly into ov::bfloat16 by simply convert -template<> -struct Matmul { - tensor2D internalBI8; - - // wei_buff is ping-pong buffer containing ov::bfloat16 weights decompressed on the fly. - tensor2D weiBuff; - - bool constB; - bool transposeB; - - constexpr static int kStep = 32; - - // 2x2 C tiles buffer - // most usecase requires post-processing with AVX, thus buffC - // is used to transfer data to AVX register - tensor2D buffC; - - Matmul(bool constB = false, bool transposeB = false) : - constB(constB), transposeB(transposeB), buffC(32, 32) {} - - float quant_scale_B; - float dequant_scale_B; - - template - void operator()(tensor2D & matA, - tensor2D & _matB, - int n0, int n1, - PP ppkernel) { - auto matB = getSubMatB(_matB, n0, n1, transposeB); - int M = matA.dims[0]; - int K = matA.dims[1]; - int N = matB.dims[transposeB ? 0 : 1]; - assert(K == matB.dims[transposeB ? 1 : 0]); - // Due to the fact that we load a full tile at tails of K dimension - // we may access memory address beyond the limit of A matrix - // to avoid read in nan values, we backoff to the left to ensure A tile - // contain valid numbers and no overflow access happens, but it requires K>=kStep; - assert(K >= kStep); - int Ktails = K % kStep; - int Kbody = K - Ktails; - int Kbackoff = (kStep - Ktails); - - // for non-constB, internalB is updated every time - // for constB, internalB is updated once - if (!constB || (internalBI8.capacity == 0)) { - // this dynamic quantization of weight matrix using minmax - // is time-consuming, should be used only for constB - if (!constB) { - std::cout << "\t WANING: dynamic quantization of weight matrix for non-constB is time-consuming " << std::endl; - } - // float min, max; - // functional::get_min_max(_matB, min, max); - // max = std::max(std::abs(max), std::abs(min)); - // quant_scale_B = 127 / max; - // dequant_scale_B = max / 127; - - tensor2D internalTmpB; - repackB_1x2(matB, transposeB, internalTmpB, constB); - functional::bf16_to_i8_tensor(internalBI8, internalTmpB, quant_scale_B); - } - - ppkernel.set_deq_scale(dequant_scale_B); - - if (M <= 16) { - // C:0/1 A:2 B:3/4 - // dequantize scale is moved into ppkernel - constexpr int prefetch_ahead = 64*1024; - tileconfig_t tfg(1, 0, {M,M,M,16,16}, 64); - auto * pBint = reinterpret_cast(&internalBI8[0]); - auto & B2buff = weiBuff; - B2buff.resize(32*2, 32); - auto * const pB = &B2buff[0]; - auto * pBsrc = pB + (32*32) * 0; - auto * pBdst = pB + (32*32) * 1; - functional::i8_to_bf16_Kx32<32>(pBint, pBsrc); - - auto * const pC0 = &buffC[0]; - const auto strideA = matA.stride; - loop2D_no_bM<32>(M, N, [&](int m, int n, int valid_m, int valid_n) { - // C:Mx32 = A:Mx32 x B:32x32 - _tile_zero(0); - _tile_zero(1); - auto * pA0 = &matA[0]; - for(int k=0; k(pBint); - - functional::i8_to_bf16_Kx32<8>(pBint, pBdst); - _tile_loadd(3, pBsrc, 64); - functional::i8_to_bf16_Kx32<8>(pBint, pBdst + 8*32); - _tile_dpbf16ps(0, 2, 3); // C0 += A*B0 - - prefetch_bytes<512, _MM_HINT_T1, prefetch_ahead>(pBint); - functional::i8_to_bf16_Kx32<8>(pBint, pBdst + 16*32); - _tile_loadd(4, pBsrc + 16*32, 64); - functional::i8_to_bf16_Kx32<8>(pBint, pBdst + 24*32); - _tile_dpbf16ps(1, 2, 4); // C1 += A*B1 - std::swap(pBsrc, pBdst); - } - if (Ktails) { - _tile_loadd(2, pA0 - Kbackoff, strideA); // backoff to prevent access beyond the end of A - prefetch_bytes<512, _MM_HINT_T1, prefetch_ahead>(pBint); - - functional::i8_to_bf16_Kx32<8>(pBint, pBdst); - _tile_loadd(3, pBsrc, 64); - functional::i8_to_bf16_Kx32<8>(pBint, pBdst + 8*32); - _tile_dpbf16ps(0, 2, 3); // C0 += A*B0 - - prefetch_bytes<512, _MM_HINT_T1, prefetch_ahead>(pBint); - functional::i8_to_bf16_Kx32<8>(pBint, pBdst + 16*32); - _tile_loadd(4, pBsrc + 16*32, 64); - functional::i8_to_bf16_Kx32<8>(pBint, pBdst + 24*32); - _tile_dpbf16ps(1, 2, 4); // C1 += A*B1 - std::swap(pBsrc, pBdst); - } - //prefetch_bytes<2048, _MM_HINT_T1, prefetch_ahead>(pBint); - _tile_stored(0, pC0, buffC.stride); - _tile_stored(1, pC0 + 16, buffC.stride); - //prefetch_bytes<2048, _MM_HINT_T1, prefetch_ahead>(pBint + 2048); - //int valid_n = std::min(N - n, 32); - (ppkernel)(buffC, 0, n + n0, M, valid_n); - }); - return; - } - - // 4 tiles buffC is reused as decompressed bf16 weights - constexpr int prefetch_ahead = 16*1024; - ov::bfloat16 * pBa = reinterpret_cast(&buffC(0,0)); - ov::bfloat16 * pBb = pBa + (16*32)*2; - auto kernel_2x2 = [&](int m, int n, int valid_m, int valid_n) { - auto strideA = matA.stride; - auto * pA0 = &matA(m, 0); - auto * pA1 = &matA(m + 16, 0); - auto * pBint = reinterpret_cast(&internalBI8(n>>5, 0)); - functional::i8_to_bf16_Kx32<32>(pBint, pBb); - - _tile_zero(0); - _tile_zero(1); - _tile_zero(2); - _tile_zero(3); - int k; - for (k = 0; k < Kbody; k += kStep) { - functional::i8_to_bf16_Kx32<16>(pBint, pBa); - - _tile_loadd(4, pA0 + k, strideA); - _tile_loadd(6, pBb, 64); - _tile_dpbf16ps(0, 4, 6); - - _tile_loadd(5, pA1 + k, strideA); - _tile_dpbf16ps(2, 5, 6); - - functional::i8_to_bf16_Kx32<16>(pBint, pBa + 16*32); - - _tile_loadd(7, pBb + 16*32, 64); - _tile_dpbf16ps(1, 4, 7); - _tile_dpbf16ps(3, 5, 7); - - std::swap(pBa, pBb); - } - if (Ktails) { - functional::i8_to_bf16_Kx32<16>(pBint, pBa); - - _tile_loadd(4, pA0 + k - Kbackoff, strideA); - _tile_loadd(6, pBb, 64); - _tile_dpbf16ps(0, 4, 6); - - _tile_loadd(5, pA1 + k - Kbackoff, strideA); - _tile_dpbf16ps(2, 5, 6); - - functional::i8_to_bf16_Kx32<16>(pBint, pBa + 16*32); - - _tile_loadd(7, pBb + 16*32, 64); - _tile_dpbf16ps(1, 4, 7); - _tile_dpbf16ps(3, 5, 7); - - std::swap(pBa, pBb); - } - _tile_stored(0, &buffC(0,0), buffC.stride); - _tile_stored(1, &buffC(0,16), buffC.stride); - _tile_stored(2, &buffC(16,0), buffC.stride); - _tile_stored(3, &buffC(16,16), buffC.stride); - (ppkernel)(buffC, m, n + n0, valid_m, valid_n); - }; - - if (M <= 32 && M > 16) { - // 2x2 C:0/1/2/3 A:4/5 B:6/7 - tileconfig_t tfg(1, 0, {16, 16, M-16, M-16, 16, M-16, 16, 16}, 64); - loop2D_no_bM<32>(M, N, kernel_2x2); - return; - } - - // determine blocking scheme - int elesz = sizeof(uint16_t); - int L2 = 2048*1024; // 2MB - int slice_size = 32*rndup(K, 32)*elesz; - int mc = std::max(1, L2/slice_size - 1); // if 1 32xK slice cannot fit L2, use 1 slice at least - - // main loop - tileconfig_t tfg(1, 0, 8, 16, 64); - loop2D_opt_Mtail<32, 32>(M, N, mc, kernel_2x2); - } -}; - -//https://stackoverflow.com/questions/29519222/how-to-transpose-a-16x16-matrix-using-simd-instructions -// vector multiply with matrix: -// mAvB: A(M, K) * B(K, 1) => C(M, 1) -// vAmB: A(1, K) * B(K, N) => C(1, N) -// -// in mAvB form, block of A (16x32) is transposed in register -// in unit of 2 packed bf16, and then vdpbf16ps was used -// to multiply with broadcasted B (2x1) and accumulate into C (16x1) -// -// B is pre-broadcasted in unit of 2 -// -struct GemAvB { - tensor2D Bpadded; - GemAvB() { - } - - void operator()(tensor2D & matA, - ov::bfloat16 * vecB, - float * vecC) { - int M = matA.dims[0]; - int K = matA.dims[1]; - - constexpr int kStep = 32; - - assert(K >= 32); - int Ktails = K % kStep; - int Kbody = K - Ktails; - int Kbackoff = (kStep - Ktails); - - if (K % 32) { - if (K > Bpadded.dims[1]) - Bpadded.resize(1, rndup(K, 32)); - auto newB = &Bpadded(0, 0); - memset(newB, 0, Bpadded.stride); - memcpy(newB, vecB, K * sizeof(ov::bfloat16)); - vecB = newB; - } - - for(int m = 0; m < M; m += 16) { - auto * pA = reinterpret_cast(&matA(m, 0)); - auto * pBi32 = reinterpret_cast(vecB); - __m512 regC0 = _mm512_setzero(); - __m512 regC1 = _mm512_setzero(); - __mmask16 kmask = _cvtu32_mask16(0xFFFF); - if (M-m < 16) { - kmask = _cvtu32_mask16(0xFFFF >> (16-(M-m))); - } - for(int k = 0; k < K; k += 32, pA += 64, pBi32 += 16) { - // handle Ab: 16x32 - // transposed in register as 16x16x2 - // r0: (a0,a1)(b0,b1).... - // r1: (a2,a3)(b2,b3).... - // ... - // rf: (a30,a31),(b30,b31).... - // - __m512i r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf; - auto stride = matA.stride; - r0 = _mm512_loadu_epi32(pA); - r1 = _mm512_loadu_epi32(pA + stride); - r2 = _mm512_loadu_epi32(pA + 2*stride); - r3 = _mm512_loadu_epi32(pA + 3*stride); - r4 = _mm512_loadu_epi32(pA + 4*stride); - r5 = _mm512_loadu_epi32(pA + 5*stride); - r6 = _mm512_loadu_epi32(pA + 6*stride); - r7 = _mm512_loadu_epi32(pA + 7*stride); - r8 = _mm512_loadu_epi32(pA + 8*stride); - r9 = _mm512_loadu_epi32(pA + 9*stride); - ra = _mm512_loadu_epi32(pA + 10*stride); - rb = _mm512_loadu_epi32(pA + 11*stride); - rc = _mm512_loadu_epi32(pA + 12*stride); - rd = _mm512_loadu_epi32(pA + 13*stride); - re = _mm512_loadu_epi32(pA + 14*stride); - rf = _mm512_loadu_epi32(pA + 15*stride); - - functional::transpose_m512i_16x16(r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf); - - // vdpbf16ps - regC0 = _mm512_dpbf16_ps(regC0, (__m512bh)(r0), (__m512bh)(_mm512_set1_epi32(pBi32[0]))); - regC1 = _mm512_dpbf16_ps(regC1, (__m512bh)(r1), (__m512bh)(_mm512_set1_epi32(pBi32[1]))); - regC0 = _mm512_dpbf16_ps(regC0, (__m512bh)(r2), (__m512bh)(_mm512_set1_epi32(pBi32[2]))); - regC1 = _mm512_dpbf16_ps(regC1, (__m512bh)(r3), (__m512bh)(_mm512_set1_epi32(pBi32[3]))); - regC0 = _mm512_dpbf16_ps(regC0, (__m512bh)(r4), (__m512bh)(_mm512_set1_epi32(pBi32[4]))); - regC1 = _mm512_dpbf16_ps(regC1, (__m512bh)(r5), (__m512bh)(_mm512_set1_epi32(pBi32[5]))); - regC0 = _mm512_dpbf16_ps(regC0, (__m512bh)(r6), (__m512bh)(_mm512_set1_epi32(pBi32[6]))); - regC1 = _mm512_dpbf16_ps(regC1, (__m512bh)(r7), (__m512bh)(_mm512_set1_epi32(pBi32[7]))); - regC0 = _mm512_dpbf16_ps(regC0, (__m512bh)(r8), (__m512bh)(_mm512_set1_epi32(pBi32[8]))); - regC1 = _mm512_dpbf16_ps(regC1, (__m512bh)(r9), (__m512bh)(_mm512_set1_epi32(pBi32[9]))); - regC0 = _mm512_dpbf16_ps(regC0, (__m512bh)(ra), (__m512bh)(_mm512_set1_epi32(pBi32[10]))); - regC1 = _mm512_dpbf16_ps(regC1, (__m512bh)(rb), (__m512bh)(_mm512_set1_epi32(pBi32[11]))); - regC0 = _mm512_dpbf16_ps(regC0, (__m512bh)(rc), (__m512bh)(_mm512_set1_epi32(pBi32[12]))); - regC1 = _mm512_dpbf16_ps(regC1, (__m512bh)(rd), (__m512bh)(_mm512_set1_epi32(pBi32[13]))); - regC0 = _mm512_dpbf16_ps(regC0, (__m512bh)(re), (__m512bh)(_mm512_set1_epi32(pBi32[14]))); - regC1 = _mm512_dpbf16_ps(regC1, (__m512bh)(rf), (__m512bh)(_mm512_set1_epi32(pBi32[15]))); - } - regC0 = _mm512_add_ps(regC0, regC1); - _mm512_mask_storeu_ps (vecC + m, kmask, regC0); - //auto regOut = _mm512_cvtne2ps_pbh(regC0, regC0); // only 16 ov::bfloat16 results in lower 256bits - //_mm256_storeu_si256(reinterpret_cast<__m256i_u *>(vecC + m), _mm512_extracti64x4_epi64(regOut, 0)); - } - } -}; - -} // namespace amx - -inline std::ostream & operator<<(std::ostream & os, const amx_kernel::PP::Steps & steps) { - os << "amx_kernel::PP::Steps::"; - if (steps == amx_kernel::PP::Steps::NONE) - os << "NONE"; - if (steps & amx_kernel::PP::Steps::DEQUANT) - os << "_DEQUANT"; - if (steps & amx_kernel::PP::Steps::BIAS) - os << "_BIAS"; - if (steps & amx_kernel::PP::Steps::GELU) - os << "_GELU"; - if (steps & amx_kernel::PP::Steps::QUANT) - os << "_QUANT"; - return os; -} +} \ No newline at end of file diff --git a/src/mm_kernel_api.cpp b/src/mm_kernel_api.cpp new file mode 100644 index 0000000..9b2ebff --- /dev/null +++ b/src/mm_kernel_api.cpp @@ -0,0 +1,28 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "llm_mm.hpp" +#include "mm_kernel_amx.hpp" + +namespace llmdnn { + +static decltype(&mm_kernel_create) mm_kernel_create_ptr = mm_kernel_create_amx; +static decltype(&mm_kernel_destroy) mm_kernel_destroy_ptr = mm_kernel_destroy_amx; +static decltype(&mm_kernel_execute) mm_kernel_execute_ptr = mm_kernel_execute_amx; + +// interface +bool mm_kernel_create(mm_kernel** mm, const mm_create_param* param) { + return mm_kernel_create_ptr(mm, param); +} + +void mm_kernel_destroy(const mm_kernel* mm) { + mm_kernel_destroy_ptr(mm); +} + +void mm_kernel_execute(const mm_kernel* mm, void* ptr_a, void* ptr_b, void* ptr_c, size_t lda, size_t ldb, size_t ldc, + size_t M, size_t N, size_t K) { + mm_kernel_execute_ptr(mm, ptr_a, ptr_b, ptr_c, lda, ldb, ldc, M, N, K); +} + +} \ No newline at end of file diff --git a/src/mm_kernel_common_amx.hpp b/src/mm_kernel_common_amx.hpp new file mode 100644 index 0000000..d07d2c7 --- /dev/null +++ b/src/mm_kernel_common_amx.hpp @@ -0,0 +1,2455 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "common/bf16.hpp" +#include "common/tensor2d.hpp" +#include "utility_amx.hpp" + +#ifdef _WIN32 +#include +#else +#include +#endif + +#ifdef ENABLE_NUMA +#include "numa.h" +#endif + +namespace amx_kernel { + +namespace functional { + + inline void transpose_m512i_16x16(__m512i &r0, __m512i &r1, __m512i &r2, __m512i &r3, + __m512i &r4, __m512i &r5, __m512i &r6, __m512i &r7, + __m512i &r8, __m512i &r9, __m512i &ra, __m512i &rb, + __m512i &rc, __m512i &rd, __m512i &re, __m512i &rf) { + __m512i t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, ta, tb, tc, td, te, tf; + + t0 = _mm512_unpacklo_epi32(r0,r1); // 0 16 1 17 4 20 5 21 8 24 9 25 12 28 13 29 + t1 = _mm512_unpackhi_epi32(r0,r1); // 2 18 3 19 6 22 7 23 10 26 11 27 14 30 15 31 + t2 = _mm512_unpacklo_epi32(r2,r3); // 32 48 33 49 ... + t3 = _mm512_unpackhi_epi32(r2,r3); // 34 50 35 51 ... + t4 = _mm512_unpacklo_epi32(r4,r5); // 64 80 65 81 ... + t5 = _mm512_unpackhi_epi32(r4,r5); // 66 82 67 83 ... + t6 = _mm512_unpacklo_epi32(r6,r7); // 96 112 97 113 ... + t7 = _mm512_unpackhi_epi32(r6,r7); // 98 114 99 115 ... + t8 = _mm512_unpacklo_epi32(r8,r9); // 128 ... + t9 = _mm512_unpackhi_epi32(r8,r9); // 130 ... + ta = _mm512_unpacklo_epi32(ra,rb); // 160 ... + tb = _mm512_unpackhi_epi32(ra,rb); // 162 ... + tc = _mm512_unpacklo_epi32(rc,rd); // 196 ... + td = _mm512_unpackhi_epi32(rc,rd); // 198 ... + te = _mm512_unpacklo_epi32(re,rf); // 228 ... + tf = _mm512_unpackhi_epi32(re,rf); // 230 ... + + r0 = _mm512_unpacklo_epi64(t0,t2); // 0 16 32 48 ... + r1 = _mm512_unpackhi_epi64(t0,t2); // 1 17 33 49 ... + r2 = _mm512_unpacklo_epi64(t1,t3); // 2 18 34 49 ... + r3 = _mm512_unpackhi_epi64(t1,t3); // 3 19 35 51 ... + r4 = _mm512_unpacklo_epi64(t4,t6); // 64 80 96 112 ... + r5 = _mm512_unpackhi_epi64(t4,t6); // 65 81 97 114 ... + r6 = _mm512_unpacklo_epi64(t5,t7); // 66 82 98 113 ... + r7 = _mm512_unpackhi_epi64(t5,t7); // 67 83 99 115 ... + r8 = _mm512_unpacklo_epi64(t8,ta); // 128 144 160 176 ... + r9 = _mm512_unpackhi_epi64(t8,ta); // 129 145 161 178 ... + ra = _mm512_unpacklo_epi64(t9,tb); // 130 146 162 177 ... + rb = _mm512_unpackhi_epi64(t9,tb); // 131 147 163 179 ... + rc = _mm512_unpacklo_epi64(tc,te); // 192 208 228 240 ... + rd = _mm512_unpackhi_epi64(tc,te); // 193 209 229 241 ... + re = _mm512_unpacklo_epi64(td,tf); // 194 210 230 242 ... + rf = _mm512_unpackhi_epi64(td,tf); // 195 211 231 243 ... + + t0 = _mm512_shuffle_i32x4(r0, r4, 0x88); // 0 16 32 48 8 24 40 56 64 80 96 112 ... + t1 = _mm512_shuffle_i32x4(r1, r5, 0x88); // 1 17 33 49 ... + t2 = _mm512_shuffle_i32x4(r2, r6, 0x88); // 2 18 34 50 ... + t3 = _mm512_shuffle_i32x4(r3, r7, 0x88); // 3 19 35 51 ... + t4 = _mm512_shuffle_i32x4(r0, r4, 0xdd); // 4 20 36 52 ... + t5 = _mm512_shuffle_i32x4(r1, r5, 0xdd); // 5 21 37 53 ... + t6 = _mm512_shuffle_i32x4(r2, r6, 0xdd); // 6 22 38 54 ... + t7 = _mm512_shuffle_i32x4(r3, r7, 0xdd); // 7 23 39 55 ... + t8 = _mm512_shuffle_i32x4(r8, rc, 0x88); // 128 144 160 176 ... + t9 = _mm512_shuffle_i32x4(r9, rd, 0x88); // 129 145 161 177 ... + ta = _mm512_shuffle_i32x4(ra, re, 0x88); // 130 146 162 178 ... + tb = _mm512_shuffle_i32x4(rb, rf, 0x88); // 131 147 163 179 ... + tc = _mm512_shuffle_i32x4(r8, rc, 0xdd); // 132 148 164 180 ... + td = _mm512_shuffle_i32x4(r9, rd, 0xdd); // 133 149 165 181 ... + te = _mm512_shuffle_i32x4(ra, re, 0xdd); // 134 150 166 182 ... + tf = _mm512_shuffle_i32x4(rb, rf, 0xdd); // 135 151 167 183 ... + + r0 = _mm512_shuffle_i32x4(t0, t8, 0x88); // 0 16 32 48 64 80 96 112 ... 240 + r1 = _mm512_shuffle_i32x4(t1, t9, 0x88); // 1 17 33 49 66 81 97 113 ... 241 + r2 = _mm512_shuffle_i32x4(t2, ta, 0x88); // 2 18 34 50 67 82 98 114 ... 242 + r3 = _mm512_shuffle_i32x4(t3, tb, 0x88); // 3 19 35 51 68 83 99 115 ... 243 + r4 = _mm512_shuffle_i32x4(t4, tc, 0x88); // 4 ... + r5 = _mm512_shuffle_i32x4(t5, td, 0x88); // 5 ... + r6 = _mm512_shuffle_i32x4(t6, te, 0x88); // 6 ... + r7 = _mm512_shuffle_i32x4(t7, tf, 0x88); // 7 ... + r8 = _mm512_shuffle_i32x4(t0, t8, 0xdd); // 8 ... + r9 = _mm512_shuffle_i32x4(t1, t9, 0xdd); // 9 ... + ra = _mm512_shuffle_i32x4(t2, ta, 0xdd); // 10 ... + rb = _mm512_shuffle_i32x4(t3, tb, 0xdd); // 11 ... + rc = _mm512_shuffle_i32x4(t4, tc, 0xdd); // 12 ... + rd = _mm512_shuffle_i32x4(t5, td, 0xdd); // 13 ... + re = _mm512_shuffle_i32x4(t6, te, 0xdd); // 14 ... + rf = _mm512_shuffle_i32x4(t7, tf, 0xdd); // 15 31 47 63 79 96 111 127 ... 255 + } + + inline void transpose_epi32_16x16(void * _dst, const void * src, int stride) { + auto * dst = reinterpret_cast(_dst); + __m512i r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf; + auto * pA = reinterpret_cast(src); + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_loadu_epi32(pA + 2*stride); + r3 = _mm512_loadu_epi32(pA + 3*stride); + r4 = _mm512_loadu_epi32(pA + 4*stride); + r5 = _mm512_loadu_epi32(pA + 5*stride); + r6 = _mm512_loadu_epi32(pA + 6*stride); + r7 = _mm512_loadu_epi32(pA + 7*stride); + r8 = _mm512_loadu_epi32(pA + 8*stride); + r9 = _mm512_loadu_epi32(pA + 9*stride); + ra = _mm512_loadu_epi32(pA + 10*stride); + rb = _mm512_loadu_epi32(pA + 11*stride); + rc = _mm512_loadu_epi32(pA + 12*stride); + rd = _mm512_loadu_epi32(pA + 13*stride); + re = _mm512_loadu_epi32(pA + 14*stride); + rf = _mm512_loadu_epi32(pA + 15*stride); + + transpose_m512i_16x16(r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf); + + _mm512_storeu_epi32(dst, r0); + _mm512_storeu_epi32(dst + 16, r1); + _mm512_storeu_epi32(dst + 2*16, r2); + _mm512_storeu_epi32(dst + 3*16, r3); + _mm512_storeu_epi32(dst + 4*16, r4); + _mm512_storeu_epi32(dst + 5*16, r5); + _mm512_storeu_epi32(dst + 6*16, r6); + _mm512_storeu_epi32(dst + 7*16, r7); + _mm512_storeu_epi32(dst + 8*16, r8); + _mm512_storeu_epi32(dst + 9*16, r9); + _mm512_storeu_epi32(dst + 10*16, ra); + _mm512_storeu_epi32(dst + 11*16, rb); + _mm512_storeu_epi32(dst + 12*16, rc); + _mm512_storeu_epi32(dst + 13*16, rd); + _mm512_storeu_epi32(dst + 14*16, re); + _mm512_storeu_epi32(dst + 15*16, rf); + } + + // 16xN, N<=16, non-valid part is filled with zeros + inline void transpose_epi32_16xN(void * _dst, const void * src, int stride, int valid_bytes) { + auto * dst = reinterpret_cast(_dst); + __m512i r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf; + auto * pA = reinterpret_cast(src); + uint64_t mask_value = 0xFFFFFFFFFFFFFFFFull >> (64 - valid_bytes); + __mmask64 mask = _cvtu64_mask64(mask_value); + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); + r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); + r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); + r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); + r6 = _mm512_maskz_loadu_epi8 (mask, pA + 6*stride); + r7 = _mm512_maskz_loadu_epi8 (mask, pA + 7*stride); + r8 = _mm512_maskz_loadu_epi8 (mask, pA + 8*stride); + r9 = _mm512_maskz_loadu_epi8 (mask, pA + 9*stride); + ra = _mm512_maskz_loadu_epi8 (mask, pA + 10*stride); + rb = _mm512_maskz_loadu_epi8 (mask, pA + 11*stride); + rc = _mm512_maskz_loadu_epi8 (mask, pA + 12*stride); + rd = _mm512_maskz_loadu_epi8 (mask, pA + 13*stride); + re = _mm512_maskz_loadu_epi8 (mask, pA + 14*stride); + rf = _mm512_maskz_loadu_epi8 (mask, pA + 15*stride); + transpose_m512i_16x16(r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf); + _mm512_storeu_epi32(dst, r0); + _mm512_storeu_epi32(dst + 16, r1); + _mm512_storeu_epi32(dst + 2*16, r2); + _mm512_storeu_epi32(dst + 3*16, r3); + _mm512_storeu_epi32(dst + 4*16, r4); + _mm512_storeu_epi32(dst + 5*16, r5); + _mm512_storeu_epi32(dst + 6*16, r6); + _mm512_storeu_epi32(dst + 7*16, r7); + _mm512_storeu_epi32(dst + 8*16, r8); + _mm512_storeu_epi32(dst + 9*16, r9); + _mm512_storeu_epi32(dst + 10*16, ra); + _mm512_storeu_epi32(dst + 11*16, rb); + _mm512_storeu_epi32(dst + 12*16, rc); + _mm512_storeu_epi32(dst + 13*16, rd); + _mm512_storeu_epi32(dst + 14*16, re); + _mm512_storeu_epi32(dst + 15*16, rf); + } + + inline void transpose_epi32_Mx16(void * _dst, const void * src, int stride, int valid_m) { + auto * dst = reinterpret_cast(_dst); + __m512i r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf; + auto * pA = reinterpret_cast(src); + switch (valid_m) { + case 15: + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_loadu_epi32(pA + 2*stride); + r3 = _mm512_loadu_epi32(pA + 3*stride); + r4 = _mm512_loadu_epi32(pA + 4*stride); + r5 = _mm512_loadu_epi32(pA + 5*stride); + r6 = _mm512_loadu_epi32(pA + 6*stride); + r7 = _mm512_loadu_epi32(pA + 7*stride); + r8 = _mm512_loadu_epi32(pA + 8*stride); + r9 = _mm512_loadu_epi32(pA + 9*stride); + ra = _mm512_loadu_epi32(pA + 10*stride); + rb = _mm512_loadu_epi32(pA + 11*stride); + rc = _mm512_loadu_epi32(pA + 12*stride); + rd = _mm512_loadu_epi32(pA + 13*stride); + re = _mm512_loadu_epi32(pA + 14*stride); + rf = _mm512_setzero_epi32(); + break; + case 14: + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_loadu_epi32(pA + 2*stride); + r3 = _mm512_loadu_epi32(pA + 3*stride); + r4 = _mm512_loadu_epi32(pA + 4*stride); + r5 = _mm512_loadu_epi32(pA + 5*stride); + r6 = _mm512_loadu_epi32(pA + 6*stride); + r7 = _mm512_loadu_epi32(pA + 7*stride); + r8 = _mm512_loadu_epi32(pA + 8*stride); + r9 = _mm512_loadu_epi32(pA + 9*stride); + ra = _mm512_loadu_epi32(pA + 10*stride); + rb = _mm512_loadu_epi32(pA + 11*stride); + rc = _mm512_loadu_epi32(pA + 12*stride); + rd = _mm512_loadu_epi32(pA + 13*stride); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 13: + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_loadu_epi32(pA + 2*stride); + r3 = _mm512_loadu_epi32(pA + 3*stride); + r4 = _mm512_loadu_epi32(pA + 4*stride); + r5 = _mm512_loadu_epi32(pA + 5*stride); + r6 = _mm512_loadu_epi32(pA + 6*stride); + r7 = _mm512_loadu_epi32(pA + 7*stride); + r8 = _mm512_loadu_epi32(pA + 8*stride); + r9 = _mm512_loadu_epi32(pA + 9*stride); + ra = _mm512_loadu_epi32(pA + 10*stride); + rb = _mm512_loadu_epi32(pA + 11*stride); + rc = _mm512_loadu_epi32(pA + 12*stride); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 12: + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_loadu_epi32(pA + 2*stride); + r3 = _mm512_loadu_epi32(pA + 3*stride); + r4 = _mm512_loadu_epi32(pA + 4*stride); + r5 = _mm512_loadu_epi32(pA + 5*stride); + r6 = _mm512_loadu_epi32(pA + 6*stride); + r7 = _mm512_loadu_epi32(pA + 7*stride); + r8 = _mm512_loadu_epi32(pA + 8*stride); + r9 = _mm512_loadu_epi32(pA + 9*stride); + ra = _mm512_loadu_epi32(pA + 10*stride); + rb = _mm512_loadu_epi32(pA + 11*stride); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 11: + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_loadu_epi32(pA + 2*stride); + r3 = _mm512_loadu_epi32(pA + 3*stride); + r4 = _mm512_loadu_epi32(pA + 4*stride); + r5 = _mm512_loadu_epi32(pA + 5*stride); + r6 = _mm512_loadu_epi32(pA + 6*stride); + r7 = _mm512_loadu_epi32(pA + 7*stride); + r8 = _mm512_loadu_epi32(pA + 8*stride); + r9 = _mm512_loadu_epi32(pA + 9*stride); + ra = _mm512_loadu_epi32(pA + 10*stride); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 10: + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_loadu_epi32(pA + 2*stride); + r3 = _mm512_loadu_epi32(pA + 3*stride); + r4 = _mm512_loadu_epi32(pA + 4*stride); + r5 = _mm512_loadu_epi32(pA + 5*stride); + r6 = _mm512_loadu_epi32(pA + 6*stride); + r7 = _mm512_loadu_epi32(pA + 7*stride); + r8 = _mm512_loadu_epi32(pA + 8*stride); + r9 = _mm512_loadu_epi32(pA + 9*stride); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 9: + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_loadu_epi32(pA + 2*stride); + r3 = _mm512_loadu_epi32(pA + 3*stride); + r4 = _mm512_loadu_epi32(pA + 4*stride); + r5 = _mm512_loadu_epi32(pA + 5*stride); + r6 = _mm512_loadu_epi32(pA + 6*stride); + r7 = _mm512_loadu_epi32(pA + 7*stride); + r8 = _mm512_loadu_epi32(pA + 8*stride); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 8: + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_loadu_epi32(pA + 2*stride); + r3 = _mm512_loadu_epi32(pA + 3*stride); + r4 = _mm512_loadu_epi32(pA + 4*stride); + r5 = _mm512_loadu_epi32(pA + 5*stride); + r6 = _mm512_loadu_epi32(pA + 6*stride); + r7 = _mm512_loadu_epi32(pA + 7*stride); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 7: + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_loadu_epi32(pA + 2*stride); + r3 = _mm512_loadu_epi32(pA + 3*stride); + r4 = _mm512_loadu_epi32(pA + 4*stride); + r5 = _mm512_loadu_epi32(pA + 5*stride); + r6 = _mm512_loadu_epi32(pA + 6*stride); + r7 = _mm512_setzero_epi32(); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 6: + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_loadu_epi32(pA + 2*stride); + r3 = _mm512_loadu_epi32(pA + 3*stride); + r4 = _mm512_loadu_epi32(pA + 4*stride); + r5 = _mm512_loadu_epi32(pA + 5*stride); + r6 = _mm512_setzero_epi32(); + r7 = _mm512_setzero_epi32(); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 5: + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_loadu_epi32(pA + 2*stride); + r3 = _mm512_loadu_epi32(pA + 3*stride); + r4 = _mm512_loadu_epi32(pA + 4*stride); + r5 = _mm512_setzero_epi32(); + r6 = _mm512_setzero_epi32(); + r7 = _mm512_setzero_epi32(); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 4: + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_loadu_epi32(pA + 2*stride); + r3 = _mm512_loadu_epi32(pA + 3*stride); + r4 = _mm512_setzero_epi32(); + r5 = _mm512_setzero_epi32(); + r6 = _mm512_setzero_epi32(); + r7 = _mm512_setzero_epi32(); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 3: + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_loadu_epi32(pA + 2*stride); + r3 = _mm512_setzero_epi32(); + r4 = _mm512_setzero_epi32(); + r5 = _mm512_setzero_epi32(); + r6 = _mm512_setzero_epi32(); + r7 = _mm512_setzero_epi32(); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 2: + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_setzero_epi32(); + r3 = _mm512_setzero_epi32(); + r4 = _mm512_setzero_epi32(); + r5 = _mm512_setzero_epi32(); + r6 = _mm512_setzero_epi32(); + r7 = _mm512_setzero_epi32(); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 1: + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_setzero_epi32(); + r2 = _mm512_setzero_epi32(); + r3 = _mm512_setzero_epi32(); + r4 = _mm512_setzero_epi32(); + r5 = _mm512_setzero_epi32(); + r6 = _mm512_setzero_epi32(); + r7 = _mm512_setzero_epi32(); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + } + + transpose_m512i_16x16(r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf); + + _mm512_storeu_epi32(dst, r0); + _mm512_storeu_epi32(dst + 16, r1); + _mm512_storeu_epi32(dst + 2*16, r2); + _mm512_storeu_epi32(dst + 3*16, r3); + _mm512_storeu_epi32(dst + 4*16, r4); + _mm512_storeu_epi32(dst + 5*16, r5); + _mm512_storeu_epi32(dst + 6*16, r6); + _mm512_storeu_epi32(dst + 7*16, r7); + _mm512_storeu_epi32(dst + 8*16, r8); + _mm512_storeu_epi32(dst + 9*16, r9); + _mm512_storeu_epi32(dst + 10*16, ra); + _mm512_storeu_epi32(dst + 11*16, rb); + _mm512_storeu_epi32(dst + 12*16, rc); + _mm512_storeu_epi32(dst + 13*16, rd); + _mm512_storeu_epi32(dst + 14*16, re); + _mm512_storeu_epi32(dst + 15*16, rf); + } + + // 16xN, N<=16, non-valid part is on the left, filled with zeros + inline void transpose_epi32_16xN_right_align(void * _dst, const void * src, int stride, int valid_bytes) { + auto * dst = reinterpret_cast(_dst); + __m512i r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf; + int invalid_bytes = 64 - valid_bytes; + auto * pA = reinterpret_cast(src) - invalid_bytes; + uint64_t mask_value = 0xFFFFFFFFFFFFFFFFull << invalid_bytes; + __mmask64 mask = _cvtu64_mask64(mask_value); + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); + r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); + r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); + r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); + r6 = _mm512_maskz_loadu_epi8 (mask, pA + 6*stride); + r7 = _mm512_maskz_loadu_epi8 (mask, pA + 7*stride); + r8 = _mm512_maskz_loadu_epi8 (mask, pA + 8*stride); + r9 = _mm512_maskz_loadu_epi8 (mask, pA + 9*stride); + ra = _mm512_maskz_loadu_epi8 (mask, pA + 10*stride); + rb = _mm512_maskz_loadu_epi8 (mask, pA + 11*stride); + rc = _mm512_maskz_loadu_epi8 (mask, pA + 12*stride); + rd = _mm512_maskz_loadu_epi8 (mask, pA + 13*stride); + re = _mm512_maskz_loadu_epi8 (mask, pA + 14*stride); + rf = _mm512_maskz_loadu_epi8 (mask, pA + 15*stride); + transpose_m512i_16x16(r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf); + _mm512_storeu_epi32(dst, r0); + _mm512_storeu_epi32(dst + 16, r1); + _mm512_storeu_epi32(dst + 2*16, r2); + _mm512_storeu_epi32(dst + 3*16, r3); + _mm512_storeu_epi32(dst + 4*16, r4); + _mm512_storeu_epi32(dst + 5*16, r5); + _mm512_storeu_epi32(dst + 6*16, r6); + _mm512_storeu_epi32(dst + 7*16, r7); + _mm512_storeu_epi32(dst + 8*16, r8); + _mm512_storeu_epi32(dst + 9*16, r9); + _mm512_storeu_epi32(dst + 10*16, ra); + _mm512_storeu_epi32(dst + 11*16, rb); + _mm512_storeu_epi32(dst + 12*16, rc); + _mm512_storeu_epi32(dst + 13*16, rd); + _mm512_storeu_epi32(dst + 14*16, re); + _mm512_storeu_epi32(dst + 15*16, rf); + } + + inline void transpose_epi32_MxN_right_align(void * _dst, const void * src, int stride, int valid_bytes, int valid_m) { + auto * dst = reinterpret_cast(_dst); + __m512i r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf; + int invalid_bytes = 64 - valid_bytes; + auto * pA = reinterpret_cast(src) - invalid_bytes; + uint64_t mask_value = 0xFFFFFFFFFFFFFFFFull << invalid_bytes; + __mmask64 mask = _cvtu64_mask64(mask_value); + switch (valid_m) { + case 15: + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); + r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); + r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); + r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); + r6 = _mm512_maskz_loadu_epi8 (mask, pA + 6*stride); + r7 = _mm512_maskz_loadu_epi8 (mask, pA + 7*stride); + r8 = _mm512_maskz_loadu_epi8 (mask, pA + 8*stride); + r9 = _mm512_maskz_loadu_epi8 (mask, pA + 9*stride); + ra = _mm512_maskz_loadu_epi8 (mask, pA + 10*stride); + rb = _mm512_maskz_loadu_epi8 (mask, pA + 11*stride); + rc = _mm512_maskz_loadu_epi8 (mask, pA + 12*stride); + rd = _mm512_maskz_loadu_epi8 (mask, pA + 13*stride); + re = _mm512_maskz_loadu_epi8 (mask, pA + 14*stride); + rf = _mm512_setzero_epi32(); + break; + case 14: + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); + r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); + r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); + r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); + r6 = _mm512_maskz_loadu_epi8 (mask, pA + 6*stride); + r7 = _mm512_maskz_loadu_epi8 (mask, pA + 7*stride); + r8 = _mm512_maskz_loadu_epi8 (mask, pA + 8*stride); + r9 = _mm512_maskz_loadu_epi8 (mask, pA + 9*stride); + ra = _mm512_maskz_loadu_epi8 (mask, pA + 10*stride); + rb = _mm512_maskz_loadu_epi8 (mask, pA + 11*stride); + rc = _mm512_maskz_loadu_epi8 (mask, pA + 12*stride); + rd = _mm512_maskz_loadu_epi8 (mask, pA + 13*stride); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 13: + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); + r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); + r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); + r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); + r6 = _mm512_maskz_loadu_epi8 (mask, pA + 6*stride); + r7 = _mm512_maskz_loadu_epi8 (mask, pA + 7*stride); + r8 = _mm512_maskz_loadu_epi8 (mask, pA + 8*stride); + r9 = _mm512_maskz_loadu_epi8 (mask, pA + 9*stride); + ra = _mm512_maskz_loadu_epi8 (mask, pA + 10*stride); + rb = _mm512_maskz_loadu_epi8 (mask, pA + 11*stride); + rc = _mm512_maskz_loadu_epi8 (mask, pA + 12*stride); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 12: + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); + r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); + r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); + r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); + r6 = _mm512_maskz_loadu_epi8 (mask, pA + 6*stride); + r7 = _mm512_maskz_loadu_epi8 (mask, pA + 7*stride); + r8 = _mm512_maskz_loadu_epi8 (mask, pA + 8*stride); + r9 = _mm512_maskz_loadu_epi8 (mask, pA + 9*stride); + ra = _mm512_maskz_loadu_epi8 (mask, pA + 10*stride); + rb = _mm512_maskz_loadu_epi8 (mask, pA + 11*stride); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 11: + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); + r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); + r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); + r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); + r6 = _mm512_maskz_loadu_epi8 (mask, pA + 6*stride); + r7 = _mm512_maskz_loadu_epi8 (mask, pA + 7*stride); + r8 = _mm512_maskz_loadu_epi8 (mask, pA + 8*stride); + r9 = _mm512_maskz_loadu_epi8 (mask, pA + 9*stride); + ra = _mm512_maskz_loadu_epi8 (mask, pA + 10*stride); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 10: + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); + r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); + r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); + r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); + r6 = _mm512_maskz_loadu_epi8 (mask, pA + 6*stride); + r7 = _mm512_maskz_loadu_epi8 (mask, pA + 7*stride); + r8 = _mm512_maskz_loadu_epi8 (mask, pA + 8*stride); + r9 = _mm512_maskz_loadu_epi8 (mask, pA + 9*stride); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 9: + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); + r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); + r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); + r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); + r6 = _mm512_maskz_loadu_epi8 (mask, pA + 6*stride); + r7 = _mm512_maskz_loadu_epi8 (mask, pA + 7*stride); + r8 = _mm512_maskz_loadu_epi8 (mask, pA + 8*stride); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 8: + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); + r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); + r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); + r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); + r6 = _mm512_maskz_loadu_epi8 (mask, pA + 6*stride); + r7 = _mm512_maskz_loadu_epi8 (mask, pA + 7*stride); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 7: + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); + r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); + r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); + r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); + r6 = _mm512_maskz_loadu_epi8 (mask, pA + 6*stride); + r7 = _mm512_setzero_epi32(); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 6: + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); + r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); + r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); + r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); + r6 = _mm512_setzero_epi32(); + r7 = _mm512_setzero_epi32(); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 5: + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); + r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); + r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); + r5 = _mm512_setzero_epi32(); + r6 = _mm512_setzero_epi32(); + r7 = _mm512_setzero_epi32(); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 4: + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); + r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); + r4 = _mm512_setzero_epi32(); + r5 = _mm512_setzero_epi32(); + r6 = _mm512_setzero_epi32(); + r7 = _mm512_setzero_epi32(); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 3: + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); + r3 = _mm512_setzero_epi32(); + r4 = _mm512_setzero_epi32(); + r5 = _mm512_setzero_epi32(); + r6 = _mm512_setzero_epi32(); + r7 = _mm512_setzero_epi32(); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 2: + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_setzero_epi32(); + r3 = _mm512_setzero_epi32(); + r4 = _mm512_setzero_epi32(); + r5 = _mm512_setzero_epi32(); + r6 = _mm512_setzero_epi32(); + r7 = _mm512_setzero_epi32(); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 1: + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_setzero_epi32(); + r2 = _mm512_setzero_epi32(); + r3 = _mm512_setzero_epi32(); + r4 = _mm512_setzero_epi32(); + r5 = _mm512_setzero_epi32(); + r6 = _mm512_setzero_epi32(); + r7 = _mm512_setzero_epi32(); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + } + transpose_m512i_16x16(r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf); + _mm512_storeu_epi32(dst, r0); + _mm512_storeu_epi32(dst + 16, r1); + _mm512_storeu_epi32(dst + 2*16, r2); + _mm512_storeu_epi32(dst + 3*16, r3); + _mm512_storeu_epi32(dst + 4*16, r4); + _mm512_storeu_epi32(dst + 5*16, r5); + _mm512_storeu_epi32(dst + 6*16, r6); + _mm512_storeu_epi32(dst + 7*16, r7); + _mm512_storeu_epi32(dst + 8*16, r8); + _mm512_storeu_epi32(dst + 9*16, r9); + _mm512_storeu_epi32(dst + 10*16, ra); + _mm512_storeu_epi32(dst + 11*16, rb); + _mm512_storeu_epi32(dst + 12*16, rc); + _mm512_storeu_epi32(dst + 13*16, rd); + _mm512_storeu_epi32(dst + 14*16, re); + _mm512_storeu_epi32(dst + 15*16, rf); + } + + // gelu_erf_minimax_approx_compute_vector_fwd in oneDNN + // x*0.5*(1+erf(x/sqrt(2))) = x*0.5*(1 + x*Polynomial(x^2)) + inline __m512 gelu_erf_minmax_approx(__m512 & x) { + auto x2 = _mm512_mul_ps(x, x); // x^2 + + auto x_positive = _mm512_castsi512_ps(_mm512_and_epi32(_mm512_castps_si512(x), _mm512_set1_epi32(0x7FFFFFFF))); // clear sign mask + auto x_half = _mm512_mul_ps(x, _mm512_set1_ps(0.5f)); + + auto poly = _mm512_castsi512_ps(_mm512_set1_epi32(0x1f1c83fd)); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0xa3198977))); // poly * x^2 + xxx + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0x268a7927))); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0xa998c963))); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0x2c67ddb2))); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0xaf013b2c))); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0x315d4a4f))); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0xb3969b11))); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0x35a776e9))); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0xb79b0914))); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0x3970b255))); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0xbb1b7399))); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0x3ca3621f))); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0xbe082bc7))); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0x3f4c4228))); + + // 1.0f + erf(x * inv_sqrt2) = 1.0f + x * P(x^2) + poly = _mm512_fmadd_ps(poly, x, _mm512_set1_ps(1.0f)); + // x*0.5*(1 + x*Polynomial(x^2)) + poly = _mm512_mul_ps(poly, x_half); + + // combine: + // zone_id + // 1 -inf; -saturation_lbound : 0.0f + // 2 -saturation_lbound; -linear_ubound : x*0.5*(1 + x*Polynomial(x^2)) + // 3 -linear_ubound, linear_ubound : x*0.5 + // 4 linear_ubound : saturation_lbound : x*0.5*(1 + x*Polynomial(x^2)) + // 5 saturation_lbound: +inf : x + constexpr int neg_saturation_lbound = 0xc0a00000; + constexpr int linear_ubound = 0x33800000; + constexpr int saturation_lbound = 0x40a00000; + + auto mask_x_not_zone1 = _mm512_cmpnlt_ps_mask(x, _mm512_castsi512_ps(_mm512_set1_epi32(neg_saturation_lbound))); + x = _mm512_maskz_mov_ps(mask_x_not_zone1, x); + + auto mask_x_in_zone5 = _mm512_cmpnle_ps_mask(x_positive, _mm512_castsi512_ps(_mm512_set1_epi32(saturation_lbound))); + poly = _mm512_mask_mov_ps(poly, mask_x_in_zone5, x); + + auto mask_x_in_zone3 = _mm512_cmple_ps_mask(x_positive, _mm512_castsi512_ps(_mm512_set1_epi32(linear_ubound))); + poly = _mm512_mask_mov_ps(poly, mask_x_in_zone3, x_half); + return poly; + } + + inline void kpack_tile_B0B1(void * _dst0, void * _dst1, const int8_t * _src, int stride, int src_rows) { + #define FROM_B(i) ((1<<4)|(i)) + static const uint32_t idx[16] = { 0,4,FROM_B(0),FROM_B(4), + 1,5,FROM_B(1),FROM_B(5), + 2,6,FROM_B(2),FROM_B(6), + 3,7,FROM_B(3),FROM_B(7)}; + auto midx = _mm512_loadu_epi64(idx); + __mmask16 mask = _cvtu32_mask16(0xFFFFu); + const auto * src = reinterpret_cast(_src); + auto * dst0 = reinterpret_cast(_dst0); + auto * dst1 = reinterpret_cast(_dst1); + if (src_rows == 64) { + for (int row = 0; row < 16; row++) { + // each element (a? or b?) is 32-bits, two lanes in each ymm register + auto a256 = _mm256_loadu_epi8(src); src += stride; // [a0 a1 a2 a3 | a4 a5 a6 a7] 256-bits ymm0 B0: a0-a3 B1: a4:a7 + auto b256 = _mm256_loadu_epi8(src); src += stride; // [b0 b1 b2 b3 | b4 b5 b6 b7] 256-bits ymm1 B0: b0-b3 B1: b4:b7 + auto c256 = _mm256_loadu_epi8(src); src += stride; // [c0 c1 c2 c3 | c4 c5 c6 c7] 256-bits ymm2 B0: c0-c3 B1: c4:c7 + auto d256 = _mm256_loadu_epi8(src); src += stride; // [d0 d1 d2 d3 | d4 d5 d6 d7] 256-bits ymm3 B0: d0-d3 B1: d4:d7 + auto a = _mm512_castsi256_si512(a256); + auto b = _mm512_castsi256_si512(b256); + auto c = _mm512_castsi256_si512(c256); + auto d = _mm512_castsi256_si512(d256); + auto ac = _mm512_mask_permutex2var_epi32(a, mask, midx, c); // [a0 a4 c0 c4 | a1 a5 c1 c5 | a2 a6 c2 c6 | a3 a7 c3 c7] + auto bd = _mm512_mask_permutex2var_epi32(b, mask, midx, d); // [b0 b4 d0 d4 | b1 b5 d1 d5 | b2 b6 d2 d6 | b3 b7 d3 d7] + auto aib = _mm512_unpacklo_epi8(ac, bd); // [a0&b0 a4&b4 | a1&b1 a5&b5 | a2&b2 a6&b6 | a3&b3 a7&b7] + auto cid = _mm512_unpackhi_epi8(ac, bd); // [c0&d0 c4&d4 | c1&d1 c5&d5 | c2&d2 c6&d6 | c3&d3 c7&d7] + auto rowB0 = _mm512_unpacklo_epi16(aib, cid); // [a0&b0&c0&d0 | a1&b1&c1&d1 | a2&b2&c2&d2 | a3&b3&c3&d3] 512-bit (64bytes) line in B0 + auto rowB1 = _mm512_unpackhi_epi16(aib, cid); // [a4&b4&c4&d4 | a5&b5&c5&d5 | a6&b6&c6&d6 | a7&b7&c7&d7] 512-bit (64bytes) line in B1 + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + dst0 += 64; + dst1 += 64; + } + } else { + // less than 64 source lines, + int allzero_dst_rows = (64-src_rows)/4; + int allnonzero_dst_rows = src_rows/4; + // padding zeros at the top + auto rowB0 = _mm512_setzero_si512(); + auto rowB1 = _mm512_setzero_si512(); + for(int i = 0; i < allzero_dst_rows ; i++) { + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + dst0 += 64; + dst1 += 64; + } + // mixed row + int tails_nz = (src_rows & 3); + if (tails_nz) { + __mmask32 kmask1 = _cvtu32_mask32(0xFFFFFFFF); + auto a256 = _mm256_setzero_si256(); // must be zero + auto b256 = _mm256_setzero_si256(); // when tails_nz > 2 + auto c256 = _mm256_setzero_si256(); // when tails_nz > 1 + auto d256 = _mm256_setzero_si256(); // when tails_nz > 0(always load) + if (tails_nz > 2) { + b256 = _mm256_loadu_epi8 (src); src += stride; + } + if (tails_nz > 1) { + c256 = _mm256_loadu_epi8 (src); src += stride; + } + d256 = _mm256_loadu_epi8 (src); src += stride; + auto a = _mm512_castsi256_si512(a256); + auto b = _mm512_castsi256_si512(b256); + auto c = _mm512_castsi256_si512(c256); + auto d = _mm512_castsi256_si512(d256); + auto ac = _mm512_mask_permutex2var_epi32(a, mask, midx, c); // [a0 a4 c0 c4 | a1 a5 c1 c5 | a2 a6 c2 c6 | a3 a7 c3 c7] + auto bd = _mm512_mask_permutex2var_epi32(b, mask, midx, d); // [b0 b4 d0 d4 | b1 b5 d1 d5 | b2 b6 d2 d6 | b3 b7 d3 d7] + auto aib = _mm512_unpacklo_epi8(ac, bd); // [a0&b0 a4&b4 | a1&b1 a5&b5 | a2&b2 a6&b6 | a3&b3 a7&b7] + auto cid = _mm512_unpackhi_epi8(ac, bd); // [c0&d0 c4&d4 | c1&d1 c5&d5 | c2&d2 c6&d6 | c3&d3 c7&d7] + auto rowB0 = _mm512_unpacklo_epi16(aib, cid); // [a0&b0&c0&d0 | a1&b1&c1&d1 | a2&b2&c2&d2 | a3&b3&c3&d3] 512-bit (64bytes) line in B0 + auto rowB1 = _mm512_unpackhi_epi16(aib, cid); // [a4&b4&c4&d4 | a5&b5&c5&d5 | a6&b6&c6&d6 | a7&b7&c7&d7] 512-bit (64bytes) line in B1 + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + dst0 += 64; + dst1 += 64; + } + // all non zeros + for (int i = 0; i < allnonzero_dst_rows; i++) { + auto a256 = _mm256_loadu_epi8 (src); src += stride; + auto b256 = _mm256_loadu_epi8 (src); src += stride; + auto c256 = _mm256_loadu_epi8 (src); src += stride; + auto d256 = _mm256_loadu_epi8 (src); src += stride; + auto a = _mm512_castsi256_si512(a256); + auto b = _mm512_castsi256_si512(b256); + auto c = _mm512_castsi256_si512(c256); + auto d = _mm512_castsi256_si512(d256); + auto ac = _mm512_mask_permutex2var_epi32(a, mask, midx, c); // [a0 a4 c0 c4 | a1 a5 c1 c5 | a2 a6 c2 c6 | a3 a7 c3 c7] + auto bd = _mm512_mask_permutex2var_epi32(b, mask, midx, d); // [b0 b4 d0 d4 | b1 b5 d1 d5 | b2 b6 d2 d6 | b3 b7 d3 d7] + auto aib = _mm512_unpacklo_epi8(ac, bd); // [a0&b0 a4&b4 | a1&b1 a5&b5 | a2&b2 a6&b6 | a3&b3 a7&b7] + auto cid = _mm512_unpackhi_epi8(ac, bd); // [c0&d0 c4&d4 | c1&d1 c5&d5 | c2&d2 c6&d6 | c3&d3 c7&d7] + auto rowB0 = _mm512_unpacklo_epi16(aib, cid); // [a0&b0&c0&d0 | a1&b1&c1&d1 | a2&b2&c2&d2 | a3&b3&c3&d3] 512-bit (64bytes) line in B0 + auto rowB1 = _mm512_unpackhi_epi16(aib, cid); // [a4&b4&c4&d4 | a5&b5&c5&d5 | a6&b6&c6&d6 | a7&b7&c7&d7] 512-bit (64bytes) line in B1 + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + dst0 += 64; + dst1 += 64; + } + } + } + + inline void kpack_tile_B0B1(void * _dst0, void * _dst1, const ov::bfloat16 * _src, int stride, int src_rows) { + static const uint64_t idx[8] = {0,4,1,5,2,6,3,7}; + auto midx = _mm512_loadu_epi64(idx); + const auto * src = reinterpret_cast(_src); + auto * dst0 = reinterpret_cast(_dst0); + auto * dst1 = reinterpret_cast(_dst1); + __m512i a,b,rowB0, rowB1; + if (src_rows == 32) { + for (int row = 0; row < 16; row++) { + a = _mm512_loadu_epi16(src); // [a1 a2 a3 a4 | a5 a6 a7 a8] total 512-bits in 8 64bits unit + b = _mm512_loadu_epi16(src + stride); // [b1 b2 b3 b4 | b5 b6 b7 b8] total 512-bits + a = _mm512_permutexvar_epi64(midx, a); // [a1 a5 | a2 a6 | a3 a7 | a4 a8] + b = _mm512_permutexvar_epi64(midx, b); // [b1 b5 | b2 b6 | b3 b7 | b4 b8] + rowB0 = _mm512_unpacklo_epi16(a, b); // [ a1&b1 a2&b2 a3&b3 a4&b4] for each 128-bits lane, interleave word in low 64 bits + rowB1 = _mm512_unpackhi_epi16(a, b); // [ a5&b5 a6&b6 a7&b7 a8&b8] for each 128-bits lane, interleave word in high 64 bits + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + src += 2*stride; + dst0 += 64; + dst1 += 64; + } + } else { + int allzero_dst_rows = (32-src_rows)/2; + int allnonzero_dst_rows = src_rows/2; + + rowB0 = _mm512_setzero_si512(); + rowB1 = _mm512_setzero_si512(); + for(int i = 0; i < allzero_dst_rows ; i++) { + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + dst0 += 64; + dst1 += 64; + } + // mixed row + if (src_rows & 1) { + a = _mm512_setzero_si512(); + b = _mm512_loadu_epi16(src); src += stride; + a = _mm512_permutexvar_epi64(midx, a); + b = _mm512_permutexvar_epi64(midx, b); + auto rowB0 = _mm512_unpacklo_epi16(a, b); + auto rowB1 = _mm512_unpackhi_epi16(a, b); + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + dst0 += 64; + dst1 += 64; + } + // all non-zero rows + for (int i = 0; i < allnonzero_dst_rows; i++) { + a = _mm512_loadu_epi16(src); + b = _mm512_loadu_epi16(src + stride); + a = _mm512_permutexvar_epi64(midx, a); + b = _mm512_permutexvar_epi64(midx, b); + rowB0 = _mm512_unpacklo_epi16(a, b); + rowB1 = _mm512_unpackhi_epi16(a, b); + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + src += 2*stride; + dst0 += 64; + dst1 += 64; + } + } + } + + inline void kpack_tile_B0B1_ntail(void * _dst0, void * _dst1, const int8_t * _src, int stride, int src_rows, int valid_n) { + #define FROM_B(i) ((1<<4)|(i)) + static const uint32_t idx[16] = { 0,4,FROM_B(0),FROM_B(4), + 1,5,FROM_B(1),FROM_B(5), + 2,6,FROM_B(2),FROM_B(6), + 3,7,FROM_B(3),FROM_B(7)}; + auto midx = _mm512_loadu_epi64(idx); + __mmask16 mask = _cvtu32_mask16(0xFFFFu); + const auto * src = reinterpret_cast(_src); + auto * dst0 = reinterpret_cast(_dst0); + auto * dst1 = reinterpret_cast(_dst1); + __mmask32 mask_n = _cvtu32_mask32(0xFFFFFFFF >> (32 - valid_n)); + if (src_rows == 64) { + for (int row = 0; row < 16; row++) { + // each element (a? or b?) is 32-bits, two lanes in each ymm register + auto a256 = _mm256_maskz_loadu_epi8(mask_n, src); src += stride; // [a0 a1 a2 a3 | a4 a5 a6 a7] 256-bits ymm0 B0: a0-a3 B1: a4:a7 + auto b256 = _mm256_maskz_loadu_epi8(mask_n, src); src += stride; // [b0 b1 b2 b3 | b4 b5 b6 b7] 256-bits ymm1 B0: b0-b3 B1: b4:b7 + auto c256 = _mm256_maskz_loadu_epi8(mask_n, src); src += stride; // [c0 c1 c2 c3 | c4 c5 c6 c7] 256-bits ymm2 B0: c0-c3 B1: c4:c7 + auto d256 = _mm256_maskz_loadu_epi8(mask_n, src); src += stride; // [d0 d1 d2 d3 | d4 d5 d6 d7] 256-bits ymm3 B0: d0-d3 B1: d4:d7 + auto a = _mm512_castsi256_si512(a256); + auto b = _mm512_castsi256_si512(b256); + auto c = _mm512_castsi256_si512(c256); + auto d = _mm512_castsi256_si512(d256); + auto ac = _mm512_mask_permutex2var_epi32(a, mask, midx, c); // [a0 a4 c0 c4 | a1 a5 c1 c5 | a2 a6 c2 c6 | a3 a7 c3 c7] + auto bd = _mm512_mask_permutex2var_epi32(b, mask, midx, d); // [b0 b4 d0 d4 | b1 b5 d1 d5 | b2 b6 d2 d6 | b3 b7 d3 d7] + auto aib = _mm512_unpacklo_epi8(ac, bd); // [a0&b0 a4&b4 | a1&b1 a5&b5 | a2&b2 a6&b6 | a3&b3 a7&b7] + auto cid = _mm512_unpackhi_epi8(ac, bd); // [c0&d0 c4&d4 | c1&d1 c5&d5 | c2&d2 c6&d6 | c3&d3 c7&d7] + auto rowB0 = _mm512_unpacklo_epi16(aib, cid); // [a0&b0&c0&d0 | a1&b1&c1&d1 | a2&b2&c2&d2 | a3&b3&c3&d3] 512-bit (64bytes) line in B0 + auto rowB1 = _mm512_unpackhi_epi16(aib, cid); // [a4&b4&c4&d4 | a5&b5&c5&d5 | a6&b6&c6&d6 | a7&b7&c7&d7] 512-bit (64bytes) line in B1 + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + dst0 += 64; + dst1 += 64; + } + } else { + // less than 64 source lines, + int allzero_dst_rows = (64-src_rows)/4; + int allnonzero_dst_rows = src_rows/4; + // padding zeros at the top + auto rowB0 = _mm512_setzero_si512(); + auto rowB1 = _mm512_setzero_si512(); + for(int i = 0; i < allzero_dst_rows ; i++) { + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + dst0 += 64; + dst1 += 64; + } + // mixed row + int tails_nz = (src_rows & 3); + if (tails_nz) { + __mmask32 kmask1 = _cvtu32_mask32(0xFFFFFFFF); + auto a256 = _mm256_setzero_si256(); // must be zero + auto b256 = _mm256_setzero_si256(); // when tails_nz > 2 + auto c256 = _mm256_setzero_si256(); // when tails_nz > 1 + auto d256 = _mm256_setzero_si256(); // when tails_nz > 0(always load) + if (tails_nz > 2) { + b256 = _mm256_maskz_loadu_epi8(mask_n, src); src += stride; + } + if (tails_nz > 1) { + c256 = _mm256_maskz_loadu_epi8(mask_n, src); src += stride; + } + d256 = _mm256_maskz_loadu_epi8(mask_n, src); src += stride; + auto a = _mm512_castsi256_si512(a256); + auto b = _mm512_castsi256_si512(b256); + auto c = _mm512_castsi256_si512(c256); + auto d = _mm512_castsi256_si512(d256); + auto ac = _mm512_mask_permutex2var_epi32(a, mask, midx, c); // [a0 a4 c0 c4 | a1 a5 c1 c5 | a2 a6 c2 c6 | a3 a7 c3 c7] + auto bd = _mm512_mask_permutex2var_epi32(b, mask, midx, d); // [b0 b4 d0 d4 | b1 b5 d1 d5 | b2 b6 d2 d6 | b3 b7 d3 d7] + auto aib = _mm512_unpacklo_epi8(ac, bd); // [a0&b0 a4&b4 | a1&b1 a5&b5 | a2&b2 a6&b6 | a3&b3 a7&b7] + auto cid = _mm512_unpackhi_epi8(ac, bd); // [c0&d0 c4&d4 | c1&d1 c5&d5 | c2&d2 c6&d6 | c3&d3 c7&d7] + auto rowB0 = _mm512_unpacklo_epi16(aib, cid); // [a0&b0&c0&d0 | a1&b1&c1&d1 | a2&b2&c2&d2 | a3&b3&c3&d3] 512-bit (64bytes) line in B0 + auto rowB1 = _mm512_unpackhi_epi16(aib, cid); // [a4&b4&c4&d4 | a5&b5&c5&d5 | a6&b6&c6&d6 | a7&b7&c7&d7] 512-bit (64bytes) line in B1 + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + dst0 += 64; + dst1 += 64; + } + // all non zeros + for (int i = 0; i < allnonzero_dst_rows; i++) { + auto a256 = _mm256_maskz_loadu_epi8(mask_n, src); src += stride; + auto b256 = _mm256_maskz_loadu_epi8(mask_n, src); src += stride; + auto c256 = _mm256_maskz_loadu_epi8(mask_n, src); src += stride; + auto d256 = _mm256_maskz_loadu_epi8(mask_n, src); src += stride; + auto a = _mm512_castsi256_si512(a256); + auto b = _mm512_castsi256_si512(b256); + auto c = _mm512_castsi256_si512(c256); + auto d = _mm512_castsi256_si512(d256); + auto ac = _mm512_mask_permutex2var_epi32(a, mask, midx, c); // [a0 a4 c0 c4 | a1 a5 c1 c5 | a2 a6 c2 c6 | a3 a7 c3 c7] + auto bd = _mm512_mask_permutex2var_epi32(b, mask, midx, d); // [b0 b4 d0 d4 | b1 b5 d1 d5 | b2 b6 d2 d6 | b3 b7 d3 d7] + auto aib = _mm512_unpacklo_epi8(ac, bd); // [a0&b0 a4&b4 | a1&b1 a5&b5 | a2&b2 a6&b6 | a3&b3 a7&b7] + auto cid = _mm512_unpackhi_epi8(ac, bd); // [c0&d0 c4&d4 | c1&d1 c5&d5 | c2&d2 c6&d6 | c3&d3 c7&d7] + auto rowB0 = _mm512_unpacklo_epi16(aib, cid); // [a0&b0&c0&d0 | a1&b1&c1&d1 | a2&b2&c2&d2 | a3&b3&c3&d3] 512-bit (64bytes) line in B0 + auto rowB1 = _mm512_unpackhi_epi16(aib, cid); // [a4&b4&c4&d4 | a5&b5&c5&d5 | a6&b6&c6&d6 | a7&b7&c7&d7] 512-bit (64bytes) line in B1 + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + dst0 += 64; + dst1 += 64; + } + } + } + + inline void kpack_tile_B0B1_ntail(void * _dst0, void * _dst1, const ov::bfloat16 * _src, int stride, int src_rows, int valid_n) { + static const uint64_t idx[8] = {0,4,1,5,2,6,3,7}; + auto midx = _mm512_loadu_epi64(idx); + const auto * src = reinterpret_cast(_src); + auto * dst0 = reinterpret_cast(_dst0); + auto * dst1 = reinterpret_cast(_dst1); + __mmask32 mask = _cvtu32_mask32(0xFFFFFFFF >> (32 - valid_n)); + __m512i a,b,rowB0, rowB1; + if (src_rows == 32) { + for (int row = 0; row < 16; row++) { + a = _mm512_maskz_loadu_epi16(mask, src); // [a1 a2 a3 a4 | a5 a6 a7 a8] total 512-bits in 8 64bits unit + b = _mm512_maskz_loadu_epi16(mask, src + stride); // [b1 b2 b3 b4 | b5 b6 b7 b8] total 512-bits + a = _mm512_permutexvar_epi64(midx, a); // [a1 a5 | a2 a6 | a3 a7 | a4 a8] + b = _mm512_permutexvar_epi64(midx, b); // [b1 b5 | b2 b6 | b3 b7 | b4 b8] + rowB0 = _mm512_unpacklo_epi16(a, b); // [ a1&b1 a2&b2 a3&b3 a4&b4] for each 128-bits lane, interleave word in low 64 bits + rowB1 = _mm512_unpackhi_epi16(a, b); // [ a5&b5 a6&b6 a7&b7 a8&b8] for each 128-bits lane, interleave word in high 64 bits + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + src += 2*stride; + dst0 += 64; + dst1 += 64; + } + } else { + int allzero_dst_rows = (32-src_rows)/2; + int allnonzero_dst_rows = src_rows/2; + + rowB0 = _mm512_setzero_si512(); + rowB1 = _mm512_setzero_si512(); + for(int i = 0; i < allzero_dst_rows ; i++) { + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + dst0 += 64; + dst1 += 64; + } + // mixed row + if (src_rows & 1) { + a = _mm512_setzero_si512(); + b = _mm512_maskz_loadu_epi16(mask, src); src += stride; + a = _mm512_permutexvar_epi64(midx, a); + b = _mm512_permutexvar_epi64(midx, b); + auto rowB0 = _mm512_unpacklo_epi16(a, b); + auto rowB1 = _mm512_unpackhi_epi16(a, b); + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + dst0 += 64; + dst1 += 64; + } + // all non-zero rows + for (int i = 0; i < allnonzero_dst_rows; i++) { + a = _mm512_maskz_loadu_epi16(mask, src); + b = _mm512_maskz_loadu_epi16(mask, src + stride); + a = _mm512_permutexvar_epi64(midx, a); + b = _mm512_permutexvar_epi64(midx, b); + rowB0 = _mm512_unpacklo_epi16(a, b); + rowB1 = _mm512_unpackhi_epi16(a, b); + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + src += 2*stride; + dst0 += 64; + dst1 += 64; + } + } + } + + // prepare B matrix for C matrix 2x2 blocking (B matrix + // will be accessed in 1x2) + // given 2x2 blocking scheme, Kx32 blocks are always + // accessed sequentially: + // transpose/repack each 32xK ov::bfloat16 submatrix + // into Kx32 slices (each number is a 16x32 bf16-block): + // 0 2 4 6 ... ... + // 1 3 5 7 ... ... + + inline void get_min_max(tensor2D & matB, float& min, float& max) { + int K = matB.dims[0]; + int N = matB.dims[1]; + auto m_max = _mm512_set1_ps(-__FLT_MAX__); + auto m_min = _mm512_set1_ps(__FLT_MAX__); + for (int k = 0; k < K; k++) { + int n = 0; + for (; n < N / 16 * 16; n += 16) { + auto a = _mm512_cvtepi16_epi32(_mm256_loadu_epi16(&matB(k, n))); + a = _mm512_slli_epi32(a, 16); + m_max = _mm512_max_ps((__m512)a, m_max); + m_min = _mm512_min_ps((__m512)a, m_min); + } + if (n != N) { + __mmask16 msk = _cvtu32_mask16(0xFFFFu >> (16 - (N - n))); + auto a = _mm512_cvtepi16_epi32(_mm256_maskz_loadu_epi16(msk, &matB(k, n))); + a = _mm512_slli_epi32(a, 16); + m_max = _mm512_mask_max_ps(m_max, msk, (__m512)a, m_max); + m_min = _mm512_mask_min_ps(m_min, msk, (__m512)a, m_min); + } + } + max = _mm512_reduce_max_ps(m_max); + min = _mm512_reduce_min_ps(m_min); + } + + template + void i8_to_bf16_Kx32(int8_t *&src, ov::bfloat16 *dst) + { + for (int k = 0; k < K; k++) + { + auto a = _mm_load_si128((__m128i *)src); // 16 int8 + auto b = _mm_load_si128((__m128i *)(src + 16)); // 16 int8 + auto a_512 = _mm512_cvtepi8_epi32(a); // 16 int32 + auto b_512 = _mm512_cvtepi8_epi32(b); // 16 int32 + auto a_f = _mm512_cvtepi32_ps(a_512); // 16 ps + auto b_f = _mm512_cvtepi32_ps(b_512); // 16 ps + auto reg_out = _mm512_cvtne2ps_pbh(b_f, a_f); // 32 packed bf16 + _mm512_store_epi32(dst, (__m512i)reg_out); // + src += 32; // 32 int8_t dequantized into 32 bf16 + dst += 32; + } + } + + inline void bf16_to_i8_tensor(tensor2D& dst, tensor2D& src, float quant_scale) { + dst.resize(src.dims[0], src.dims[1]); + auto scale = _mm512_set1_ps(quant_scale); + for (int k = 0; k < src.dims[0]; k++) { + auto p_src = &src(k, 0); + auto p_dst = &dst(k, 0); + for (int n = 0; n < src.dims[1]; n += 16, p_src += 16, p_dst += 16) { + auto a = _mm512_cvtepi16_epi32(_mm256_loadu_epi16(p_src)); // load packed 16 x bf16 + a = _mm512_slli_epi32(a, 16); // bf16 zero-extend to f32 + auto a_f = _mm512_mul_ps((__m512)a, scale); // scale + a = _mm512_cvtps_epi32(a_f); // ps to dw + auto a_128 = _mm512_cvtsepi32_epi8(a); // saturate convert into int8 + _mm_store_si128((__m128i*)(p_dst), a_128); + } + } + } +}; + +// 2x2 tiles post process kernels + +// 4 tiles located at C matrix (m,n) of size (valid_m, valid_n) +// tC00/tC01 +// tC10/tC11 + +namespace PP { + template + struct is_f32i32 : std::false_type {}; + template<> + struct is_f32i32 : std::true_type {}; + template<> + struct is_f32i32 : std::true_type {}; + + enum Steps { + NONE = 0, + DEQUANT = 1<<0, + BIAS = 1<<1, + GELU = 1<<2, + QUANT = 1<<3, + + BIAS_GELU = BIAS | GELU, + DEQUANT_BIAS_GELU = DEQUANT | BIAS_GELU, + DEQUANT_BIAS_GELU_QUANT = DEQUANT_BIAS_GELU | QUANT, + DEQUANT_BIAS_QUANT = DEQUANT | BIAS | QUANT, + DEQUANT_GELU_QUANT = DEQUANT | GELU | QUANT, + DEQUANT_QUANT = DEQUANT | QUANT, + + DEQUANT_GELU = DEQUANT | GELU, + DEQUANT_BIAS = DEQUANT | BIAS + }; + + template + struct BiasGeluStore { + static_assert(std::is_same::value || std::is_same::value || std::is_same::value, + "BiasGeluStore only support output data types ov::bfloat16/int8_t/float"); + + BiasGeluStore(tensor2D & C, float * bias = nullptr) : C(C), bias(bias) {} + + tensor2D & C; + float * bias; + void set_bias(float * _bias) { + assert (steps & BIAS); + bias = _bias; + } + + float deq_scale_common = 1.0f; + float * deq_scale_per_oc = nullptr; + void set_deq_scale(float scale = 1.0f) { + assert (steps & DEQUANT); + deq_scale_common = scale; + deq_scale_per_oc = nullptr; + } + void set_deq_scale(float * scale_per_oc) { + assert (steps & DEQUANT); + deq_scale_common = 0; + deq_scale_per_oc = scale_per_oc; + } + + float q_scale_common = 0.0f; + float * q_scale_per_oc = nullptr; + void set_q_scale(float scale) { + assert (steps & QUANT); + q_scale_common = scale; + q_scale_per_oc = nullptr; + } + void set_q_scale(float * scale_per_oc) { + assert (steps & QUANT); + q_scale_common = 0; + q_scale_per_oc = scale_per_oc; + } + + // source buffC can be i32 or f32 + template::value, bool>::type = true> + void operator()(tensor2D & buffC, int m, int n, int valid_m, int valid_n) { + auto * psrc = &buffC(0,0); + int8_t * pdst = reinterpret_cast(&(C(m, n))); + int stride = C.stride; + + __m512 bias0, bias1; + if (steps & BIAS) { + bias0 = _mm512_loadu_ps(bias + n); + bias1 = _mm512_loadu_ps(bias + n + 16); + } + + __m512 m512_q_scale0; + __m512 m512_q_scale1; + __m512 m512_deq_scale0; + __m512 m512_deq_scale1; + if (steps & DEQUANT) { + if (deq_scale_per_oc) { + m512_deq_scale0 = _mm512_loadu_ps(deq_scale_per_oc + n); + m512_deq_scale1 = _mm512_loadu_ps(deq_scale_per_oc + n + 16); + } else { + m512_deq_scale0 = _mm512_set1_ps(deq_scale_common); + m512_deq_scale1 = _mm512_set1_ps(deq_scale_common); + } + } + if (steps & QUANT) { + if (q_scale_per_oc) { + m512_q_scale0 = _mm512_loadu_ps(q_scale_per_oc + n); + m512_q_scale1 = _mm512_loadu_ps(q_scale_per_oc + n + 16); + } else { + m512_q_scale0 = _mm512_set1_ps(q_scale_common); + m512_q_scale1 = _mm512_set1_ps(q_scale_common); + } + } + + __mmask32 kall; + __mmask16 k0, k1; + if (std::is_same::value) { + if (valid_n >= 16) { + k0 = _cvtu32_mask16(0xFFFF); + k1 = _cvtu32_mask16(0xFFFF >> (32-valid_n)); + } else { + k0 = _cvtu32_mask16(0xFFFF >> (16-valid_n)); + k1 = _cvtu32_mask16(0); + } + } else { + kall = _cvtu32_mask32(0xFFFFFFFF >> (32-valid_n)); + } + + for(int i = 0; i < valid_m; i ++) { + auto r0 = _mm512_loadu_ps(psrc); + auto r1 = _mm512_loadu_ps(psrc + 16); + if (std::is_same::value) { + r0 = _mm512_cvtepi32_ps(_mm512_castps_si512(r0)); // cvt i32=>f32 + r1 = _mm512_cvtepi32_ps(_mm512_castps_si512(r1)); // cvt i32=>f32 + } + if (steps & DEQUANT) { + r0 = _mm512_mul_ps(r0, m512_deq_scale0); // dequantize + r1 = _mm512_mul_ps(r1, m512_deq_scale1); // dequantize + } + if (steps & BIAS) { + r0 = _mm512_add_ps(r0, bias0); + r1 = _mm512_add_ps(r1, bias1); + } + if (steps & GELU) { + r0 = functional::gelu_erf_minmax_approx(r0); + r1 = functional::gelu_erf_minmax_approx(r1); + } + + // quantize & store + if (steps & QUANT) { + r0 = _mm512_mul_ps(r0, m512_q_scale0); + r1 = _mm512_mul_ps(r1, m512_q_scale1); + } + if (std::is_same::value) { + auto c = _mm512_cvtne2ps_pbh(r1, r0); // convert to bf16 + _mm512_mask_storeu_epi16(pdst, kall, (__m512i)c); // store bf16 + } + if (std::is_same::value) { + auto d0 = _mm512_cvtps_epi32(r0); // convert to dword(i32) + auto d1 = _mm512_cvtps_epi32(r1); // convert to dword(i32) + auto b0 = _mm512_cvtsepi32_epi8 (d0); // dword => int8 with Saturate8 + auto b1 = _mm512_cvtsepi32_epi8 (d1); // dword => int8 with Saturate8 + auto b0b1 = _mm256_inserti32x4(_mm256_castsi128_si256(b0), b1, 1); // combine two int8 xmm into a ymm + _mm256_mask_storeu_epi8(pdst, kall, b0b1); // masked store + } + if (std::is_same::value) { + _mm512_mask_storeu_ps(pdst, k0, r0); // store float + _mm512_mask_storeu_ps(pdst + 64, k1, r1); // store float + } + pdst += stride; + psrc += 32; + } + } + }; +} + +template +void prefetch_bytes(void *src) +{ + int8_t *p = reinterpret_cast(src); + for (int i = 0; i < bytes; i+=64) + _mm_prefetch(p + i + advance, sel); +} + +// matmul (FC) +// +// constB constrols whether it's FC or not +// store precision for weight compression, only for BF16 AMX + +template +tensor2D getSubMatB(tensor2D & _matB, int n0, int n1, bool transposeB) { + int Bd0 = transposeB ? (n1-n0) : _matB.dims[0]; + int Bd1 = transposeB ? _matB.dims[1] : (n1-n0); + T * pbase = transposeB ? (&_matB(n0, 0)):(&_matB(0, n0)); + return tensor2D(Bd0, Bd1, pbase, _matB.stride); +} + +template +void loop2D_no_bM(int M, int N, F f) { + for(int n=0; n +void loop2D(int M, int N, int mc, F f) { + for(int m0=0; m0= bM) +template +void loop2D_opt_Mtail(int M, int N, int mc, F f) { + int tailM = (M % (mc*bM)) % bM; + assert(M > bM); + for(int m0=0; m0 +void repackB_1x2(const tensor2D &Bi, bool transpose, tensor2D& Bo, bool is_const) { + int K = Bi.dims[transpose ? 1 : 0]; + int N = Bi.dims[transpose ? 0 : 1]; + + // K_padded : round up to multiple of 32/64 + int kStep = 64 / sizeof(T); + int K_padded = (K + kStep - 1) / kStep * kStep; + int Ktails = K % kStep; + int Kbody = K - Ktails; + + // N_padded : round up to multiple of (2*16) + int N_unit = 2 * 16; + int N_padded = (N + N_unit - 1) / N_unit * N_unit; + + // Bo(ni, 0) is a vector flattened from a slice of shape [K_padded x N_unit] + Bo.resize(N_padded / N_unit, K_padded * N_unit, false, is_const); + + int n = 0; + int n_tail = N % N_unit; + if (transpose) { + for(; n < N - n_tail; n += N_unit) { + // a K_padded x N_unit submatrix layouted in B0/B1... and put sequentially + auto* dst = reinterpret_cast(&Bo(n / N_unit, 0)); + auto* src0 = reinterpret_cast(&Bi(n, 0)); + int k; + for(k = 0; k < Kbody; k += kStep) { + // B0 (16x32) => transpose+repack as 32x16(16x16x2) or 64x16(16x16x4) + functional::transpose_epi32_16x16(dst, src0 + 0 * 16 * Bi.stride + k * sizeof(T), Bi.stride); + dst += 1024; + functional::transpose_epi32_16x16(dst, src0 + 1 * 16 * Bi.stride + k * sizeof(T), Bi.stride); + dst += 1024; + } + if (Ktails) { + // Ktails part is loaded into A tile right-aligned, so B tile must also load + // Ktails part to bottom-aligned, and fill upper padding with zero + functional::transpose_epi32_16xN_right_align(dst, src0 + 0 * 16 * Bi.stride + k * sizeof(T), Bi.stride, (K - k)*sizeof(T)); + dst += 1024; + functional::transpose_epi32_16xN_right_align(dst, src0 + 1 * 16 * Bi.stride + k * sizeof(T), Bi.stride, (K - k)*sizeof(T)); + dst += 1024; + } + } + // n_tail: [16, 32) + if (N - n >= 16) { + auto* dst = reinterpret_cast(&Bo(n / N_unit, 0)); + auto* src0 = reinterpret_cast(&Bi(n, 0)); + int k; + for(k = 0; k < Kbody; k += kStep) { + // B0 (16x32) => transpose+repack as 32x16(16x16x2) or 64x16(16x16x4) + functional::transpose_epi32_16x16(dst, src0 + 0 * 16 * Bi.stride + k * sizeof(T), Bi.stride); + dst += 1024 * 2; + } + if (Ktails) { + // Ktails part is loaded into A tile right-aligned, so B tile must also load + // Ktails part to bottom-aligned, and fill upper padding with zero + functional::transpose_epi32_16xN_right_align(dst, src0 + 0 * 16 * Bi.stride + k * sizeof(T), Bi.stride, (K - k) * sizeof(T)); + } + n += 16; + } + // n_tail: (0, 16) + if (N - n > 0) { + auto* dst = reinterpret_cast(&Bo(n / N_unit, 0)) + (n_tail > 16 ? 1024 : 0); + auto* src0 = reinterpret_cast(&Bi(n, 0)); + int k; + for(k = 0; k < Kbody; k += kStep) { + // B0 (16x32) => transpose+repack as 32x16(16x16x2) or 64x16(16x16x4) + functional::transpose_epi32_Mx16(dst, src0 + 0 * 16 * Bi.stride + k * sizeof(T), Bi.stride, N - n); + dst += 1024 * 2; + } + if (Ktails) { + // Ktails part is loaded into A tile right-aligned, so B tile must also load + // Ktails part to bottom-aligned, and fill upper padding with zero + functional::transpose_epi32_MxN_right_align(dst, src0 + 0 * 16 * Bi.stride + k * sizeof(T), Bi.stride, (K - k) * sizeof(T), N - n); + } + n = N; + } + // second B tile is untouched, need to set to zero + if (n_tail > 0 && n_tail <= 16) { + auto* dst = reinterpret_cast(&Bo(n / N_unit, 0)); + for (int k = 0; k < K_padded; k += kStep) { + memset(dst + 1024, 0, 1024); + dst += 1024 * 2; + } + } + } else { + // pack & layout sequentially + int n = 0; + int n_tail = N % N_unit; + for(; n < N - n_tail; n += N_unit) { + auto * dst = reinterpret_cast(&Bo(n / N_unit, 0)); + for(int k = 0; k < K; k += kStep) { + // bf16: B0 B1 32x(16+16) => repack as two 16x16x2 + // int8: B0 B1 64x(16+16) => repack as two 16x16x4 + int src_rows = std::min(K - k, kStep); + functional::kpack_tile_B0B1(dst, dst + 1024, &Bi(k, n), Bi.stride, src_rows); + dst += 2048; + } + } + // n_tail: (0, 32) + if (N - n > 0) { + auto * dst = reinterpret_cast(&Bo(n / N_unit, 0)); + for(int k = 0; k < K; k += kStep) { + // bf16: B0 B1 32x(16+16) => repack as two 16x16x2 + // int8: B0 B1 64x(16+16) => repack as two 16x16x4 + int src_rows = std::min(K - k, kStep); + functional::kpack_tile_B0B1_ntail(dst, dst + 1024, &Bi(k, n), Bi.stride, src_rows, N - n); + dst += 2048; + } + n += 16; + } + } +} + +template +struct acc_type {}; +template<> +struct acc_type { typedef float type; }; +template<> +struct acc_type { typedef int32_t type; }; +template<> +struct acc_type { typedef int32_t type; }; + +template +using acc_type_t = typename acc_type::type; + +// matrix multiply with vector +// C_Nx1 = A_MxK * b_Kx1 + +template> +struct MatmulVector { + MatmulVector() {} + constexpr static bool is_bf16s8 = std::is_same::value && std::is_same::value; + constexpr static bool is_bf16bf16 = std::is_same::value && std::is_same::value; + constexpr static bool is_s8s8 = std::is_same::value && std::is_same::value; + constexpr static bool is_s8u8 = std::is_same::value && std::is_same::value; + constexpr static bool is_u8s8 = std::is_same::value && std::is_same::value; + constexpr static bool is_u8u8 = std::is_same::value && std::is_same::value; + constexpr static bool is_i8_mode = is_s8s8 || is_s8u8 || is_u8s8 || is_u8u8; + constexpr static int kStep = is_i8_mode ? 64 : 32; + +#define TILE_DP(dst, a, b) \ + if (is_bf16bf16) _tile_dpbf16ps(dst, a, b); \ + if (is_s8s8) _tile_dpbssd(dst, a, b); \ + if (is_s8u8) _tile_dpbsud(dst, a, b); \ + if (is_u8s8) _tile_dpbusd(dst, a, b); \ + if (is_u8u8) _tile_dpbuud(dst, a, b); + + alignas(64) int8_t KtailBuff[64]; + + template + void kernel(int M, int K, const void * pA, int strideA, const void * vB, void * vC) { + static_assert(tmmN >= 1 && tmmN <= 6, "tmmN must be within [1-6] range"); + const auto * pA0 = reinterpret_cast(pA); + int KLastOffBytes = (K - kStep) * sizeof(TA); + const auto * pB0 = reinterpret_cast(vB); + auto * pC0 = reinterpret_cast(vC); + + const auto * pBLast = pB0 + 64*(tmmN - 1); + int Ktail = K & (kStep - 1); + if (Ktail) { + if (bFallbackKtails) { + // if bContainMtails, the last submatrix needs to use special to prevent A matrix read overflow + // K tails is handled by: + // - zero-padding the last tile of vector B, at the top + // - right-align last tile load from matA + __mmask64 kmask = _cvtu64_mask64(0xFFFFFFFFFFFFFFFFull << (kStep - Ktail)*sizeof(TB)); + auto r = _mm512_maskz_loadu_epi8(kmask, pB0 + KLastOffBytes); + _mm512_storeu_epi8(KtailBuff, r); + } else { + // each row of A can be read overflow w/o worrying NaN numbers + // zero-padding the last tile of vector B as bottom is enough + __mmask64 kmask = _cvtu64_mask64(0xFFFFFFFFFFFFFFFFull >> (kStep - Ktail)*sizeof(TB)); + KLastOffBytes = (K - Ktail)*sizeof(TA); + auto r = _mm512_maskz_loadu_epi8(kmask, pB0 + KLastOffBytes); + _mm512_storeu_epi8(KtailBuff, r); + } + pBLast = KtailBuff; + } + + // load B tiles outside of loop + if (tmmN == 1) { + _tile_loadd(2, pB0, 4); + } + if (tmmN == 2) { + _tile_loadd(2, pB0, 4); + _tile_loadd(3, pBLast, 4); + } + if (tmmN == 3) { + _tile_loadd(2, pB0, 4); + _tile_loadd(3, pB0 + 64, 4); + _tile_loadd(4, pBLast, 4); + } + if (tmmN == 4) { + _tile_loadd(2, pB0, 4); + _tile_loadd(3, pB0 + 64, 4); + _tile_loadd(4, pB0 + 64*2, 4); + _tile_loadd(5, pBLast, 4); + } + if (tmmN == 5) { + _tile_loadd(2, pB0, 4); + _tile_loadd(3, pB0 + 64, 4); + _tile_loadd(4, pB0 + 64*2, 4); + _tile_loadd(5, pB0 + 64*3, 4); + _tile_loadd(6, pBLast, 4); + } + if (tmmN == 6) { + _tile_loadd(2, pB0, 4); + _tile_loadd(3, pB0 + 64, 4); + _tile_loadd(4, pB0 + 64*2, 4); + _tile_loadd(5, pB0 + 64*3, 4); + _tile_loadd(6, pB0 + 64*4, 4); + _tile_loadd(7, pBLast, 4); + } + //asm("int3"); + for(int m = 0; m < M; m+=16) { + _tile_zero(0); + if (tmmN == 1) { + _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); + } + if (tmmN == 2) { + _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); + _tile_loadd(1, pA0 + KLastOffBytes, strideA); TILE_DP(0, 1, 3); + } + if (tmmN == 3) { + _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); + _tile_loadd(1, pA0 + 64, strideA); TILE_DP(0, 1, 3); + _tile_loadd(1, pA0 + KLastOffBytes, strideA); TILE_DP(0, 1, 4); + } + if (tmmN == 4) { + _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); + _tile_loadd(1, pA0 + 64, strideA); TILE_DP(0, 1, 3); + _tile_loadd(1, pA0 + 128, strideA); TILE_DP(0, 1, 4); + _tile_loadd(1, pA0 + KLastOffBytes, strideA); TILE_DP(0, 1, 5); + } + if (tmmN == 5) { + _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); + _tile_loadd(1, pA0 + 64, strideA); TILE_DP(0, 1, 3); + _tile_loadd(1, pA0 + 128, strideA); TILE_DP(0, 1, 4); + _tile_loadd(1, pA0 + 192, strideA); TILE_DP(0, 1, 5); + _tile_loadd(1, pA0 + KLastOffBytes, strideA); TILE_DP(0, 1, 6); + } + if (tmmN == 6) { + _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); + _tile_loadd(1, pA0 + 64, strideA); TILE_DP(0, 1, 3); + _tile_loadd(1, pA0 + 128, strideA); TILE_DP(0, 1, 4); + _tile_loadd(1, pA0 + 192, strideA); TILE_DP(0, 1, 5); + _tile_loadd(1, pA0 + 256, strideA); TILE_DP(0, 1, 6); + _tile_loadd(1, pA0 + KLastOffBytes, strideA); TILE_DP(0, 1, 7); + } + _tile_stored(0, pC0, 4); pC0 += 16*4; // C is single column, always take 4 bytes + pA0 += 16 * strideA; + } + } + + void operator()(tensor2D & matA, const TB * vB, TC * vC) { + int M = matA.dims[0]; + int K = matA.dims[1]; + TA * pA = &matA[0]; + int strideA = matA.stride; + + // M tails is handled + assert(K >= kStep && K <= 6*kStep); + + int Ktail = K & (kStep - 1); + int Mtail = M & (16 - 1); + int Mbody = M - Mtail; + int numBtiles = (K + kStep - 1)/kStep; + + // if we have Ktails, then it will always be handled in Mtail, so we split + // Mtail out even if it's zero + if (Ktail) { + if (Mtail == 0) { + Mtail = 16; + Mbody -= 16; + } + } + + if (Mbody) { + tileconfig_t tfg(1, 0, { + {16, 4}, // C:0 M x 1 (4b) + {16, 64}, // A:1 M x 32/64 (64b) + {16, 4}, // B:2 32/64 x 1 (4b) + {16, 4}, // B:3 + {16, 4}, // B:4 + {16, 4}, // B:5 + {16, 4}, // B:6 + {16, 4}, // B:7 + }); + // Ktail fallback will always be done at Mtails loop + switch(numBtiles) { + case 1: kernel<1, false>(Mbody, K, pA, strideA, vB, vC); break; + case 2: kernel<2, false>(Mbody, K, pA, strideA, vB, vC); break; + case 3: kernel<3, false>(Mbody, K, pA, strideA, vB, vC); break; + case 4: kernel<4, false>(Mbody, K, pA, strideA, vB, vC); break; + case 5: kernel<5, false>(Mbody, K, pA, strideA, vB, vC); break; + case 6: kernel<6, false>(Mbody, K, pA, strideA, vB, vC); break; + default: + assert(false); // impossible since (K <= 6*kStep) + } + } + + if (Mtail) { + pA = &matA(Mbody, 0); + tileconfig_t tfg(1, 0, { + {Mtail, 4}, // C:0 M x 1 (4b) + {Mtail, 64}, // A:1 M x 32/64 (64b) + {16, 4}, // B:2 32/64 x 1 (4b) + {16, 4}, // B:3 + {16, 4}, // B:4 + {16, 4}, // B:5 + {16, 4}, // B:6 + {16, 4}, // B:7 + }); + if (Ktail) { + switch(numBtiles) { + case 1: kernel<1, true>(Mtail, K, pA, strideA, vB, vC + Mbody); break; + case 2: kernel<2, true>(Mtail, K, pA, strideA, vB, vC + Mbody); break; + case 3: kernel<3, true>(Mtail, K, pA, strideA, vB, vC + Mbody); break; + case 4: kernel<4, true>(Mtail, K, pA, strideA, vB, vC + Mbody); break; + case 5: kernel<5, true>(Mtail, K, pA, strideA, vB, vC + Mbody); break; + case 6: kernel<6, true>(Mtail, K, pA, strideA, vB, vC + Mbody); break; + default: + assert(false); // impossible since (K <= 6*kStep) + } + } else { + switch(numBtiles) { + case 1: kernel<1, false>(Mtail, K, pA, strideA, vB, vC + Mbody); break; + case 2: kernel<2, false>(Mtail, K, pA, strideA, vB, vC + Mbody); break; + case 3: kernel<3, false>(Mtail, K, pA, strideA, vB, vC + Mbody); break; + case 4: kernel<4, false>(Mtail, K, pA, strideA, vB, vC + Mbody); break; + case 5: kernel<5, false>(Mtail, K, pA, strideA, vB, vC + Mbody); break; + case 6: kernel<6, false>(Mtail, K, pA, strideA, vB, vC + Mbody); break; + default: + assert(false); // impossible since (K <= 6*kStep) + } + } + } + } +}; + +template::type> +struct Matmul { + // B matrix is orgnized as tensor2D of shape axb where a=round_up_div(N, 32), b=round_up(K,32/64)*32 + // so b is size of submatrix of Kx32 composed of two columns of B0/B1 tiles. + tensor2D internalB; + + bool constB; + bool transposeB; + + constexpr static bool is_bf16s8 = std::is_same::value && std::is_same::value; + constexpr static bool is_bf16bf16 = std::is_same::value && std::is_same::value; + constexpr static bool is_s8s8 = std::is_same::value && std::is_same::value; + constexpr static bool is_s8u8 = std::is_same::value && std::is_same::value; + constexpr static bool is_u8s8 = std::is_same::value && std::is_same::value; + constexpr static bool is_u8u8 = std::is_same::value && std::is_same::value; + constexpr static bool is_i8_mode = is_s8s8 || is_s8u8 || is_u8s8 || is_u8u8; + + // AMX bf16 & int8 has same M(=16) in A,C tile and same N(=16) in B tile + // but only different K(32 vs 64) in A,C & B tiles + constexpr static int kStep = is_i8_mode ? 64 : 32; + + // 2x2 C tiles buffer + // most usecase requires post-processing with AVX, thus buffC + // is used to transfer data to AVX register + tensor2D buffC; + + Matmul(bool constB = false, bool transposeB = false) : + constB(constB), transposeB(transposeB), buffC(32, 32) {} + + // ppkernel is a callable which captures the runtime args + // by itself, so no need to pass in any post-process related + // runtime args through this API + // + // n0/n1 allows us for calculating only partial results, so it + // can be used to run on multi-cores in parallel + // + // ppkernel will be invoked with true (m,n) with n0-offset added + // so ppkernel don't need to know which sub-matrix it's working on. + // + // for most ppkernels w/o runtime state, a single ppkernel can be + // shared among all threads. + // + // but for ppkernels doing reductions, it needs separate instance + // for each thread, also a post-merging process to combine the results. + // + // ppkernels are simple to write, further wrapping or structurelize only + // makes the design more complex, so we stop doing that. + + // I cannot find a way to call TDP intrinsic polymophically using overload or template. + // have to use old-macro-tricks, hopefully these compile-time checks can be optimized + // by compiler. +#define TILE_DP(dst, a, b) \ + if (is_bf16bf16) _tile_dpbf16ps(dst, a, b); \ + if (is_s8s8) _tile_dpbssd(dst, a, b); \ + if (is_s8u8) _tile_dpbsud(dst, a, b); \ + if (is_u8s8) _tile_dpbusd(dst, a, b); \ + if (is_u8u8) _tile_dpbuud(dst, a, b); + + template + void kernel_slimB(int M, int N, int K, int n0, + tensor2D & A, + void * B, + tensor2D & buffC, + PP ppkernel) { + auto * pB0 = reinterpret_cast(B); + auto * pC0 = &buffC[0]; + int8_t * pA0 = reinterpret_cast(&A[0]); + int strideA = A.stride; + int KlastOffBytes = (K - kStep)* sizeof(TA); + // load B tiles outside of loop + if (tmmN > 0) _tile_loadd(2, pB0, 64); pB0 += 1024*2; + if (tmmN > 1) _tile_loadd(3, pB0, 64); pB0 += 1024*2; + if (tmmN > 2) _tile_loadd(4, pB0, 64); pB0 += 1024*2; + if (tmmN > 3) _tile_loadd(5, pB0, 64); pB0 += 1024*2; + if (tmmN > 4) _tile_loadd(6, pB0, 64); pB0 += 1024*2; + if (tmmN > 5) _tile_loadd(7, pB0, 64); pB0 += 1024*2; + //asm("int3"); + for(int m0 = 0; m0 < M; m0+=16) { + int m = m0; + if (M - m0 < 16) { + // shift up to prevent M-tails + pA0 -= (16 - (M - m0))*A.stride; + m = M - 16; + } + _tile_zero(0); + if (tmmN == 1) { + _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); + } + if (tmmN == 2) { + _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); + _tile_loadd(1, pA0 + KlastOffBytes, strideA); TILE_DP(0, 1, 3); + } + if (tmmN == 3) { + _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); + _tile_loadd(1, pA0 + 64, strideA); TILE_DP(0, 1, 3); + _tile_loadd(1, pA0 + KlastOffBytes, strideA); TILE_DP(0, 1, 4); + } + if (tmmN == 4) { + _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); + _tile_loadd(1, pA0 + 64, strideA); TILE_DP(0, 1, 3); + _tile_loadd(1, pA0 + 128, strideA); TILE_DP(0, 1, 4); + _tile_loadd(1, pA0 + KlastOffBytes, strideA); TILE_DP(0, 1, 5); + } + if (tmmN == 5) { + _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); + _tile_loadd(1, pA0 + 64, strideA); TILE_DP(0, 1, 3); + _tile_loadd(1, pA0 + 128, strideA); TILE_DP(0, 1, 4); + _tile_loadd(1, pA0 + 192, strideA); TILE_DP(0, 1, 5); + _tile_loadd(1, pA0 + KlastOffBytes, strideA); TILE_DP(0, 1, 6); + } + if (tmmN == 6) { + _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); + _tile_loadd(1, pA0 + 64, strideA); TILE_DP(0, 1, 3); + _tile_loadd(1, pA0 + 128, strideA); TILE_DP(0, 1, 4); + _tile_loadd(1, pA0 + 192, strideA); TILE_DP(0, 1, 5); + _tile_loadd(1, pA0 + 256, strideA); TILE_DP(0, 1, 6); + _tile_loadd(1, pA0 + KlastOffBytes, strideA); TILE_DP(0, 1, 7); + } + _tile_stored(0, pC0, buffC.stride); + (ppkernel)(buffC, m, n0, 16, N); + pA0 += 16*A.stride; + } + } + + template + void operator()(tensor2D & matA, + tensor2D & _matB, + int n0, int n1, + PP ppkernel, + bool skip_repack = false) { + auto matB = getSubMatB(_matB, n0, n1, transposeB); + int M = matA.dims[0]; + int K = matA.dims[1]; + int N = matB.dims[transposeB ? 0 : 1]; + assert(K == matB.dims[transposeB ? 1 : 0]); + // Due to the fact that we load a full tile at tails of K dimension + // we may access memory address beyond the limit of A matrix + // to avoid read in nan values, we backoff to the left to ensure A tile + // contain valid numbers and no overflow access happens, but it requires K>=kStep; + assert(K >= kStep); + int Ktails = K % kStep; + int Kbody = K - Ktails; + int KbackoffBytes = (kStep - Ktails)*sizeof(TA); + + // for non-constB, internalB is updated every time + // for constB, internalB is updated once + if ((!constB && !skip_repack) || (internalB.capacity == 0)) { + repackB_1x2(matB, transposeB, internalB, constB); + } + + // special case when whole B matrix can fit in 6 tiles + // we can load B only once + if (M >= 16 && N <= 16 && K <= 6*kStep) { + // B is zero-padded + // C:0 + // A:1 + // B:2,3,4,5,6,7 + auto * pB0 = reinterpret_cast(&internalB[0]); + tileconfig_t tfg(1, 0, 8, 16, 64); + switch((K + kStep - 1)/kStep) { + case 1: kernel_slimB<1>(M, N, K, n0, matA, pB0, buffC, ppkernel); break; + case 2: kernel_slimB<2>(M, N, K, n0, matA, pB0, buffC, ppkernel); break; + case 3: kernel_slimB<3>(M, N, K, n0, matA, pB0, buffC, ppkernel); break; + case 4: kernel_slimB<4>(M, N, K, n0, matA, pB0, buffC, ppkernel); break; + case 5: kernel_slimB<5>(M, N, K, n0, matA, pB0, buffC, ppkernel); break; + case 6: kernel_slimB<6>(M, N, K, n0, matA, pB0, buffC, ppkernel); break; + default: + assert(false); // impossible since (K <= 6*kStep) + } + return; + } + + if (M <= 16) { + // register/cache blocking scheme is simplified when M <= 16 + // C_MxN: 0,1 + // A_MxK: 2, + // B_KxN: 3, 4 + tileconfig_t tfg(1, 0, {M,M,M,16,16}, 64); + auto * pB0 = reinterpret_cast(&internalB[0]); + auto * const pC0 = &buffC[0]; + int k; + const auto strideA = matA.stride; + loop2D_no_bM<32>(M, N, [&](int m, int n, int valid_m, int valid_n) { + _tile_zero(0); + _tile_zero(1); + int8_t * pA0 = reinterpret_cast(&matA[0]); + for(k=0; k(pB0); + _tile_loadd(3, pB0, 64); pB0 += 1024; // tile B0 32x16(16x16x2)/64x16(16x16x4) is always 1KB + // prefetch_bytes<1024, _MM_HINT_T1, 4096*48>(pB0); + _tile_loadd(4, pB0, 64); pB0 += 1024; // tile B1 32x16(16x16x2)/64x16(16x16x4) is always 1KB + TILE_DP(0, 2, 3); // C0 += A*B0 + TILE_DP(1, 2, 4); // C1 += A*B1 + } + if (Ktails) { + _tile_loadd(2, pA0 - KbackoffBytes, strideA); + // prefetch_bytes<1024, _MM_HINT_T1, 4096*48>(pB0); + _tile_loadd(3, pB0, 64); pB0 += 1024; + // prefetch_bytes<1024, _MM_HINT_T1, 4096*48>(pB0); + _tile_loadd(4, pB0, 64); pB0 += 1024; + TILE_DP(0, 2, 3); // C0 += A*B0 + TILE_DP(1, 2, 4); // C1 += A*B1 + } + _tile_stored(0, pC0, buffC.stride); + _tile_stored(1, pC0 + 16, buffC.stride); + //int valid_n = std::min(N - n, 32); + (ppkernel)(buffC, 0, n + n0, M, valid_n); + }); + return; + } + + auto kernel_2x2 = [&](int m, int n, int valid_m, int valid_n) { + auto * pA0 = reinterpret_cast(&matA(m, 0)); + auto * pA1 = reinterpret_cast(&matA(m + 16, 0)); + auto strideA = matA.stride; + auto * pB = reinterpret_cast(&internalB(n>>5, 0)); + _tile_zero(0); + _tile_zero(1); + _tile_zero(2); + _tile_zero(3); + // 2x2 + for (int k = 0; k < Kbody; k += kStep) { + _tile_loadd(4, pA0, strideA); pA0 += 64; + _tile_loadd(6, pB, 64); pB += 1024; + // prefetch_bytes<1024>(pB); + TILE_DP(0, 4, 6); + + _tile_loadd(5, pA1, strideA); pA1 += 64; + TILE_DP(2, 5, 6); + _tile_loadd(7, pB, 64); pB += 1024; + // prefetch_bytes<1024>(pB); + TILE_DP(1, 4, 7); + + TILE_DP(3, 5, 7); + } + if (Ktails) { + _tile_loadd(4, pA0 - KbackoffBytes, strideA); + _tile_loadd(6, pB, 64); pB += 1024; + // prefetch_bytes<1024>(pB); + TILE_DP(0, 4, 6); + + _tile_loadd(5, pA1 - KbackoffBytes, strideA); + TILE_DP(2, 5, 6); + _tile_loadd(7, pB, 64); pB += 1024; + // prefetch_bytes<1024>(pB); + TILE_DP(1, 4, 7); + + TILE_DP(3, 5, 7); + } + _tile_stored(0, &buffC(0,0), buffC.stride); + _tile_stored(1, &buffC(0,16), buffC.stride); + _tile_stored(2, &buffC(16,0), buffC.stride); + _tile_stored(3, &buffC(16,16), buffC.stride); + (ppkernel)(buffC, m, n + n0, valid_m, valid_n); + }; + + if (M <= 32 && M >16) { + // 2x2 tile, C:0/1/2/3 A:4/5 B:6/7 no blocking along M dimension + tileconfig_t tfg(1, 0, {16,16,M-16,M-16,16,M-16,16,16}, 64); + loop2D_no_bM<32>(M, N, kernel_2x2); + return; + } + + // generic input shapes with M > 32 + // determine cache blocking scheme + int elesz = sizeof(TA); + int L2 = 2048*1024; // 2MB + int slice_size = 32*rndup(K, 32)*elesz; + int mc = std::max(1, L2/slice_size - 1); + + // M > bM + tileconfig_t tfg(1, 0, 8, 16, 64); + loop2D_opt_Mtail<32, 32>(M, N, mc, kernel_2x2); + } +}; + +// specialization: +// TA is ov::bfloat16 and TB is int8_t, decompressed on the fly into ov::bfloat16 by simply convert +template<> +struct Matmul { + tensor2D internalBI8; + + // wei_buff is ping-pong buffer containing ov::bfloat16 weights decompressed on the fly. + tensor2D weiBuff; + + bool constB; + bool transposeB; + + constexpr static int kStep = 32; + + // 2x2 C tiles buffer + // most usecase requires post-processing with AVX, thus buffC + // is used to transfer data to AVX register + tensor2D buffC; + + Matmul(bool constB = false, bool transposeB = false) : + constB(constB), transposeB(transposeB), buffC(32, 32) {} + + float quant_scale_B; + float dequant_scale_B; + + template + void operator()(tensor2D & matA, + tensor2D & _matB, + int n0, int n1, + PP ppkernel) { + auto matB = getSubMatB(_matB, n0, n1, transposeB); + int M = matA.dims[0]; + int K = matA.dims[1]; + int N = matB.dims[transposeB ? 0 : 1]; + assert(K == matB.dims[transposeB ? 1 : 0]); + // Due to the fact that we load a full tile at tails of K dimension + // we may access memory address beyond the limit of A matrix + // to avoid read in nan values, we backoff to the left to ensure A tile + // contain valid numbers and no overflow access happens, but it requires K>=kStep; + assert(K >= kStep); + int Ktails = K % kStep; + int Kbody = K - Ktails; + int Kbackoff = (kStep - Ktails); + + // for non-constB, internalB is updated every time + // for constB, internalB is updated once + if (!constB || (internalBI8.capacity == 0)) { + // this dynamic quantization of weight matrix using minmax + // is time-consuming, should be used only for constB + if (!constB) { + std::cout << "\t WANING: dynamic quantization of weight matrix for non-constB is time-consuming " << std::endl; + } + // float min, max; + // functional::get_min_max(_matB, min, max); + // max = std::max(std::abs(max), std::abs(min)); + // quant_scale_B = 127 / max; + // dequant_scale_B = max / 127; + + tensor2D internalTmpB; + repackB_1x2(matB, transposeB, internalTmpB, constB); + functional::bf16_to_i8_tensor(internalBI8, internalTmpB, quant_scale_B); + } + + ppkernel.set_deq_scale(dequant_scale_B); + + if (M <= 16) { + // C:0/1 A:2 B:3/4 + // dequantize scale is moved into ppkernel + constexpr int prefetch_ahead = 64*1024; + tileconfig_t tfg(1, 0, {M,M,M,16,16}, 64); + auto * pBint = reinterpret_cast(&internalBI8[0]); + auto & B2buff = weiBuff; + B2buff.resize(32*2, 32); + auto * const pB = &B2buff[0]; + auto * pBsrc = pB + (32*32) * 0; + auto * pBdst = pB + (32*32) * 1; + functional::i8_to_bf16_Kx32<32>(pBint, pBsrc); + + auto * const pC0 = &buffC[0]; + const auto strideA = matA.stride; + loop2D_no_bM<32>(M, N, [&](int m, int n, int valid_m, int valid_n) { + // C:Mx32 = A:Mx32 x B:32x32 + _tile_zero(0); + _tile_zero(1); + auto * pA0 = &matA[0]; + for(int k=0; k(pBint); + + functional::i8_to_bf16_Kx32<8>(pBint, pBdst); + _tile_loadd(3, pBsrc, 64); + functional::i8_to_bf16_Kx32<8>(pBint, pBdst + 8*32); + _tile_dpbf16ps(0, 2, 3); // C0 += A*B0 + + prefetch_bytes<512, _MM_HINT_T1, prefetch_ahead>(pBint); + functional::i8_to_bf16_Kx32<8>(pBint, pBdst + 16*32); + _tile_loadd(4, pBsrc + 16*32, 64); + functional::i8_to_bf16_Kx32<8>(pBint, pBdst + 24*32); + _tile_dpbf16ps(1, 2, 4); // C1 += A*B1 + std::swap(pBsrc, pBdst); + } + if (Ktails) { + _tile_loadd(2, pA0 - Kbackoff, strideA); // backoff to prevent access beyond the end of A + prefetch_bytes<512, _MM_HINT_T1, prefetch_ahead>(pBint); + + functional::i8_to_bf16_Kx32<8>(pBint, pBdst); + _tile_loadd(3, pBsrc, 64); + functional::i8_to_bf16_Kx32<8>(pBint, pBdst + 8*32); + _tile_dpbf16ps(0, 2, 3); // C0 += A*B0 + + prefetch_bytes<512, _MM_HINT_T1, prefetch_ahead>(pBint); + functional::i8_to_bf16_Kx32<8>(pBint, pBdst + 16*32); + _tile_loadd(4, pBsrc + 16*32, 64); + functional::i8_to_bf16_Kx32<8>(pBint, pBdst + 24*32); + _tile_dpbf16ps(1, 2, 4); // C1 += A*B1 + std::swap(pBsrc, pBdst); + } + //prefetch_bytes<2048, _MM_HINT_T1, prefetch_ahead>(pBint); + _tile_stored(0, pC0, buffC.stride); + _tile_stored(1, pC0 + 16, buffC.stride); + //prefetch_bytes<2048, _MM_HINT_T1, prefetch_ahead>(pBint + 2048); + //int valid_n = std::min(N - n, 32); + (ppkernel)(buffC, 0, n + n0, M, valid_n); + }); + return; + } + + // 4 tiles buffC is reused as decompressed bf16 weights + constexpr int prefetch_ahead = 16*1024; + ov::bfloat16 * pBa = reinterpret_cast(&buffC(0,0)); + ov::bfloat16 * pBb = pBa + (16*32)*2; + auto kernel_2x2 = [&](int m, int n, int valid_m, int valid_n) { + auto strideA = matA.stride; + auto * pA0 = &matA(m, 0); + auto * pA1 = &matA(m + 16, 0); + auto * pBint = reinterpret_cast(&internalBI8(n>>5, 0)); + functional::i8_to_bf16_Kx32<32>(pBint, pBb); + + _tile_zero(0); + _tile_zero(1); + _tile_zero(2); + _tile_zero(3); + int k; + for (k = 0; k < Kbody; k += kStep) { + functional::i8_to_bf16_Kx32<16>(pBint, pBa); + + _tile_loadd(4, pA0 + k, strideA); + _tile_loadd(6, pBb, 64); + _tile_dpbf16ps(0, 4, 6); + + _tile_loadd(5, pA1 + k, strideA); + _tile_dpbf16ps(2, 5, 6); + + functional::i8_to_bf16_Kx32<16>(pBint, pBa + 16*32); + + _tile_loadd(7, pBb + 16*32, 64); + _tile_dpbf16ps(1, 4, 7); + _tile_dpbf16ps(3, 5, 7); + + std::swap(pBa, pBb); + } + if (Ktails) { + functional::i8_to_bf16_Kx32<16>(pBint, pBa); + + _tile_loadd(4, pA0 + k - Kbackoff, strideA); + _tile_loadd(6, pBb, 64); + _tile_dpbf16ps(0, 4, 6); + + _tile_loadd(5, pA1 + k - Kbackoff, strideA); + _tile_dpbf16ps(2, 5, 6); + + functional::i8_to_bf16_Kx32<16>(pBint, pBa + 16*32); + + _tile_loadd(7, pBb + 16*32, 64); + _tile_dpbf16ps(1, 4, 7); + _tile_dpbf16ps(3, 5, 7); + + std::swap(pBa, pBb); + } + _tile_stored(0, &buffC(0,0), buffC.stride); + _tile_stored(1, &buffC(0,16), buffC.stride); + _tile_stored(2, &buffC(16,0), buffC.stride); + _tile_stored(3, &buffC(16,16), buffC.stride); + (ppkernel)(buffC, m, n + n0, valid_m, valid_n); + }; + + if (M <= 32 && M > 16) { + // 2x2 C:0/1/2/3 A:4/5 B:6/7 + tileconfig_t tfg(1, 0, {16, 16, M-16, M-16, 16, M-16, 16, 16}, 64); + loop2D_no_bM<32>(M, N, kernel_2x2); + return; + } + + // determine blocking scheme + int elesz = sizeof(uint16_t); + int L2 = 2048*1024; // 2MB + int slice_size = 32*rndup(K, 32)*elesz; + int mc = std::max(1, L2/slice_size - 1); // if 1 32xK slice cannot fit L2, use 1 slice at least + + // main loop + tileconfig_t tfg(1, 0, 8, 16, 64); + loop2D_opt_Mtail<32, 32>(M, N, mc, kernel_2x2); + } +}; + +//https://stackoverflow.com/questions/29519222/how-to-transpose-a-16x16-matrix-using-simd-instructions +// vector multiply with matrix: +// mAvB: A(M, K) * B(K, 1) => C(M, 1) +// vAmB: A(1, K) * B(K, N) => C(1, N) +// +// in mAvB form, block of A (16x32) is transposed in register +// in unit of 2 packed bf16, and then vdpbf16ps was used +// to multiply with broadcasted B (2x1) and accumulate into C (16x1) +// +// B is pre-broadcasted in unit of 2 +// +struct GemAvB { + tensor2D Bpadded; + GemAvB() { + } + + void operator()(tensor2D & matA, + ov::bfloat16 * vecB, + float * vecC) { + int M = matA.dims[0]; + int K = matA.dims[1]; + + constexpr int kStep = 32; + + assert(K >= 32); + int Ktails = K % kStep; + int Kbody = K - Ktails; + int Kbackoff = (kStep - Ktails); + + if (K % 32) { + if (K > Bpadded.dims[1]) + Bpadded.resize(1, rndup(K, 32)); + auto newB = &Bpadded(0, 0); + memset(newB, 0, Bpadded.stride); + memcpy(newB, vecB, K * sizeof(ov::bfloat16)); + vecB = newB; + } + + for(int m = 0; m < M; m += 16) { + auto * pA = reinterpret_cast(&matA(m, 0)); + auto * pBi32 = reinterpret_cast(vecB); + __m512 regC0 = _mm512_setzero(); + __m512 regC1 = _mm512_setzero(); + __mmask16 kmask = _cvtu32_mask16(0xFFFF); + if (M-m < 16) { + kmask = _cvtu32_mask16(0xFFFF >> (16-(M-m))); + } + for(int k = 0; k < K; k += 32, pA += 64, pBi32 += 16) { + // handle Ab: 16x32 + // transposed in register as 16x16x2 + // r0: (a0,a1)(b0,b1).... + // r1: (a2,a3)(b2,b3).... + // ... + // rf: (a30,a31),(b30,b31).... + // + __m512i r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf; + auto stride = matA.stride; + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_loadu_epi32(pA + 2*stride); + r3 = _mm512_loadu_epi32(pA + 3*stride); + r4 = _mm512_loadu_epi32(pA + 4*stride); + r5 = _mm512_loadu_epi32(pA + 5*stride); + r6 = _mm512_loadu_epi32(pA + 6*stride); + r7 = _mm512_loadu_epi32(pA + 7*stride); + r8 = _mm512_loadu_epi32(pA + 8*stride); + r9 = _mm512_loadu_epi32(pA + 9*stride); + ra = _mm512_loadu_epi32(pA + 10*stride); + rb = _mm512_loadu_epi32(pA + 11*stride); + rc = _mm512_loadu_epi32(pA + 12*stride); + rd = _mm512_loadu_epi32(pA + 13*stride); + re = _mm512_loadu_epi32(pA + 14*stride); + rf = _mm512_loadu_epi32(pA + 15*stride); + + functional::transpose_m512i_16x16(r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf); + + // vdpbf16ps + regC0 = _mm512_dpbf16_ps(regC0, (__m512bh)(r0), (__m512bh)(_mm512_set1_epi32(pBi32[0]))); + regC1 = _mm512_dpbf16_ps(regC1, (__m512bh)(r1), (__m512bh)(_mm512_set1_epi32(pBi32[1]))); + regC0 = _mm512_dpbf16_ps(regC0, (__m512bh)(r2), (__m512bh)(_mm512_set1_epi32(pBi32[2]))); + regC1 = _mm512_dpbf16_ps(regC1, (__m512bh)(r3), (__m512bh)(_mm512_set1_epi32(pBi32[3]))); + regC0 = _mm512_dpbf16_ps(regC0, (__m512bh)(r4), (__m512bh)(_mm512_set1_epi32(pBi32[4]))); + regC1 = _mm512_dpbf16_ps(regC1, (__m512bh)(r5), (__m512bh)(_mm512_set1_epi32(pBi32[5]))); + regC0 = _mm512_dpbf16_ps(regC0, (__m512bh)(r6), (__m512bh)(_mm512_set1_epi32(pBi32[6]))); + regC1 = _mm512_dpbf16_ps(regC1, (__m512bh)(r7), (__m512bh)(_mm512_set1_epi32(pBi32[7]))); + regC0 = _mm512_dpbf16_ps(regC0, (__m512bh)(r8), (__m512bh)(_mm512_set1_epi32(pBi32[8]))); + regC1 = _mm512_dpbf16_ps(regC1, (__m512bh)(r9), (__m512bh)(_mm512_set1_epi32(pBi32[9]))); + regC0 = _mm512_dpbf16_ps(regC0, (__m512bh)(ra), (__m512bh)(_mm512_set1_epi32(pBi32[10]))); + regC1 = _mm512_dpbf16_ps(regC1, (__m512bh)(rb), (__m512bh)(_mm512_set1_epi32(pBi32[11]))); + regC0 = _mm512_dpbf16_ps(regC0, (__m512bh)(rc), (__m512bh)(_mm512_set1_epi32(pBi32[12]))); + regC1 = _mm512_dpbf16_ps(regC1, (__m512bh)(rd), (__m512bh)(_mm512_set1_epi32(pBi32[13]))); + regC0 = _mm512_dpbf16_ps(regC0, (__m512bh)(re), (__m512bh)(_mm512_set1_epi32(pBi32[14]))); + regC1 = _mm512_dpbf16_ps(regC1, (__m512bh)(rf), (__m512bh)(_mm512_set1_epi32(pBi32[15]))); + } + regC0 = _mm512_add_ps(regC0, regC1); + _mm512_mask_storeu_ps (vecC + m, kmask, regC0); + //auto regOut = _mm512_cvtne2ps_pbh(regC0, regC0); // only 16 ov::bfloat16 results in lower 256bits + //_mm256_storeu_si256(reinterpret_cast<__m256i_u *>(vecC + m), _mm512_extracti64x4_epi64(regOut, 0)); + } + } +}; + +} // namespace amx + +inline std::ostream & operator<<(std::ostream & os, const amx_kernel::PP::Steps & steps) { + os << "amx_kernel::PP::Steps::"; + if (steps == amx_kernel::PP::Steps::NONE) + os << "NONE"; + if (steps & amx_kernel::PP::Steps::DEQUANT) + os << "_DEQUANT"; + if (steps & amx_kernel::PP::Steps::BIAS) + os << "_BIAS"; + if (steps & amx_kernel::PP::Steps::GELU) + os << "_GELU"; + if (steps & amx_kernel::PP::Steps::QUANT) + os << "_QUANT"; + return os; +} diff --git a/src/softmax_kernel_avx512.hpp b/src/softmax_kernel_avx512.hpp index f9638d2..85030d0 100644 --- a/src/softmax_kernel_avx512.hpp +++ b/src/softmax_kernel_avx512.hpp @@ -11,12 +11,12 @@ #include #include #endif -#include "bf16.hpp" +#include "common/bf16.hpp" #include "llm_types.hpp" #include "utility_avx512.hpp" namespace llmdnn { - inline void exp_ps(__m512 & src) { + inline void exp_ps_avx512(__m512 & src) { static __m512 exp_ln_flt_min_f = _mm512_castsi512_ps(_mm512_set1_epi32(0xc2aeac50)); // log(FLT_MIN) static __m512 exp_ln_flt_max_f = _mm512_castsi512_ps(_mm512_set1_epi32(0x42b17218)); // log(FLT_MAX) static __m512 exp_log2ef = _mm512_castsi512_ps(_mm512_set1_epi32(0x3fb8aa3b)); // log2(e) @@ -85,10 +85,10 @@ namespace llmdnn { } template - void softmax(D* dst, float* src, int N, float* s_max=nullptr, float* s_sum=nullptr, float* quant=nullptr) { + void softmax_avx512(D* dst, float* src, int N, float* s_max=nullptr, float* s_sum=nullptr, float* quant=nullptr) { static_assert(std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value, - "softmax only support output data types ov::bfloat16/uint8_t/int8_t/float"); + "softmax_avx512 only support output data types ov::bfloat16/uint8_t/int8_t/float"); static __m512 one = _mm512_castsi512_ps(_mm512_set1_epi32(0x3f800000)); // 1.0f auto tail = N % 16; @@ -115,7 +115,7 @@ namespace llmdnn { for(i = 0; i < N - tail; i += 16) { auto x = _mm512_loadu_ps(src + i); x = _mm512_sub_ps(x, x_max); - exp_ps(x); // exp(x-x_max) + exp_ps_avx512(x); // exp(x-x_max) sum_exp = _mm512_add_ps(sum_exp, x); // sum(exp(x-x_max)) _mm512_storeu_ps(src + i, x); // save exp(x-x_max) } @@ -124,7 +124,7 @@ namespace llmdnn { if (tail) { auto x = _mm512_maskz_loadu_ps(x_mask, src + i); x = _mm512_sub_ps(x, x_max); - exp_ps(x); + exp_ps_avx512(x); x = _mm512_mask_blend_ps(x_mask, _mm512_setzero_ps(), x); sum_exp = _mm512_add_ps(sum_exp, x); _mm512_mask_storeu_ps(src + i, x_mask, x); diff --git a/src/transpose_kernel_avx512.hpp b/src/transpose_kernel_avx512.hpp index 3965d80..3bbc8c5 100644 --- a/src/transpose_kernel_avx512.hpp +++ b/src/transpose_kernel_avx512.hpp @@ -11,19 +11,19 @@ #include #include #endif -#include "bf16.hpp" +#include "common/bf16.hpp" #include "llm_types.hpp" #include "utility_avx512.hpp" namespace llmdnn { template - void memcpy2d_stride(D* dst, S* src, size_t height, size_t width, size_t src_stride, size_t dst_stride, float* quant=nullptr); + void memcpy2d_stride_avx512(D* dst, S* src, size_t height, size_t width, size_t src_stride, size_t dst_stride, float* quant=nullptr); template - void memcpy2d_stride(D* dst, float* src, size_t height, size_t width, size_t src_stride, size_t dst_stride, float* quant=nullptr) { + void memcpy2d_stride_avx512(D* dst, float* src, size_t height, size_t width, size_t src_stride, size_t dst_stride, float* quant=nullptr) { static_assert(std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value, - "softmax only support output data types ov::bfloat16/uint8_t/int8_t/float"); + "memcpy2d_stride_avx512 only support output data types ov::bfloat16/uint8_t/int8_t/float"); auto tail = width % 16; __mmask16 x_mask = _cvtu32_mask16(0xFFFFu >> (16 - tail)); diff --git a/src/utility_amx.hpp b/src/utility_amx.hpp index 2c348d4..9a9778f 100644 --- a/src/utility_amx.hpp +++ b/src/utility_amx.hpp @@ -62,6 +62,7 @@ struct tileconfig_t { } load(); } + tileconfig_t(int palette, int _startRow, int numTiles, int _rows, int columnsBytes) { palette_id = palette; startRow = _startRow; @@ -83,14 +84,15 @@ struct tileconfig_t { ~tileconfig_t() { _tile_release(); } + void __attribute__((noinline)) load() { - //std::cout << "\ttile load config ... " << std::flush; _tile_loadconfig(this); - //std::cout << *this << std::flush << std::endl; } + void store() { _tile_storeconfig(this); } + friend std::ostream& operator<<(std::ostream& out, const tileconfig_t& cfg) { out << " palette_id=" << static_cast(cfg.palette_id); out << " startRow=" << static_cast(cfg.startRow); diff --git a/src/utility_avx512.hpp b/src/utility_avx512.hpp index 8a5507d..ca55c7b 100644 --- a/src/utility_avx512.hpp +++ b/src/utility_avx512.hpp @@ -5,7 +5,7 @@ #pragma once #include -#include "bf16.hpp" +#include "common/bf16.hpp" #ifdef _WIN32 #include #else @@ -53,7 +53,7 @@ inline void store_n(__m128i mm, unsigned int n, void* storage) _mm_maskmoveu_si128(mm, reinterpret_cast< const __m128i& >(masks[n]), static_cast< char* >(storage)); } -inline void quant_i8(void* dst, void* src, size_t ele_num, float scale) { +inline void quant_i8_avx512(void* dst, void* src, size_t ele_num, float scale) { size_t i = 0; auto* a = reinterpret_cast(src); int8_t* d = reinterpret_cast(dst); @@ -81,7 +81,7 @@ inline void quant_i8(void* dst, void* src, size_t ele_num, float scale) { } // NOTE: did not handle tail because there should be enough room -inline void cvt_i32_f32(float* dst, int32_t* src, size_t ele_num) { +inline void cvt_i32_f32_avx512(float* dst, int32_t* src, size_t ele_num) { for (int i = 0; i < (ele_num + 15) / 16 * 16; i += 16) { auto a0 = _mm512_load_epi32(src); auto a_f = _mm512_cvtepi32_ps(a0); @@ -91,7 +91,7 @@ inline void cvt_i32_f32(float* dst, int32_t* src, size_t ele_num) { } } -inline void mul_add_f32(float* dst, float* src, float mul, float* add, int ele_num) { +inline void mul_add_f32_avx512(float* dst, float* src, float mul, float* add, int ele_num) { auto mul_f = _mm512_set1_ps(mul); int i; auto tail = ele_num % 16; diff --git a/tests/script/README.md b/tests/script/README.md index 547a508..b301c08 100644 --- a/tests/script/README.md +++ b/tests/script/README.md @@ -12,8 +12,12 @@ compile extension ``` ./build.sh ``` +debug version extension(llmdnn needs to config debug version also) +``` +DEBUG_EXT=1 ./build.sh +``` run test ``` -python test_mha_gpt.py +pytest ``` \ No newline at end of file diff --git a/tests/script/ext/mha_gpt.cpp b/tests/script/ext/mha_gpt.cpp index 03531f6..b2ed223 100644 --- a/tests/script/ext/mha_gpt.cpp +++ b/tests/script/ext/mha_gpt.cpp @@ -6,9 +6,9 @@ #include #include "alloca.h" #include "module.hpp" -#include "utility.hpp" +#include "common/utility.hpp" #include "utility_amx.hpp" -#include "mha_gpt.hpp" +#include "llm_mha_gpt.hpp" #include "test_common.hpp" void regclass_mha_gpt(pybind11::module m) { diff --git a/tests/script/ext/module.cpp b/tests/script/ext/module.cpp index 19419cd..0feb8f8 100644 --- a/tests/script/ext/module.cpp +++ b/tests/script/ext/module.cpp @@ -6,7 +6,7 @@ #include #include "module.hpp" #include "utility_amx.hpp" -#include "mha_gpt.hpp" +#include "llm_mha_gpt.hpp" #include "test_common.hpp" PYBIND11_MODULE(llmdnn, m) { diff --git a/tests/script/ext/setup.py b/tests/script/ext/setup.py index aa58078..43fe15c 100644 --- a/tests/script/ext/setup.py +++ b/tests/script/ext/setup.py @@ -5,6 +5,7 @@ from setuptools import setup, Extension from torch.utils import cpp_extension import sys +import os ''' using intel compiler: @@ -12,28 +13,32 @@ export CXX=icx export CC=icx ''' -debug = True +debug = False +if 'DEBUG_EXT' in os.environ: + debug = True if os.environ['DEBUG_EXT'] == '1' else False extra_args = ['-fopenmp', '-march=native'] -llmdnn_lib_dir = '../../../../../../../../bin/intel64/Release' +llmdnn_lib_dir = f'{os.getcwd()}/../../../../../../../../bin/intel64/Release' if debug: - llmdnn_lib_dir = '../../../../../../../../bin/intel64/Debug' + llmdnn_lib_dir = f'{os.getcwd()}/../../../../../../../../bin/intel64/Debug' extra_args += ['-g', '-O0'] + print('install debug version') +else: + print('install release version') + setup(name='llmdnn', ext_modules=[ cpp_extension.CppExtension( 'llmdnn', - ['module.cpp', 'mha_gpt.cpp', '../../src/test_common.cpp'], + ['module.cpp', 'mha_gpt.cpp', f'../../src/test_common.cpp'], extra_compile_args=extra_args, - #extra_link_args=['-lgomp'], - include_dirs=['../../src', - '../../../include', - '../../../src'], + include_dirs=[f'{os.getcwd()}/../../src', + f'{os.getcwd()}/../../../include', + f'{os.getcwd()}/../../../src'], library_dirs=[f'{sys.prefix}/lib', llmdnn_lib_dir], - #runtime_library_dirs=[ f'{sys.prefix}/lib', ], libraries=['llmdnn', 'stdc++']), ], - cmdclass={'build_ext': cpp_extension.BuildExtension.with_options(use_ninja=False)} + cmdclass={'build_ext': cpp_extension.BuildExtension} ) \ No newline at end of file diff --git a/tests/script/requirements.txt b/tests/script/requirements.txt index 15ff051..b526b7c 100644 --- a/tests/script/requirements.txt +++ b/tests/script/requirements.txt @@ -1,3 +1,5 @@ -f https://download.pytorch.org/whl/torch_stable.html numpy==1.24.2 torch==2.0.1+cpu +pytest +ninja \ No newline at end of file diff --git a/tests/script/test_mha_gpt.py b/tests/script/test_mha_gpt.py index bb2edf9..d08d05a 100644 --- a/tests/script/test_mha_gpt.py +++ b/tests/script/test_mha_gpt.py @@ -113,7 +113,26 @@ def forward(self, query, key, value, attention_mask=None): SIZE_PER_HEAD = 80 HIDDEN_SIZE = HEAD_NUM * SIZE_PER_HEAD MAX_POSITION_EMBEDDINGS = 1024 #2048 -def test(inputs): +def test_gpt_neox(): + inputs = [ + # q, k, v, attn_mask + # q: [batch, num_heads, query_seq_len, head_size] + # k: [batch, num_heads, key_seq_len, head_size] + # v: [batch, num_heads, value_seq_len, head_size] + # attn: [1, MAX_POSITION_EMBEDDINGS] + (np.random.random(size=[2, HEAD_NUM, 2, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 32, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 32, SIZE_PER_HEAD]).astype(np.float32), + np.zeros([1, 32], dtype=np.float32)), + (np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), + np.zeros([1, 200], dtype=np.float32)), + (np.random.random(size=[2, HEAD_NUM, 1, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), + np.zeros([1, 200], dtype=np.float32)), + ] class FakeConfig: def __init__(self): self.num_attention_heads = HEAD_NUM @@ -134,28 +153,10 @@ def __init__(self): output = net.forward(q, k, v, attn_mask) if not torch.allclose(ref_output, output, rtol=0.001, atol=0.01): print(f"error at index {i} ref:\n{ref_output} \ncur:\n {output} ") + assert(False) print('done.') return if __name__ == "__main__": - inputs = [ - # q, k, v, attn_mask - # q: [batch, num_heads, query_seq_len, head_size] - # k: [batch, num_heads, key_seq_len, head_size] - # v: [batch, num_heads, value_seq_len, head_size] - # attn: [1, MAX_POSITION_EMBEDDINGS] - (np.random.random(size=[2, HEAD_NUM, 2, SIZE_PER_HEAD]).astype(np.float32), - np.random.random(size=[2, HEAD_NUM, 32, SIZE_PER_HEAD]).astype(np.float32), - np.random.random(size=[2, HEAD_NUM, 32, SIZE_PER_HEAD]).astype(np.float32), - np.zeros([1, 32], dtype=np.float32)), - (np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), - np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), - np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), - np.zeros([1, 200], dtype=np.float32)), - (np.random.random(size=[2, HEAD_NUM, 1, SIZE_PER_HEAD]).astype(np.float32), - np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), - np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), - np.zeros([1, 200], dtype=np.float32)), - ] - test(inputs) \ No newline at end of file + test_gpt_neox() \ No newline at end of file diff --git a/tests/src/test_common.cpp b/tests/src/test_common.cpp index 595b243..a96a268 100644 --- a/tests/src/test_common.cpp +++ b/tests/src/test_common.cpp @@ -10,8 +10,8 @@ #include #include #include +#include "common/simple_parallel.hpp" #include "test_common.hpp" -#include "simple_parallel.hpp" #ifndef _GNU_SOURCE #define _GNU_SOURCE /* See feature_test_macros(7) */ diff --git a/tests/src/test_common.hpp b/tests/src/test_common.hpp index e557de3..1071158 100644 --- a/tests/src/test_common.hpp +++ b/tests/src/test_common.hpp @@ -12,8 +12,8 @@ #include #include "llm_types.hpp" #include "llm_fc.hpp" -#include "tensor2d.hpp" -#include "bf16.hpp" +#include "common/tensor2d.hpp" +#include "common/bf16.hpp" #ifdef _WIN32 #include #else diff --git a/tests/src/test_fc_kernel.cpp b/tests/src/test_fc_kernel.cpp index 04583f7..a0a0532 100644 --- a/tests/src/test_fc_kernel.cpp +++ b/tests/src/test_fc_kernel.cpp @@ -11,8 +11,8 @@ #include #include "gtest/gtest.h" #include "llm_fc.hpp" -#include "tensor2d.hpp" -#include "tensor2d_helper.hpp" +#include "common/tensor2d.hpp" +#include "common/tensor2d_helper.hpp" #include "test_common.hpp" using namespace std; diff --git a/tests/src/test_mm_kernel.cpp b/tests/src/test_mm_kernel.cpp index 9a081d3..9cd3507 100644 --- a/tests/src/test_mm_kernel.cpp +++ b/tests/src/test_mm_kernel.cpp @@ -11,8 +11,8 @@ #include #include "gtest/gtest.h" #include "llm_mm.hpp" -#include "tensor2d.hpp" -#include "tensor2d_helper.hpp" +#include "common/tensor2d.hpp" +#include "common/tensor2d_helper.hpp" #include "test_common.hpp" using namespace std; diff --git a/tests/src/test_softmax_kernel_avx512.cpp b/tests/src/test_softmax_kernel_avx512.cpp index fef8577..6f2ffb8 100644 --- a/tests/src/test_softmax_kernel_avx512.cpp +++ b/tests/src/test_softmax_kernel_avx512.cpp @@ -11,8 +11,8 @@ #include #include "gtest/gtest.h" #include "llm_mm.hpp" -#include "tensor2d.hpp" -#include "tensor2d_helper.hpp" +#include "common/tensor2d.hpp" +#include "common/tensor2d_helper.hpp" #include "softmax_kernel_avx512.hpp" #include "test_common.hpp" @@ -94,7 +94,7 @@ class SoftmaxTest : public TestWithParam { } quant = 128.f; gen_ref(A, out_ref, quant); - llmdnn::softmax(out.data, A.data, n, nullptr, nullptr, quant.data); + llmdnn::softmax_avx512(out.data, A.data, n, nullptr, nullptr, quant.data); for (int i = 0; i < n; i++) { float a = out[i]; float b = out_ref[i]; diff --git a/tests/src/test_transpose_kernel_avx512.cpp b/tests/src/test_transpose_kernel_avx512.cpp index 2dfe988..30e8ee7 100644 --- a/tests/src/test_transpose_kernel_avx512.cpp +++ b/tests/src/test_transpose_kernel_avx512.cpp @@ -11,8 +11,8 @@ #include #include "gtest/gtest.h" #include "llm_mm.hpp" -#include "tensor2d.hpp" -#include "tensor2d_helper.hpp" +#include "common/tensor2d.hpp" +#include "common/tensor2d_helper.hpp" #include "transpose_kernel_avx512.hpp" #include "test_common.hpp" @@ -90,7 +90,7 @@ class TransposeTest : public TestWithParam { auto* dst_p_ref = dst_ref.data; for (int i = 0; i < num_heads; i++) { auto* src_p = &src(i, 0); - llmdnn::memcpy2d_stride(dst_p, src_p, query_seq_len, head_size, + llmdnn::memcpy2d_stride_avx512(dst_p, src_p, query_seq_len, head_size, head_size * sizeof(float), num_heads * head_size * sizeof(T), quant.data); gen_ref(dst_p_ref, src_p, query_seq_len, head_size, head_size * sizeof(float), num_heads * head_size * sizeof(T), quant.data); diff --git a/tests/src/test_utility.cpp b/tests/src/test_utility.cpp index e3b91e0..c46a7ff 100644 --- a/tests/src/test_utility.cpp +++ b/tests/src/test_utility.cpp @@ -10,8 +10,8 @@ #include #include "gtest/gtest.h" #include "llm_mm.hpp" -#include "tensor2d.hpp" -#include "tensor2d_helper.hpp" +#include "common/tensor2d.hpp" +#include "common/tensor2d_helper.hpp" #include "utility_avx512.hpp" #include "test_common.hpp" @@ -29,7 +29,7 @@ TEST(smoke_Utility, muladd) { bias[i] = -100.0f + i; ref[i] = x[i] * normal_factor + bias[i]; } - mul_add_f32(x_out.data(), x.data(), normal_factor, bias.data(), len); + mul_add_f32_avx512(x_out.data(), x.data(), normal_factor, bias.data(), len); for (size_t i = 0; i < x.size(); i++) { ASSERT_TRUE(std::abs(x_out[i] - ref[i]) < 0.0001f) << " length: " << len << " pos: " << i << " cur: " << x[i] << " ref: " << ref[i]; } diff --git a/tests/src/test_utility_repack1x2.cpp b/tests/src/test_utility_repack1x2.cpp index 4597b5b..5b0a645 100644 --- a/tests/src/test_utility_repack1x2.cpp +++ b/tests/src/test_utility_repack1x2.cpp @@ -11,9 +11,9 @@ #include #include "gtest/gtest.h" #include "llm_mm.hpp" -#include "tensor2d.hpp" -#include "tensor2d_helper.hpp" -#include "mm_kernel_amx.hpp" +#include "common/tensor2d.hpp" +#include "common/tensor2d_helper.hpp" +#include "mm_kernel_common_amx.hpp" #include "test_common.hpp" using namespace std; From 3cfc8de84115e873ad01311c8acb2209cbf7cee1 Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Thu, 29 Jun 2023 01:51:25 +0800 Subject: [PATCH 07/54] use add_subdir to simplify cmake --- CMakeLists.txt | 17 +++++---- src/CMakeLists.txt | 6 +-- src/common/tensor2d.hpp | 2 +- src/fc_kernel_amx.cpp | 2 +- src/fc_kernel_api.cpp | 2 +- src/mha_gpt_amx.cpp | 9 +++-- src/mm_kernel_amx.cpp | 2 +- src/mm_kernel_common_amx.hpp | 37 ++++++++++--------- src/softmax_kernel_avx512.hpp | 2 +- src/transpose_kernel_avx512.hpp | 4 +- ...utility_amx.hpp => utility_kernel_amx.hpp} | 0 ...y_avx512.hpp => utility_kernel_avx512.hpp} | 2 +- tests/CMakeLists.txt | 22 ++++++----- tests/script/build.sh | 7 ++-- tests/script/ext/mha_gpt.cpp | 2 +- tests/script/ext/module.cpp | 2 +- ...t_fc_kernel.cpp => test_fc_kernel_amx.cpp} | 2 +- ...t_mm_kernel.cpp => test_mm_kernel_amx.cpp} | 2 +- tests/src/test_softmax_kernel_avx512.cpp | 4 +- tests/src/test_transpose_kernel_avx512.cpp | 14 +++---- ...ity.cpp => test_utility_kernel_avx512.cpp} | 2 +- ... test_utility_kernel_repack1x2_avx512.cpp} | 3 +- 22 files changed, 75 insertions(+), 70 deletions(-) rename src/{utility_amx.hpp => utility_kernel_amx.hpp} (100%) rename src/{utility_avx512.hpp => utility_kernel_avx512.hpp} (98%) rename tests/src/{test_fc_kernel.cpp => test_fc_kernel_amx.cpp} (99%) rename tests/src/{test_mm_kernel.cpp => test_mm_kernel_amx.cpp} (99%) rename tests/src/{test_utility.cpp => test_utility_kernel_avx512.cpp} (96%) rename tests/src/{test_utility_repack1x2.cpp => test_utility_kernel_repack1x2_avx512.cpp} (98%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 84387af..d1d46ba 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -7,17 +7,18 @@ cmake_minimum_required(VERSION 3.13) project(root) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) -option(BUILD_TESTS "Build with tests" ON) -option(BUILD_PYTHON_TESTS "Build with tests need python extension" OFF) +option(LLMDNN_BUILD_TESTS "Build with tests" ON) message(INFO "--------------------------------") -message(STATUS "Build with tests: ${BUILD_TESTS}") +message(STATUS "Build with tests: ${LLMDNN_BUILD_TESTS}") message(INFO "--------------------------------") set(CMAKE_CXX_STANDARD 11) if(MSVC) - # TODO - message(FATAL_ERROR "Not support yet. Use intel compiler 2023.0+.") + # TODO: validate + if(MSVC_VERSION VERSION_LESS 1928) + message(FATAL_ERROR "Insufficient msvc compiler version, current ${MSVC_VERSION}, minimum 1928.") + endif() # Force to always compile with W4 if(CMAKE_CXX_FLAGS MATCHES "/W[0-4]") string(REGEX REPLACE "/W[0-4]" "/W4" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") @@ -25,8 +26,8 @@ if(MSVC) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /W4") endif() elseif(CMAKE_COMPILER_IS_GNUCC OR CMAKE_COMPILER_IS_GNUCXX) - if(CMAKE_CXX_COMPILER_VERSION VERSION_LESS "11.0") - message(FATAL_ERROR "Insufficient gcc compiler version, current ${CMAKE_CXX_COMPILER_VERSION}, minimum 12.0.") + if(CMAKE_CXX_COMPILER_VERSION VERSION_LESS "11.3") + message(FATAL_ERROR "Insufficient gcc compiler version, current ${CMAKE_CXX_COMPILER_VERSION}, minimum 11.3.") endif() set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=sapphirerapids") elseif(CMAKE_CXX_COMPILER_ID STREQUAL "IntelLLVM") @@ -45,6 +46,6 @@ if (NOT CMAKE_RUNTIME_OUTPUT_DIRECTORY) set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) endif() add_subdirectory(src) -if (BUILD_TESTS) +if (LLMDNN_BUILD_TESTS) add_subdirectory(tests) endif() diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 2ece614..8bd7d53 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -8,10 +8,8 @@ project(llmdnn) file(GLOB_RECURSE SOURCE_FILES ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp) add_library(llmdnn STATIC ${SOURCE_FILES}) -add_compile_definitions(DNNL_CPU_THREADING_RUNTIME=DNNL_RUNTIME_TBB) -target_compile_definitions(llmdnn PRIVATE LLMDNN_EXPORT) -target_include_directories(llmdnn INTERFACE ${CMAKE_CURRENT_SOURCE_DIR}) +target_include_directories(llmdnn PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} + INTERFACE ${CMAKE_CURRENT_SOURCE_DIR}/../include) set_target_properties(llmdnn PROPERTIES POSITION_INDEPENDENT_CODE ON ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}) -install(TARGETS llmdnn DESTINATION ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}) diff --git a/src/common/tensor2d.hpp b/src/common/tensor2d.hpp index e24c8c7..0c4a8f8 100644 --- a/src/common/tensor2d.hpp +++ b/src/common/tensor2d.hpp @@ -108,7 +108,7 @@ struct tensor2D { data = reinterpret_cast(aligned_alloc(64, capacity)); } if (is_const) - memset(data, 0, need_capacity); + memset(static_cast(data), 0, need_capacity); if (reinterpret_cast(data) % 64) std::cout << "WARNING: resize(), data is not cache-line aligned!" << std::endl; } diff --git a/src/fc_kernel_amx.cpp b/src/fc_kernel_amx.cpp index 0843fc0..ba5c41f 100644 --- a/src/fc_kernel_amx.cpp +++ b/src/fc_kernel_amx.cpp @@ -16,7 +16,7 @@ #include "llm_fc.hpp" #include "mm_kernel_common_amx.hpp" -#include "utility_avx512.hpp" +#include "utility_kernel_avx512.hpp" #include "fc_kernel_amx.hpp" namespace llmdnn { diff --git a/src/fc_kernel_api.cpp b/src/fc_kernel_api.cpp index 12fe913..a063f7e 100644 --- a/src/fc_kernel_api.cpp +++ b/src/fc_kernel_api.cpp @@ -17,7 +17,7 @@ #include "llm_fc.hpp" #include "fc_kernel_amx.hpp" #include "mm_kernel_common_amx.hpp" -#include "utility_avx512.hpp" +#include "utility_kernel_avx512.hpp" namespace llmdnn { diff --git a/src/mha_gpt_amx.cpp b/src/mha_gpt_amx.cpp index dd84591..668515c 100644 --- a/src/mha_gpt_amx.cpp +++ b/src/mha_gpt_amx.cpp @@ -7,11 +7,12 @@ #include "common/simple_parallel.hpp" #include "common/utility.hpp" -#include "utility_avx512.hpp" +#include "utility_kernel_avx512.hpp" #include "mm_kernel_common_amx.hpp" #include "softmax_kernel_avx512.hpp" #include "transpose_kernel_avx512.hpp" #include "llm_mha_gpt.hpp" +#include "mha_gpt_amx.hpp" using namespace ov::cpu; @@ -179,7 +180,7 @@ void mha_gpt_impl_amx::mha_bf16(const mha_gpt::exec_param ¶m) { auto pMatMul0Out = bufferMatMul0Out_local; // loop along K dimension size_t valid_softmax_items = causal_mask_offset_start + seq_start + 1; - for (size_t m = 0; m < seq_cout; m++) { + for (int m = 0; m < seq_cout; m++) { float* src = reinterpret_cast(pMatMul0Out + m * rndup(param.key_seq_len * sizeof(float), 64)); ov::bfloat16* dst = reinterpret_cast(pMatMul0Out + m * rndup(param.key_seq_len * sizeof(ov::bfloat16), 64)); mul_add_f32_avx512(src, src, _create_param.normal_factor, pAddIn1_aux, valid_softmax_items); @@ -187,7 +188,7 @@ void mha_gpt_impl_amx::mha_bf16(const mha_gpt::exec_param ¶m) { // attn_scores = torch.where(causal_mask, attn_scores, mask_value) if (param.key_seq_len > valid_softmax_items) { auto *invalidPtr = dst + valid_softmax_items; - memset(invalidPtr, 0, (param.key_seq_len - valid_softmax_items) * get_precision_size(_create_param.qkv_precision)); + memset(static_cast(invalidPtr), 0, (param.key_seq_len - valid_softmax_items) * get_precision_size(_create_param.qkv_precision)); valid_softmax_items = std::min(valid_softmax_items + 1, param.key_seq_len); } } @@ -305,7 +306,7 @@ void mha_gpt_impl_amx::mha_i8(const mha_gpt::exec_param ¶m) { auto pMatMul0Out = bufferMatMul0Out_local; // loop along K dimension size_t valid_softmax_items = causal_mask_offset_start + seq_start + 1; - for (size_t m = 0; m < seq_cout; m++) { + for (int m = 0; m < seq_cout; m++) { float* src = reinterpret_cast(pMatMul0Out + m * rndup(param.key_seq_len * sizeof(float), 64)); uint8_t* dst = reinterpret_cast(pMatMul0Out + m * rndup(param.key_seq_len * sizeof(uint8_t), 64)); mul_add_f32_avx512(src, src, mul_scales, pAddIn1_aux, valid_softmax_items); diff --git a/src/mm_kernel_amx.cpp b/src/mm_kernel_amx.cpp index 3c0a810..d4adf8f 100644 --- a/src/mm_kernel_amx.cpp +++ b/src/mm_kernel_amx.cpp @@ -16,7 +16,7 @@ #include "llm_mm.hpp" #include "mm_kernel_common_amx.hpp" -#include "utility_avx512.hpp" +#include "utility_kernel_avx512.hpp" #include "mm_kernel_amx.hpp" namespace llmdnn { diff --git a/src/mm_kernel_common_amx.hpp b/src/mm_kernel_common_amx.hpp index d07d2c7..c6a8389 100644 --- a/src/mm_kernel_common_amx.hpp +++ b/src/mm_kernel_common_amx.hpp @@ -6,7 +6,7 @@ #include "common/bf16.hpp" #include "common/tensor2d.hpp" -#include "utility_amx.hpp" +#include "utility_kernel_amx.hpp" #ifdef _WIN32 #include @@ -918,7 +918,6 @@ namespace functional { // mixed row int tails_nz = (src_rows & 3); if (tails_nz) { - __mmask32 kmask1 = _cvtu32_mask32(0xFFFFFFFF); auto a256 = _mm256_setzero_si256(); // must be zero auto b256 = _mm256_setzero_si256(); // when tails_nz > 2 auto c256 = _mm256_setzero_si256(); // when tails_nz > 1 @@ -1082,7 +1081,6 @@ namespace functional { // mixed row int tails_nz = (src_rows & 3); if (tails_nz) { - __mmask32 kmask1 = _cvtu32_mask32(0xFFFFFFFF); auto a256 = _mm256_setzero_si256(); // must be zero auto b256 = _mm256_setzero_si256(); // when tails_nz > 2 auto c256 = _mm256_setzero_si256(); // when tails_nz > 1 @@ -1487,7 +1485,6 @@ void loop2D(int M, int N, int mc, F f) { // but it works only when (M >= bM) template void loop2D_opt_Mtail(int M, int N, int mc, F f) { - int tailM = (M % (mc*bM)) % bM; assert(M > bM); for(int m0=0; m0 0) _tile_loadd(2, pB0, 64); pB0 += 1024*2; - if (tmmN > 1) _tile_loadd(3, pB0, 64); pB0 += 1024*2; - if (tmmN > 2) _tile_loadd(4, pB0, 64); pB0 += 1024*2; - if (tmmN > 3) _tile_loadd(5, pB0, 64); pB0 += 1024*2; - if (tmmN > 4) _tile_loadd(6, pB0, 64); pB0 += 1024*2; - if (tmmN > 5) _tile_loadd(7, pB0, 64); pB0 += 1024*2; + if (tmmN > 0) { + _tile_loadd(2, pB0, 64); pB0 += 1024*2; + } + if (tmmN > 1) { + _tile_loadd(3, pB0, 64); pB0 += 1024*2; + } + if (tmmN > 2) { + _tile_loadd(4, pB0, 64); pB0 += 1024*2; + } + if (tmmN > 3) { + _tile_loadd(5, pB0, 64); pB0 += 1024*2; + } + if (tmmN > 4) { + _tile_loadd(6, pB0, 64); pB0 += 1024*2; + } + if (tmmN > 5) { + _tile_loadd(7, pB0, 64); pB0 += 1024*2; + } //asm("int3"); for(int m0 = 0; m0 < M; m0+=16) { int m = m0; @@ -2258,7 +2267,6 @@ struct Matmul { } // 4 tiles buffC is reused as decompressed bf16 weights - constexpr int prefetch_ahead = 16*1024; ov::bfloat16 * pBa = reinterpret_cast(&buffC(0,0)); ov::bfloat16 * pBb = pBa + (16*32)*2; auto kernel_2x2 = [&](int m, int n, int valid_m, int valid_n) { @@ -2357,18 +2365,13 @@ struct GemAvB { int M = matA.dims[0]; int K = matA.dims[1]; - constexpr int kStep = 32; - assert(K >= 32); - int Ktails = K % kStep; - int Kbody = K - Ktails; - int Kbackoff = (kStep - Ktails); if (K % 32) { if (K > Bpadded.dims[1]) Bpadded.resize(1, rndup(K, 32)); auto newB = &Bpadded(0, 0); - memset(newB, 0, Bpadded.stride); + memset(static_cast(newB), 0, Bpadded.stride); memcpy(newB, vecB, K * sizeof(ov::bfloat16)); vecB = newB; } diff --git a/src/softmax_kernel_avx512.hpp b/src/softmax_kernel_avx512.hpp index 85030d0..92e0e79 100644 --- a/src/softmax_kernel_avx512.hpp +++ b/src/softmax_kernel_avx512.hpp @@ -13,7 +13,7 @@ #endif #include "common/bf16.hpp" #include "llm_types.hpp" -#include "utility_avx512.hpp" +#include "utility_kernel_avx512.hpp" namespace llmdnn { inline void exp_ps_avx512(__m512 & src) { diff --git a/src/transpose_kernel_avx512.hpp b/src/transpose_kernel_avx512.hpp index 3bbc8c5..abcd37d 100644 --- a/src/transpose_kernel_avx512.hpp +++ b/src/transpose_kernel_avx512.hpp @@ -13,7 +13,7 @@ #endif #include "common/bf16.hpp" #include "llm_types.hpp" -#include "utility_avx512.hpp" +#include "utility_kernel_avx512.hpp" namespace llmdnn { template @@ -29,7 +29,7 @@ namespace llmdnn { __mmask16 x_mask = _cvtu32_mask16(0xFFFFu >> (16 - tail)); for (size_t j = 0; j < height; j++) { - int i; + size_t i; if (std::is_same::value) { for(i = 0; i < width - tail; i += 16) { auto x = _mm512_loadu_ps(src + i); diff --git a/src/utility_amx.hpp b/src/utility_kernel_amx.hpp similarity index 100% rename from src/utility_amx.hpp rename to src/utility_kernel_amx.hpp diff --git a/src/utility_avx512.hpp b/src/utility_kernel_avx512.hpp similarity index 98% rename from src/utility_avx512.hpp rename to src/utility_kernel_avx512.hpp index ca55c7b..e47f5c1 100644 --- a/src/utility_avx512.hpp +++ b/src/utility_kernel_avx512.hpp @@ -82,7 +82,7 @@ inline void quant_i8_avx512(void* dst, void* src, size_t ele_num, float scale) { // NOTE: did not handle tail because there should be enough room inline void cvt_i32_f32_avx512(float* dst, int32_t* src, size_t ele_num) { - for (int i = 0; i < (ele_num + 15) / 16 * 16; i += 16) { + for (size_t i = 0; i < (ele_num + 15) / 16 * 16; i += 16) { auto a0 = _mm512_load_epi32(src); auto a_f = _mm512_cvtepi32_ps(a0); _mm512_storeu_ps(dst, a_f); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 09b7295..ceb0369 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -5,15 +5,18 @@ cmake_minimum_required(VERSION 3.13) project(llmdnn_tests) -include(FetchContent) - -FetchContent_Declare( - googletest - GIT_REPOSITORY https://github.com/google/googletest.git - GIT_TAG release-1.11.0 - GIT_SHALLOW TRUE - GIT_PROGRESS TRUE) -FetchContent_MakeAvailable(googletest) +enable_testing() + +if (NOT TARGET gtest_main) + include(FetchContent) + FetchContent_Declare( + googletest + GIT_REPOSITORY https://github.com/google/googletest.git + GIT_TAG release-1.11.0 + GIT_SHALLOW TRUE + GIT_PROGRESS TRUE) + FetchContent_MakeAvailable(googletest) +endif() set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) @@ -25,4 +28,3 @@ find_package(OpenMP REQUIRED) add_executable(llmdnn_tests ${TEST_SOURCE_FILES}) target_link_libraries(llmdnn_tests llmdnn gtest_main stdc++ OpenMP::OpenMP_CXX) -install(TARGETS llmdnn_tests DESTINATION ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}) diff --git a/tests/script/build.sh b/tests/script/build.sh index b466a7c..ab29766 100755 --- a/tests/script/build.sh +++ b/tests/script/build.sh @@ -1,7 +1,8 @@ #!/bin/bash pip uninstall -y llmdnn -cd ../../../../../../../build/ && make -j 20 llmdnnlib -cd - -cd ext +cd ../../../../../../../build/ || exit +make -j 20 llmdnn +cd - || exit +cd ext || exit python setup.py clean --all install diff --git a/tests/script/ext/mha_gpt.cpp b/tests/script/ext/mha_gpt.cpp index b2ed223..16dd95b 100644 --- a/tests/script/ext/mha_gpt.cpp +++ b/tests/script/ext/mha_gpt.cpp @@ -7,7 +7,7 @@ #include "alloca.h" #include "module.hpp" #include "common/utility.hpp" -#include "utility_amx.hpp" +#include "utility_kernel_amx.hpp" #include "llm_mha_gpt.hpp" #include "test_common.hpp" diff --git a/tests/script/ext/module.cpp b/tests/script/ext/module.cpp index 0feb8f8..a3ff62b 100644 --- a/tests/script/ext/module.cpp +++ b/tests/script/ext/module.cpp @@ -5,7 +5,7 @@ #include #include #include "module.hpp" -#include "utility_amx.hpp" +#include "utility_kernel_amx.hpp" #include "llm_mha_gpt.hpp" #include "test_common.hpp" diff --git a/tests/src/test_fc_kernel.cpp b/tests/src/test_fc_kernel_amx.cpp similarity index 99% rename from tests/src/test_fc_kernel.cpp rename to tests/src/test_fc_kernel_amx.cpp index a0a0532..308bc02 100644 --- a/tests/src/test_fc_kernel.cpp +++ b/tests/src/test_fc_kernel_amx.cpp @@ -50,7 +50,7 @@ class FCKernelTest : public TestWithParam { } protected: - virtual void SetUp() { + virtual void SetUp() override { initXTILE(); FCKernelTestShape shape; diff --git a/tests/src/test_mm_kernel.cpp b/tests/src/test_mm_kernel_amx.cpp similarity index 99% rename from tests/src/test_mm_kernel.cpp rename to tests/src/test_mm_kernel_amx.cpp index 9cd3507..61fdeb6 100644 --- a/tests/src/test_mm_kernel.cpp +++ b/tests/src/test_mm_kernel_amx.cpp @@ -46,7 +46,7 @@ class GemmKernelTest : public TestWithParam { } protected: - virtual void SetUp() { + virtual void SetUp() override { initXTILE(); MMKernelTestShape shape; diff --git a/tests/src/test_softmax_kernel_avx512.cpp b/tests/src/test_softmax_kernel_avx512.cpp index 6f2ffb8..bfb0b45 100644 --- a/tests/src/test_softmax_kernel_avx512.cpp +++ b/tests/src/test_softmax_kernel_avx512.cpp @@ -38,7 +38,7 @@ class SoftmaxTest : public TestWithParam { } protected: - virtual void SetUp() { + virtual void SetUp() override { std::tie(_types) = GetParam(); }; @@ -59,7 +59,7 @@ class SoftmaxTest : public TestWithParam { } out.resize(x.dims[0], x.dims[1], true); if (std::is_same::value) { - memcpy(out.data, y.data, x.dims[0] * x.dims[1] * sizeof(float)); + memcpy(static_cast(out.data), y.data, x.dims[0] * x.dims[1] * sizeof(float)); } if (std::is_same::value) { for(int i = 0; i < x.dims[1]; i++) { diff --git a/tests/src/test_transpose_kernel_avx512.cpp b/tests/src/test_transpose_kernel_avx512.cpp index 30e8ee7..b2b36e8 100644 --- a/tests/src/test_transpose_kernel_avx512.cpp +++ b/tests/src/test_transpose_kernel_avx512.cpp @@ -38,31 +38,31 @@ class TransposeTest : public TestWithParam { } protected: - virtual void SetUp() { + virtual void SetUp() override { std::tie(_types) = GetParam(); }; template static void gen_ref(T* dst, float* src, size_t height, size_t width, size_t src_stride, size_t dst_stride, float* quant) { - for (int j = 0; j < height; j++) { + for (size_t j = 0; j < height; j++) { if (std::is_same::value) { - memcpy(dst, src, width * sizeof(float)); + memcpy(static_cast(dst), src, width * sizeof(float)); } if (std::is_same::value) { - for(int i = 0; i < width; i++) { + for(size_t i = 0; i < width; i++) { dst[i] = src[i]; } } #define CLIP(x, low, high) \ (x < low ? low : (x > high ? high : x)) if (std::is_same::value) { - for(int i = 0; i < width; i++) { + for(size_t i = 0; i < width; i++) { auto tmp = src[i] * quant[i]; dst[i] = static_cast(CLIP(tmp, -128, 127)); } } if (std::is_same::value) { - for(int i = 0; i < width; i++) { + for(size_t i = 0; i < width; i++) { auto tmp = src[i] * quant[i]; dst[i] = static_cast(CLIP(tmp, 0, 255)); } @@ -76,7 +76,7 @@ class TransposeTest : public TestWithParam { template void test(float thresh) { // [num_heads, query_seq_len, head_size] => [query_seq_len, num_heads * head_size] - int num_heads = 2, query_seq_len = 10, head_size; + int num_heads = 2, query_seq_len = 10; for (int head_size = 1; head_size < 129; head_size++) { tensor2D src(num_heads, head_size * query_seq_len, true); tensor2D quant(1, head_size, true); diff --git a/tests/src/test_utility.cpp b/tests/src/test_utility_kernel_avx512.cpp similarity index 96% rename from tests/src/test_utility.cpp rename to tests/src/test_utility_kernel_avx512.cpp index c46a7ff..615a5b0 100644 --- a/tests/src/test_utility.cpp +++ b/tests/src/test_utility_kernel_avx512.cpp @@ -12,7 +12,7 @@ #include "llm_mm.hpp" #include "common/tensor2d.hpp" #include "common/tensor2d_helper.hpp" -#include "utility_avx512.hpp" +#include "utility_kernel_avx512.hpp" #include "test_common.hpp" using namespace std; diff --git a/tests/src/test_utility_repack1x2.cpp b/tests/src/test_utility_kernel_repack1x2_avx512.cpp similarity index 98% rename from tests/src/test_utility_repack1x2.cpp rename to tests/src/test_utility_kernel_repack1x2_avx512.cpp index 5b0a645..f140ba4 100644 --- a/tests/src/test_utility_repack1x2.cpp +++ b/tests/src/test_utility_kernel_repack1x2_avx512.cpp @@ -38,7 +38,7 @@ class RepackTest : public TestWithParam { } protected: - virtual void SetUp() { + virtual void SetUp() override { std::tie(_types) = GetParam(); }; @@ -49,7 +49,6 @@ class RepackTest : public TestWithParam { int kStep = 64 / sizeof(T); int K_padded = (K + kStep - 1) / kStep * kStep; int Ktails = K % kStep; - int Kbody = K - Ktails; // N_padded : round up to multiple of (2*16) int N_unit = 2 * 16; From 2a64a5ae871939c1ad41433b90adde2e3bf9bb89 Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Thu, 29 Jun 2023 01:51:25 +0800 Subject: [PATCH 08/54] add clang12 support & apply review comments. --- CMakeLists.txt | 7 ++++++- src/CMakeLists.txt | 5 +---- src/mm_kernel_common_amx.hpp | 39 +++++++++++++++++------------------ src/softmax_kernel_avx512.hpp | 4 ++-- 4 files changed, 28 insertions(+), 27 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index d1d46ba..b825c82 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -29,7 +29,12 @@ elseif(CMAKE_COMPILER_IS_GNUCC OR CMAKE_COMPILER_IS_GNUCXX) if(CMAKE_CXX_COMPILER_VERSION VERSION_LESS "11.3") message(FATAL_ERROR "Insufficient gcc compiler version, current ${CMAKE_CXX_COMPILER_VERSION}, minimum 11.3.") endif() - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=sapphirerapids") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=sapphirerapids -flax-vector-conversions") +elseif(OV_COMPILER_IS_CLANG) + if(CMAKE_CXX_COMPILER_VERSION VERSION_LESS "12") + message(FATAL_ERROR "Insufficient clang compiler version, current ${CMAKE_CXX_COMPILER_VERSION}, minimum 12.") + endif() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=sapphirerapids -flax-vector-conversions") elseif(CMAKE_CXX_COMPILER_ID STREQUAL "IntelLLVM") if(CMAKE_CXX_COMPILER_VERSION VERSION_LESS "2023.0") message(FATAL_ERROR "Insufficient intel compiler version, current ${CMAKE_CXX_COMPILER_VERSION}, minimum 2023.0.") diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 8bd7d53..3b970cf 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -9,7 +9,4 @@ file(GLOB_RECURSE SOURCE_FILES ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp) add_library(llmdnn STATIC ${SOURCE_FILES}) target_include_directories(llmdnn PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - INTERFACE ${CMAKE_CURRENT_SOURCE_DIR}/../include) -set_target_properties(llmdnn PROPERTIES - POSITION_INDEPENDENT_CODE ON - ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}) + PUBLIC $) diff --git a/src/mm_kernel_common_amx.hpp b/src/mm_kernel_common_amx.hpp index c6a8389..a8968e8 100644 --- a/src/mm_kernel_common_amx.hpp +++ b/src/mm_kernel_common_amx.hpp @@ -1435,12 +1435,11 @@ namespace PP { }; } -template -void prefetch_bytes(void *src) -{ - int8_t *p = reinterpret_cast(src); - for (int i = 0; i < bytes; i+=64) - _mm_prefetch(p + i + advance, sel); +// WA: clang could not find _mm_hint definition +#define prefetch_bytes(bytes, sel, advance, src) { \ + int8_t *p = reinterpret_cast(src); \ + for (int i = 0; i < bytes; i+=64) \ + _mm_prefetch(p + i + advance, sel); \ } // matmul (FC) @@ -2049,18 +2048,18 @@ struct Matmul { int8_t * pA0 = reinterpret_cast(&matA[0]); for(k=0; k(pB0); + // prefetch_bytes(1024, _MM_HINT_T1, 4096*48, pB0); _tile_loadd(3, pB0, 64); pB0 += 1024; // tile B0 32x16(16x16x2)/64x16(16x16x4) is always 1KB - // prefetch_bytes<1024, _MM_HINT_T1, 4096*48>(pB0); + // prefetch_bytes(1024, _MM_HINT_T1, 4096*48, pB0); _tile_loadd(4, pB0, 64); pB0 += 1024; // tile B1 32x16(16x16x2)/64x16(16x16x4) is always 1KB TILE_DP(0, 2, 3); // C0 += A*B0 TILE_DP(1, 2, 4); // C1 += A*B1 } if (Ktails) { _tile_loadd(2, pA0 - KbackoffBytes, strideA); - // prefetch_bytes<1024, _MM_HINT_T1, 4096*48>(pB0); + // prefetch_bytes(1024, _MM_HINT_T1, 4096*48, pB0); _tile_loadd(3, pB0, 64); pB0 += 1024; - // prefetch_bytes<1024, _MM_HINT_T1, 4096*48>(pB0); + // prefetch_bytes(1024, _MM_HINT_T1, 4096*48, pB0); _tile_loadd(4, pB0, 64); pB0 += 1024; TILE_DP(0, 2, 3); // C0 += A*B0 TILE_DP(1, 2, 4); // C1 += A*B1 @@ -2086,13 +2085,13 @@ struct Matmul { for (int k = 0; k < Kbody; k += kStep) { _tile_loadd(4, pA0, strideA); pA0 += 64; _tile_loadd(6, pB, 64); pB += 1024; - // prefetch_bytes<1024>(pB); + // prefetch_bytes(1024, _MM_HINT_T0, 4096, pB); TILE_DP(0, 4, 6); _tile_loadd(5, pA1, strideA); pA1 += 64; TILE_DP(2, 5, 6); _tile_loadd(7, pB, 64); pB += 1024; - // prefetch_bytes<1024>(pB); + // prefetch_bytes(1024, _MM_HINT_T0, 4096, pB); TILE_DP(1, 4, 7); TILE_DP(3, 5, 7); @@ -2100,13 +2099,13 @@ struct Matmul { if (Ktails) { _tile_loadd(4, pA0 - KbackoffBytes, strideA); _tile_loadd(6, pB, 64); pB += 1024; - // prefetch_bytes<1024>(pB); + // prefetch_bytes(1024, _MM_HINT_T0, 4096, pB); TILE_DP(0, 4, 6); _tile_loadd(5, pA1 - KbackoffBytes, strideA); TILE_DP(2, 5, 6); _tile_loadd(7, pB, 64); pB += 1024; - // prefetch_bytes<1024>(pB); + // prefetch_bytes(1024, _MM_HINT_T0, 4096, pB); TILE_DP(1, 4, 7); TILE_DP(3, 5, 7); @@ -2226,14 +2225,14 @@ struct Matmul { for(int k=0; k(pBint); + prefetch_bytes(512, _MM_HINT_T1, prefetch_ahead, pBint); functional::i8_to_bf16_Kx32<8>(pBint, pBdst); _tile_loadd(3, pBsrc, 64); functional::i8_to_bf16_Kx32<8>(pBint, pBdst + 8*32); _tile_dpbf16ps(0, 2, 3); // C0 += A*B0 - prefetch_bytes<512, _MM_HINT_T1, prefetch_ahead>(pBint); + prefetch_bytes(512, _MM_HINT_T1, prefetch_ahead, pBint); functional::i8_to_bf16_Kx32<8>(pBint, pBdst + 16*32); _tile_loadd(4, pBsrc + 16*32, 64); functional::i8_to_bf16_Kx32<8>(pBint, pBdst + 24*32); @@ -2242,24 +2241,24 @@ struct Matmul { } if (Ktails) { _tile_loadd(2, pA0 - Kbackoff, strideA); // backoff to prevent access beyond the end of A - prefetch_bytes<512, _MM_HINT_T1, prefetch_ahead>(pBint); + prefetch_bytes(512, _MM_HINT_T1, prefetch_ahead, pBint); functional::i8_to_bf16_Kx32<8>(pBint, pBdst); _tile_loadd(3, pBsrc, 64); functional::i8_to_bf16_Kx32<8>(pBint, pBdst + 8*32); _tile_dpbf16ps(0, 2, 3); // C0 += A*B0 - prefetch_bytes<512, _MM_HINT_T1, prefetch_ahead>(pBint); + prefetch_bytes(512, _MM_HINT_T1, prefetch_ahead, pBint); functional::i8_to_bf16_Kx32<8>(pBint, pBdst + 16*32); _tile_loadd(4, pBsrc + 16*32, 64); functional::i8_to_bf16_Kx32<8>(pBint, pBdst + 24*32); _tile_dpbf16ps(1, 2, 4); // C1 += A*B1 std::swap(pBsrc, pBdst); } - //prefetch_bytes<2048, _MM_HINT_T1, prefetch_ahead>(pBint); + //prefetch_bytes(2048, _MM_HINT_T1, prefetch_ahead, pBint); _tile_stored(0, pC0, buffC.stride); _tile_stored(1, pC0 + 16, buffC.stride); - //prefetch_bytes<2048, _MM_HINT_T1, prefetch_ahead>(pBint + 2048); + //prefetch_bytes(2048, _MM_HINT_T1, prefetch_ahead, pBint + 2048); //int valid_n = std::min(N - n, 32); (ppkernel)(buffC, 0, n + n0, M, valid_n); }); diff --git a/src/softmax_kernel_avx512.hpp b/src/softmax_kernel_avx512.hpp index 92e0e79..55237c1 100644 --- a/src/softmax_kernel_avx512.hpp +++ b/src/softmax_kernel_avx512.hpp @@ -162,7 +162,7 @@ namespace llmdnn { auto x = _mm512_loadu_ps(src + i); x = _mm512_mul_ps(x, reciprocal_sum_exp); auto out = _mm512_cvtne2ps_pbh(x, x); - _mm256_storeu_si256(reinterpret_cast<__m256i *>(dst + i), _mm512_extracti64x4_epi64(out, 0)); + _mm256_storeu_si256(reinterpret_cast<__m256i *>(dst + i), _mm512_extracti64x4_epi64((__m512i)out, 0)); i += 16; } // handle tails @@ -170,7 +170,7 @@ namespace llmdnn { auto x = _mm512_maskz_loadu_ps(x_mask, src + i); x = _mm512_mul_ps(x, reciprocal_sum_exp); auto out = _mm512_cvtne2ps_pbh(x, x); - _mm256_mask_storeu_epi16(dst + i, x_mask, _mm512_extracti64x4_epi64(out, 0)); + _mm256_mask_storeu_epi16(dst + i, x_mask, _mm512_extracti64x4_epi64((__m512i)out, 0)); } } if (std::is_same::value) { From 4aa2441b45c82ff01c5f6a26312236aee5a92b25 Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Thu, 29 Jun 2023 01:51:25 +0800 Subject: [PATCH 09/54] mha int8 support --- CMakeLists.txt | 2 +- include/llm_mha_gpt.hpp | 12 +-- src/mha_gpt_amx.cpp | 40 +++++--- src/mha_gpt_api.cpp | 4 +- src/softmax_kernel_avx512.hpp | 24 +++-- tests/script/ext/mha_gpt.cpp | 64 +++++++++++- tests/script/test_mha_gpt.py | 121 +++++++++++++++++++---- tests/src/test_softmax_kernel_avx512.cpp | 10 +- 8 files changed, 223 insertions(+), 54 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index b825c82..45b4f49 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -13,7 +13,7 @@ message(INFO "--------------------------------") message(STATUS "Build with tests: ${LLMDNN_BUILD_TESTS}") message(INFO "--------------------------------") -set(CMAKE_CXX_STANDARD 11) +set(CMAKE_CXX_STANDARD 17) if(MSVC) # TODO: validate if(MSVC_VERSION VERSION_LESS 1928) diff --git a/include/llm_mha_gpt.hpp b/include/llm_mha_gpt.hpp index 0822120..f30dc08 100644 --- a/include/llm_mha_gpt.hpp +++ b/include/llm_mha_gpt.hpp @@ -47,6 +47,7 @@ class mha_gpt { size_t head_size_aligned; // better to aligned to 64 bytes for best performance, apply for qkv size_t max_seq_len; // max seq length for computing the size of matmul tmp result float normal_factor; + // supported (qkv, dst): (bf16, bf16), (s8, s8) data_type_t qkv_precision; data_type_t dst_precision; }; @@ -63,22 +64,21 @@ class mha_gpt { // attention_mask[0] shape: [1, max_seq_len] uint8_t* attn_output; // output, compact, shape: [batch, query_seq_len, num_heads * head_size] size_t head_stride_in_kv; // kv stride for next head; kv may be preallocated a big buffer + // expected quant schema: + // q,k,v use per tensor quant, attn_output may use per tensor/channel quant float q_dequant; float k_dequant; float v_dequant; float qk_quant; - std::vector qkv_quant; // per channel - // float* qk_normal_dq; // per channel, each item = normal_factor * q_dequant * k_dequant, used for softmax input - // float* qk_quant; // per channel, used for softmax output - // float* qkv_dq_q; // per channel, each item = 1 / qk_quant * v_dequant * qkv_quant, used for matmul2 output + std::vector qkv_quant; // size==1 per tensor, size==head_size per channel }; mha_gpt(); - void create(const create_param& param); + bool create(const create_param& param); void exec(const exec_param& param); struct impl { - virtual void create(const create_param& param) = 0; + virtual bool create(const create_param& param) = 0; virtual void exec(const exec_param& param) = 0; }; protected: diff --git a/src/mha_gpt_amx.cpp b/src/mha_gpt_amx.cpp index 668515c..8431058 100644 --- a/src/mha_gpt_amx.cpp +++ b/src/mha_gpt_amx.cpp @@ -19,7 +19,7 @@ using namespace ov::cpu; namespace llmdnn { struct mha_gpt_impl_amx : public mha_gpt::impl { - void create(const mha_gpt::create_param& param) override; + bool create(const mha_gpt::create_param& param) override; void exec(const mha_gpt::exec_param& param) override; mha_gpt::create_param _create_param; @@ -32,6 +32,7 @@ struct mha_gpt_impl_amx : public mha_gpt::impl { std::shared_ptr bufferMatMul0Out; std::shared_ptr bufferMatMul1Out; + std::shared_ptr qkvQuantBuf; std::vector>> gemAvB_BF16xBF16; std::vector>> qKtrGemm_BF16xBF16; @@ -42,7 +43,15 @@ struct mha_gpt_impl_amx : public mha_gpt::impl { std::vector>> gemAvB_i8xi8; }; -void mha_gpt_impl_amx::create(const mha_gpt::create_param& param) { +bool mha_gpt_impl_amx::create(const mha_gpt::create_param& param) { + if (param.qkv_precision != dnnl_bf16 && param.dst_precision != dnnl_s8) { + std::cout << "input precision must be bf16 or int8.\n"; + return false; + } + if (param.dst_precision != dnnl_bf16 && param.dst_precision != dnnl_s8) { + std::cout << "dst precision must be bf16 or int8.\n"; + return false; + } _create_param = param; // q: [batch, num_heads, query_seq_len, head_size] @@ -65,6 +74,10 @@ void mha_gpt_impl_amx::create(const mha_gpt::create_param& param) { for (size_t i = 0; i < numThreads; i++) { gemAvB_i8xi8[i] = std::make_shared>(); } + qkvQuantBuf = std::shared_ptr( + reinterpret_cast(aligned_alloc(64, param.head_size * sizeof(float))), + [](void * p) { ::free(p); }); + memset(qkvQuantBuf.get(), 0, sizeof(param.head_size * sizeof(float))); } else { gemAvB_BF16xBF16.resize(numThreads); for (size_t i = 0; i < numThreads; i++) { @@ -86,9 +99,12 @@ void mha_gpt_impl_amx::create(const mha_gpt::create_param& param) { bufferMatMul0Out = std::shared_ptr( reinterpret_cast(aligned_alloc(64, numThreads * bufferMatMul0OutSize)), [](void * p) { ::free(p); }); + memset(bufferMatMul0Out.get(), 0, numThreads * bufferMatMul0OutSize); bufferMatMul1Out = std::shared_ptr( reinterpret_cast(aligned_alloc(64, numThreads * bufferMatMul1OutSize)), [](void * p) { ::free(p); }); + memset(bufferMatMul1Out.get(), 0, numThreads * bufferMatMul1OutSize); + return true; } void mha_gpt_impl_amx::mha_bf16(const mha_gpt::exec_param ¶m) { @@ -129,7 +145,7 @@ void mha_gpt_impl_amx::mha_bf16(const mha_gpt::exec_param ¶m) { float* pMatMul0Out = reinterpret_cast(bufferMatMul0Out_local); mul_add_f32_avx512(pMatMul0Out, pMatMul0Out, _create_param.normal_factor, pAddIn1_aux, param.key_seq_len); - softmax_avx512(reinterpret_cast(pMatMul0Out), pMatMul0Out, param.key_seq_len, nullptr, nullptr, nullptr); + softmax_avx512(reinterpret_cast(pMatMul0Out), pMatMul0Out, param.key_seq_len, nullptr); auto pOut_aux = pout + (i0 * batch_stride_in_attn + i1 * head_stride_in_attn) * outPrcSize; tensor2D matQK(param.query_seq_len, param.key_seq_len, reinterpret_cast(bufferMatMul0Out_local), rndup(param.key_seq_len * sizeof(ov::bfloat16), 64)); tensor2D matV(param.key_seq_len, _create_param.head_size, reinterpret_cast(pVIn0_aux), _create_param.head_size_aligned * sizeof(ov::bfloat16)); @@ -184,7 +200,7 @@ void mha_gpt_impl_amx::mha_bf16(const mha_gpt::exec_param ¶m) { float* src = reinterpret_cast(pMatMul0Out + m * rndup(param.key_seq_len * sizeof(float), 64)); ov::bfloat16* dst = reinterpret_cast(pMatMul0Out + m * rndup(param.key_seq_len * sizeof(ov::bfloat16), 64)); mul_add_f32_avx512(src, src, _create_param.normal_factor, pAddIn1_aux, valid_softmax_items); - softmax_avx512(dst, src, valid_softmax_items, nullptr, nullptr, nullptr); + softmax_avx512(dst, src, valid_softmax_items, nullptr); // attn_scores = torch.where(causal_mask, attn_scores, mask_value) if (param.key_seq_len > valid_softmax_items) { auto *invalidPtr = dst + valid_softmax_items; @@ -223,10 +239,12 @@ void mha_gpt_impl_amx::mha_i8(const mha_gpt::exec_param ¶m) { // dequant param auto mul_scales = _create_param.normal_factor * param.q_dequant * param.k_dequant; // prepare for per channel - auto qkv_quant = param.qkv_quant; - std::vector qk_quant_vec(_create_param.head_size, param.qk_quant); + assert(param.qkv_quant.size() == 1 || param.qkv_quant.size() == _create_param.head_size); for (size_t i = 0; i < param.qkv_quant.size(); i++) { - qkv_quant[i] *= param.v_dequant / param.qk_quant; + (qkvQuantBuf.get())[i] = param.qkv_quant[i] * param.v_dequant / param.qk_quant; + } + if (param.qkv_quant.size() == 1) { + std::fill(qkvQuantBuf.get() + 1, qkvQuantBuf.get() + _create_param.head_size, *qkvQuantBuf.get()); } size_t head_stride_in_q = _create_param.head_size_aligned * param.query_seq_len; size_t batch_stride_in_q = head_stride_in_q * _create_param.num_heads; @@ -255,7 +273,7 @@ void mha_gpt_impl_amx::mha_i8(const mha_gpt::exec_param ¶m) { float* pMatMul0Out = reinterpret_cast(bufferMatMul0Out_local); mul_add_f32_avx512(pMatMul0Out, pMatMul0Out, mul_scales, pAddIn1_aux, param.key_seq_len); - softmax_avx512(reinterpret_cast(pMatMul0Out), pMatMul0Out, param.key_seq_len, nullptr, nullptr, qk_quant_vec.data()); + softmax_avx512(reinterpret_cast(pMatMul0Out), pMatMul0Out, param.key_seq_len, param.qk_quant); auto pOut_aux = pout + (i0 * batch_stride_in_attn + i1 * head_stride_in_attn) * outPrcSize; tensor2D matQK(param.query_seq_len, param.key_seq_len, reinterpret_cast(bufferMatMul0Out_local), rndup(param.key_seq_len * sizeof(uint8_t), 64)); tensor2D matV(param.key_seq_len, _create_param.head_size, reinterpret_cast(pVIn0_aux), _create_param.head_size_aligned * sizeof(int8_t)); @@ -263,7 +281,7 @@ void mha_gpt_impl_amx::mha_i8(const mha_gpt::exec_param ¶m) { amx_kernel::PP::BiasGeluStore pp(matQKV); (*qKVGemm_ops[threadNum])(matQK, matV, 0, _create_param.head_size, pp); memcpy2d_stride_avx512(reinterpret_cast(pOut_aux), reinterpret_cast(bufferMatMul1Out_local), param.query_seq_len, - _create_param.head_size, _create_param.head_size_aligned * sizeof(float), _create_param.num_heads * _create_param.head_size, qkv_quant.data()); + _create_param.head_size, _create_param.head_size_aligned * sizeof(float), _create_param.num_heads * _create_param.head_size, qkvQuantBuf.get()); }); } else { auto numThreads = getTotalThreads(); @@ -310,7 +328,7 @@ void mha_gpt_impl_amx::mha_i8(const mha_gpt::exec_param ¶m) { float* src = reinterpret_cast(pMatMul0Out + m * rndup(param.key_seq_len * sizeof(float), 64)); uint8_t* dst = reinterpret_cast(pMatMul0Out + m * rndup(param.key_seq_len * sizeof(uint8_t), 64)); mul_add_f32_avx512(src, src, mul_scales, pAddIn1_aux, valid_softmax_items); - softmax_avx512(dst, src, valid_softmax_items, nullptr, nullptr, qk_quant_vec.data()); + softmax_avx512(dst, src, valid_softmax_items, param.qk_quant); // attn_scores = torch.where(causal_mask, attn_scores, mask_value) if (param.key_seq_len > valid_softmax_items) { auto *invalidPtr = dst + valid_softmax_items; @@ -329,7 +347,7 @@ void mha_gpt_impl_amx::mha_i8(const mha_gpt::exec_param ¶m) { // matmul1: [batch, num_heads, query_seq_len, head_size] // attn_output: [batch, query_seq_len, num_heads * head_size] memcpy2d_stride_avx512(reinterpret_cast(pOut_aux), reinterpret_cast(bufferMatMul1Out_local), seq_cout, - _create_param.head_size, _create_param.head_size_aligned * sizeof(float), _create_param.num_heads * _create_param.head_size, qkv_quant.data()); + _create_param.head_size, _create_param.head_size_aligned * sizeof(float), _create_param.num_heads * _create_param.head_size, qkvQuantBuf.get()); parallel_it_step(i0, param.batch, i1, _create_param.num_heads, seq, seq_cout_all); } }); diff --git a/src/mha_gpt_api.cpp b/src/mha_gpt_api.cpp index a979a5b..e702e8e 100644 --- a/src/mha_gpt_api.cpp +++ b/src/mha_gpt_api.cpp @@ -13,8 +13,8 @@ namespace llmdnn { mha_gpt::mha_gpt(): _impl(new_impl_amx()) { } -void mha_gpt::create(const create_param& param) { - _impl->create(param); +bool mha_gpt::create(const create_param& param) { + return _impl->create(param); } void mha_gpt::exec(const exec_param& param) { diff --git a/src/softmax_kernel_avx512.hpp b/src/softmax_kernel_avx512.hpp index 55237c1..81c52c5 100644 --- a/src/softmax_kernel_avx512.hpp +++ b/src/softmax_kernel_avx512.hpp @@ -84,8 +84,8 @@ namespace llmdnn { src = _mm512_mul_ps(src, two); } - template - void softmax_avx512(D* dst, float* src, int N, float* s_max=nullptr, float* s_sum=nullptr, float* quant=nullptr) { + template + void softmax_avx512(D* dst, float* src, int N, QD quant) { static_assert(std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value, "softmax_avx512 only support output data types ov::bfloat16/uint8_t/int8_t/float"); @@ -107,7 +107,6 @@ namespace llmdnn { x_max = _mm512_mask_max_ps(x_max, x_mask, x_max, x); } auto max = _mm512_reduce_max_ps(x_max); - if (s_max) *s_max = max; x_max = _mm512_set1_ps(max); // softmax @@ -131,7 +130,6 @@ namespace llmdnn { } auto sum = _mm512_reduce_add_ps(sum_exp); - if (s_sum) *s_sum = sum; sum_exp = _mm512_set1_ps(sum); auto reciprocal_sum_exp = _mm512_div_ps(one, sum_exp); // 1/sum_exp @@ -174,8 +172,12 @@ namespace llmdnn { } } if (std::is_same::value) { + __m512 q; + if constexpr (std::is_same::value) + q = _mm512_set1_ps(quant); for(i = 0; i < N - tail; i += 16) { - auto q = _mm512_loadu_ps(quant + i); + if constexpr (std::is_same::value) + q = _mm512_loadu_ps(quant + i); auto x = _mm512_loadu_ps(src + i); x = _mm512_mul_ps(x, reciprocal_sum_exp); x = _mm512_mul_ps(x, q); @@ -185,7 +187,8 @@ namespace llmdnn { // handle tails if (tail) { auto x = _mm512_maskz_loadu_ps(x_mask, src + i); - auto q = _mm512_maskz_loadu_ps(x_mask, quant + i); + if constexpr (std::is_same::value) + q = _mm512_maskz_loadu_ps(x_mask, quant + i); x = _mm512_mul_ps(x, reciprocal_sum_exp); x = _mm512_mul_ps(x, q); auto x_i = _mm512_cvtps_epi32(x); @@ -194,8 +197,12 @@ namespace llmdnn { } if (std::is_same::value) { auto zero = _mm512_setzero_epi32(); + __m512 q; + if constexpr (std::is_same::value) + q = _mm512_set1_ps(quant); for(i = 0; i < N - tail; i += 16) { - auto q = _mm512_loadu_ps(quant + i); + if constexpr (std::is_same::value) + q = _mm512_loadu_ps(quant + i); auto x = _mm512_loadu_ps(src + i); x = _mm512_mul_ps(x, reciprocal_sum_exp); x = _mm512_mul_ps(x, q); @@ -206,7 +213,8 @@ namespace llmdnn { // handle tails if (tail) { auto x = _mm512_maskz_loadu_ps(x_mask, src + i); - auto q = _mm512_maskz_loadu_ps(x_mask, quant + i); + if constexpr (std::is_same::value) + q = _mm512_maskz_loadu_ps(x_mask, quant + i); x = _mm512_mul_ps(x, reciprocal_sum_exp); x = _mm512_mul_ps(x, q); auto x_i = _mm512_cvtps_epi32(x); diff --git a/tests/script/ext/mha_gpt.cpp b/tests/script/ext/mha_gpt.cpp index 16dd95b..c5b397a 100644 --- a/tests/script/ext/mha_gpt.cpp +++ b/tests/script/ext/mha_gpt.cpp @@ -34,7 +34,8 @@ void regclass_mha_gpt(pybind11::module m) { throw pybind11::type_error("Incorrect qkv type " + qkv_precision_name); if (param.dst_precision == llmdnn::dnnl_data_type_undef) throw pybind11::type_error("Incorrect dst type " + dst_precision_name); - self.create(param); + if (!self.create(param)) + throw pybind11::type_error("Incorrect param"); }, py::arg("num_heads"), py::arg("head_size"), @@ -49,7 +50,7 @@ void regclass_mha_gpt(pybind11::module m) { :param num_heads: heads number. :type num_heads: int )"); - cls.def("exec", [] (llmdnn::mha_gpt& self, torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor attn_mask) { + cls.def("exec", [] (llmdnn::mha_gpt& self, const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v, const torch::Tensor& attn_mask) { // q: [batch, num_heads, query_seq_len, head_size] // k: [batch, num_heads, key_seq_len, head_size] // v: [batch, num_heads, value_seq_len, head_size] @@ -94,6 +95,65 @@ void regclass_mha_gpt(pybind11::module m) { R"( exec mha + :param num_heads: heads number. + :type num_heads: int + )"); + cls.def("exec_quant", [] (llmdnn::mha_gpt& self, const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v, const torch::Tensor& attn_mask, + float q_dequant, float k_dequant, float v_dequant, float qk_quant, const std::vector& qkv_quant) { + // q: [batch, num_heads, query_seq_len, head_size] + // k: [batch, num_heads, key_seq_len, head_size] + // v: [batch, num_heads, value_seq_len, head_size] + // attn_mask: [batch, MAX_POSITION_EMBEDDINGS] + // out: [batch, query_seq_len, num_heads * head_size] + AT_ASSERT(q.dim() == 4 && k.dim() == 4 && v.dim() == 4 && attn_mask.dim() == 2); + auto batch = q.size(0); + auto num_heads = q.size(1); + auto query_seq_len = q.size(2); + auto head_size = q.size(3); + auto key_seq_len = k.size(2); + auto attn_len = attn_mask.size(1); + AT_ASSERT(key_seq_len == v.size(2) && key_seq_len == attn_len && + batch == k.size(0) && batch == v.size(0) && 1 == attn_mask.size(0) && + num_heads == k.size(1) && num_heads == v.size(1) && + head_size == k.size(3) && head_size == v.size(3)); + + auto out = q.new_empty({batch, query_seq_len, num_heads * head_size}, torch::TensorOptions(torch::kInt8)); + llmdnn::mha_gpt::exec_param param; + param.batch = batch; + param.query_seq_len = query_seq_len; + param.key_seq_len = key_seq_len; + param.q = reinterpret_cast(q.data_ptr()); + param.attn_output = reinterpret_cast(out.data_ptr()); + param.head_stride_in_kv = key_seq_len * head_size; + param.k = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); + param.v = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); + param.attention_mask = reinterpret_cast(alloca(batch * sizeof(float*))); + param.q_dequant = q_dequant; + param.k_dequant = k_dequant; + param.v_dequant = v_dequant; + param.qk_quant = qk_quant; + param.qkv_quant = qkv_quant; + for (int i = 0; i < batch; i++) { + param.k[i] = reinterpret_cast(k.data_ptr()) + i * num_heads * key_seq_len * head_size; + param.v[i] = reinterpret_cast(v.data_ptr()) + i * num_heads * key_seq_len * head_size; + param.attention_mask[i] = attn_mask.data_ptr(); + } + + self.exec(param); + return out; + }, + py::arg("q"), + py::arg("k"), + py::arg("v"), + py::arg("attn_mask"), + py::arg("q_dequant"), + py::arg("k_dequant"), + py::arg("v_dequant"), + py::arg("qk_quant"), + py::arg("qkv_quant"), + R"( + exec mha quant + :param num_heads: heads number. :type num_heads: int )"); diff --git a/tests/script/test_mha_gpt.py b/tests/script/test_mha_gpt.py index d08d05a..184d753 100644 --- a/tests/script/test_mha_gpt.py +++ b/tests/script/test_mha_gpt.py @@ -25,9 +25,17 @@ def __init__(self, config): ) self.norm_factor = torch.sqrt(torch.tensor(self.head_size, dtype=torch.float32)).to(torch.get_default_dtype()) - def forward(self, query, key, value, attention_mask=None): + def forward(self, query, key, value, attention_mask, q_quant=None, k_quant=None, qk_quant=None, v_quant=None, requant=None): + if q_quant: + # quant + query = query.to(torch.float32) + key = key.to(torch.float32) + value = value.to(torch.float32) + # Compute attention - attn_output, attn_weights = self._attn(query, key, value, attention_mask) + attn_output, attn_weights = self._attn(query, key, value, attention_mask, q_quant, k_quant, qk_quant, v_quant) + if q_quant: + attn_output = (attn_output * requant).round().clamp(-128, 127).to(torch.int8) # Reshape outputs attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size) @@ -46,7 +54,7 @@ def _merge_heads(cls, tensor, num_attention_heads, attn_head_size): # -> [bs, seq_len, hidden_size] return tensor - def _attn(self, query, key, value, attention_mask=None, head_mask=None): + def _attn(self, query, key, value, attention_mask, q_quant, k_quant, qk_quant, v_quant): # q, k, v: [bs, num_attention_heads, seq_len, attn_head_size] # compute causal mask from causal mask buffer batch_size, num_attention_heads, query_length, attn_head_size = query.size() @@ -63,12 +71,16 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): dtype=query.dtype, device=key.device, ) + if q_quant: + norm_factor = torch.tensor(1.0, dtype=self.norm_factor.dtype, device=self.norm_factor.device) / self.norm_factor * (1 / q_quant) * (1 / k_quant) + else: + norm_factor = torch.tensor(1.0, dtype=self.norm_factor.dtype, device=self.norm_factor.device) / self.norm_factor attn_scores = torch.baddbmm( attn_scores, query, key.transpose(1, 2), beta=1.0, - alpha=(torch.tensor(1.0, dtype=self.norm_factor.dtype, device=self.norm_factor.device) / self.norm_factor), + alpha=(norm_factor), ) attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length) @@ -85,15 +97,15 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): attn_weights = nn.functional.softmax(attn_scores, dim=-1) attn_weights = attn_weights.to(value.dtype) - # Mask heads if we want to - if head_mask is not None: - attn_weights = attn_weights * head_mask - + if q_quant: + attn_weights = (attn_weights * qk_quant).round().clamp(0, 255).to(torch.uint8).to(torch.float32) attn_output = torch.matmul(attn_weights, value) + if q_quant: + attn_output = attn_output * ((1 / qk_quant) * (1 / v_quant)) return attn_output, attn_weights class GPTNeoXAttentionExt: - def __init__(self, num_attention_heads, hidden_size, max_position_embeddings): + def __init__(self, num_attention_heads, hidden_size, max_position_embeddings, is_int8=False): self.mha = ld.mha_gpt() num_heads = num_attention_heads head_size = hidden_size // num_attention_heads @@ -101,18 +113,33 @@ def __init__(self, num_attention_heads, hidden_size, max_position_embeddings): head_size_aligned = head_size normal_factor = 1.0 / math.sqrt(head_size) - qkv_precision_name = 'bf16' - dst_precision_name = 'bf16' + qkv_precision_name = 's8' if is_int8 else 'bf16' + dst_precision_name = 's8' if is_int8 else 'bf16' self.mha.create(num_heads, head_size, head_size_aligned, normal_factor, qkv_precision_name, dst_precision_name, max_seq_len) - def forward(self, query, key, value, attention_mask=None): + def forward(self, query, key, value, attention_mask): return self.mha.exec(query, key, value, attention_mask) + def forward_quant(self, query, key, value, attention_mask, q_quant, k_quant, qk_quant, v_quant, requant): + # q_dequant, k_dequant, v_dequant, qk_quant, std::vector& qkv_quant + return self.mha.exec_quant(query, key, value, attention_mask, 1.0 / q_quant, 1.0 / k_quant, 1.0 / v_quant, qk_quant, requant) + HEAD_NUM = 32 SIZE_PER_HEAD = 80 HIDDEN_SIZE = HEAD_NUM * SIZE_PER_HEAD MAX_POSITION_EMBEDDINGS = 1024 #2048 +def get_ref_model(): + class FakeConfig: + def __init__(self): + self.num_attention_heads = HEAD_NUM + self.hidden_size = HIDDEN_SIZE + self.max_position_embeddings = MAX_POSITION_EMBEDDINGS + config = FakeConfig() + ref_net = GPTNeoXAttention(config) + ref_net = ref_net.to(dtype=torch.bfloat16) + return ref_net + def test_gpt_neox(): inputs = [ # q, k, v, attn_mask @@ -133,14 +160,7 @@ def test_gpt_neox(): np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), np.zeros([1, 200], dtype=np.float32)), ] - class FakeConfig: - def __init__(self): - self.num_attention_heads = HEAD_NUM - self.hidden_size = HIDDEN_SIZE - self.max_position_embeddings = MAX_POSITION_EMBEDDINGS - config = FakeConfig() - ref_net = GPTNeoXAttention(config) - ref_net = ref_net.to(dtype=torch.bfloat16) + ref_net = get_ref_model() net = GPTNeoXAttentionExt(HEAD_NUM, HIDDEN_SIZE, MAX_POSITION_EMBEDDINGS) with torch.cpu.amp.autocast(): for (i, input) in enumerate(inputs): @@ -158,5 +178,64 @@ def __init__(self): print('done.') return +def test_gpt_neox_int8(): + low = -4 + high = 4 + range_ = high - low + q_quant = 127.0 / high + qs = [ + np.random.random(size=[2, HEAD_NUM, 900, SIZE_PER_HEAD]).astype(np.float32)*range_+low, + np.random.random(size=[2, HEAD_NUM , 1, SIZE_PER_HEAD]).astype(np.float32)*range_+low, + np.random.random(size=[2, HEAD_NUM , 1, SIZE_PER_HEAD]).astype(np.float32)*range_+low, + ] + low = -2 + high = 2 + range_ = high - low + k_quant = 127.0 / high + ks = [ + np.random.random(size=[2, HEAD_NUM, 900, SIZE_PER_HEAD]).astype(np.float32)*range_+low, + np.random.random(size=[2, HEAD_NUM, 901, SIZE_PER_HEAD]).astype(np.float32)*range_+low, + np.random.random(size=[2, HEAD_NUM, 902, SIZE_PER_HEAD]).astype(np.float32)*range_+low, + ] + low = -8 + high = 8 + range_ = high - low + v_quant = 127.0 / high + vs = [ + np.random.random(size=[2, HEAD_NUM, 900, SIZE_PER_HEAD]).astype(np.float32)*range_+low, + np.random.random(size=[2, HEAD_NUM, 901, SIZE_PER_HEAD]).astype(np.float32)*range_+low, + np.random.random(size=[2, HEAD_NUM, 902, SIZE_PER_HEAD]).astype(np.float32)*range_+low, + ] + # q, k, v, attn_mask + # q: [batch, num_heads, query_seq_len, head_size] + # k: [batch, num_heads, key_seq_len, head_size] + # v: [batch, num_heads, value_seq_len, head_size] + # attn: [1, MAX_POSITION_EMBEDDINGS] + inputs = [] + for i in range(len(qs)): + inputs.append((qs[i], ks[i], vs[i], np.zeros([1, ks[i].shape[-2]], dtype=np.float32))) + + ref_net = get_ref_model() + net = GPTNeoXAttentionExt(HEAD_NUM, HIDDEN_SIZE, MAX_POSITION_EMBEDDINGS, True) + qk_quant, requant = 255.0, 10.0 + with torch.cpu.amp.autocast(): + for (i, input) in enumerate(inputs): + q, k, v, attn_mask = input + q = torch.from_numpy(q) + k = torch.from_numpy(k) + v = torch.from_numpy(v) + attn_mask = torch.from_numpy(attn_mask) + q = (q * q_quant).round().clamp(-128, 127).to(torch.int8) + k = (k * k_quant).round().clamp(-128, 127).to(torch.int8) + v = (v * v_quant).round().clamp(-128, 127).to(torch.int8) + ref_output = ref_net.forward(q, k, v, attn_mask, q_quant, k_quant, qk_quant, v_quant, requant) + output = net.forward_quant(q, k, v, attn_mask, q_quant, k_quant, qk_quant, v_quant, [requant,]) + if (torch.abs(ref_output- output) > 2).any(): + print(f"error at index {i} ref:\n{ref_output} \ncur:\n {output} ") + assert(False) + + print('done.') + return + if __name__ == "__main__": - test_gpt_neox() \ No newline at end of file + test_gpt_neox_int8() \ No newline at end of file diff --git a/tests/src/test_softmax_kernel_avx512.cpp b/tests/src/test_softmax_kernel_avx512.cpp index bfb0b45..a78fa1f 100644 --- a/tests/src/test_softmax_kernel_avx512.cpp +++ b/tests/src/test_softmax_kernel_avx512.cpp @@ -86,15 +86,16 @@ class SoftmaxTest : public TestWithParam { template void test(float thresh) { for (int n = 1; n < 129; n++) { - tensor2D A(1, n, true); + tensor2D A(1, n, true), A_scalar(1, n, true); tensor2D quant(1, n, true); - tensor2D out(1, n, true), out_ref; + tensor2D out(1, n, true), out_ref, out_scalar(1, n, true); for (int i = 0; i < n; i++) { A[i] = static_cast(i) - n / 2; + A_scalar[i] = A[i]; } quant = 128.f; gen_ref(A, out_ref, quant); - llmdnn::softmax_avx512(out.data, A.data, n, nullptr, nullptr, quant.data); + llmdnn::softmax_avx512(out.data, A.data, n, quant.data); for (int i = 0; i < n; i++) { float a = out[i]; float b = out_ref[i]; @@ -102,6 +103,9 @@ class SoftmaxTest : public TestWithParam { FAIL() << " N: " << n << " pos: " << i << " opt: " << a << " ref: " << b; } } + // input is scalar + llmdnn::softmax_avx512(out_scalar.data, A_scalar.data, n, quant.data[0]); + ASSERT_TRUE(out == out_scalar); } } From 7ed062a09f1e88d8243f65132b1c6c0d17295dd5 Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Thu, 29 Jun 2023 01:51:25 +0800 Subject: [PATCH 10/54] rotary bf16 gpt initial support --- include/llm_emb_gpt.hpp | 49 +++++ src/emb_gpt_api.cpp | 24 +++ src/emb_gpt_avx512.cpp | 183 ++++++++++++++++++ src/emb_gpt_avx512.hpp | 17 ++ src/mha_gpt_amx.cpp | 2 +- src/rotary_kernel_avx512.hpp | 135 +++++++++++++ tests/script/ext/emb_gpt.cpp | 108 +++++++++++ tests/script/ext/module.cpp | 1 + tests/script/ext/module.hpp | 3 +- tests/script/ext/setup.py | 3 +- tests/script/test_rotary_pastkv.py | 213 +++++++++++++++++++++ tests/src/test_rotary_kernel_avx512.cpp | 109 +++++++++++ tests/src/test_transpose_kernel_avx512.cpp | 2 +- 13 files changed, 845 insertions(+), 4 deletions(-) create mode 100644 include/llm_emb_gpt.hpp create mode 100644 src/emb_gpt_api.cpp create mode 100644 src/emb_gpt_avx512.cpp create mode 100644 src/emb_gpt_avx512.hpp create mode 100644 src/rotary_kernel_avx512.hpp create mode 100644 tests/script/ext/emb_gpt.cpp create mode 100644 tests/script/test_rotary_pastkv.py create mode 100644 tests/src/test_rotary_kernel_avx512.cpp diff --git a/include/llm_emb_gpt.hpp b/include/llm_emb_gpt.hpp new file mode 100644 index 0000000..a23f3d5 --- /dev/null +++ b/include/llm_emb_gpt.hpp @@ -0,0 +1,49 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include "llm_types.hpp" + +namespace llmdnn { + +class emb_gpt { +public: + struct create_param { + size_t num_heads; + size_t head_size; + size_t head_size_aligned; // better to aligned to 64 bytes for best performance, apply for qkv + size_t max_seq_len; // max seq length for computing the size of matmul tmp result + // supported (qkv, dst): (bf16, bf16) + data_type_t qkv_precision; + data_type_t dst_precision; + size_t rotary_emb_base; + float rotary_pct; + }; + struct exec_param { + size_t batch; + size_t query_seq_len; + size_t past_seq_len; + uint8_t* qkv; + uint8_t* query_dst; + uint8_t** layer_past_key_padded; + uint8_t** layer_past_value_padded; + }; + + emb_gpt(); + bool create(const create_param& param); + void exec(const exec_param& param); + + struct impl { + virtual bool create(const create_param& param) = 0; + virtual void exec(const exec_param& param) = 0; + }; +protected: + std::shared_ptr _impl; +}; + +} diff --git a/src/emb_gpt_api.cpp b/src/emb_gpt_api.cpp new file mode 100644 index 0000000..b8a7b84 --- /dev/null +++ b/src/emb_gpt_api.cpp @@ -0,0 +1,24 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include + +#include "emb_gpt_avx512.hpp" + +namespace llmdnn { + +// interface +emb_gpt::emb_gpt(): _impl(new_impl_avx512()) { +} + +bool emb_gpt::create(const create_param& param) { + return _impl->create(param); +} + +void emb_gpt::exec(const exec_param& param) { + _impl->exec(param); +} + +} \ No newline at end of file diff --git a/src/emb_gpt_avx512.cpp b/src/emb_gpt_avx512.cpp new file mode 100644 index 0000000..b38c484 --- /dev/null +++ b/src/emb_gpt_avx512.cpp @@ -0,0 +1,183 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include + +#include "common/simple_parallel.hpp" +#include "common/utility.hpp" +#include "utility_kernel_avx512.hpp" +#include "transpose_kernel_avx512.hpp" +#include "llm_emb_gpt.hpp" +#include "emb_gpt_avx512.hpp" +#include "rotary_kernel_avx512.hpp" + +using namespace ov::cpu; + +namespace llmdnn { + +struct emb_gpt_impl_avx512 : public emb_gpt::impl { + bool create(const emb_gpt::create_param& param) override; + void exec(const emb_gpt::exec_param& param) override; + + void initRotery(size_t max_seq_len); + void applyRotaryPosEmbMemcpy(uint8_t* q_src, uint8_t* k_src, uint8_t* v_src, uint8_t* q_dst, uint8_t** k_dst, uint8_t** v_dst, + size_t batch, size_t q_seq_len, size_t past_seq_len); + + emb_gpt::create_param _create_param; + size_t _head_num = 32; + size_t _size_per_head = 80; + size_t _hidden_size = 32 * 80; + size_t _rotary_emb_base = 10000; + float _rotary_pct = 0.25; + size_t _max_seq_len = 400; + // aligned to cache line + size_t _size_per_head_aligned = 80; + int _rotary_ndims = 0; + std::shared_ptr _cos_cached; + std::shared_ptr _sin_cached; + int64_t _input_type_size = 1; + int64_t _output_type_size = 1; +}; + +bool emb_gpt_impl_avx512::create(const emb_gpt::create_param& param) { + if (param.qkv_precision != dnnl_bf16) { + std::cout << "input precision must be bf16 or int8.\n"; + return false; + } + // TODO: support s8 + // if (param.dst_precision != dnnl_bf16 && param.dst_precision != dnnl_s8) { + // std::cout << "dst precision must be bf16 or int8.\n"; + // return false; + // } + _create_param = param; + + _head_num = param.num_heads; + _size_per_head = param.head_size; + _size_per_head_aligned = param.head_size_aligned; + _hidden_size = param.head_size * param.num_heads; + _rotary_emb_base = param.rotary_emb_base; + _rotary_pct = param.rotary_pct; + _max_seq_len = param.max_seq_len; + _input_type_size = sizeof(ov::bfloat16); + _output_type_size = sizeof(ov::bfloat16); + if (param.dst_precision == dnnl_s8) + _output_type_size = sizeof(int8_t); + + _rotary_ndims = static_cast(_size_per_head * _rotary_pct); + initRotery(_max_seq_len); + + return true; +} + +void emb_gpt_impl_avx512::initRotery(size_t max_seq_len) { + std::vector inv_freq; + for (int i = 0; i < _rotary_ndims; i += 2) { + inv_freq.push_back(1.0f / (powf(_rotary_emb_base, static_cast(i) / _rotary_ndims))); + } + std::vector t; + for (size_t i = 0; i < max_seq_len * 2; i++) { + t.push_back(static_cast(i)); + } + auto width = _rotary_ndims / 2 * 2; + auto height = max_seq_len * 2; + auto capacity = height * width * sizeof(float); + _cos_cached = std::shared_ptr( + reinterpret_cast(aligned_alloc(64, capacity)), + [](void * p) { ::free(p); }); + _sin_cached = std::shared_ptr( + reinterpret_cast(aligned_alloc(64, capacity)), + [](void * p) { ::free(p); }); + + auto* cos_p = _cos_cached.get(); + auto* sin_p = _sin_cached.get(); + for (size_t i = 0; i < height; i++) { + for (int j = 0; j < width / 2; j++) { + cos_p[i * width + j] = cosf(t[i] * inv_freq[j]); + cos_p[i * width + j + width / 2] = cosf(t[i] * inv_freq[j]); + sin_p[i * width + j] = sinf(t[i] * inv_freq[j]); + sin_p[i * width + j + width / 2] = sinf(t[i] * inv_freq[j]); + } + } +} + +// q_src shape: [batch, q_seq_len, num_attention_heads, 3 * head_size] +// q_dst shape: [batch, num_attention_heads, q_seq_len, head_size_aligned] +// kv_src shape: [batch, q_seq_len, num_attention_heads, 3 * head_size] +// kv_dst shape: [batch, num_attention_heads, q_seq_len+past_seq_len, head_size_aligned] +void emb_gpt_impl_avx512::applyRotaryPosEmbMemcpy(uint8_t* q_src, uint8_t* k_src, uint8_t* v_src, uint8_t* q_dst, uint8_t** k_dst, uint8_t** v_dst, + size_t batch, size_t q_seq_len, size_t past_seq_len) { + auto key_offset = _output_type_size * past_seq_len * _size_per_head_aligned; + auto* cos_cached = _cos_cached.get() + past_seq_len * _rotary_ndims; + auto* sin_cached = _sin_cached.get() + past_seq_len * _rotary_ndims; + parallel_for3d(batch, _head_num, q_seq_len, [&](size_t b, size_t h, size_t s) { + // q, k rotary encoding + auto q_dst_batch = q_dst + b * _head_num * q_seq_len * _size_per_head_aligned * _output_type_size; + auto k_dst_batch = k_dst[b] + key_offset; + auto v_dst_batch = v_dst[b] + key_offset; + auto q_src_batch = q_src + b * _hidden_size * 3 * q_seq_len * _input_type_size; + auto k_src_batch = k_src + b * _hidden_size * 3 * q_seq_len * _input_type_size; + auto v_src_batch = v_src + b * _hidden_size * 3 * q_seq_len * _input_type_size; + auto q_dst_seq = q_dst_batch + s * _size_per_head_aligned * _output_type_size; + auto k_dst_seq = k_dst_batch + s * _size_per_head_aligned * _output_type_size; + auto v_dst_seq = v_dst_batch + s * _size_per_head_aligned * _output_type_size; + auto q_src_seq = q_src_batch + s * _hidden_size * 3 * _input_type_size; + auto k_src_seq = k_src_batch + s * _hidden_size * 3 * _input_type_size; + auto v_src_seq = v_src_batch + s * _hidden_size * 3 * _input_type_size; + auto* q_src_f = reinterpret_cast(q_src_seq + h * _size_per_head * 3 * _input_type_size); + auto* k_src_f = reinterpret_cast(k_src_seq + h * _size_per_head * 3 * _input_type_size); + auto* q_dst_f = reinterpret_cast(q_dst_seq + h * q_seq_len * _size_per_head_aligned * _output_type_size); + auto* k_dst_f = reinterpret_cast(k_dst_seq + h * _max_seq_len * _size_per_head_aligned * _output_type_size); + rotary_avx512(_rotary_ndims, cos_cached + s * _rotary_ndims, sin_cached + s * _rotary_ndims, q_src_f, k_src_f, q_dst_f, k_dst_f); + + // q, k concat + memcpy(reinterpret_cast(q_dst_f) + _rotary_ndims * _output_type_size, reinterpret_cast(q_src_f) + _rotary_ndims * _input_type_size, _output_type_size * (_size_per_head - _rotary_ndims)); + memcpy(reinterpret_cast(k_dst_f) + _rotary_ndims * _output_type_size, reinterpret_cast(k_src_f) + _rotary_ndims * _input_type_size, _output_type_size * (_size_per_head - _rotary_ndims)); + // v concat + memcpy(static_cast(v_dst_seq) + h * _max_seq_len * _size_per_head_aligned * _output_type_size, + static_cast(v_src_seq) + h * _size_per_head * 3 * _input_type_size, + _size_per_head * _output_type_size); + }); +} + +void emb_gpt_impl_avx512::exec(const emb_gpt::exec_param& param) { + // [batch, seq_len, (num_heads * 3 * head_size)] + // --> [batch, seq_len, num_heads, 3 * head_size] + auto* qkv = param.qkv; + auto query = qkv; // qkv[..., : self.head_size].permute(0, 2, 1, 3) + auto key = qkv + _size_per_head * _input_type_size; // qkv[..., self.head_size : 2 * self.head_size].permute(0, 2, 1, 3) + auto value = qkv + 2 * _size_per_head * _input_type_size; // qkv[..., 2 * self.head_size :].permute(0, 2, 1, 3) + auto query_dst = param.query_dst; + auto key_dst = param.layer_past_key_padded; + auto value_dst = param.layer_past_value_padded; + auto batch = param.batch; + auto query_seq_len = param.query_seq_len; + auto past_seq_len = param.past_seq_len; + // transpose + rotary embbeding: + // transpose: [batch, seq_len, num_attention_heads, 3 * head_size] --> + // 3 [batch, num_attention_heads, seq_len, head_size] + // rotary embbeding: part of key will write to past_key, part of query will write to tempory buffer + if (_create_param.dst_precision == dnnl_s8) { + // query pass part(temp buffer): query = torch.cat((query, query_pass), dim=-1) + // key pass part(past_key): key = torch.cat((key, key_pass), dim=-1) + // value(pastKeys): value = torch.cat((past_value, value), dim=-2) + // applyRotaryPosEmbMemcpyQuant(query, key, queryTranspose.get(), current_k_bufs, _output_type_size * new_seq_offset * _size_per_head_aligned, + // _cos_cached.get(), _sin_cached.get(), batch, seq_len, new_seq_offset, value, current_v_bufs); + assert(false); + } else { + // query pass part(temp buffer): query = torch.cat((query, query_pass), dim=-1) + // key pass part(past_key): key = torch.cat((key, key_pass), dim=-1) + // value(pastKeys): value = torch.cat((past_value, value), dim=-2) + // q_dst shape: [batch, num_attention_heads, q_seq_len, head_size_aligned] + // kv_dst shape: [batch, num_attention_heads, q_seq_len+past_seq_len, head_size_aligned] + applyRotaryPosEmbMemcpy(query, key, value, query_dst, key_dst, value_dst, batch, query_seq_len, past_seq_len); + } +} + +std::shared_ptr new_impl_avx512() { + return std::make_shared(); +} + +} \ No newline at end of file diff --git a/src/emb_gpt_avx512.hpp b/src/emb_gpt_avx512.hpp new file mode 100644 index 0000000..62e3bea --- /dev/null +++ b/src/emb_gpt_avx512.hpp @@ -0,0 +1,17 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include "llm_types.hpp" +#include "llm_emb_gpt.hpp" + +namespace llmdnn { + +std::shared_ptr new_impl_avx512(); + +} diff --git a/src/mha_gpt_amx.cpp b/src/mha_gpt_amx.cpp index 8431058..46601e4 100644 --- a/src/mha_gpt_amx.cpp +++ b/src/mha_gpt_amx.cpp @@ -44,7 +44,7 @@ struct mha_gpt_impl_amx : public mha_gpt::impl { }; bool mha_gpt_impl_amx::create(const mha_gpt::create_param& param) { - if (param.qkv_precision != dnnl_bf16 && param.dst_precision != dnnl_s8) { + if (param.qkv_precision != dnnl_bf16 && param.qkv_precision != dnnl_s8) { std::cout << "input precision must be bf16 or int8.\n"; return false; } diff --git a/src/rotary_kernel_avx512.hpp b/src/rotary_kernel_avx512.hpp new file mode 100644 index 0000000..9f1a6ce --- /dev/null +++ b/src/rotary_kernel_avx512.hpp @@ -0,0 +1,135 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#ifdef _WIN32 +#include +#else +#include +#include +#endif +#include "common/bf16.hpp" +#include "llm_types.hpp" +#include "utility_kernel_avx512.hpp" + +namespace llmdnn { + template + void rotary_avx512(size_t N, float* cos, float* sin, T* q_src, T* k_src, T* q_dst, T* k_dst) { + static_assert(std::is_same::value, + "rotary_avx512 only support output data types ov::bfloat16/int8_t"); + auto half = N / 2; + // for (size_t i = 0; i < half; i++) { + // q_dst[i] = q_src[i] * cos[i] - q_src[i + half] * sin[i]; + // k_dst[i] = k_src[i] * cos[i] - k_src[i + half] * sin[i]; + // } + // for (size_t i = half; i < N; i++) { + // q_dst[i] = q_src[i] * cos[i] + q_src[i - half] * sin[i]; + // k_dst[i] = k_src[i] * cos[i] + k_src[i - half] * sin[i]; + // } + size_t tail = half % 16; + __mmask16 x_mask = _cvtu32_mask16(0xFFFFu >> (16 - tail)); + size_t i; + for (i = 0; i < half - tail; i += 16) { + auto q = _mm256_loadu_epi16(q_src + i + half); + auto q_f = _mm512_cvtpbh_ps((__m256bh)q); + auto k = _mm256_loadu_epi16(k_src + i + half); + auto k_f = _mm512_cvtpbh_ps((__m256bh)k); + auto cos_f = _mm512_loadu_ps(cos + i); + auto sin_f = _mm512_loadu_ps(sin + i); + auto q_dst_f = _mm512_mul_ps(q_f, sin_f); + auto k_dst_f = _mm512_mul_ps(k_f, sin_f); + + q = _mm256_loadu_epi16(q_src + i); + q_f = _mm512_cvtpbh_ps((__m256bh)q); + k = _mm256_loadu_epi16(k_src + i); + k_f = _mm512_cvtpbh_ps((__m256bh)k); + + q_dst_f = _mm512_fmsub_ps(q_f, cos_f, q_dst_f); + k_dst_f = _mm512_fmsub_ps(k_f, cos_f, k_dst_f); + + auto out = _mm512_cvtne2ps_pbh(q_dst_f, q_dst_f); + _mm256_storeu_si256(reinterpret_cast<__m256i *>(q_dst + i), _mm512_extracti64x4_epi64((__m512i)out, 0)); + out = _mm512_cvtne2ps_pbh(k_dst_f, k_dst_f); + _mm256_storeu_si256(reinterpret_cast<__m256i *>(k_dst + i), _mm512_extracti64x4_epi64((__m512i)out, 0)); + } + if (tail) { + auto q = _mm256_maskz_loadu_epi16(x_mask, q_src + i + half); + auto q_f = _mm512_cvtpbh_ps((__m256bh)q); + auto k = _mm256_maskz_loadu_epi16(x_mask, k_src + i + half); + auto k_f = _mm512_cvtpbh_ps((__m256bh)k); + auto cos_f = _mm512_maskz_loadu_ps(x_mask, cos + i); + auto sin_f = _mm512_maskz_loadu_ps(x_mask, sin + i); + auto q_dst_f = _mm512_mul_ps(q_f, sin_f); + auto k_dst_f = _mm512_mul_ps(k_f, sin_f); + + q = _mm256_maskz_loadu_epi16(x_mask, q_src + i); + q_f = _mm512_cvtpbh_ps((__m256bh)q); + k = _mm256_maskz_loadu_epi16(x_mask, k_src + i); + k_f = _mm512_cvtpbh_ps((__m256bh)k); + + q_dst_f = _mm512_fmsub_ps(q_f, cos_f, q_dst_f); + k_dst_f = _mm512_fmsub_ps(k_f, cos_f, k_dst_f); + + auto out = _mm512_cvtne2ps_pbh(q_dst_f, q_dst_f); + _mm256_mask_storeu_epi16(q_dst + i, x_mask, _mm512_extracti64x4_epi64((__m512i)out, 0)); + out = _mm512_cvtne2ps_pbh(k_dst_f, k_dst_f); + _mm256_mask_storeu_epi16(k_dst + i, x_mask, _mm512_extracti64x4_epi64((__m512i)out, 0)); + } + // second half + q_src += half; + k_src += half; + cos += half; + sin += half; + q_dst += half; + k_dst += half; + for (i = 0; i < half - tail; i += 16) { + auto q = _mm256_loadu_epi16(q_src + i - half); + auto q_f = _mm512_cvtpbh_ps((__m256bh)q); + auto k = _mm256_loadu_epi16(k_src + i - half); + auto k_f = _mm512_cvtpbh_ps((__m256bh)k); + auto cos_f = _mm512_loadu_ps(cos + i); + auto sin_f = _mm512_loadu_ps(sin + i); + auto q_dst_f = _mm512_mul_ps(q_f, sin_f); + auto k_dst_f = _mm512_mul_ps(k_f, sin_f); + + q = _mm256_loadu_epi16(q_src + i); + q_f = _mm512_cvtpbh_ps((__m256bh)q); + k = _mm256_loadu_epi16(k_src + i); + k_f = _mm512_cvtpbh_ps((__m256bh)k); + + q_dst_f = _mm512_fmadd_ps(q_f, cos_f, q_dst_f); + k_dst_f = _mm512_fmadd_ps(k_f, cos_f, k_dst_f); + + auto out = _mm512_cvtne2ps_pbh(q_dst_f, q_dst_f); + _mm256_storeu_si256(reinterpret_cast<__m256i *>(q_dst + i), _mm512_extracti64x4_epi64((__m512i)out, 0)); + out = _mm512_cvtne2ps_pbh(k_dst_f, k_dst_f); + _mm256_storeu_si256(reinterpret_cast<__m256i *>(k_dst + i), _mm512_extracti64x4_epi64((__m512i)out, 0)); + } + if (tail) { + auto q = _mm256_maskz_loadu_epi16(x_mask, q_src + i - half); + auto q_f = _mm512_cvtpbh_ps((__m256bh)q); + auto k = _mm256_maskz_loadu_epi16(x_mask, k_src + i - half); + auto k_f = _mm512_cvtpbh_ps((__m256bh)k); + auto cos_f = _mm512_maskz_loadu_ps(x_mask, cos + i); + auto sin_f = _mm512_maskz_loadu_ps(x_mask, sin + i); + auto q_dst_f = _mm512_mul_ps(q_f, sin_f); + auto k_dst_f = _mm512_mul_ps(k_f, sin_f); + + q = _mm256_maskz_loadu_epi16(x_mask, q_src + i); + q_f = _mm512_cvtpbh_ps((__m256bh)q); + k = _mm256_maskz_loadu_epi16(x_mask, k_src + i); + k_f = _mm512_cvtpbh_ps((__m256bh)k); + + q_dst_f = _mm512_fmadd_ps(q_f, cos_f, q_dst_f); + k_dst_f = _mm512_fmadd_ps(k_f, cos_f, k_dst_f); + + auto out = _mm512_cvtne2ps_pbh(q_dst_f, q_dst_f); + _mm256_mask_storeu_epi16(q_dst + i, x_mask, _mm512_extracti64x4_epi64((__m512i)out, 0)); + out = _mm512_cvtne2ps_pbh(k_dst_f, k_dst_f); + _mm256_mask_storeu_epi16(k_dst + i, x_mask, _mm512_extracti64x4_epi64((__m512i)out, 0)); + } + } +} \ No newline at end of file diff --git a/tests/script/ext/emb_gpt.cpp b/tests/script/ext/emb_gpt.cpp new file mode 100644 index 0000000..7818751 --- /dev/null +++ b/tests/script/ext/emb_gpt.cpp @@ -0,0 +1,108 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include "alloca.h" +#include "module.hpp" +#include "common/utility.hpp" +#include "utility_kernel_amx.hpp" +#include "llm_emb_gpt.hpp" +#include "test_common.hpp" + +using namespace torch::indexing; + +void regclass_emb_gpt(pybind11::module m) { + py::class_> cls(m, "emb_gpt"); + cls.def(py::init<>()); + cls.def("create", [] (llmdnn::emb_gpt& self, + const size_t num_heads, + const size_t head_size, + const size_t head_size_aligned, + const std::string qkv_precision_name, + const std::string dst_precision_name, + const size_t max_seq_len, + const size_t rotary_emb_base, + float rotary_pct) { + llmdnn::emb_gpt::create_param param; + param.num_heads = num_heads; + param.head_size = head_size; + param.head_size_aligned = head_size_aligned; + param.qkv_precision = llmdnn::get_dt_from_str(qkv_precision_name); + param.dst_precision = llmdnn::get_dt_from_str(dst_precision_name); + param.max_seq_len = max_seq_len; + param.rotary_emb_base = rotary_emb_base; + param.rotary_pct = rotary_pct; + if (param.qkv_precision == llmdnn::dnnl_data_type_undef) + throw pybind11::type_error("Incorrect qkv type " + qkv_precision_name); + if (param.dst_precision == llmdnn::dnnl_data_type_undef) + throw pybind11::type_error("Incorrect dst type " + dst_precision_name); + if (!self.create(param)) + throw pybind11::type_error("Incorrect param"); + }, + py::arg("num_heads"), + py::arg("head_size"), + py::arg("head_size_aligned"), + py::arg("qkv_precision_name"), + py::arg("dst_precision_name"), + py::arg("max_seq_len"), + py::arg("rotary_emb_base"), + py::arg("rotary_pct"), + R"( + Create emb + + :param num_heads: heads number. + :type num_heads: int + )"); + // torch::List + cls.def("exec", [] (llmdnn::emb_gpt& self, const torch::Tensor& qkv, const torch::Tensor& layer_past_key_padded, + const torch::Tensor& layer_past_value_padded, const torch::Tensor& query_padded, size_t past_seq_len) { + // qkv: [batch, seq_len, (num_heads * 3 * head_size)] + // layer_past_padded: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] + // past_seq_len: past_seq_len==layer_past.shape[-2] + // query_padded: [batch, num_heads, query_seq_len, head_size_aligned] + // key/value: [batch, num_heads, query_seq_len+past_seq_len, head_size_aligned] + AT_ASSERT(qkv.dim() == 3 && layer_past_key_padded.dim() == 4 && layer_past_value_padded.dim() == 4 && + qkv.size(0) == layer_past_key_padded.size(0) && + layer_past_key_padded.dim() == layer_past_value_padded.dim()); + AT_ASSERT(query_padded.dim() == 4 && query_padded.size(0) == qkv.size(0) && + query_padded.size(1) == layer_past_key_padded.size(1) && query_padded.size(2) == qkv.size(1) && + query_padded.size(3) == layer_past_key_padded.size(3)); + auto batch = qkv.size(0); + auto num_heads = layer_past_key_padded.size(1); + auto query_seq_len = qkv.size(1); + auto head_size = qkv.size(2) / 3 / num_heads; + AT_ASSERT(past_seq_len <= layer_past_key_padded.size(2) && head_size <= layer_past_key_padded.size(3) && + query_seq_len <= layer_past_key_padded.size(2)); + + llmdnn::emb_gpt::exec_param param; + param.batch = batch; + param.query_seq_len = query_seq_len; + param.past_seq_len = past_seq_len; + param.qkv = reinterpret_cast(qkv.data_ptr()); + param.query_dst = reinterpret_cast(query_padded.data_ptr()); + param.layer_past_key_padded = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); + param.layer_past_value_padded = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); + for (int i = 0; i < batch; i++) { + param.layer_past_key_padded[i] = reinterpret_cast(layer_past_key_padded[i].data_ptr()); + param.layer_past_value_padded[i] = reinterpret_cast(layer_past_value_padded[i].data_ptr()); + } + + self.exec(param); + + // auto options = torch::TensorOptions().dtype(torch::kBFloat16); + // auto query = torch::from_blob(param.query, {batch, num_heads, query_seq_len, head_size}, options); + }, + py::arg("qkv"), + py::arg("layer_past_key_padded"), + py::arg("layer_past_value_padded"), + py::arg("query_padded"), + py::arg("past_seq_len"), + R"( + exec emb + + :param num_heads: heads number. + :type num_heads: int + )"); +} \ No newline at end of file diff --git a/tests/script/ext/module.cpp b/tests/script/ext/module.cpp index a3ff62b..d016470 100644 --- a/tests/script/ext/module.cpp +++ b/tests/script/ext/module.cpp @@ -15,4 +15,5 @@ PYBIND11_MODULE(llmdnn, m) { std::cout << "init amx failed.\n"; } regclass_mha_gpt(m); + regclass_emb_gpt(m); } \ No newline at end of file diff --git a/tests/script/ext/module.hpp b/tests/script/ext/module.hpp index 8f8f720..d9f8850 100644 --- a/tests/script/ext/module.hpp +++ b/tests/script/ext/module.hpp @@ -6,4 +6,5 @@ #include -void regclass_mha_gpt(pybind11::module m); \ No newline at end of file +void regclass_mha_gpt(pybind11::module m); +void regclass_emb_gpt(pybind11::module m); \ No newline at end of file diff --git a/tests/script/ext/setup.py b/tests/script/ext/setup.py index 43fe15c..ebca40f 100644 --- a/tests/script/ext/setup.py +++ b/tests/script/ext/setup.py @@ -30,7 +30,8 @@ ext_modules=[ cpp_extension.CppExtension( 'llmdnn', - ['module.cpp', 'mha_gpt.cpp', f'../../src/test_common.cpp'], + ['module.cpp', f'../../src/test_common.cpp', + 'mha_gpt.cpp', 'emb_gpt.cpp'], extra_compile_args=extra_args, include_dirs=[f'{os.getcwd()}/../../src', f'{os.getcwd()}/../../../include', diff --git a/tests/script/test_rotary_pastkv.py b/tests/script/test_rotary_pastkv.py new file mode 100644 index 0000000..f8ee6e2 --- /dev/null +++ b/tests/script/test_rotary_pastkv.py @@ -0,0 +1,213 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +import math +import sys +import torch +import numpy as np +import llmdnn as ld +from torch import nn + +# copy from transformers/models/gpt_neox/modeling_gpt_neox.py +class RotaryEmbedding(torch.nn.Module): + def __init__(self, dim, max_position_embeddings, base=10000, device=None): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) + self.register_buffer("inv_freq", inv_freq) + + # Build here to make `torch.jit.trace` work. + self.max_seq_len_cached = max_position_embeddings + t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.cos_cached = emb.cos()[None, None, :, :] + self.sin_cached = emb.sin()[None, None, :, :] + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. + if seq_len > self.max_seq_len_cached: + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + self.cos_cached = emb.cos()[None, None, :, :] + self.sin_cached = emb.sin()[None, None, :, :] + return self.cos_cached[:seq_len, ...].to(x.device), self.sin_cached[:seq_len, ...].to(x.device) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0): + cos = cos[..., offset : q.shape[-2] + offset, :] + sin = sin[..., offset : q.shape[-2] + offset, :] + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + +class GPTNeoXAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.num_attention_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.head_size = self.hidden_size // self.num_attention_heads + self.rotary_ndims = int(self.head_size * config.rotary_pct) + self.rotary_emb = RotaryEmbedding( + self.rotary_ndims, config.max_position_embeddings, base=config.rotary_emb_base + ) + + # qkv: [batch, seq_len, (num_heads * 3 * head_size)] + # layer_past: [batch, num_attention_heads, past_seq_len, head_size] + # return: (key, value), key/value: [batch, num_attention_heads, seq_len+layer_past[0].shape[-2], head_size] + def forward(self, qkv, layer_past): + has_layer_past = layer_past is not None + # [batch, seq_len, (num_heads * 3 * head_size)] + # --> [batch, seq_len, num_heads, 3 * head_size] + new_qkv_shape = qkv.size()[:-1] + (self.num_attention_heads, 3 * self.head_size) + qkv = qkv.view(*new_qkv_shape) + + # [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size] + query = qkv[..., : self.head_size].permute(0, 2, 1, 3) + key = qkv[..., self.head_size : 2 * self.head_size].permute(0, 2, 1, 3) + value = qkv[..., 2 * self.head_size :].permute(0, 2, 1, 3) + + # Compute rotary embeddings on rotary_ndims + query_rot = query[..., : self.rotary_ndims] + query_pass = query[..., self.rotary_ndims :] + key_rot = key[..., : self.rotary_ndims] + key_pass = key[..., self.rotary_ndims :] + + # Compute token offset for rotary embeddings (when decoding) + seq_len = key.shape[-2] + offset = 0 + if has_layer_past: + offset = layer_past[0].shape[-2] + seq_len += offset + cos, sin = self.rotary_emb(value, seq_len=seq_len) + query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, offset=offset) + query = torch.cat((query, query_pass), dim=-1) + key = torch.cat((key, key_pass), dim=-1) + + # Cache QKV values + if has_layer_past: + past_key = layer_past[0] + past_value = layer_past[1] + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + present = (key, value) + + return present, query, key, value + +class GPTNeoXAttentionExt: + def __init__(self, num_attention_heads, hidden_size, head_size_aligned, max_position_embeddings, rotary_emb_base, rotary_pct, is_int8=False): + self.emd = ld.emb_gpt() + num_heads = num_attention_heads + head_size = hidden_size // num_attention_heads + max_seq_len = max_position_embeddings + + qkv_precision_name = 's8' if is_int8 else 'bf16' + dst_precision_name = 's8' if is_int8 else 'bf16' + self.emd.create(num_heads, head_size, head_size_aligned, qkv_precision_name, + dst_precision_name, max_seq_len, rotary_emb_base, rotary_pct) + + # qkv: [batch, seq_len, (num_heads * 3 * head_size)] + # layer_past_padded: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] + # past_seq_len: past_seq_len==layer_past.shape[-2] + # return: + # 0: (k, v): ([batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned], [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned]) + # 1: query: [batch, num_attention_heads, seq_len, head_size_aligned] + # 2: k: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] + # 3: v: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] + def forward(self, qkv, layer_past_key_padded, layer_past_value_padded, query_padded, past_seq_len): + self.emd.exec(qkv, layer_past_key_padded, layer_past_value_padded, query_padded, past_seq_len) + + +HEAD_NUM = 32 +SIZE_PER_HEAD = 80 +SIZE_PER_HEAD_ALIGN = 96 +HIDDEN_SIZE = HEAD_NUM * SIZE_PER_HEAD +MAX_POSITION_EMBEDDINGS = 1024 #2048 +ROTARY_EMB_BASE = 10000 +ROTARY_PCT = 0.25 +MAX_SEQ_LEN = 1024 +def get_ref_model(): + class FakeConfig: + def __init__(self): + self.num_attention_heads = HEAD_NUM + self.hidden_size = HIDDEN_SIZE + self.rotary_pct = ROTARY_PCT + self.max_position_embeddings = MAX_POSITION_EMBEDDINGS + self.rotary_emb_base = ROTARY_EMB_BASE + config = FakeConfig() + ref_net = GPTNeoXAttention(config) + ref_net = ref_net.to(dtype=torch.bfloat16) + return ref_net + +def test_gpt_neox(): + inputs = [ + # qkv: [batch, seq_len, (num_heads * 3 * head_size)] + # layer_past: [batch, num_attention_heads, past_seq_len, head_size] + (np.random.random(size=[2, 200, 3 * HEAD_NUM * SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 0, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 0, SIZE_PER_HEAD]).astype(np.float32)), + (np.random.random(size=[2, 1, 3 * HEAD_NUM * SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32)), + ] + ref_net = get_ref_model() + net = GPTNeoXAttentionExt(HEAD_NUM, HIDDEN_SIZE, SIZE_PER_HEAD_ALIGN, MAX_POSITION_EMBEDDINGS, ROTARY_EMB_BASE, ROTARY_PCT) + with torch.cpu.amp.autocast(): + for (i, input) in enumerate(inputs): + qkv, layer_past_key, layer_past_value = input + qkv = torch.from_numpy(qkv).to(torch.bfloat16) + past_seq_len = layer_past_key.shape[-2] + shape = list(layer_past_key.shape) + shape[-2] = MAX_SEQ_LEN + shape[-1] = SIZE_PER_HEAD_ALIGN + layer_past_key_padded = torch.zeros(shape, dtype=torch.bfloat16) + layer_past_value_padded = torch.zeros(shape, dtype=torch.bfloat16) + query_shape = list(shape) + query_shape[2] = qkv.shape[1] + query_padded = torch.zeros(query_shape, dtype=torch.bfloat16) + layer_past_key = torch.from_numpy(layer_past_key).to(torch.bfloat16) + layer_past_value = torch.from_numpy(layer_past_value).to(torch.bfloat16) + layer_past_key_padded[:,:,:layer_past_key.shape[-2],:layer_past_key.shape[-1]] = layer_past_key + layer_past_value_padded[:,:,:layer_past_key.shape[-2],:layer_past_key.shape[-1]] = layer_past_value + ref_output, query_ref, key_ref, value_ref = ref_net.forward(qkv, (layer_past_key, layer_past_value)) + query_ref = query_ref.to(dtype=torch.bfloat16) + key_ref = key_ref.to(dtype=torch.bfloat16) + net.forward(qkv, layer_past_key_padded, layer_past_value_padded, query_padded, past_seq_len) + output, query, key, value = (layer_past_key_padded, layer_past_value_padded), query_padded, layer_past_key_padded, layer_past_value_padded + # check output + if not torch.allclose(ref_output[0].to(dtype=torch.bfloat16), output[0][:,:,:ref_output[0].shape[-2],:ref_output[0].shape[-1]], rtol=0.001, atol=0.01): + print(f"error at past key index {i} ref:\n{ref_output[0]} \ncur:\n {output[0]} ") + assert(False) + if not torch.allclose(ref_output[1], output[1][:,:,:ref_output[1].shape[-2],:ref_output[1].shape[-1]], rtol=0.001, atol=0.01): + print(f"error at past value index {i} ref:\n{ref_output[1]} \ncur:\n {output[1]} ") + assert(False) + # check query + if not torch.allclose(query_ref, query[:,:,:,:query_ref.shape[-1]], rtol=0.001, atol=0.01): + print(f"error at query index {i} ref:\n{query_ref} \ncur:\n {query} ") + assert(False) + # check key + if not torch.allclose(key_ref, key[:,:,:key_ref.shape[-2],:key_ref.shape[-1]], rtol=0.001, atol=0.01): + print(f"error at key index {i} ref:\n{key_ref} \ncur:\n {key} ") + assert(False) + # check value + if not torch.allclose(value_ref, value[:,:,:value_ref.shape[-2],:value_ref.shape[-1]], rtol=0.001, atol=0.01): + print(f"error at value index {i} ref:\n{value_ref} \ncur:\n {value} ") + assert(False) + + print('done.') + return + +if __name__ == "__main__": + test_gpt_neox() \ No newline at end of file diff --git a/tests/src/test_rotary_kernel_avx512.cpp b/tests/src/test_rotary_kernel_avx512.cpp new file mode 100644 index 0000000..6be5829 --- /dev/null +++ b/tests/src/test_rotary_kernel_avx512.cpp @@ -0,0 +1,109 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include +#include +#include +#include +#include "gtest/gtest.h" +#include "llm_mm.hpp" +#include "common/tensor2d.hpp" +#include "common/tensor2d_helper.hpp" +#include "rotary_kernel_avx512.hpp" +#include "test_common.hpp" + +using namespace std; +using namespace llmdnn; +using ::testing::TestWithParam; +using ::testing::Values; +using ::testing::ValuesIn; + +using RotaryTestParamSet = std::tuple< + data_type_t // data type + >; + +class RotaryTest : public TestWithParam { +public: + static std::string getTestCaseName(const testing::TestParamInfo& obj) { + data_type_t types; + std::tie(types) = obj.param; + + std::ostringstream result; + result << dtype_to_str(types); + return result.str(); + } + +protected: + virtual void SetUp() override { + std::tie(_types) = GetParam(); + }; + + template + static void rotary_emb(size_t rotaryNdims, float* cos, float* sin, T* q_src, T* k_src, T* q_dst, T* k_dst) { + auto halfRotaryNdims = rotaryNdims / 2; + for (size_t i = 0; i < halfRotaryNdims; i++) { + q_dst[i] = q_src[i] * cos[i] - q_src[i + halfRotaryNdims] * sin[i]; + k_dst[i] = k_src[i] * cos[i] - k_src[i + halfRotaryNdims] * sin[i]; + } + for (size_t i = halfRotaryNdims; i < rotaryNdims; i++) { + q_dst[i] = q_src[i] * cos[i] + q_src[i - halfRotaryNdims] * sin[i]; + k_dst[i] = k_src[i] * cos[i] + k_src[i - halfRotaryNdims] * sin[i]; + } + } + + template + void test(float thresh) { + for (int n = 6; n < 129; n += 2) { + tensor2D q_src(1, n, true); + tensor2D k_src(1, n, true); + tensor2D q_dst(1, n, true); + tensor2D k_dst(1, n, true); + tensor2D q_dst_ref(1, n, true); + tensor2D k_dst_ref(1, n, true); + tensor2D cos(1, n, true); + tensor2D sin(1, n, true); + for (int i = 0; i < n; i++) { + q_src[i] = i % 19 - 10; + k_src[i] = i % 19 - 9; + cos[i] = i % 19 - 8; + sin[i] = i % 19 - 7; + } + rotary_emb(n, cos.data, sin.data, q_src.data, k_src.data, q_dst_ref.data, k_dst_ref.data); + rotary_avx512(n, cos.data, sin.data, q_src.data, k_src.data, q_dst.data, k_dst.data); + for (int i = 0; i < n; i++) { + float q = q_dst[i]; + float q_ref = q_dst_ref[i]; + float k = k_dst[i]; + float k_ref = k_dst_ref[i]; + if (std::abs(q - q_ref) > thresh) { + FAIL() << " q is not equal, N: " << n << " pos: " << i << " opt: " << q << " ref: " << q_ref; + } + if (std::abs(k - k_ref) > thresh) { + FAIL() << " k is not equal, N: " << n << " pos: " << i << " opt: " << k << " ref: " << k_ref; + } + } + } + } + + data_type_t _types; +}; + +TEST_P(RotaryTest, rotary) { + if (_types == dnnl_s8) { + ASSERT_TRUE(false); + } else { + test(0.01f); + } +} + +const std::vector types = { + dnnl_bf16 +}; + +INSTANTIATE_TEST_SUITE_P(smoke_Rotary, RotaryTest, + ::testing::Combine(ValuesIn(types)), + RotaryTest::getTestCaseName); diff --git a/tests/src/test_transpose_kernel_avx512.cpp b/tests/src/test_transpose_kernel_avx512.cpp index b2b36e8..46035b5 100644 --- a/tests/src/test_transpose_kernel_avx512.cpp +++ b/tests/src/test_transpose_kernel_avx512.cpp @@ -100,7 +100,7 @@ class TransposeTest : public TestWithParam { for (int i = 0; i < num_heads * head_size * query_seq_len; i++) { float a = dst[i]; float b = dst_ref[i]; - if (a - b > thresh) { + if (std::abs(a - b) > thresh) { FAIL() << " N: " << head_size << " pos: " << i << " opt: " << a << " ref: " << b; } } From 28ba2d547c43569cf9af7f968d032ce1058373f3 Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Thu, 29 Jun 2023 01:51:25 +0800 Subject: [PATCH 11/54] add chatglm support --- include/llm_emb_gpt.hpp | 2 + include/llm_mha_gpt.hpp | 6 +- src/emb_gpt_avx512.cpp | 60 +- src/mha_gpt_amx.cpp | 81 +- tests/script/ext/attn_gpt.cpp | 232 +++ tests/script/ext/emb_gpt.cpp | 58 +- tests/script/ext/mha_gpt.cpp | 84 +- tests/script/ext/module.cpp | 1 + tests/script/ext/module.hpp | 3 +- tests/script/ext/setup.py | 5 +- tests/script/models/chatglm-6b/.gitattributes | 34 + tests/script/models/chatglm-6b/LICENSE | 201 +++ tests/script/models/chatglm-6b/MODEL_LICENSE | 65 + tests/script/models/chatglm-6b/README.md | 89 + tests/script/models/chatglm-6b/config.json | 28 + .../chatglm-6b/configuration_chatglm.py | 103 ++ .../models/chatglm-6b/modeling_chatglm.org.py | 1450 ++++++++++++++++ .../models/chatglm-6b/modeling_chatglm.py | 1479 +++++++++++++++++ .../chatglm-6b/pytorch_model.bin.index.json | 375 +++++ .../script/models/chatglm-6b/quantization.py | 201 +++ .../chatglm-6b/test_modeling_chatglm.py | 245 +++ .../models/chatglm-6b/tokenization_chatglm.py | 443 +++++ .../models/chatglm-6b/tokenizer_config.json | 20 + tests/script/pytest.ini | 2 + tests/script/test_attn_chatglm.py | 436 +++++ tests/script/test_mha_chatglm.py | 259 +++ tests/script/test_mha_gpt.py | 14 +- tests/script/test_rotary_pastkv_chatglm.py | 352 ++++ 28 files changed, 6246 insertions(+), 82 deletions(-) create mode 100644 tests/script/ext/attn_gpt.cpp create mode 100644 tests/script/models/chatglm-6b/.gitattributes create mode 100644 tests/script/models/chatglm-6b/LICENSE create mode 100644 tests/script/models/chatglm-6b/MODEL_LICENSE create mode 100644 tests/script/models/chatglm-6b/README.md create mode 100644 tests/script/models/chatglm-6b/config.json create mode 100644 tests/script/models/chatglm-6b/configuration_chatglm.py create mode 100644 tests/script/models/chatglm-6b/modeling_chatglm.org.py create mode 100644 tests/script/models/chatglm-6b/modeling_chatglm.py create mode 100644 tests/script/models/chatglm-6b/pytorch_model.bin.index.json create mode 100644 tests/script/models/chatglm-6b/quantization.py create mode 100644 tests/script/models/chatglm-6b/test_modeling_chatglm.py create mode 100644 tests/script/models/chatglm-6b/tokenization_chatglm.py create mode 100644 tests/script/models/chatglm-6b/tokenizer_config.json create mode 100644 tests/script/pytest.ini create mode 100644 tests/script/test_attn_chatglm.py create mode 100644 tests/script/test_mha_chatglm.py create mode 100644 tests/script/test_rotary_pastkv_chatglm.py diff --git a/include/llm_emb_gpt.hpp b/include/llm_emb_gpt.hpp index a23f3d5..2d3c09c 100644 --- a/include/llm_emb_gpt.hpp +++ b/include/llm_emb_gpt.hpp @@ -23,6 +23,7 @@ class emb_gpt { data_type_t dst_precision; size_t rotary_emb_base; float rotary_pct; + bool use_position2d; }; struct exec_param { size_t batch; @@ -32,6 +33,7 @@ class emb_gpt { uint8_t* query_dst; uint8_t** layer_past_key_padded; uint8_t** layer_past_value_padded; + int* position2d_ids; // shape: [batch, 2, query_seq_len] }; emb_gpt(); diff --git a/include/llm_mha_gpt.hpp b/include/llm_mha_gpt.hpp index f30dc08..befa348 100644 --- a/include/llm_mha_gpt.hpp +++ b/include/llm_mha_gpt.hpp @@ -55,13 +55,15 @@ class mha_gpt { size_t batch; size_t query_seq_len; size_t key_seq_len; + bool is_causal_in_attention; // causal mask is fused in attention mask: chatglm uses it. uint8_t* q; // q buffer, compact, shape: [batch, num_heads, query_seq_len, head_size] uint8_t** k; // k buffer, k[N] stands different batch which may be discreted // k[0] shape: [batch, num_heads, key_seq_len, head_size] uint8_t** v; // v buffer, v[N] stands different batch which may be discreted // v[0] shape: [batch, num_heads, value_seq_len, head_size] - float** attention_mask; // attention mask, attention_mask[N] is the batch - // attention_mask[0] shape: [1, max_seq_len] + float* attention_mask; // attention mask, attention_mask[0] shape: + // [batch, 1, 1, key_seq_len], when is_causal_in_attention is false + // [batch, 1, query_seq_len, key_seq_len], when is_causal_in_attention is true uint8_t* attn_output; // output, compact, shape: [batch, query_seq_len, num_heads * head_size] size_t head_stride_in_kv; // kv stride for next head; kv may be preallocated a big buffer // expected quant schema: diff --git a/src/emb_gpt_avx512.cpp b/src/emb_gpt_avx512.cpp index b38c484..5d48e16 100644 --- a/src/emb_gpt_avx512.cpp +++ b/src/emb_gpt_avx512.cpp @@ -25,6 +25,8 @@ struct emb_gpt_impl_avx512 : public emb_gpt::impl { void initRotery(size_t max_seq_len); void applyRotaryPosEmbMemcpy(uint8_t* q_src, uint8_t* k_src, uint8_t* v_src, uint8_t* q_dst, uint8_t** k_dst, uint8_t** v_dst, size_t batch, size_t q_seq_len, size_t past_seq_len); + void applyRotaryPosEmbMemcpyWithPosition2d(uint8_t* q_src, uint8_t* k_src, uint8_t* v_src, uint8_t* q_dst, uint8_t** k_dst, uint8_t** v_dst, + size_t batch, size_t q_seq_len, size_t past_seq_len, int* position2d_ids); emb_gpt::create_param _create_param; size_t _head_num = 32; @@ -40,6 +42,7 @@ struct emb_gpt_impl_avx512 : public emb_gpt::impl { std::shared_ptr _sin_cached; int64_t _input_type_size = 1; int64_t _output_type_size = 1; + bool _use_position2d = false; }; bool emb_gpt_impl_avx512::create(const emb_gpt::create_param& param) { @@ -66,7 +69,12 @@ bool emb_gpt_impl_avx512::create(const emb_gpt::create_param& param) { if (param.dst_precision == dnnl_s8) _output_type_size = sizeof(int8_t); - _rotary_ndims = static_cast(_size_per_head * _rotary_pct); + _use_position2d = param.use_position2d; + if (_use_position2d) { + _rotary_ndims = static_cast(_size_per_head / 2); + } else { + _rotary_ndims = static_cast(_size_per_head * _rotary_pct); + } initRotery(_max_seq_len); return true; @@ -142,6 +150,50 @@ void emb_gpt_impl_avx512::applyRotaryPosEmbMemcpy(uint8_t* q_src, uint8_t* k_src }); } +// q_src shape: [batch, q_seq_len, num_attention_heads, 3 * head_size] +// q_dst shape: [batch, num_attention_heads, q_seq_len, head_size_aligned] +// kv_src shape: [batch, q_seq_len, num_attention_heads, 3 * head_size] +// kv_dst shape: [batch, num_attention_heads, q_seq_len+past_seq_len, head_size_aligned] +// position2d_ids: [batch, 2, q_seq_len] +void emb_gpt_impl_avx512::applyRotaryPosEmbMemcpyWithPosition2d(uint8_t* q_src, uint8_t* k_src, uint8_t* v_src, uint8_t* q_dst, uint8_t** k_dst, uint8_t** v_dst, + size_t batch, size_t q_seq_len, size_t past_seq_len, int* position2d_ids) { + auto key_offset = _output_type_size * past_seq_len * _size_per_head_aligned; + auto* cos_cached = _cos_cached.get(); + auto* sin_cached = _sin_cached.get(); + parallel_for3d(batch, _head_num, q_seq_len, [&](size_t b, size_t h, size_t s) { + // q, k rotary encoding + auto q_dst_batch = q_dst + b * _head_num * q_seq_len * _size_per_head_aligned * _output_type_size; + auto k_dst_batch = k_dst[b] + key_offset; + auto v_dst_batch = v_dst[b] + key_offset; + auto pos_batch = position2d_ids + b * 2 * q_seq_len; + auto block_batch = pos_batch + q_seq_len; + auto q_src_batch = q_src + b * _hidden_size * 3 * q_seq_len * _input_type_size; + auto k_src_batch = k_src + b * _hidden_size * 3 * q_seq_len * _input_type_size; + auto v_src_batch = v_src + b * _hidden_size * 3 * q_seq_len * _input_type_size; + auto q_dst_seq = q_dst_batch + s * _size_per_head_aligned * _output_type_size; + auto k_dst_seq = k_dst_batch + s * _size_per_head_aligned * _output_type_size; + auto v_dst_seq = v_dst_batch + s * _size_per_head_aligned * _output_type_size; + auto q_src_seq = q_src_batch + s * _hidden_size * 3 * _input_type_size; + auto k_src_seq = k_src_batch + s * _hidden_size * 3 * _input_type_size; + auto v_src_seq = v_src_batch + s * _hidden_size * 3 * _input_type_size; + auto* q_src_f = reinterpret_cast(q_src_seq + h * _size_per_head * 3 * _input_type_size); + auto* k_src_f = reinterpret_cast(k_src_seq + h * _size_per_head * 3 * _input_type_size); + auto* q_dst_f = reinterpret_cast(q_dst_seq + h * q_seq_len * _size_per_head_aligned * _output_type_size); + auto* k_dst_f = reinterpret_cast(k_dst_seq + h * _max_seq_len * _size_per_head_aligned * _output_type_size); + rotary_avx512(_rotary_ndims, cos_cached + pos_batch[s] * _rotary_ndims, sin_cached + pos_batch[s] * _rotary_ndims, q_src_f, k_src_f, q_dst_f, k_dst_f); + rotary_avx512(_rotary_ndims, cos_cached + block_batch[s] * _rotary_ndims, sin_cached + block_batch[s] * _rotary_ndims, + q_src_f + _rotary_ndims, + k_src_f + _rotary_ndims, + q_dst_f + _rotary_ndims, + k_dst_f + _rotary_ndims); + + // v concat + memcpy(static_cast(v_dst_seq) + h * _max_seq_len * _size_per_head_aligned * _output_type_size, + static_cast(v_src_seq) + h * _size_per_head * 3 * _input_type_size, + _size_per_head * _output_type_size); + }); +} + void emb_gpt_impl_avx512::exec(const emb_gpt::exec_param& param) { // [batch, seq_len, (num_heads * 3 * head_size)] // --> [batch, seq_len, num_heads, 3 * head_size] @@ -172,7 +224,11 @@ void emb_gpt_impl_avx512::exec(const emb_gpt::exec_param& param) { // value(pastKeys): value = torch.cat((past_value, value), dim=-2) // q_dst shape: [batch, num_attention_heads, q_seq_len, head_size_aligned] // kv_dst shape: [batch, num_attention_heads, q_seq_len+past_seq_len, head_size_aligned] - applyRotaryPosEmbMemcpy(query, key, value, query_dst, key_dst, value_dst, batch, query_seq_len, past_seq_len); + if (_use_position2d) { + applyRotaryPosEmbMemcpyWithPosition2d(query, key, value, query_dst, key_dst, value_dst, batch, query_seq_len, past_seq_len, param.position2d_ids); + } else { + applyRotaryPosEmbMemcpy(query, key, value, query_dst, key_dst, value_dst, batch, query_seq_len, past_seq_len); + } } } diff --git a/src/mha_gpt_amx.cpp b/src/mha_gpt_amx.cpp index 46601e4..6047290 100644 --- a/src/mha_gpt_amx.cpp +++ b/src/mha_gpt_amx.cpp @@ -131,7 +131,7 @@ void mha_gpt_impl_amx::mha_bf16(const mha_gpt::exec_param ¶m) { auto pKIn0_aux = pKIn0[i0] + i1 * param.head_stride_in_kv * get_precision_size(_create_param.qkv_precision); auto pVIn0_aux = pVIn0[i0] + i1 * param.head_stride_in_kv * get_precision_size(_create_param.qkv_precision); - auto pAddIn1_aux = attn_masks[i0]; + auto pAddIn1_aux = attn_masks + i0 * param.key_seq_len; auto bufferMatMul0Out_local = reinterpret_cast(bufferMatMul0Out.get() + threadNum * bufferMatMul0OutSize); auto bufferMatMul1Out_local = reinterpret_cast(bufferMatMul1Out.get() + threadNum * bufferMatMul1OutSize); @@ -181,8 +181,6 @@ void mha_gpt_impl_amx::mha_bf16(const mha_gpt::exec_param ¶m) { auto pKIn0_aux = pKIn0[i0] + i1 * param.head_stride_in_kv * get_precision_size(_create_param.qkv_precision); auto pVIn0_aux = pVIn0[i0] + i1 * param.head_stride_in_kv * get_precision_size(_create_param.qkv_precision); - auto pAddIn1_aux = attn_masks[i0]; - auto bufferMatMul0Out_local = reinterpret_cast(bufferMatMul0Out.get() + threadNum * bufferMatMul0OutSize); auto bufferMatMul1Out_local = reinterpret_cast(bufferMatMul1Out.get() + threadNum * bufferMatMul1OutSize); @@ -194,20 +192,33 @@ void mha_gpt_impl_amx::mha_bf16(const mha_gpt::exec_param ¶m) { prev_k = pKIn0_aux; auto pMatMul0Out = bufferMatMul0Out_local; - // loop along K dimension - size_t valid_softmax_items = causal_mask_offset_start + seq_start + 1; - for (int m = 0; m < seq_cout; m++) { - float* src = reinterpret_cast(pMatMul0Out + m * rndup(param.key_seq_len * sizeof(float), 64)); - ov::bfloat16* dst = reinterpret_cast(pMatMul0Out + m * rndup(param.key_seq_len * sizeof(ov::bfloat16), 64)); - mul_add_f32_avx512(src, src, _create_param.normal_factor, pAddIn1_aux, valid_softmax_items); - softmax_avx512(dst, src, valid_softmax_items, nullptr); - // attn_scores = torch.where(causal_mask, attn_scores, mask_value) - if (param.key_seq_len > valid_softmax_items) { - auto *invalidPtr = dst + valid_softmax_items; - memset(static_cast(invalidPtr), 0, (param.key_seq_len - valid_softmax_items) * get_precision_size(_create_param.qkv_precision)); - valid_softmax_items = std::min(valid_softmax_items + 1, param.key_seq_len); + if (param.is_causal_in_attention) { + auto pAddIn1_aux = attn_masks + i0 * param.key_seq_len * param.query_seq_len; + // loop along K dimension + for (int m = 0; m < seq_cout; m++) { + float* src = reinterpret_cast(pMatMul0Out + m * rndup(param.key_seq_len * sizeof(float), 64)); + ov::bfloat16* dst = reinterpret_cast(pMatMul0Out + m * rndup(param.key_seq_len * sizeof(ov::bfloat16), 64)); + mul_add_f32_avx512(src, src, _create_param.normal_factor, pAddIn1_aux + (m + seq_start) * param.key_seq_len, param.key_seq_len); + softmax_avx512(dst, src, param.key_seq_len, nullptr); + } + } else { + auto pAddIn1_aux = attn_masks + i0 * param.key_seq_len; + // loop along K dimension + size_t valid_softmax_items = causal_mask_offset_start + seq_start + 1; + for (int m = 0; m < seq_cout; m++) { + float* src = reinterpret_cast(pMatMul0Out + m * rndup(param.key_seq_len * sizeof(float), 64)); + ov::bfloat16* dst = reinterpret_cast(pMatMul0Out + m * rndup(param.key_seq_len * sizeof(ov::bfloat16), 64)); + mul_add_f32_avx512(src, src, _create_param.normal_factor, pAddIn1_aux, valid_softmax_items); + softmax_avx512(dst, src, valid_softmax_items, nullptr); + // attn_scores = torch.where(causal_mask, attn_scores, mask_value) + if (param.key_seq_len > valid_softmax_items) { + auto *invalidPtr = dst + valid_softmax_items; + memset(static_cast(invalidPtr), 0, (param.key_seq_len - valid_softmax_items) * get_precision_size(_create_param.qkv_precision)); + valid_softmax_items = std::min(valid_softmax_items + 1, param.key_seq_len); + } } } + auto pOut_aux = pout + (i0 * batch_stride_in_attn + i1 * head_stride_in_attn + seq_start * head_stride_in_attn * _create_param.num_heads) * outPrcSize; tensor2D matQKBF16(seq_cout, param.key_seq_len, reinterpret_cast(bufferMatMul0Out_local), rndup(param.key_seq_len * sizeof(ov::bfloat16), 64)); @@ -258,7 +269,7 @@ void mha_gpt_impl_amx::mha_i8(const mha_gpt::exec_param ¶m) { auto pKIn0_aux = pKIn0[i0] + i1 * param.head_stride_in_kv * get_precision_size(_create_param.qkv_precision); auto pVIn0_aux = pVIn0[i0] + i1 * param.head_stride_in_kv * get_precision_size(_create_param.qkv_precision); - auto pAddIn1_aux = attn_masks[i0]; + auto pAddIn1_aux = attn_masks + i0 * param.key_seq_len; auto bufferMatMul0Out_local = reinterpret_cast(bufferMatMul0Out.get() + threadNum * bufferMatMul0OutSize); auto bufferMatMul1Out_local = reinterpret_cast(bufferMatMul1Out.get() + threadNum * bufferMatMul1OutSize); @@ -309,8 +320,6 @@ void mha_gpt_impl_amx::mha_i8(const mha_gpt::exec_param ¶m) { auto pKIn0_aux = pKIn0[i0] + i1 * param.head_stride_in_kv; auto pVIn0_aux = pVIn0[i0] + i1 * param.head_stride_in_kv; - auto pAddIn1_aux = attn_masks[i0]; - auto bufferMatMul0Out_local = reinterpret_cast(bufferMatMul0Out.get() + threadNum * bufferMatMul0OutSize); auto bufferMatMul1Out_local = reinterpret_cast(bufferMatMul1Out.get() + threadNum * bufferMatMul1OutSize); @@ -322,18 +331,30 @@ void mha_gpt_impl_amx::mha_i8(const mha_gpt::exec_param ¶m) { prev_k = pKIn0_aux; auto pMatMul0Out = bufferMatMul0Out_local; - // loop along K dimension - size_t valid_softmax_items = causal_mask_offset_start + seq_start + 1; - for (int m = 0; m < seq_cout; m++) { - float* src = reinterpret_cast(pMatMul0Out + m * rndup(param.key_seq_len * sizeof(float), 64)); - uint8_t* dst = reinterpret_cast(pMatMul0Out + m * rndup(param.key_seq_len * sizeof(uint8_t), 64)); - mul_add_f32_avx512(src, src, mul_scales, pAddIn1_aux, valid_softmax_items); - softmax_avx512(dst, src, valid_softmax_items, param.qk_quant); - // attn_scores = torch.where(causal_mask, attn_scores, mask_value) - if (param.key_seq_len > valid_softmax_items) { - auto *invalidPtr = dst + valid_softmax_items; - memset(invalidPtr, 0, (param.key_seq_len - valid_softmax_items) * get_precision_size(_create_param.qkv_precision)); - valid_softmax_items = std::min(valid_softmax_items + 1, param.key_seq_len); + if (param.is_causal_in_attention) { + auto pAddIn1_aux = attn_masks + i0 * param.key_seq_len * param.query_seq_len; + // loop along K dimension + for (int m = 0; m < seq_cout; m++) { + float* src = reinterpret_cast(pMatMul0Out + m * rndup(param.key_seq_len * sizeof(float), 64)); + uint8_t* dst = reinterpret_cast(pMatMul0Out + m * rndup(param.key_seq_len * sizeof(uint8_t), 64)); + mul_add_f32_avx512(src, src, mul_scales, pAddIn1_aux + (m + seq_start) * param.key_seq_len, param.key_seq_len); + softmax_avx512(dst, src, param.key_seq_len, param.qk_quant); + } + } else { + auto pAddIn1_aux = attn_masks + i0 * param.key_seq_len; + // loop along K dimension + size_t valid_softmax_items = causal_mask_offset_start + seq_start + 1; + for (int m = 0; m < seq_cout; m++) { + float* src = reinterpret_cast(pMatMul0Out + m * rndup(param.key_seq_len * sizeof(float), 64)); + uint8_t* dst = reinterpret_cast(pMatMul0Out + m * rndup(param.key_seq_len * sizeof(uint8_t), 64)); + mul_add_f32_avx512(src, src, mul_scales, pAddIn1_aux, valid_softmax_items); + softmax_avx512(dst, src, valid_softmax_items, param.qk_quant); + // attn_scores = torch.where(causal_mask, attn_scores, mask_value) + if (param.key_seq_len > valid_softmax_items) { + auto *invalidPtr = dst + valid_softmax_items; + memset(invalidPtr, 0, (param.key_seq_len - valid_softmax_items) * get_precision_size(_create_param.qkv_precision)); + valid_softmax_items = std::min(valid_softmax_items + 1, param.key_seq_len); + } } } auto pOut_aux = pout + (i0 * batch_stride_in_attn + i1 * head_stride_in_attn diff --git a/tests/script/ext/attn_gpt.cpp b/tests/script/ext/attn_gpt.cpp new file mode 100644 index 0000000..4b4ea8c --- /dev/null +++ b/tests/script/ext/attn_gpt.cpp @@ -0,0 +1,232 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include "alloca.h" +#include "module.hpp" +#include "common/utility.hpp" +#include "utility_kernel_amx.hpp" +#include "llm_emb_gpt.hpp" +#include "llm_mha_gpt.hpp" +#include "test_common.hpp" + +using namespace torch::indexing; + +class attn_gpt { +public: + struct create_param { + size_t num_heads; + size_t head_size; + size_t head_size_aligned; // better to aligned to 64 bytes for best performance, apply for qkv + size_t max_seq_len; // max seq length for computing the size of matmul tmp result + // supported (qkv, dst): (bf16, bf16) + llmdnn::data_type_t qkv_precision; + llmdnn::data_type_t dst_precision; + size_t rotary_emb_base; + float normal_factor; + float rotary_pct; + bool use_position2d; + }; + struct exec_param { + size_t batch; + size_t query_seq_len; + size_t past_seq_len; + bool is_causal_in_attention; // causal mask is fused in attention mask: chatglm uses it. + uint8_t* qkv; + uint8_t** layer_past_key_padded; + uint8_t** layer_past_value_padded; + int* position2d_ids; // shape: [batch, 2, query_seq_len] + float* attention_mask; // attention mask, attention_mask[0] shape: + // [batch, 1, 1, key_seq_len], when is_causal_in_attention is false + // [batch, 1, query_seq_len, key_seq_len], when is_causal_in_attention is true + uint8_t* attn_output; + size_t head_stride_in_kv; + }; + + attn_gpt(); + bool create(const create_param& param); + void exec(const exec_param& param); + +private: + create_param _create_param; + std::shared_ptr _emb_gpt; + std::shared_ptr _mha_gpt; + std::shared_ptr _query_dst; + size_t _query_cached_batch = 0; +}; + +attn_gpt::attn_gpt(): _emb_gpt(std::make_shared()), + _mha_gpt(std::make_shared()) { + +} + +bool attn_gpt::create(const attn_gpt::create_param& param) { + _create_param = param; + llmdnn::emb_gpt::create_param emb_param; + emb_param.num_heads = param.num_heads; + emb_param.head_size = param.head_size; + emb_param.head_size_aligned = param.head_size_aligned; + emb_param.qkv_precision = param.qkv_precision; + emb_param.dst_precision = param.dst_precision; + emb_param.max_seq_len = param.max_seq_len; + emb_param.rotary_emb_base = param.rotary_emb_base; + emb_param.rotary_pct = param.rotary_pct; + emb_param.use_position2d = param.use_position2d; + + if (!_emb_gpt->create(emb_param)) + return false; + + llmdnn::mha_gpt::create_param mha_param; + mha_param.num_heads = param.num_heads; + mha_param.head_size = param.head_size; + mha_param.head_size_aligned = param.head_size_aligned; + mha_param.normal_factor = param.normal_factor; + mha_param.qkv_precision = param.qkv_precision; + mha_param.dst_precision = param.dst_precision; + mha_param.max_seq_len = param.max_seq_len; + + return _mha_gpt->create(mha_param); +} + +void attn_gpt::exec(const attn_gpt::exec_param& param) { + if (_query_cached_batch < param.batch) { + auto capacity = param.batch * _create_param.max_seq_len * (_create_param.num_heads * _create_param.head_size_aligned) * + llmdnn::get_precision_size(_create_param.qkv_precision); + _query_dst = std::shared_ptr(reinterpret_cast(aligned_alloc(64, capacity)), + [](void * p) { ::free(p); }); + memset(_query_dst.get(), 0, capacity); + _query_cached_batch = param.batch; + } + + llmdnn::emb_gpt::exec_param emb_param; + emb_param.batch = param.batch; + emb_param.query_seq_len = param.query_seq_len; + emb_param.past_seq_len = param.past_seq_len; + emb_param.qkv = param.qkv; + emb_param.query_dst = _query_dst.get(); + emb_param.layer_past_key_padded = param.layer_past_key_padded; + emb_param.layer_past_value_padded = param.layer_past_value_padded; + emb_param.position2d_ids = param.position2d_ids; + _emb_gpt->exec(emb_param); + + llmdnn::mha_gpt::exec_param mha_param; + mha_param.batch = param.batch; + mha_param.query_seq_len = param.query_seq_len; + mha_param.key_seq_len = param.query_seq_len + param.past_seq_len; + mha_param.q = emb_param.query_dst; + mha_param.attn_output = param.attn_output; + mha_param.head_stride_in_kv = param.head_stride_in_kv; + mha_param.is_causal_in_attention = param.is_causal_in_attention; + mha_param.attention_mask = param.attention_mask; + mha_param.k = emb_param.layer_past_key_padded; + mha_param.v = emb_param.layer_past_value_padded; + _mha_gpt->exec(mha_param); +} + +void regclass_attn_gpt(pybind11::module m) { + py::class_> cls(m, "attn_gpt"); + cls.def(py::init<>()); + cls.def("create", [] (attn_gpt& self, + const size_t num_heads, + const size_t head_size, + const size_t head_size_aligned, + float normal_factor, + const std::string qkv_precision_name, + const std::string dst_precision_name, + const size_t max_seq_len, + const size_t rotary_emb_base, + float rotary_pct, + bool use_position2d) { + attn_gpt::create_param param; + param.num_heads = num_heads; + param.head_size = head_size; + param.head_size_aligned = head_size_aligned; + param.normal_factor = normal_factor; + param.qkv_precision = llmdnn::get_dt_from_str(qkv_precision_name); + param.dst_precision = llmdnn::get_dt_from_str(dst_precision_name); + param.max_seq_len = max_seq_len; + param.rotary_emb_base = rotary_emb_base; + param.rotary_pct = rotary_pct; + param.use_position2d = use_position2d; + if (param.qkv_precision == llmdnn::dnnl_data_type_undef) + throw pybind11::type_error("Incorrect qkv type " + qkv_precision_name); + if (param.dst_precision == llmdnn::dnnl_data_type_undef) + throw pybind11::type_error("Incorrect dst type " + dst_precision_name); + if (!self.create(param)) + throw pybind11::type_error("Incorrect param"); + }, + py::arg("num_heads"), + py::arg("head_size"), + py::arg("head_size_aligned"), + py::arg("normal_factor"), + py::arg("qkv_precision_name"), + py::arg("dst_precision_name"), + py::arg("max_seq_len"), + py::arg("rotary_emb_base"), + py::arg("rotary_pct"), + py::arg("use_position2d") = false, + R"( + Create emb + + :param num_heads: heads number. + :type num_heads: int + )"); + cls.def("exec_position", [] (attn_gpt& self, const torch::Tensor& qkv, const torch::Tensor& layer_past_key_padded, + const torch::Tensor& layer_past_value_padded, int64_t past_seq_len, const torch::Tensor& attn_mask, const torch::Tensor& position2d_ids) { + // qkv: [batch, seq_len, (num_heads * 3 * head_size)] + // layer_past_padded: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] + // past_seq_len: past_seq_len==layer_past.shape[-2] + // attn_mask: [batch, 1, 1/query_seq_len, key_seq_len] + // key/value: [batch, num_heads, query_seq_len+past_seq_len, head_size_aligned] + AT_ASSERT(qkv.dim() == 3 && layer_past_key_padded.dim() == 4 && layer_past_value_padded.dim() == 4 && attn_mask.dim() == 4 && + qkv.size(0) == layer_past_key_padded.size(0) && + layer_past_key_padded.dim() == layer_past_value_padded.dim()); + auto batch = qkv.size(0); + auto num_heads = layer_past_key_padded.size(1); + auto query_seq_len = qkv.size(1); + auto head_size = qkv.size(2) / 3 / num_heads; + auto head_size_aligned = layer_past_key_padded.size(3); + auto max_seq_len = layer_past_key_padded.size(2); + AT_ASSERT(past_seq_len <= layer_past_key_padded.size(2) && head_size <= layer_past_key_padded.size(3) && + query_seq_len <= layer_past_key_padded.size(2)); + + attn_gpt::exec_param param; + param.batch = batch; + param.query_seq_len = query_seq_len; + param.past_seq_len = past_seq_len; + param.qkv = reinterpret_cast(qkv.data_ptr()); + param.layer_past_key_padded = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); + param.layer_past_value_padded = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); + for (int i = 0; i < batch; i++) { + param.layer_past_key_padded[i] = reinterpret_cast(layer_past_key_padded[i].data_ptr()); + param.layer_past_value_padded[i] = reinterpret_cast(layer_past_value_padded[i].data_ptr()); + } + param.position2d_ids = reinterpret_cast(position2d_ids.data_ptr()); + + param.is_causal_in_attention = attn_mask.size(2) != 1; + param.attention_mask = attn_mask.data_ptr(); + param.head_stride_in_kv = max_seq_len * head_size_aligned; + auto out = qkv.new_empty({batch, query_seq_len, num_heads * head_size}); + param.attn_output = reinterpret_cast(out.data_ptr()); + + self.exec(param); + + // auto options = torch::TensorOptions().dtype(torch::kBFloat16); + // auto query = torch::from_blob(param.query, {batch, num_heads, query_seq_len, head_size}, options); + return out; + }, + py::arg("qkv"), + py::arg("layer_past_key_padded"), + py::arg("layer_past_value_padded"), + py::arg("past_seq_len"), + py::arg("attn_mask"), + py::arg("position2d_ids"), + R"( + exec emb + + :param num_heads: heads number. + :type num_heads: int + )"); +} \ No newline at end of file diff --git a/tests/script/ext/emb_gpt.cpp b/tests/script/ext/emb_gpt.cpp index 7818751..66ec8c1 100644 --- a/tests/script/ext/emb_gpt.cpp +++ b/tests/script/ext/emb_gpt.cpp @@ -24,7 +24,8 @@ void regclass_emb_gpt(pybind11::module m) { const std::string dst_precision_name, const size_t max_seq_len, const size_t rotary_emb_base, - float rotary_pct) { + float rotary_pct, + bool use_position2d) { llmdnn::emb_gpt::create_param param; param.num_heads = num_heads; param.head_size = head_size; @@ -34,6 +35,7 @@ void regclass_emb_gpt(pybind11::module m) { param.max_seq_len = max_seq_len; param.rotary_emb_base = rotary_emb_base; param.rotary_pct = rotary_pct; + param.use_position2d = use_position2d; if (param.qkv_precision == llmdnn::dnnl_data_type_undef) throw pybind11::type_error("Incorrect qkv type " + qkv_precision_name); if (param.dst_precision == llmdnn::dnnl_data_type_undef) @@ -49,6 +51,7 @@ void regclass_emb_gpt(pybind11::module m) { py::arg("max_seq_len"), py::arg("rotary_emb_base"), py::arg("rotary_pct"), + py::arg("use_position2d") = false, R"( Create emb @@ -57,7 +60,7 @@ void regclass_emb_gpt(pybind11::module m) { )"); // torch::List cls.def("exec", [] (llmdnn::emb_gpt& self, const torch::Tensor& qkv, const torch::Tensor& layer_past_key_padded, - const torch::Tensor& layer_past_value_padded, const torch::Tensor& query_padded, size_t past_seq_len) { + const torch::Tensor& layer_past_value_padded, const torch::Tensor& query_padded, int64_t past_seq_len) { // qkv: [batch, seq_len, (num_heads * 3 * head_size)] // layer_past_padded: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] // past_seq_len: past_seq_len==layer_past.shape[-2] @@ -102,6 +105,57 @@ void regclass_emb_gpt(pybind11::module m) { R"( exec emb + :param num_heads: heads number. + :type num_heads: int + )"); + cls.def("exec_position", [] (llmdnn::emb_gpt& self, const torch::Tensor& qkv, const torch::Tensor& layer_past_key_padded, + const torch::Tensor& layer_past_value_padded, const torch::Tensor& query_padded, int64_t past_seq_len, const torch::Tensor position2d_ids) { + // qkv: [batch, seq_len, (num_heads * 3 * head_size)] + // layer_past_padded: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] + // past_seq_len: past_seq_len==layer_past.shape[-2] + // query_padded: [batch, num_heads, query_seq_len, head_size_aligned] + // key/value: [batch, num_heads, query_seq_len+past_seq_len, head_size_aligned] + AT_ASSERT(qkv.dim() == 3 && layer_past_key_padded.dim() == 4 && layer_past_value_padded.dim() == 4 && + qkv.size(0) == layer_past_key_padded.size(0) && + layer_past_key_padded.dim() == layer_past_value_padded.dim()); + AT_ASSERT(query_padded.dim() == 4 && query_padded.size(0) == qkv.size(0) && + query_padded.size(1) == layer_past_key_padded.size(1) && query_padded.size(2) == qkv.size(1) && + query_padded.size(3) == layer_past_key_padded.size(3)); + auto batch = qkv.size(0); + auto num_heads = layer_past_key_padded.size(1); + auto query_seq_len = qkv.size(1); + auto head_size = qkv.size(2) / 3 / num_heads; + AT_ASSERT(past_seq_len <= layer_past_key_padded.size(2) && head_size <= layer_past_key_padded.size(3) && + query_seq_len <= layer_past_key_padded.size(2)); + + llmdnn::emb_gpt::exec_param param; + param.batch = batch; + param.query_seq_len = query_seq_len; + param.past_seq_len = past_seq_len; + param.qkv = reinterpret_cast(qkv.data_ptr()); + param.query_dst = reinterpret_cast(query_padded.data_ptr()); + param.layer_past_key_padded = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); + param.layer_past_value_padded = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); + for (int i = 0; i < batch; i++) { + param.layer_past_key_padded[i] = reinterpret_cast(layer_past_key_padded[i].data_ptr()); + param.layer_past_value_padded[i] = reinterpret_cast(layer_past_value_padded[i].data_ptr()); + } + param.position2d_ids = reinterpret_cast(position2d_ids.data_ptr()); + + self.exec(param); + + // auto options = torch::TensorOptions().dtype(torch::kBFloat16); + // auto query = torch::from_blob(param.query, {batch, num_heads, query_seq_len, head_size}, options); + }, + py::arg("qkv"), + py::arg("layer_past_key_padded"), + py::arg("layer_past_value_padded"), + py::arg("query_padded"), + py::arg("past_seq_len"), + py::arg("position2d_ids"), + R"( + exec emb + :param num_heads: heads number. :type num_heads: int )"); diff --git a/tests/script/ext/mha_gpt.cpp b/tests/script/ext/mha_gpt.cpp index c5b397a..704a1fb 100644 --- a/tests/script/ext/mha_gpt.cpp +++ b/tests/script/ext/mha_gpt.cpp @@ -50,39 +50,41 @@ void regclass_mha_gpt(pybind11::module m) { :param num_heads: heads number. :type num_heads: int )"); - cls.def("exec", [] (llmdnn::mha_gpt& self, const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v, const torch::Tensor& attn_mask) { - // q: [batch, num_heads, query_seq_len, head_size] - // k: [batch, num_heads, key_seq_len, head_size] - // v: [batch, num_heads, value_seq_len, head_size] - // attn_mask: [batch, MAX_POSITION_EMBEDDINGS] + cls.def("exec", [] (llmdnn::mha_gpt& self, const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v, const torch::Tensor& attn_mask, int64_t head_size, int64_t key_seq_len) { + // q: [batch, num_heads, query_seq_len, head_size_aligned] + // k: [batch, num_heads, max_seq_len, head_size_aligned] valid in max_seq_len: key_seq_len + // v: [batch, num_heads, max_seq_len, head_size_aligned] valid in max_seq_len: key_seq_len + // attn_mask: [batch, 1, 1/query_seq_len, key_seq_len] // out: [batch, query_seq_len, num_heads * head_size] - AT_ASSERT(q.dim() == 4 && k.dim() == 4 && v.dim() == 4 && attn_mask.dim() == 2); + AT_ASSERT(q.dim() == 4 && k.dim() == 4 && v.dim() == 4 && attn_mask.dim() == 4); auto batch = q.size(0); auto num_heads = q.size(1); auto query_seq_len = q.size(2); - auto head_size = q.size(3); - auto key_seq_len = k.size(2); - auto attn_len = attn_mask.size(1); - AT_ASSERT(key_seq_len == v.size(2) && key_seq_len == attn_len && - batch == k.size(0) && batch == v.size(0) && 1 == attn_mask.size(0) && + auto head_size_aligned = q.size(3); + auto max_seq_len = k.size(2); + auto attn_len = attn_mask.size(3); + AT_ASSERT(max_seq_len == v.size(2) && + batch == k.size(0) && batch == v.size(0) && batch == attn_mask.size(0) && num_heads == k.size(1) && num_heads == v.size(1) && - head_size == k.size(3) && head_size == v.size(3)); + head_size_aligned == k.size(3) && head_size_aligned == v.size(3)); - auto out = q.new_empty({batch, query_seq_len, num_heads * head_size}); llmdnn::mha_gpt::exec_param param; param.batch = batch; param.query_seq_len = query_seq_len; - param.key_seq_len = key_seq_len; + param.key_seq_len = key_seq_len == 0 ? max_seq_len : key_seq_len; + head_size = head_size == 0 ? head_size_aligned : head_size; + auto out = q.new_empty({batch, query_seq_len, num_heads * head_size}); + AT_ASSERT((int64_t)param.key_seq_len == attn_len); param.q = reinterpret_cast(q.data_ptr()); param.attn_output = reinterpret_cast(out.data_ptr()); - param.head_stride_in_kv = key_seq_len * head_size; + param.head_stride_in_kv = max_seq_len * head_size_aligned; + param.is_causal_in_attention = attn_mask.size(2) != 1; + param.attention_mask = attn_mask.data_ptr(); param.k = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); param.v = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); - param.attention_mask = reinterpret_cast(alloca(batch * sizeof(float*))); for (int i = 0; i < batch; i++) { - param.k[i] = reinterpret_cast(k.data_ptr()) + i * num_heads * key_seq_len * head_size * sizeof(ov::bfloat16); - param.v[i] = reinterpret_cast(v.data_ptr()) + i * num_heads * key_seq_len * head_size * sizeof(ov::bfloat16); - param.attention_mask[i] = attn_mask.data_ptr(); + param.k[i] = reinterpret_cast(k[i].data_ptr()); + param.v[i] = reinterpret_cast(v[i].data_ptr()); } self.exec(param); @@ -92,6 +94,8 @@ void regclass_mha_gpt(pybind11::module m) { py::arg("k"), py::arg("v"), py::arg("attn_mask"), + py::arg("head_size") = 0, + py::arg("key_seq_len") = 0, R"( exec mha @@ -99,44 +103,46 @@ void regclass_mha_gpt(pybind11::module m) { :type num_heads: int )"); cls.def("exec_quant", [] (llmdnn::mha_gpt& self, const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v, const torch::Tensor& attn_mask, - float q_dequant, float k_dequant, float v_dequant, float qk_quant, const std::vector& qkv_quant) { - // q: [batch, num_heads, query_seq_len, head_size] - // k: [batch, num_heads, key_seq_len, head_size] - // v: [batch, num_heads, value_seq_len, head_size] - // attn_mask: [batch, MAX_POSITION_EMBEDDINGS] + float q_dequant, float k_dequant, float v_dequant, float qk_quant, const std::vector& qkv_quant, int64_t head_size, int64_t key_seq_len) { + // q: [batch, num_heads, query_seq_len, head_size_aligned] + // k: [batch, num_heads, max_seq_len, head_size_aligned] valid in max_seq_len: key_seq_len + // v: [batch, num_heads, max_seq_len, head_size_aligned] valid in max_seq_len: key_seq_len + // attn_mask: [batch, 1, 1/query_seq_len, key_seq_len] // out: [batch, query_seq_len, num_heads * head_size] - AT_ASSERT(q.dim() == 4 && k.dim() == 4 && v.dim() == 4 && attn_mask.dim() == 2); + AT_ASSERT(q.dim() == 4 && k.dim() == 4 && v.dim() == 4 && attn_mask.dim() == 4); auto batch = q.size(0); auto num_heads = q.size(1); auto query_seq_len = q.size(2); - auto head_size = q.size(3); - auto key_seq_len = k.size(2); - auto attn_len = attn_mask.size(1); - AT_ASSERT(key_seq_len == v.size(2) && key_seq_len == attn_len && - batch == k.size(0) && batch == v.size(0) && 1 == attn_mask.size(0) && + auto head_size_aligned = q.size(3); + auto max_seq_len = k.size(2); + auto attn_len = attn_mask.size(3); + AT_ASSERT(max_seq_len == v.size(2) && + batch == k.size(0) && batch == v.size(0) && batch == attn_mask.size(0) && num_heads == k.size(1) && num_heads == v.size(1) && - head_size == k.size(3) && head_size == v.size(3)); + head_size_aligned == k.size(3) && head_size_aligned == v.size(3)); - auto out = q.new_empty({batch, query_seq_len, num_heads * head_size}, torch::TensorOptions(torch::kInt8)); llmdnn::mha_gpt::exec_param param; param.batch = batch; param.query_seq_len = query_seq_len; - param.key_seq_len = key_seq_len; + param.key_seq_len = key_seq_len == 0 ? max_seq_len : key_seq_len; + head_size = head_size == 0 ? head_size_aligned : head_size; + auto out = q.new_empty({batch, query_seq_len, num_heads * head_size}, torch::TensorOptions(torch::kInt8)); + AT_ASSERT((int64_t)param.key_seq_len == attn_len); param.q = reinterpret_cast(q.data_ptr()); param.attn_output = reinterpret_cast(out.data_ptr()); - param.head_stride_in_kv = key_seq_len * head_size; + param.head_stride_in_kv = max_seq_len * head_size_aligned; param.k = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); param.v = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); - param.attention_mask = reinterpret_cast(alloca(batch * sizeof(float*))); + param.is_causal_in_attention = attn_mask.size(2) != 1; + param.attention_mask = attn_mask.data_ptr(); param.q_dequant = q_dequant; param.k_dequant = k_dequant; param.v_dequant = v_dequant; param.qk_quant = qk_quant; param.qkv_quant = qkv_quant; for (int i = 0; i < batch; i++) { - param.k[i] = reinterpret_cast(k.data_ptr()) + i * num_heads * key_seq_len * head_size; - param.v[i] = reinterpret_cast(v.data_ptr()) + i * num_heads * key_seq_len * head_size; - param.attention_mask[i] = attn_mask.data_ptr(); + param.k[i] = reinterpret_cast(k[i].data_ptr()); + param.v[i] = reinterpret_cast(v[i].data_ptr()); } self.exec(param); @@ -151,6 +157,8 @@ void regclass_mha_gpt(pybind11::module m) { py::arg("v_dequant"), py::arg("qk_quant"), py::arg("qkv_quant"), + py::arg("head_size") = 0, + py::arg("key_seq_len") = 0, R"( exec mha quant diff --git a/tests/script/ext/module.cpp b/tests/script/ext/module.cpp index d016470..58c0ba1 100644 --- a/tests/script/ext/module.cpp +++ b/tests/script/ext/module.cpp @@ -16,4 +16,5 @@ PYBIND11_MODULE(llmdnn, m) { } regclass_mha_gpt(m); regclass_emb_gpt(m); + regclass_attn_gpt(m); } \ No newline at end of file diff --git a/tests/script/ext/module.hpp b/tests/script/ext/module.hpp index d9f8850..12efeed 100644 --- a/tests/script/ext/module.hpp +++ b/tests/script/ext/module.hpp @@ -7,4 +7,5 @@ #include void regclass_mha_gpt(pybind11::module m); -void regclass_emb_gpt(pybind11::module m); \ No newline at end of file +void regclass_emb_gpt(pybind11::module m); +void regclass_attn_gpt(pybind11::module m); \ No newline at end of file diff --git a/tests/script/ext/setup.py b/tests/script/ext/setup.py index ebca40f..dadcf5d 100644 --- a/tests/script/ext/setup.py +++ b/tests/script/ext/setup.py @@ -31,7 +31,10 @@ cpp_extension.CppExtension( 'llmdnn', ['module.cpp', f'../../src/test_common.cpp', - 'mha_gpt.cpp', 'emb_gpt.cpp'], + 'mha_gpt.cpp', + 'emb_gpt.cpp', + 'attn_gpt.cpp', + ], extra_compile_args=extra_args, include_dirs=[f'{os.getcwd()}/../../src', f'{os.getcwd()}/../../../include', diff --git a/tests/script/models/chatglm-6b/.gitattributes b/tests/script/models/chatglm-6b/.gitattributes new file mode 100644 index 0000000..c7d9f33 --- /dev/null +++ b/tests/script/models/chatglm-6b/.gitattributes @@ -0,0 +1,34 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text diff --git a/tests/script/models/chatglm-6b/LICENSE b/tests/script/models/chatglm-6b/LICENSE new file mode 100644 index 0000000..ac4aee5 --- /dev/null +++ b/tests/script/models/chatglm-6b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright Zhengxiao Du + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/tests/script/models/chatglm-6b/MODEL_LICENSE b/tests/script/models/chatglm-6b/MODEL_LICENSE new file mode 100644 index 0000000..f8e2731 --- /dev/null +++ b/tests/script/models/chatglm-6b/MODEL_LICENSE @@ -0,0 +1,65 @@ +The ChatGLM-6B License + +一、定义 + +“许可方”是指分发其软件的 ChatGLM-6B 模型团队。 + +“软件”是指根据本许可提供的 ChatGLM-6B 模型参数。 + +2. 许可授予 + +根据本许可的条款和条件,许可方特此授予您非排他性、全球性、不可转让、不可再许可、可撤销、免版税的版权许可,仅用于您的非商业研究目的。 + +上述版权声明和本许可声明应包含在本软件的所有副本或重要部分中。 + +3.限制 + +您不得出于任何商业、军事或非法目的使用、复制、修改、合并、发布、分发、复制或创建本软件的全部或部分衍生作品。 + +您不得利用本软件从事任何危害国家安全和国家统一、危害社会公共利益、侵犯人身权益的行为。 + +4.免责声明 + +本软件“按原样”提供,不提供任何明示或暗示的保证,包括但不限于对适销性、特定用途的适用性和非侵权性的保证。 在任何情况下,作者或版权持有人均不对任何索赔、损害或其他责任负责,无论是在合同诉讼、侵权行为还是其他方面,由软件或软件的使用或其他交易引起、由软件引起或与之相关 软件。 + +5. 责任限制 + +除适用法律禁止的范围外,在任何情况下且根据任何法律理论,无论是基于侵权行为、疏忽、合同、责任或其他原因,任何许可方均不对您承担任何直接、间接、特殊、偶然、示范性、 或间接损害,或任何其他商业损失,即使许可人已被告知此类损害的可能性。 + +6.争议解决 + +本许可受中华人民共和国法律管辖并按其解释。 因本许可引起的或与本许可有关的任何争议应提交北京市海淀区人民法院。 + +请注意,许可证可能会更新到更全面的版本。 有关许可和版权的任何问题,请通过 glm-130b@googlegroups.com 与我们联系。 + +1. Definitions + +“Licensor” means the ChatGLM-6B Model Team that distributes its Software. + +“Software” means the ChatGLM-6B model parameters made available under this license. + +2. License Grant + +Subject to the terms and conditions of this License, the Licensor hereby grants to you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty-free copyright license to use the Software solely for your non-commercial research purposes. + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +3. Restriction + +You will not use, copy, modify, merge, publish, distribute, reproduce, or create derivative works of the Software, in whole or in part, for any commercial, military, or illegal purposes. + +You will not use the Software for any act that may undermine China's national security and national unity, harm the public interest of society, or infringe upon the rights and interests of human beings. + +4. Disclaimer + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +5. Limitation of Liability + +EXCEPT TO THE EXTENT PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER BASED IN TORT, NEGLIGENCE, CONTRACT, LIABILITY, OR OTHERWISE WILL ANY LICENSOR BE LIABLE TO YOU FOR ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES, OR ANY OTHER COMMERCIAL LOSSES, EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. + +6. Dispute Resolution + +This license shall be governed and construed in accordance with the laws of People’s Republic of China. Any dispute arising from or in connection with this License shall be submitted to Haidian District People's Court in Beijing. + +Note that the license is subject to update to a more comprehensive version. For any questions related to the license and copyright, please contact us at glm-130b@googlegroups.com. diff --git a/tests/script/models/chatglm-6b/README.md b/tests/script/models/chatglm-6b/README.md new file mode 100644 index 0000000..857edd7 --- /dev/null +++ b/tests/script/models/chatglm-6b/README.md @@ -0,0 +1,89 @@ +--- +language: +- zh +- en +tags: +- glm +- chatglm +- thudm +--- +# ChatGLM-6B +

+ 🌐 Blog • 💻 Github Repo • 🐦 Twitter • 📃 [GLM@ACL 22] [GitHub] • 📃 [GLM-130B@ICLR 23] [GitHub]
+

+ +

+ 👋 Join our Slack and WeChat +

+ +## 介绍 +ChatGLM-6B 是一个开源的、支持中英双语问答的对话语言模型,基于 [General Language Model (GLM)](https://github.com/THUDM/GLM) 架构,具有 62 亿参数。结合模型量化技术,用户可以在消费级的显卡上进行本地部署(INT4 量化级别下最低只需 6GB 显存)。ChatGLM-6B 使用了和 [ChatGLM](https://chatglm.cn) 相同的技术,针对中文问答和对话进行了优化。经过约 1T 标识符的中英双语训练,辅以监督微调、反馈自助、人类反馈强化学习等技术的加持,62 亿参数的 ChatGLM-6B 已经能生成相当符合人类偏好的回答。 + +ChatGLM-6B is an open bilingual language model based on [General Language Model (GLM)](https://github.com/THUDM/GLM) framework, with 6.2 billion parameters. With the quantization technique, users can deploy locally on consumer-grade graphics cards (only 6GB of GPU memory is required at the INT4 quantization level). ChatGLM-6B uses technology similar to ChatGPT, optimized for Chinese QA and dialogue. The model is trained for about 1T tokens of Chinese and English corpus, supplemented by supervised fine-tuning, feedback bootstrap, and reinforcement learning wit human feedback. With only about 6.2 billion parameters, the model is able to generate answers that are in line with human preference. + +## 软件依赖 + +```shell +pip install protobuf==3.20.0 transformers==4.27.1 icetk cpm_kernels +``` + +## 代码调用 + +可以通过如下代码调用 ChatGLM-6B 模型来生成对话: + +```ipython +>>> from transformers import AutoTokenizer, AutoModel +>>> tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) +>>> model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() +>>> response, history = model.chat(tokenizer, "你好", history=[]) +>>> print(response) +你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。 +>>> response, history = model.chat(tokenizer, "晚上睡不着应该怎么办", history=history) +>>> print(response) +晚上睡不着可能会让你感到焦虑或不舒服,但以下是一些可以帮助你入睡的方法: + +1. 制定规律的睡眠时间表:保持规律的睡眠时间表可以帮助你建立健康的睡眠习惯,使你更容易入睡。尽量在每天的相同时间上床,并在同一时间起床。 +2. 创造一个舒适的睡眠环境:确保睡眠环境舒适,安静,黑暗且温度适宜。可以使用舒适的床上用品,并保持房间通风。 +3. 放松身心:在睡前做些放松的活动,例如泡个热水澡,听些轻柔的音乐,阅读一些有趣的书籍等,有助于缓解紧张和焦虑,使你更容易入睡。 +4. 避免饮用含有咖啡因的饮料:咖啡因是一种刺激性物质,会影响你的睡眠质量。尽量避免在睡前饮用含有咖啡因的饮料,例如咖啡,茶和可乐。 +5. 避免在床上做与睡眠无关的事情:在床上做些与睡眠无关的事情,例如看电影,玩游戏或工作等,可能会干扰你的睡眠。 +6. 尝试呼吸技巧:深呼吸是一种放松技巧,可以帮助你缓解紧张和焦虑,使你更容易入睡。试着慢慢吸气,保持几秒钟,然后缓慢呼气。 + +如果这些方法无法帮助你入睡,你可以考虑咨询医生或睡眠专家,寻求进一步的建议。 +``` + +关于更多的使用说明,包括如何运行命令行和网页版本的 DEMO,以及使用模型量化以节省显存,请参考我们的 [Github Repo](https://github.com/THUDM/ChatGLM-6B)。 + +For more instructions, including how to run CLI and web demos, and model quantization, please refer to our [Github Repo](https://github.com/THUDM/ChatGLM-6B). + +## Change Log +* v1.1.0 ([942945d](https://huggingface.co/THUDM/chatglm-6b/commit/942945df047dee66f653c68ae0e56655045f1741)): 更新 v1.1 版本 checkpoint +* v0.1.0 ([f831824](https://huggingface.co/THUDM/chatglm-6b/commit/f83182484538e663a03d3f73647f10f89878f438)) + +## 协议 + +本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源,ChatGLM-6B 模型的权重的使用则需要遵循 [Model License](MODEL_LICENSE)。 + +## 引用 + +如果你觉得我们的工作有帮助的话,请考虑引用下列论文: + +``` +@inproceedings{ + zeng2023glm-130b, + title={{GLM}-130B: An Open Bilingual Pre-trained Model}, + author={Aohan Zeng and Xiao Liu and Zhengxiao Du and Zihan Wang and Hanyu Lai and Ming Ding and Zhuoyi Yang and Yifan Xu and Wendi Zheng and Xiao Xia and Weng Lam Tam and Zixuan Ma and Yufei Xue and Jidong Zhai and Wenguang Chen and Zhiyuan Liu and Peng Zhang and Yuxiao Dong and Jie Tang}, + booktitle={The Eleventh International Conference on Learning Representations (ICLR)}, + year={2023}, + url={https://openreview.net/forum?id=-Aw0rrrPUF} +} +``` +``` +@inproceedings{du2022glm, + title={GLM: General Language Model Pretraining with Autoregressive Blank Infilling}, + author={Du, Zhengxiao and Qian, Yujie and Liu, Xiao and Ding, Ming and Qiu, Jiezhong and Yang, Zhilin and Tang, Jie}, + booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)}, + pages={320--335}, + year={2022} +} +``` \ No newline at end of file diff --git a/tests/script/models/chatglm-6b/config.json b/tests/script/models/chatglm-6b/config.json new file mode 100644 index 0000000..7cc6e70 --- /dev/null +++ b/tests/script/models/chatglm-6b/config.json @@ -0,0 +1,28 @@ +{ + "_name_or_path": "THUDM/chatglm-6b", + "architectures": [ + "ChatGLMModel" + ], + "auto_map": { + "AutoConfig": "configuration_chatglm.ChatGLMConfig", + "AutoModel": "modeling_chatglm.ChatGLMForConditionalGeneration", + "AutoModelForSeq2SeqLM": "modeling_chatglm.ChatGLMForConditionalGeneration" + }, + "bos_token_id": 130004, + "eos_token_id": 130005, + "mask_token_id": 130000, + "gmask_token_id": 130001, + "pad_token_id": 3, + "hidden_size": 4096, + "inner_hidden_size": 16384, + "layernorm_epsilon": 1e-05, + "max_sequence_length": 2048, + "model_type": "chatglm", + "num_attention_heads": 32, + "num_layers": 28, + "position_encoding_2d": true, + "torch_dtype": "float32", + "transformers_version": "4.23.1", + "use_cache": true, + "vocab_size": 130528 +} diff --git a/tests/script/models/chatglm-6b/configuration_chatglm.py b/tests/script/models/chatglm-6b/configuration_chatglm.py new file mode 100644 index 0000000..78f3425 --- /dev/null +++ b/tests/script/models/chatglm-6b/configuration_chatglm.py @@ -0,0 +1,103 @@ +""" ChatGLM model configuration """ + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class ChatGLMConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`~ChatGLMModel`]. + It is used to instantiate an ChatGLM model according to the specified arguments, defining the model + architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of + the ChatGLM-6B [THUDM/ChatGLM-6B](https://huggingface.co/THUDM/chatglm-6b) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used + to control the model outputs. Read the documentation from [`PretrainedConfig`] + for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 150528): + Vocabulary size of the ChatGLM-6B model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`~ChatGLMModel`] or + [`~TFChatGLMModel`]. + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 28): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + inner_hidden_size (`int`, *optional*, defaults to 16384): + Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + max_sequence_length (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. + Typically set this to something large just in case (e.g., 512 or 1024 or 2048). + layernorm_epsilon (`float`, *optional*, defaults to 1e-5): + The epsilon used by the layer normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether the model should return the last key/values attentions (not used by all models). + Example: + + ```python + >>> from configuration_chatglm import ChatGLMConfig + >>> from modeling_chatglm import ChatGLMModel + + >>> # Initializing a ChatGLM-6B THUDM/ChatGLM-6B style configuration + >>> configuration = ChatGLMConfig() + + >>> # Initializing a model from the THUDM/ChatGLM-6B style configuration + >>> model = ChatGLMModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` +""" + model_type = "chatglm" + + def __init__( + self, + vocab_size=150528, + hidden_size=4096, + num_layers=28, + num_attention_heads=32, + layernorm_epsilon=1e-5, + use_cache=False, + bos_token_id=150004, + eos_token_id=150005, + mask_token_id=150000, + gmask_token_id=150001, + pad_token_id=0, + max_sequence_length=2048, + inner_hidden_size=16384, + position_encoding_2d=True, + quantization_bit=0, + pre_seq_len=None, + prefix_projection=False, + **kwargs + ): + self.num_layers = num_layers + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.max_sequence_length = max_sequence_length + self.layernorm_epsilon = layernorm_epsilon + self.inner_hidden_size = inner_hidden_size + self.use_cache = use_cache + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.mask_token_id = mask_token_id + self.gmask_token_id = gmask_token_id + self.position_encoding_2d = position_encoding_2d + self.quantization_bit = quantization_bit + self.pre_seq_len = pre_seq_len + self.prefix_projection = prefix_projection + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + **kwargs + ) diff --git a/tests/script/models/chatglm-6b/modeling_chatglm.org.py b/tests/script/models/chatglm-6b/modeling_chatglm.org.py new file mode 100644 index 0000000..d24bf52 --- /dev/null +++ b/tests/script/models/chatglm-6b/modeling_chatglm.org.py @@ -0,0 +1,1450 @@ +""" PyTorch ChatGLM model. """ + +import math +import copy +import os +import warnings +import re +import sys + +import torch +import torch.utils.checkpoint +import torch.nn.functional as F +from torch import nn +from torch.nn import CrossEntropyLoss, LayerNorm +from torch.nn.utils import skip_init +from typing import Optional, Tuple, Union, List, Callable, Dict, Any + +from transformers.utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + BaseModelOutputWithPastAndCrossAttentions, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging +from transformers.generation.logits_process import LogitsProcessor +from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput + +from .configuration_chatglm import ChatGLMConfig + +# flags required to enable jit fusion kernels + +if sys.platform != 'darwin': + torch._C._jit_set_profiling_mode(False) + torch._C._jit_set_profiling_executor(False) + torch._C._jit_override_can_fuse_on_cpu(True) + torch._C._jit_override_can_fuse_on_gpu(True) + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "THUDM/ChatGLM-6B" +_CONFIG_FOR_DOC = "ChatGLM6BConfig" + +CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "THUDM/chatglm-6b", + # See all ChatGLM-6B models at https://huggingface.co/models?filter=chatglm +] + + +class InvalidScoreLogitsProcessor(LogitsProcessor): + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + if torch.isnan(scores).any() or torch.isinf(scores).any(): + scores.zero_() + scores[..., 5] = 5e4 + return scores + + +def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array) + + for name, array in zip(names, arrays): + name = name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name + ): + logger.info(f"Skipping {'/'.join(name)}") + continue + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "kernel" or scope_names[0] == "gamma": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "output_weights": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info(f"Skipping {'/'.join(name)}") + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if m_name[-11:] == "_embeddings": + pointer = getattr(pointer, "weight") + elif m_name == "kernel": + array = np.transpose(array) + try: + assert ( + pointer.shape == array.shape + ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array) + return model + + +class PrefixEncoder(torch.nn.Module): + """ + The torch.nn model to encode the prefix + Input shape: (batch-size, prefix-length) + Output shape: (batch-size, prefix-length, 2*layers*hidden) + """ + + def __init__(self, config): + super().__init__() + self.prefix_projection = config.prefix_projection + if self.prefix_projection: + # Use a two-layer MLP to encode the prefix + self.embedding = torch.nn.Embedding(config.pre_seq_len, config.hidden_size) + self.trans = torch.nn.Sequential( + torch.nn.Linear(config.hidden_size, config.hidden_size), + torch.nn.Tanh(), + torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2) + ) + else: + self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_layers * config.hidden_size * 2) + + def forward(self, prefix: torch.Tensor): + if self.prefix_projection: + prefix_tokens = self.embedding(prefix) + past_key_values = self.trans(prefix_tokens) + else: + past_key_values = self.embedding(prefix) + return past_key_values + + +@torch.jit.script +def gelu_impl(x): + """OpenAI's gelu implementation.""" + return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * + (1.0 + 0.044715 * x * x))) + + +def gelu(x): + return gelu_impl(x) + + +class RotaryEmbedding(torch.nn.Module): + def __init__(self, dim, base=10000, precision=torch.half, learnable=False): + super().__init__() + inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) + inv_freq = inv_freq.half() + self.learnable = learnable + if learnable: + self.inv_freq = torch.nn.Parameter(inv_freq) + self.max_seq_len_cached = None + else: + self.register_buffer('inv_freq', inv_freq) + self.max_seq_len_cached = None + self.cos_cached = None + self.sin_cached = None + self.precision = precision + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, + error_msgs): + pass + + def forward(self, x, seq_dim=1, seq_len=None): + if seq_len is None: + seq_len = x.shape[seq_dim] + if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached): + self.max_seq_len_cached = None if self.learnable else seq_len + t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype) + freqs = torch.einsum('i,j->ij', t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + if self.precision == torch.bfloat16: + emb = emb.float() + + # [sx, 1 (b * np), hn] + cos_cached = emb.cos()[:, None, :] + sin_cached = emb.sin()[:, None, :] + if self.precision == torch.bfloat16: + cos_cached = cos_cached.bfloat16() + sin_cached = sin_cached.bfloat16() + if self.learnable: + return cos_cached, sin_cached + self.cos_cached, self.sin_cached = cos_cached, sin_cached + return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...] + + def _apply(self, fn): + if self.cos_cached is not None: + self.cos_cached = fn(self.cos_cached) + if self.sin_cached is not None: + self.sin_cached = fn(self.sin_cached) + return super()._apply(fn) + + +def rotate_half(x): + x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions + + +@torch.jit.script +def apply_rotary_pos_emb_index(q, k, cos, sin, position_id): + # position_id: [b, sq], q, k: [b, sq, np, hn], cos: [sq, 1, hn] -> [b, sq, 1, hn] + cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), \ + F.embedding(position_id, sin.squeeze(1)).unsqueeze(2) + q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) + return q, k + + +def attention_fn( + self, + query_layer, + key_layer, + value_layer, + attention_mask, + hidden_size_per_partition, + layer_id, + layer_past=None, + scaling_attention_score=True, + use_cache=False, +): + if layer_past is not None: + past_key, past_value = layer_past[0], layer_past[1] + key_layer = torch.cat((past_key, key_layer), dim=1) + value_layer = torch.cat((past_value, value_layer), dim=1) + + # batch, seqlen, num_attention_heads, hidden_size_per_attention_head + b, seq_len, nh, hidden_size = key_layer.shape + + if use_cache: + present = (key_layer, value_layer) + else: + present = None + + query_key_layer_scaling_coeff = float(layer_id + 1) + if scaling_attention_score: + query_layer = query_layer / (math.sqrt(hidden_size) * query_key_layer_scaling_coeff) + + # =================================== + # Raw attention scores. [b, np, s, s] + # =================================== + + # [b, np, sq, sk] + output_size = (query_layer.size(0), query_layer.size(2), query_layer.size(1), key_layer.size(1)) + + # # [sq, b, np, hn] -> [sq, b * np, hn] + # query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) + # # [sk, b, np, hn] -> [sk, b * np, hn] + # key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) + # [b, sq, np, hn] -> [b, np, sq, hn] + query_layer = query_layer.transpose(1, 2) + # [b, sk, np, hn] -> [b, np, sk, hn] + key_layer = key_layer.transpose(1, 2) + # [b, np, sq, hn] -> [b * np, sq, hn] + query_layer = query_layer.view(output_size[0] * output_size[1], output_size[2], -1) + # [b, np, sk, hn] -> [b * np, sk, hn] + key_layer = key_layer.view(output_size[0] * output_size[1], output_size[3], -1) + + matmul_result = torch.zeros( + 1, 1, 1, + dtype=query_layer.dtype, + device=query_layer.device, + ) + + matmul_result = torch.baddbmm( + matmul_result, + query_layer, # [b * np, sq, hn] + key_layer.transpose(1, 2), # [b * np, hn, sk] + beta=0.0, + alpha=1.0, + ) + + # change view to [b, np, sq, sk] + attention_scores = matmul_result.view(*output_size) + + if self.scale_mask_softmax: + self.scale_mask_softmax.scale = query_key_layer_scaling_coeff + attention_probs = self.scale_mask_softmax(attention_scores, attention_mask.contiguous()) + else: + if not (attention_mask == 0).all(): + # if auto-regressive, skip + attention_scores.masked_fill_(attention_mask, -10000.0) + dtype = attention_scores.dtype + attention_scores = attention_scores.float() + attention_scores = attention_scores * query_key_layer_scaling_coeff + + attention_probs = F.softmax(attention_scores, dim=-1) + + attention_probs = attention_probs.type(dtype) + + # ========================= + # Context layer. [sq, b, hp] + # ========================= + + # value_layer -> context layer. + # [b, sk, np, hn] --> [b, np, sq, hn] + + # context layer shape: [b, np, sq, hn] + output_size = (value_layer.size(0), value_layer.size(2), query_layer.size(1), value_layer.size(3)) + + # # change view [sk, b * np, hn] + # value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) + # [b, sk, np, hn] -> [b, np, sk, hn] + value_layer = value_layer.transpose(1, 2) + # [b, np, sk, hn] -> [b * np, sk, hn] + value_layer = value_layer.view(output_size[0] * output_size[1], value_layer.size(2), output_size[3]) + + # change view [b * np, sq, sk] + attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) + + # matmul: [b * np, sq, hn] + context_layer = torch.bmm(attention_probs, value_layer) + + # change view [b, np, sq, hn] + context_layer = context_layer.view(*output_size) + + # [b, np, sq, hn] --> [b, sq, np, hn] + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + + # [b, sq, np, hn] --> [b, sq, hp] + new_context_layer_shape = context_layer.size()[:-2] + (hidden_size_per_partition,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, present, attention_probs) + + return outputs + + +def default_init(cls, *args, **kwargs): + return cls(*args, **kwargs) + + +class SelfAttention(torch.nn.Module): + def __init__(self, hidden_size, num_attention_heads, + layer_id, hidden_size_per_attention_head=None, bias=True, + params_dtype=torch.float, position_encoding_2d=True, empty_init=True): + if empty_init: + init_method = skip_init + else: + init_method = default_init + super(SelfAttention, self).__init__() + + self.layer_id = layer_id + self.hidden_size = hidden_size + self.hidden_size_per_partition = hidden_size + self.num_attention_heads = num_attention_heads + self.num_attention_heads_per_partition = num_attention_heads + self.position_encoding_2d = position_encoding_2d + self.rotary_emb = RotaryEmbedding( + self.hidden_size // (self.num_attention_heads * 2) + if position_encoding_2d + else self.hidden_size // self.num_attention_heads, + base=10000, + precision=torch.half, + learnable=False, + ) + + self.scale_mask_softmax = None + + if hidden_size_per_attention_head is None: + self.hidden_size_per_attention_head = hidden_size // num_attention_heads + else: + self.hidden_size_per_attention_head = hidden_size_per_attention_head + + self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head + + # Strided linear layer. + self.query_key_value = init_method( + torch.nn.Linear, + hidden_size, + 3 * self.inner_hidden_size, + bias=bias, + dtype=params_dtype, + ) + + self.dense = init_method( + torch.nn.Linear, + self.inner_hidden_size, + hidden_size, + bias=bias, + dtype=params_dtype, + ) + + @staticmethod + def attention_mask_func(attention_scores, attention_mask): + attention_scores.masked_fill_(attention_mask, -10000.0) + return attention_scores + + def split_tensor_along_last_dim(self, tensor, num_partitions, + contiguous_split_chunks=False): + """Split a tensor along its last dimension. + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + """ + # Get the size and dimension. + last_dim = tensor.dim() - 1 + last_dim_size = tensor.size()[last_dim] // num_partitions + # Split. + tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) + # Note: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tensor_list + + def forward( + self, + hidden_states: torch.Tensor, + position_ids, + attention_mask: torch.Tensor, + layer_id, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + use_cache: bool = False, + output_attentions: bool = False, + ): + """ + hidden_states: [batch, seq_len, hidden_size] + attention_mask: [(1, 1), seq_len, seq_len] + """ + + # [batch, seq_len, 3 * hidden_size] + mixed_raw_layer = self.query_key_value(hidden_states) + + # [batch, seq_len, 3 * hidden_size] --> [batch, seq_len, num_attention_heads, 3 * hidden_size_per_attention_head] + new_tensor_shape = mixed_raw_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head, + ) + mixed_raw_layer = mixed_raw_layer.view(*new_tensor_shape) + + # [batch, seq_len, num_attention_heads, hidden_size_per_attention_head] + (query_layer, key_layer, value_layer) = self.split_tensor_along_last_dim(mixed_raw_layer, 3) + + if self.position_encoding_2d: + q1, q2 = query_layer.chunk(2, dim=(query_layer.ndim - 1)) + k1, k2 = key_layer.chunk(2, dim=(key_layer.ndim - 1)) + cos, sin = self.rotary_emb(q1, seq_len=position_ids.max() + 1) + # xxx + position_ids, block_position_ids = position_ids[:, 0, :].contiguous(), \ + position_ids[:, 1, :].contiguous() + q1, k1 = apply_rotary_pos_emb_index(q1, k1, cos, sin, position_ids) + q2, k2 = apply_rotary_pos_emb_index(q2, k2, cos, sin, block_position_ids) + query_layer = torch.concat([q1, q2], dim=(q1.ndim - 1)) + key_layer = torch.concat([k1, k2], dim=(k1.ndim - 1)) + else: + position_ids = position_ids.transpose(0, 1) + cos, sin = self.rotary_emb(value_layer, seq_len=position_ids.max() + 1) + # [batch, seq_len, num_attention_heads, hidden_size_per_attention_head] + query_layer, key_layer = apply_rotary_pos_emb_index(query_layer, key_layer, cos, sin, position_ids) + + # [batch, seq_len, hidden_size] + context_layer, present, attention_probs = attention_fn( + self=self, + query_layer=query_layer, + key_layer=key_layer, + value_layer=value_layer, + attention_mask=attention_mask, + hidden_size_per_partition=self.hidden_size_per_partition, + layer_id=layer_id, + layer_past=layer_past, + use_cache=use_cache + ) + + output = self.dense(context_layer) + + outputs = (output, present) + + if output_attentions: + outputs += (attention_probs,) + + return outputs # output, present, attention_probs + + +class GEGLU(torch.nn.Module): + def __init__(self): + super().__init__() + self.activation_fn = F.gelu + + def forward(self, x): + # dim=-1 breaks in jit for pt<1.10 + x1, x2 = x.chunk(2, dim=(x.ndim - 1)) + return x1 * self.activation_fn(x2) + + +class GLU(torch.nn.Module): + def __init__(self, hidden_size, inner_hidden_size=None, + layer_id=None, bias=True, activation_func=gelu, params_dtype=torch.float, empty_init=True): + super(GLU, self).__init__() + if empty_init: + init_method = skip_init + else: + init_method = default_init + self.layer_id = layer_id + self.activation_func = activation_func + + # Project to 4h. + self.hidden_size = hidden_size + if inner_hidden_size is None: + inner_hidden_size = 4 * hidden_size + self.inner_hidden_size = inner_hidden_size + self.dense_h_to_4h = init_method( + torch.nn.Linear, + self.hidden_size, + self.inner_hidden_size, + bias=bias, + dtype=params_dtype, + ) + # Project back to h. + self.dense_4h_to_h = init_method( + torch.nn.Linear, + self.inner_hidden_size, + self.hidden_size, + bias=bias, + dtype=params_dtype, + ) + + def forward(self, hidden_states): + """ + hidden_states: [seq_len, batch, hidden_size] + """ + + # [seq_len, batch, inner_hidden_size] + intermediate_parallel = self.dense_h_to_4h(hidden_states) + + intermediate_parallel = self.activation_func(intermediate_parallel) + + output = self.dense_4h_to_h(intermediate_parallel) + + return output + + +class GLMBlock(torch.nn.Module): + def __init__( + self, + hidden_size, + num_attention_heads, + layernorm_epsilon, + layer_id, + inner_hidden_size=None, + hidden_size_per_attention_head=None, + layernorm=LayerNorm, + use_bias=True, + params_dtype=torch.float, + num_layers=28, + position_encoding_2d=True, + empty_init=True + ): + super(GLMBlock, self).__init__() + # Set output layer initialization if not provided. + + self.layer_id = layer_id + + # Layernorm on the input data. + self.input_layernorm = layernorm(hidden_size, eps=layernorm_epsilon) + + self.position_encoding_2d = position_encoding_2d + + # Self attention. + self.attention = SelfAttention( + hidden_size, + num_attention_heads, + layer_id, + hidden_size_per_attention_head=hidden_size_per_attention_head, + bias=use_bias, + params_dtype=params_dtype, + position_encoding_2d=self.position_encoding_2d, + empty_init=empty_init + ) + + # Layernorm on the input data. + self.post_attention_layernorm = layernorm(hidden_size, eps=layernorm_epsilon) + + self.num_layers = num_layers + + # GLU + self.mlp = GLU( + hidden_size, + inner_hidden_size=inner_hidden_size, + bias=use_bias, + layer_id=layer_id, + params_dtype=params_dtype, + empty_init=empty_init + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids, + attention_mask: torch.Tensor, + layer_id, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + use_cache: bool = False, + output_attentions: bool = False, + ): + """ + hidden_states: [seq_len, batch, hidden_size] + attention_mask: [(1, 1), seq_len, seq_len] + """ + + # Layer norm at the begining of the transformer layer. + # [batch, seq_len, hidden_size] + attention_input = self.input_layernorm(hidden_states) + + # Self attention. + attention_outputs = self.attention( + attention_input, + position_ids, + attention_mask=attention_mask, + layer_id=layer_id, + layer_past=layer_past, + use_cache=use_cache, + output_attentions=output_attentions + ) + + attention_output = attention_outputs[0] + + outputs = attention_outputs[1:] + + # Residual connection. + alpha = (2 * self.num_layers) ** 0.5 + hidden_states = attention_input * alpha + attention_output + + mlp_input = self.post_attention_layernorm(hidden_states) + + # MLP. + mlp_output = self.mlp(mlp_input) + + # Second residual connection. + output = mlp_input * alpha + mlp_output + + if use_cache: + outputs = (output,) + outputs + else: + outputs = (output,) + outputs[1:] + + return outputs # hidden_states, present, attentions + + +class ChatGLMPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and + a simple interface for downloading and loading pretrained models. + """ + + is_parallelizable = False + supports_gradient_checkpointing = True + config_class = ChatGLMConfig + base_model_prefix = "transformer" + _no_split_modules = ["GLMBlock"] + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights(self, module: nn.Module): + """Initialize the weights.""" + return + + def get_masks(self, input_ids, device): + batch_size, seq_length = input_ids.shape + context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids] + attention_mask = torch.ones((batch_size, seq_length, seq_length), device=device) + attention_mask.tril_() + for i, context_length in enumerate(context_lengths): + attention_mask[i, :, :context_length] = 1 + attention_mask.unsqueeze_(1) + attention_mask = (attention_mask < 0.5).bool() + + return attention_mask + + def get_position_ids(self, input_ids, mask_positions, device, use_gmasks=None): + batch_size, seq_length = input_ids.shape + if use_gmasks is None: + use_gmasks = [False] * batch_size + context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids] + if self.position_encoding_2d: + position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) + for i, context_length in enumerate(context_lengths): + position_ids[i, context_length:] = mask_positions[i] + block_position_ids = [torch.cat(( + torch.zeros(context_length, dtype=torch.long, device=device), + torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1 + )) for context_length in context_lengths] + block_position_ids = torch.stack(block_position_ids, dim=0) + position_ids = torch.stack((position_ids, block_position_ids), dim=1) + else: + position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) + for i, context_length in enumerate(context_lengths): + if not use_gmasks[i]: + position_ids[i, context_length:] = mask_positions[i] + + return position_ids + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, ChatGLMModel): + module.gradient_checkpointing = value + + +CHATGLM_6B_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general + usage and behavior. + + Parameters: + config ([`~ChatGLM6BConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the configuration. + Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +CHATGLM_6B_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`ChatGLM6BTokenizer`]. + See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. + Selected in the range `[0, config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert *input_ids* indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare ChatGLM-6B Model transformer outputting raw hidden-states without any specific head on top.", + CHATGLM_6B_START_DOCSTRING, +) +class ChatGLMModel(ChatGLMPreTrainedModel): + """ + + The model can behave as an encoder (with only self-attention) as well + as a decoder, in which case a layer of cross-attention is added between + the self-attention layers, following the architecture described in [Attention is + all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, + Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + + To behave as an decoder the model needs to be initialized with the + `is_decoder` argument of the configuration set to `True`. + To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` + argument and `add_cross_attention` set to `True`; an + `encoder_hidden_states` is then expected as an input to the forward pass. + """ + + def __init__(self, config: ChatGLMConfig, empty_init=True): + super().__init__(config) + if empty_init: + init_method = skip_init + else: + init_method = default_init + # recording parameters + self.max_sequence_length = config.max_sequence_length + self.hidden_size = config.hidden_size + self.params_dtype = torch.half + self.num_attention_heads = config.num_attention_heads + self.vocab_size = config.vocab_size + self.num_layers = config.num_layers + self.layernorm_epsilon = config.layernorm_epsilon + self.inner_hidden_size = config.inner_hidden_size + self.hidden_size_per_attention_head = self.hidden_size // self.num_attention_heads + self.position_encoding_2d = config.position_encoding_2d + self.pre_seq_len = config.pre_seq_len + self.prefix_projection = config.prefix_projection + + self.word_embeddings = init_method( + torch.nn.Embedding, + num_embeddings=self.vocab_size, embedding_dim=self.hidden_size, + dtype=self.params_dtype + ) + self.gradient_checkpointing = False + + def get_layer(layer_id): + return GLMBlock( + self.hidden_size, + self.num_attention_heads, + self.layernorm_epsilon, + layer_id, + inner_hidden_size=self.inner_hidden_size, + hidden_size_per_attention_head=self.hidden_size_per_attention_head, + layernorm=LayerNorm, + use_bias=True, + params_dtype=self.params_dtype, + position_encoding_2d=self.position_encoding_2d, + empty_init=empty_init + ) + + self.layers = torch.nn.ModuleList( + [get_layer(layer_id) for layer_id in range(self.num_layers)] + ) + + # Final layer norm before output. + self.final_layernorm = LayerNorm(self.hidden_size, eps=self.layernorm_epsilon) + + if self.pre_seq_len is not None: + for param in self.parameters(): + param.requires_grad = False + self.prefix_tokens = torch.arange(self.pre_seq_len).long() + self.prefix_encoder = PrefixEncoder(config) + self.dropout = torch.nn.Dropout(0.1) + + # total_params = sum(p.numel() for p in self.parameters()) + # trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) + # print("Using p-tuning v2: # trainable_params = {} / {}".format(trainable_params, total_params)) + + def get_input_embeddings(self): + return self.word_embeddings + + def set_input_embeddings(self, new_embeddings: torch.Tensor): + self.word_embeddings = new_embeddings + + def get_prompt(self, batch_size, device, dtype=torch.half): + prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device) + past_key_values = self.prefix_encoder(prefix_tokens).type(dtype) + past_key_values = past_key_values.view( + batch_size, + self.pre_seq_len, + self.num_layers * 2, + self.num_attention_heads, + self.hidden_size // self.num_attention_heads + ) + # seq_len, b, nh, hidden_size + past_key_values = self.dropout(past_key_values) + past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2) + # past_key_values = [(v[0], v[1]) for v in past_key_values] + return past_key_values + + @add_start_docstrings_to_model_forward(CHATGLM_6B_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPastAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPast]: + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length = inputs_embeds.shape[:2] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + if past_key_values is None: + if self.pre_seq_len is not None: + past_key_values = self.get_prompt(batch_size=input_ids.shape[0], device=input_ids.device, + dtype=inputs_embeds.dtype) + else: + past_key_values = tuple([None] * len(self.layers)) + + if attention_mask is None: + attention_mask = self.get_masks( + input_ids, + device=input_ids.device + ) + + + if position_ids is None: + MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id + seqs = input_ids.tolist() + + mask_positions, use_gmasks = [], [] + for seq in seqs: + mask_token = gMASK if gMASK in seq else MASK + use_gmask = mask_token == gMASK + mask_positions.append(seq.index(mask_token)) + use_gmasks.append(use_gmask) + + position_ids = self.get_position_ids( + input_ids, + mask_positions=mask_positions, + device=input_ids.device, + use_gmasks=use_gmasks + ) + + if self.pre_seq_len is not None and attention_mask is not None: + prefix_attention_mask = torch.ones(batch_size, 1, input_ids.size(-1), self.pre_seq_len).to( + attention_mask.device) + prefix_attention_mask = (prefix_attention_mask < 0.5).bool() + attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3) + + # [seq_len, batch, hidden_size] + # hidden_states = inputs_embeds.transpose(0, 1) + # xxx + hidden_states = inputs_embeds + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + if attention_mask is None: + attention_mask = torch.zeros(1, 1, device=input_ids.device).bool() + else: + attention_mask = attention_mask.to(hidden_states.device) + + for i, layer in enumerate(self.layers): + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + layer_past = past_key_values[i] + + if self.gradient_checkpointing and self.training: + layer_ret = torch.utils.checkpoint.checkpoint( + layer, + hidden_states, + position_ids, + attention_mask, + torch.tensor(i), + layer_past, + use_cache, + output_attentions + ) + else: + layer_ret = layer( + hidden_states, + position_ids=position_ids, + attention_mask=attention_mask, + layer_id=torch.tensor(i), + layer_past=layer_past, + use_cache=use_cache, + output_attentions=output_attentions + ) + + hidden_states = layer_ret[0] + + if use_cache: + presents = presents + (layer_ret[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_ret[2 if use_cache else 1],) + + # Final layer norm. + hidden_states = self.final_layernorm(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): + def __init__(self, config: ChatGLMConfig, empty_init=True): + super().__init__(config) + if empty_init: + init_method = skip_init + else: + init_method = default_init + + # self.hidden_size = config.hidden_size + # self.params_dtype = torch.half + # self.vocab_size = config.vocab_size + self.max_sequence_length = config.max_sequence_length + + self.position_encoding_2d = config.position_encoding_2d + + self.transformer = ChatGLMModel(config, empty_init=empty_init) + + self.lm_head = init_method( + nn.Linear, + config.hidden_size, + config.vocab_size, + bias=False, + dtype=torch.half + ) + + self.config = config + + self.quantized = False + + if self.config.quantization_bit: + self.quantize(self.config.quantization_bit, empty_init=True) + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def _update_model_kwargs_for_generation( + self, + outputs: ModelOutput, + model_kwargs: Dict[str, Any], + is_encoder_decoder: bool = False, + standardize_cache_format: bool = False, + ) -> Dict[str, Any]: + # update past_key_values + model_kwargs["past_key_values"] = self._extract_past_from_model_output( + outputs, standardize_cache_format=standardize_cache_format + ) + + # update attention mask + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + if attention_mask is not None and attention_mask.dtype == torch.bool: + attention_mask = torch.cat( + [attention_mask, attention_mask.new_ones((*attention_mask.shape[:3], 1))], dim=3) + new_attention_mask = attention_mask[:, :, -1:].clone() + new_attention_mask[..., -1] = False + model_kwargs["attention_mask"] = torch.cat( + [attention_mask, new_attention_mask], dim=2 + ) + + # update position ids + if "position_ids" in model_kwargs: + position_ids = model_kwargs["position_ids"] + new_position_id = position_ids[..., -1:].clone() + new_position_id[:, 1, :] += 1 + model_kwargs["position_ids"] = torch.cat( + [position_ids, new_position_id], dim=-1 + ) + + return model_kwargs + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor, + past: Optional[torch.Tensor] = None, + past_key_values: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + **kwargs + ) -> dict: + batch_size, seq_length = input_ids.shape + MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id + seqs = input_ids.tolist() + mask_positions, use_gmasks = [], [] + for seq in seqs: + mask_token = gMASK if gMASK in seq else MASK + use_gmask = mask_token == gMASK + mask_positions.append(seq.index(mask_token)) + use_gmasks.append(use_gmask) + + # only last token for input_ids if past is not None + if past is not None or past_key_values is not None: + last_token = input_ids[:, -1].unsqueeze(-1) + if attention_mask is not None and attention_mask.dtype == torch.bool: + attention_mask = attention_mask[:, :, -1:] + else: + attention_mask = None + if position_ids is not None: + position_ids = position_ids[..., -1:] + else: + context_lengths = [seq.index(self.config.bos_token_id) for seq in seqs] + if self.position_encoding_2d: + position_ids = torch.tensor( + [[mask_position, seq_length - context_length] for mask_position, context_length in + zip(mask_positions, context_lengths)], dtype=torch.long, device=input_ids.device).unsqueeze(-1) + else: + position_ids = torch.tensor([mask_position for mask_position in mask_positions], dtype=torch.long, + device=input_ids.device).unsqueeze(-1) + + if past is None: + past = past_key_values + return { + "input_ids": last_token, + "past_key_values": past, + "position_ids": position_ids, + "attention_mask": attention_mask + } + else: + if attention_mask is not None and attention_mask.dtype != torch.bool: + logger.warning_once(f"The dtype of attention mask ({attention_mask.dtype}) is not bool") + attention_mask = None + if attention_mask is None: + attention_mask = self.get_masks( + input_ids, + device=input_ids.device + ) + if position_ids is None: + position_ids = self.get_position_ids( + input_ids, + device=input_ids.device, + mask_positions=mask_positions, + use_gmasks=use_gmasks + ) + + return { + "input_ids": input_ids, + "past_key_values": past, + "position_ids": position_ids, + "attention_mask": attention_mask + } + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + + lm_logits = self.lm_head(hidden_states).contiguous() + + loss = None + if labels is not None: + lm_logits = lm_logits.to(torch.float32) + + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss(ignore_index=-100) + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + lm_logits = lm_logits.to(hidden_states.dtype) + loss = loss.to(hidden_states.dtype) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + @staticmethod + def _reorder_cache( + past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor + ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + + Output shares the same memory storage as `past`. + """ + return tuple( + ( + layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)), + layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)), + ) + for layer_past in past + ) + + def process_response(self, response): + response = response.strip() + response = response.replace("[[训练时间]]", "2023年") + punkts = [ + [",", ","], + ["!", "!"], + [":", ":"], + [";", ";"], + ["\?", "?"], + ] + for item in punkts: + response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response) + response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response) + return response + + @torch.no_grad() + def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1, + do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs): + if history is None: + history = [] + if logits_processor is None: + logits_processor = LogitsProcessorList() + logits_processor.append(InvalidScoreLogitsProcessor()) + gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p, + "temperature": temperature, "logits_processor": logits_processor, **kwargs} + if not history: + prompt = query + else: + prompt = "" + for i, (old_query, response) in enumerate(history): + prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response) + prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) + inputs = tokenizer([prompt], return_tensors="pt") + inputs = inputs.to(self.device) + outputs = self.generate(**inputs, **gen_kwargs) + outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] + response = tokenizer.decode(outputs) + response = self.process_response(response) + history = history + [(query, response)] + return response, history + + @torch.no_grad() + def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, + do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs): + if history is None: + history = [] + if logits_processor is None: + logits_processor = LogitsProcessorList() + logits_processor.append(InvalidScoreLogitsProcessor()) + gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p, + "temperature": temperature, "logits_processor": logits_processor, **kwargs} + if not history: + prompt = query + else: + prompt = "" + for i, (old_query, response) in enumerate(history): + prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response) + prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) + inputs = tokenizer([prompt], return_tensors="pt") + inputs = inputs.to(self.device) + for outputs in self.stream_generate(**inputs, **gen_kwargs): + outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] + response = tokenizer.decode(outputs) + response = self.process_response(response) + new_history = history + [(query, response)] + yield response, new_history + + @torch.no_grad() + def stream_generate( + self, + input_ids, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, + **kwargs, + ): + batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] + + if generation_config is None: + generation_config = self.generation_config + generation_config = copy.deepcopy(generation_config) + model_kwargs = generation_config.update(**kwargs) + bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id + + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + + has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None + if has_default_max_length and generation_config.max_new_tokens is None: + warnings.warn( + f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " + "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" + " recommend using `max_new_tokens` to control the maximum length of the generation.", + UserWarning, + ) + elif generation_config.max_new_tokens is not None: + generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length + if not has_default_max_length: + logger.warn( + f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" + f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " + "Please refer to the documentation for more information. " + "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)", + UserWarning, + ) + + if input_ids_seq_length >= generation_config.max_length: + input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" + logger.warning( + f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" + f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" + " increasing `max_new_tokens`." + ) + + # 2. Set generation parameters if not already defined + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + + logits_processor = self._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids_seq_length, + encoder_input_ids=input_ids, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + logits_processor=logits_processor, + ) + + stopping_criteria = self._get_stopping_criteria( + generation_config=generation_config, stopping_criteria=stopping_criteria + ) + logits_warper = self._get_logits_warper(generation_config) + + unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + scores = None + while True: + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + # forward pass to get next token + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=False, + output_hidden_states=False, + ) + + next_token_logits = outputs.logits[:, -1, :] + + # pre-process distribution + next_token_scores = logits_processor(input_ids, next_token_logits) + next_token_scores = logits_warper(input_ids, next_token_scores) + + # sample + probs = nn.functional.softmax(next_token_scores, dim=-1) + if generation_config.do_sample: + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + else: + next_tokens = torch.argmax(probs, dim=-1) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long()) + + # stop when each sentence is finished, or if we exceed the maximum length + if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): + break + yield input_ids + + def quantize(self, bits: int, empty_init=False, **kwargs): + if bits == 0: + return + + from .quantization import quantize + + if self.quantized: + logger.info("Already quantized.") + return self + + self.quantized = True + + self.config.quantization_bit = bits + + self.transformer = quantize(self.transformer, bits, empty_init=empty_init, **kwargs) + return self diff --git a/tests/script/models/chatglm-6b/modeling_chatglm.py b/tests/script/models/chatglm-6b/modeling_chatglm.py new file mode 100644 index 0000000..b220192 --- /dev/null +++ b/tests/script/models/chatglm-6b/modeling_chatglm.py @@ -0,0 +1,1479 @@ +""" PyTorch ChatGLM model. """ + +import math +import copy +import os +import warnings +import re +import sys + +import torch +import torch.utils.checkpoint +import torch.nn.functional as F +from torch import nn +from torch.nn import CrossEntropyLoss, LayerNorm +from torch.nn.utils import skip_init +from typing import Optional, Tuple, Union, List, Callable, Dict, Any +import llmdnn as ld + +from transformers.utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + BaseModelOutputWithPastAndCrossAttentions, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging +from transformers.generation.logits_process import LogitsProcessor +from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput + +from .configuration_chatglm import ChatGLMConfig + +# flags required to enable jit fusion kernels + +if sys.platform != 'darwin': + torch._C._jit_set_profiling_mode(False) + torch._C._jit_set_profiling_executor(False) + torch._C._jit_override_can_fuse_on_cpu(True) + torch._C._jit_override_can_fuse_on_gpu(True) + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "THUDM/ChatGLM-6B" +_CONFIG_FOR_DOC = "ChatGLM6BConfig" + +CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "THUDM/chatglm-6b", + # See all ChatGLM-6B models at https://huggingface.co/models?filter=chatglm +] + + +class InvalidScoreLogitsProcessor(LogitsProcessor): + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + if torch.isnan(scores).any() or torch.isinf(scores).any(): + scores.zero_() + scores[..., 5] = 5e4 + return scores + + +def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array) + + for name, array in zip(names, arrays): + name = name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name + ): + logger.info(f"Skipping {'/'.join(name)}") + continue + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "kernel" or scope_names[0] == "gamma": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "output_weights": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info(f"Skipping {'/'.join(name)}") + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if m_name[-11:] == "_embeddings": + pointer = getattr(pointer, "weight") + elif m_name == "kernel": + array = np.transpose(array) + try: + assert ( + pointer.shape == array.shape + ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array) + return model + + +class PrefixEncoder(torch.nn.Module): + """ + The torch.nn model to encode the prefix + Input shape: (batch-size, prefix-length) + Output shape: (batch-size, prefix-length, 2*layers*hidden) + """ + + def __init__(self, config): + super().__init__() + self.prefix_projection = config.prefix_projection + if self.prefix_projection: + # Use a two-layer MLP to encode the prefix + self.embedding = torch.nn.Embedding(config.pre_seq_len, config.hidden_size) + self.trans = torch.nn.Sequential( + torch.nn.Linear(config.hidden_size, config.hidden_size), + torch.nn.Tanh(), + torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2) + ) + else: + self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_layers * config.hidden_size * 2) + + def forward(self, prefix: torch.Tensor): + if self.prefix_projection: + prefix_tokens = self.embedding(prefix) + past_key_values = self.trans(prefix_tokens) + else: + past_key_values = self.embedding(prefix) + return past_key_values + + +@torch.jit.script +def gelu_impl(x): + """OpenAI's gelu implementation.""" + return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * + (1.0 + 0.044715 * x * x))) + + +def gelu(x): + return gelu_impl(x) + + +class RotaryEmbedding(torch.nn.Module): + def __init__(self, dim, base=10000, precision=torch.half, learnable=False): + super().__init__() + inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) + inv_freq = inv_freq.half() + self.learnable = learnable + if learnable: + self.inv_freq = torch.nn.Parameter(inv_freq) + self.max_seq_len_cached = None + else: + self.register_buffer('inv_freq', inv_freq) + self.max_seq_len_cached = None + self.cos_cached = None + self.sin_cached = None + self.precision = precision + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, + error_msgs): + pass + + def forward(self, x, seq_dim=1, seq_len=None): + if seq_len is None: + seq_len = x.shape[seq_dim] + if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached): + self.max_seq_len_cached = None if self.learnable else seq_len + t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype) + freqs = torch.einsum('i,j->ij', t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + if self.precision == torch.bfloat16: + emb = emb.float() + + # [sx, 1 (b * np), hn] + cos_cached = emb.cos()[:, None, :] + sin_cached = emb.sin()[:, None, :] + if self.precision == torch.bfloat16: + cos_cached = cos_cached.bfloat16() + sin_cached = sin_cached.bfloat16() + if self.learnable: + return cos_cached, sin_cached + self.cos_cached, self.sin_cached = cos_cached, sin_cached + return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...] + + def _apply(self, fn): + if self.cos_cached is not None: + self.cos_cached = fn(self.cos_cached) + if self.sin_cached is not None: + self.sin_cached = fn(self.sin_cached) + return super()._apply(fn) + + +def rotate_half(x): + x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions + + +@torch.jit.script +def apply_rotary_pos_emb_index(q, k, cos, sin, position_id): + # position_id: [b, sq], q, k: [b, sq, np, hn], cos: [sq, 1, hn] -> [b, sq, 1, hn] + cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), \ + F.embedding(position_id, sin.squeeze(1)).unsqueeze(2) + q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) + return q, k + + +def attention_fn( + self, + query_layer, + key_layer, + value_layer, + attention_mask, + hidden_size_per_partition, + layer_id, + layer_past=None, + scaling_attention_score=True, + use_cache=False, +): + if layer_past is not None: + past_key, past_value = layer_past[0], layer_past[1] + key_layer = torch.cat((past_key, key_layer), dim=1) + value_layer = torch.cat((past_value, value_layer), dim=1) + + # batch, seqlen, num_attention_heads, hidden_size_per_attention_head + b, seq_len, nh, hidden_size = key_layer.shape + + if use_cache: + present = (key_layer, value_layer) + else: + present = None + + query_key_layer_scaling_coeff = float(layer_id + 1) + if scaling_attention_score: + query_layer = query_layer / (math.sqrt(hidden_size) * query_key_layer_scaling_coeff) + + # =================================== + # Raw attention scores. [b, np, s, s] + # =================================== + + # [b, np, sq, sk] + output_size = (query_layer.size(0), query_layer.size(2), query_layer.size(1), key_layer.size(1)) + + # # [sq, b, np, hn] -> [sq, b * np, hn] + # query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) + # # [sk, b, np, hn] -> [sk, b * np, hn] + # key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) + # [b, sq, np, hn] -> [b, np, sq, hn] + query_layer = query_layer.transpose(1, 2) + # [b, sk, np, hn] -> [b, np, sk, hn] + key_layer = key_layer.transpose(1, 2) + # [b, np, sq, hn] -> [b * np, sq, hn] + query_layer = query_layer.view(output_size[0] * output_size[1], output_size[2], -1) + # [b, np, sk, hn] -> [b * np, sk, hn] + key_layer = key_layer.view(output_size[0] * output_size[1], output_size[3], -1) + + matmul_result = torch.zeros( + 1, 1, 1, + dtype=query_layer.dtype, + device=query_layer.device, + ) + + matmul_result = torch.baddbmm( + matmul_result, + query_layer, # [b * np, sq, hn] + key_layer.transpose(1, 2), # [b * np, hn, sk] + beta=0.0, + alpha=1.0, + ) + + # change view to [b, np, sq, sk] + attention_scores = matmul_result.view(*output_size) + + if self.scale_mask_softmax: + self.scale_mask_softmax.scale = query_key_layer_scaling_coeff + attention_probs = self.scale_mask_softmax(attention_scores, attention_mask.contiguous()) + else: + if not (attention_mask == 0).all(): + # if auto-regressive, skip + attention_scores.masked_fill_(attention_mask, -10000.0) + dtype = attention_scores.dtype + attention_scores = attention_scores.float() + attention_scores = attention_scores * query_key_layer_scaling_coeff + + attention_probs = F.softmax(attention_scores, dim=-1) + + attention_probs = attention_probs.type(dtype) + + # ========================= + # Context layer. [sq, b, hp] + # ========================= + + # value_layer -> context layer. + # [b, sk, np, hn] --> [b, np, sq, hn] + + # context layer shape: [b, np, sq, hn] + output_size = (value_layer.size(0), value_layer.size(2), query_layer.size(1), value_layer.size(3)) + + # # change view [sk, b * np, hn] + # value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) + # [b, sk, np, hn] -> [b, np, sk, hn] + value_layer = value_layer.transpose(1, 2) + # [b, np, sk, hn] -> [b * np, sk, hn] + value_layer = value_layer.view(output_size[0] * output_size[1], value_layer.size(2), output_size[3]) + + # change view [b * np, sq, sk] + attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) + + # matmul: [b * np, sq, hn] + context_layer = torch.bmm(attention_probs, value_layer) + + # change view [b, np, sq, hn] + context_layer = context_layer.view(*output_size) + + # [b, np, sq, hn] --> [b, sq, np, hn] + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + + # [b, sq, np, hn] --> [b, sq, hp] + new_context_layer_shape = context_layer.size()[:-2] + (hidden_size_per_partition,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, present, attention_probs) + + return outputs + + +def default_init(cls, *args, **kwargs): + return cls(*args, **kwargs) + + +class SelfAttention(torch.nn.Module): + def __init__(self, hidden_size, num_attention_heads, + layer_id, hidden_size_per_attention_head=None, bias=True, + params_dtype=torch.float, position_encoding_2d=True, empty_init=True, max_sequence_length=2048): + if empty_init: + init_method = skip_init + else: + init_method = default_init + super(SelfAttention, self).__init__() + + self.layer_id = layer_id + self.hidden_size = hidden_size + self.hidden_size_per_partition = hidden_size + self.num_attention_heads = num_attention_heads + self.num_attention_heads_per_partition = num_attention_heads + self.position_encoding_2d = position_encoding_2d + self.rotary_emb = RotaryEmbedding( + self.hidden_size // (self.num_attention_heads * 2) + if position_encoding_2d + else self.hidden_size // self.num_attention_heads, + base=10000, + precision=torch.half, + learnable=False, + ) + + self.scale_mask_softmax = None + + if hidden_size_per_attention_head is None: + self.hidden_size_per_attention_head = hidden_size // num_attention_heads + else: + self.hidden_size_per_attention_head = hidden_size_per_attention_head + + self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head + + # Strided linear layer. + self.query_key_value = init_method( + torch.nn.Linear, + hidden_size, + 3 * self.inner_hidden_size, + bias=bias, + dtype=params_dtype, + ) + + self.dense = init_method( + torch.nn.Linear, + self.inner_hidden_size, + hidden_size, + bias=bias, + dtype=params_dtype, + ) + + self.attn = ld.attn_gpt() + head_size = hidden_size // num_attention_heads + self.head_size_aligned = (head_size + 31) // 32 * 32 + self.max_sequence_length = max_sequence_length + normal_factor = 1.0 / math.sqrt(head_size) + self.attn.create(num_attention_heads, head_size, self.head_size_aligned, + normal_factor, 'bf16', 'bf16', max_sequence_length, 10000, 0.5, True) + self.layer_past_key_padded = None + self.layer_past_value_padded = None + self.past_seq_len = 0 + + @staticmethod + def attention_mask_func(attention_scores, attention_mask): + attention_scores.masked_fill_(attention_mask, -10000.0) + return attention_scores + + def split_tensor_along_last_dim(self, tensor, num_partitions, + contiguous_split_chunks=False): + """Split a tensor along its last dimension. + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + """ + # Get the size and dimension. + last_dim = tensor.dim() - 1 + last_dim_size = tensor.size()[last_dim] // num_partitions + # Split. + tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) + # Note: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tensor_list + + def forward( + self, + hidden_states: torch.Tensor, + position_ids, + attention_mask: torch.Tensor, + layer_id, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + use_cache: bool = False, + output_attentions: bool = False, + ): + """ + hidden_states: [batch, seq_len, hidden_size] + attention_mask: [(1, 1), seq_len, seq_len] + """ + + # [batch, seq_len, 3 * hidden_size] + mixed_raw_layer = self.query_key_value(hidden_states) + position_ids = position_ids.contiguous() + if self.layer_past_key_padded is None: + shape = (hidden_states.size(0), self.num_attention_heads, self.max_sequence_length, self.head_size_aligned) + self.layer_past_key_padded = torch.zeros(shape, dtype=torch.bfloat16) + self.layer_past_value_padded = torch.zeros(shape, dtype=torch.bfloat16) + if layer_past is None: + self.past_seq_len = 0 + context_layer = self.attn.exec_position(mixed_raw_layer, self.layer_past_key_padded, self.layer_past_value_padded, self.past_seq_len, attention_mask, position_ids) + present = (self.layer_past_key_padded, self.layer_past_value_padded) + attention_probs = None + self.past_seq_len += mixed_raw_layer.size(1) + + if 0: + # [batch, seq_len, 3 * hidden_size] --> [batch, seq_len, num_attention_heads, 3 * hidden_size_per_attention_head] + new_tensor_shape = mixed_raw_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head, + ) + mixed_raw_layer = mixed_raw_layer.view(*new_tensor_shape) + + # [batch, seq_len, num_attention_heads, hidden_size_per_attention_head] + (query_layer, key_layer, value_layer) = self.split_tensor_along_last_dim(mixed_raw_layer, 3) + + if self.position_encoding_2d: + q1, q2 = query_layer.chunk(2, dim=(query_layer.ndim - 1)) + k1, k2 = key_layer.chunk(2, dim=(key_layer.ndim - 1)) + cos, sin = self.rotary_emb(q1, seq_len=position_ids.max() + 1) + # xxx + position_ids, block_position_ids = position_ids[:, 0, :].contiguous(), \ + position_ids[:, 1, :].contiguous() + q1, k1 = apply_rotary_pos_emb_index(q1, k1, cos, sin, position_ids) + q2, k2 = apply_rotary_pos_emb_index(q2, k2, cos, sin, block_position_ids) + query_layer = torch.concat([q1, q2], dim=(q1.ndim - 1)) + key_layer = torch.concat([k1, k2], dim=(k1.ndim - 1)) + else: + position_ids = position_ids.transpose(0, 1) + cos, sin = self.rotary_emb(value_layer, seq_len=position_ids.max() + 1) + # [batch, seq_len, num_attention_heads, hidden_size_per_attention_head] + query_layer, key_layer = apply_rotary_pos_emb_index(query_layer, key_layer, cos, sin, position_ids) + + # [batch, seq_len, hidden_size] + context_layer, present, attention_probs = attention_fn( + self=self, + query_layer=query_layer, + key_layer=key_layer, + value_layer=value_layer, + attention_mask=attention_mask, + hidden_size_per_partition=self.hidden_size_per_partition, + layer_id=layer_id, + layer_past=layer_past, + use_cache=use_cache + ) + + output = self.dense(context_layer) + + outputs = (output, present) + + if output_attentions: + outputs += (attention_probs,) + + return outputs # output, present, attention_probs + + +class GEGLU(torch.nn.Module): + def __init__(self): + super().__init__() + self.activation_fn = F.gelu + + def forward(self, x): + # dim=-1 breaks in jit for pt<1.10 + x1, x2 = x.chunk(2, dim=(x.ndim - 1)) + return x1 * self.activation_fn(x2) + + +class GLU(torch.nn.Module): + def __init__(self, hidden_size, inner_hidden_size=None, + layer_id=None, bias=True, activation_func=gelu, params_dtype=torch.float, empty_init=True): + super(GLU, self).__init__() + if empty_init: + init_method = skip_init + else: + init_method = default_init + self.layer_id = layer_id + self.activation_func = activation_func + + # Project to 4h. + self.hidden_size = hidden_size + if inner_hidden_size is None: + inner_hidden_size = 4 * hidden_size + self.inner_hidden_size = inner_hidden_size + self.dense_h_to_4h = init_method( + torch.nn.Linear, + self.hidden_size, + self.inner_hidden_size, + bias=bias, + dtype=params_dtype, + ) + # Project back to h. + self.dense_4h_to_h = init_method( + torch.nn.Linear, + self.inner_hidden_size, + self.hidden_size, + bias=bias, + dtype=params_dtype, + ) + + def forward(self, hidden_states): + """ + hidden_states: [seq_len, batch, hidden_size] + """ + + # [seq_len, batch, inner_hidden_size] + intermediate_parallel = self.dense_h_to_4h(hidden_states) + + intermediate_parallel = self.activation_func(intermediate_parallel) + + output = self.dense_4h_to_h(intermediate_parallel) + + return output + + +class GLMBlock(torch.nn.Module): + def __init__( + self, + hidden_size, + num_attention_heads, + layernorm_epsilon, + layer_id, + inner_hidden_size=None, + hidden_size_per_attention_head=None, + layernorm=LayerNorm, + use_bias=True, + params_dtype=torch.float, + num_layers=28, + position_encoding_2d=True, + empty_init=True, + max_sequence_length=2048 + ): + super(GLMBlock, self).__init__() + # Set output layer initialization if not provided. + + self.layer_id = layer_id + + # Layernorm on the input data. + self.input_layernorm = layernorm(hidden_size, eps=layernorm_epsilon) + + self.position_encoding_2d = position_encoding_2d + + # Self attention. + self.attention = SelfAttention( + hidden_size, + num_attention_heads, + layer_id, + hidden_size_per_attention_head=hidden_size_per_attention_head, + bias=use_bias, + params_dtype=params_dtype, + position_encoding_2d=self.position_encoding_2d, + empty_init=empty_init, + max_sequence_length=max_sequence_length + ) + + # Layernorm on the input data. + self.post_attention_layernorm = layernorm(hidden_size, eps=layernorm_epsilon) + + self.num_layers = num_layers + + # GLU + self.mlp = GLU( + hidden_size, + inner_hidden_size=inner_hidden_size, + bias=use_bias, + layer_id=layer_id, + params_dtype=params_dtype, + empty_init=empty_init + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids, + attention_mask: torch.Tensor, + layer_id, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + use_cache: bool = False, + output_attentions: bool = False, + ): + """ + hidden_states: [seq_len, batch, hidden_size] + attention_mask: [(1, 1), seq_len, seq_len] + """ + + # Layer norm at the begining of the transformer layer. + # [batch, seq_len, hidden_size] + attention_input = self.input_layernorm(hidden_states) + + # Self attention. + attention_outputs = self.attention( + attention_input, + position_ids, + attention_mask=attention_mask, + layer_id=layer_id, + layer_past=layer_past, + use_cache=use_cache, + output_attentions=output_attentions + ) + + attention_output = attention_outputs[0] + + outputs = attention_outputs[1:] + + # Residual connection. + alpha = (2 * self.num_layers) ** 0.5 + hidden_states = attention_input * alpha + attention_output + + mlp_input = self.post_attention_layernorm(hidden_states) + + # MLP. + mlp_output = self.mlp(mlp_input) + + # Second residual connection. + output = mlp_input * alpha + mlp_output + + if use_cache: + outputs = (output,) + outputs + else: + outputs = (output,) + outputs[1:] + + return outputs # hidden_states, present, attentions + + +class ChatGLMPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and + a simple interface for downloading and loading pretrained models. + """ + + is_parallelizable = False + supports_gradient_checkpointing = True + config_class = ChatGLMConfig + base_model_prefix = "transformer" + _no_split_modules = ["GLMBlock"] + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights(self, module: nn.Module): + """Initialize the weights.""" + return + + def get_masks(self, input_ids, device): + batch_size, seq_length = input_ids.shape + context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids] + attention_mask = torch.ones((batch_size, seq_length, seq_length), device=device) + attention_mask.tril_() + for i, context_length in enumerate(context_lengths): + attention_mask[i, :, :context_length] = 1 + attention_mask.unsqueeze_(1) + attention_mask = (attention_mask < 0.5).bool() + + return attention_mask + + def get_position_ids(self, input_ids, mask_positions, device, use_gmasks=None): + batch_size, seq_length = input_ids.shape + if use_gmasks is None: + use_gmasks = [False] * batch_size + context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids] + if self.position_encoding_2d: + position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) + for i, context_length in enumerate(context_lengths): + position_ids[i, context_length:] = mask_positions[i] + block_position_ids = [torch.cat(( + torch.zeros(context_length, dtype=torch.long, device=device), + torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1 + )) for context_length in context_lengths] + block_position_ids = torch.stack(block_position_ids, dim=0) + position_ids = torch.stack((position_ids, block_position_ids), dim=1) + else: + position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) + for i, context_length in enumerate(context_lengths): + if not use_gmasks[i]: + position_ids[i, context_length:] = mask_positions[i] + + return position_ids + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, ChatGLMModel): + module.gradient_checkpointing = value + + +CHATGLM_6B_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general + usage and behavior. + + Parameters: + config ([`~ChatGLM6BConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the configuration. + Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +CHATGLM_6B_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`ChatGLM6BTokenizer`]. + See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. + Selected in the range `[0, config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert *input_ids* indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare ChatGLM-6B Model transformer outputting raw hidden-states without any specific head on top.", + CHATGLM_6B_START_DOCSTRING, +) +class ChatGLMModel(ChatGLMPreTrainedModel): + """ + + The model can behave as an encoder (with only self-attention) as well + as a decoder, in which case a layer of cross-attention is added between + the self-attention layers, following the architecture described in [Attention is + all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, + Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + + To behave as an decoder the model needs to be initialized with the + `is_decoder` argument of the configuration set to `True`. + To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` + argument and `add_cross_attention` set to `True`; an + `encoder_hidden_states` is then expected as an input to the forward pass. + """ + + def __init__(self, config: ChatGLMConfig, empty_init=True): + super().__init__(config) + if empty_init: + init_method = skip_init + else: + init_method = default_init + # recording parameters + self.max_sequence_length = config.max_sequence_length + self.hidden_size = config.hidden_size + self.params_dtype = torch.half + self.num_attention_heads = config.num_attention_heads + self.vocab_size = config.vocab_size + self.num_layers = config.num_layers + self.layernorm_epsilon = config.layernorm_epsilon + self.inner_hidden_size = config.inner_hidden_size + self.hidden_size_per_attention_head = self.hidden_size // self.num_attention_heads + self.position_encoding_2d = config.position_encoding_2d + self.pre_seq_len = config.pre_seq_len + self.prefix_projection = config.prefix_projection + + self.word_embeddings = init_method( + torch.nn.Embedding, + num_embeddings=self.vocab_size, embedding_dim=self.hidden_size, + dtype=self.params_dtype + ) + self.gradient_checkpointing = False + + def get_layer(layer_id): + return GLMBlock( + self.hidden_size, + self.num_attention_heads, + self.layernorm_epsilon, + layer_id, + inner_hidden_size=self.inner_hidden_size, + hidden_size_per_attention_head=self.hidden_size_per_attention_head, + layernorm=LayerNorm, + use_bias=True, + params_dtype=self.params_dtype, + position_encoding_2d=self.position_encoding_2d, + empty_init=empty_init, + max_sequence_length=self.max_sequence_length + ) + + self.layers = torch.nn.ModuleList( + [get_layer(layer_id) for layer_id in range(self.num_layers)] + ) + + # Final layer norm before output. + self.final_layernorm = LayerNorm(self.hidden_size, eps=self.layernorm_epsilon) + + if self.pre_seq_len is not None: + for param in self.parameters(): + param.requires_grad = False + self.prefix_tokens = torch.arange(self.pre_seq_len).long() + self.prefix_encoder = PrefixEncoder(config) + self.dropout = torch.nn.Dropout(0.1) + + # total_params = sum(p.numel() for p in self.parameters()) + # trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) + # print("Using p-tuning v2: # trainable_params = {} / {}".format(trainable_params, total_params)) + + def get_input_embeddings(self): + return self.word_embeddings + + def set_input_embeddings(self, new_embeddings: torch.Tensor): + self.word_embeddings = new_embeddings + + def get_prompt(self, batch_size, device, dtype=torch.half): + prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device) + past_key_values = self.prefix_encoder(prefix_tokens).type(dtype) + past_key_values = past_key_values.view( + batch_size, + self.pre_seq_len, + self.num_layers * 2, + self.num_attention_heads, + self.hidden_size // self.num_attention_heads + ) + # seq_len, b, nh, hidden_size + past_key_values = self.dropout(past_key_values) + past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2) + # past_key_values = [(v[0], v[1]) for v in past_key_values] + return past_key_values + + @add_start_docstrings_to_model_forward(CHATGLM_6B_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPastAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPast]: + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length = inputs_embeds.shape[:2] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + if past_key_values is None: + if self.pre_seq_len is not None: + past_key_values = self.get_prompt(batch_size=input_ids.shape[0], device=input_ids.device, + dtype=inputs_embeds.dtype) + else: + past_key_values = tuple([None] * len(self.layers)) + + if attention_mask is None: + attention_mask = self.get_masks( + input_ids, + device=input_ids.device + ) + + + if position_ids is None: + MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id + seqs = input_ids.tolist() + + mask_positions, use_gmasks = [], [] + for seq in seqs: + mask_token = gMASK if gMASK in seq else MASK + use_gmask = mask_token == gMASK + mask_positions.append(seq.index(mask_token)) + use_gmasks.append(use_gmask) + + position_ids = self.get_position_ids( + input_ids, + mask_positions=mask_positions, + device=input_ids.device, + use_gmasks=use_gmasks + ) + + if self.pre_seq_len is not None and attention_mask is not None: + prefix_attention_mask = torch.ones(batch_size, 1, input_ids.size(-1), self.pre_seq_len).to( + attention_mask.device) + prefix_attention_mask = (prefix_attention_mask < 0.5).bool() + attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3) + + # [seq_len, batch, hidden_size] + # hidden_states = inputs_embeds.transpose(0, 1) + # xxx + hidden_states = inputs_embeds + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + if attention_mask is None: + attention_mask = torch.zeros(1, 1, device=input_ids.device).bool() + else: + attention_mask = attention_mask.to(hidden_states.device) + + for i, layer in enumerate(self.layers): + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + layer_past = past_key_values[i] + + if self.gradient_checkpointing and self.training: + layer_ret = torch.utils.checkpoint.checkpoint( + layer, + hidden_states, + position_ids, + attention_mask, + torch.tensor(i), + layer_past, + use_cache, + output_attentions + ) + else: + layer_ret = layer( + hidden_states, + position_ids=position_ids, + attention_mask=attention_mask, + layer_id=torch.tensor(i), + layer_past=layer_past, + use_cache=use_cache, + output_attentions=output_attentions + ) + + hidden_states = layer_ret[0] + + if use_cache: + presents = presents + (layer_ret[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_ret[2 if use_cache else 1],) + + # Final layer norm. + hidden_states = self.final_layernorm(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): + def __init__(self, config: ChatGLMConfig, empty_init=True): + super().__init__(config) + if empty_init: + init_method = skip_init + else: + init_method = default_init + + # self.hidden_size = config.hidden_size + # self.params_dtype = torch.half + # self.vocab_size = config.vocab_size + self.max_sequence_length = config.max_sequence_length + + self.position_encoding_2d = config.position_encoding_2d + + self.transformer = ChatGLMModel(config, empty_init=empty_init) + + self.lm_head = init_method( + nn.Linear, + config.hidden_size, + config.vocab_size, + bias=False, + dtype=torch.half + ) + + self.config = config + + self.quantized = False + + if self.config.quantization_bit: + self.quantize(self.config.quantization_bit, empty_init=True) + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def _update_model_kwargs_for_generation( + self, + outputs: ModelOutput, + model_kwargs: Dict[str, Any], + is_encoder_decoder: bool = False, + standardize_cache_format: bool = False, + ) -> Dict[str, Any]: + # update past_key_values + model_kwargs["past_key_values"] = self._extract_past_from_model_output( + outputs, standardize_cache_format=standardize_cache_format + ) + + # update attention mask + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + if attention_mask is not None: # and attention_mask.dtype == torch.bool: + right = torch.empty((*attention_mask.shape[:3], 1), dtype=torch.float32) + right[:] = -10000.0 + attention_mask = torch.cat( + [attention_mask, right], dim=3) + new_attention_mask = attention_mask[:, :, -1:].clone() + new_attention_mask[..., -1] = 0 + model_kwargs["attention_mask"] = torch.cat( + [attention_mask, new_attention_mask], dim=2 + ) + + # update position ids + if "position_ids" in model_kwargs: + position_ids = model_kwargs["position_ids"] + new_position_id = position_ids[..., -1:].clone() + new_position_id[:, 1, :] += 1 + model_kwargs["position_ids"] = torch.cat( + [position_ids, new_position_id], dim=-1 + ) + + return model_kwargs + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor, + past: Optional[torch.Tensor] = None, + past_key_values: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + **kwargs + ) -> dict: + batch_size, seq_length = input_ids.shape + MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id + seqs = input_ids.tolist() + mask_positions, use_gmasks = [], [] + for seq in seqs: + mask_token = gMASK if gMASK in seq else MASK + use_gmask = mask_token == gMASK + mask_positions.append(seq.index(mask_token)) + use_gmasks.append(use_gmask) + + # only last token for input_ids if past is not None + if past is not None or past_key_values is not None: + last_token = input_ids[:, -1].unsqueeze(-1) + if attention_mask is not None: # and attention_mask.dtype == torch.bool: + attention_mask = attention_mask[:, :, -1:] + # else: + # attention_mask = None + if position_ids is not None: + position_ids = position_ids[..., -1:] + # else: + # context_lengths = [seq.index(self.config.bos_token_id) for seq in seqs] + # if self.position_encoding_2d: + # position_ids = torch.tensor( + # [[mask_position, seq_length - context_length] for mask_position, context_length in + # zip(mask_positions, context_lengths)], dtype=torch.long, device=input_ids.device).unsqueeze(-1) + # else: + # position_ids = torch.tensor([mask_position for mask_position in mask_positions], dtype=torch.long, + # device=input_ids.device).unsqueeze(-1) + + if past is None: + past = past_key_values + return { + "input_ids": last_token, + "past_key_values": past, + "position_ids": position_ids, + "attention_mask": attention_mask + } + else: + # if attention_mask is not None and attention_mask.dtype != torch.bool: + # logger.warning_once(f"The dtype of attention mask ({attention_mask.dtype}) is not bool") + # attention_mask = None + # if attention_mask is None: + # attention_mask = self.get_masks( + # input_ids, + # device=input_ids.device + # ) + # if position_ids is None: + # position_ids = self.get_position_ids( + # input_ids, + # device=input_ids.device, + # mask_positions=mask_positions, + # use_gmasks=use_gmasks + # ) + + return { + "input_ids": input_ids, + "past_key_values": past, + "position_ids": position_ids, + "attention_mask": attention_mask + } + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + + lm_logits = self.lm_head(hidden_states).contiguous() + + loss = None + if labels is not None: + lm_logits = lm_logits.to(torch.float32) + + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss(ignore_index=-100) + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + lm_logits = lm_logits.to(hidden_states.dtype) + loss = loss.to(hidden_states.dtype) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + @staticmethod + def _reorder_cache( + past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor + ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + + Output shares the same memory storage as `past`. + """ + return tuple( + ( + layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)), + layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)), + ) + for layer_past in past + ) + + def process_response(self, response): + response = response.strip() + response = response.replace("[[训练时间]]", "2023年") + punkts = [ + [",", ","], + ["!", "!"], + [":", ":"], + [";", ";"], + ["\?", "?"], + ] + for item in punkts: + response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response) + response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response) + return response + + @torch.no_grad() + def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1, + do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs): + if history is None: + history = [] + if logits_processor is None: + logits_processor = LogitsProcessorList() + logits_processor.append(InvalidScoreLogitsProcessor()) + gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p, + "temperature": temperature, "logits_processor": logits_processor, **kwargs} + if not history: + prompt = query + else: + prompt = "" + for i, (old_query, response) in enumerate(history): + prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response) + prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) + inputs = tokenizer([prompt], return_tensors="pt") + inputs = inputs.to(self.device) + outputs = self.generate(**inputs, **gen_kwargs) + outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] + response = tokenizer.decode(outputs) + response = self.process_response(response) + history = history + [(query, response)] + return response, history + + @torch.no_grad() + def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, + do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs): + if history is None: + history = [] + if logits_processor is None: + logits_processor = LogitsProcessorList() + logits_processor.append(InvalidScoreLogitsProcessor()) + gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p, + "temperature": temperature, "logits_processor": logits_processor, **kwargs} + if not history: + prompt = query + else: + prompt = "" + for i, (old_query, response) in enumerate(history): + prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response) + prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) + inputs = tokenizer([prompt], return_tensors="pt") + inputs = inputs.to(self.device) + for outputs in self.stream_generate(**inputs, **gen_kwargs): + outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] + response = tokenizer.decode(outputs) + response = self.process_response(response) + new_history = history + [(query, response)] + yield response, new_history + + @torch.no_grad() + def stream_generate( + self, + input_ids, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, + **kwargs, + ): + batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] + + if generation_config is None: + generation_config = self.generation_config + generation_config = copy.deepcopy(generation_config) + model_kwargs = generation_config.update(**kwargs) + bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id + + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + + has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None + if has_default_max_length and generation_config.max_new_tokens is None: + warnings.warn( + f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " + "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" + " recommend using `max_new_tokens` to control the maximum length of the generation.", + UserWarning, + ) + elif generation_config.max_new_tokens is not None: + generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length + if not has_default_max_length: + logger.warn( + f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" + f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " + "Please refer to the documentation for more information. " + "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)", + UserWarning, + ) + + if input_ids_seq_length >= generation_config.max_length: + input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" + logger.warning( + f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" + f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" + " increasing `max_new_tokens`." + ) + + # 2. Set generation parameters if not already defined + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + + logits_processor = self._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids_seq_length, + encoder_input_ids=input_ids, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + logits_processor=logits_processor, + ) + + stopping_criteria = self._get_stopping_criteria( + generation_config=generation_config, stopping_criteria=stopping_criteria + ) + logits_warper = self._get_logits_warper(generation_config) + + unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + scores = None + while True: + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + # forward pass to get next token + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=False, + output_hidden_states=False, + ) + + next_token_logits = outputs.logits[:, -1, :] + + # pre-process distribution + next_token_scores = logits_processor(input_ids, next_token_logits) + next_token_scores = logits_warper(input_ids, next_token_scores) + + # sample + probs = nn.functional.softmax(next_token_scores, dim=-1) + if generation_config.do_sample: + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + else: + next_tokens = torch.argmax(probs, dim=-1) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long()) + + # stop when each sentence is finished, or if we exceed the maximum length + if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): + break + yield input_ids + + def quantize(self, bits: int, empty_init=False, **kwargs): + if bits == 0: + return + + from .quantization import quantize + + if self.quantized: + logger.info("Already quantized.") + return self + + self.quantized = True + + self.config.quantization_bit = bits + + self.transformer = quantize(self.transformer, bits, empty_init=empty_init, **kwargs) + return self diff --git a/tests/script/models/chatglm-6b/pytorch_model.bin.index.json b/tests/script/models/chatglm-6b/pytorch_model.bin.index.json new file mode 100644 index 0000000..b8ada2b --- /dev/null +++ b/tests/script/models/chatglm-6b/pytorch_model.bin.index.json @@ -0,0 +1,375 @@ +{ + "metadata": { + "total_size": 13744473856 + }, + "weight_map": { + "lm_head.weight": "pytorch_model-00008-of-00008.bin", + "transformer.final_layernorm.bias": "pytorch_model-00007-of-00008.bin", + "transformer.final_layernorm.weight": "pytorch_model-00007-of-00008.bin", + "transformer.layers.0.attention.dense.bias": "pytorch_model-00001-of-00008.bin", + "transformer.layers.0.attention.dense.weight": "pytorch_model-00001-of-00008.bin", + "transformer.layers.0.attention.query_key_value.bias": "pytorch_model-00001-of-00008.bin", + "transformer.layers.0.attention.query_key_value.weight": "pytorch_model-00001-of-00008.bin", + "transformer.layers.0.attention.rotary_emb.inv_freq": "pytorch_model-00001-of-00008.bin", + "transformer.layers.0.input_layernorm.bias": "pytorch_model-00001-of-00008.bin", + "transformer.layers.0.input_layernorm.weight": "pytorch_model-00001-of-00008.bin", + "transformer.layers.0.mlp.dense_4h_to_h.bias": "pytorch_model-00001-of-00008.bin", + "transformer.layers.0.mlp.dense_4h_to_h.weight": "pytorch_model-00001-of-00008.bin", + "transformer.layers.0.mlp.dense_h_to_4h.bias": "pytorch_model-00001-of-00008.bin", + "transformer.layers.0.mlp.dense_h_to_4h.weight": "pytorch_model-00001-of-00008.bin", + "transformer.layers.0.post_attention_layernorm.bias": "pytorch_model-00001-of-00008.bin", + "transformer.layers.0.post_attention_layernorm.weight": "pytorch_model-00001-of-00008.bin", + "transformer.layers.1.attention.dense.bias": "pytorch_model-00001-of-00008.bin", + "transformer.layers.1.attention.dense.weight": "pytorch_model-00001-of-00008.bin", + "transformer.layers.1.attention.query_key_value.bias": "pytorch_model-00001-of-00008.bin", + "transformer.layers.1.attention.query_key_value.weight": "pytorch_model-00001-of-00008.bin", + "transformer.layers.1.attention.rotary_emb.inv_freq": "pytorch_model-00001-of-00008.bin", + "transformer.layers.1.input_layernorm.bias": "pytorch_model-00001-of-00008.bin", + "transformer.layers.1.input_layernorm.weight": "pytorch_model-00001-of-00008.bin", + "transformer.layers.1.mlp.dense_4h_to_h.bias": "pytorch_model-00002-of-00008.bin", + "transformer.layers.1.mlp.dense_4h_to_h.weight": "pytorch_model-00002-of-00008.bin", + "transformer.layers.1.mlp.dense_h_to_4h.bias": "pytorch_model-00001-of-00008.bin", + "transformer.layers.1.mlp.dense_h_to_4h.weight": "pytorch_model-00001-of-00008.bin", + "transformer.layers.1.post_attention_layernorm.bias": "pytorch_model-00001-of-00008.bin", + "transformer.layers.1.post_attention_layernorm.weight": "pytorch_model-00001-of-00008.bin", + "transformer.layers.10.attention.dense.bias": "pytorch_model-00003-of-00008.bin", + "transformer.layers.10.attention.dense.weight": "pytorch_model-00003-of-00008.bin", + "transformer.layers.10.attention.query_key_value.bias": "pytorch_model-00003-of-00008.bin", + "transformer.layers.10.attention.query_key_value.weight": "pytorch_model-00003-of-00008.bin", + "transformer.layers.10.attention.rotary_emb.inv_freq": "pytorch_model-00003-of-00008.bin", + "transformer.layers.10.input_layernorm.bias": "pytorch_model-00003-of-00008.bin", + "transformer.layers.10.input_layernorm.weight": "pytorch_model-00003-of-00008.bin", + "transformer.layers.10.mlp.dense_4h_to_h.bias": "pytorch_model-00003-of-00008.bin", + "transformer.layers.10.mlp.dense_4h_to_h.weight": "pytorch_model-00003-of-00008.bin", + "transformer.layers.10.mlp.dense_h_to_4h.bias": "pytorch_model-00003-of-00008.bin", + "transformer.layers.10.mlp.dense_h_to_4h.weight": "pytorch_model-00003-of-00008.bin", + "transformer.layers.10.post_attention_layernorm.bias": "pytorch_model-00003-of-00008.bin", + "transformer.layers.10.post_attention_layernorm.weight": "pytorch_model-00003-of-00008.bin", + "transformer.layers.11.attention.dense.bias": "pytorch_model-00004-of-00008.bin", + "transformer.layers.11.attention.dense.weight": "pytorch_model-00004-of-00008.bin", + "transformer.layers.11.attention.query_key_value.bias": "pytorch_model-00003-of-00008.bin", + "transformer.layers.11.attention.query_key_value.weight": "pytorch_model-00003-of-00008.bin", + "transformer.layers.11.attention.rotary_emb.inv_freq": "pytorch_model-00003-of-00008.bin", + "transformer.layers.11.input_layernorm.bias": "pytorch_model-00003-of-00008.bin", + "transformer.layers.11.input_layernorm.weight": "pytorch_model-00003-of-00008.bin", + "transformer.layers.11.mlp.dense_4h_to_h.bias": "pytorch_model-00004-of-00008.bin", + "transformer.layers.11.mlp.dense_4h_to_h.weight": "pytorch_model-00004-of-00008.bin", + "transformer.layers.11.mlp.dense_h_to_4h.bias": "pytorch_model-00004-of-00008.bin", + "transformer.layers.11.mlp.dense_h_to_4h.weight": "pytorch_model-00004-of-00008.bin", + "transformer.layers.11.post_attention_layernorm.bias": "pytorch_model-00004-of-00008.bin", + "transformer.layers.11.post_attention_layernorm.weight": "pytorch_model-00004-of-00008.bin", + "transformer.layers.12.attention.dense.bias": "pytorch_model-00004-of-00008.bin", + "transformer.layers.12.attention.dense.weight": "pytorch_model-00004-of-00008.bin", + "transformer.layers.12.attention.query_key_value.bias": "pytorch_model-00004-of-00008.bin", + "transformer.layers.12.attention.query_key_value.weight": "pytorch_model-00004-of-00008.bin", + "transformer.layers.12.attention.rotary_emb.inv_freq": "pytorch_model-00004-of-00008.bin", + "transformer.layers.12.input_layernorm.bias": "pytorch_model-00004-of-00008.bin", + "transformer.layers.12.input_layernorm.weight": "pytorch_model-00004-of-00008.bin", + "transformer.layers.12.mlp.dense_4h_to_h.bias": "pytorch_model-00004-of-00008.bin", + "transformer.layers.12.mlp.dense_4h_to_h.weight": "pytorch_model-00004-of-00008.bin", + "transformer.layers.12.mlp.dense_h_to_4h.bias": "pytorch_model-00004-of-00008.bin", + "transformer.layers.12.mlp.dense_h_to_4h.weight": "pytorch_model-00004-of-00008.bin", + "transformer.layers.12.post_attention_layernorm.bias": "pytorch_model-00004-of-00008.bin", + "transformer.layers.12.post_attention_layernorm.weight": "pytorch_model-00004-of-00008.bin", + "transformer.layers.13.attention.dense.bias": "pytorch_model-00004-of-00008.bin", + "transformer.layers.13.attention.dense.weight": "pytorch_model-00004-of-00008.bin", + "transformer.layers.13.attention.query_key_value.bias": "pytorch_model-00004-of-00008.bin", + "transformer.layers.13.attention.query_key_value.weight": "pytorch_model-00004-of-00008.bin", + "transformer.layers.13.attention.rotary_emb.inv_freq": "pytorch_model-00004-of-00008.bin", + "transformer.layers.13.input_layernorm.bias": "pytorch_model-00004-of-00008.bin", + "transformer.layers.13.input_layernorm.weight": "pytorch_model-00004-of-00008.bin", + "transformer.layers.13.mlp.dense_4h_to_h.bias": "pytorch_model-00004-of-00008.bin", + "transformer.layers.13.mlp.dense_4h_to_h.weight": "pytorch_model-00004-of-00008.bin", + "transformer.layers.13.mlp.dense_h_to_4h.bias": "pytorch_model-00004-of-00008.bin", + "transformer.layers.13.mlp.dense_h_to_4h.weight": "pytorch_model-00004-of-00008.bin", + "transformer.layers.13.post_attention_layernorm.bias": "pytorch_model-00004-of-00008.bin", + "transformer.layers.13.post_attention_layernorm.weight": "pytorch_model-00004-of-00008.bin", + "transformer.layers.14.attention.dense.bias": "pytorch_model-00004-of-00008.bin", + "transformer.layers.14.attention.dense.weight": "pytorch_model-00004-of-00008.bin", + "transformer.layers.14.attention.query_key_value.bias": "pytorch_model-00004-of-00008.bin", + "transformer.layers.14.attention.query_key_value.weight": "pytorch_model-00004-of-00008.bin", + "transformer.layers.14.attention.rotary_emb.inv_freq": "pytorch_model-00004-of-00008.bin", + "transformer.layers.14.input_layernorm.bias": "pytorch_model-00004-of-00008.bin", + "transformer.layers.14.input_layernorm.weight": "pytorch_model-00004-of-00008.bin", + "transformer.layers.14.mlp.dense_4h_to_h.bias": "pytorch_model-00004-of-00008.bin", + "transformer.layers.14.mlp.dense_4h_to_h.weight": "pytorch_model-00004-of-00008.bin", + "transformer.layers.14.mlp.dense_h_to_4h.bias": "pytorch_model-00004-of-00008.bin", + "transformer.layers.14.mlp.dense_h_to_4h.weight": "pytorch_model-00004-of-00008.bin", + "transformer.layers.14.post_attention_layernorm.bias": "pytorch_model-00004-of-00008.bin", + "transformer.layers.14.post_attention_layernorm.weight": "pytorch_model-00004-of-00008.bin", + "transformer.layers.15.attention.dense.bias": "pytorch_model-00004-of-00008.bin", + "transformer.layers.15.attention.dense.weight": "pytorch_model-00004-of-00008.bin", + "transformer.layers.15.attention.query_key_value.bias": "pytorch_model-00004-of-00008.bin", + "transformer.layers.15.attention.query_key_value.weight": "pytorch_model-00004-of-00008.bin", + "transformer.layers.15.attention.rotary_emb.inv_freq": "pytorch_model-00004-of-00008.bin", + "transformer.layers.15.input_layernorm.bias": "pytorch_model-00004-of-00008.bin", + "transformer.layers.15.input_layernorm.weight": "pytorch_model-00004-of-00008.bin", + "transformer.layers.15.mlp.dense_4h_to_h.bias": "pytorch_model-00004-of-00008.bin", + "transformer.layers.15.mlp.dense_4h_to_h.weight": "pytorch_model-00004-of-00008.bin", + "transformer.layers.15.mlp.dense_h_to_4h.bias": "pytorch_model-00004-of-00008.bin", + "transformer.layers.15.mlp.dense_h_to_4h.weight": "pytorch_model-00004-of-00008.bin", + "transformer.layers.15.post_attention_layernorm.bias": "pytorch_model-00004-of-00008.bin", + "transformer.layers.15.post_attention_layernorm.weight": "pytorch_model-00004-of-00008.bin", + "transformer.layers.16.attention.dense.bias": "pytorch_model-00005-of-00008.bin", + "transformer.layers.16.attention.dense.weight": "pytorch_model-00005-of-00008.bin", + "transformer.layers.16.attention.query_key_value.bias": "pytorch_model-00005-of-00008.bin", + "transformer.layers.16.attention.query_key_value.weight": "pytorch_model-00005-of-00008.bin", + "transformer.layers.16.attention.rotary_emb.inv_freq": "pytorch_model-00004-of-00008.bin", + "transformer.layers.16.input_layernorm.bias": "pytorch_model-00004-of-00008.bin", + "transformer.layers.16.input_layernorm.weight": "pytorch_model-00004-of-00008.bin", + "transformer.layers.16.mlp.dense_4h_to_h.bias": "pytorch_model-00005-of-00008.bin", + "transformer.layers.16.mlp.dense_4h_to_h.weight": "pytorch_model-00005-of-00008.bin", + "transformer.layers.16.mlp.dense_h_to_4h.bias": "pytorch_model-00005-of-00008.bin", + "transformer.layers.16.mlp.dense_h_to_4h.weight": "pytorch_model-00005-of-00008.bin", + "transformer.layers.16.post_attention_layernorm.bias": "pytorch_model-00005-of-00008.bin", + "transformer.layers.16.post_attention_layernorm.weight": "pytorch_model-00005-of-00008.bin", + "transformer.layers.17.attention.dense.bias": "pytorch_model-00005-of-00008.bin", + "transformer.layers.17.attention.dense.weight": "pytorch_model-00005-of-00008.bin", + "transformer.layers.17.attention.query_key_value.bias": "pytorch_model-00005-of-00008.bin", + "transformer.layers.17.attention.query_key_value.weight": "pytorch_model-00005-of-00008.bin", + "transformer.layers.17.attention.rotary_emb.inv_freq": "pytorch_model-00005-of-00008.bin", + "transformer.layers.17.input_layernorm.bias": "pytorch_model-00005-of-00008.bin", + "transformer.layers.17.input_layernorm.weight": "pytorch_model-00005-of-00008.bin", + "transformer.layers.17.mlp.dense_4h_to_h.bias": "pytorch_model-00005-of-00008.bin", + "transformer.layers.17.mlp.dense_4h_to_h.weight": "pytorch_model-00005-of-00008.bin", + "transformer.layers.17.mlp.dense_h_to_4h.bias": "pytorch_model-00005-of-00008.bin", + "transformer.layers.17.mlp.dense_h_to_4h.weight": "pytorch_model-00005-of-00008.bin", + "transformer.layers.17.post_attention_layernorm.bias": "pytorch_model-00005-of-00008.bin", + "transformer.layers.17.post_attention_layernorm.weight": "pytorch_model-00005-of-00008.bin", + "transformer.layers.18.attention.dense.bias": "pytorch_model-00005-of-00008.bin", + "transformer.layers.18.attention.dense.weight": "pytorch_model-00005-of-00008.bin", + "transformer.layers.18.attention.query_key_value.bias": "pytorch_model-00005-of-00008.bin", + "transformer.layers.18.attention.query_key_value.weight": "pytorch_model-00005-of-00008.bin", + "transformer.layers.18.attention.rotary_emb.inv_freq": "pytorch_model-00005-of-00008.bin", + "transformer.layers.18.input_layernorm.bias": "pytorch_model-00005-of-00008.bin", + "transformer.layers.18.input_layernorm.weight": "pytorch_model-00005-of-00008.bin", + "transformer.layers.18.mlp.dense_4h_to_h.bias": "pytorch_model-00005-of-00008.bin", + "transformer.layers.18.mlp.dense_4h_to_h.weight": "pytorch_model-00005-of-00008.bin", + "transformer.layers.18.mlp.dense_h_to_4h.bias": "pytorch_model-00005-of-00008.bin", + "transformer.layers.18.mlp.dense_h_to_4h.weight": "pytorch_model-00005-of-00008.bin", + "transformer.layers.18.post_attention_layernorm.bias": "pytorch_model-00005-of-00008.bin", + "transformer.layers.18.post_attention_layernorm.weight": "pytorch_model-00005-of-00008.bin", + "transformer.layers.19.attention.dense.bias": "pytorch_model-00005-of-00008.bin", + "transformer.layers.19.attention.dense.weight": "pytorch_model-00005-of-00008.bin", + "transformer.layers.19.attention.query_key_value.bias": "pytorch_model-00005-of-00008.bin", + "transformer.layers.19.attention.query_key_value.weight": "pytorch_model-00005-of-00008.bin", + "transformer.layers.19.attention.rotary_emb.inv_freq": "pytorch_model-00005-of-00008.bin", + "transformer.layers.19.input_layernorm.bias": "pytorch_model-00005-of-00008.bin", + "transformer.layers.19.input_layernorm.weight": "pytorch_model-00005-of-00008.bin", + "transformer.layers.19.mlp.dense_4h_to_h.bias": "pytorch_model-00005-of-00008.bin", + "transformer.layers.19.mlp.dense_4h_to_h.weight": "pytorch_model-00005-of-00008.bin", + "transformer.layers.19.mlp.dense_h_to_4h.bias": "pytorch_model-00005-of-00008.bin", + "transformer.layers.19.mlp.dense_h_to_4h.weight": "pytorch_model-00005-of-00008.bin", + "transformer.layers.19.post_attention_layernorm.bias": "pytorch_model-00005-of-00008.bin", + "transformer.layers.19.post_attention_layernorm.weight": "pytorch_model-00005-of-00008.bin", + "transformer.layers.2.attention.dense.bias": "pytorch_model-00002-of-00008.bin", + "transformer.layers.2.attention.dense.weight": "pytorch_model-00002-of-00008.bin", + "transformer.layers.2.attention.query_key_value.bias": "pytorch_model-00002-of-00008.bin", + "transformer.layers.2.attention.query_key_value.weight": "pytorch_model-00002-of-00008.bin", + "transformer.layers.2.attention.rotary_emb.inv_freq": "pytorch_model-00002-of-00008.bin", + "transformer.layers.2.input_layernorm.bias": "pytorch_model-00002-of-00008.bin", + "transformer.layers.2.input_layernorm.weight": "pytorch_model-00002-of-00008.bin", + "transformer.layers.2.mlp.dense_4h_to_h.bias": "pytorch_model-00002-of-00008.bin", + "transformer.layers.2.mlp.dense_4h_to_h.weight": "pytorch_model-00002-of-00008.bin", + "transformer.layers.2.mlp.dense_h_to_4h.bias": "pytorch_model-00002-of-00008.bin", + "transformer.layers.2.mlp.dense_h_to_4h.weight": "pytorch_model-00002-of-00008.bin", + "transformer.layers.2.post_attention_layernorm.bias": "pytorch_model-00002-of-00008.bin", + "transformer.layers.2.post_attention_layernorm.weight": "pytorch_model-00002-of-00008.bin", + "transformer.layers.20.attention.dense.bias": "pytorch_model-00005-of-00008.bin", + "transformer.layers.20.attention.dense.weight": "pytorch_model-00005-of-00008.bin", + "transformer.layers.20.attention.query_key_value.bias": "pytorch_model-00005-of-00008.bin", + "transformer.layers.20.attention.query_key_value.weight": "pytorch_model-00005-of-00008.bin", + "transformer.layers.20.attention.rotary_emb.inv_freq": "pytorch_model-00005-of-00008.bin", + "transformer.layers.20.input_layernorm.bias": "pytorch_model-00005-of-00008.bin", + "transformer.layers.20.input_layernorm.weight": "pytorch_model-00005-of-00008.bin", + "transformer.layers.20.mlp.dense_4h_to_h.bias": "pytorch_model-00006-of-00008.bin", + "transformer.layers.20.mlp.dense_4h_to_h.weight": "pytorch_model-00006-of-00008.bin", + "transformer.layers.20.mlp.dense_h_to_4h.bias": "pytorch_model-00005-of-00008.bin", + "transformer.layers.20.mlp.dense_h_to_4h.weight": "pytorch_model-00005-of-00008.bin", + "transformer.layers.20.post_attention_layernorm.bias": "pytorch_model-00005-of-00008.bin", + "transformer.layers.20.post_attention_layernorm.weight": "pytorch_model-00005-of-00008.bin", + "transformer.layers.21.attention.dense.bias": "pytorch_model-00006-of-00008.bin", + "transformer.layers.21.attention.dense.weight": "pytorch_model-00006-of-00008.bin", + "transformer.layers.21.attention.query_key_value.bias": "pytorch_model-00006-of-00008.bin", + "transformer.layers.21.attention.query_key_value.weight": "pytorch_model-00006-of-00008.bin", + "transformer.layers.21.attention.rotary_emb.inv_freq": "pytorch_model-00006-of-00008.bin", + "transformer.layers.21.input_layernorm.bias": "pytorch_model-00006-of-00008.bin", + "transformer.layers.21.input_layernorm.weight": "pytorch_model-00006-of-00008.bin", + "transformer.layers.21.mlp.dense_4h_to_h.bias": "pytorch_model-00006-of-00008.bin", + "transformer.layers.21.mlp.dense_4h_to_h.weight": "pytorch_model-00006-of-00008.bin", + "transformer.layers.21.mlp.dense_h_to_4h.bias": "pytorch_model-00006-of-00008.bin", + "transformer.layers.21.mlp.dense_h_to_4h.weight": "pytorch_model-00006-of-00008.bin", + "transformer.layers.21.post_attention_layernorm.bias": "pytorch_model-00006-of-00008.bin", + "transformer.layers.21.post_attention_layernorm.weight": "pytorch_model-00006-of-00008.bin", + "transformer.layers.22.attention.dense.bias": "pytorch_model-00006-of-00008.bin", + "transformer.layers.22.attention.dense.weight": "pytorch_model-00006-of-00008.bin", + "transformer.layers.22.attention.query_key_value.bias": "pytorch_model-00006-of-00008.bin", + "transformer.layers.22.attention.query_key_value.weight": "pytorch_model-00006-of-00008.bin", + "transformer.layers.22.attention.rotary_emb.inv_freq": "pytorch_model-00006-of-00008.bin", + "transformer.layers.22.input_layernorm.bias": "pytorch_model-00006-of-00008.bin", + "transformer.layers.22.input_layernorm.weight": "pytorch_model-00006-of-00008.bin", + "transformer.layers.22.mlp.dense_4h_to_h.bias": "pytorch_model-00006-of-00008.bin", + "transformer.layers.22.mlp.dense_4h_to_h.weight": "pytorch_model-00006-of-00008.bin", + "transformer.layers.22.mlp.dense_h_to_4h.bias": "pytorch_model-00006-of-00008.bin", + "transformer.layers.22.mlp.dense_h_to_4h.weight": "pytorch_model-00006-of-00008.bin", + "transformer.layers.22.post_attention_layernorm.bias": "pytorch_model-00006-of-00008.bin", + "transformer.layers.22.post_attention_layernorm.weight": "pytorch_model-00006-of-00008.bin", + "transformer.layers.23.attention.dense.bias": "pytorch_model-00006-of-00008.bin", + "transformer.layers.23.attention.dense.weight": "pytorch_model-00006-of-00008.bin", + "transformer.layers.23.attention.query_key_value.bias": "pytorch_model-00006-of-00008.bin", + "transformer.layers.23.attention.query_key_value.weight": "pytorch_model-00006-of-00008.bin", + "transformer.layers.23.attention.rotary_emb.inv_freq": "pytorch_model-00006-of-00008.bin", + "transformer.layers.23.input_layernorm.bias": "pytorch_model-00006-of-00008.bin", + "transformer.layers.23.input_layernorm.weight": "pytorch_model-00006-of-00008.bin", + "transformer.layers.23.mlp.dense_4h_to_h.bias": "pytorch_model-00006-of-00008.bin", + "transformer.layers.23.mlp.dense_4h_to_h.weight": "pytorch_model-00006-of-00008.bin", + "transformer.layers.23.mlp.dense_h_to_4h.bias": "pytorch_model-00006-of-00008.bin", + "transformer.layers.23.mlp.dense_h_to_4h.weight": "pytorch_model-00006-of-00008.bin", + "transformer.layers.23.post_attention_layernorm.bias": "pytorch_model-00006-of-00008.bin", + "transformer.layers.23.post_attention_layernorm.weight": "pytorch_model-00006-of-00008.bin", + "transformer.layers.24.attention.dense.bias": "pytorch_model-00006-of-00008.bin", + "transformer.layers.24.attention.dense.weight": "pytorch_model-00006-of-00008.bin", + "transformer.layers.24.attention.query_key_value.bias": "pytorch_model-00006-of-00008.bin", + "transformer.layers.24.attention.query_key_value.weight": "pytorch_model-00006-of-00008.bin", + "transformer.layers.24.attention.rotary_emb.inv_freq": "pytorch_model-00006-of-00008.bin", + "transformer.layers.24.input_layernorm.bias": "pytorch_model-00006-of-00008.bin", + "transformer.layers.24.input_layernorm.weight": "pytorch_model-00006-of-00008.bin", + "transformer.layers.24.mlp.dense_4h_to_h.bias": "pytorch_model-00006-of-00008.bin", + "transformer.layers.24.mlp.dense_4h_to_h.weight": "pytorch_model-00006-of-00008.bin", + "transformer.layers.24.mlp.dense_h_to_4h.bias": "pytorch_model-00006-of-00008.bin", + "transformer.layers.24.mlp.dense_h_to_4h.weight": "pytorch_model-00006-of-00008.bin", + "transformer.layers.24.post_attention_layernorm.bias": "pytorch_model-00006-of-00008.bin", + "transformer.layers.24.post_attention_layernorm.weight": "pytorch_model-00006-of-00008.bin", + "transformer.layers.25.attention.dense.bias": "pytorch_model-00006-of-00008.bin", + "transformer.layers.25.attention.dense.weight": "pytorch_model-00006-of-00008.bin", + "transformer.layers.25.attention.query_key_value.bias": "pytorch_model-00006-of-00008.bin", + "transformer.layers.25.attention.query_key_value.weight": "pytorch_model-00006-of-00008.bin", + "transformer.layers.25.attention.rotary_emb.inv_freq": "pytorch_model-00006-of-00008.bin", + "transformer.layers.25.input_layernorm.bias": "pytorch_model-00006-of-00008.bin", + "transformer.layers.25.input_layernorm.weight": "pytorch_model-00006-of-00008.bin", + "transformer.layers.25.mlp.dense_4h_to_h.bias": "pytorch_model-00007-of-00008.bin", + "transformer.layers.25.mlp.dense_4h_to_h.weight": "pytorch_model-00007-of-00008.bin", + "transformer.layers.25.mlp.dense_h_to_4h.bias": "pytorch_model-00007-of-00008.bin", + "transformer.layers.25.mlp.dense_h_to_4h.weight": "pytorch_model-00007-of-00008.bin", + "transformer.layers.25.post_attention_layernorm.bias": "pytorch_model-00006-of-00008.bin", + "transformer.layers.25.post_attention_layernorm.weight": "pytorch_model-00006-of-00008.bin", + "transformer.layers.26.attention.dense.bias": "pytorch_model-00007-of-00008.bin", + "transformer.layers.26.attention.dense.weight": "pytorch_model-00007-of-00008.bin", + "transformer.layers.26.attention.query_key_value.bias": "pytorch_model-00007-of-00008.bin", + "transformer.layers.26.attention.query_key_value.weight": "pytorch_model-00007-of-00008.bin", + "transformer.layers.26.attention.rotary_emb.inv_freq": "pytorch_model-00007-of-00008.bin", + "transformer.layers.26.input_layernorm.bias": "pytorch_model-00007-of-00008.bin", + "transformer.layers.26.input_layernorm.weight": "pytorch_model-00007-of-00008.bin", + "transformer.layers.26.mlp.dense_4h_to_h.bias": "pytorch_model-00007-of-00008.bin", + "transformer.layers.26.mlp.dense_4h_to_h.weight": "pytorch_model-00007-of-00008.bin", + "transformer.layers.26.mlp.dense_h_to_4h.bias": "pytorch_model-00007-of-00008.bin", + "transformer.layers.26.mlp.dense_h_to_4h.weight": "pytorch_model-00007-of-00008.bin", + "transformer.layers.26.post_attention_layernorm.bias": "pytorch_model-00007-of-00008.bin", + "transformer.layers.26.post_attention_layernorm.weight": "pytorch_model-00007-of-00008.bin", + "transformer.layers.27.attention.dense.bias": "pytorch_model-00007-of-00008.bin", + "transformer.layers.27.attention.dense.weight": "pytorch_model-00007-of-00008.bin", + "transformer.layers.27.attention.query_key_value.bias": "pytorch_model-00007-of-00008.bin", + "transformer.layers.27.attention.query_key_value.weight": "pytorch_model-00007-of-00008.bin", + "transformer.layers.27.attention.rotary_emb.inv_freq": "pytorch_model-00007-of-00008.bin", + "transformer.layers.27.input_layernorm.bias": "pytorch_model-00007-of-00008.bin", + "transformer.layers.27.input_layernorm.weight": "pytorch_model-00007-of-00008.bin", + "transformer.layers.27.mlp.dense_4h_to_h.bias": "pytorch_model-00007-of-00008.bin", + "transformer.layers.27.mlp.dense_4h_to_h.weight": "pytorch_model-00007-of-00008.bin", + "transformer.layers.27.mlp.dense_h_to_4h.bias": "pytorch_model-00007-of-00008.bin", + "transformer.layers.27.mlp.dense_h_to_4h.weight": "pytorch_model-00007-of-00008.bin", + "transformer.layers.27.post_attention_layernorm.bias": "pytorch_model-00007-of-00008.bin", + "transformer.layers.27.post_attention_layernorm.weight": "pytorch_model-00007-of-00008.bin", + "transformer.layers.3.attention.dense.bias": "pytorch_model-00002-of-00008.bin", + "transformer.layers.3.attention.dense.weight": "pytorch_model-00002-of-00008.bin", + "transformer.layers.3.attention.query_key_value.bias": "pytorch_model-00002-of-00008.bin", + "transformer.layers.3.attention.query_key_value.weight": "pytorch_model-00002-of-00008.bin", + "transformer.layers.3.attention.rotary_emb.inv_freq": "pytorch_model-00002-of-00008.bin", + "transformer.layers.3.input_layernorm.bias": "pytorch_model-00002-of-00008.bin", + "transformer.layers.3.input_layernorm.weight": "pytorch_model-00002-of-00008.bin", + "transformer.layers.3.mlp.dense_4h_to_h.bias": "pytorch_model-00002-of-00008.bin", + "transformer.layers.3.mlp.dense_4h_to_h.weight": "pytorch_model-00002-of-00008.bin", + "transformer.layers.3.mlp.dense_h_to_4h.bias": "pytorch_model-00002-of-00008.bin", + "transformer.layers.3.mlp.dense_h_to_4h.weight": "pytorch_model-00002-of-00008.bin", + "transformer.layers.3.post_attention_layernorm.bias": "pytorch_model-00002-of-00008.bin", + "transformer.layers.3.post_attention_layernorm.weight": "pytorch_model-00002-of-00008.bin", + "transformer.layers.4.attention.dense.bias": "pytorch_model-00002-of-00008.bin", + "transformer.layers.4.attention.dense.weight": "pytorch_model-00002-of-00008.bin", + "transformer.layers.4.attention.query_key_value.bias": "pytorch_model-00002-of-00008.bin", + "transformer.layers.4.attention.query_key_value.weight": "pytorch_model-00002-of-00008.bin", + "transformer.layers.4.attention.rotary_emb.inv_freq": "pytorch_model-00002-of-00008.bin", + "transformer.layers.4.input_layernorm.bias": "pytorch_model-00002-of-00008.bin", + "transformer.layers.4.input_layernorm.weight": "pytorch_model-00002-of-00008.bin", + "transformer.layers.4.mlp.dense_4h_to_h.bias": "pytorch_model-00002-of-00008.bin", + "transformer.layers.4.mlp.dense_4h_to_h.weight": "pytorch_model-00002-of-00008.bin", + "transformer.layers.4.mlp.dense_h_to_4h.bias": "pytorch_model-00002-of-00008.bin", + "transformer.layers.4.mlp.dense_h_to_4h.weight": "pytorch_model-00002-of-00008.bin", + "transformer.layers.4.post_attention_layernorm.bias": "pytorch_model-00002-of-00008.bin", + "transformer.layers.4.post_attention_layernorm.weight": "pytorch_model-00002-of-00008.bin", + "transformer.layers.5.attention.dense.bias": "pytorch_model-00002-of-00008.bin", + "transformer.layers.5.attention.dense.weight": "pytorch_model-00002-of-00008.bin", + "transformer.layers.5.attention.query_key_value.bias": "pytorch_model-00002-of-00008.bin", + "transformer.layers.5.attention.query_key_value.weight": "pytorch_model-00002-of-00008.bin", + "transformer.layers.5.attention.rotary_emb.inv_freq": "pytorch_model-00002-of-00008.bin", + "transformer.layers.5.input_layernorm.bias": "pytorch_model-00002-of-00008.bin", + "transformer.layers.5.input_layernorm.weight": "pytorch_model-00002-of-00008.bin", + "transformer.layers.5.mlp.dense_4h_to_h.bias": "pytorch_model-00002-of-00008.bin", + "transformer.layers.5.mlp.dense_4h_to_h.weight": "pytorch_model-00002-of-00008.bin", + "transformer.layers.5.mlp.dense_h_to_4h.bias": "pytorch_model-00002-of-00008.bin", + "transformer.layers.5.mlp.dense_h_to_4h.weight": "pytorch_model-00002-of-00008.bin", + "transformer.layers.5.post_attention_layernorm.bias": "pytorch_model-00002-of-00008.bin", + "transformer.layers.5.post_attention_layernorm.weight": "pytorch_model-00002-of-00008.bin", + "transformer.layers.6.attention.dense.bias": "pytorch_model-00002-of-00008.bin", + "transformer.layers.6.attention.dense.weight": "pytorch_model-00002-of-00008.bin", + "transformer.layers.6.attention.query_key_value.bias": "pytorch_model-00002-of-00008.bin", + "transformer.layers.6.attention.query_key_value.weight": "pytorch_model-00002-of-00008.bin", + "transformer.layers.6.attention.rotary_emb.inv_freq": "pytorch_model-00002-of-00008.bin", + "transformer.layers.6.input_layernorm.bias": "pytorch_model-00002-of-00008.bin", + "transformer.layers.6.input_layernorm.weight": "pytorch_model-00002-of-00008.bin", + "transformer.layers.6.mlp.dense_4h_to_h.bias": "pytorch_model-00003-of-00008.bin", + "transformer.layers.6.mlp.dense_4h_to_h.weight": "pytorch_model-00003-of-00008.bin", + "transformer.layers.6.mlp.dense_h_to_4h.bias": "pytorch_model-00003-of-00008.bin", + "transformer.layers.6.mlp.dense_h_to_4h.weight": "pytorch_model-00003-of-00008.bin", + "transformer.layers.6.post_attention_layernorm.bias": "pytorch_model-00002-of-00008.bin", + "transformer.layers.6.post_attention_layernorm.weight": "pytorch_model-00002-of-00008.bin", + "transformer.layers.7.attention.dense.bias": "pytorch_model-00003-of-00008.bin", + "transformer.layers.7.attention.dense.weight": "pytorch_model-00003-of-00008.bin", + "transformer.layers.7.attention.query_key_value.bias": "pytorch_model-00003-of-00008.bin", + "transformer.layers.7.attention.query_key_value.weight": "pytorch_model-00003-of-00008.bin", + "transformer.layers.7.attention.rotary_emb.inv_freq": "pytorch_model-00003-of-00008.bin", + "transformer.layers.7.input_layernorm.bias": "pytorch_model-00003-of-00008.bin", + "transformer.layers.7.input_layernorm.weight": "pytorch_model-00003-of-00008.bin", + "transformer.layers.7.mlp.dense_4h_to_h.bias": "pytorch_model-00003-of-00008.bin", + "transformer.layers.7.mlp.dense_4h_to_h.weight": "pytorch_model-00003-of-00008.bin", + "transformer.layers.7.mlp.dense_h_to_4h.bias": "pytorch_model-00003-of-00008.bin", + "transformer.layers.7.mlp.dense_h_to_4h.weight": "pytorch_model-00003-of-00008.bin", + "transformer.layers.7.post_attention_layernorm.bias": "pytorch_model-00003-of-00008.bin", + "transformer.layers.7.post_attention_layernorm.weight": "pytorch_model-00003-of-00008.bin", + "transformer.layers.8.attention.dense.bias": "pytorch_model-00003-of-00008.bin", + "transformer.layers.8.attention.dense.weight": "pytorch_model-00003-of-00008.bin", + "transformer.layers.8.attention.query_key_value.bias": "pytorch_model-00003-of-00008.bin", + "transformer.layers.8.attention.query_key_value.weight": "pytorch_model-00003-of-00008.bin", + "transformer.layers.8.attention.rotary_emb.inv_freq": "pytorch_model-00003-of-00008.bin", + "transformer.layers.8.input_layernorm.bias": "pytorch_model-00003-of-00008.bin", + "transformer.layers.8.input_layernorm.weight": "pytorch_model-00003-of-00008.bin", + "transformer.layers.8.mlp.dense_4h_to_h.bias": "pytorch_model-00003-of-00008.bin", + "transformer.layers.8.mlp.dense_4h_to_h.weight": "pytorch_model-00003-of-00008.bin", + "transformer.layers.8.mlp.dense_h_to_4h.bias": "pytorch_model-00003-of-00008.bin", + "transformer.layers.8.mlp.dense_h_to_4h.weight": "pytorch_model-00003-of-00008.bin", + "transformer.layers.8.post_attention_layernorm.bias": "pytorch_model-00003-of-00008.bin", + "transformer.layers.8.post_attention_layernorm.weight": "pytorch_model-00003-of-00008.bin", + "transformer.layers.9.attention.dense.bias": "pytorch_model-00003-of-00008.bin", + "transformer.layers.9.attention.dense.weight": "pytorch_model-00003-of-00008.bin", + "transformer.layers.9.attention.query_key_value.bias": "pytorch_model-00003-of-00008.bin", + "transformer.layers.9.attention.query_key_value.weight": "pytorch_model-00003-of-00008.bin", + "transformer.layers.9.attention.rotary_emb.inv_freq": "pytorch_model-00003-of-00008.bin", + "transformer.layers.9.input_layernorm.bias": "pytorch_model-00003-of-00008.bin", + "transformer.layers.9.input_layernorm.weight": "pytorch_model-00003-of-00008.bin", + "transformer.layers.9.mlp.dense_4h_to_h.bias": "pytorch_model-00003-of-00008.bin", + "transformer.layers.9.mlp.dense_4h_to_h.weight": "pytorch_model-00003-of-00008.bin", + "transformer.layers.9.mlp.dense_h_to_4h.bias": "pytorch_model-00003-of-00008.bin", + "transformer.layers.9.mlp.dense_h_to_4h.weight": "pytorch_model-00003-of-00008.bin", + "transformer.layers.9.post_attention_layernorm.bias": "pytorch_model-00003-of-00008.bin", + "transformer.layers.9.post_attention_layernorm.weight": "pytorch_model-00003-of-00008.bin", + "transformer.word_embeddings.weight": "pytorch_model-00001-of-00008.bin" + } +} diff --git a/tests/script/models/chatglm-6b/quantization.py b/tests/script/models/chatglm-6b/quantization.py new file mode 100644 index 0000000..6f469f6 --- /dev/null +++ b/tests/script/models/chatglm-6b/quantization.py @@ -0,0 +1,201 @@ +from torch.nn import Linear +from torch.nn.parameter import Parameter + +import bz2 +import torch +import base64 +import ctypes +from transformers.utils import logging + +from typing import List +from functools import partial + +logger = logging.get_logger(__name__) + +try: + from cpm_kernels.kernels.base import LazyKernelCModule, KernelFunction, round_up + + class Kernel: + def __init__(self, code: bytes, function_names: List[str]): + self.code = code + self._function_names = function_names + self._cmodule = LazyKernelCModule(self.code) + + for name in self._function_names: + setattr(self, name, KernelFunction(self._cmodule, name)) + + quantization_code = "$QlpoOTFBWSZTWU9yuJUAQHN//////////f/n/8/n///n//bt4dTidcVx8X3V9FV/92/v4B7/AD5FBQFAAAChSgKpFCFAFVSigUAAAEKhSgUUqgFBKigqVREQAABQBQIANDTTIGI00BkZBkNGE0A0BkBkGQGRkaNAaAGQNBoGgDIAAYIGTI0DQAQAaGmmQMRpoDIyDIaMJoBoDIDIMgMjI0aA0AMgaDQNAGQAAwQMmRoGgAgA0NNMgYjTQGRkGQ0YTQDQGQGQZAZGRo0BoAZA0GgaAMgABggZMjQNABABoaaZAxGmgMjIMhowmgGgMgMgyAyMjRoDQAyBoNA0AZAADBAyZGgaAAmqU1NEgJqnptU/Sn4jRR6J6epk2pqb1Q/SgAPUGgyNNGjQ2SBpoAZAAGg0NB6mgDIAAAAA2oaApSREBNAARhGiYEaEwU8pvImlP0k2aam1GaGqbFNM1MHpTwmkepmyU9R6nqPKekHqNNPUxNGhp6n6p6QaZ6o9TG1GMqcoV9ly6nRanHlq6zPNbnGZNi6HSug+2nPiZ13XcnFYZW+45W11CumhzYhchOJ2GLLV1OBjBjGf4TptOddTSOcVxhqYZMYwZXZZY00zI1paX5X9J+b+f4e+x43RXSxXPOdquiGpduatGyXneN696M9t4HU2eR5XX/kPhP261NTx3JO1Ow7LyuDmeo9a7d351T1ZxnvnrvYnrXv/hXxPCeuYx2XsNmO003eg9J3Z6U7b23meJ4ri01OdzTk9BNO96brz+qT5nuvvH3ds/G+m/JcG/F2XYuhXlvO+jP7U3XgrzPN/lr8Sf1n6j4j7jZs+s/T0tNaNNYzTs12rxjwztHlnire3Nzc3N1wuBwOBwXBvZfoHpD7rFmR99V5vj3aXza3xdBbXMalubTg/jIv5dfAi54Pdc75j4z412n3Npj3Ld/ENm7a3b/Cod6h/ret1/5vn/C+l+gdslMvgPSLJ8d8q+U66fevYn/tW1chleEtNTGlcHCbLRlq0tHzF5tsbbZZfHjjLgZu42XCuC3NrdjTasZGNzgxPIrGqp7r3p7L2p5XjnpPSmTd5XtzqnB6U87zzg1Ol0zd0zsLszxR6lkxp35u6/teL0L0W922cR7Lu1lpL9CsHirzuM2T+BgsyViT6LHcm0/Vr6U/7LGGyJeqTEjt0PHWhF5mCT7R9mtlDwriYv0Tyr/OxYt6qp5r0mPVT0608TqnqMZaarU2nFwrTzzlrs1ed7z1ux60wyr4ydCaTi3enW8x68x0zU7tXSlcmPSW1mGpWJMg4zmPC2lK96tp0OE80y4MfEvnZj8zGluR6b22ki1Ou9V2nCd9xovcPvcYMZYy0lvN60ScZ45vN6yeCeeXFb1lVjnnCar5fwXwE2bzJ4HI1XVPXfXZMm44GUsMpYsmLB65TuVdm0cl0b+i/wGNN66XjeV7zuPpHcnK/juhhjdfId5jMdE5nN0dGmmm2zZs2cexD5n9p/dY352XsvXHaZNWWsmmS1atjR452nYudzvqv2HMRyvNNnlMcDl3R2+yx2uVrBubTW9icHDVtbNXlZm7jma1rM4VurZZd2y6nUau7ZXZ7bVU+mnoOVxZGMrVmvX60605JwmzGZhhhjTWtaaaMaaGTGmNMZasY0iX8VMUl8eepaIrzGSpemWOQyZORk2bNpjUybMmxqYmknCGCFynutfksaZpjTNMaaatM0xsxcGR0sociNqxNSmhhR1ZJPbsn8qyF0t2qH6iYBclclalbtTTcHTDsPaX6rlnElph2Jyumumtynv2Kk8GI7rsvXbIcJgHJOSaSXnnGaI3m87RtVXJOZ/YtgdTE6Wpha6ZlE8ayXkef1fh602r2WwvfMXtMdLlkfnLFdYYwYso+bWqm7yJqHXZGw2nrS5ZanSYnWlxBxMF1V940K2wdrI7R6OYf7DGGamMmTSbRhlS45xmVOumF1EyPCmHrrN8wwZOOrdNtLeMtzFzDlWnfTBxMk2NaXIZHBYxYLD4w8yju0ao65Vz1OIXoS9dLanwCe1PWrYuWMqf1if1z2k2yYfKJ741PDgno1ZQ8DRqvUny3mNoWTzGO6m1DkrJI8JiR5cSd+vZdGOO8nrMoc5+NDUFsMSXaZJeNlMmGLtJsovOsUp7I9S5VojKxF6bTVEelXqlfJobQr3LozSh2Jk7VcrVMfhXqszGWMzNqGhqZY0OadxkyyMssKugZR0KNFXBHlqwmJgTE/BNVMk6ItJXZMR0H47GpXv/DMOvNkmVuaV1PRfEdxuqc7Hcd+ZV/zTLaRxWk0nl9CdCeM6mn5rstHIBcpiuwmUZXeq81DacHI2rmrZ5SuE5mOZd6LQrZg9mx32TprA8BMo5jKN6yLTCi3WzQaZSuhzTtM1fUTGVpG8Tw+KXI0tjEpiWxtLYynOlktSbVlaI5kxP8TDH8kx50xoxi5KcA4pcja8KWLRlO/Ks6q06ergnvm1ca3Tq8Uw7LTUsmWyctXPWmpitl/uvGcWTGXGuAXDfhqazGmjkxcJW5hMMMMpYsXl2TZYtVOddG3XCarUt6Ptq9CZXSNzyuRzqRZOjsxdBbFVz6OA5HI43r1jityVlVpVkxmOsyaYWE1NTGq1sOVh36mHMcxtSvcy70edG0ZGR3I1Go1GRlV7mWWo1G0ZGRqlvH40l7o4m5xMWLLLYyNjnqc8556mdPqLJ31n/1nWOncxzG1tizrHs/Z+d2vP/B/l8wdJ6rHUn2nbbDq4p6htFtYzMMMTaZis1K5GKzGNmxhmUx2DDlZ/qNnIx41xnaMfCZWYaZWtNLTNW8ND4Fw1MyZOCdM428suKG1ehW8TesOydg7J+YYcD4cYR+8dFK6M4E3HM9ZfRNNL+Sn6rsl4DsrDl2HpPCnfxjGXtbZtYys1ttlyJ4T+BvexjGWRjMszK4Jpc77D3GyuVD7q0+G8m9G+2+rGm7cOR2y7FdtY2XUYx/oNlfRYxhMYyYZkyyg55enna9Kt/FFi6GMMwYwdwxWgxGMLKYmUyGExTKMZkMFhkymKuh0NOBNnBu+23LdwDoZYYzGGMxtORaTU1pjTGWTTGGtMrNWUsyyTTLLG1qy2ZjbK2DBllWqxMtBMaYZQmcE7zvvRcTkclUwdkxTaSdyySt/7fpL+T1v516Ji97fwr5JbLu305zMn5+GMTTZ9F+y7ExwmGVfG44yxn3dLv6l5i+Wth1jCrDq21nW9LqvvDzz3Vf3LLH/O/32TJ/erx3bXftO4eF+G956D952K/An4NfvOpjFjExjevP/UmE0fIoZXx6/w6lX/no3D0bLt+ixjieBM6ksRd0yB4Lt2SwYNE+gd1detlZWUnpiZfGfFaK+4PyCa/v18V8X75pe9fLXzp7l3VjF76vWZmHwGz1IZNWT7b8yddJ4q5kyrVdfru6atWc7bVYztL9Jf4GXvT+Y8m9/YsXP6H018a8D4XVOqvfzqeR+6yZOD8dPv0+U7/q5Pl+2dNb0MjzGVH5p6MNQ7cOWvw62U9aHE8DprDek+McLyvDz+te+9Zhq5+YTruufMcWMabqysTmZVWjKPfnK0wyVcrsuhjZRdLkHNvD72b9abriOSGIxiLixMOoalNPXzy+wT/tf+U6HHONfsz+xe8ufHBdQWWGWLA9if0rsnmrxK5LvRZQeWsTCsrmOYy8VteVfuRfcVTtDLItLIsMYxZLdU/DbtSemxF6Z6Zo5WBXE4tFdCyVMMXMTEMZXVlS6Xec2T4e0tHsRcEuWshcJ2YsNF5rUx1E8ifCq6Z+ZP7qdCeu/aTwFd53l16/o0NOw6O3dLavP4Hbi4RdmuDk6DoYaninC0+o4uZjbJ7Rxeu0/FbuFg+q7DVS6fQe0rZ6NDGUNNU6DEqOaLTicKnYZMnBWruljQxoaS3dZhocDge0bSTyOvdAbG5hxe2xji7E/L55xX13wWNDi6HCekcFxfCPGxY0MXC+s7afWaMdDyjyr+o8Rudm/NabOZvdl274zH4f5XK9z6On1Pe/K5TdPAslg77BjuO6Y3eO7GqvOPG/stknp1leyvLL0Z7bl9I4noMvLkzytLhWYzrOZzLXCORe028rORzOg4N/L0HlMOQ3Pgmnbb6KczlabORpu980q37TBqRu0/p3PO6234Bl03Ynuz+9W7gnsEcmvYaYY3aMYY0wx3pYd+ujsXauWdaY5Xkbtl23fPzFHiDB/QMo0yFjBllYxTQYYyxkrwn7JufwJ/PfgJ+C83X69ni6zvXcnyXabv0ncbLwsceS+RNlyN2mnneJtX0ngYO0+e+0+UnA+Wch3ji8hj5an4h+i6XBySU4n+R0roVcbw5yvHrmr4Yw8Y7x6c+9POPYHI5HI5HI5HI5HGXGww4nE4nrVyOR8XeqPEO7PLOiukYa3Novk5hV4cdtYZLI93e+uxff2jRo0aNGjRo0aNG1bVtW1dy3m83m8+tQ5ZzHw3nObwOu8La9Rc1dtkdS8A3eTk823tnktXWlxN6Oixe06zrN70Isd9jiOgZFq9yfkPqP/SLhN2Myl8jDM43bl1nbcb4cO57jlh8Jow6pzXZdL4dyODTuuhu77FyO27DdwdRxmvO+O+3N2+BdqyTwLHVczDVY4UPE4O66/ZO2cx1LFzVdSXtF7G4HMbrauOHRw6c8FdZ5m9fHZHYZXfTlZquyynSyTTKke6vcffSD9pzPA/G7n7jxPmuhc1DHMynPMrGL6AdewYmwu5ko+UUyTwrMv27rPH1v1nGqd87+p6N6LU8k3NEng53xXyHS97+44OSg/sy/hn+Se6yfYNjW0/uTgP+PvWYzLMmjhcLB/gGpri6H83/84eUXWT6T9Hsv7785z/7z4icpW+zfXypuR7rx/gMdZb1/wC678pcs8/2a3mDitGHxl9mfPlll5MafWWqxk/eYuTDgcNMzDGWLWvsuglNxs53GtN6uWpktlW1tZZYcuinMMWmnNnJydze3b2Y1McBxrBkXw799izLMZZYyy0TkbsGM4p03S2uVu5s/XXUdSdec6smVxZYYGpVmT8A+8ajuEyV5FatkvVru2x6uxGXXbH4A+jvgP4GMYy3iPLXzq/6z65+E005ey+cwMZD3fZcqc6xpjTFjQ0P3U+e++cPYmTIwj0nrK5NPTfl3WvpfLtXDcb2HQMudYOxFXQBor4L4T6vrOauFctYXJQ++NUWmJe5bmx1jDiZS1dTqWxo4GR8jm3fttpmPHppk9PEyv4/y8/sO07XacOmcqc0x2Vi9BvNJvN5oW8x4mOsydpidRxMYJPx06m1bqPzq9KtK8sxXNXFodD/+MYYaJTLwOhc9brCsV18oOR1i4tXChyTkq4lf4y1Ke+9axjDHqs1mfBbMXuP4Hzi+X7t8vzv7bHerrUPgPCxhjre4fXdfLNtNM+Jd+Zdh8xd8wP87uNPoPgv4W7/5P2BuxfsMabNnMnza+54Pdi5U671GPZY8CehX8Voeoo7FHpkeEc6715FwHZrIrUrHaviPUbPZHND+IhczrP6FcYvhOZ0Di/ETt0OI+YwNWR9r7tpf6WDeZKZDB1+z2IthOl1mPyb5FluvEx9h9d0NnM0Y1XPFkWIsk1WotJ0PBMmkvjvQTd0e71tfeV+8r8lQ/tpzpsmxJ+InrI/dj2UajUajVTUajatRqNRtGo1Go1Go4wjeMpZFMVV9CHbofPraLsJ3JpWV2XOoanCuFky4y3PPNxucK2uKC1Lbdb1eo+m5XomN6HfeZsabHLHRX/K+offtNGGmHWctcVcG44MdSqsOLY9VzX+Zxfxn2HPdWTpzWvkrtJ8M5zorrKcquRytJ5N5DZmcaW02l76nWO+BqPXm1A2Ry/0q71dH/mqrqeFjkYxjEXtsX8qubTk67rGycyqsdm4tZx5D6D5hhi0waaWmiaMP81Yjii5qxPlPuU/GfTL1Y5E6Jyfiq63qTa39A4J0sOGDgO9WF9bOXl0XfPRbsY2bPNKPy1YrFYrFYmRhhlTIyMjJWJYZHXuCXI8OoXsvfljGLFicNifpp2XunoPiG1wtx3p1Tah+/DD66OnVtVXP9rKbVxOnL0tR/rHtqB5UDErUVcl11D4qqvjpOcxX7armUNJB3LpW6bxVvD08e8h3odKKvyCFZBdSh2FVcST9xV3n3T8t1j7Kr9qgrqXg+13Pt5U7JCvFXVIV1YG5lRhkVYZJYYDDD4KOIMoHCp26WS8GB7uBh2zIdgq/PKyInjV2STShuoapUdCpX1yTwqq/z1VvET7Kh5nVPkO8YyxjLt2MaaMmWTLQvx3qnzltnXW0p2jxgbEtSny/Osv8Y9pLMXYoHVPAhkVdWVeODhR6q9/Sxe2liwwZWMVvFXfRkeIDxAePUPIrdJ4ey6yquzH+PD/bUOWAu05qVHtFd8rrKHSoeNIOUqrYr3FXyToqfYJgwmJdKpXXOwYYegNNGMzfZPp/t3t/DVs4zjNTN61rRqaWaa4NYbRjTa0tWwy2Y2tGN8ZO8ofNKq4j9SL7I+cSm4/6ovLV5HNXLI0jJidwrtk6ynCaP6Z++GjRlWS3tLeW129Mi9evxU9mtz6s5J3Z7M2ngTgnKvmpomxpaLCzPfmx0JWE+m3NLDDGOX47RctdYYNK5jakdqLkRlI39n590T5zctGSwwZZDJj6kW8XSi6ot2MmWWJ0DUT3nuvebBudScjZ79g8cWJ8av0k+/bE5WKd5MdbFpbDVMxu1DVMmtNZGJvq1mtRbn6M+g/kP0FwDwr7quZs7xosNGpbscyxhhd9TyJyFwbLcxlTasg75vW7TsV5K7ji44XPMMrdoj+Y3rT0Hie62nlYV/pwczzOmdLqLhYkzGMzCZWGMQzGMSsZYY6Di1t4nlJ+Em63mJxrVLxPbYxNEdgc1dU2iOKyoYYWjNrEeHTYybVk0atSa7ehuwsWMWTqn1TrnS6hYsi71d1+s+k+ic70e20fzE/VaTdxT9ZtU4GIXdeNx3X77guYYfpHeTQjaMX6brOu4OY4K7Y2d9mbHarI5ox3p4GpJ2Vd/Tst60f7j999pppjR+Q/Qf8J/VaORs3cji7FfFuN61+ui9s8hix1OCh5KGVV23BPXvZfz3CLyHpix+exi8z/KnCnosY2eunor+cxyPO/xJ0vKey9OvE9VjqaYu0x3Z3jd6o2b1T12D+F8l232lwaaacD5LE8LBxu7WTlbWraWpew8Xexjel3E+wWD4APITdNqR8F3R3T0lunCQ4GaE9R37DxeCYfcHi4xci5ovKfxVs55y2hf+65E/Xdp6jR5nrebTmi5incpkyOjs50JvrZwstbbW6kfuuQw+2mykf/EXNFzxfKTrxew929TR6bWnGL//F3JFOFCQT3K4lQ" + + kernels = Kernel( + bz2.decompress(base64.b64decode(quantization_code)), + [ + "int4WeightCompression", + "int4WeightExtractionFloat", + "int4WeightExtractionHalf", + "int8WeightExtractionFloat", + "int8WeightExtractionHalf", + ], + ) +except Exception as exception: + kernels = None + logger.warning("Failed to load cpm_kernels:" + str(exception)) + + +class W8A16Linear(torch.autograd.Function): + @staticmethod + def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width): + ctx.inp_shape = inp.size() + ctx.weight_bit_width = weight_bit_width + out_features = quant_w.size(0) + inp = inp.contiguous().view(-1, inp.size(-1)) + weight = extract_weight_to_half(quant_w, scale_w, weight_bit_width) + ctx.weight_shape = weight.size() + output = inp.mm(weight.t()) + ctx.save_for_backward(inp, quant_w, scale_w) + return output.view(*(ctx.inp_shape[:-1] + (out_features,))) + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + inp, quant_w, scale_w = ctx.saved_tensors + weight = extract_weight_to_half(quant_w, scale_w, ctx.weight_bit_width) + grad_output = grad_output.contiguous().view(-1, weight.size(0)) + grad_input = grad_output.mm(weight) + grad_weight = grad_output.t().mm(inp) + return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None, None + + +def compress_int4_weight(weight: torch.Tensor): # (n, m) + with torch.cuda.device(weight.device): + n, m = weight.size(0), weight.size(1) + assert m % 2 == 0 + m = m // 2 + out = torch.empty(n, m, dtype=torch.int8, device="cuda") + stream = torch.cuda.current_stream() + + gridDim = (n, 1, 1) + blockDim = (min(round_up(m, 32), 1024), 1, 1) + + kernels.int4WeightCompression( + gridDim, + blockDim, + 0, + stream, + [ctypes.c_void_p(weight.data_ptr()), ctypes.c_void_p(out.data_ptr()), ctypes.c_int32(n), ctypes.c_int32(m)], + ) + return out + + +def extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tensor, source_bit_width: int): + if source_bit_width == 8: + func = kernels.int8WeightExtractionHalf + elif source_bit_width == 4: + func = kernels.int4WeightExtractionHalf + else: + assert False, "Unsupported bit-width" + + with torch.cuda.device(weight.device): + n, m = weight.size(0), weight.size(1) + out = torch.empty(n, m * (8 // source_bit_width), dtype=torch.half, device="cuda") + stream = torch.cuda.current_stream() + + gridDim = (n, 1, 1) + blockDim = (min(round_up(m, 32), 1024), 1, 1) + + func( + gridDim, + blockDim, + 0, + stream, + [ + ctypes.c_void_p(weight.data_ptr()), + ctypes.c_void_p(scale_list.data_ptr()), + ctypes.c_void_p(out.data_ptr()), + ctypes.c_int32(n), + ctypes.c_int32(m), + ], + ) + return out + + +class QuantizedLinear(Linear): + def __init__(self, weight_bit_width: int, weight_tensor=None, bias_tensor=None, empty_init=False, *args, **kwargs): + super(QuantizedLinear, self).__init__(*args, **kwargs) + self.weight_bit_width = weight_bit_width + + shape = self.weight.shape + del self.weight + + if weight_tensor is None or empty_init: + self.weight = torch.empty( + shape[0], shape[1] * weight_bit_width // 8, dtype=torch.int8, device=kwargs["device"] + ) + self.weight_scale = torch.empty(shape[0], dtype=kwargs["dtype"], device=kwargs["device"]) + else: + self.weight_scale = (weight_tensor.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).half() + self.weight = torch.round(weight_tensor / self.weight_scale[:, None]).to(torch.int8) + if weight_bit_width == 4: + self.weight = compress_int4_weight(self.weight) + + self.weight = Parameter(self.weight.to(kwargs["device"]), requires_grad=False) + self.weight_scale = Parameter(self.weight_scale.to(kwargs["device"]), requires_grad=False) + if bias_tensor is not None: + self.bias = Parameter(bias_tensor.to(kwargs["device"]), requires_grad=False) + else: + self.bias = None + + def forward(self, input): + output = W8A16Linear.apply(input, self.weight, self.weight_scale, self.weight_bit_width) + if self.bias is not None: + output = output + self.bias + return output + + +def quantize(model, weight_bit_width, empty_init=False, **kwargs): + """Replace fp16 linear with quantized linear""" + + for layer in model.layers: + layer.attention.query_key_value = QuantizedLinear( + weight_bit_width=weight_bit_width, + weight_tensor=layer.attention.query_key_value.weight.to(torch.cuda.current_device()), + bias_tensor=layer.attention.query_key_value.bias, + in_features=layer.attention.query_key_value.in_features, + out_features=layer.attention.query_key_value.out_features, + bias=True, + dtype=torch.half, + device=layer.attention.query_key_value.weight.device, + empty_init=empty_init + ) + layer.attention.dense = QuantizedLinear( + weight_bit_width=weight_bit_width, + weight_tensor=layer.attention.dense.weight.to(torch.cuda.current_device()), + bias_tensor=layer.attention.dense.bias, + in_features=layer.attention.dense.in_features, + out_features=layer.attention.dense.out_features, + bias=True, + dtype=torch.half, + device=layer.attention.dense.weight.device, + empty_init=empty_init + ) + layer.mlp.dense_h_to_4h = QuantizedLinear( + weight_bit_width=weight_bit_width, + weight_tensor=layer.mlp.dense_h_to_4h.weight.to(torch.cuda.current_device()), + bias_tensor=layer.mlp.dense_h_to_4h.bias, + in_features=layer.mlp.dense_h_to_4h.in_features, + out_features=layer.mlp.dense_h_to_4h.out_features, + bias=True, + dtype=torch.half, + device=layer.mlp.dense_h_to_4h.weight.device, + empty_init=empty_init + ) + layer.mlp.dense_4h_to_h = QuantizedLinear( + weight_bit_width=weight_bit_width, + weight_tensor=layer.mlp.dense_4h_to_h.weight.to(torch.cuda.current_device()), + bias_tensor=layer.mlp.dense_4h_to_h.bias, + in_features=layer.mlp.dense_4h_to_h.in_features, + out_features=layer.mlp.dense_4h_to_h.out_features, + bias=True, + dtype=torch.half, + device=layer.mlp.dense_4h_to_h.weight.device, + empty_init=empty_init + ) + return model diff --git a/tests/script/models/chatglm-6b/test_modeling_chatglm.py b/tests/script/models/chatglm-6b/test_modeling_chatglm.py new file mode 100644 index 0000000..814c8bd --- /dev/null +++ b/tests/script/models/chatglm-6b/test_modeling_chatglm.py @@ -0,0 +1,245 @@ +import datetime +import math +import unittest +import torch +import random +import time + +from transformers import AutoTokenizer, AutoModel +from transformers.testing_utils import require_torch, slow, torch_device + +from torch.profiler import profile, record_function, ProfilerActivity + +def set_random_seed(seed): + import random + + random.seed(seed) + + # pytorch RNGs + import torch + + torch.manual_seed(seed) + torch.backends.cudnn.deterministic = True + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + # numpy RNG + import numpy as np + + np.random.seed(seed) + + + +def ids_tensor(shape, vocab_size): + # Creates a random int32 tensor of the shape within the vocab size + total_dims = 1 + for dim in shape: + total_dims *= dim + + values = [] + for _ in range(total_dims): + values.append(random.randint(0, vocab_size - 1)) + + return torch.tensor(data=values, dtype=torch.long, device=torch_device).view(shape).contiguous() + + +def get_model_and_tokenizer(): + model = AutoModel.from_pretrained("/home/luocheng/openvino/src/plugins/intel_cpu/thirdparty/llmdnn/tests/script/models/chatglm-6b", trust_remote_code=True) + model.to(torch_device, dtype=torch.bfloat16) + print(f"torch_device={torch_device}") + model.eval() + tokenizer = AutoTokenizer.from_pretrained("/home/luocheng/openvino/src/plugins/intel_cpu/thirdparty/llmdnn/tests/script/models/chatglm-6b", trust_remote_code=True) + return model, tokenizer + + +@require_torch +class ChatGLMGenerationTest(unittest.TestCase): + def get_generation_kwargs(self): + pass + + def ntest_chat(self): + print("======================test_chat") + model, tokenizer = get_model_and_tokenizer() + prompts = ["你好", "介绍一下清华大学", "它创建于哪一年"] + history = [] + set_random_seed(42) + expected_responses = [ + '你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。', + '清华大学是中国著名的综合性研究型大学,位于中国北京市海淀区,创建于 1911 年,前身是清华学堂。作为我国顶尖高等教育机构之一,清华大学在科学研究、工程技术、信息技术、经济管理等领域处于领先地位,也是世界上最著名的工程学府之一。\n\n清华大学拥有世界一流的教学设施和科学研究平台,设有多个学院和研究中心,包括工程学院、自然科学学院、社会科学学院、人文学院、法学院、经济管理学院等。学校拥有众多知名教授和研究团队,其中包括多位院士、国家杰出青年科学基金获得者、长江学者等。\n\n清华大学的本科生招生范围为全国中学毕业生,本科生入学要求严格,考试成绩优秀。同时,清华大学也提供研究生和博士生招生,包括硕士研究生和博士研究生。', + '清华大学创建于 1911 年。' + ] + for (prompt, expected_response) in zip(prompts, expected_responses): + response, history = model.chat(tokenizer, prompt, history=history) + print(repr(response)) + self.assertEqual(expected_response, response) + + def ntest_stream_chat(self): + print("======================test_stream_chat") + model, tokenizer = get_model_and_tokenizer() + prompts = ["你好", "介绍一下清华大学", "它创建于哪一年"] + history = [] + expected_responses = [ + '你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。', + '清华大学是中国著名的综合性研究型大学,位于中国北京市海淀区,创建于 1911 年,前身是清华学堂。作为我国顶尖高等教育机构之一,清华大学在科学研究、工程技术、信息技术、经济管理等领域处于领先地位,也是世界上最著名的工程学府之一。\n\n清华大学拥有世界一流的教学设施和科学研究平台,设有多个学院和研究中心,包括工程学院、自然科学学院、社会科学学院、人文学院、法学院、经济管理学院等。学校拥有众多知名教授和研究团队,其中包括多位院士、国家杰出青年科学基金获得者、长江学者等。\n\n清华大学的本科生招生范围为全国中学毕业生,本科生入学要求严格,考试成绩优秀。同时,清华大学也提供研究生和博士生招生,包括硕士研究生和博士研究生。', + '清华大学创建于 1911 年。' + ] + set_random_seed(42) + for prompt, expected_response in zip(prompts, expected_responses): + response = "" + for idx, (response, history) in enumerate(model.stream_chat(tokenizer, prompt, history=history)): + pass + print(repr(response)) + self.assertEqual(expected_response, response) + + def ntest_generation(self): + print("======================test_generation") + model, tokenizer = get_model_and_tokenizer() + sentence = "晚上睡不着怎么办" + parameters = [(False, 2048, 1), + (False, 64, 1), + (True, 2048, 1), + (True, 64, 1), + (True, 2048, 4)] + expected_out_sentences = [ + '晚上睡不着怎么办 以下是一些可能有助于在晚上入睡的方法:\n\n1. 保持规律的睡眠时间表:尽量在同一时间上床,并尝试在早上醒来时自然起床。\n\n2. 创建舒适的睡眠环境:保持房间安静、凉爽、黑暗、舒适,并使用舒适的床垫和枕头。\n\n3. 避免刺激性物质:避免饮用含咖啡因的饮料,如咖啡、茶和可乐,并尽可能减少饮酒。\n\n4. 放松身心:尝试进行放松的活动,如冥想、深呼吸、瑜伽或听轻柔的音乐。\n\n5. 避免在床上做其他事情:例如看电视、使用电脑或智能手机等。\n\n6. 练习放松技巧:例如渐进性肌肉松弛法、冥想或深呼吸练习。\n\n7. 寻求帮助:如果长时间都无法正常入睡,可以考虑咨询医生或专业心理医生,寻求更进一步的帮助。\n\n希望这些方法能有助于入睡。', + '晚上睡不着怎么办 以下是一些可能有助于在晚上入睡的方法:\n\n1. 保持规律的睡眠时间表:尽量在同一时间上床,并尝试在早上醒来时自然起床。\n\n2. 创建舒适的睡眠环境:保持房间安静、凉爽、黑暗、舒适,并使用舒适的床垫和枕头。', + '晚上睡不着怎么办 以下是一些有助于在晚上更好地入睡的方法:\n\n1. 维持规律的睡眠时间:每晚尽可能在同一时间上床,保持规律的睡眠时间表,帮助身体调整并更容易入睡。\n\n2. 避免在床上使用电子设备:手机、平板电脑、电脑等电子设备会发出蓝光,这会干扰身体释放褪黑素,进而导致难以入睡。建议你在睡前一小时停止使用这些设备。\n\n3. 创建舒适的睡眠环境:确保卧室安静、黑暗、凉爽,舒适的床垫和枕头,保持卧室温度适宜,这有助于让你更容易入睡。\n\n4. 放松身心:尝试进行一些放松的活动,如冥想、深呼吸、瑜伽或轻松的散步,减轻压力和焦虑,让你更容易入睡。\n\n5. 避免咖啡因和酒精:咖啡因和酒精会让大脑更加兴奋,进而干扰身体入睡过程。建议在睡前几小时避免饮用这些物质。\n\n6. 做一些安静的活动:阅读一本书、听轻柔的音乐、绣或者绘画等安静的活动,有助于自己放松身心,进而更容易入睡。\n\n如果采取以上这些方法仍然无法入睡,建议咨询医生或专业的睡眠专家,获取更好的建议和帮助。', + '晚上睡不着怎么办 以下是一些有助于在晚上更好地入睡的方法:\n\n1. 维持规律的睡眠时间:每晚尽可能在同一时间上床,保持规律的睡眠时间表,帮助身体调整并更容易入睡。\n\n2. 避免在床上使用电子设备:手机、平板电脑、电脑等电子设备会发出蓝光,这会干扰身体', + '晚上睡不着怎么办 以下是一些可能有助于在晚上入睡的方法:\n\n1. 建立规律的睡眠时间表:尽量在同一时间入睡和起床,即使在周末和假期也要尽量保持一致。\n\n2. 创造舒适的睡眠环境:保持房间安静、凉爽、黑暗、舒适,使用舒适的床垫和枕头等。\n\n3. 放松身心:尝试进行一些放松的活动,如冥想、深呼吸、瑜伽、听轻柔的音乐等,缓解压力和紧张情绪。\n\n4. 避免刺激性物质:避免饮用咖啡、茶、可乐等含咖啡因的饮料,避免吸烟和饮酒等刺激性物质。\n\n5. 避免躺在床上翻来覆去:如果躺在床上超过20分钟还不能入睡,就不要躺在床上翻来覆去,而是起床去做一些放松的活动,直到感到困倦为止。\n\n6. 练习放松技巧:如果感到焦虑或紧张,可以尝试进行一些放松技巧,如渐进性肌肉松弛、冥想等。\n\n7. 改善睡眠障碍:如果已经尝试了上述方法仍然无法入睡,可以考虑咨询医生,了解是否存在其他睡眠障碍问题,并接受相应的治疗。'] + for (do_sample, max_length, num_beams), expected_output_sentence in zip(parameters, expected_out_sentences): + set_random_seed(42) + inputs = tokenizer([sentence,], return_tensors="pt", padding=True) + inputs = inputs.to(torch_device) + print(inputs) + outputs = model.generate( + **inputs, + do_sample=do_sample, + max_length=max_length, + num_beams=num_beams + ) + print(outputs) + outputs = outputs.tolist()[0] + out_sentence = tokenizer.decode(outputs, skip_special_tokens=True) + print(out_sentence) + self.assertEqual(expected_output_sentence, out_sentence) + + def test_batch_generation(self): + print("======================test_batch_generation") + model, tokenizer = get_model_and_tokenizer() + sentences = [ + "你好", + "介绍一下清华大学" + ] + parameters = [(False, 2048, 1), + (False, 64, 1), + (True, 2048, 1), + (True, 64, 1), + (True, 2048, 4)] + expected_out_sentences = [ + ['你好 你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。', + '介绍一下清华大学 清华大学是中国著名的综合性大学,位于北京市海淀区双清路30号,其历史可以追溯到1911年创建的清华学堂,1925年更名为清华学校,1937年抗日战争全面爆发后南迁长沙,1946年迁回清华园。新中国成立后,清华学校更名为清华大学。\n\n清华大学是中国最顶尖的大学之一,在工程、科学、技术、经济、管理等领域都有很高的学术声誉和影响力。学校拥有世界一流的教学设施和科学研究平台,有多个学院和研究中心,包括工程学院、自然科学学院、人文学院、社会科学学院、经济管理学院、法学院、美术学院、医学院、器学院等。\n\n清华大学的本科生招生始于2000年,实行全面二孩政策后,本科生招生规模不断扩大。截至2022年,清华大学共有本科生近3万人,研究生近2万人,其中国际学生占比约为10%。清华大学的本科生教育注重通识教育和个性化培养,强调实践、创新、国际化和综合素质。'], + [ + '你好 你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。', + '介绍一下清华大学 清华大学是中国著名的综合性大学,位于北京市海淀区双清路30号,其历史可以追溯到1911年创建的清华学堂,1925年更名为清华学校,1937年抗日战争全面爆发后南迁长沙,1946年迁回' + ], + [ + '你好 你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。', + '介绍一下清华大学 清华大学是中国著名的综合性研究型大学,位于北京市海淀区双清路 30 号,其溯源于 1911 年创建的清华学堂, 1925 年更名为清华学校, 1937 年秋抗日战争全面爆发后闭校。1949 年 10 月开学复校,成为我国第一个社会主义大学生活了的高校。截至 2023 年,清华学校共管辖 2 个学院、13 个系,有本科专业 60 个,研究生专业 190 个。' + ], + [ + '你好 你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。', + '介绍一下清华大学 清华大学是中国著名的综合性研究型大学,位于北京市海淀区双清路 30 号,其溯源于 1911 年创建的清华学堂, 1925 年更名为清华学校, 1937 年秋抗日战争全面爆发后' + ], + [ + '你好 你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。', + '介绍一下清华大学 清华大学是中国著名的综合性研究型大学,位于北京市海淀区双清路30号,其历史可以追溯到1911年创建的清华学堂,1925年更名为清华学校,1937年抗日战争全面爆发后南迁长沙,与北京大学、南开大学组建国立长沙临时大学,1938年迁至 昆明改名为国立西南联合大学,1946年迁回北京。新中国成立后,清华学校更名为清华大学。' + ] + ] + for (do_sample, max_length, num_beams), expected_output_sentence in zip(parameters, expected_out_sentences): + set_random_seed(42) + inputs = tokenizer(sentences, return_tensors="pt", padding=True) + inputs = inputs.to(torch_device) + print(inputs) + outputs = model.generate( + **inputs, + do_sample=do_sample, + max_length=max_length, + num_beams=num_beams + ) + print(outputs) + batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True) + print(batch_out_sentence) + self.assertListEqual(expected_output_sentence, batch_out_sentence) + +def mytest(use_jit = False): + model, tokenizer = get_model_and_tokenizer() + + # import intel_extension_for_pytorch as ipex + # model = ipex.optimize(model, dtype=torch.bfloat16) + + + sentence = "世界羽毛球史上最伟大的球员都有谁?世界羽毛球史上最伟大的球员都有谁?世界羽毛球史上最伟大的球员都有谁?世界羽毛球史上最伟大的球员都有谁?" + parameters = [(False, 2048, 1), + #(True, 2048, 1), + #(True, 2048, 4) + ] + expected_out_sentences = [ + '晚上睡不着怎么办 以下是一些可能有助于在晚上入睡的方法:\n\n1. 保持规律的睡眠时间表:尽量在同一时间上床,并尝试在早上醒来时自然起床。\n\n2. 创建舒适的睡眠环境:保持房间安静、凉爽、黑暗、舒适,并使用舒适的床垫和枕头。\n\n3. 避免刺激性物质:避免饮用含咖啡因的饮料,如咖啡、茶和可乐,并尽可能减少饮酒。\n\n4. 放松身心:尝试进行放松的活动,如冥想、深呼吸、瑜伽或听轻柔的音乐。\n\n5. 避免在床上做其他事情:例如看电视、使用电脑或智能手机等。\n\n6. 练习放松技巧:例如渐进性肌肉松弛法、冥想或深呼吸练习。\n\n7. 寻求帮助:如果长时间都无法正常入睡,可以考虑咨询医生或专业心理医生,寻求更进一步的帮助。\n\n希望这些方法能有助于入睡。', + '晚上睡不着怎么办 以下是一些可能有助于在晚上入睡的方法:\n\n1. 保持规律的睡眠时间表:尽量在同一时间上床,并尝试在早上醒来时自然起床。\n\n2. 创建舒适的睡眠环境:保持房间安静、凉爽、黑暗、舒适,并使用舒适的床垫和枕头。', + '晚上睡不着怎么办 以下是一些有助于在晚上更好地入睡的方法:\n\n1. 维持规律的睡眠时间:每晚尽可能在同一时间上床,保持规律的睡眠时间表,帮助身体调整并更容易入睡。\n\n2. 避免在床上使用电子设备:手机、平板电脑、电脑等电子设备会发出蓝光,这会干扰身体释放褪黑素,进而导致难以入睡。建议你在睡前一小时停止使用这些设备。\n\n3. 创建舒适的睡眠环境:确保卧室安静、黑暗、凉爽,舒适的床垫和枕头,保持卧室温度适宜,这有助于让你更容易入睡。\n\n4. 放松身心:尝试进行一些放松的活动,如冥想、深呼吸、瑜伽或轻松的散步,减轻压力和焦虑,让你更容易入睡。\n\n5. 避免咖啡因和酒精:咖啡因和酒精会让大脑更加兴奋,进而干扰身体入睡过程。建议在睡前几小时避免饮用这些物质。\n\n6. 做一些安静的活动:阅读一本书、听轻柔的音乐、绣或者绘画等安静的活动,有助于自己放松身心,进而更容易入睡。\n\n如果采取以上这些方法仍然无法入睡,建议咨询医生或专业的睡眠专家,获取更好的建议和帮助。', + '晚上睡不着怎么办 以下是一些有助于在晚上更好地入睡的方法:\n\n1. 维持规律的睡眠时间:每晚尽可能在同一时间上床,保持规律的睡眠时间表,帮助身体调整并更容易入睡。\n\n2. 避免在床上使用电子设备:手机、平板电脑、电脑等电子设备会发出蓝光,这会干扰身体', + '晚上睡不着怎么办 以下是一些可能有助于在晚上入睡的方法:\n\n1. 建立规律的睡眠时间表:尽量在同一时间入睡和起床,即使在周末和假期也要尽量保持一致。\n\n2. 创造舒适的睡眠环境:保持房间安静、凉爽、黑暗、舒适,使用舒适的床垫和枕头等。\n\n3. 放松身心:尝试进行一些放松的活动,如冥想、深呼吸、瑜伽、听轻柔的音乐等,缓解压力和紧张情绪。\n\n4. 避免刺激性物质:避免饮用咖啡、茶、可乐等含咖啡因的饮料,避免吸烟和饮酒等刺激性物质。\n\n5. 避免躺在床上翻来覆去:如果躺在床上超过20分钟还不能入睡,就不要躺在床上翻来覆去,而是起床去做一些放松的活动,直到感到困倦为止。\n\n6. 练习放松技巧:如果感到焦虑或紧张,可以尝试进行一些放松技巧,如渐进性肌肉松弛、冥想等。\n\n7. 改善睡眠障碍:如果已经尝试了上述方法仍然无法入睡,可以考虑咨询医生,了解是否存在其他睡眠障碍问题,并接受相应的治疗。'] + + jit_model_generated = False + f = open('result.torch.txt', 'w') + for (do_sample, max_length, num_beams), expected_output_sentence in zip(parameters, expected_out_sentences): + set_random_seed(42) + inputs = tokenizer([sentence,], return_tensors="pt", padding=True) + inputs = inputs.to(torch_device) + #print(inputs) + inputs.data['position_ids'] = inputs.data['position_ids'].to(torch.int32) + attn_mask = torch.zeros_like(inputs.data['attention_mask'], dtype=torch.float32) + inputs.data['attention_mask'] = attn_mask.masked_fill_(inputs.data['attention_mask'], -10000.0) + + if not jit_model_generated and use_jit: + print("generating jit model...") + with torch.no_grad(), torch.cpu.amp.autocast(): + model = torch.jit.trace(model, inputs) + model = torch.jit.freeze(model) + jit_model_generated = True + print("done") + + for repeat in range(1): + t0 = time.time() + + # with profile(activities=[ProfilerActivity.CPU], record_shapes=False) as prof: + # with record_function("model_inference"): + with torch.no_grad(), torch.cpu.amp.autocast(): + outputs = model.generate( + **inputs, + do_sample=do_sample, + max_length=max_length, + num_beams=num_beams + ) + t1 = time.time() + + #prof.export_chrome_trace("trace.json") + #print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10)) + + #print(outputs) + outputs = outputs.tolist()[0] + + if repeat == 0: + out_sentence = tokenizer.decode(outputs, skip_special_tokens=True) + print(f"{out_sentence}") + print(f" #tokens={len(outputs)},do_sample={do_sample},max_length={max_length},num_beams={num_beams}") + f.write(out_sentence) + + print(f" [{repeat}] ::: {(t1-t0)*1e3/len(outputs)} ms/token") + + f.close() + +# numactl -C 0-46 python ./test_modeling_chatglm.py +if __name__ == '__main__': + mytest() + #unittest.main() \ No newline at end of file diff --git a/tests/script/models/chatglm-6b/tokenization_chatglm.py b/tests/script/models/chatglm-6b/tokenization_chatglm.py new file mode 100644 index 0000000..69ee85c --- /dev/null +++ b/tests/script/models/chatglm-6b/tokenization_chatglm.py @@ -0,0 +1,443 @@ +"""Tokenization classes for ChatGLM.""" +from typing import List, Optional, Union +import os + +from transformers.tokenization_utils import PreTrainedTokenizer +from transformers.utils import logging, PaddingStrategy +from transformers.tokenization_utils_base import EncodedInput, BatchEncoding +from typing import Dict +import sentencepiece as spm +import numpy as np + +logger = logging.get_logger(__name__) + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + "THUDM/chatglm-6b": 2048, +} + + +class TextTokenizer: + def __init__(self, model_path): + self.sp = spm.SentencePieceProcessor() + self.sp.Load(model_path) + self.num_tokens = self.sp.vocab_size() + + def encode(self, text): + return self.sp.EncodeAsIds(text) + + def decode(self, ids: List[int]): + return self.sp.DecodeIds(ids) + + def tokenize(self, text): + return self.sp.EncodeAsPieces(text) + + def convert_tokens_to_string(self, tokens): + return self.sp.DecodePieces(tokens) + + def convert_tokens_to_ids(self, tokens): + return [self.sp.PieceToId(token) for token in tokens] + + def convert_token_to_id(self, token): + return self.sp.PieceToId(token) + + def convert_id_to_token(self, idx): + return self.sp.IdToPiece(idx) + + def __len__(self): + return self.num_tokens + + +class SPTokenizer: + def __init__( + self, + vocab_file, + num_image_tokens=20000, + max_blank_length=80, + byte_fallback=True, + ): + assert vocab_file is not None + self.vocab_file = vocab_file + self.num_image_tokens = num_image_tokens + self.special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "", "", "", "", ""] + self.max_blank_length = max_blank_length + self.byte_fallback = byte_fallback + self.text_tokenizer = TextTokenizer(vocab_file) + + def _get_text_tokenizer(self): + return self.text_tokenizer + + @staticmethod + def get_blank_token(length: int): + assert length >= 2 + return f"<|blank_{length}|>" + + @staticmethod + def get_tab_token(): + return f"<|tab|>" + + @property + def num_text_tokens(self): + return self.text_tokenizer.num_tokens + + @property + def num_tokens(self): + return self.num_image_tokens + self.num_text_tokens + + @staticmethod + def _encode_whitespaces(text: str, max_len: int = 80): + text = text.replace("\t", SPTokenizer.get_tab_token()) + for i in range(max_len, 1, -1): + text = text.replace(" " * i, SPTokenizer.get_blank_token(i)) + return text + + def _preprocess(self, text: str, linebreak=True, whitespaces=True): + if linebreak: + text = text.replace("\n", "") + if whitespaces: + text = self._encode_whitespaces(text, max_len=self.max_blank_length) + return text + + def encode( + self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True + ) -> List[int]: + """ + @param text: Text to encode. + @param linebreak: Whether to encode newline (\n) in text. + @param whitespaces: Whether to encode multiple whitespaces or tab in text, useful for source code encoding. + @param special_tokens: Whether to encode special token ([MASK], [gMASK], etc.) in text. + @param add_dummy_prefix: Whether to add dummy blank space in the beginning. + """ + text = self._preprocess(text, linebreak, whitespaces) + if not add_dummy_prefix: + text = "" + text + tmp = self._get_text_tokenizer().encode(text) + tokens = [x + self.num_image_tokens for x in tmp] + return tokens if add_dummy_prefix else tokens[2:] + + def postprocess(self, text): + text = text.replace("", "\n") + text = text.replace(SPTokenizer.get_tab_token(), "\t") + for i in range(2, self.max_blank_length + 1): + text = text.replace(self.get_blank_token(i), " " * i) + return text + + def decode(self, text_ids: List[int]) -> str: + ids = [int(_id) - self.num_image_tokens for _id in text_ids] + ids = [_id for _id in ids if _id >= 0] + text = self._get_text_tokenizer().decode(ids) + text = self.postprocess(text) + return text + + def decode_tokens(self, tokens: List[str]) -> str: + text = self._get_text_tokenizer().convert_tokens_to_string(tokens) + text = self.postprocess(text) + return text + + def tokenize( + self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True + ) -> List[str]: + """ + @param text: Text to encode. + @param linebreak: Whether to encode newline (\n) in text. + @param whitespaces: Whether to encode multiple whitespaces or tab in text, useful for source code encoding. + @param special_tokens: Whether to encode special token ([MASK], [gMASK], etc.) in text. + @param add_dummy_prefix: Whether to add dummy blank space in the beginning. + """ + text = self._preprocess(text, linebreak, whitespaces) + if not add_dummy_prefix: + text = "" + text + tokens = self._get_text_tokenizer().tokenize(text) + return tokens if add_dummy_prefix else tokens[2:] + + def __getitem__(self, x: Union[int, str]): + if isinstance(x, int): + if x < self.num_image_tokens: + return "".format(x) + else: + return self.text_tokenizer.convert_id_to_token(x - self.num_image_tokens) + elif isinstance(x, str): + if x.startswith("") and x[7:-1].isdigit(): + return int(x[7:-1]) + else: + return self.text_tokenizer.convert_token_to_id(x) + self.num_image_tokens + else: + raise ValueError("The key should be str or int.") + + +class ChatGLMTokenizer(PreTrainedTokenizer): + """ + Construct a ChatGLM tokenizer. Based on byte-level Byte-Pair-Encoding. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + """ + + vocab_files_names = {"vocab_file": "ice_text.model"} + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + model_input_names = ["input_ids", "attention_mask", "position_ids"] + + def __init__( + self, + vocab_file, + do_lower_case=False, + remove_space=False, + bos_token='', + eos_token='', + end_token='', + mask_token='[MASK]', + gmask_token='[gMASK]', + padding_side="left", + pad_token="", + unk_token="", + num_image_tokens=20000, + **kwargs + ) -> None: + super().__init__( + do_lower_case=do_lower_case, + remove_space=remove_space, + padding_side=padding_side, + bos_token=bos_token, + eos_token=eos_token, + end_token=end_token, + mask_token=mask_token, + gmask_token=gmask_token, + pad_token=pad_token, + unk_token=unk_token, + num_image_tokens=num_image_tokens, + **kwargs + ) + + self.do_lower_case = do_lower_case + self.remove_space = remove_space + self.vocab_file = vocab_file + + self.bos_token = bos_token + self.eos_token = eos_token + self.end_token = end_token + self.mask_token = mask_token + self.gmask_token = gmask_token + + self.sp_tokenizer = SPTokenizer(vocab_file, num_image_tokens=num_image_tokens) + + """ Initialisation """ + + @property + def gmask_token_id(self) -> Optional[int]: + if self.gmask_token is None: + return None + return self.convert_tokens_to_ids(self.gmask_token) + + @property + def end_token_id(self) -> Optional[int]: + """ + `Optional[int]`: Id of the end of context token in the vocabulary. Returns `None` if the token has not been + set. + """ + if self.end_token is None: + return None + return self.convert_tokens_to_ids(self.end_token) + + @property + def vocab_size(self): + """ Returns vocab size """ + return self.sp_tokenizer.num_tokens + + def get_vocab(self): + """ Returns vocab as a dict """ + vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + def preprocess_text(self, inputs): + if self.remove_space: + outputs = " ".join(inputs.strip().split()) + else: + outputs = inputs + + if self.do_lower_case: + outputs = outputs.lower() + + return outputs + + def _tokenize(self, text, **kwargs): + """ Returns a tokenized string. """ + text = self.preprocess_text(text) + + seq = self.sp_tokenizer.tokenize(text) + + return seq + + def convert_tokens_to_string(self, tokens: List[str]) -> str: + return self.sp_tokenizer.decode_tokens(tokens) + + def _decode( + self, + token_ids: Union[int, List[int]], + **kwargs + ) -> str: + if isinstance(token_ids, int): + token_ids = [token_ids] + if len(token_ids) == 0: + return "" + if self.pad_token_id in token_ids: # remove pad + token_ids = list(filter((self.pad_token_id).__ne__, token_ids)) + return super()._decode(token_ids, **kwargs) + + def _convert_token_to_id(self, token): + """ Converts a token (str) in an id using the vocab. """ + return self.sp_tokenizer[token] + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.sp_tokenizer[index] + + def save_vocabulary(self, save_directory, filename_prefix=None): + """ + Save the vocabulary and special tokens file to a directory. + + Args: + save_directory (`str`): + The directory in which to save the vocabulary. + filename_prefix (`str`, *optional*): + An optional prefix to add to the named of the saved files. + + Returns: + `Tuple(str)`: Paths to the files saved. + """ + if os.path.isdir(save_directory): + vocab_file = os.path.join( + save_directory, self.vocab_files_names["vocab_file"] + ) + else: + vocab_file = save_directory + + with open(self.vocab_file, 'rb') as fin: + proto_str = fin.read() + + with open(vocab_file, "wb") as writer: + writer.write(proto_str) + + return (vocab_file,) + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A BERT sequence has the following format: + + - single sequence: `[CLS] X [SEP]` + - pair of sequences: `[CLS] A [SEP] B [SEP]` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + gmask_id = self.sp_tokenizer[self.gmask_token] + eos_id = self.sp_tokenizer[self.eos_token] + token_ids_0 = token_ids_0 + [gmask_id, self.sp_tokenizer[self.bos_token]] + if token_ids_1 is not None: + token_ids_0 = token_ids_0 + token_ids_1 + [eos_id] + return token_ids_0 + + def _pad( + self, + encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding], + max_length: Optional[int] = None, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + pad_to_multiple_of: Optional[int] = None, + return_attention_mask: Optional[bool] = None, + ) -> dict: + """ + Pad encoded inputs (on left/right and up to predefined length or max length in the batch) + + Args: + encoded_inputs: + Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`). + max_length: maximum length of the returned list and optionally padding length (see below). + Will truncate by taking into account the special tokens. + padding_strategy: PaddingStrategy to use for padding. + + - PaddingStrategy.LONGEST Pad to the longest sequence in the batch + - PaddingStrategy.MAX_LENGTH: Pad to the max length (default) + - PaddingStrategy.DO_NOT_PAD: Do not pad + The tokenizer padding sides are defined in self.padding_side: + + - 'left': pads on the left of the sequences + - 'right': pads on the right of the sequences + pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability + `>= 7.5` (Volta). + return_attention_mask: + (optional) Set to False to avoid returning attention mask (default: set to model specifics) + """ + # Load from model defaults + bos_token_id = self.sp_tokenizer[self.bos_token] + mask_token_id = self.sp_tokenizer[self.mask_token] + gmask_token_id = self.sp_tokenizer[self.gmask_token] + assert self.padding_side == "left" + + required_input = encoded_inputs[self.model_input_names[0]] + seq_length = len(required_input) + + if padding_strategy == PaddingStrategy.LONGEST: + max_length = len(required_input) + + if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): + max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + + needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length + + # Initialize attention mask if not present. + if max_length is not None: + if "attention_mask" not in encoded_inputs: + if bos_token_id in required_input: + context_length = required_input.index(bos_token_id) + else: + context_length = seq_length + attention_mask = np.ones((1, seq_length, seq_length)) + attention_mask = np.tril(attention_mask) + attention_mask[:, :, :context_length] = 1 + attention_mask = np.bool_(attention_mask < 0.5) + encoded_inputs["attention_mask"] = attention_mask + + if "position_ids" not in encoded_inputs: + if bos_token_id in required_input: + context_length = required_input.index(bos_token_id) + else: + context_length = seq_length + position_ids = np.arange(seq_length, dtype=np.int64) + mask_token = mask_token_id if mask_token_id in required_input else gmask_token_id + if mask_token in required_input: + mask_position = required_input.index(mask_token) + position_ids[context_length:] = mask_position + block_position_ids = np.concatenate( + [np.zeros(context_length, dtype=np.int64), + np.arange(1, seq_length - context_length + 1, dtype=np.int64)]) + encoded_inputs["position_ids"] = np.stack([position_ids, block_position_ids], axis=0) + + if needs_to_be_padded: + difference = max_length - len(required_input) + + if "attention_mask" in encoded_inputs: + encoded_inputs["attention_mask"] = np.pad(encoded_inputs["attention_mask"], + pad_width=[(0, 0), (difference, 0), (difference, 0)], + mode='constant', constant_values=True) + if "token_type_ids" in encoded_inputs: + encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[ + "token_type_ids" + ] + if "special_tokens_mask" in encoded_inputs: + encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"] + if "position_ids" in encoded_inputs: + encoded_inputs["position_ids"] = np.pad(encoded_inputs["position_ids"], + pad_width=[(0, 0), (difference, 0)]) + encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input + + return encoded_inputs diff --git a/tests/script/models/chatglm-6b/tokenizer_config.json b/tests/script/models/chatglm-6b/tokenizer_config.json new file mode 100644 index 0000000..f8221e0 --- /dev/null +++ b/tests/script/models/chatglm-6b/tokenizer_config.json @@ -0,0 +1,20 @@ +{ + "name_or_path": "THUDM/chatglm-6b", + "bos_token": "", + "eos_token": "", + "end_token": "", + "gmask_token": "[gMASK]", + "mask_token": "[MASK]", + "pad_token": "", + "unk_token": "", + "remove_space": false, + "do_lower_case": false, + "tokenizer_class": "ChatGLMTokenizer", + "num_image_tokens": 0, + "auto_map": { + "AutoTokenizer": [ + "tokenization_chatglm.ChatGLMTokenizer", + null + ] + } +} diff --git a/tests/script/pytest.ini b/tests/script/pytest.ini new file mode 100644 index 0000000..e68d012 --- /dev/null +++ b/tests/script/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +addopts = --ignore=models \ No newline at end of file diff --git a/tests/script/test_attn_chatglm.py b/tests/script/test_attn_chatglm.py new file mode 100644 index 0000000..9cf5164 --- /dev/null +++ b/tests/script/test_attn_chatglm.py @@ -0,0 +1,436 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +import math +import sys +import torch +import numpy as np +import llmdnn as ld +from torch import nn +import torch.nn.functional as F +from torch.nn.utils import skip_init +from typing import Optional, Tuple, Union, List, Callable, Dict, Any + +# copy from chatglm-6b/modeling_chatglm.py +class RotaryEmbedding(torch.nn.Module): + def __init__(self, dim, max_position_embeddings, base=10000, precision=torch.half, learnable=False): + super().__init__() + inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) + #inv_freq = inv_freq.half() + self.learnable = learnable + if learnable: + self.inv_freq = torch.nn.Parameter(inv_freq) + self.max_seq_len_cached = None + else: + self.register_buffer('inv_freq', inv_freq) + self.max_seq_len_cached = None + self.cos_cached = None + self.sin_cached = None + # use f32 to pass accuracy test + # Build here to make `torch.jit.trace` work. + self.max_seq_len_cached = max_position_embeddings + t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.cos_cached = emb.cos()[:, None, :] + self.sin_cached = emb.sin()[:, None, :] + + self.precision = precision + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, + error_msgs): + pass + + def forward(self, x, seq_dim=1, seq_len=None): + if seq_len is None: + seq_len = x.shape[seq_dim] + if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached): + self.max_seq_len_cached = None if self.learnable else seq_len + t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype) + freqs = torch.einsum('i,j->ij', t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + if self.precision == torch.bfloat16: + emb = emb.float() + + # [sx, 1 (b * np), hn] + cos_cached = emb.cos()[:, None, :] + sin_cached = emb.sin()[:, None, :] + if self.precision == torch.bfloat16: + cos_cached = cos_cached.bfloat16() + sin_cached = sin_cached.bfloat16() + if self.learnable: + return cos_cached, sin_cached + self.cos_cached, self.sin_cached = cos_cached, sin_cached + return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...] + + def _apply(self, fn): + if self.cos_cached is not None: + self.cos_cached = fn(self.cos_cached) + if self.sin_cached is not None: + self.sin_cached = fn(self.sin_cached) + return super()._apply(fn) + + +def rotate_half(x): + x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions + + +@torch.jit.script +def apply_rotary_pos_emb_index(q, k, cos, sin, position_id): + # position_id: [b, sq], q, k: [b, sq, np, hn], cos: [sq, 1, hn] -> [b, sq, 1, hn] + cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), \ + F.embedding(position_id, sin.squeeze(1)).unsqueeze(2) + q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) + return q, k + +class SelfAttention(torch.nn.Module): + def __init__(self, max_position_embeddings, hidden_size, num_attention_heads, + layer_id, hidden_size_per_attention_head=None, bias=True, + params_dtype=torch.float, position_encoding_2d=True, empty_init=True): + super(SelfAttention, self).__init__() + + self.layer_id = layer_id + self.hidden_size = hidden_size + self.hidden_size_per_partition = hidden_size + self.num_attention_heads = num_attention_heads + self.head_size = self.hidden_size // self.num_attention_heads + self.norm_factor = torch.sqrt(torch.tensor(self.head_size, dtype=torch.float32)).to(torch.get_default_dtype()) + self.num_attention_heads_per_partition = num_attention_heads + self.position_encoding_2d = position_encoding_2d + self.rotary_emb = RotaryEmbedding( + self.hidden_size // (self.num_attention_heads * 2) + if position_encoding_2d + else self.hidden_size // self.num_attention_heads, + max_position_embeddings, + base=10000, + precision=torch.half, + learnable=False, + ) + + self.scale_mask_softmax = None + + if hidden_size_per_attention_head is None: + self.hidden_size_per_attention_head = hidden_size // num_attention_heads + else: + self.hidden_size_per_attention_head = hidden_size_per_attention_head + + self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head + + @staticmethod + def attention_mask_func(attention_scores, attention_mask): + attention_scores.masked_fill_(attention_mask, -10000.0) + return attention_scores + + def split_tensor_along_last_dim(self, tensor, num_partitions, + contiguous_split_chunks=False): + """Split a tensor along its last dimension. + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + """ + # Get the size and dimension. + last_dim = tensor.dim() - 1 + last_dim_size = tensor.size()[last_dim] // num_partitions + # Split. + tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) + # Note: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tensor_list + + def attention_fn( + self, + query_layer, + key_layer, + value_layer, + layer_past=None, + attention_mask=None + ): + # batch, seqlen, num_attention_heads, hidden_size_per_attention_head + b, seq_len, nh, hidden_size = key_layer.shape + + # =================================== + # Raw attention scores. [b, np, s, s] + # =================================== + + # [b, np, sq, sk] + output_size = (query_layer.size(0), query_layer.size(2), query_layer.size(1), key_layer.size(1)) + + # # [sq, b, np, hn] -> [sq, b * np, hn] + # query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) + # # [sk, b, np, hn] -> [sk, b * np, hn] + # key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) + # [b, sq, np, hn] -> [b, np, sq, hn] + query_layer = query_layer.transpose(1, 2) + # [b, sk, np, hn] -> [b, np, sk, hn] + key_layer = key_layer.transpose(1, 2) + # [b, np, sq, hn] -> [b * np, sq, hn] + #query_layer = query_layer.view(output_size[0] * output_size[1], output_size[2], -1) + # [b, np, sk, hn] -> [b * np, sk, hn] + #key_layer = key_layer.view(output_size[0] * output_size[1], output_size[3], -1) + + # value_layer -> context layer. + # [b, sk, np, hn] --> [b, np, sq, hn] + + # context layer shape: [b, np, sq, hn] + output_size = (value_layer.size(0), value_layer.size(2), query_layer.size(1), value_layer.size(3)) + + # # change view [sk, b * np, hn] + # value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) + # [b, sk, np, hn] -> [b, np, sk, hn] + value_layer = value_layer.transpose(1, 2) + + if layer_past is not None: + past_key, past_value = layer_past[0], layer_past[1] + key_layer = torch.cat((past_key, key_layer), dim=2) + value_layer = torch.cat((past_value, value_layer), dim=2) + + # [b, np, sk, hn] -> [b * np, sk, hn] + #value_layer = value_layer.view(output_size[0] * output_size[1], value_layer.size(2), output_size[3]) + # return query_layer, key_layer, value_layer + past = (key_layer, value_layer) + + # from test_mha_chatglm.py/forward + query_layer = query_layer / self.norm_factor + + # =================================== + # Raw attention scores. [b, np, s, s] + # =================================== + + # [b, np, sq, sk] + output_size = (query_layer.size(0), query_layer.size(1), query_layer.size(2), key_layer.size(2)) + + # [b, np, sq, hn] -> [b * np, sq, hn] + query_layer = query_layer.reshape(output_size[0] * output_size[1], output_size[2], -1) + # [b, np, sk, hn] -> [b * np, sk, hn] + key_layer = key_layer.reshape(output_size[0] * output_size[1], output_size[3], -1) + + matmul_result = torch.zeros( + 1, 1, 1, + dtype=query_layer.dtype, + device=query_layer.device, + ) + + matmul_result = torch.baddbmm( + matmul_result, + query_layer, # [b * np, sq, hn] + key_layer.transpose(1, 2), # [b * np, hn, sk] + beta=0.0, + alpha=1.0, + ) + + # change view to [b, np, sq, sk] + attention_scores = matmul_result.view(*output_size) + + # if not (attention_mask == 0).all(): + # # if auto-regressive, skip + # attention_scores.masked_fill_(attention_mask, -10000.0) + attention_scores = attention_scores + attention_mask + dtype = attention_scores.dtype + attention_scores = attention_scores.float() + + attention_probs = F.softmax(attention_scores, dim=-1) + + attention_probs = attention_probs.type(dtype) + + # ========================= + # Context layer. [sq, b, hp] + # ========================= + + # value_layer -> context layer. + # [b, sk, np, hn] --> [b, np, sq, hn] + + # context layer shape: [b, np, sq, hn] + output_size = (value_layer.size(0), value_layer.size(1), query_layer.size(1), value_layer.size(3)) + + # [b, np, sk, hn] -> [b * np, sk, hn] + value_layer = value_layer.view(output_size[0] * output_size[1], value_layer.size(2), output_size[3]) + + # change view [b * np, sq, sk] + attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) + + # matmul: [b * np, sq, hn] + context_layer = torch.bmm(attention_probs, value_layer) + + # change view [b, np, sq, hn] + context_layer = context_layer.view(*output_size) + + # [b, np, sq, hn] --> [b, sq, np, hn] + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + + # [b, sq, np, hn] --> [b, sq, hp] + new_context_layer_shape = (output_size[0], output_size[2], output_size[1] * output_size[3]) + context_layer = context_layer.view(*new_context_layer_shape) + + return context_layer, past + + + def forward( + self, + qkv: torch.Tensor, # [batch, seq_len, 3 * hidden_size] + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]], + attention_mask: torch.Tensor, # [batch, 1, query_seq_len, key_seq_len] + position_ids # [batch, 2, query_seq_len] + ): + """ + hidden_states: [batch, seq_len, hidden_size] + attention_mask: [(1, 1), seq_len, seq_len] + """ + + # [batch, seq_len, 3 * hidden_size] + mixed_raw_layer = qkv # self.query_key_value(hidden_states) + + # [batch, seq_len, 3 * hidden_size] --> [batch, seq_len, num_attention_heads, 3 * hidden_size_per_attention_head] + new_tensor_shape = mixed_raw_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head, + ) + mixed_raw_layer = mixed_raw_layer.view(*new_tensor_shape) + + # [batch, seq_len, num_attention_heads, hidden_size_per_attention_head] + (query_layer, key_layer, value_layer) = self.split_tensor_along_last_dim(mixed_raw_layer, 3) + + if self.position_encoding_2d: + q1, q2 = query_layer.chunk(2, dim=(query_layer.ndim - 1)) + k1, k2 = key_layer.chunk(2, dim=(key_layer.ndim - 1)) + cos, sin = self.rotary_emb(q1, seq_len=position_ids.max() + 1) + # xxx + position_ids, block_position_ids = position_ids[:, 0, :].contiguous(), \ + position_ids[:, 1, :].contiguous() + q1, k1 = apply_rotary_pos_emb_index(q1, k1, cos, sin, position_ids) + q2, k2 = apply_rotary_pos_emb_index(q2, k2, cos, sin, block_position_ids) + query_layer = torch.concat([q1, q2], dim=(q1.ndim - 1)) + key_layer = torch.concat([k1, k2], dim=(k1.ndim - 1)) + else: + position_ids = position_ids.transpose(0, 1) + cos, sin = self.rotary_emb(value_layer, seq_len=position_ids.max() + 1) + # [batch, seq_len, num_attention_heads, hidden_size_per_attention_head] + query_layer, key_layer = apply_rotary_pos_emb_index(query_layer, key_layer, cos, sin, position_ids) + + query_layer = query_layer.to(dtype=torch.bfloat16) + key_layer = key_layer.to(dtype=torch.bfloat16) + + # [batch, seq_len, hidden_size] + attn, past = self.attention_fn( + query_layer=query_layer, + key_layer=key_layer, + value_layer=value_layer, + layer_past=layer_past, + attention_mask=attention_mask + ) + + return attn, past + + +class GPTAttentionExt: + def __init__(self, num_attention_heads, hidden_size, head_size_aligned, max_position_embeddings, rotary_emb_base, rotary_pct, is_int8=False): + self.attn = ld.attn_gpt() + num_heads = num_attention_heads + head_size = hidden_size // num_attention_heads + max_seq_len = max_position_embeddings + normal_factor = 1.0 / math.sqrt(head_size) + + qkv_precision_name = 's8' if is_int8 else 'bf16' + dst_precision_name = 's8' if is_int8 else 'bf16' + self.attn.create(num_heads, head_size, head_size_aligned, normal_factor, qkv_precision_name, + dst_precision_name, max_seq_len, rotary_emb_base, rotary_pct, True) + + # qkv: [batch, seq_len, (num_heads * 3 * head_size)] + # layer_past_padded: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] + # past_seq_len: past_seq_len==layer_past.shape[-2] + # attn_mask: [batch, 1, query_seq_len, key_seq_len] + # position_ids: [batch, 2, query_seq_len] + # return: + # 0: qkv [batch, seq_len, (num_heads * 3 * head_size)] + # 1: k: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] + # 2: v: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] + def forward(self, qkv, layer_past_key_padded, layer_past_value_padded, past_seq_len, attn_mask, position_ids): + return self.attn.exec_position(qkv, layer_past_key_padded, layer_past_value_padded, past_seq_len, attn_mask, position_ids) + + +HEAD_NUM = 32 +SIZE_PER_HEAD = 80 +SIZE_PER_HEAD_ALIGN = 96 +HIDDEN_SIZE = HEAD_NUM * SIZE_PER_HEAD +MAX_POSITION_EMBEDDINGS = 1024 #2048 +ROTARY_EMB_BASE = 10000 +ROTARY_PCT = 0.5 +MAX_SEQ_LEN = 1024 +def get_ref_model(): + ref_net = SelfAttention(MAX_POSITION_EMBEDDINGS, HIDDEN_SIZE, HEAD_NUM, 0, None) + ref_net = ref_net.to(dtype=torch.bfloat16) + return ref_net + +def test_attn(): + inputs = [ + # qkv: [batch, seq_len, (num_heads * 3 * head_size)] + # layer_past: [batch, num_attention_heads, past_seq_len, head_size] + # attn: [batch, 1, query_seq_len, key_seq_len] + (np.random.random(size=[2, 200, 3 * HEAD_NUM * SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 0, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 0, SIZE_PER_HEAD]).astype(np.float32), + np.zeros([2, 1, 200, 200], dtype=np.float32)), + (np.random.random(size=[2, 1, 3 * HEAD_NUM * SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), + np.zeros([2, 1, 1, 201], dtype=np.float32)), + ] + ref_net = get_ref_model() + net = GPTAttentionExt(HEAD_NUM, HIDDEN_SIZE, SIZE_PER_HEAD_ALIGN, MAX_POSITION_EMBEDDINGS, ROTARY_EMB_BASE, ROTARY_PCT) + with torch.cpu.amp.autocast(): + for (i, input) in enumerate(inputs): + qkv, layer_past_key, layer_past_value, attn_mask = input + qkv = torch.from_numpy(qkv).to(torch.bfloat16) + layer_past_key = torch.from_numpy(layer_past_key).to(torch.bfloat16) + layer_past_value = torch.from_numpy(layer_past_value).to(torch.bfloat16) + attn_mask = torch.from_numpy(attn_mask) + attn_mask[:,:,:,-2:] = torch.finfo(torch.float32).min + past_seq_len = layer_past_key.shape[-2] + shape = list(layer_past_key.shape) + + if qkv.size(1) > 1: + seq_batch1 = torch.arange(end=qkv.size(1) - 1, dtype=torch.int32) + seq_batch1 = torch.concat((seq_batch1, seq_batch1[-1:])) + block_batch1 = torch.concat((torch.zeros(qkv.size(1) -1, dtype=torch.int32), torch.ones(1, dtype=torch.int32))) + else: + seq_batch1 = torch.tensor([3], dtype=torch.int32) + block_batch1 = torch.tensor([5], dtype=torch.int32) + seq_ids = torch.empty((qkv.size(0), 2, qkv.size(1)), dtype=torch.int32) + seq_ids[:, 0, :] = seq_batch1 + seq_ids[:, 1, :] = block_batch1 + output_ref, (key_ref, value_ref) = ref_net.forward(qkv, (layer_past_key, layer_past_value), attn_mask, seq_ids) + + shape[-2] = MAX_SEQ_LEN + shape[-1] = SIZE_PER_HEAD_ALIGN + layer_past_key_padded = torch.zeros(shape, dtype=torch.bfloat16) + layer_past_value_padded = torch.zeros(shape, dtype=torch.bfloat16) + layer_past_key_padded[:,:,:layer_past_key.shape[-2],:layer_past_key.shape[-1]] = layer_past_key + layer_past_value_padded[:,:,:layer_past_key.shape[-2],:layer_past_key.shape[-1]] = layer_past_value + key_ref = key_ref.to(dtype=torch.bfloat16) + output = net.forward(qkv, layer_past_key_padded, layer_past_value_padded, past_seq_len, attn_mask, seq_ids) + key, value = layer_past_key_padded, layer_past_value_padded + # check output + if not torch.allclose(output_ref, output, rtol=0.001, atol=0.01): + print(f"error at index {i} ref:\n{output_ref} \ncur:\n {output} ") + assert(False) + # check key + if not torch.allclose(key_ref, key[:,:,:key_ref.shape[-2],:key_ref.shape[-1]], rtol=0.001, atol=0.01): + print(f"error at key index {i} ref:\n{key_ref} \ncur:\n {key} ") + assert(False) + # check value + if not torch.allclose(value_ref, value[:,:,:value_ref.shape[-2],:value_ref.shape[-1]], rtol=0.001, atol=0.01): + print(f"error at value index {i} ref:\n{value_ref} \ncur:\n {value} ") + assert(False) + + print('done.') + return + +if __name__ == "__main__": + test_attn() \ No newline at end of file diff --git a/tests/script/test_mha_chatglm.py b/tests/script/test_mha_chatglm.py new file mode 100644 index 0000000..59179e6 --- /dev/null +++ b/tests/script/test_mha_chatglm.py @@ -0,0 +1,259 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +import math +import sys +import torch +import numpy as np +import llmdnn as ld +from torch import nn +import torch.nn.functional as F + +# copy from chatglm-6b/modeling_chatglm.py +class GPTNeoXAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.num_attention_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.head_size = self.hidden_size // self.num_attention_heads + self.norm_factor = torch.sqrt(torch.tensor(self.head_size, dtype=torch.float32)).to(torch.get_default_dtype()) + + def forward(self, + query_layer, # [b, np, sq, hn] + key_layer, # [b, np, sk, hn] + value_layer, # [b, np, sk, hn] + attention_mask # [b, np, s, s] + ): + return self.attention_fn(query_layer, key_layer, value_layer, attention_mask) + + def attention_fn( + self, + query_layer, + key_layer, + value_layer, + attention_mask + ): + query_layer = query_layer / self.norm_factor + + # =================================== + # Raw attention scores. [b, np, s, s] + # =================================== + + # [b, np, sq, sk] + output_size = (query_layer.size(0), query_layer.size(1), query_layer.size(2), key_layer.size(2)) + + # [b, np, sq, hn] -> [b * np, sq, hn] + query_layer = query_layer.view(output_size[0] * output_size[1], output_size[2], -1) + # [b, np, sk, hn] -> [b * np, sk, hn] + key_layer = key_layer.view(output_size[0] * output_size[1], output_size[3], -1) + + matmul_result = torch.zeros( + 1, 1, 1, + dtype=query_layer.dtype, + device=query_layer.device, + ) + + matmul_result = torch.baddbmm( + matmul_result, + query_layer, # [b * np, sq, hn] + key_layer.transpose(1, 2), # [b * np, hn, sk] + beta=0.0, + alpha=1.0, + ) + + # change view to [b, np, sq, sk] + attention_scores = matmul_result.view(*output_size) + + # if not (attention_mask == 0).all(): + # # if auto-regressive, skip + # attention_scores.masked_fill_(attention_mask, -10000.0) + attention_scores = attention_scores + attention_mask + dtype = attention_scores.dtype + attention_scores = attention_scores.float() + + attention_probs = F.softmax(attention_scores, dim=-1) + + attention_probs = attention_probs.type(dtype) + + # ========================= + # Context layer. [sq, b, hp] + # ========================= + + # value_layer -> context layer. + # [b, sk, np, hn] --> [b, np, sq, hn] + + # context layer shape: [b, np, sq, hn] + output_size = (value_layer.size(0), value_layer.size(1), query_layer.size(1), value_layer.size(3)) + + # [b, np, sk, hn] -> [b * np, sk, hn] + value_layer = value_layer.view(output_size[0] * output_size[1], value_layer.size(2), output_size[3]) + + # change view [b * np, sq, sk] + attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) + + # matmul: [b * np, sq, hn] + context_layer = torch.bmm(attention_probs, value_layer) + + # change view [b, np, sq, hn] + context_layer = context_layer.view(*output_size) + + # [b, np, sq, hn] --> [b, sq, np, hn] + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + + # [b, sq, np, hn] --> [b, sq, hp] + new_context_layer_shape = (output_size[0], output_size[2], output_size[1] * output_size[3]) + context_layer = context_layer.view(*new_context_layer_shape) + + return context_layer + +class GPTNeoXAttentionExt: + def __init__(self, num_attention_heads, hidden_size, head_size_aligned, max_position_embeddings, is_int8=False): + self.mha = ld.mha_gpt() + num_heads = num_attention_heads + head_size = hidden_size // num_attention_heads + max_seq_len = max_position_embeddings + + head_size_aligned = head_size_aligned + normal_factor = 1.0 / math.sqrt(head_size) + qkv_precision_name = 's8' if is_int8 else 'bf16' + dst_precision_name = 's8' if is_int8 else 'bf16' + self.mha.create(num_heads, head_size, head_size_aligned, normal_factor, qkv_precision_name, + dst_precision_name, max_seq_len) + + def forward(self, query, key, value, attention_mask, head_size, key_seq_len): + return self.mha.exec(query, key, value, attention_mask, head_size, key_seq_len) + + def forward_quant(self, query, key, value, attention_mask, q_quant, k_quant, qk_quant, v_quant, requant): + # q_dequant, k_dequant, v_dequant, qk_quant, std::vector& qkv_quant + return self.mha.exec_quant(query, key, value, attention_mask, 1.0 / q_quant, 1.0 / k_quant, 1.0 / v_quant, qk_quant, requant) + +HEAD_NUM = 32 +SIZE_PER_HEAD = 80 +SIZE_PER_HEAD_ALIGN = 96 +HIDDEN_SIZE = HEAD_NUM * SIZE_PER_HEAD +MAX_POSITION_EMBEDDINGS = 1024 #2048 +def get_ref_model(): + class FakeConfig: + def __init__(self): + self.num_attention_heads = HEAD_NUM + self.hidden_size = HIDDEN_SIZE + self.max_position_embeddings = MAX_POSITION_EMBEDDINGS + config = FakeConfig() + ref_net = GPTNeoXAttention(config) + ref_net = ref_net.to(dtype=torch.bfloat16) + return ref_net + +def test_chatglm_neox(): + inputs = [ + # q, k, v, attn_mask + # q: [batch, num_heads, query_seq_len, head_size] + # k: [batch, num_heads, key_seq_len, head_size] + # v: [batch, num_heads, value_seq_len, head_size] + # attn: [batch, 1, query_seq_len, key_seq_len] + (np.random.random(size=[2, HEAD_NUM, 2, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 32, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 32, SIZE_PER_HEAD]).astype(np.float32), + np.zeros([2, 1, 2, 32], dtype=np.float32)), + (np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), + np.zeros([2, 1, 200, 200], dtype=np.float32)), + (np.random.random(size=[2, HEAD_NUM, 1, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), + np.zeros([2, 1, 1, 200], dtype=np.float32)), + ] + ref_net = get_ref_model() + net = GPTNeoXAttentionExt(HEAD_NUM, HIDDEN_SIZE, SIZE_PER_HEAD_ALIGN, MAX_POSITION_EMBEDDINGS) + with torch.cpu.amp.autocast(): + for (i, input) in enumerate(inputs): + q, k, v, attn_mask = input + q = torch.from_numpy(q).to(torch.bfloat16) + k = torch.from_numpy(k).to(torch.bfloat16) + v = torch.from_numpy(v).to(torch.bfloat16) + attn_mask = torch.from_numpy(attn_mask) + attn_mask[:,:,:,-2:] = torch.finfo(torch.float32).min + ref_output = ref_net.forward(q, k, v, attn_mask) + + shape = list(k.shape) + shape[-2] = MAX_POSITION_EMBEDDINGS + shape[-1] = SIZE_PER_HEAD_ALIGN + key_padded = torch.zeros(shape, dtype=torch.bfloat16) + value_padded = torch.zeros(shape, dtype=torch.bfloat16) + query_shape = list(q.shape) + query_shape[-1] = SIZE_PER_HEAD_ALIGN + query_padded = torch.zeros(query_shape, dtype=torch.bfloat16) + key_padded[:,:,:k.shape[-2],:k.shape[-1]] = k + value_padded[:,:,:k.shape[-2],:k.shape[-1]] = v + query_padded[:,:,:,:q.shape[-1]] = q + + output = net.forward(query_padded, key_padded, value_padded, attn_mask, k.size(3), k.size(2)) + if not torch.allclose(ref_output, output, rtol=0.001, atol=0.01): + print(f"error at index {i} ref:\n{ref_output} \ncur:\n {output} ") + assert(False) + + print('done.') + return + +def todotest_gpt_neox_int8(): + low = -4 + high = 4 + range_ = high - low + q_quant = 127.0 / high + qs = [ + np.random.random(size=[2, HEAD_NUM, 900, SIZE_PER_HEAD]).astype(np.float32)*range_+low, + np.random.random(size=[2, HEAD_NUM , 1, SIZE_PER_HEAD]).astype(np.float32)*range_+low, + np.random.random(size=[2, HEAD_NUM , 1, SIZE_PER_HEAD]).astype(np.float32)*range_+low, + ] + low = -2 + high = 2 + range_ = high - low + k_quant = 127.0 / high + ks = [ + np.random.random(size=[2, HEAD_NUM, 900, SIZE_PER_HEAD]).astype(np.float32)*range_+low, + np.random.random(size=[2, HEAD_NUM, 901, SIZE_PER_HEAD]).astype(np.float32)*range_+low, + np.random.random(size=[2, HEAD_NUM, 902, SIZE_PER_HEAD]).astype(np.float32)*range_+low, + ] + low = -8 + high = 8 + range_ = high - low + v_quant = 127.0 / high + vs = [ + np.random.random(size=[2, HEAD_NUM, 900, SIZE_PER_HEAD]).astype(np.float32)*range_+low, + np.random.random(size=[2, HEAD_NUM, 901, SIZE_PER_HEAD]).astype(np.float32)*range_+low, + np.random.random(size=[2, HEAD_NUM, 902, SIZE_PER_HEAD]).astype(np.float32)*range_+low, + ] + # q, k, v, attn_mask + # q: [batch, num_heads, query_seq_len, head_size] + # k: [batch, num_heads, key_seq_len, head_size] + # v: [batch, num_heads, value_seq_len, head_size] + # attn: [1, MAX_POSITION_EMBEDDINGS] + inputs = [] + for i in range(len(qs)): + inputs.append((qs[i], ks[i], vs[i], np.zeros([1, ks[i].shape[-2]], dtype=np.float32))) + + ref_net = get_ref_model() + net = GPTNeoXAttentionExt(HEAD_NUM, HIDDEN_SIZE, MAX_POSITION_EMBEDDINGS, True) + qk_quant, requant = 255.0, 10.0 + with torch.cpu.amp.autocast(): + for (i, input) in enumerate(inputs): + q, k, v, attn_mask = input + q = torch.from_numpy(q) + k = torch.from_numpy(k) + v = torch.from_numpy(v) + attn_mask = torch.from_numpy(attn_mask) + q = (q * q_quant).round().clamp(-128, 127).to(torch.int8) + k = (k * k_quant).round().clamp(-128, 127).to(torch.int8) + v = (v * v_quant).round().clamp(-128, 127).to(torch.int8) + ref_output = ref_net.forward(q, k, v, attn_mask, q_quant, k_quant, qk_quant, v_quant, requant) + output = net.forward_quant(q, k, v, attn_mask, q_quant, k_quant, qk_quant, v_quant, [requant,]) + if (torch.abs(ref_output- output) > 2).any(): + print(f"error at index {i} ref:\n{ref_output} \ncur:\n {output} ") + assert(False) + + print('done.') + return + +if __name__ == "__main__": + test_chatglm_neox() \ No newline at end of file diff --git a/tests/script/test_mha_gpt.py b/tests/script/test_mha_gpt.py index 184d753..ac99f99 100644 --- a/tests/script/test_mha_gpt.py +++ b/tests/script/test_mha_gpt.py @@ -146,19 +146,19 @@ def test_gpt_neox(): # q: [batch, num_heads, query_seq_len, head_size] # k: [batch, num_heads, key_seq_len, head_size] # v: [batch, num_heads, value_seq_len, head_size] - # attn: [1, MAX_POSITION_EMBEDDINGS] + # attn: [2, 1, 1, key_seq_len] (np.random.random(size=[2, HEAD_NUM, 2, SIZE_PER_HEAD]).astype(np.float32), np.random.random(size=[2, HEAD_NUM, 32, SIZE_PER_HEAD]).astype(np.float32), np.random.random(size=[2, HEAD_NUM, 32, SIZE_PER_HEAD]).astype(np.float32), - np.zeros([1, 32], dtype=np.float32)), + np.zeros([2, 1, 1, 32], dtype=np.float32)), (np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), - np.zeros([1, 200], dtype=np.float32)), + np.zeros([2, 1, 1, 200], dtype=np.float32)), (np.random.random(size=[2, HEAD_NUM, 1, SIZE_PER_HEAD]).astype(np.float32), np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), - np.zeros([1, 200], dtype=np.float32)), + np.zeros([2, 1, 1, 200], dtype=np.float32)), ] ref_net = get_ref_model() net = GPTNeoXAttentionExt(HEAD_NUM, HIDDEN_SIZE, MAX_POSITION_EMBEDDINGS) @@ -169,6 +169,7 @@ def test_gpt_neox(): k = torch.from_numpy(k).to(torch.bfloat16) v = torch.from_numpy(v).to(torch.bfloat16) attn_mask = torch.from_numpy(attn_mask) + attn_mask[:,:,:,-2:] = torch.finfo(torch.float32).min ref_output = ref_net.forward(q, k, v, attn_mask) output = net.forward(q, k, v, attn_mask) if not torch.allclose(ref_output, output, rtol=0.001, atol=0.01): @@ -210,10 +211,10 @@ def test_gpt_neox_int8(): # q: [batch, num_heads, query_seq_len, head_size] # k: [batch, num_heads, key_seq_len, head_size] # v: [batch, num_heads, value_seq_len, head_size] - # attn: [1, MAX_POSITION_EMBEDDINGS] + # attn: [batch, 1, 1, key_seq_len] inputs = [] for i in range(len(qs)): - inputs.append((qs[i], ks[i], vs[i], np.zeros([1, ks[i].shape[-2]], dtype=np.float32))) + inputs.append((qs[i], ks[i], vs[i], np.zeros([ks[i].shape[0], 1, 1, ks[i].shape[-2]], dtype=np.float32))) ref_net = get_ref_model() net = GPTNeoXAttentionExt(HEAD_NUM, HIDDEN_SIZE, MAX_POSITION_EMBEDDINGS, True) @@ -238,4 +239,5 @@ def test_gpt_neox_int8(): return if __name__ == "__main__": + test_gpt_neox() test_gpt_neox_int8() \ No newline at end of file diff --git a/tests/script/test_rotary_pastkv_chatglm.py b/tests/script/test_rotary_pastkv_chatglm.py new file mode 100644 index 0000000..278bfa6 --- /dev/null +++ b/tests/script/test_rotary_pastkv_chatglm.py @@ -0,0 +1,352 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +import math +import sys +import torch +import numpy as np +import llmdnn as ld +from torch import nn +import torch.nn.functional as F +from torch.nn.utils import skip_init +from typing import Optional, Tuple, Union, List, Callable, Dict, Any + +# copy from chatglm-6b/modeling_chatglm.py +class RotaryEmbedding(torch.nn.Module): + def __init__(self, dim, max_position_embeddings, base=10000, precision=torch.half, learnable=False): + super().__init__() + inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) + #inv_freq = inv_freq.half() + self.learnable = learnable + if learnable: + self.inv_freq = torch.nn.Parameter(inv_freq) + self.max_seq_len_cached = None + else: + self.register_buffer('inv_freq', inv_freq) + self.max_seq_len_cached = None + self.cos_cached = None + self.sin_cached = None + # use f32 to pass accuracy test + # Build here to make `torch.jit.trace` work. + self.max_seq_len_cached = max_position_embeddings + t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.cos_cached = emb.cos()[:, None, :] + self.sin_cached = emb.sin()[:, None, :] + + self.precision = precision + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, + error_msgs): + pass + + def forward(self, x, seq_dim=1, seq_len=None): + if seq_len is None: + seq_len = x.shape[seq_dim] + if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached): + self.max_seq_len_cached = None if self.learnable else seq_len + t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype) + freqs = torch.einsum('i,j->ij', t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + if self.precision == torch.bfloat16: + emb = emb.float() + + # [sx, 1 (b * np), hn] + cos_cached = emb.cos()[:, None, :] + sin_cached = emb.sin()[:, None, :] + if self.precision == torch.bfloat16: + cos_cached = cos_cached.bfloat16() + sin_cached = sin_cached.bfloat16() + if self.learnable: + return cos_cached, sin_cached + self.cos_cached, self.sin_cached = cos_cached, sin_cached + return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...] + + def _apply(self, fn): + if self.cos_cached is not None: + self.cos_cached = fn(self.cos_cached) + if self.sin_cached is not None: + self.sin_cached = fn(self.sin_cached) + return super()._apply(fn) + + +def rotate_half(x): + x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions + + +@torch.jit.script +def apply_rotary_pos_emb_index(q, k, cos, sin, position_id): + # position_id: [b, sq], q, k: [b, sq, np, hn], cos: [sq, 1, hn] -> [b, sq, 1, hn] + cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), \ + F.embedding(position_id, sin.squeeze(1)).unsqueeze(2) + q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) + return q, k + +class SelfAttention(torch.nn.Module): + def __init__(self, max_position_embeddings, hidden_size, num_attention_heads, + layer_id, hidden_size_per_attention_head=None, bias=True, + params_dtype=torch.float, position_encoding_2d=True, empty_init=True): + super(SelfAttention, self).__init__() + + self.layer_id = layer_id + self.hidden_size = hidden_size + self.hidden_size_per_partition = hidden_size + self.num_attention_heads = num_attention_heads + self.num_attention_heads_per_partition = num_attention_heads + self.position_encoding_2d = position_encoding_2d + self.rotary_emb = RotaryEmbedding( + self.hidden_size // (self.num_attention_heads * 2) + if position_encoding_2d + else self.hidden_size // self.num_attention_heads, + max_position_embeddings, + base=10000, + precision=torch.half, + learnable=False, + ) + + self.scale_mask_softmax = None + + if hidden_size_per_attention_head is None: + self.hidden_size_per_attention_head = hidden_size // num_attention_heads + else: + self.hidden_size_per_attention_head = hidden_size_per_attention_head + + self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head + + @staticmethod + def attention_mask_func(attention_scores, attention_mask): + attention_scores.masked_fill_(attention_mask, -10000.0) + return attention_scores + + def split_tensor_along_last_dim(self, tensor, num_partitions, + contiguous_split_chunks=False): + """Split a tensor along its last dimension. + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + """ + # Get the size and dimension. + last_dim = tensor.dim() - 1 + last_dim_size = tensor.size()[last_dim] // num_partitions + # Split. + tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) + # Note: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tensor_list + + def attention_fn( + self, + query_layer, + key_layer, + value_layer, + layer_past=None + ): + # batch, seqlen, num_attention_heads, hidden_size_per_attention_head + b, seq_len, nh, hidden_size = key_layer.shape + + # =================================== + # Raw attention scores. [b, np, s, s] + # =================================== + + # [b, np, sq, sk] + output_size = (query_layer.size(0), query_layer.size(2), query_layer.size(1), key_layer.size(1)) + + # # [sq, b, np, hn] -> [sq, b * np, hn] + # query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) + # # [sk, b, np, hn] -> [sk, b * np, hn] + # key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) + # [b, sq, np, hn] -> [b, np, sq, hn] + query_layer = query_layer.transpose(1, 2) + # [b, sk, np, hn] -> [b, np, sk, hn] + key_layer = key_layer.transpose(1, 2) + # [b, np, sq, hn] -> [b * np, sq, hn] + #query_layer = query_layer.view(output_size[0] * output_size[1], output_size[2], -1) + # [b, np, sk, hn] -> [b * np, sk, hn] + #key_layer = key_layer.view(output_size[0] * output_size[1], output_size[3], -1) + + # value_layer -> context layer. + # [b, sk, np, hn] --> [b, np, sq, hn] + + # context layer shape: [b, np, sq, hn] + output_size = (value_layer.size(0), value_layer.size(2), query_layer.size(1), value_layer.size(3)) + + # # change view [sk, b * np, hn] + # value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) + # [b, sk, np, hn] -> [b, np, sk, hn] + value_layer = value_layer.transpose(1, 2) + + if layer_past is not None: + past_key, past_value = layer_past[0], layer_past[1] + key_layer = torch.cat((past_key, key_layer), dim=2) + value_layer = torch.cat((past_value, value_layer), dim=2) + + # [b, np, sk, hn] -> [b * np, sk, hn] + #value_layer = value_layer.view(output_size[0] * output_size[1], value_layer.size(2), output_size[3]) + return query_layer, key_layer, value_layer + + def forward( + self, + qkv: torch.Tensor, # [batch, seq_len, 3 * hidden_size] + position_ids, # [batch, 2, query_seq_len] + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None + ): + """ + hidden_states: [batch, seq_len, hidden_size] + attention_mask: [(1, 1), seq_len, seq_len] + """ + + # [batch, seq_len, 3 * hidden_size] + mixed_raw_layer = qkv # self.query_key_value(hidden_states) + + # [batch, seq_len, 3 * hidden_size] --> [batch, seq_len, num_attention_heads, 3 * hidden_size_per_attention_head] + new_tensor_shape = mixed_raw_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head, + ) + mixed_raw_layer = mixed_raw_layer.view(*new_tensor_shape) + + # [batch, seq_len, num_attention_heads, hidden_size_per_attention_head] + (query_layer, key_layer, value_layer) = self.split_tensor_along_last_dim(mixed_raw_layer, 3) + + if self.position_encoding_2d: + q1, q2 = query_layer.chunk(2, dim=(query_layer.ndim - 1)) + k1, k2 = key_layer.chunk(2, dim=(key_layer.ndim - 1)) + cos, sin = self.rotary_emb(q1, seq_len=position_ids.max() + 1) + # xxx + position_ids, block_position_ids = position_ids[:, 0, :].contiguous(), \ + position_ids[:, 1, :].contiguous() + q1, k1 = apply_rotary_pos_emb_index(q1, k1, cos, sin, position_ids) + q2, k2 = apply_rotary_pos_emb_index(q2, k2, cos, sin, block_position_ids) + query_layer = torch.concat([q1, q2], dim=(q1.ndim - 1)) + key_layer = torch.concat([k1, k2], dim=(k1.ndim - 1)) + else: + position_ids = position_ids.transpose(0, 1) + cos, sin = self.rotary_emb(value_layer, seq_len=position_ids.max() + 1) + # [batch, seq_len, num_attention_heads, hidden_size_per_attention_head] + query_layer, key_layer = apply_rotary_pos_emb_index(query_layer, key_layer, cos, sin, position_ids) + + query_layer = query_layer.to(dtype=torch.bfloat16) + key_layer = key_layer.to(dtype=torch.bfloat16) + + # [batch, seq_len, hidden_size] + query_layer, key_layer, value_layer = self.attention_fn( + query_layer=query_layer, + key_layer=key_layer, + value_layer=value_layer, + layer_past=layer_past, + ) + + return query_layer, key_layer, value_layer + + +class GPTNeoXAttentionExt: + def __init__(self, num_attention_heads, hidden_size, head_size_aligned, max_position_embeddings, rotary_emb_base, rotary_pct, is_int8=False): + self.emd = ld.emb_gpt() + num_heads = num_attention_heads + head_size = hidden_size // num_attention_heads + max_seq_len = max_position_embeddings + + qkv_precision_name = 's8' if is_int8 else 'bf16' + dst_precision_name = 's8' if is_int8 else 'bf16' + self.emd.create(num_heads, head_size, head_size_aligned, qkv_precision_name, + dst_precision_name, max_seq_len, rotary_emb_base, rotary_pct, True) + + # qkv: [batch, seq_len, (num_heads * 3 * head_size)] + # layer_past_padded: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] + # past_seq_len: past_seq_len==layer_past.shape[-2] + # return: + # 0: (k, v): ([batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned], [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned]) + # 1: query: [batch, num_attention_heads, seq_len, head_size_aligned] + # 2: k: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] + # 3: v: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] + def forward(self, qkv, layer_past_key_padded, layer_past_value_padded, query_padded, past_seq_len, position_ids): + self.emd.exec_position(qkv, layer_past_key_padded, layer_past_value_padded, query_padded, past_seq_len, position_ids) + + +HEAD_NUM = 32 +SIZE_PER_HEAD = 80 +SIZE_PER_HEAD_ALIGN = 96 +HIDDEN_SIZE = HEAD_NUM * SIZE_PER_HEAD +MAX_POSITION_EMBEDDINGS = 1024 #2048 +ROTARY_EMB_BASE = 10000 +ROTARY_PCT = 0.5 +MAX_SEQ_LEN = 1024 +def get_ref_model(): + ref_net = SelfAttention(MAX_POSITION_EMBEDDINGS, HIDDEN_SIZE, HEAD_NUM, 0, None) + ref_net = ref_net.to(dtype=torch.bfloat16) + return ref_net + +def test_chatglm(): + inputs = [ + # qkv: [batch, seq_len, (num_heads * 3 * head_size)] + # layer_past: [batch, num_attention_heads, past_seq_len, head_size] + (np.random.random(size=[2, 200, 3 * HEAD_NUM * SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 0, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 0, SIZE_PER_HEAD]).astype(np.float32)), + (np.random.random(size=[2, 1, 3 * HEAD_NUM * SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32)), + ] + ref_net = get_ref_model() + net = GPTNeoXAttentionExt(HEAD_NUM, HIDDEN_SIZE, SIZE_PER_HEAD_ALIGN, MAX_POSITION_EMBEDDINGS, ROTARY_EMB_BASE, ROTARY_PCT) + with torch.cpu.amp.autocast(): + for (i, input) in enumerate(inputs): + qkv, layer_past_key, layer_past_value = input + qkv = torch.from_numpy(qkv).to(torch.bfloat16) + layer_past_key = torch.from_numpy(layer_past_key).to(torch.bfloat16) + layer_past_value = torch.from_numpy(layer_past_value).to(torch.bfloat16) + past_seq_len = layer_past_key.shape[-2] + shape = list(layer_past_key.shape) + + if qkv.size(1) > 1: + seq_batch1 = torch.arange(end=qkv.size(1) - 1, dtype=torch.int32) + seq_batch1 = torch.concat((seq_batch1, seq_batch1[-1:])) + block_batch1 = torch.concat((torch.zeros(qkv.size(1) -1, dtype=torch.int32), torch.ones(1, dtype=torch.int32))) + else: + seq_batch1 = torch.tensor([3], dtype=torch.int32) + block_batch1 = torch.tensor([5], dtype=torch.int32) + seq_ids = torch.empty((qkv.size(0), 2, qkv.size(1)), dtype=torch.int32) + seq_ids[:, 0, :] = seq_batch1 + seq_ids[:, 1, :] = block_batch1 + query_ref, key_ref, value_ref = ref_net.forward(qkv, seq_ids, (layer_past_key, layer_past_value)) + + shape[-2] = MAX_SEQ_LEN + shape[-1] = SIZE_PER_HEAD_ALIGN + layer_past_key_padded = torch.zeros(shape, dtype=torch.bfloat16) + layer_past_value_padded = torch.zeros(shape, dtype=torch.bfloat16) + query_shape = list(shape) + query_shape[2] = qkv.shape[1] + query_padded = torch.zeros(query_shape, dtype=torch.bfloat16) + layer_past_key_padded[:,:,:layer_past_key.shape[-2],:layer_past_key.shape[-1]] = layer_past_key + layer_past_value_padded[:,:,:layer_past_key.shape[-2],:layer_past_key.shape[-1]] = layer_past_value + query_ref = query_ref.to(dtype=torch.bfloat16) + key_ref = key_ref.to(dtype=torch.bfloat16) + net.forward(qkv, layer_past_key_padded, layer_past_value_padded, query_padded, past_seq_len, seq_ids) + output, query, key, value = (layer_past_key_padded, layer_past_value_padded), query_padded, layer_past_key_padded, layer_past_value_padded + # check query + if not torch.allclose(query_ref, query[:,:,:,:query_ref.shape[-1]], rtol=0.001, atol=0.01): + print(f"error at query index {i} ref:\n{query_ref} \ncur:\n {query} ") + assert(False) + # check key + if not torch.allclose(key_ref, key[:,:,:key_ref.shape[-2],:key_ref.shape[-1]], rtol=0.001, atol=0.01): + print(f"error at key index {i} ref:\n{key_ref} \ncur:\n {key} ") + assert(False) + # check value + if not torch.allclose(value_ref, value[:,:,:value_ref.shape[-2],:value_ref.shape[-1]], rtol=0.001, atol=0.01): + print(f"error at value index {i} ref:\n{value_ref} \ncur:\n {value} ") + assert(False) + + print('done.') + return + +if __name__ == "__main__": + test_chatglm() \ No newline at end of file From 0590c6d81b618685686b1875fefe2dd9ae4018da Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Thu, 29 Jun 2023 01:51:25 +0800 Subject: [PATCH 12/54] rotary embbeding supports contiguous pastkv --- include/llm_emb_gpt.hpp | 15 +++-- src/emb_gpt_avx512.cpp | 53 +++++++++++---- tests/script/ext/attn_gpt.cpp | 47 +++++++------ tests/script/ext/emb_gpt.cpp | 78 +++++++++++++--------- tests/script/test_rotary_pastkv_chatglm.py | 33 +++++++-- 5 files changed, 148 insertions(+), 78 deletions(-) diff --git a/include/llm_emb_gpt.hpp b/include/llm_emb_gpt.hpp index 2d3c09c..7bfd183 100644 --- a/include/llm_emb_gpt.hpp +++ b/include/llm_emb_gpt.hpp @@ -23,17 +23,20 @@ class emb_gpt { data_type_t dst_precision; size_t rotary_emb_base; float rotary_pct; - bool use_position2d; + bool use_position2d; // chatglm true, other false }; struct exec_param { size_t batch; size_t query_seq_len; size_t past_seq_len; - uint8_t* qkv; - uint8_t* query_dst; - uint8_t** layer_past_key_padded; - uint8_t** layer_past_value_padded; - int* position2d_ids; // shape: [batch, 2, query_seq_len] + uint8_t* qkv; // shape: [batch, query_seq_len, 3 * hidden size] + uint8_t* query_dst; // rotary embbeding dst + uint8_t** layer_past_key_src; // past key src + uint8_t** layer_past_value_src; // past value src + uint8_t** layer_past_key_dst; // past key dst, if layer_past_key_src!=layer_past_key_dst, will copy layer_past_key_src to layer_past_key_dst + uint8_t** layer_past_value_dst; // past value dst, if layer_past_value!=layer_past_value_dst, will copy layer_past_value to layer_past_value_dst + int* position2d_ids; // shape: [batch, 2, query_seq_len] + size_t head_stride_in_kv; // kv stride for next head; kv may be preallocated a big buffer }; emb_gpt(); diff --git a/src/emb_gpt_avx512.cpp b/src/emb_gpt_avx512.cpp index 5d48e16..c957d93 100644 --- a/src/emb_gpt_avx512.cpp +++ b/src/emb_gpt_avx512.cpp @@ -23,10 +23,12 @@ struct emb_gpt_impl_avx512 : public emb_gpt::impl { void exec(const emb_gpt::exec_param& param) override; void initRotery(size_t max_seq_len); + void memcpyPastKV(uint8_t** pastk_src, uint8_t** pastv_src, uint8_t** pastk_dst, uint8_t** pastv_dst, + size_t batch, size_t past_seq_len, size_t head_stride_in_kv); void applyRotaryPosEmbMemcpy(uint8_t* q_src, uint8_t* k_src, uint8_t* v_src, uint8_t* q_dst, uint8_t** k_dst, uint8_t** v_dst, - size_t batch, size_t q_seq_len, size_t past_seq_len); + size_t batch, size_t q_seq_len, size_t past_seq_len, size_t head_stride_in_kv); void applyRotaryPosEmbMemcpyWithPosition2d(uint8_t* q_src, uint8_t* k_src, uint8_t* v_src, uint8_t* q_dst, uint8_t** k_dst, uint8_t** v_dst, - size_t batch, size_t q_seq_len, size_t past_seq_len, int* position2d_ids); + size_t batch, size_t q_seq_len, size_t past_seq_len, int* position2d_ids, size_t head_stride_in_kv); emb_gpt::create_param _create_param; size_t _head_num = 32; @@ -111,12 +113,33 @@ void emb_gpt_impl_avx512::initRotery(size_t max_seq_len) { } } +void emb_gpt_impl_avx512::memcpyPastKV(uint8_t** pastk_src, uint8_t** pastv_src, uint8_t** pastk_dst, uint8_t** pastv_dst, + size_t batch, size_t past_seq_len, size_t head_stride_in_kv) { + parallel_for3d(batch, _head_num, past_seq_len, [&](size_t b, size_t h, size_t s) { + auto k_dst_batch = pastk_dst[b]; + auto v_dst_batch = pastv_dst[b]; + auto k_src_batch = pastk_src[b]; + auto v_src_batch = pastv_src[b]; + auto k_dst_seq = k_dst_batch + s * _size_per_head_aligned * _output_type_size; + auto v_dst_seq = v_dst_batch + s * _size_per_head_aligned * _output_type_size; + auto k_src_seq = k_src_batch + s * _size_per_head_aligned * _output_type_size; + auto v_src_seq = v_src_batch + s * _size_per_head_aligned * _output_type_size; + auto* k_src_f = k_src_seq + h * past_seq_len * _size_per_head_aligned * _output_type_size; + auto* k_dst_f = k_dst_seq + h * head_stride_in_kv * _output_type_size; + auto* v_src_f = v_src_seq + h * past_seq_len * _size_per_head_aligned * _output_type_size; + auto* v_dst_f = v_dst_seq + h * head_stride_in_kv * _output_type_size; + + memcpy(k_dst_f, k_src_f, _output_type_size * _size_per_head); + memcpy(v_dst_f, v_src_f, _output_type_size * _size_per_head); + }); +} + // q_src shape: [batch, q_seq_len, num_attention_heads, 3 * head_size] // q_dst shape: [batch, num_attention_heads, q_seq_len, head_size_aligned] // kv_src shape: [batch, q_seq_len, num_attention_heads, 3 * head_size] // kv_dst shape: [batch, num_attention_heads, q_seq_len+past_seq_len, head_size_aligned] void emb_gpt_impl_avx512::applyRotaryPosEmbMemcpy(uint8_t* q_src, uint8_t* k_src, uint8_t* v_src, uint8_t* q_dst, uint8_t** k_dst, uint8_t** v_dst, - size_t batch, size_t q_seq_len, size_t past_seq_len) { + size_t batch, size_t q_seq_len, size_t past_seq_len, size_t head_stride_in_kv) { auto key_offset = _output_type_size * past_seq_len * _size_per_head_aligned; auto* cos_cached = _cos_cached.get() + past_seq_len * _rotary_ndims; auto* sin_cached = _sin_cached.get() + past_seq_len * _rotary_ndims; @@ -137,14 +160,14 @@ void emb_gpt_impl_avx512::applyRotaryPosEmbMemcpy(uint8_t* q_src, uint8_t* k_src auto* q_src_f = reinterpret_cast(q_src_seq + h * _size_per_head * 3 * _input_type_size); auto* k_src_f = reinterpret_cast(k_src_seq + h * _size_per_head * 3 * _input_type_size); auto* q_dst_f = reinterpret_cast(q_dst_seq + h * q_seq_len * _size_per_head_aligned * _output_type_size); - auto* k_dst_f = reinterpret_cast(k_dst_seq + h * _max_seq_len * _size_per_head_aligned * _output_type_size); + auto* k_dst_f = reinterpret_cast(k_dst_seq + h * head_stride_in_kv * _output_type_size); rotary_avx512(_rotary_ndims, cos_cached + s * _rotary_ndims, sin_cached + s * _rotary_ndims, q_src_f, k_src_f, q_dst_f, k_dst_f); // q, k concat memcpy(reinterpret_cast(q_dst_f) + _rotary_ndims * _output_type_size, reinterpret_cast(q_src_f) + _rotary_ndims * _input_type_size, _output_type_size * (_size_per_head - _rotary_ndims)); memcpy(reinterpret_cast(k_dst_f) + _rotary_ndims * _output_type_size, reinterpret_cast(k_src_f) + _rotary_ndims * _input_type_size, _output_type_size * (_size_per_head - _rotary_ndims)); // v concat - memcpy(static_cast(v_dst_seq) + h * _max_seq_len * _size_per_head_aligned * _output_type_size, + memcpy(static_cast(v_dst_seq) + h * head_stride_in_kv * _output_type_size, static_cast(v_src_seq) + h * _size_per_head * 3 * _input_type_size, _size_per_head * _output_type_size); }); @@ -156,7 +179,7 @@ void emb_gpt_impl_avx512::applyRotaryPosEmbMemcpy(uint8_t* q_src, uint8_t* k_src // kv_dst shape: [batch, num_attention_heads, q_seq_len+past_seq_len, head_size_aligned] // position2d_ids: [batch, 2, q_seq_len] void emb_gpt_impl_avx512::applyRotaryPosEmbMemcpyWithPosition2d(uint8_t* q_src, uint8_t* k_src, uint8_t* v_src, uint8_t* q_dst, uint8_t** k_dst, uint8_t** v_dst, - size_t batch, size_t q_seq_len, size_t past_seq_len, int* position2d_ids) { + size_t batch, size_t q_seq_len, size_t past_seq_len, int* position2d_ids, size_t head_stride_in_kv) { auto key_offset = _output_type_size * past_seq_len * _size_per_head_aligned; auto* cos_cached = _cos_cached.get(); auto* sin_cached = _sin_cached.get(); @@ -179,7 +202,7 @@ void emb_gpt_impl_avx512::applyRotaryPosEmbMemcpyWithPosition2d(uint8_t* q_src, auto* q_src_f = reinterpret_cast(q_src_seq + h * _size_per_head * 3 * _input_type_size); auto* k_src_f = reinterpret_cast(k_src_seq + h * _size_per_head * 3 * _input_type_size); auto* q_dst_f = reinterpret_cast(q_dst_seq + h * q_seq_len * _size_per_head_aligned * _output_type_size); - auto* k_dst_f = reinterpret_cast(k_dst_seq + h * _max_seq_len * _size_per_head_aligned * _output_type_size); + auto* k_dst_f = reinterpret_cast(k_dst_seq + h * head_stride_in_kv * _output_type_size); rotary_avx512(_rotary_ndims, cos_cached + pos_batch[s] * _rotary_ndims, sin_cached + pos_batch[s] * _rotary_ndims, q_src_f, k_src_f, q_dst_f, k_dst_f); rotary_avx512(_rotary_ndims, cos_cached + block_batch[s] * _rotary_ndims, sin_cached + block_batch[s] * _rotary_ndims, q_src_f + _rotary_ndims, @@ -188,7 +211,7 @@ void emb_gpt_impl_avx512::applyRotaryPosEmbMemcpyWithPosition2d(uint8_t* q_src, k_dst_f + _rotary_ndims); // v concat - memcpy(static_cast(v_dst_seq) + h * _max_seq_len * _size_per_head_aligned * _output_type_size, + memcpy(static_cast(v_dst_seq) + h * head_stride_in_kv * _output_type_size, static_cast(v_src_seq) + h * _size_per_head * 3 * _input_type_size, _size_per_head * _output_type_size); }); @@ -202,11 +225,17 @@ void emb_gpt_impl_avx512::exec(const emb_gpt::exec_param& param) { auto key = qkv + _size_per_head * _input_type_size; // qkv[..., self.head_size : 2 * self.head_size].permute(0, 2, 1, 3) auto value = qkv + 2 * _size_per_head * _input_type_size; // qkv[..., 2 * self.head_size :].permute(0, 2, 1, 3) auto query_dst = param.query_dst; - auto key_dst = param.layer_past_key_padded; - auto value_dst = param.layer_past_value_padded; + auto key_dst = param.layer_past_key_dst; + auto value_dst = param.layer_past_value_dst; auto batch = param.batch; auto query_seq_len = param.query_seq_len; auto past_seq_len = param.past_seq_len; + auto head_stride_in_kv = param.head_stride_in_kv; + + // past kv src != dst, copy src to dst first + if (param.layer_past_key_src[0] != param.layer_past_key_dst[0] && past_seq_len) + memcpyPastKV(param.layer_past_key_src, param.layer_past_value_src, param.layer_past_key_dst, param.layer_past_value_dst, batch, past_seq_len, head_stride_in_kv); + // transpose + rotary embbeding: // transpose: [batch, seq_len, num_attention_heads, 3 * head_size] --> // 3 [batch, num_attention_heads, seq_len, head_size] @@ -225,9 +254,9 @@ void emb_gpt_impl_avx512::exec(const emb_gpt::exec_param& param) { // q_dst shape: [batch, num_attention_heads, q_seq_len, head_size_aligned] // kv_dst shape: [batch, num_attention_heads, q_seq_len+past_seq_len, head_size_aligned] if (_use_position2d) { - applyRotaryPosEmbMemcpyWithPosition2d(query, key, value, query_dst, key_dst, value_dst, batch, query_seq_len, past_seq_len, param.position2d_ids); + applyRotaryPosEmbMemcpyWithPosition2d(query, key, value, query_dst, key_dst, value_dst, batch, query_seq_len, past_seq_len, param.position2d_ids, head_stride_in_kv); } else { - applyRotaryPosEmbMemcpy(query, key, value, query_dst, key_dst, value_dst, batch, query_seq_len, past_seq_len); + applyRotaryPosEmbMemcpy(query, key, value, query_dst, key_dst, value_dst, batch, query_seq_len, past_seq_len, head_stride_in_kv); } } } diff --git a/tests/script/ext/attn_gpt.cpp b/tests/script/ext/attn_gpt.cpp index 4b4ea8c..5cc4e46 100644 --- a/tests/script/ext/attn_gpt.cpp +++ b/tests/script/ext/attn_gpt.cpp @@ -35,8 +35,8 @@ class attn_gpt { size_t past_seq_len; bool is_causal_in_attention; // causal mask is fused in attention mask: chatglm uses it. uint8_t* qkv; - uint8_t** layer_past_key_padded; - uint8_t** layer_past_value_padded; + uint8_t** layer_past_key_dst; + uint8_t** layer_past_value_dst; int* position2d_ids; // shape: [batch, 2, query_seq_len] float* attention_mask; // attention mask, attention_mask[0] shape: // [batch, 1, 1, key_seq_len], when is_causal_in_attention is false @@ -106,9 +106,12 @@ void attn_gpt::exec(const attn_gpt::exec_param& param) { emb_param.past_seq_len = param.past_seq_len; emb_param.qkv = param.qkv; emb_param.query_dst = _query_dst.get(); - emb_param.layer_past_key_padded = param.layer_past_key_padded; - emb_param.layer_past_value_padded = param.layer_past_value_padded; + emb_param.layer_past_key_src = param.layer_past_key_dst; + emb_param.layer_past_value_src = param.layer_past_value_dst; + emb_param.layer_past_key_dst = param.layer_past_key_dst; + emb_param.layer_past_value_dst = param.layer_past_value_dst; emb_param.position2d_ids = param.position2d_ids; + emb_param.head_stride_in_kv = param.head_stride_in_kv; _emb_gpt->exec(emb_param); llmdnn::mha_gpt::exec_param mha_param; @@ -120,8 +123,8 @@ void attn_gpt::exec(const attn_gpt::exec_param& param) { mha_param.head_stride_in_kv = param.head_stride_in_kv; mha_param.is_causal_in_attention = param.is_causal_in_attention; mha_param.attention_mask = param.attention_mask; - mha_param.k = emb_param.layer_past_key_padded; - mha_param.v = emb_param.layer_past_value_padded; + mha_param.k = emb_param.layer_past_key_dst; + mha_param.v = emb_param.layer_past_value_dst; _mha_gpt->exec(mha_param); } @@ -173,35 +176,35 @@ void regclass_attn_gpt(pybind11::module m) { :param num_heads: heads number. :type num_heads: int )"); - cls.def("exec_position", [] (attn_gpt& self, const torch::Tensor& qkv, const torch::Tensor& layer_past_key_padded, - const torch::Tensor& layer_past_value_padded, int64_t past_seq_len, const torch::Tensor& attn_mask, const torch::Tensor& position2d_ids) { + cls.def("exec_position", [] (attn_gpt& self, const torch::Tensor& qkv, const torch::Tensor& layer_past_key_dst, + const torch::Tensor& layer_past_value_dst, int64_t past_seq_len, const torch::Tensor& attn_mask, const torch::Tensor& position2d_ids) { // qkv: [batch, seq_len, (num_heads * 3 * head_size)] // layer_past_padded: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] // past_seq_len: past_seq_len==layer_past.shape[-2] // attn_mask: [batch, 1, 1/query_seq_len, key_seq_len] // key/value: [batch, num_heads, query_seq_len+past_seq_len, head_size_aligned] - AT_ASSERT(qkv.dim() == 3 && layer_past_key_padded.dim() == 4 && layer_past_value_padded.dim() == 4 && attn_mask.dim() == 4 && - qkv.size(0) == layer_past_key_padded.size(0) && - layer_past_key_padded.dim() == layer_past_value_padded.dim()); + AT_ASSERT(qkv.dim() == 3 && layer_past_key_dst.dim() == 4 && layer_past_value_dst.dim() == 4 && attn_mask.dim() == 4 && + qkv.size(0) == layer_past_key_dst.size(0) && + layer_past_key_dst.dim() == layer_past_value_dst.dim()); auto batch = qkv.size(0); - auto num_heads = layer_past_key_padded.size(1); + auto num_heads = layer_past_key_dst.size(1); auto query_seq_len = qkv.size(1); auto head_size = qkv.size(2) / 3 / num_heads; - auto head_size_aligned = layer_past_key_padded.size(3); - auto max_seq_len = layer_past_key_padded.size(2); - AT_ASSERT(past_seq_len <= layer_past_key_padded.size(2) && head_size <= layer_past_key_padded.size(3) && - query_seq_len <= layer_past_key_padded.size(2)); + auto head_size_aligned = layer_past_key_dst.size(3); + auto max_seq_len = layer_past_key_dst.size(2); + AT_ASSERT(past_seq_len <= layer_past_key_dst.size(2) && head_size <= layer_past_key_dst.size(3) && + query_seq_len <= layer_past_key_dst.size(2)); attn_gpt::exec_param param; param.batch = batch; param.query_seq_len = query_seq_len; param.past_seq_len = past_seq_len; param.qkv = reinterpret_cast(qkv.data_ptr()); - param.layer_past_key_padded = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); - param.layer_past_value_padded = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); + param.layer_past_key_dst = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); + param.layer_past_value_dst = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); for (int i = 0; i < batch; i++) { - param.layer_past_key_padded[i] = reinterpret_cast(layer_past_key_padded[i].data_ptr()); - param.layer_past_value_padded[i] = reinterpret_cast(layer_past_value_padded[i].data_ptr()); + param.layer_past_key_dst[i] = reinterpret_cast(layer_past_key_dst[i].data_ptr()); + param.layer_past_value_dst[i] = reinterpret_cast(layer_past_value_dst[i].data_ptr()); } param.position2d_ids = reinterpret_cast(position2d_ids.data_ptr()); @@ -218,8 +221,8 @@ void regclass_attn_gpt(pybind11::module m) { return out; }, py::arg("qkv"), - py::arg("layer_past_key_padded"), - py::arg("layer_past_value_padded"), + py::arg("layer_past_key_dst"), + py::arg("layer_past_value_dst"), py::arg("past_seq_len"), py::arg("attn_mask"), py::arg("position2d_ids"), diff --git a/tests/script/ext/emb_gpt.cpp b/tests/script/ext/emb_gpt.cpp index 66ec8c1..1010f8c 100644 --- a/tests/script/ext/emb_gpt.cpp +++ b/tests/script/ext/emb_gpt.cpp @@ -59,25 +59,27 @@ void regclass_emb_gpt(pybind11::module m) { :type num_heads: int )"); // torch::List - cls.def("exec", [] (llmdnn::emb_gpt& self, const torch::Tensor& qkv, const torch::Tensor& layer_past_key_padded, - const torch::Tensor& layer_past_value_padded, const torch::Tensor& query_padded, int64_t past_seq_len) { + cls.def("exec", [] (llmdnn::emb_gpt& self, const torch::Tensor& qkv, const torch::Tensor& layer_past_key_dst, + const torch::Tensor& layer_past_value_dst, const torch::Tensor& query_padded, int64_t past_seq_len) { // qkv: [batch, seq_len, (num_heads * 3 * head_size)] // layer_past_padded: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] // past_seq_len: past_seq_len==layer_past.shape[-2] // query_padded: [batch, num_heads, query_seq_len, head_size_aligned] // key/value: [batch, num_heads, query_seq_len+past_seq_len, head_size_aligned] - AT_ASSERT(qkv.dim() == 3 && layer_past_key_padded.dim() == 4 && layer_past_value_padded.dim() == 4 && - qkv.size(0) == layer_past_key_padded.size(0) && - layer_past_key_padded.dim() == layer_past_value_padded.dim()); + AT_ASSERT(qkv.dim() == 3 && layer_past_key_dst.dim() == 4 && layer_past_value_dst.dim() == 4 && + qkv.size(0) == layer_past_key_dst.size(0) && + layer_past_key_dst.dim() == layer_past_value_dst.dim()); AT_ASSERT(query_padded.dim() == 4 && query_padded.size(0) == qkv.size(0) && - query_padded.size(1) == layer_past_key_padded.size(1) && query_padded.size(2) == qkv.size(1) && - query_padded.size(3) == layer_past_key_padded.size(3)); + query_padded.size(1) == layer_past_key_dst.size(1) && query_padded.size(2) == qkv.size(1) && + query_padded.size(3) == layer_past_key_dst.size(3)); auto batch = qkv.size(0); - auto num_heads = layer_past_key_padded.size(1); + auto num_heads = layer_past_key_dst.size(1); auto query_seq_len = qkv.size(1); auto head_size = qkv.size(2) / 3 / num_heads; - AT_ASSERT(past_seq_len <= layer_past_key_padded.size(2) && head_size <= layer_past_key_padded.size(3) && - query_seq_len <= layer_past_key_padded.size(2)); + auto head_size_aligned = layer_past_key_dst.size(3); + auto max_seq_len = layer_past_key_dst.size(2); + AT_ASSERT(past_seq_len <= layer_past_key_dst.size(2) && head_size <= layer_past_key_dst.size(3) && + query_seq_len <= layer_past_key_dst.size(2)); llmdnn::emb_gpt::exec_param param; param.batch = batch; @@ -85,11 +87,14 @@ void regclass_emb_gpt(pybind11::module m) { param.past_seq_len = past_seq_len; param.qkv = reinterpret_cast(qkv.data_ptr()); param.query_dst = reinterpret_cast(query_padded.data_ptr()); - param.layer_past_key_padded = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); - param.layer_past_value_padded = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); + param.layer_past_key_src = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); + param.layer_past_value_src = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); + param.layer_past_key_dst = param.layer_past_key_src; + param.layer_past_value_dst = param.layer_past_value_src; + param.head_stride_in_kv = max_seq_len * head_size_aligned; for (int i = 0; i < batch; i++) { - param.layer_past_key_padded[i] = reinterpret_cast(layer_past_key_padded[i].data_ptr()); - param.layer_past_value_padded[i] = reinterpret_cast(layer_past_value_padded[i].data_ptr()); + param.layer_past_key_src[i] = reinterpret_cast(layer_past_key_dst[i].data_ptr()); + param.layer_past_value_src[i] = reinterpret_cast(layer_past_value_dst[i].data_ptr()); } self.exec(param); @@ -98,8 +103,8 @@ void regclass_emb_gpt(pybind11::module m) { // auto query = torch::from_blob(param.query, {batch, num_heads, query_seq_len, head_size}, options); }, py::arg("qkv"), - py::arg("layer_past_key_padded"), - py::arg("layer_past_value_padded"), + py::arg("layer_past_key_dst"), + py::arg("layer_past_value_dst"), py::arg("query_padded"), py::arg("past_seq_len"), R"( @@ -108,25 +113,27 @@ void regclass_emb_gpt(pybind11::module m) { :param num_heads: heads number. :type num_heads: int )"); - cls.def("exec_position", [] (llmdnn::emb_gpt& self, const torch::Tensor& qkv, const torch::Tensor& layer_past_key_padded, - const torch::Tensor& layer_past_value_padded, const torch::Tensor& query_padded, int64_t past_seq_len, const torch::Tensor position2d_ids) { + cls.def("exec_position", [] (llmdnn::emb_gpt& self, const torch::Tensor& qkv, const torch::Tensor& layer_past_key_src, const torch::Tensor& layer_past_value_src, + const torch::Tensor& layer_past_key_dst, const torch::Tensor& layer_past_value_dst, const torch::Tensor& query_padded, int64_t past_seq_len, const torch::Tensor position2d_ids) { // qkv: [batch, seq_len, (num_heads * 3 * head_size)] // layer_past_padded: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] // past_seq_len: past_seq_len==layer_past.shape[-2] // query_padded: [batch, num_heads, query_seq_len, head_size_aligned] // key/value: [batch, num_heads, query_seq_len+past_seq_len, head_size_aligned] - AT_ASSERT(qkv.dim() == 3 && layer_past_key_padded.dim() == 4 && layer_past_value_padded.dim() == 4 && - qkv.size(0) == layer_past_key_padded.size(0) && - layer_past_key_padded.dim() == layer_past_value_padded.dim()); + AT_ASSERT(qkv.dim() == 3 && layer_past_key_dst.dim() == 4 && layer_past_value_dst.dim() == 4 && + qkv.size(0) == layer_past_key_dst.size(0) && + layer_past_key_dst.dim() == layer_past_value_dst.dim()); AT_ASSERT(query_padded.dim() == 4 && query_padded.size(0) == qkv.size(0) && - query_padded.size(1) == layer_past_key_padded.size(1) && query_padded.size(2) == qkv.size(1) && - query_padded.size(3) == layer_past_key_padded.size(3)); + query_padded.size(1) == layer_past_key_dst.size(1) && query_padded.size(2) == qkv.size(1) && + query_padded.size(3) == layer_past_key_dst.size(3)); auto batch = qkv.size(0); - auto num_heads = layer_past_key_padded.size(1); + auto num_heads = layer_past_key_dst.size(1); auto query_seq_len = qkv.size(1); auto head_size = qkv.size(2) / 3 / num_heads; - AT_ASSERT(past_seq_len <= layer_past_key_padded.size(2) && head_size <= layer_past_key_padded.size(3) && - query_seq_len <= layer_past_key_padded.size(2)); + auto head_size_aligned = layer_past_key_dst.size(3); + auto max_seq_len = layer_past_key_dst.size(2); + AT_ASSERT(past_seq_len <= layer_past_key_dst.size(2) && head_size <= layer_past_key_dst.size(3) && + query_seq_len <= layer_past_key_dst.size(2)); llmdnn::emb_gpt::exec_param param; param.batch = batch; @@ -134,11 +141,16 @@ void regclass_emb_gpt(pybind11::module m) { param.past_seq_len = past_seq_len; param.qkv = reinterpret_cast(qkv.data_ptr()); param.query_dst = reinterpret_cast(query_padded.data_ptr()); - param.layer_past_key_padded = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); - param.layer_past_value_padded = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); + param.layer_past_key_src = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); + param.layer_past_value_src = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); + param.layer_past_key_dst = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); + param.layer_past_value_dst = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); + param.head_stride_in_kv = max_seq_len * head_size_aligned; for (int i = 0; i < batch; i++) { - param.layer_past_key_padded[i] = reinterpret_cast(layer_past_key_padded[i].data_ptr()); - param.layer_past_value_padded[i] = reinterpret_cast(layer_past_value_padded[i].data_ptr()); + param.layer_past_key_src[i] = past_seq_len == 0 ? nullptr : reinterpret_cast(layer_past_key_src[i].data_ptr()); + param.layer_past_value_src[i] = past_seq_len == 0 ? nullptr : reinterpret_cast(layer_past_value_src[i].data_ptr()); + param.layer_past_key_dst[i] = reinterpret_cast(layer_past_key_dst[i].data_ptr()); + param.layer_past_value_dst[i] = reinterpret_cast(layer_past_value_dst[i].data_ptr()); } param.position2d_ids = reinterpret_cast(position2d_ids.data_ptr()); @@ -148,8 +160,10 @@ void regclass_emb_gpt(pybind11::module m) { // auto query = torch::from_blob(param.query, {batch, num_heads, query_seq_len, head_size}, options); }, py::arg("qkv"), - py::arg("layer_past_key_padded"), - py::arg("layer_past_value_padded"), + py::arg("layer_past_key_src"), + py::arg("layer_past_value_src"), + py::arg("layer_past_key_dst"), + py::arg("layer_past_value_dst"), py::arg("query_padded"), py::arg("past_seq_len"), py::arg("position2d_ids"), diff --git a/tests/script/test_rotary_pastkv_chatglm.py b/tests/script/test_rotary_pastkv_chatglm.py index 278bfa6..73ea72d 100644 --- a/tests/script/test_rotary_pastkv_chatglm.py +++ b/tests/script/test_rotary_pastkv_chatglm.py @@ -268,8 +268,8 @@ def __init__(self, num_attention_heads, hidden_size, head_size_aligned, max_posi # 1: query: [batch, num_attention_heads, seq_len, head_size_aligned] # 2: k: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] # 3: v: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] - def forward(self, qkv, layer_past_key_padded, layer_past_value_padded, query_padded, past_seq_len, position_ids): - self.emd.exec_position(qkv, layer_past_key_padded, layer_past_value_padded, query_padded, past_seq_len, position_ids) + def forward(self, qkv, layer_past_key_src, layer_past_value_src, layer_past_key_dst, layer_past_value_dst, query_padded, past_seq_len, position_ids): + self.emd.exec_position(qkv, layer_past_key_src, layer_past_value_src, layer_past_key_dst, layer_past_value_dst, query_padded, past_seq_len, position_ids) HEAD_NUM = 32 @@ -298,6 +298,7 @@ def test_chatglm(): ] ref_net = get_ref_model() net = GPTNeoXAttentionExt(HEAD_NUM, HIDDEN_SIZE, SIZE_PER_HEAD_ALIGN, MAX_POSITION_EMBEDDINGS, ROTARY_EMB_BASE, ROTARY_PCT) + net_seq = GPTNeoXAttentionExt(HEAD_NUM, HIDDEN_SIZE, SIZE_PER_HEAD, MAX_POSITION_EMBEDDINGS, ROTARY_EMB_BASE, ROTARY_PCT) with torch.cpu.amp.autocast(): for (i, input) in enumerate(inputs): qkv, layer_past_key, layer_past_value = input @@ -318,7 +319,29 @@ def test_chatglm(): seq_ids[:, 0, :] = seq_batch1 seq_ids[:, 1, :] = block_batch1 query_ref, key_ref, value_ref = ref_net.forward(qkv, seq_ids, (layer_past_key, layer_past_value)) + query_ref = query_ref.to(dtype=torch.bfloat16) + key_ref = key_ref.to(dtype=torch.bfloat16) + + # no prealloc past kv + layer_past_key_seq = torch.zeros(key_ref.shape, dtype=torch.bfloat16) + layer_past_value_seq = torch.zeros(value_ref.shape, dtype=torch.bfloat16) + query_seq = torch.zeros(query_ref.shape, dtype=torch.bfloat16) + net_seq.forward(qkv, layer_past_key, layer_past_value, layer_past_key_seq, layer_past_value_seq, query_seq, past_seq_len, seq_ids) + query, key, value = query_seq, layer_past_key_seq, layer_past_value_seq + # check query + if not torch.allclose(query_ref, query, rtol=0.001, atol=0.01): + print(f"error at sequence query index {i} ref:\n{query_ref} \ncur:\n {query} ") + #assert(False) + # check key + if not torch.allclose(key_ref, key, rtol=0.001, atol=0.01): + print(f"error at sequence key index {i} ref:\n{key_ref} \ncur:\n {key} ") + assert(False) + # check value + if not torch.allclose(value_ref, value, rtol=0.001, atol=0.01): + print(f"error at sequence value index {i} ref:\n{value_ref} \ncur:\n {value} ") + assert(False) + # prealloc past kv shape[-2] = MAX_SEQ_LEN shape[-1] = SIZE_PER_HEAD_ALIGN layer_past_key_padded = torch.zeros(shape, dtype=torch.bfloat16) @@ -328,10 +351,8 @@ def test_chatglm(): query_padded = torch.zeros(query_shape, dtype=torch.bfloat16) layer_past_key_padded[:,:,:layer_past_key.shape[-2],:layer_past_key.shape[-1]] = layer_past_key layer_past_value_padded[:,:,:layer_past_key.shape[-2],:layer_past_key.shape[-1]] = layer_past_value - query_ref = query_ref.to(dtype=torch.bfloat16) - key_ref = key_ref.to(dtype=torch.bfloat16) - net.forward(qkv, layer_past_key_padded, layer_past_value_padded, query_padded, past_seq_len, seq_ids) - output, query, key, value = (layer_past_key_padded, layer_past_value_padded), query_padded, layer_past_key_padded, layer_past_value_padded + net.forward(qkv, layer_past_key_padded, layer_past_value_padded, layer_past_key_padded, layer_past_value_padded, query_padded, past_seq_len, seq_ids) + query, key, value = query_padded, layer_past_key_padded, layer_past_value_padded # check query if not torch.allclose(query_ref, query[:,:,:,:query_ref.shape[-1]], rtol=0.001, atol=0.01): print(f"error at query index {i} ref:\n{query_ref} \ncur:\n {query} ") From 67aabdcdcbb80146beb1a8dd632c9d63f92a8b47 Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Thu, 29 Jun 2023 01:51:25 +0800 Subject: [PATCH 13/54] fc support int8 weight compression --- include/llm_fc.hpp | 5 +++-- src/fc_kernel_amx.cpp | 12 ++---------- src/fc_kernel_amx.hpp | 1 - src/fc_kernel_api.cpp | 6 ------ 4 files changed, 5 insertions(+), 19 deletions(-) diff --git a/include/llm_fc.hpp b/include/llm_fc.hpp index d503d8a..ac77fd4 100644 --- a/include/llm_fc.hpp +++ b/include/llm_fc.hpp @@ -32,6 +32,9 @@ struct fc_create_param { data_type_t dt_c; bool b_is_trans; postops_types postops_type; + // for weight compression + float q; + float dq; }; struct fc_kernel; @@ -58,7 +61,5 @@ void fc_kernel_execute(const fc_kernel* mm, /// weight compression /// compute weight min/max once, set q, dq for each fc_kernel instance void fc_kernel_bf16w8_get_q_dq(size_t K, size_t N, size_t stride, void* ptr, float* q, float* dq); -/// set q, dq for each fc_kernel instance, must call before first fc_kernel_execute -void fc_kernel_bf16w8_set_q_dq(const fc_kernel* mm, float q, float dq); } diff --git a/src/fc_kernel_amx.cpp b/src/fc_kernel_amx.cpp index ba5c41f..8f6d6eb 100644 --- a/src/fc_kernel_amx.cpp +++ b/src/fc_kernel_amx.cpp @@ -90,6 +90,8 @@ bool fc_kernel_create_amx(fc_kernel** mm, const fc_create_param* param) { m->bf16xbf16 = std::make_shared>(true, param->b_is_trans); } else if (param->dt_a == dnnl_bf16 && param->dt_b == dnnl_s8) { m->bf16xi8 = std::make_shared>(true, param->b_is_trans); + m->bf16xi8->quant_scale_B = param->q; + m->bf16xi8->dequant_scale_B = param->dq; } else { std::cout << "fc_kernel_create: unsupport input type, a: " << param->dt_a << ", b: " << param->dt_b << ".\n"; goto ERR; @@ -303,14 +305,4 @@ void fc_kernel_bf16w8_get_q_dq_amx(size_t K, size_t N, size_t stride, void* ptr, *dq = max / 127; } -/// set q, dq for each fc_kernel instance -void fc_kernel_bf16w8_set_q_dq_amx(const fc_kernel* mm, float q, float dq) { - if (!mm || !mm->bf16xi8) { - std::cout << "fc_kernel_bf16w8_set_q_dq: created kernel is not int8 weight.\n"; - return; - } - mm->bf16xi8->quant_scale_B = q; - mm->bf16xi8->dequant_scale_B = dq; -} - } \ No newline at end of file diff --git a/src/fc_kernel_amx.hpp b/src/fc_kernel_amx.hpp index 5c5467f..93bf47f 100644 --- a/src/fc_kernel_amx.hpp +++ b/src/fc_kernel_amx.hpp @@ -14,6 +14,5 @@ void fc_kernel_execute_amx(const fc_kernel* mm, void* ptr_a, void* ptr_b, void* size_t M, size_t N, size_t K, size_t n_start, size_t n_end, float* dq, float* q, float* bias); void fc_kernel_bf16w8_get_q_dq_amx(size_t K, size_t N, size_t stride, void* ptr, float* q, float* dq); -void fc_kernel_bf16w8_set_q_dq_amx(const fc_kernel* mm, float q, float dq); } \ No newline at end of file diff --git a/src/fc_kernel_api.cpp b/src/fc_kernel_api.cpp index a063f7e..d7b3e9c 100644 --- a/src/fc_kernel_api.cpp +++ b/src/fc_kernel_api.cpp @@ -25,7 +25,6 @@ static decltype(&fc_kernel_create) fc_kernel_create_ptr = fc_kernel_create_amx; static decltype(&fc_kernel_destroy) fc_kernel_destroy_ptr = fc_kernel_destroy_amx; static decltype(&fc_kernel_execute) fc_kernel_execute_ptr = fc_kernel_execute_amx; static decltype(&fc_kernel_bf16w8_get_q_dq) fc_kernel_bf16w8_get_q_dq_ptr = fc_kernel_bf16w8_get_q_dq_amx; -static decltype(&fc_kernel_bf16w8_set_q_dq) fc_kernel_bf16w8_set_q_dq_ptr = fc_kernel_bf16w8_set_q_dq_amx; // interface bool fc_kernel_create(fc_kernel** mm, const fc_create_param* param) { @@ -45,9 +44,4 @@ void fc_kernel_bf16w8_get_q_dq(size_t K, size_t N, size_t stride, void* ptr, flo fc_kernel_bf16w8_get_q_dq_ptr(K, N, stride, ptr, q, dq); } -/// set q, dq for each fc_kernel instance -void fc_kernel_bf16w8_set_q_dq(const fc_kernel* mm, float q, float dq) { - fc_kernel_bf16w8_set_q_dq_ptr(mm, q, dq); -} - } \ No newline at end of file From f7723dc3d10a242d13f98ab5e2f955c09e62f452 Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Thu, 29 Jun 2023 01:51:25 +0800 Subject: [PATCH 14/54] support K<32(pad zero to 32) --- src/common/tensor2d.hpp | 14 ++++++++ src/mm_kernel_common_amx.hpp | 15 +++++++- tests/script/models/chatglm-6b/.gitattributes | 34 ------------------- tests/src/test_mm_kernel_amx.cpp | 2 ++ 4 files changed, 30 insertions(+), 35 deletions(-) delete mode 100644 tests/script/models/chatglm-6b/.gitattributes diff --git a/src/common/tensor2d.hpp b/src/common/tensor2d.hpp index 0c4a8f8..971d1b0 100644 --- a/src/common/tensor2d.hpp +++ b/src/common/tensor2d.hpp @@ -72,6 +72,20 @@ struct tensor2D { } return ret; } + tensor2D clone_with_padzero(int dim0, int dim1) { + tensor2D ret; + ret.resize(dim0, dim1, force_compact); + assert(dim0 >= dims[0] && dim1 >= dims[1]); + for(int i = 0; i < dims[0]; i++) { + memcpy(&ret(i, 0), &(*this)(i, 0), dims[1] * sizeof(T)); + memset(reinterpret_cast(&ret(i, 0) + dims[1]), 0, ret.stride - dims[1] * sizeof(T)); + } + if (dims[1] == dim1) { + memset(reinterpret_cast(ret.data + dims[0] * ret.padded_dim1), 0, (dim0 - dims[0]) * ret.stride); + } + + return ret; + } void resize(int d0, int d1, bool _force_compact = false, bool is_const=false) { own = true; force_compact = _force_compact; diff --git a/src/mm_kernel_common_amx.hpp b/src/mm_kernel_common_amx.hpp index a8968e8..3291a03 100644 --- a/src/mm_kernel_common_amx.hpp +++ b/src/mm_kernel_common_amx.hpp @@ -1990,9 +1990,22 @@ struct Matmul { int n0, int n1, PP ppkernel, bool skip_repack = false) { - auto matB = getSubMatB(_matB, n0, n1, transposeB); int M = matA.dims[0]; int K = matA.dims[1]; + if (K < kStep) { + int B0, B1; + if (transposeB) { + B0 = _matB.dims[0]; + B1 = kStep; + } else { + B0 = kStep; + B1 = _matB.dims[1]; + } + matA = matA.clone_with_padzero(M, kStep); + _matB = _matB.clone_with_padzero(B0, B1); + K = kStep; + } + auto matB = getSubMatB(_matB, n0, n1, transposeB); int N = matB.dims[transposeB ? 0 : 1]; assert(K == matB.dims[transposeB ? 1 : 0]); // Due to the fact that we load a full tile at tails of K dimension diff --git a/tests/script/models/chatglm-6b/.gitattributes b/tests/script/models/chatglm-6b/.gitattributes deleted file mode 100644 index c7d9f33..0000000 --- a/tests/script/models/chatglm-6b/.gitattributes +++ /dev/null @@ -1,34 +0,0 @@ -*.7z filter=lfs diff=lfs merge=lfs -text -*.arrow filter=lfs diff=lfs merge=lfs -text -*.bin filter=lfs diff=lfs merge=lfs -text -*.bz2 filter=lfs diff=lfs merge=lfs -text -*.ckpt filter=lfs diff=lfs merge=lfs -text -*.ftz filter=lfs diff=lfs merge=lfs -text -*.gz filter=lfs diff=lfs merge=lfs -text -*.h5 filter=lfs diff=lfs merge=lfs -text -*.joblib filter=lfs diff=lfs merge=lfs -text -*.lfs.* filter=lfs diff=lfs merge=lfs -text -*.mlmodel filter=lfs diff=lfs merge=lfs -text -*.model filter=lfs diff=lfs merge=lfs -text -*.msgpack filter=lfs diff=lfs merge=lfs -text -*.npy filter=lfs diff=lfs merge=lfs -text -*.npz filter=lfs diff=lfs merge=lfs -text -*.onnx filter=lfs diff=lfs merge=lfs -text -*.ot filter=lfs diff=lfs merge=lfs -text -*.parquet filter=lfs diff=lfs merge=lfs -text -*.pb filter=lfs diff=lfs merge=lfs -text -*.pickle filter=lfs diff=lfs merge=lfs -text -*.pkl filter=lfs diff=lfs merge=lfs -text -*.pt filter=lfs diff=lfs merge=lfs -text -*.pth filter=lfs diff=lfs merge=lfs -text -*.rar filter=lfs diff=lfs merge=lfs -text -*.safetensors filter=lfs diff=lfs merge=lfs -text -saved_model/**/* filter=lfs diff=lfs merge=lfs -text -*.tar.* filter=lfs diff=lfs merge=lfs -text -*.tflite filter=lfs diff=lfs merge=lfs -text -*.tgz filter=lfs diff=lfs merge=lfs -text -*.wasm filter=lfs diff=lfs merge=lfs -text -*.xz filter=lfs diff=lfs merge=lfs -text -*.zip filter=lfs diff=lfs merge=lfs -text -*.zst filter=lfs diff=lfs merge=lfs -text -*tfevents* filter=lfs diff=lfs merge=lfs -text diff --git a/tests/src/test_mm_kernel_amx.cpp b/tests/src/test_mm_kernel_amx.cpp index 61fdeb6..ce0fcb2 100644 --- a/tests/src/test_mm_kernel_amx.cpp +++ b/tests/src/test_mm_kernel_amx.cpp @@ -116,6 +116,8 @@ const std::vector> types = { const std::vector shapes = { // normal {256, 48, 448}, + // k < 32 + {256, 48, 15}, // k tail {256, 48, 449}, // M tail == unroll 8 From ae5fc0cf3c947b8b715d4713db73608d5d87988b Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Thu, 29 Jun 2023 01:51:25 +0800 Subject: [PATCH 15/54] import && rename --- CMakeLists.txt | 10 +++++----- README.md | 9 +++++++++ src/CMakeLists.txt | 10 ++++++---- tests/CMakeLists.txt | 8 ++++---- tests/script/build.sh | 4 ++-- tests/script/ext/CMakeLists.txt | 16 ++++++++-------- tests/script/ext/setup.py | 8 ++++---- 7 files changed, 38 insertions(+), 27 deletions(-) create mode 100644 README.md diff --git a/CMakeLists.txt b/CMakeLists.txt index 45b4f49..d6d772c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -7,10 +7,10 @@ cmake_minimum_required(VERSION 3.13) project(root) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) -option(LLMDNN_BUILD_TESTS "Build with tests" ON) +option(CPU_EXTENSIONS_BUILD_TESTS "Build with tests" ON) message(INFO "--------------------------------") -message(STATUS "Build with tests: ${LLMDNN_BUILD_TESTS}") +message(STATUS "Build with tests: ${CPU_EXTENSIONS_BUILD_TESTS}") message(INFO "--------------------------------") set(CMAKE_CXX_STANDARD 17) @@ -26,8 +26,8 @@ if(MSVC) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /W4") endif() elseif(CMAKE_COMPILER_IS_GNUCC OR CMAKE_COMPILER_IS_GNUCXX) - if(CMAKE_CXX_COMPILER_VERSION VERSION_LESS "11.3") - message(FATAL_ERROR "Insufficient gcc compiler version, current ${CMAKE_CXX_COMPILER_VERSION}, minimum 11.3.") + if(CMAKE_CXX_COMPILER_VERSION VERSION_LESS "11.2") + message(FATAL_ERROR "Insufficient gcc compiler version, current ${CMAKE_CXX_COMPILER_VERSION}, minimum 11.2.") endif() set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=sapphirerapids -flax-vector-conversions") elseif(OV_COMPILER_IS_CLANG) @@ -51,6 +51,6 @@ if (NOT CMAKE_RUNTIME_OUTPUT_DIRECTORY) set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) endif() add_subdirectory(src) -if (LLMDNN_BUILD_TESTS) +if (CPU_EXTENSIONS_BUILD_TESTS) add_subdirectory(tests) endif() diff --git a/README.md b/README.md new file mode 100644 index 0000000..0129645 --- /dev/null +++ b/README.md @@ -0,0 +1,9 @@ +# About CPU_Extensions +CPU_Extensions is a compute library containing processor optimized kernels code. + +# Unit tests for CPU_Extensions +## Tests for kernels +Tests for kernels are written in gtest under tests\src, use ./cpu_extensions_tests to run it. + +## Tests for complex features +Some features have many steps and the reference could not be easily written using gtest. For these features can use python to generate the reference. The directory tests\script contains these test, please refer [test in python](./tests/script/README.md). \ No newline at end of file diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 3b970cf..9c24750 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -3,10 +3,12 @@ # cmake_minimum_required(VERSION 3.13) -project(llmdnn) +project(cpu_extensions) file(GLOB_RECURSE SOURCE_FILES ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp) -add_library(llmdnn STATIC ${SOURCE_FILES}) -target_include_directories(llmdnn PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PUBLIC $) +add_library(cpu_extensions STATIC ${SOURCE_FILES}) +target_include_directories(cpu_extensions PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} + PUBLIC $) +set_target_properties(cpu_extensions PROPERTIES + POSITION_INDEPENDENT_CODE ON) diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index ceb0369..690e748 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -3,7 +3,7 @@ # cmake_minimum_required(VERSION 3.13) -project(llmdnn_tests) +project(cpu_extensions_tests) enable_testing() @@ -20,11 +20,11 @@ endif() set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) -include_directories(${LLMDNN_HEADERS_DIR}) +include_directories(${CPU_EXTENSIONS_HEADERS_DIR}) file(GLOB_RECURSE TEST_SOURCE_FILES ${CMAKE_CURRENT_SOURCE_DIR}/src/*.cpp) find_package(OpenMP REQUIRED) -add_executable(llmdnn_tests ${TEST_SOURCE_FILES}) -target_link_libraries(llmdnn_tests llmdnn gtest_main stdc++ OpenMP::OpenMP_CXX) +add_executable(cpu_extensions_tests ${TEST_SOURCE_FILES}) +target_link_libraries(cpu_extensions_tests cpu_extensions gtest_main stdc++ OpenMP::OpenMP_CXX) diff --git a/tests/script/build.sh b/tests/script/build.sh index ab29766..b6bc751 100755 --- a/tests/script/build.sh +++ b/tests/script/build.sh @@ -1,8 +1,8 @@ #!/bin/bash pip uninstall -y llmdnn -cd ../../../../../../../build/ || exit -make -j 20 llmdnn +cd ../../build/ || exit +make -j 20 cd - || exit cd ext || exit python setup.py clean --all install diff --git a/tests/script/ext/CMakeLists.txt b/tests/script/ext/CMakeLists.txt index d2eed61..41a0951 100644 --- a/tests/script/ext/CMakeLists.txt +++ b/tests/script/ext/CMakeLists.txt @@ -4,7 +4,7 @@ cmake_minimum_required (VERSION 3.13) -project(llmdnn_ext LANGUAGES CXX) +project(cpu_extensions_ext LANGUAGES CXX) # ref:https://stackoverflow.com/questions/68401650/how-can-i-make-a-pytorch-extension-with-cmake execute_process(COMMAND python3 -c "import torch;print(torch.utils.cmake_prefix_path+'/Torch/',end='')" OUTPUT_VARIABLE TORCH_CMAKE_PATH) @@ -14,16 +14,16 @@ find_package(Python REQUIRED COMPONENTS Development) find_package(Torch REQUIRED) find_package(OpenMP REQUIRED) -add_library(llmdnn_ext SHARED +add_library(cpu_extensions_ext SHARED mha_gpt.cpp module.cpp ../../src/test_common.cpp ) -set_target_properties(llmdnn_ext PROPERTIES - OUTPUT_NAME "llmdnn_ext" +set_target_properties(cpu_extensions_ext PROPERTIES + OUTPUT_NAME "cpu_extensions_ext" POSITION_INDEPENDENT_CODE ON LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../) -install(TARGETS llmdnn_ext DESTINATION ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}) -target_compile_features(llmdnn_ext PRIVATE cxx_std_14) -target_include_directories(llmdnn_ext PRIVATE ../../src) -target_link_libraries(llmdnn_ext PRIVATE ${TORCH_LIBRARIES} Python::Python llmdnn stdc++ OpenMP::OpenMP_CXX) +install(TARGETS cpu_extensions_ext DESTINATION ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}) +target_compile_features(cpu_extensions_ext PRIVATE cxx_std_14) +target_include_directories(cpu_extensions_ext PRIVATE ../../src) +target_link_libraries(cpu_extensions_ext PRIVATE ${TORCH_LIBRARIES} Python::Python cpu_extensions stdc++ OpenMP::OpenMP_CXX) diff --git a/tests/script/ext/setup.py b/tests/script/ext/setup.py index dadcf5d..2780c94 100644 --- a/tests/script/ext/setup.py +++ b/tests/script/ext/setup.py @@ -18,9 +18,9 @@ debug = True if os.environ['DEBUG_EXT'] == '1' else False extra_args = ['-fopenmp', '-march=native'] -llmdnn_lib_dir = f'{os.getcwd()}/../../../../../../../../bin/intel64/Release' +cpu_extensions_lib_dir = f'{os.getcwd()}/../../../build/lib' if debug: - llmdnn_lib_dir = f'{os.getcwd()}/../../../../../../../../bin/intel64/Debug' + cpu_extensions_lib_dir = f'{os.getcwd()}/../../../build/lib' extra_args += ['-g', '-O0'] print('install debug version') else: @@ -40,8 +40,8 @@ f'{os.getcwd()}/../../../include', f'{os.getcwd()}/../../../src'], library_dirs=[f'{sys.prefix}/lib', - llmdnn_lib_dir], - libraries=['llmdnn', + cpu_extensions_lib_dir], + libraries=['cpu_extensions', 'stdc++']), ], cmdclass={'build_ext': cpp_extension.BuildExtension} From 9725e595a40e0651120bf67e1144392fd958468b Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Fri, 30 Jun 2023 17:21:42 +0800 Subject: [PATCH 16/54] workaround MatmulVector K<=6*32 --- .gitignore | 7 ++++++- src/mha_gpt_amx.cpp | 4 ++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index d1f16b8..5e92026 100644 --- a/.gitignore +++ b/.gitignore @@ -44,8 +44,13 @@ ## Local build/**/* +**/build/**/* out/* lib/* bin/* test/test_runner -.vs \ No newline at end of file +.vs +.cache +__pycache__ +dist +*.egg-info \ No newline at end of file diff --git a/src/mha_gpt_amx.cpp b/src/mha_gpt_amx.cpp index 6047290..fd9971c 100644 --- a/src/mha_gpt_amx.cpp +++ b/src/mha_gpt_amx.cpp @@ -118,7 +118,7 @@ void mha_gpt_impl_amx::mha_bf16(const mha_gpt::exec_param ¶m) { auto& gemAvB_ops = gemAvB_BF16xBF16; auto& qKtrGemm_ops = qKtrGemm_BF16xBF16; auto& qKVGemm_ops = qKVGemm_BF16xBF16; - bool is_vector = param.query_seq_len == 1; + bool is_vector = param.query_seq_len == 1 && _create_param.head_size >= 32 && _create_param.head_size <= 32 * 6; size_t head_stride_in_q = _create_param.head_size_aligned * param.query_seq_len; size_t batch_stride_in_q = head_stride_in_q * _create_param.num_heads; size_t head_stride_in_attn = _create_param.head_size; @@ -246,7 +246,7 @@ void mha_gpt_impl_amx::mha_i8(const mha_gpt::exec_param ¶m) { auto& gemAvB_ops = gemAvB_i8xi8; auto& qKtrGemm_ops = qKtrGemm_i8xi8; auto& qKVGemm_ops = qKVGemm_u8xi8; - bool is_vector = param.query_seq_len == 1; + bool is_vector = param.query_seq_len == 1 && _create_param.head_size >= 64 && _create_param.head_size <= 64 * 6; // dequant param auto mul_scales = _create_param.normal_factor * param.q_dequant * param.k_dequant; // prepare for per channel From aac4116289a14feb93ca5b2e74b1d02960df1b71 Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Fri, 30 Jun 2023 21:17:43 +0800 Subject: [PATCH 17/54] export cmake target --- CMakeLists.txt | 3 --- src/CMakeLists.txt | 40 +++++++++++++++++++++++++++++++++++----- tests/CMakeLists.txt | 13 ++++++++----- 3 files changed, 43 insertions(+), 13 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index d6d772c..f99a3e7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -42,9 +42,6 @@ elseif(CMAKE_CXX_COMPILER_ID STREQUAL "IntelLLVM") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=sapphirerapids") endif() -include_directories(${PROJECT_SOURCE_DIR}/include) -include_directories(${PROJECT_SOURCE_DIR}/src/) - if (NOT CMAKE_RUNTIME_OUTPUT_DIRECTORY) set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 9c24750..2369c0b 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -5,10 +5,40 @@ cmake_minimum_required(VERSION 3.13) project(cpu_extensions) -file(GLOB_RECURSE SOURCE_FILES ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp) +file(GLOB_RECURSE ${PROJECT_NAME}_SOURCE_FILES CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp) -add_library(cpu_extensions STATIC ${SOURCE_FILES}) -target_include_directories(cpu_extensions PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PUBLIC $) -set_target_properties(cpu_extensions PROPERTIES +add_library(${PROJECT_NAME} STATIC ${${PROJECT_NAME}_SOURCE_FILES}) +set_target_properties(${PROJECT_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON) +target_include_directories(${PROJECT_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} + PUBLIC $ + $/${CMAKE_INSTALL_INCLUDEDIR}>) + +set(CMAKE_DST ${CMAKE_INSTALL_LIBDIR}/cmake/${PROJECT_NAME}) +# header files +include(GNUInstallDirs) +install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../include/ + DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}) + +# library file +install(TARGETS ${PROJECT_NAME} + EXPORT ${PROJECT_NAME}Targets + LIBRARY DESTINATION lib + ARCHIVE DESTINATION lib + RUNTIME DESTINATION bin + INCLUDES DESTINATION include) + +# config file +install(EXPORT ${PROJECT_NAME}Targets + FILE ${PROJECT_NAME}Config.cmake + NAMESPACE ${PROJECT_NAME}:: + DESTINATION ${CMAKE_DST}) + +# version file +include(CMakePackageConfigHelpers) +write_basic_package_version_file(${CMAKE_CURRENT_BINARY_DIR}/${PROJECT_NAME}ConfigVersion.cmake + VERSION 1.0.0 + COMPATIBILITY AnyNewerVersion) +install(FILES + ${CMAKE_CURRENT_BINARY_DIR}/${PROJECT_NAME}ConfigVersion.cmake + DESTINATION ${CMAKE_DST}) diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 690e748..8ed51d5 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -15,16 +15,19 @@ if (NOT TARGET gtest_main) GIT_TAG release-1.11.0 GIT_SHALLOW TRUE GIT_PROGRESS TRUE) - FetchContent_MakeAvailable(googletest) + FetchContent_GetProperties(googletest) + if(NOT googletest_POPULATED) + FetchContent_Populate(googletest) + add_subdirectory(${googletest_SOURCE_DIR} ${googletest_BINARY_DIR} EXCLUDE_FROM_ALL) + endif() endif() set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) -include_directories(${CPU_EXTENSIONS_HEADERS_DIR}) - -file(GLOB_RECURSE TEST_SOURCE_FILES ${CMAKE_CURRENT_SOURCE_DIR}/src/*.cpp) +file(GLOB_RECURSE TEST_SOURCE_FILES CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/src/*.cpp) find_package(OpenMP REQUIRED) add_executable(cpu_extensions_tests ${TEST_SOURCE_FILES}) -target_link_libraries(cpu_extensions_tests cpu_extensions gtest_main stdc++ OpenMP::OpenMP_CXX) +target_include_directories(cpu_extensions_tests PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../src) +target_link_libraries(cpu_extensions_tests cpu_extensions gtest_main stdc++ OpenMP::OpenMP_CXX) \ No newline at end of file From 323fa5f7a119f6309d249bd75f22d403d48d5789 Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Mon, 3 Jul 2023 17:21:27 +0800 Subject: [PATCH 18/54] rotary use external cos/sin loopup table --- include/llm_emb_gpt.hpp | 6 +- src/emb_gpt_avx512.cpp | 91 +++++-------------- tests/script/ext/attn_gpt.cpp | 27 +++--- tests/script/ext/emb_gpt.cpp | 24 ++--- .../models/chatglm-6b/modeling_chatglm.py | 15 ++- tests/script/test_attn_chatglm.py | 20 +++- tests/script/test_rotary_pastkv.py | 16 +++- tests/script/test_rotary_pastkv_chatglm.py | 20 +++- 8 files changed, 118 insertions(+), 101 deletions(-) diff --git a/include/llm_emb_gpt.hpp b/include/llm_emb_gpt.hpp index 7bfd183..95cdd07 100644 --- a/include/llm_emb_gpt.hpp +++ b/include/llm_emb_gpt.hpp @@ -17,12 +17,10 @@ class emb_gpt { size_t num_heads; size_t head_size; size_t head_size_aligned; // better to aligned to 64 bytes for best performance, apply for qkv - size_t max_seq_len; // max seq length for computing the size of matmul tmp result // supported (qkv, dst): (bf16, bf16) data_type_t qkv_precision; data_type_t dst_precision; - size_t rotary_emb_base; - float rotary_pct; + size_t rotary_dims; bool use_position2d; // chatglm true, other false }; struct exec_param { @@ -35,6 +33,8 @@ class emb_gpt { uint8_t** layer_past_value_src; // past value src uint8_t** layer_past_key_dst; // past key dst, if layer_past_key_src!=layer_past_key_dst, will copy layer_past_key_src to layer_past_key_dst uint8_t** layer_past_value_dst; // past value dst, if layer_past_value!=layer_past_value_dst, will copy layer_past_value to layer_past_value_dst + float* cos; // cos table + float* sin; // sin table int* position2d_ids; // shape: [batch, 2, query_seq_len] size_t head_stride_in_kv; // kv stride for next head; kv may be preallocated a big buffer }; diff --git a/src/emb_gpt_avx512.cpp b/src/emb_gpt_avx512.cpp index c957d93..2485764 100644 --- a/src/emb_gpt_avx512.cpp +++ b/src/emb_gpt_avx512.cpp @@ -22,26 +22,20 @@ struct emb_gpt_impl_avx512 : public emb_gpt::impl { bool create(const emb_gpt::create_param& param) override; void exec(const emb_gpt::exec_param& param) override; - void initRotery(size_t max_seq_len); void memcpyPastKV(uint8_t** pastk_src, uint8_t** pastv_src, uint8_t** pastk_dst, uint8_t** pastv_dst, size_t batch, size_t past_seq_len, size_t head_stride_in_kv); void applyRotaryPosEmbMemcpy(uint8_t* q_src, uint8_t* k_src, uint8_t* v_src, uint8_t* q_dst, uint8_t** k_dst, uint8_t** v_dst, - size_t batch, size_t q_seq_len, size_t past_seq_len, size_t head_stride_in_kv); + size_t batch, size_t q_seq_len, size_t past_seq_len, size_t head_stride_in_kv, float* cos, float* sin); void applyRotaryPosEmbMemcpyWithPosition2d(uint8_t* q_src, uint8_t* k_src, uint8_t* v_src, uint8_t* q_dst, uint8_t** k_dst, uint8_t** v_dst, - size_t batch, size_t q_seq_len, size_t past_seq_len, int* position2d_ids, size_t head_stride_in_kv); + size_t batch, size_t q_seq_len, size_t past_seq_len, int* position2d_ids, size_t head_stride_in_kv, float* cos, float* sin); emb_gpt::create_param _create_param; size_t _head_num = 32; size_t _size_per_head = 80; size_t _hidden_size = 32 * 80; - size_t _rotary_emb_base = 10000; - float _rotary_pct = 0.25; - size_t _max_seq_len = 400; + size_t _rotary_dim = 20; // aligned to cache line size_t _size_per_head_aligned = 80; - int _rotary_ndims = 0; - std::shared_ptr _cos_cached; - std::shared_ptr _sin_cached; int64_t _input_type_size = 1; int64_t _output_type_size = 1; bool _use_position2d = false; @@ -63,56 +57,17 @@ bool emb_gpt_impl_avx512::create(const emb_gpt::create_param& param) { _size_per_head = param.head_size; _size_per_head_aligned = param.head_size_aligned; _hidden_size = param.head_size * param.num_heads; - _rotary_emb_base = param.rotary_emb_base; - _rotary_pct = param.rotary_pct; - _max_seq_len = param.max_seq_len; + _rotary_dim = param.rotary_dims; _input_type_size = sizeof(ov::bfloat16); _output_type_size = sizeof(ov::bfloat16); if (param.dst_precision == dnnl_s8) _output_type_size = sizeof(int8_t); _use_position2d = param.use_position2d; - if (_use_position2d) { - _rotary_ndims = static_cast(_size_per_head / 2); - } else { - _rotary_ndims = static_cast(_size_per_head * _rotary_pct); - } - initRotery(_max_seq_len); return true; } -void emb_gpt_impl_avx512::initRotery(size_t max_seq_len) { - std::vector inv_freq; - for (int i = 0; i < _rotary_ndims; i += 2) { - inv_freq.push_back(1.0f / (powf(_rotary_emb_base, static_cast(i) / _rotary_ndims))); - } - std::vector t; - for (size_t i = 0; i < max_seq_len * 2; i++) { - t.push_back(static_cast(i)); - } - auto width = _rotary_ndims / 2 * 2; - auto height = max_seq_len * 2; - auto capacity = height * width * sizeof(float); - _cos_cached = std::shared_ptr( - reinterpret_cast(aligned_alloc(64, capacity)), - [](void * p) { ::free(p); }); - _sin_cached = std::shared_ptr( - reinterpret_cast(aligned_alloc(64, capacity)), - [](void * p) { ::free(p); }); - - auto* cos_p = _cos_cached.get(); - auto* sin_p = _sin_cached.get(); - for (size_t i = 0; i < height; i++) { - for (int j = 0; j < width / 2; j++) { - cos_p[i * width + j] = cosf(t[i] * inv_freq[j]); - cos_p[i * width + j + width / 2] = cosf(t[i] * inv_freq[j]); - sin_p[i * width + j] = sinf(t[i] * inv_freq[j]); - sin_p[i * width + j + width / 2] = sinf(t[i] * inv_freq[j]); - } - } -} - void emb_gpt_impl_avx512::memcpyPastKV(uint8_t** pastk_src, uint8_t** pastv_src, uint8_t** pastk_dst, uint8_t** pastv_dst, size_t batch, size_t past_seq_len, size_t head_stride_in_kv) { parallel_for3d(batch, _head_num, past_seq_len, [&](size_t b, size_t h, size_t s) { @@ -139,10 +94,10 @@ void emb_gpt_impl_avx512::memcpyPastKV(uint8_t** pastk_src, uint8_t** pastv_src, // kv_src shape: [batch, q_seq_len, num_attention_heads, 3 * head_size] // kv_dst shape: [batch, num_attention_heads, q_seq_len+past_seq_len, head_size_aligned] void emb_gpt_impl_avx512::applyRotaryPosEmbMemcpy(uint8_t* q_src, uint8_t* k_src, uint8_t* v_src, uint8_t* q_dst, uint8_t** k_dst, uint8_t** v_dst, - size_t batch, size_t q_seq_len, size_t past_seq_len, size_t head_stride_in_kv) { + size_t batch, size_t q_seq_len, size_t past_seq_len, size_t head_stride_in_kv, float* cos, float* sin) { auto key_offset = _output_type_size * past_seq_len * _size_per_head_aligned; - auto* cos_cached = _cos_cached.get() + past_seq_len * _rotary_ndims; - auto* sin_cached = _sin_cached.get() + past_seq_len * _rotary_ndims; + auto* cos_cached = cos + past_seq_len * _rotary_dim; + auto* sin_cached = sin + past_seq_len * _rotary_dim; parallel_for3d(batch, _head_num, q_seq_len, [&](size_t b, size_t h, size_t s) { // q, k rotary encoding auto q_dst_batch = q_dst + b * _head_num * q_seq_len * _size_per_head_aligned * _output_type_size; @@ -161,11 +116,11 @@ void emb_gpt_impl_avx512::applyRotaryPosEmbMemcpy(uint8_t* q_src, uint8_t* k_src auto* k_src_f = reinterpret_cast(k_src_seq + h * _size_per_head * 3 * _input_type_size); auto* q_dst_f = reinterpret_cast(q_dst_seq + h * q_seq_len * _size_per_head_aligned * _output_type_size); auto* k_dst_f = reinterpret_cast(k_dst_seq + h * head_stride_in_kv * _output_type_size); - rotary_avx512(_rotary_ndims, cos_cached + s * _rotary_ndims, sin_cached + s * _rotary_ndims, q_src_f, k_src_f, q_dst_f, k_dst_f); + rotary_avx512(_rotary_dim, cos_cached + s * _rotary_dim, sin_cached + s * _rotary_dim, q_src_f, k_src_f, q_dst_f, k_dst_f); // q, k concat - memcpy(reinterpret_cast(q_dst_f) + _rotary_ndims * _output_type_size, reinterpret_cast(q_src_f) + _rotary_ndims * _input_type_size, _output_type_size * (_size_per_head - _rotary_ndims)); - memcpy(reinterpret_cast(k_dst_f) + _rotary_ndims * _output_type_size, reinterpret_cast(k_src_f) + _rotary_ndims * _input_type_size, _output_type_size * (_size_per_head - _rotary_ndims)); + memcpy(reinterpret_cast(q_dst_f) + _rotary_dim * _output_type_size, reinterpret_cast(q_src_f) + _rotary_dim * _input_type_size, _output_type_size * (_size_per_head - _rotary_dim)); + memcpy(reinterpret_cast(k_dst_f) + _rotary_dim * _output_type_size, reinterpret_cast(k_src_f) + _rotary_dim * _input_type_size, _output_type_size * (_size_per_head - _rotary_dim)); // v concat memcpy(static_cast(v_dst_seq) + h * head_stride_in_kv * _output_type_size, static_cast(v_src_seq) + h * _size_per_head * 3 * _input_type_size, @@ -179,10 +134,10 @@ void emb_gpt_impl_avx512::applyRotaryPosEmbMemcpy(uint8_t* q_src, uint8_t* k_src // kv_dst shape: [batch, num_attention_heads, q_seq_len+past_seq_len, head_size_aligned] // position2d_ids: [batch, 2, q_seq_len] void emb_gpt_impl_avx512::applyRotaryPosEmbMemcpyWithPosition2d(uint8_t* q_src, uint8_t* k_src, uint8_t* v_src, uint8_t* q_dst, uint8_t** k_dst, uint8_t** v_dst, - size_t batch, size_t q_seq_len, size_t past_seq_len, int* position2d_ids, size_t head_stride_in_kv) { + size_t batch, size_t q_seq_len, size_t past_seq_len, int* position2d_ids, size_t head_stride_in_kv, float* cos, float* sin) { auto key_offset = _output_type_size * past_seq_len * _size_per_head_aligned; - auto* cos_cached = _cos_cached.get(); - auto* sin_cached = _sin_cached.get(); + auto* cos_cached = cos; + auto* sin_cached = sin; parallel_for3d(batch, _head_num, q_seq_len, [&](size_t b, size_t h, size_t s) { // q, k rotary encoding auto q_dst_batch = q_dst + b * _head_num * q_seq_len * _size_per_head_aligned * _output_type_size; @@ -203,12 +158,12 @@ void emb_gpt_impl_avx512::applyRotaryPosEmbMemcpyWithPosition2d(uint8_t* q_src, auto* k_src_f = reinterpret_cast(k_src_seq + h * _size_per_head * 3 * _input_type_size); auto* q_dst_f = reinterpret_cast(q_dst_seq + h * q_seq_len * _size_per_head_aligned * _output_type_size); auto* k_dst_f = reinterpret_cast(k_dst_seq + h * head_stride_in_kv * _output_type_size); - rotary_avx512(_rotary_ndims, cos_cached + pos_batch[s] * _rotary_ndims, sin_cached + pos_batch[s] * _rotary_ndims, q_src_f, k_src_f, q_dst_f, k_dst_f); - rotary_avx512(_rotary_ndims, cos_cached + block_batch[s] * _rotary_ndims, sin_cached + block_batch[s] * _rotary_ndims, - q_src_f + _rotary_ndims, - k_src_f + _rotary_ndims, - q_dst_f + _rotary_ndims, - k_dst_f + _rotary_ndims); + rotary_avx512(_rotary_dim, cos_cached + pos_batch[s] * _rotary_dim, sin_cached + pos_batch[s] * _rotary_dim, q_src_f, k_src_f, q_dst_f, k_dst_f); + rotary_avx512(_rotary_dim, cos_cached + block_batch[s] * _rotary_dim, sin_cached + block_batch[s] * _rotary_dim, + q_src_f + _rotary_dim, + k_src_f + _rotary_dim, + q_dst_f + _rotary_dim, + k_dst_f + _rotary_dim); // v concat memcpy(static_cast(v_dst_seq) + h * head_stride_in_kv * _output_type_size, @@ -233,7 +188,7 @@ void emb_gpt_impl_avx512::exec(const emb_gpt::exec_param& param) { auto head_stride_in_kv = param.head_stride_in_kv; // past kv src != dst, copy src to dst first - if (param.layer_past_key_src[0] != param.layer_past_key_dst[0] && past_seq_len) + if (param.layer_past_key_src && param.layer_past_key_src[0] != param.layer_past_key_dst[0] && past_seq_len) memcpyPastKV(param.layer_past_key_src, param.layer_past_value_src, param.layer_past_key_dst, param.layer_past_value_dst, batch, past_seq_len, head_stride_in_kv); // transpose + rotary embbeding: @@ -254,9 +209,11 @@ void emb_gpt_impl_avx512::exec(const emb_gpt::exec_param& param) { // q_dst shape: [batch, num_attention_heads, q_seq_len, head_size_aligned] // kv_dst shape: [batch, num_attention_heads, q_seq_len+past_seq_len, head_size_aligned] if (_use_position2d) { - applyRotaryPosEmbMemcpyWithPosition2d(query, key, value, query_dst, key_dst, value_dst, batch, query_seq_len, past_seq_len, param.position2d_ids, head_stride_in_kv); + applyRotaryPosEmbMemcpyWithPosition2d(query, key, value, query_dst, key_dst, value_dst, batch, query_seq_len, past_seq_len, + param.position2d_ids, head_stride_in_kv, param.cos, param.sin); } else { - applyRotaryPosEmbMemcpy(query, key, value, query_dst, key_dst, value_dst, batch, query_seq_len, past_seq_len, head_stride_in_kv); + applyRotaryPosEmbMemcpy(query, key, value, query_dst, key_dst, value_dst, batch, query_seq_len, past_seq_len, head_stride_in_kv, + param.cos, param.sin); } } } diff --git a/tests/script/ext/attn_gpt.cpp b/tests/script/ext/attn_gpt.cpp index 5cc4e46..e4e7683 100644 --- a/tests/script/ext/attn_gpt.cpp +++ b/tests/script/ext/attn_gpt.cpp @@ -24,9 +24,8 @@ class attn_gpt { // supported (qkv, dst): (bf16, bf16) llmdnn::data_type_t qkv_precision; llmdnn::data_type_t dst_precision; - size_t rotary_emb_base; + size_t rotary_dims; float normal_factor; - float rotary_pct; bool use_position2d; }; struct exec_param { @@ -43,6 +42,8 @@ class attn_gpt { // [batch, 1, query_seq_len, key_seq_len], when is_causal_in_attention is true uint8_t* attn_output; size_t head_stride_in_kv; + float* cos; + float* sin; }; attn_gpt(); @@ -70,9 +71,7 @@ bool attn_gpt::create(const attn_gpt::create_param& param) { emb_param.head_size_aligned = param.head_size_aligned; emb_param.qkv_precision = param.qkv_precision; emb_param.dst_precision = param.dst_precision; - emb_param.max_seq_len = param.max_seq_len; - emb_param.rotary_emb_base = param.rotary_emb_base; - emb_param.rotary_pct = param.rotary_pct; + emb_param.rotary_dims = param.rotary_dims; emb_param.use_position2d = param.use_position2d; if (!_emb_gpt->create(emb_param)) @@ -112,6 +111,8 @@ void attn_gpt::exec(const attn_gpt::exec_param& param) { emb_param.layer_past_value_dst = param.layer_past_value_dst; emb_param.position2d_ids = param.position2d_ids; emb_param.head_stride_in_kv = param.head_stride_in_kv; + emb_param.cos = param.cos; + emb_param.sin = param.sin; _emb_gpt->exec(emb_param); llmdnn::mha_gpt::exec_param mha_param; @@ -139,8 +140,7 @@ void regclass_attn_gpt(pybind11::module m) { const std::string qkv_precision_name, const std::string dst_precision_name, const size_t max_seq_len, - const size_t rotary_emb_base, - float rotary_pct, + const size_t rotary_dims, bool use_position2d) { attn_gpt::create_param param; param.num_heads = num_heads; @@ -150,8 +150,7 @@ void regclass_attn_gpt(pybind11::module m) { param.qkv_precision = llmdnn::get_dt_from_str(qkv_precision_name); param.dst_precision = llmdnn::get_dt_from_str(dst_precision_name); param.max_seq_len = max_seq_len; - param.rotary_emb_base = rotary_emb_base; - param.rotary_pct = rotary_pct; + param.rotary_dims = rotary_dims; param.use_position2d = use_position2d; if (param.qkv_precision == llmdnn::dnnl_data_type_undef) throw pybind11::type_error("Incorrect qkv type " + qkv_precision_name); @@ -167,8 +166,7 @@ void regclass_attn_gpt(pybind11::module m) { py::arg("qkv_precision_name"), py::arg("dst_precision_name"), py::arg("max_seq_len"), - py::arg("rotary_emb_base"), - py::arg("rotary_pct"), + py::arg("rotary_dims"), py::arg("use_position2d") = false, R"( Create emb @@ -177,7 +175,8 @@ void regclass_attn_gpt(pybind11::module m) { :type num_heads: int )"); cls.def("exec_position", [] (attn_gpt& self, const torch::Tensor& qkv, const torch::Tensor& layer_past_key_dst, - const torch::Tensor& layer_past_value_dst, int64_t past_seq_len, const torch::Tensor& attn_mask, const torch::Tensor& position2d_ids) { + const torch::Tensor& layer_past_value_dst, int64_t past_seq_len, const torch::Tensor& attn_mask, + const torch::Tensor& position2d_ids, const torch::Tensor& cos, const torch::Tensor& sin) { // qkv: [batch, seq_len, (num_heads * 3 * head_size)] // layer_past_padded: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] // past_seq_len: past_seq_len==layer_past.shape[-2] @@ -213,6 +212,8 @@ void regclass_attn_gpt(pybind11::module m) { param.head_stride_in_kv = max_seq_len * head_size_aligned; auto out = qkv.new_empty({batch, query_seq_len, num_heads * head_size}); param.attn_output = reinterpret_cast(out.data_ptr()); + param.cos = reinterpret_cast(cos.data_ptr()); + param.sin = reinterpret_cast(sin.data_ptr()); self.exec(param); @@ -226,6 +227,8 @@ void regclass_attn_gpt(pybind11::module m) { py::arg("past_seq_len"), py::arg("attn_mask"), py::arg("position2d_ids"), + py::arg("cos"), + py::arg("sin"), R"( exec emb diff --git a/tests/script/ext/emb_gpt.cpp b/tests/script/ext/emb_gpt.cpp index 1010f8c..409cc0a 100644 --- a/tests/script/ext/emb_gpt.cpp +++ b/tests/script/ext/emb_gpt.cpp @@ -22,9 +22,7 @@ void regclass_emb_gpt(pybind11::module m) { const size_t head_size_aligned, const std::string qkv_precision_name, const std::string dst_precision_name, - const size_t max_seq_len, - const size_t rotary_emb_base, - float rotary_pct, + const size_t rotary_dims, bool use_position2d) { llmdnn::emb_gpt::create_param param; param.num_heads = num_heads; @@ -32,9 +30,7 @@ void regclass_emb_gpt(pybind11::module m) { param.head_size_aligned = head_size_aligned; param.qkv_precision = llmdnn::get_dt_from_str(qkv_precision_name); param.dst_precision = llmdnn::get_dt_from_str(dst_precision_name); - param.max_seq_len = max_seq_len; - param.rotary_emb_base = rotary_emb_base; - param.rotary_pct = rotary_pct; + param.rotary_dims = rotary_dims; param.use_position2d = use_position2d; if (param.qkv_precision == llmdnn::dnnl_data_type_undef) throw pybind11::type_error("Incorrect qkv type " + qkv_precision_name); @@ -48,9 +44,7 @@ void regclass_emb_gpt(pybind11::module m) { py::arg("head_size_aligned"), py::arg("qkv_precision_name"), py::arg("dst_precision_name"), - py::arg("max_seq_len"), - py::arg("rotary_emb_base"), - py::arg("rotary_pct"), + py::arg("rotary_dims"), py::arg("use_position2d") = false, R"( Create emb @@ -60,7 +54,7 @@ void regclass_emb_gpt(pybind11::module m) { )"); // torch::List cls.def("exec", [] (llmdnn::emb_gpt& self, const torch::Tensor& qkv, const torch::Tensor& layer_past_key_dst, - const torch::Tensor& layer_past_value_dst, const torch::Tensor& query_padded, int64_t past_seq_len) { + const torch::Tensor& layer_past_value_dst, const torch::Tensor& query_padded, int64_t past_seq_len, const torch::Tensor& cos, const torch::Tensor& sin) { // qkv: [batch, seq_len, (num_heads * 3 * head_size)] // layer_past_padded: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] // past_seq_len: past_seq_len==layer_past.shape[-2] @@ -92,6 +86,8 @@ void regclass_emb_gpt(pybind11::module m) { param.layer_past_key_dst = param.layer_past_key_src; param.layer_past_value_dst = param.layer_past_value_src; param.head_stride_in_kv = max_seq_len * head_size_aligned; + param.cos = reinterpret_cast(cos.data_ptr()); + param.sin = reinterpret_cast(sin.data_ptr()); for (int i = 0; i < batch; i++) { param.layer_past_key_src[i] = reinterpret_cast(layer_past_key_dst[i].data_ptr()); param.layer_past_value_src[i] = reinterpret_cast(layer_past_value_dst[i].data_ptr()); @@ -107,6 +103,8 @@ void regclass_emb_gpt(pybind11::module m) { py::arg("layer_past_value_dst"), py::arg("query_padded"), py::arg("past_seq_len"), + py::arg("cos"), + py::arg("sin"), R"( exec emb @@ -114,7 +112,7 @@ void regclass_emb_gpt(pybind11::module m) { :type num_heads: int )"); cls.def("exec_position", [] (llmdnn::emb_gpt& self, const torch::Tensor& qkv, const torch::Tensor& layer_past_key_src, const torch::Tensor& layer_past_value_src, - const torch::Tensor& layer_past_key_dst, const torch::Tensor& layer_past_value_dst, const torch::Tensor& query_padded, int64_t past_seq_len, const torch::Tensor position2d_ids) { + const torch::Tensor& layer_past_key_dst, const torch::Tensor& layer_past_value_dst, const torch::Tensor& query_padded, int64_t past_seq_len, const torch::Tensor position2d_ids, const torch::Tensor& cos, const torch::Tensor& sin) { // qkv: [batch, seq_len, (num_heads * 3 * head_size)] // layer_past_padded: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] // past_seq_len: past_seq_len==layer_past.shape[-2] @@ -146,6 +144,8 @@ void regclass_emb_gpt(pybind11::module m) { param.layer_past_key_dst = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); param.layer_past_value_dst = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); param.head_stride_in_kv = max_seq_len * head_size_aligned; + param.cos = reinterpret_cast(cos.data_ptr()); + param.sin = reinterpret_cast(sin.data_ptr()); for (int i = 0; i < batch; i++) { param.layer_past_key_src[i] = past_seq_len == 0 ? nullptr : reinterpret_cast(layer_past_key_src[i].data_ptr()); param.layer_past_value_src[i] = past_seq_len == 0 ? nullptr : reinterpret_cast(layer_past_value_src[i].data_ptr()); @@ -167,6 +167,8 @@ void regclass_emb_gpt(pybind11::module m) { py::arg("query_padded"), py::arg("past_seq_len"), py::arg("position2d_ids"), + py::arg("cos"), + py::arg("sin"), R"( exec emb diff --git a/tests/script/models/chatglm-6b/modeling_chatglm.py b/tests/script/models/chatglm-6b/modeling_chatglm.py index b220192..5d851ff 100644 --- a/tests/script/models/chatglm-6b/modeling_chatglm.py +++ b/tests/script/models/chatglm-6b/modeling_chatglm.py @@ -419,11 +419,22 @@ def __init__(self, hidden_size, num_attention_heads, self.head_size_aligned = (head_size + 31) // 32 * 32 self.max_sequence_length = max_sequence_length normal_factor = 1.0 / math.sqrt(head_size) + rotary_ndims = int(head_size * 0.5) self.attn.create(num_attention_heads, head_size, self.head_size_aligned, - normal_factor, 'bf16', 'bf16', max_sequence_length, 10000, 0.5, True) + normal_factor, 'bf16', 'bf16', max_sequence_length, rotary_ndims, True) self.layer_past_key_padded = None self.layer_past_value_padded = None self.past_seq_len = 0 + inv_freq = 1.0 / (10000 ** (torch.arange(0, rotary_ndims, 2).float() / rotary_ndims)) + + # Build here to make `torch.jit.trace` work. + self.max_seq_len_cached = max_sequence_length + t = torch.arange(self.max_seq_len_cached, device=inv_freq.device, dtype=inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.cos_cached = emb.cos()[None, None, :, :] + self.sin_cached = emb.sin()[None, None, :, :] @staticmethod def attention_mask_func(attention_scores, attention_mask): @@ -474,7 +485,7 @@ def forward( self.layer_past_value_padded = torch.zeros(shape, dtype=torch.bfloat16) if layer_past is None: self.past_seq_len = 0 - context_layer = self.attn.exec_position(mixed_raw_layer, self.layer_past_key_padded, self.layer_past_value_padded, self.past_seq_len, attention_mask, position_ids) + context_layer = self.attn.exec_position(mixed_raw_layer, self.layer_past_key_padded, self.layer_past_value_padded, self.past_seq_len, attention_mask, position_ids, self.cos_cached, self.sin_cached) present = (self.layer_past_key_padded, self.layer_past_value_padded) attention_probs = None self.past_seq_len += mixed_raw_layer.size(1) diff --git a/tests/script/test_attn_chatglm.py b/tests/script/test_attn_chatglm.py index 9cf5164..2caf02a 100644 --- a/tests/script/test_attn_chatglm.py +++ b/tests/script/test_attn_chatglm.py @@ -339,8 +339,24 @@ def __init__(self, num_attention_heads, hidden_size, head_size_aligned, max_posi qkv_precision_name = 's8' if is_int8 else 'bf16' dst_precision_name = 's8' if is_int8 else 'bf16' + rotary_ndims = int(head_size * rotary_pct) self.attn.create(num_heads, head_size, head_size_aligned, normal_factor, qkv_precision_name, - dst_precision_name, max_seq_len, rotary_emb_base, rotary_pct, True) + dst_precision_name, max_seq_len, rotary_ndims, True) + + inv_freq = 1. / (rotary_emb_base ** (torch.arange(0, rotary_ndims, 2).float() / rotary_ndims)) + #inv_freq = inv_freq.half() + self.max_seq_len_cached = None + self.cos_cached = None + self.sin_cached = None + # use f32 to pass accuracy test + # Build here to make `torch.jit.trace` work. + self.max_seq_len_cached = max_position_embeddings + t = torch.arange(self.max_seq_len_cached, device=inv_freq.device, dtype=inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.cos_cached = emb.cos()[:, None, :] + self.sin_cached = emb.sin()[:, None, :] # qkv: [batch, seq_len, (num_heads * 3 * head_size)] # layer_past_padded: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] @@ -352,7 +368,7 @@ def __init__(self, num_attention_heads, hidden_size, head_size_aligned, max_posi # 1: k: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] # 2: v: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] def forward(self, qkv, layer_past_key_padded, layer_past_value_padded, past_seq_len, attn_mask, position_ids): - return self.attn.exec_position(qkv, layer_past_key_padded, layer_past_value_padded, past_seq_len, attn_mask, position_ids) + return self.attn.exec_position(qkv, layer_past_key_padded, layer_past_value_padded, past_seq_len, attn_mask, position_ids, self.cos_cached, self.sin_cached) HEAD_NUM = 32 diff --git a/tests/script/test_rotary_pastkv.py b/tests/script/test_rotary_pastkv.py index f8ee6e2..402720f 100644 --- a/tests/script/test_rotary_pastkv.py +++ b/tests/script/test_rotary_pastkv.py @@ -115,8 +115,20 @@ def __init__(self, num_attention_heads, hidden_size, head_size_aligned, max_posi qkv_precision_name = 's8' if is_int8 else 'bf16' dst_precision_name = 's8' if is_int8 else 'bf16' + rotary_ndims = int(head_size * rotary_pct) self.emd.create(num_heads, head_size, head_size_aligned, qkv_precision_name, - dst_precision_name, max_seq_len, rotary_emb_base, rotary_pct) + dst_precision_name, rotary_ndims) + + inv_freq = 1.0 / (rotary_emb_base ** (torch.arange(0, rotary_ndims, 2).float() / rotary_ndims)) + + # Build here to make `torch.jit.trace` work. + self.max_seq_len_cached = max_position_embeddings + t = torch.arange(self.max_seq_len_cached, device=inv_freq.device, dtype=inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.cos_cached = emb.cos()[None, None, :, :] + self.sin_cached = emb.sin()[None, None, :, :] # qkv: [batch, seq_len, (num_heads * 3 * head_size)] # layer_past_padded: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] @@ -127,7 +139,7 @@ def __init__(self, num_attention_heads, hidden_size, head_size_aligned, max_posi # 2: k: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] # 3: v: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] def forward(self, qkv, layer_past_key_padded, layer_past_value_padded, query_padded, past_seq_len): - self.emd.exec(qkv, layer_past_key_padded, layer_past_value_padded, query_padded, past_seq_len) + self.emd.exec(qkv, layer_past_key_padded, layer_past_value_padded, query_padded, past_seq_len, self.cos_cached, self.sin_cached) HEAD_NUM = 32 diff --git a/tests/script/test_rotary_pastkv_chatglm.py b/tests/script/test_rotary_pastkv_chatglm.py index 73ea72d..612c7d4 100644 --- a/tests/script/test_rotary_pastkv_chatglm.py +++ b/tests/script/test_rotary_pastkv_chatglm.py @@ -257,8 +257,24 @@ def __init__(self, num_attention_heads, hidden_size, head_size_aligned, max_posi qkv_precision_name = 's8' if is_int8 else 'bf16' dst_precision_name = 's8' if is_int8 else 'bf16' + rotary_ndims = int(head_size * rotary_pct) self.emd.create(num_heads, head_size, head_size_aligned, qkv_precision_name, - dst_precision_name, max_seq_len, rotary_emb_base, rotary_pct, True) + dst_precision_name, rotary_ndims, True) + + inv_freq = 1. / (rotary_emb_base ** (torch.arange(0, rotary_ndims, 2).float() / rotary_ndims)) + #inv_freq = inv_freq.half() + self.max_seq_len_cached = None + self.cos_cached = None + self.sin_cached = None + # use f32 to pass accuracy test + # Build here to make `torch.jit.trace` work. + self.max_seq_len_cached = max_position_embeddings + t = torch.arange(self.max_seq_len_cached, device=inv_freq.device, dtype=inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.cos_cached = emb.cos()[:, None, :] + self.sin_cached = emb.sin()[:, None, :] # qkv: [batch, seq_len, (num_heads * 3 * head_size)] # layer_past_padded: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] @@ -269,7 +285,7 @@ def __init__(self, num_attention_heads, hidden_size, head_size_aligned, max_posi # 2: k: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] # 3: v: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] def forward(self, qkv, layer_past_key_src, layer_past_value_src, layer_past_key_dst, layer_past_value_dst, query_padded, past_seq_len, position_ids): - self.emd.exec_position(qkv, layer_past_key_src, layer_past_value_src, layer_past_key_dst, layer_past_value_dst, query_padded, past_seq_len, position_ids) + self.emd.exec_position(qkv, layer_past_key_src, layer_past_value_src, layer_past_key_dst, layer_past_value_dst, query_padded, past_seq_len, position_ids, self.cos_cached, self.sin_cached) HEAD_NUM = 32 From 6d5971cc8cda2764a4c6dab03be6fcdf49aba973 Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Mon, 3 Jul 2023 22:28:02 +0800 Subject: [PATCH 19/54] qkv of emb changes to q, k, v --- include/llm_emb_gpt.hpp | 11 +++++-- src/emb_gpt_avx512.cpp | 55 +++++++++++++++++------------------ tests/script/ext/attn_gpt.cpp | 19 ++++++++++-- tests/script/ext/emb_gpt.cpp | 10 +++++-- 4 files changed, 59 insertions(+), 36 deletions(-) diff --git a/include/llm_emb_gpt.hpp b/include/llm_emb_gpt.hpp index 95cdd07..dfc4947 100644 --- a/include/llm_emb_gpt.hpp +++ b/include/llm_emb_gpt.hpp @@ -27,14 +27,19 @@ class emb_gpt { size_t batch; size_t query_seq_len; size_t past_seq_len; - uint8_t* qkv; // shape: [batch, query_seq_len, 3 * hidden size] + uint8_t* q; // shape: [batch, query_seq_len, hidden size], inner stride is ldq + uint8_t* k; // shape: [batch, query_seq_len, hidden size], inner stride is ldk + uint8_t* v; // shape: [batch, query_seq_len, hidden size], inner stride is ldv + size_t ldq; // inner stride of q + size_t ldk; // inner stride of k + size_t ldv; // inner stride of v uint8_t* query_dst; // rotary embbeding dst uint8_t** layer_past_key_src; // past key src uint8_t** layer_past_value_src; // past value src uint8_t** layer_past_key_dst; // past key dst, if layer_past_key_src!=layer_past_key_dst, will copy layer_past_key_src to layer_past_key_dst uint8_t** layer_past_value_dst; // past value dst, if layer_past_value!=layer_past_value_dst, will copy layer_past_value to layer_past_value_dst - float* cos; // cos table - float* sin; // sin table + float* cos; // cos lookup table, shape: [max_seq_len, rotary_dims] + float* sin; // sin lookup table, shape: [max_seq_len, rotary_dims] int* position2d_ids; // shape: [batch, 2, query_seq_len] size_t head_stride_in_kv; // kv stride for next head; kv may be preallocated a big buffer }; diff --git a/src/emb_gpt_avx512.cpp b/src/emb_gpt_avx512.cpp index 2485764..b290da2 100644 --- a/src/emb_gpt_avx512.cpp +++ b/src/emb_gpt_avx512.cpp @@ -24,9 +24,9 @@ struct emb_gpt_impl_avx512 : public emb_gpt::impl { void memcpyPastKV(uint8_t** pastk_src, uint8_t** pastv_src, uint8_t** pastk_dst, uint8_t** pastv_dst, size_t batch, size_t past_seq_len, size_t head_stride_in_kv); - void applyRotaryPosEmbMemcpy(uint8_t* q_src, uint8_t* k_src, uint8_t* v_src, uint8_t* q_dst, uint8_t** k_dst, uint8_t** v_dst, + void applyRotaryPosEmbMemcpy(uint8_t* q_src, uint8_t* k_src, uint8_t* v_src, size_t ldq, size_t ldk, size_t ldv, uint8_t* q_dst, uint8_t** k_dst, uint8_t** v_dst, size_t batch, size_t q_seq_len, size_t past_seq_len, size_t head_stride_in_kv, float* cos, float* sin); - void applyRotaryPosEmbMemcpyWithPosition2d(uint8_t* q_src, uint8_t* k_src, uint8_t* v_src, uint8_t* q_dst, uint8_t** k_dst, uint8_t** v_dst, + void applyRotaryPosEmbMemcpyWithPosition2d(uint8_t* q_src, uint8_t* k_src, uint8_t* v_src, size_t ldq, size_t ldk, size_t ldv, uint8_t* q_dst, uint8_t** k_dst, uint8_t** v_dst, size_t batch, size_t q_seq_len, size_t past_seq_len, int* position2d_ids, size_t head_stride_in_kv, float* cos, float* sin); emb_gpt::create_param _create_param; @@ -93,7 +93,7 @@ void emb_gpt_impl_avx512::memcpyPastKV(uint8_t** pastk_src, uint8_t** pastv_src, // q_dst shape: [batch, num_attention_heads, q_seq_len, head_size_aligned] // kv_src shape: [batch, q_seq_len, num_attention_heads, 3 * head_size] // kv_dst shape: [batch, num_attention_heads, q_seq_len+past_seq_len, head_size_aligned] -void emb_gpt_impl_avx512::applyRotaryPosEmbMemcpy(uint8_t* q_src, uint8_t* k_src, uint8_t* v_src, uint8_t* q_dst, uint8_t** k_dst, uint8_t** v_dst, +void emb_gpt_impl_avx512::applyRotaryPosEmbMemcpy(uint8_t* q_src, uint8_t* k_src, uint8_t* v_src, size_t ldq, size_t ldk, size_t ldv, uint8_t* q_dst, uint8_t** k_dst, uint8_t** v_dst, size_t batch, size_t q_seq_len, size_t past_seq_len, size_t head_stride_in_kv, float* cos, float* sin) { auto key_offset = _output_type_size * past_seq_len * _size_per_head_aligned; auto* cos_cached = cos + past_seq_len * _rotary_dim; @@ -103,17 +103,17 @@ void emb_gpt_impl_avx512::applyRotaryPosEmbMemcpy(uint8_t* q_src, uint8_t* k_src auto q_dst_batch = q_dst + b * _head_num * q_seq_len * _size_per_head_aligned * _output_type_size; auto k_dst_batch = k_dst[b] + key_offset; auto v_dst_batch = v_dst[b] + key_offset; - auto q_src_batch = q_src + b * _hidden_size * 3 * q_seq_len * _input_type_size; - auto k_src_batch = k_src + b * _hidden_size * 3 * q_seq_len * _input_type_size; - auto v_src_batch = v_src + b * _hidden_size * 3 * q_seq_len * _input_type_size; + auto q_src_batch = q_src + b * _head_num * ldq * q_seq_len * _input_type_size; + auto k_src_batch = k_src + b * _head_num * ldk * q_seq_len * _input_type_size; + auto v_src_batch = v_src + b * _head_num * ldv * q_seq_len * _input_type_size; auto q_dst_seq = q_dst_batch + s * _size_per_head_aligned * _output_type_size; auto k_dst_seq = k_dst_batch + s * _size_per_head_aligned * _output_type_size; auto v_dst_seq = v_dst_batch + s * _size_per_head_aligned * _output_type_size; - auto q_src_seq = q_src_batch + s * _hidden_size * 3 * _input_type_size; - auto k_src_seq = k_src_batch + s * _hidden_size * 3 * _input_type_size; - auto v_src_seq = v_src_batch + s * _hidden_size * 3 * _input_type_size; - auto* q_src_f = reinterpret_cast(q_src_seq + h * _size_per_head * 3 * _input_type_size); - auto* k_src_f = reinterpret_cast(k_src_seq + h * _size_per_head * 3 * _input_type_size); + auto q_src_seq = q_src_batch + s * _head_num * ldq * _input_type_size; + auto k_src_seq = k_src_batch + s * _head_num * ldk * _input_type_size; + auto v_src_seq = v_src_batch + s * _head_num * ldv * _input_type_size; + auto* q_src_f = reinterpret_cast(q_src_seq + h * ldq * _input_type_size); + auto* k_src_f = reinterpret_cast(k_src_seq + h * ldk * _input_type_size); auto* q_dst_f = reinterpret_cast(q_dst_seq + h * q_seq_len * _size_per_head_aligned * _output_type_size); auto* k_dst_f = reinterpret_cast(k_dst_seq + h * head_stride_in_kv * _output_type_size); rotary_avx512(_rotary_dim, cos_cached + s * _rotary_dim, sin_cached + s * _rotary_dim, q_src_f, k_src_f, q_dst_f, k_dst_f); @@ -123,7 +123,7 @@ void emb_gpt_impl_avx512::applyRotaryPosEmbMemcpy(uint8_t* q_src, uint8_t* k_src memcpy(reinterpret_cast(k_dst_f) + _rotary_dim * _output_type_size, reinterpret_cast(k_src_f) + _rotary_dim * _input_type_size, _output_type_size * (_size_per_head - _rotary_dim)); // v concat memcpy(static_cast(v_dst_seq) + h * head_stride_in_kv * _output_type_size, - static_cast(v_src_seq) + h * _size_per_head * 3 * _input_type_size, + static_cast(v_src_seq) + h * ldv * _input_type_size, _size_per_head * _output_type_size); }); } @@ -133,7 +133,7 @@ void emb_gpt_impl_avx512::applyRotaryPosEmbMemcpy(uint8_t* q_src, uint8_t* k_src // kv_src shape: [batch, q_seq_len, num_attention_heads, 3 * head_size] // kv_dst shape: [batch, num_attention_heads, q_seq_len+past_seq_len, head_size_aligned] // position2d_ids: [batch, 2, q_seq_len] -void emb_gpt_impl_avx512::applyRotaryPosEmbMemcpyWithPosition2d(uint8_t* q_src, uint8_t* k_src, uint8_t* v_src, uint8_t* q_dst, uint8_t** k_dst, uint8_t** v_dst, +void emb_gpt_impl_avx512::applyRotaryPosEmbMemcpyWithPosition2d(uint8_t* q_src, uint8_t* k_src, uint8_t* v_src, size_t ldq, size_t ldk, size_t ldv, uint8_t* q_dst, uint8_t** k_dst, uint8_t** v_dst, size_t batch, size_t q_seq_len, size_t past_seq_len, int* position2d_ids, size_t head_stride_in_kv, float* cos, float* sin) { auto key_offset = _output_type_size * past_seq_len * _size_per_head_aligned; auto* cos_cached = cos; @@ -145,17 +145,17 @@ void emb_gpt_impl_avx512::applyRotaryPosEmbMemcpyWithPosition2d(uint8_t* q_src, auto v_dst_batch = v_dst[b] + key_offset; auto pos_batch = position2d_ids + b * 2 * q_seq_len; auto block_batch = pos_batch + q_seq_len; - auto q_src_batch = q_src + b * _hidden_size * 3 * q_seq_len * _input_type_size; - auto k_src_batch = k_src + b * _hidden_size * 3 * q_seq_len * _input_type_size; - auto v_src_batch = v_src + b * _hidden_size * 3 * q_seq_len * _input_type_size; + auto q_src_batch = q_src + b * _head_num * ldq * q_seq_len * _input_type_size; + auto k_src_batch = k_src + b * _head_num * ldk * q_seq_len * _input_type_size; + auto v_src_batch = v_src + b * _head_num * ldv * q_seq_len * _input_type_size; auto q_dst_seq = q_dst_batch + s * _size_per_head_aligned * _output_type_size; auto k_dst_seq = k_dst_batch + s * _size_per_head_aligned * _output_type_size; auto v_dst_seq = v_dst_batch + s * _size_per_head_aligned * _output_type_size; - auto q_src_seq = q_src_batch + s * _hidden_size * 3 * _input_type_size; - auto k_src_seq = k_src_batch + s * _hidden_size * 3 * _input_type_size; - auto v_src_seq = v_src_batch + s * _hidden_size * 3 * _input_type_size; - auto* q_src_f = reinterpret_cast(q_src_seq + h * _size_per_head * 3 * _input_type_size); - auto* k_src_f = reinterpret_cast(k_src_seq + h * _size_per_head * 3 * _input_type_size); + auto q_src_seq = q_src_batch + s * _head_num * ldq * _input_type_size; + auto k_src_seq = k_src_batch + s * _head_num * ldk * _input_type_size; + auto v_src_seq = v_src_batch + s * _head_num * ldv * _input_type_size; + auto* q_src_f = reinterpret_cast(q_src_seq + h * ldq * _input_type_size); + auto* k_src_f = reinterpret_cast(k_src_seq + h * ldk * _input_type_size); auto* q_dst_f = reinterpret_cast(q_dst_seq + h * q_seq_len * _size_per_head_aligned * _output_type_size); auto* k_dst_f = reinterpret_cast(k_dst_seq + h * head_stride_in_kv * _output_type_size); rotary_avx512(_rotary_dim, cos_cached + pos_batch[s] * _rotary_dim, sin_cached + pos_batch[s] * _rotary_dim, q_src_f, k_src_f, q_dst_f, k_dst_f); @@ -167,7 +167,7 @@ void emb_gpt_impl_avx512::applyRotaryPosEmbMemcpyWithPosition2d(uint8_t* q_src, // v concat memcpy(static_cast(v_dst_seq) + h * head_stride_in_kv * _output_type_size, - static_cast(v_src_seq) + h * _size_per_head * 3 * _input_type_size, + static_cast(v_src_seq) + h * ldv* _input_type_size, _size_per_head * _output_type_size); }); } @@ -175,10 +175,9 @@ void emb_gpt_impl_avx512::applyRotaryPosEmbMemcpyWithPosition2d(uint8_t* q_src, void emb_gpt_impl_avx512::exec(const emb_gpt::exec_param& param) { // [batch, seq_len, (num_heads * 3 * head_size)] // --> [batch, seq_len, num_heads, 3 * head_size] - auto* qkv = param.qkv; - auto query = qkv; // qkv[..., : self.head_size].permute(0, 2, 1, 3) - auto key = qkv + _size_per_head * _input_type_size; // qkv[..., self.head_size : 2 * self.head_size].permute(0, 2, 1, 3) - auto value = qkv + 2 * _size_per_head * _input_type_size; // qkv[..., 2 * self.head_size :].permute(0, 2, 1, 3) + auto query = param.q; + auto key = param.k; + auto value = param.v; auto query_dst = param.query_dst; auto key_dst = param.layer_past_key_dst; auto value_dst = param.layer_past_value_dst; @@ -209,10 +208,10 @@ void emb_gpt_impl_avx512::exec(const emb_gpt::exec_param& param) { // q_dst shape: [batch, num_attention_heads, q_seq_len, head_size_aligned] // kv_dst shape: [batch, num_attention_heads, q_seq_len+past_seq_len, head_size_aligned] if (_use_position2d) { - applyRotaryPosEmbMemcpyWithPosition2d(query, key, value, query_dst, key_dst, value_dst, batch, query_seq_len, past_seq_len, + applyRotaryPosEmbMemcpyWithPosition2d(query, key, value, param.ldq, param.ldk, param.ldv, query_dst, key_dst, value_dst, batch, query_seq_len, past_seq_len, param.position2d_ids, head_stride_in_kv, param.cos, param.sin); } else { - applyRotaryPosEmbMemcpy(query, key, value, query_dst, key_dst, value_dst, batch, query_seq_len, past_seq_len, head_stride_in_kv, + applyRotaryPosEmbMemcpy(query, key, value, param.ldq, param.ldk, param.ldv, query_dst, key_dst, value_dst, batch, query_seq_len, past_seq_len, head_stride_in_kv, param.cos, param.sin); } } diff --git a/tests/script/ext/attn_gpt.cpp b/tests/script/ext/attn_gpt.cpp index e4e7683..a87a3a1 100644 --- a/tests/script/ext/attn_gpt.cpp +++ b/tests/script/ext/attn_gpt.cpp @@ -33,7 +33,12 @@ class attn_gpt { size_t query_seq_len; size_t past_seq_len; bool is_causal_in_attention; // causal mask is fused in attention mask: chatglm uses it. - uint8_t* qkv; + uint8_t* q; // shape: [batch, query_seq_len, hidden size], inner stride is ldq + uint8_t* k; // shape: [batch, query_seq_len, hidden size], inner stride is ldk + uint8_t* v; // shape: [batch, query_seq_len, hidden size], inner stride is ldv + size_t ldq; // inner stride of q + size_t ldk; // inner stride of k + size_t ldv; // inner stride of v uint8_t** layer_past_key_dst; uint8_t** layer_past_value_dst; int* position2d_ids; // shape: [batch, 2, query_seq_len] @@ -103,7 +108,12 @@ void attn_gpt::exec(const attn_gpt::exec_param& param) { emb_param.batch = param.batch; emb_param.query_seq_len = param.query_seq_len; emb_param.past_seq_len = param.past_seq_len; - emb_param.qkv = param.qkv; + emb_param.q = param.q; + emb_param.k = param.k; + emb_param.v = param.v; + emb_param.ldq = param.ldq; + emb_param.ldk = param.ldk; + emb_param.ldv = param.ldv; emb_param.query_dst = _query_dst.get(); emb_param.layer_past_key_src = param.layer_past_key_dst; emb_param.layer_past_value_src = param.layer_past_value_dst; @@ -198,7 +208,10 @@ void regclass_attn_gpt(pybind11::module m) { param.batch = batch; param.query_seq_len = query_seq_len; param.past_seq_len = past_seq_len; - param.qkv = reinterpret_cast(qkv.data_ptr()); + param.q = reinterpret_cast(qkv.data_ptr()); + param.k = param.q + head_size * sizeof(ov::bfloat16); + param.v = param.k + head_size * sizeof(ov::bfloat16); + param.ldq = param.ldk = param.ldv = head_size * 3; param.layer_past_key_dst = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); param.layer_past_value_dst = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); for (int i = 0; i < batch; i++) { diff --git a/tests/script/ext/emb_gpt.cpp b/tests/script/ext/emb_gpt.cpp index 409cc0a..cbf4aa8 100644 --- a/tests/script/ext/emb_gpt.cpp +++ b/tests/script/ext/emb_gpt.cpp @@ -79,7 +79,10 @@ void regclass_emb_gpt(pybind11::module m) { param.batch = batch; param.query_seq_len = query_seq_len; param.past_seq_len = past_seq_len; - param.qkv = reinterpret_cast(qkv.data_ptr()); + param.q = reinterpret_cast(qkv.data_ptr()); + param.k = param.q + head_size * sizeof(ov::bfloat16); + param.v = param.k + head_size * sizeof(ov::bfloat16); + param.ldq = param.ldk = param.ldv = head_size * 3; param.query_dst = reinterpret_cast(query_padded.data_ptr()); param.layer_past_key_src = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); param.layer_past_value_src = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); @@ -137,7 +140,10 @@ void regclass_emb_gpt(pybind11::module m) { param.batch = batch; param.query_seq_len = query_seq_len; param.past_seq_len = past_seq_len; - param.qkv = reinterpret_cast(qkv.data_ptr()); + param.q = reinterpret_cast(qkv.data_ptr()); + param.k = param.q + head_size * sizeof(ov::bfloat16); + param.v = param.k + head_size * sizeof(ov::bfloat16); + param.ldq = param.ldk = param.ldv = head_size * 3; param.query_dst = reinterpret_cast(query_padded.data_ptr()); param.layer_past_key_src = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); param.layer_past_value_src = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); From a741a8b73b9119999a6eda0fd102082346155f00 Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Tue, 4 Jul 2023 17:08:13 +0800 Subject: [PATCH 20/54] gelu tanh support --- include/llm_fc.hpp | 15 +- src/common/tensor2d.hpp | 2 - src/common/utility.hpp | 10 + src/fc_kernel_amx.cpp | 64 ++++- src/gelu_kernel_avx512.hpp | 386 ++++++++++++++++++++++++++ src/mm_kernel_common_amx.hpp | 98 +------ tests/src/test_common.hpp | 2 + tests/src/test_fc_kernel_amx.cpp | 16 ++ tests/src/test_gelu_kernel_avx512.cpp | 97 +++++++ 9 files changed, 592 insertions(+), 98 deletions(-) create mode 100644 src/gelu_kernel_avx512.hpp create mode 100644 tests/src/test_gelu_kernel_avx512.cpp diff --git a/include/llm_fc.hpp b/include/llm_fc.hpp index ac77fd4..1272c1b 100644 --- a/include/llm_fc.hpp +++ b/include/llm_fc.hpp @@ -12,8 +12,10 @@ typedef enum { NONE = 0, DEQUANT = 1 << 0, BIAS = 1 << 1, - GELU = 1 << 2, - QUANT = 1 << 3, + GELU_ERF = 1 << 2, + GELU_TANH = 1 << 3, + QUANT = 1 << 4, + GELU = GELU_ERF, // default is ERF BIAS_GELU = BIAS | GELU, DEQUANT_BIAS_GELU = DEQUANT | BIAS_GELU, @@ -23,7 +25,14 @@ typedef enum { DEQUANT_QUANT = DEQUANT | QUANT, DEQUANT_GELU = DEQUANT | GELU, - DEQUANT_BIAS = DEQUANT | BIAS + DEQUANT_BIAS = DEQUANT | BIAS, + + BIAS_GELU_TANH = BIAS | GELU_TANH, + DEQUANT_BIAS_GELU_TANH = DEQUANT | BIAS_GELU_TANH, + DEQUANT_BIAS_GELU_TANH_QUANT = DEQUANT_BIAS_GELU_TANH | QUANT, + DEQUANT_GELU_TANH_QUANT = DEQUANT | GELU_TANH | QUANT, + + DEQUANT_GELU_TANH = DEQUANT | GELU_TANH, } postops_types; struct fc_create_param { diff --git a/src/common/tensor2d.hpp b/src/common/tensor2d.hpp index 971d1b0..ce83671 100644 --- a/src/common/tensor2d.hpp +++ b/src/common/tensor2d.hpp @@ -94,8 +94,6 @@ struct tensor2D { stride = d1 * sizeof(T); if ((stride % 64) && (!force_compact)) { auto stride_fix = rndup(stride, 64); - std::cout << "\tWarnning: stride " << stride << " is not aligned to cache line, will increase to " << stride_fix - << " (" << stride_fix/64 << " cache lines)\n"; stride = stride_fix; } padded_dim1 = stride / sizeof(T); diff --git a/src/common/utility.hpp b/src/common/utility.hpp index 5eee4fb..4d81a19 100644 --- a/src/common/utility.hpp +++ b/src/common/utility.hpp @@ -12,6 +12,16 @@ #include #include "llm_types.hpp" +#ifndef OV_DECL_ALIGNED +# ifdef __GNUC__ +# define OV_DECL_ALIGNED(x) __attribute__ ((aligned (x))) +# elif defined _MSC_VER +# define OV_DECL_ALIGNED(x) __declspec(align(x)) +# else +# define OV_DECL_ALIGNED(x) +# endif +#endif // OV_DECL_ALIGNED + namespace llmdnn { inline size_t get_precision_size(data_type_t type) { diff --git a/src/fc_kernel_amx.cpp b/src/fc_kernel_amx.cpp index 8f6d6eb..9b91e89 100644 --- a/src/fc_kernel_amx.cpp +++ b/src/fc_kernel_amx.cpp @@ -38,13 +38,13 @@ struct fc_kernel { using supported_key = std::tuple; using supported_value = std::pair; static std::map supported_postops = { - { { dnnl_s8, dnnl_s8, dnnl_s8 }, { DEQUANT | QUANT, BIAS | GELU } }, - { { dnnl_s8, dnnl_s8, dnnl_bf16 }, { DEQUANT, BIAS | GELU } }, - { { dnnl_s8, dnnl_s8, dnnl_f32 }, { DEQUANT, BIAS | GELU } }, - { { dnnl_bf16, dnnl_bf16, dnnl_bf16 }, { 0, BIAS | GELU } }, - { { dnnl_bf16, dnnl_bf16, dnnl_f32 }, { 0, BIAS | GELU } }, - { { dnnl_bf16, dnnl_s8, dnnl_f32 }, { DEQUANT, BIAS | GELU } }, - { { dnnl_bf16, dnnl_s8, dnnl_bf16 }, { DEQUANT, BIAS | GELU } }, + { { dnnl_s8, dnnl_s8, dnnl_s8 }, { DEQUANT | QUANT, BIAS | GELU | GELU_TANH } }, + { { dnnl_s8, dnnl_s8, dnnl_bf16 }, { DEQUANT, BIAS | GELU | GELU_TANH } }, + { { dnnl_s8, dnnl_s8, dnnl_f32 }, { DEQUANT, BIAS | GELU | GELU_TANH } }, + { { dnnl_bf16, dnnl_bf16, dnnl_bf16 }, { 0, BIAS | GELU | GELU_TANH } }, + { { dnnl_bf16, dnnl_bf16, dnnl_f32 }, { 0, BIAS | GELU | GELU_TANH } }, + { { dnnl_bf16, dnnl_s8, dnnl_f32 }, { DEQUANT, BIAS | GELU | GELU_TANH } }, + { { dnnl_bf16, dnnl_s8, dnnl_bf16 }, { DEQUANT, BIAS | GELU | GELU_TANH } }, }; static bool check_valid_postops(size_t value, data_type_t dt_a, data_type_t dt_b, data_type_t dt_c) { @@ -135,6 +135,11 @@ void fc_kernel_execute_amx(const fc_kernel* mm, void* ptr_a, void* ptr_b, void* ppkernel.set_deq_scale(dq); ppkernel.set_q_scale(q); (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); + } else if (mm->postops_type & GELU_TANH) { + amx_kernel::PP::BiasGeluStore ppkernel(c); + ppkernel.set_deq_scale(dq); + ppkernel.set_q_scale(q); + (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); } else { amx_kernel::PP::BiasGeluStore ppkernel(c); ppkernel.set_deq_scale(dq); @@ -147,6 +152,11 @@ void fc_kernel_execute_amx(const fc_kernel* mm, void* ptr_a, void* ptr_b, void* ppkernel.set_deq_scale(dq); ppkernel.set_q_scale(q); (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); + } else if (mm->postops_type & GELU_TANH) { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + ppkernel.set_deq_scale(dq); + ppkernel.set_q_scale(q); + (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); } else { amx_kernel::PP::BiasGeluStore ppkernel(c, bias); ppkernel.set_deq_scale(dq); @@ -161,6 +171,10 @@ void fc_kernel_execute_amx(const fc_kernel* mm, void* ptr_a, void* ptr_b, void* amx_kernel::PP::BiasGeluStore ppkernel(c); ppkernel.set_deq_scale(dq); (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); + } else if (mm->postops_type & GELU_TANH) { + amx_kernel::PP::BiasGeluStore ppkernel(c); + ppkernel.set_deq_scale(dq); + (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); } else { amx_kernel::PP::BiasGeluStore ppkernel(c); ppkernel.set_deq_scale(dq); @@ -171,6 +185,10 @@ void fc_kernel_execute_amx(const fc_kernel* mm, void* ptr_a, void* ptr_b, void* amx_kernel::PP::BiasGeluStore ppkernel(c, bias); ppkernel.set_deq_scale(dq); (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); + } else if (mm->postops_type & GELU_TANH) { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + ppkernel.set_deq_scale(dq); + (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); } else { amx_kernel::PP::BiasGeluStore ppkernel(c, bias); ppkernel.set_deq_scale(dq); @@ -184,6 +202,10 @@ void fc_kernel_execute_amx(const fc_kernel* mm, void* ptr_a, void* ptr_b, void* amx_kernel::PP::BiasGeluStore ppkernel(c); ppkernel.set_deq_scale(dq); (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); + } else if (mm->postops_type & GELU_TANH) { + amx_kernel::PP::BiasGeluStore ppkernel(c); + ppkernel.set_deq_scale(dq); + (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); } else { amx_kernel::PP::BiasGeluStore ppkernel(c); ppkernel.set_deq_scale(dq); @@ -194,6 +216,10 @@ void fc_kernel_execute_amx(const fc_kernel* mm, void* ptr_a, void* ptr_b, void* amx_kernel::PP::BiasGeluStore ppkernel(c, bias); ppkernel.set_deq_scale(dq); (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); + } else if (mm->postops_type & GELU_TANH) { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + ppkernel.set_deq_scale(dq); + (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); } else { amx_kernel::PP::BiasGeluStore ppkernel(c, bias); ppkernel.set_deq_scale(dq); @@ -217,6 +243,9 @@ void fc_kernel_execute_amx(const fc_kernel* mm, void* ptr_a, void* ptr_b, void* if (mm->postops_type & GELU) { amx_kernel::PP::BiasGeluStore ppkernel(c); (*mm->bf16xbf16)(a, b, n_start, n_end, ppkernel); + } else if (mm->postops_type & GELU_TANH) { + amx_kernel::PP::BiasGeluStore ppkernel(c); + (*mm->bf16xbf16)(a, b, n_start, n_end, ppkernel); } else { amx_kernel::PP::BiasGeluStore ppkernel(c); (*mm->bf16xbf16)(a, b, n_start, n_end, ppkernel); @@ -225,6 +254,9 @@ void fc_kernel_execute_amx(const fc_kernel* mm, void* ptr_a, void* ptr_b, void* if (mm->postops_type & GELU) { amx_kernel::PP::BiasGeluStore ppkernel(c, bias); (*mm->bf16xbf16)(a, b, n_start, n_end, ppkernel); + } else if (mm->postops_type & GELU_TANH) { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + (*mm->bf16xbf16)(a, b, n_start, n_end, ppkernel); } else { amx_kernel::PP::BiasGeluStore ppkernel(c, bias); (*mm->bf16xbf16)(a, b, n_start, n_end, ppkernel); @@ -236,6 +268,9 @@ void fc_kernel_execute_amx(const fc_kernel* mm, void* ptr_a, void* ptr_b, void* if (mm->postops_type & GELU) { amx_kernel::PP::BiasGeluStore ppkernel(c); (*mm->bf16xbf16)(a, b, n_start, n_end, ppkernel); + } else if (mm->postops_type & GELU_TANH) { + amx_kernel::PP::BiasGeluStore ppkernel(c); + (*mm->bf16xbf16)(a, b, n_start, n_end, ppkernel); } else { amx_kernel::PP::BiasGeluStore ppkernel(c); (*mm->bf16xbf16)(a, b, n_start, n_end, ppkernel); @@ -244,6 +279,9 @@ void fc_kernel_execute_amx(const fc_kernel* mm, void* ptr_a, void* ptr_b, void* if (mm->postops_type & GELU) { amx_kernel::PP::BiasGeluStore ppkernel(c, bias); (*mm->bf16xbf16)(a, b, n_start, n_end, ppkernel); + } else if (mm->postops_type & GELU_TANH) { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + (*mm->bf16xbf16)(a, b, n_start, n_end, ppkernel); } else { amx_kernel::PP::BiasGeluStore ppkernel(c, bias); (*mm->bf16xbf16)(a, b, n_start, n_end, ppkernel); @@ -260,6 +298,9 @@ void fc_kernel_execute_amx(const fc_kernel* mm, void* ptr_a, void* ptr_b, void* if (mm->postops_type & GELU) { amx_kernel::PP::BiasGeluStore ppkernel(c); (*mm->bf16xi8)(a, b, n_start, n_end, ppkernel); + } else if (mm->postops_type & GELU_TANH) { + amx_kernel::PP::BiasGeluStore ppkernel(c); + (*mm->bf16xi8)(a, b, n_start, n_end, ppkernel); } else { amx_kernel::PP::BiasGeluStore ppkernel(c); (*mm->bf16xi8)(a, b, n_start, n_end, ppkernel); @@ -268,6 +309,9 @@ void fc_kernel_execute_amx(const fc_kernel* mm, void* ptr_a, void* ptr_b, void* if (mm->postops_type & GELU) { amx_kernel::PP::BiasGeluStore ppkernel(c, bias); (*mm->bf16xi8)(a, b, n_start, n_end, ppkernel); + } else if (mm->postops_type & GELU_TANH) { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + (*mm->bf16xi8)(a, b, n_start, n_end, ppkernel); } else { amx_kernel::PP::BiasGeluStore ppkernel(c, bias); (*mm->bf16xi8)(a, b, n_start, n_end, ppkernel); @@ -279,6 +323,9 @@ void fc_kernel_execute_amx(const fc_kernel* mm, void* ptr_a, void* ptr_b, void* if (mm->postops_type & GELU) { amx_kernel::PP::BiasGeluStore ppkernel(c); (*mm->bf16xi8)(a, b, n_start, n_end, ppkernel); + } else if (mm->postops_type & GELU_TANH) { + amx_kernel::PP::BiasGeluStore ppkernel(c); + (*mm->bf16xi8)(a, b, n_start, n_end, ppkernel); } else { amx_kernel::PP::BiasGeluStore ppkernel(c); (*mm->bf16xi8)(a, b, n_start, n_end, ppkernel); @@ -287,6 +334,9 @@ void fc_kernel_execute_amx(const fc_kernel* mm, void* ptr_a, void* ptr_b, void* if (mm->postops_type & GELU) { amx_kernel::PP::BiasGeluStore ppkernel(c, bias); (*mm->bf16xi8)(a, b, n_start, n_end, ppkernel); + } else if (mm->postops_type & GELU_TANH) { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + (*mm->bf16xi8)(a, b, n_start, n_end, ppkernel); } else { amx_kernel::PP::BiasGeluStore ppkernel(c, bias); (*mm->bf16xi8)(a, b, n_start, n_end, ppkernel); diff --git a/src/gelu_kernel_avx512.hpp b/src/gelu_kernel_avx512.hpp new file mode 100644 index 0000000..c416ad5 --- /dev/null +++ b/src/gelu_kernel_avx512.hpp @@ -0,0 +1,386 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#ifdef _WIN32 +#include +#else +#include +#include +#endif +#include "common/bf16.hpp" +#include "llm_types.hpp" +#include "common/utility.hpp" +#include "utility_kernel_avx512.hpp" + +namespace llmdnn { + + // gelu_erf_minimax_approx_compute_vector_fwd in oneDNN + // x*0.5*(1+erf(x/sqrt(2))) = x*0.5*(1 + x*Polynomial(x^2)) + inline __m512 gelu_erf_minmax_approx_avx512(__m512 & x) { + auto x2 = _mm512_mul_ps(x, x); // x^2 + + auto x_positive = _mm512_castsi512_ps(_mm512_and_epi32(_mm512_castps_si512(x), _mm512_set1_epi32(0x7FFFFFFF))); // clear sign mask + auto x_half = _mm512_mul_ps(x, _mm512_set1_ps(0.5f)); + + auto poly = _mm512_castsi512_ps(_mm512_set1_epi32(0x1f1c83fd)); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0xa3198977))); // poly * x^2 + xxx + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0x268a7927))); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0xa998c963))); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0x2c67ddb2))); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0xaf013b2c))); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0x315d4a4f))); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0xb3969b11))); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0x35a776e9))); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0xb79b0914))); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0x3970b255))); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0xbb1b7399))); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0x3ca3621f))); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0xbe082bc7))); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0x3f4c4228))); + + // 1.0f + erf(x * inv_sqrt2) = 1.0f + x * P(x^2) + poly = _mm512_fmadd_ps(poly, x, _mm512_set1_ps(1.0f)); + // x*0.5*(1 + x*Polynomial(x^2)) + poly = _mm512_mul_ps(poly, x_half); + + // combine: + // zone_id + // 1 -inf; -saturation_lbound : 0.0f + // 2 -saturation_lbound; -linear_ubound : x*0.5*(1 + x*Polynomial(x^2)) + // 3 -linear_ubound, linear_ubound : x*0.5 + // 4 linear_ubound : saturation_lbound : x*0.5*(1 + x*Polynomial(x^2)) + // 5 saturation_lbound: +inf : x + constexpr int neg_saturation_lbound = 0xc0a00000; + constexpr int linear_ubound = 0x33800000; + constexpr int saturation_lbound = 0x40a00000; + + auto mask_x_not_zone1 = _mm512_cmpnlt_ps_mask(x, _mm512_castsi512_ps(_mm512_set1_epi32(neg_saturation_lbound))); + x = _mm512_maskz_mov_ps(mask_x_not_zone1, x); + + auto mask_x_in_zone5 = _mm512_cmpnle_ps_mask(x_positive, _mm512_castsi512_ps(_mm512_set1_epi32(saturation_lbound))); + poly = _mm512_mask_mov_ps(poly, mask_x_in_zone5, x); + + auto mask_x_in_zone3 = _mm512_cmple_ps_mask(x_positive, _mm512_castsi512_ps(_mm512_set1_epi32(linear_ubound))); + poly = _mm512_mask_mov_ps(poly, mask_x_in_zone3, x_half); + return poly; + } + + // gelu_tanh_compute_vector_fwd in oneDNN + inline __m512 gelu_tanh_avx512(__m512& x) { + // compute G(x) = sqrt_root_two_over_pi * x * (1 + fitting_const * x * x) + auto x2 = _mm512_mul_ps(x, x); + auto y = _mm512_fmadd_ps(x2, (__m512)_mm512_set1_epi32(0x3d372713), _mm512_set1_ps(1.0f)); + y = _mm512_mul_ps(y, x); + y = _mm512_mul_ps(y, (__m512)_mm512_set1_epi32(0x3f4c422a)); + + // compute tanh(G(x)) + // We split the positive domain in 33 intervals: + // a) [0; linear_ubound]: in this interval tanh(x) = x + // b) [linear_ubound; 0x1.8p-12]: This interval spans part of a + // half binade + // c) [0x1.8p-12; 0x1.0p-11], ..., [0x1.8p2; 0x1.0p3]: + // one interval for each half binade, there are 29 of those + // d) [0x1.0p3; saturation_ubound]: + // This interval spans part of a half binade + // e) [0x1.205966p3; saturation_ubound]: in this interval, tanh(x) = 1 + // For b-d, we need 31 polynomials and will do a table lookup for those. + // To simplify the logic, we will also put a) in the table. + + // The polynomials are of degree 6, so we need to gather 7 coefficients. + // - sse4.1: we do it the naive way using vextract/vinsert. + // Here we will extract the indices in gpr only once and + // reuse them as there are only 4 of them. + // - avx: we do the same as for sse4.1 but use half of the 64-bits + // registers to store the idx of second half of YMM and half for + // responding XMM. Halfway through the copy we exchange Xmm and + // higher half of Ymm and we get the expected result. + // - avx2: we use vpermps and blend for each coefficient. + // This needs an extra vmm to store the mask + // - avx512: because the table fits in 2 registers, we can use vpermi2d. + + // because tanh(x) = -tanh(-x), we extract sign to make x postive + // and reapply sign at the end + auto y_positive = _mm512_and_ps(y, (__m512)(_mm512_set1_epi32(0x7fffffff))); + + // We compute the indices for the table lookup + auto indices = _mm512_sub_epi32((__m512i)y_positive, _mm512_set1_epi32(0x39800000)); + + indices = _mm512_and_epi32(indices, _mm512_set1_epi32(0xffc00000)); + indices = _mm512_srli_epi32(indices, 22); + // we do the argument reduction + auto y_shift = _mm512_and_ps(y_positive, (__m512)_mm512_set1_epi32(0xffc00000)); + + y_shift = _mm512_sub_ps(y_positive, y_shift); + + static uint32_t OV_DECL_ALIGNED(64) tanh_pol_table[] = { + // coefficients of degree 0 + 0x00000000, + 0x39bfffff, + 0x39ffffff, + 0x3a3ffffe, + 0x3a7ffffb, + 0x3abffff7, + 0x3affffeb, + 0x3b3fffdc, + 0x3b7fffab, + 0x3bbfff70, + 0x3bfffeab, + 0x3c3ffdc0, + 0x3c7ffaab, + 0x3cbff701, + 0x3cffeaad, + 0x3d3fdc08, + 0x3d7faacd, + 0x3dbf7081, + 0x3dfeacc9, + 0x3e3dc7fd, + 0x3e7acbf5, + 0x3eb77a9f, + 0x3eec9a9f, + 0x3f22991f, + 0x3f42f7d6, + 0x3f67b7cc, + 0x3f76ca83, + 0x3f7ebbe9, + 0x3f7fd40c, + 0x3f7fff32, + 0x3f7ffffc, + 0x3f800000, + // coefficients of degree 1 + 0x3f800000, + 0x3f800018, + 0x3f7fffe8, + 0x3f7fffda, + 0x3f7fffdc, + 0x3f7fffdc, + 0x3f7fffac, + 0x3f7fff70, + 0x3f7ffeec, + 0x3f7ffdc0, + 0x3f7ffbed, + 0x3f7ff704, + 0x3f7feff5, + 0x3f7fdbca, + 0x3f7fbfff, + 0x3f7f7041, + 0x3f7f009b, + 0x3f7dc36c, + 0x3f7c0aa8, + 0x3f7734b8, + 0x3f70a4de, + 0x3f5f1fd8, + 0x3f495493, + 0x3f18b9ec, + 0x3ed706cb, + 0x3e390b06, + 0x3d90b11f, + 0x3c21a053, + 0x3aaf7fdb, + 0x37ccc1a3, + 0x355c6733, + 0x00000000, + // coefficients of degree 2 + 0x00000000, + 0xbe4e0ff1, + 0x3d25b1b1, + 0x3d6b6dab, + 0x3c9fb1d5, + 0xbabff06f, + 0x3c07b3f6, + 0xbb3fc1bc, + 0x3a9f5921, + 0xbbbf06f2, + 0xbbb0f402, + 0xbc47db9e, + 0xbc73d5e7, + 0xbca25bda, + 0xbcfca780, + 0xbd40e07c, + 0xbd7dab03, + 0xbdbe4a0f, + 0xbdfb14a5, + 0xbe36cc8d, + 0xbe6bd102, + 0xbe9fe7c5, + 0xbeba0f10, + 0xbec206a8, + 0xbea3c388, + 0xbe277d62, + 0xbd8b7960, + 0xbc209f49, + 0xbaad44ca, + 0xb7c6eeac, + 0xb663aa41, + 0x00000000, + // coefficients of degree 3 + 0x00000000, + 0x45b3ae96, + 0xc414eb20, + 0xc450e02e, + 0xc3152b4e, + 0xbead2f56, + 0xc2162e02, + 0xbeb4bd5a, + 0xc11a59a4, + 0xbed2f507, + 0xc020d32c, + 0x3dd0f506, + 0xbf2a75e2, + 0xbff950e3, + 0xbed47334, + 0xbe809b8c, + 0xbeb64532, + 0xbe961a5b, + 0xbe9b63ac, + 0xbea0d4b2, + 0xbe828a77, + 0xbe378612, + 0xbdc20908, + 0x3d2d3957, + 0x3dd46e89, + 0x3db3f629, + 0x3d2c5e7b, + 0x3bd20403, + 0x3a59dfae, + 0x3770af45, + 0x372cc014, + 0x00000000, + // coefficients of degree 4 + 0x00000000, + 0xcc981a1b, + 0x4a7edd3d, + 0x4ab1007c, + 0x48fedd9c, + 0x41a557b5, + 0x477ee32a, + 0x422557f5, + 0x45ff3ce4, + 0x42a55641, + 0x446e0867, + 0xc33dc19a, + 0x42915214, + 0x43af4fad, + 0x4110fe88, + 0xc1099b75, + 0x3fc8a8dc, + 0xbfbeaef5, + 0xbe365aad, + 0x3f4d9652, + 0x3ddfa08f, + 0x3e34e9b8, + 0x3e2d07a6, + 0x3dc63567, + 0x3cdaeb78, + 0xbcd17537, + 0xbc92829c, + 0xbb43ab99, + 0xb9b471dd, + 0xb6baad5a, + 0xb78bafc7, + 0x00000000, + // coefficients of degree 5 + 0x00000000, + 0x52f688d5, + 0xd0505c72, + 0xd08f98e3, + 0xce505cc9, + 0xc7162b8a, + 0xcc5061d6, + 0xc7162bdf, + 0xca50b37f, + 0xc7162a3a, + 0xc8422086, + 0x471a714e, + 0xc5ece1f1, + 0xc70e3d90, + 0xc3eba94a, + 0x43e0c424, + 0xc21f4552, + 0x42217cc8, + 0x405e7dc4, + 0xc10dd401, + 0x3e96b602, + 0xbd1a6d2f, + 0xbd393883, + 0xbd674682, + 0xbd310016, + 0xb961e269, + 0x3ba32495, + 0x3a7680d5, + 0x38b3173c, + 0x35a9deea, + 0x375c3f2a, + 0x00000000, + // coefficients of degree 6 + 0x00000000, + 0xd8995ed1, + 0x558285ea, + 0x55b2cd69, + 0x53028625, + 0x4bc9991f, + 0x5082898a, + 0x4b4999b3, + 0x4e02c07c, + 0x4ac99764, + 0x4b72c822, + 0xca40c0e1, + 0x489413e4, + 0x49b12224, + 0x46134c4e, + 0xc60c2d57, + 0x43c83910, + 0xc3c872d1, + 0xc186bc9e, + 0x42325bc3, + 0xbf2ffa4a, + 0x3d9a203c, + 0xbc545a43, + 0xbae08fee, + 0x3c80225d, + 0x3b1fd1df, + 0xba36b9d1, + 0xb91de544, + 0xb71f100f, + 0xb408e2ed, + 0xb685fec8, + 0x00000000, + }; + auto pol = _mm512_permutex2var_ps(_mm512_load_ps(tanh_pol_table + 32 * 6), indices, _mm512_load_ps(tanh_pol_table + 32 * 6 + 16)); + + pol = _mm512_fmadd_ps(pol, y_shift, _mm512_permutex2var_ps(_mm512_load_ps(tanh_pol_table + 32 * 5), indices, _mm512_load_ps(tanh_pol_table + 32 * 5 + 16))); + pol = _mm512_fmadd_ps(pol, y_shift, _mm512_permutex2var_ps(_mm512_load_ps(tanh_pol_table + 32 * 4), indices, _mm512_load_ps(tanh_pol_table + 32 * 4 + 16))); + pol = _mm512_fmadd_ps(pol, y_shift, _mm512_permutex2var_ps(_mm512_load_ps(tanh_pol_table + 32 * 3), indices, _mm512_load_ps(tanh_pol_table + 32 * 3 + 16))); + pol = _mm512_fmadd_ps(pol, y_shift, _mm512_permutex2var_ps(_mm512_load_ps(tanh_pol_table + 32 * 2), indices, _mm512_load_ps(tanh_pol_table + 32 * 2 + 16))); + pol = _mm512_fmadd_ps(pol, y_shift, _mm512_permutex2var_ps(_mm512_load_ps(tanh_pol_table + 32 * 1), indices, _mm512_load_ps(tanh_pol_table + 32 * 1 + 16))); + pol = _mm512_fmadd_ps(pol, y_shift, _mm512_permutex2var_ps(_mm512_load_ps(tanh_pol_table + 32 * 0), indices, _mm512_load_ps(tanh_pol_table + 32 * 0 + 16))); + + // we restore src with cleared sign, and keep sign + auto sign = _mm512_and_ps(y, (__m512)_mm512_set1_epi32(0x80000000)); + + // Now we blend the results + // [saturation_ubound; +inf[ : we return +/- 1 + auto dst = (__m512)_mm512_set1_epi32(0x3f800000); + // [linear_ubound; saturation_lbound] : we return +/- P(x) + auto mask = (__m512)_mm512_set1_epi32(0x41102cb3); + auto mask16 = _mm512_cmp_ps_mask(mask, y_positive, _CMP_GT_OS); + dst = _mm512_mask_blend_ps(mask16, dst, pol); + // [0; linear_ubound] : we return x + mask = (__m512)_mm512_set1_epi32(0x39ddb3d7); + mask16 = _mm512_cmp_ps_mask(mask, y_positive, _CMP_GT_OS); + dst = _mm512_mask_blend_ps(mask16, dst, y_positive); + + // We reapply the sign and return + dst = _mm512_xor_ps(dst, sign); + + // compute 0.5 * x * (1 + tanh(G(x))) + dst = _mm512_add_ps(dst, (__m512)_mm512_set1_epi32(0x3f800000)); + dst = _mm512_mul_ps(dst, (__m512)_mm512_set1_epi32(0x3f000000)); + dst = _mm512_mul_ps(dst, x); + return dst; + } +} \ No newline at end of file diff --git a/src/mm_kernel_common_amx.hpp b/src/mm_kernel_common_amx.hpp index 3291a03..0bb22ce 100644 --- a/src/mm_kernel_common_amx.hpp +++ b/src/mm_kernel_common_amx.hpp @@ -7,6 +7,8 @@ #include "common/bf16.hpp" #include "common/tensor2d.hpp" #include "utility_kernel_amx.hpp" +#include "gelu_kernel_avx512.hpp" +#include "llm_fc.hpp" #ifdef _WIN32 #include @@ -18,6 +20,8 @@ #include "numa.h" #endif +using namespace llmdnn; + namespace amx_kernel { namespace functional { @@ -818,57 +822,6 @@ namespace functional { _mm512_storeu_epi32(dst + 15*16, rf); } - // gelu_erf_minimax_approx_compute_vector_fwd in oneDNN - // x*0.5*(1+erf(x/sqrt(2))) = x*0.5*(1 + x*Polynomial(x^2)) - inline __m512 gelu_erf_minmax_approx(__m512 & x) { - auto x2 = _mm512_mul_ps(x, x); // x^2 - - auto x_positive = _mm512_castsi512_ps(_mm512_and_epi32(_mm512_castps_si512(x), _mm512_set1_epi32(0x7FFFFFFF))); // clear sign mask - auto x_half = _mm512_mul_ps(x, _mm512_set1_ps(0.5f)); - - auto poly = _mm512_castsi512_ps(_mm512_set1_epi32(0x1f1c83fd)); - poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0xa3198977))); // poly * x^2 + xxx - poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0x268a7927))); - poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0xa998c963))); - poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0x2c67ddb2))); - poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0xaf013b2c))); - poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0x315d4a4f))); - poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0xb3969b11))); - poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0x35a776e9))); - poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0xb79b0914))); - poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0x3970b255))); - poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0xbb1b7399))); - poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0x3ca3621f))); - poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0xbe082bc7))); - poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0x3f4c4228))); - - // 1.0f + erf(x * inv_sqrt2) = 1.0f + x * P(x^2) - poly = _mm512_fmadd_ps(poly, x, _mm512_set1_ps(1.0f)); - // x*0.5*(1 + x*Polynomial(x^2)) - poly = _mm512_mul_ps(poly, x_half); - - // combine: - // zone_id - // 1 -inf; -saturation_lbound : 0.0f - // 2 -saturation_lbound; -linear_ubound : x*0.5*(1 + x*Polynomial(x^2)) - // 3 -linear_ubound, linear_ubound : x*0.5 - // 4 linear_ubound : saturation_lbound : x*0.5*(1 + x*Polynomial(x^2)) - // 5 saturation_lbound: +inf : x - constexpr int neg_saturation_lbound = 0xc0a00000; - constexpr int linear_ubound = 0x33800000; - constexpr int saturation_lbound = 0x40a00000; - - auto mask_x_not_zone1 = _mm512_cmpnlt_ps_mask(x, _mm512_castsi512_ps(_mm512_set1_epi32(neg_saturation_lbound))); - x = _mm512_maskz_mov_ps(mask_x_not_zone1, x); - - auto mask_x_in_zone5 = _mm512_cmpnle_ps_mask(x_positive, _mm512_castsi512_ps(_mm512_set1_epi32(saturation_lbound))); - poly = _mm512_mask_mov_ps(poly, mask_x_in_zone5, x); - - auto mask_x_in_zone3 = _mm512_cmple_ps_mask(x_positive, _mm512_castsi512_ps(_mm512_set1_epi32(linear_ubound))); - poly = _mm512_mask_mov_ps(poly, mask_x_in_zone3, x_half); - return poly; - } - inline void kpack_tile_B0B1(void * _dst0, void * _dst1, const int8_t * _src, int stride, int src_rows) { #define FROM_B(i) ((1<<4)|(i)) static const uint32_t idx[16] = { 0,4,FROM_B(0),FROM_B(4), @@ -1279,23 +1232,7 @@ namespace PP { template<> struct is_f32i32 : std::true_type {}; - enum Steps { - NONE = 0, - DEQUANT = 1<<0, - BIAS = 1<<1, - GELU = 1<<2, - QUANT = 1<<3, - - BIAS_GELU = BIAS | GELU, - DEQUANT_BIAS_GELU = DEQUANT | BIAS_GELU, - DEQUANT_BIAS_GELU_QUANT = DEQUANT_BIAS_GELU | QUANT, - DEQUANT_BIAS_QUANT = DEQUANT | BIAS | QUANT, - DEQUANT_GELU_QUANT = DEQUANT | GELU | QUANT, - DEQUANT_QUANT = DEQUANT | QUANT, - - DEQUANT_GELU = DEQUANT | GELU, - DEQUANT_BIAS = DEQUANT | BIAS - }; + using Steps = postops_types; template struct BiasGeluStore { @@ -1403,8 +1340,12 @@ namespace PP { r1 = _mm512_add_ps(r1, bias1); } if (steps & GELU) { - r0 = functional::gelu_erf_minmax_approx(r0); - r1 = functional::gelu_erf_minmax_approx(r1); + r0 = gelu_erf_minmax_approx_avx512(r0); + r1 = gelu_erf_minmax_approx_avx512(r1); + } + if (steps & GELU_TANH) { + r0 = gelu_tanh_avx512(r0); + r1 = gelu_tanh_avx512(r1); } // quantize & store @@ -1945,7 +1886,7 @@ struct Matmul { m = M - 16; } _tile_zero(0); - if (tmmN == 1) { + if (tmmN == 1) { _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); } if (tmmN == 2) { @@ -2453,18 +2394,3 @@ struct GemAvB { }; } // namespace amx - -inline std::ostream & operator<<(std::ostream & os, const amx_kernel::PP::Steps & steps) { - os << "amx_kernel::PP::Steps::"; - if (steps == amx_kernel::PP::Steps::NONE) - os << "NONE"; - if (steps & amx_kernel::PP::Steps::DEQUANT) - os << "_DEQUANT"; - if (steps & amx_kernel::PP::Steps::BIAS) - os << "_BIAS"; - if (steps & amx_kernel::PP::Steps::GELU) - os << "_GELU"; - if (steps & amx_kernel::PP::Steps::QUANT) - os << "_QUANT"; - return os; -} diff --git a/tests/src/test_common.hpp b/tests/src/test_common.hpp index 1071158..f9acf1e 100644 --- a/tests/src/test_common.hpp +++ b/tests/src/test_common.hpp @@ -212,6 +212,8 @@ inline std::ostream & operator<<(std::ostream & os, const llmdnn::postops_types os << "_BIAS"; if (steps & llmdnn::GELU) os << "_GELU"; + if (steps & llmdnn::GELU_TANH) + os << "_GELU_TANH"; if (steps & llmdnn::QUANT) os << "_QUANT"; return os; diff --git a/tests/src/test_fc_kernel_amx.cpp b/tests/src/test_fc_kernel_amx.cpp index 308bc02..a27bbbe 100644 --- a/tests/src/test_fc_kernel_amx.cpp +++ b/tests/src/test_fc_kernel_amx.cpp @@ -46,6 +46,7 @@ class FCKernelTest : public TestWithParam { result << "A_" << dtype_to_str(dt_a) << "_B_" << dtype_to_str(dt_b) << "_C_" << dtype_to_str(dt_c) << (is_transpose ? "_transpose" : "") << "_postops_" << postops_type << "_M_" << M << "_N_" << N << "_K_" << K; + printf("%s\n", result.str().c_str()); return result.str(); } @@ -116,6 +117,11 @@ class FCKernelTest : public TestWithParam { return x * 0.5 * (1 + std::erf(x / std::sqrt(2))); }; } + if (_postops_type & GELU_TANH) { + act = [] (float x) { + return 0.5f * x * (1.0f + std::tanh(std::sqrt(2.0f / 3.1415926f) * x * (1 + 0.044715f * x * x))); + }; + } matmul(A, B, C_Ref, ptr_dq, ptr_bias, act, ptr_q); float thresh = 0.0001f; @@ -165,22 +171,32 @@ const std::vector types = { { dnnl_s8, dnnl_s8, dnnl_s8, DEQUANT_BIAS_QUANT }, { dnnl_s8, dnnl_s8, dnnl_s8, DEQUANT_GELU_QUANT }, { dnnl_s8, dnnl_s8, dnnl_s8, DEQUANT_BIAS_GELU_QUANT }, + { dnnl_s8, dnnl_s8, dnnl_s8, DEQUANT_GELU_TANH_QUANT }, + { dnnl_s8, dnnl_s8, dnnl_s8, DEQUANT_BIAS_GELU_TANH_QUANT }, { dnnl_s8, dnnl_s8, dnnl_bf16, DEQUANT }, { dnnl_s8, dnnl_s8, dnnl_bf16, DEQUANT_BIAS }, { dnnl_s8, dnnl_s8, dnnl_bf16, DEQUANT_GELU }, { dnnl_s8, dnnl_s8, dnnl_bf16, DEQUANT_BIAS_GELU }, + { dnnl_s8, dnnl_s8, dnnl_bf16, DEQUANT_GELU_TANH }, + { dnnl_s8, dnnl_s8, dnnl_bf16, DEQUANT_BIAS_GELU_TANH }, { dnnl_s8, dnnl_s8, dnnl_f32, DEQUANT }, { dnnl_s8, dnnl_s8, dnnl_f32, DEQUANT_BIAS }, { dnnl_s8, dnnl_s8, dnnl_f32, DEQUANT_GELU }, { dnnl_s8, dnnl_s8, dnnl_f32, DEQUANT_BIAS_GELU }, + { dnnl_s8, dnnl_s8, dnnl_f32, DEQUANT_GELU_TANH }, + { dnnl_s8, dnnl_s8, dnnl_f32, DEQUANT_BIAS_GELU_TANH }, { dnnl_bf16, dnnl_bf16, dnnl_bf16, NONE }, { dnnl_bf16, dnnl_bf16, dnnl_bf16, BIAS }, { dnnl_bf16, dnnl_bf16, dnnl_bf16, GELU }, { dnnl_bf16, dnnl_bf16, dnnl_bf16, BIAS_GELU }, + { dnnl_bf16, dnnl_bf16, dnnl_bf16, GELU_TANH }, + { dnnl_bf16, dnnl_bf16, dnnl_bf16, BIAS_GELU_TANH }, { dnnl_bf16, dnnl_bf16, dnnl_f32, NONE }, { dnnl_bf16, dnnl_bf16, dnnl_f32, BIAS }, { dnnl_bf16, dnnl_bf16, dnnl_f32, GELU }, { dnnl_bf16, dnnl_bf16, dnnl_f32, BIAS_GELU }, + { dnnl_bf16, dnnl_bf16, dnnl_f32, GELU_TANH }, + { dnnl_bf16, dnnl_bf16, dnnl_f32, BIAS_GELU_TANH }, // TODO: support weight compression // { dnnl_bf16, dnnl_s8, dnnl_f32, DEQUANT }, // { dnnl_bf16, dnnl_s8, dnnl_f32, DEQUANT_BIAS }, diff --git a/tests/src/test_gelu_kernel_avx512.cpp b/tests/src/test_gelu_kernel_avx512.cpp new file mode 100644 index 0000000..70a9a27 --- /dev/null +++ b/tests/src/test_gelu_kernel_avx512.cpp @@ -0,0 +1,97 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include +#include +#include +#include +#include +#include "gtest/gtest.h" +#include "llm_mm.hpp" +#include "common/tensor2d.hpp" +#include "common/tensor2d_helper.hpp" +#include "gelu_kernel_avx512.hpp" +#include "test_common.hpp" + +using namespace std; +using namespace llmdnn; +using ::testing::TestWithParam; +using ::testing::Values; +using ::testing::ValuesIn; + +using GeluTestParamSet = std::tuple< + std::string // data type + >; + +class GeluTest : public TestWithParam { +public: + static std::string getTestCaseName(const testing::TestParamInfo& obj) { + std::string types; + std::tie(types) = obj.param; + + std::ostringstream result; + result << types; + return result.str(); + } + +protected: + virtual void SetUp() override { + std::tie(_types) = GetParam(); + }; + + void gelu_ref(float* s, float* d, int n) { + if (_types == "tanh") { + for (int i = 0; i < n; i++) { + auto x = s[i]; + d[i] = 0.5f * x * (1.0f + std::tanh(std::sqrt(2.0f / 3.1415926f) * x * (1 + 0.044715f * x * x))); + } + } else { + for (int i = 0; i < n; i++) { + auto x = s[i]; + d[i] = 0.5f * x * (1.0f + std::erf(x / std::sqrt(2.0f))); + } + } + } + + void test(float thresh) { + tensor2D src(1, 16, true); + tensor2D dst(1, 16, true); + tensor2D ref(1, 16, true); + for (int i = 0; i < 16; i++) { + src[i] = std::sqrt(i) - 2; + } + __m512 s = _mm512_loadu_ps(src.data); + __m512 d; + if (_types == "tanh") + d = gelu_tanh_avx512(s); + else + d = gelu_erf_minmax_approx_avx512(s); + _mm512_storeu_ps(dst.data, d); + gelu_ref(src.data, ref.data, 16); + for (int i = 0; i < src.dims[1]; i++) { + float r = ref[i]; + float c = dst[i]; + if (std::abs(r - c) > thresh) { + FAIL() << " cur is not equal, pos: " << i << " opt: " << c << " ref: " << r; + } + } + } + + std::string _types; +}; + +TEST_P(GeluTest, Gelu) { + test(0.01f); +} + +const std::vector types = { + "tanh", "erf" +}; + +INSTANTIATE_TEST_SUITE_P(smoke_Gelu, GeluTest, + ::testing::Combine(ValuesIn(types)), + GeluTest::getTestCaseName); From eb427a61e3f8f4be007d94132cfd58c5c333484b Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Thu, 6 Jul 2023 18:41:18 +0800 Subject: [PATCH 21/54] add git commit id to package --- CMakeLists.txt | 12 ++++++++++++ src/CMakeLists.txt | 2 +- tests/src/test_fc_kernel_amx.cpp | 1 - 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index f99a3e7..7d2217c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -51,3 +51,15 @@ add_subdirectory(src) if (CPU_EXTENSIONS_BUILD_TESTS) add_subdirectory(tests) endif() + +# Get the latest commit hash +execute_process( + COMMAND git rev-parse HEAD + WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR} + OUTPUT_VARIABLE GIT_HASH + OUTPUT_STRIP_TRAILING_WHITESPACE + ) +file(WRITE ${CMAKE_BINARY_DIR}/git-state.txt ${GIT_HASH}) +install(FILES + ${CMAKE_BINARY_DIR}/git-state.txt + DESTINATION ${CMAKE_INSTALL_PREFIX}) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 2369c0b..93c0e79 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -14,7 +14,7 @@ target_include_directories(${PROJECT_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} PUBLIC $ $/${CMAKE_INSTALL_INCLUDEDIR}>) -set(CMAKE_DST ${CMAKE_INSTALL_LIBDIR}/cmake/${PROJECT_NAME}) +set(CMAKE_DST lib/cmake/${PROJECT_NAME}) # header files include(GNUInstallDirs) install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../include/ diff --git a/tests/src/test_fc_kernel_amx.cpp b/tests/src/test_fc_kernel_amx.cpp index a27bbbe..0046541 100644 --- a/tests/src/test_fc_kernel_amx.cpp +++ b/tests/src/test_fc_kernel_amx.cpp @@ -46,7 +46,6 @@ class FCKernelTest : public TestWithParam { result << "A_" << dtype_to_str(dt_a) << "_B_" << dtype_to_str(dt_b) << "_C_" << dtype_to_str(dt_c) << (is_transpose ? "_transpose" : "") << "_postops_" << postops_type << "_M_" << M << "_N_" << N << "_K_" << K; - printf("%s\n", result.str().c_str()); return result.str(); } From 16ea5655dcddf449064c7e9146923a1b7c07c106 Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Thu, 6 Jul 2023 20:17:57 +0800 Subject: [PATCH 22/54] add rotary avx2 kernel --- src/rotary_kernel_avx2.hpp | 108 +++++++++++++++++++++++++ src/utility_kernel_avx2.hpp | 59 ++++++++++++++ tests/src/test_rotary_kernel_avx2.cpp | 109 ++++++++++++++++++++++++++ 3 files changed, 276 insertions(+) create mode 100644 src/rotary_kernel_avx2.hpp create mode 100644 src/utility_kernel_avx2.hpp create mode 100644 tests/src/test_rotary_kernel_avx2.cpp diff --git a/src/rotary_kernel_avx2.hpp b/src/rotary_kernel_avx2.hpp new file mode 100644 index 0000000..8602a04 --- /dev/null +++ b/src/rotary_kernel_avx2.hpp @@ -0,0 +1,108 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#ifdef _WIN32 +#include +#else +#include +#include +#endif +#include "common/bf16.hpp" +#include "llm_types.hpp" +#include "utility_kernel_avx2.hpp" + +namespace llmdnn { + inline void rotary_avx2(size_t N, float* cos, float* sin, float* q_src, float* k_src, float* q_dst, float* k_dst) { + auto half = N / 2; + // for (size_t i = 0; i < half; i++) { + // q_dst[i] = q_src[i] * cos[i] - q_src[i + half] * sin[i]; + // k_dst[i] = k_src[i] * cos[i] - k_src[i + half] * sin[i]; + // } + // for (size_t i = half; i < N; i++) { + // q_dst[i] = q_src[i] * cos[i] + q_src[i - half] * sin[i]; + // k_dst[i] = k_src[i] * cos[i] + k_src[i - half] * sin[i]; + // } + size_t tail = half % 8; + auto x_mask = get_mask(tail); + size_t i; + for (i = 0; i < half - tail; i += 8) { + auto q_f = _mm256_loadu_ps(q_src + i + half); + auto k_f = _mm256_loadu_ps(k_src + i + half); + auto cos_f = _mm256_loadu_ps(cos + i); + auto sin_f = _mm256_loadu_ps(sin + i); + auto q_dst_f = _mm256_mul_ps(q_f, sin_f); + auto k_dst_f = _mm256_mul_ps(k_f, sin_f); + + q_f = _mm256_loadu_ps(q_src + i); + k_f = _mm256_loadu_ps(k_src + i); + + q_dst_f = _mm256_fmsub_ps(q_f, cos_f, q_dst_f); + k_dst_f = _mm256_fmsub_ps(k_f, cos_f, k_dst_f); + + _mm256_storeu_ps(q_dst + i, q_dst_f); + _mm256_storeu_ps(k_dst + i, k_dst_f); + } + if (tail) { + auto q_f = _mm256_maskload_ps(q_src + i + half, x_mask); + auto k_f = _mm256_maskload_ps(k_src + i + half, x_mask); + auto cos_f = _mm256_maskload_ps(cos + i, x_mask); + auto sin_f = _mm256_maskload_ps(sin + i, x_mask); + auto q_dst_f = _mm256_mul_ps(q_f, sin_f); + auto k_dst_f = _mm256_mul_ps(k_f, sin_f); + + q_f = _mm256_maskload_ps(q_src + i, x_mask); + k_f = _mm256_maskload_ps(k_src + i, x_mask); + + q_dst_f = _mm256_fmsub_ps(q_f, cos_f, q_dst_f); + k_dst_f = _mm256_fmsub_ps(k_f, cos_f, k_dst_f); + + _mm256_maskstore_ps(q_dst + i, x_mask, q_dst_f); + _mm256_maskstore_ps(k_dst + i, x_mask, k_dst_f); + } + // second half + q_src += half; + k_src += half; + cos += half; + sin += half; + q_dst += half; + k_dst += half; + for (i = 0; i < half - tail; i += 8) { + auto q_f = _mm256_loadu_ps(q_src + i - half); + auto k_f = _mm256_loadu_ps(k_src + i - half); + auto cos_f = _mm256_loadu_ps(cos + i); + auto sin_f = _mm256_loadu_ps(sin + i); + auto q_dst_f = _mm256_mul_ps(q_f, sin_f); + auto k_dst_f = _mm256_mul_ps(k_f, sin_f); + + q_f = _mm256_loadu_ps(q_src + i); + k_f = _mm256_loadu_ps(k_src + i); + + q_dst_f = _mm256_fmadd_ps(q_f, cos_f, q_dst_f); + k_dst_f = _mm256_fmadd_ps(k_f, cos_f, k_dst_f); + + _mm256_storeu_ps(q_dst + i, q_dst_f); + _mm256_storeu_ps(k_dst + i, k_dst_f); + } + if (tail) { + auto q_f = _mm256_maskload_ps(q_src + i - half, x_mask); + auto k_f = _mm256_maskload_ps(k_src + i - half, x_mask); + auto cos_f = _mm256_maskload_ps(cos + i, x_mask); + auto sin_f = _mm256_maskload_ps(sin + i, x_mask); + auto q_dst_f = _mm256_mul_ps(q_f, sin_f); + auto k_dst_f = _mm256_mul_ps(k_f, sin_f); + + q_f = _mm256_maskload_ps(q_src + i, x_mask); + k_f = _mm256_maskload_ps(k_src + i, x_mask); + + q_dst_f = _mm256_fmadd_ps(q_f, cos_f, q_dst_f); + k_dst_f = _mm256_fmadd_ps(k_f, cos_f, k_dst_f); + + _mm256_maskstore_ps(q_dst + i, x_mask, q_dst_f); + _mm256_maskstore_ps(k_dst + i, x_mask, k_dst_f); + } + } +} \ No newline at end of file diff --git a/src/utility_kernel_avx2.hpp b/src/utility_kernel_avx2.hpp new file mode 100644 index 0000000..3c1775e --- /dev/null +++ b/src/utility_kernel_avx2.hpp @@ -0,0 +1,59 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include "common/bf16.hpp" +#ifdef _WIN32 +#include +#else +#include +#include +#endif + +namespace llmdnn { + +#pragma once + +#include +#include +#include +#include + +#ifdef _WIN32 +#include +#else +#include +#include +#endif + +inline __m256i get_mask(int N7) { + static __m256i mask[] = { + _mm256_set_epi32( 0, 0, 0, 0, 0, 0, 0, 0), + _mm256_set_epi32( 0, 0, 0, 0, 0, 0, 0,-1), + _mm256_set_epi32( 0, 0, 0, 0, 0, 0,-1,-1), + _mm256_set_epi32( 0, 0, 0, 0, 0,-1,-1,-1), + _mm256_set_epi32( 0, 0, 0, 0,-1,-1,-1,-1), + _mm256_set_epi32( 0, 0, 0,-1,-1,-1,-1,-1), + _mm256_set_epi32( 0, 0,-1,-1,-1,-1,-1,-1), + _mm256_set_epi32( 0,-1,-1,-1,-1,-1,-1,-1), + _mm256_set_epi32(-1,-1,-1,-1,-1,-1,-1,-1), + }; + return _mm256_loadu_si256(&mask[N7]); +} + +// https://stackoverflow.com/questions/23189488/horizontal-sum-of-32-bit-floats-in-256-bit-avx-vector +static inline float _mm256_reduce_add_ps(__m256 x) { + /* ( x3+x7, x2+x6, x1+x5, x0+x4 ) */ + const __m128 x128 = _mm_add_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x)); + /* ( -, -, x1+x3+x5+x7, x0+x2+x4+x6 ) */ + const __m128 x64 = _mm_add_ps(x128, _mm_movehl_ps(x128, x128)); + /* ( -, -, -, x0+x1+x2+x3+x4+x5+x6+x7 ) */ + const __m128 x32 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55)); + /* Conversion to float is a no-op on x86-64 */ + return _mm_cvtss_f32(x32); +} + +} \ No newline at end of file diff --git a/tests/src/test_rotary_kernel_avx2.cpp b/tests/src/test_rotary_kernel_avx2.cpp new file mode 100644 index 0000000..0ab1fca --- /dev/null +++ b/tests/src/test_rotary_kernel_avx2.cpp @@ -0,0 +1,109 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include +#include +#include +#include +#include "gtest/gtest.h" +#include "llm_mm.hpp" +#include "common/tensor2d.hpp" +#include "common/tensor2d_helper.hpp" +#include "rotary_kernel_avx2.hpp" +#include "test_common.hpp" + +using namespace std; +using namespace llmdnn; +using ::testing::TestWithParam; +using ::testing::Values; +using ::testing::ValuesIn; + +using RotaryTestAVX2ParamSet = std::tuple< + data_type_t // data type + >; + +class RotaryTestAVX2 : public TestWithParam { +public: + static std::string getTestCaseName(const testing::TestParamInfo& obj) { + data_type_t types; + std::tie(types) = obj.param; + + std::ostringstream result; + result << dtype_to_str(types); + return result.str(); + } + +protected: + virtual void SetUp() override { + std::tie(_types) = GetParam(); + }; + + template + static void rotary_emb(size_t rotaryNdims, float* cos, float* sin, T* q_src, T* k_src, T* q_dst, T* k_dst) { + auto halfRotaryNdims = rotaryNdims / 2; + for (size_t i = 0; i < halfRotaryNdims; i++) { + q_dst[i] = q_src[i] * cos[i] - q_src[i + halfRotaryNdims] * sin[i]; + k_dst[i] = k_src[i] * cos[i] - k_src[i + halfRotaryNdims] * sin[i]; + } + for (size_t i = halfRotaryNdims; i < rotaryNdims; i++) { + q_dst[i] = q_src[i] * cos[i] + q_src[i - halfRotaryNdims] * sin[i]; + k_dst[i] = k_src[i] * cos[i] + k_src[i - halfRotaryNdims] * sin[i]; + } + } + + template + void test(float thresh) { + for (int n = 6; n < 129; n += 2) { + tensor2D q_src(1, n, true); + tensor2D k_src(1, n, true); + tensor2D q_dst(1, n, true); + tensor2D k_dst(1, n, true); + tensor2D q_dst_ref(1, n, true); + tensor2D k_dst_ref(1, n, true); + tensor2D cos(1, n, true); + tensor2D sin(1, n, true); + for (int i = 0; i < n; i++) { + q_src[i] = i % 19 - 10; + k_src[i] = i % 19 - 9; + cos[i] = i % 19 - 8; + sin[i] = i % 19 - 7; + } + rotary_emb(n, cos.data, sin.data, q_src.data, k_src.data, q_dst_ref.data, k_dst_ref.data); + rotary_avx2(n, cos.data, sin.data, q_src.data, k_src.data, q_dst.data, k_dst.data); + for (int i = 0; i < n; i++) { + float q = q_dst[i]; + float q_ref = q_dst_ref[i]; + float k = k_dst[i]; + float k_ref = k_dst_ref[i]; + if (std::abs(q - q_ref) > thresh) { + FAIL() << " q is not equal, N: " << n << " pos: " << i << " opt: " << q << " ref: " << q_ref; + } + if (std::abs(k - k_ref) > thresh) { + FAIL() << " k is not equal, N: " << n << " pos: " << i << " opt: " << k << " ref: " << k_ref; + } + } + } + } + + data_type_t _types; +}; + +TEST_P(RotaryTestAVX2, rotary) { + if (_types == dnnl_s8) { + ASSERT_TRUE(false); + } else { + test(0.01f); + } +} + +const std::vector types = { + dnnl_f32 +}; + +INSTANTIATE_TEST_SUITE_P(smoke_Rotary, RotaryTestAVX2, + ::testing::Combine(ValuesIn(types)), + RotaryTestAVX2::getTestCaseName); From 73f335da679d25ecbaa3ccfa3bd6ca29fe6b38c8 Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Fri, 7 Jul 2023 01:04:59 +0800 Subject: [PATCH 23/54] add bloom mha support --- include/llm_mha_gpt.hpp | 2 + src/mha_gpt_amx.cpp | 73 +++++++--- src/utility_kernel_avx512.hpp | 22 +++ tests/script/ext/attn_gpt.cpp | 12 +- tests/script/ext/emb_gpt.cpp | 6 +- tests/script/ext/mha_gpt.cpp | 21 +-- tests/script/test_mha_bloom.py | 237 +++++++++++++++++++++++++++++++ tests/script/test_mha_chatglm.py | 2 +- tests/script/test_mha_gpt.py | 2 +- 9 files changed, 341 insertions(+), 36 deletions(-) create mode 100644 tests/script/test_mha_bloom.py diff --git a/include/llm_mha_gpt.hpp b/include/llm_mha_gpt.hpp index befa348..83ea8be 100644 --- a/include/llm_mha_gpt.hpp +++ b/include/llm_mha_gpt.hpp @@ -50,6 +50,7 @@ class mha_gpt { // supported (qkv, dst): (bf16, bf16), (s8, s8) data_type_t qkv_precision; data_type_t dst_precision; + bool is_bloom; // for bloom mha }; struct exec_param { size_t batch; @@ -66,6 +67,7 @@ class mha_gpt { // [batch, 1, query_seq_len, key_seq_len], when is_causal_in_attention is true uint8_t* attn_output; // output, compact, shape: [batch, query_seq_len, num_heads * head_size] size_t head_stride_in_kv; // kv stride for next head; kv may be preallocated a big buffer + float* alibi; // only is_bloom is true will use // expected quant schema: // q,k,v use per tensor quant, attn_output may use per tensor/channel quant float q_dequant; diff --git a/src/mha_gpt_amx.cpp b/src/mha_gpt_amx.cpp index fd9971c..5667216 100644 --- a/src/mha_gpt_amx.cpp +++ b/src/mha_gpt_amx.cpp @@ -64,7 +64,7 @@ bool mha_gpt_impl_amx::create(const mha_gpt::create_param& param) { if (_create_param.qkv_precision == dnnl_s8) { qKtrGemm_i8xi8.resize(numThreads); for (size_t i = 0; i < numThreads; i++) { - qKtrGemm_i8xi8[i] = std::make_shared>(false, true); + qKtrGemm_i8xi8[i] = std::make_shared>(false, !param.is_bloom); } qKVGemm_u8xi8.resize(numThreads); for (size_t i = 0; i < numThreads; i++) { @@ -85,7 +85,7 @@ bool mha_gpt_impl_amx::create(const mha_gpt::create_param& param) { } qKtrGemm_BF16xBF16.resize(numThreads); for (size_t i = 0; i < numThreads; i++) { - qKtrGemm_BF16xBF16[i] = std::make_shared>(false, true); + qKtrGemm_BF16xBF16[i] = std::make_shared>(false, !param.is_bloom); } qKVGemm_BF16xBF16.resize(numThreads); for (size_t i = 0; i < numThreads; i++) { @@ -109,16 +109,17 @@ bool mha_gpt_impl_amx::create(const mha_gpt::create_param& param) { void mha_gpt_impl_amx::mha_bf16(const mha_gpt::exec_param ¶m) { uint8_t* pQIn0 = param.q; - auto& pKIn0 = param.k; - auto& attn_masks = param.attention_mask; - auto& pVIn0 = param.v; + auto pKIn0 = param.k; + auto attn_masks = param.attention_mask; + auto pVIn0 = param.v; uint8_t* pout = param.attn_output; + auto alibi = param.alibi; auto outPrcSize = get_precision_size(_create_param.qkv_precision); auto& gemAvB_ops = gemAvB_BF16xBF16; auto& qKtrGemm_ops = qKtrGemm_BF16xBF16; auto& qKVGemm_ops = qKVGemm_BF16xBF16; - bool is_vector = param.query_seq_len == 1 && _create_param.head_size >= 32 && _create_param.head_size <= 32 * 6; + bool is_vector = param.query_seq_len == 1 && _create_param.head_size >= 32 && _create_param.head_size <= 32 * 6 && !_create_param.is_bloom; size_t head_stride_in_q = _create_param.head_size_aligned * param.query_seq_len; size_t batch_stride_in_q = head_stride_in_q * _create_param.num_heads; size_t head_stride_in_attn = _create_param.head_size; @@ -185,10 +186,15 @@ void mha_gpt_impl_amx::mha_bf16(const mha_gpt::exec_param ¶m) { auto bufferMatMul1Out_local = reinterpret_cast(bufferMatMul1Out.get() + threadNum * bufferMatMul1OutSize); tensor2D matQ(seq_cout, _create_param.head_size, reinterpret_cast(pQIn0_aux), _create_param.head_size_aligned * sizeof(ov::bfloat16)); - tensor2D matK(param.key_seq_len, _create_param.head_size, reinterpret_cast(pKIn0_aux), _create_param.head_size_aligned * sizeof(ov::bfloat16)); tensor2D matQK(seq_cout, param.key_seq_len, reinterpret_cast(bufferMatMul0Out_local), rndup(param.key_seq_len * sizeof(float), 64)); amx_kernel::PP::BiasGeluStore pp(matQK); - (*qKtrGemm_ops[threadNum])(matQ, matK, 0, param.key_seq_len, pp, pKIn0_aux == prev_k); + if (!_create_param.is_bloom) { + tensor2D matK(param.key_seq_len, _create_param.head_size, reinterpret_cast(pKIn0_aux), _create_param.head_size_aligned * sizeof(ov::bfloat16)); + (*qKtrGemm_ops[threadNum])(matQ, matK, 0, param.key_seq_len, pp, pKIn0_aux == prev_k); + } else { + tensor2D matK(_create_param.head_size, param.key_seq_len, reinterpret_cast(pKIn0_aux), param.key_seq_len * sizeof(ov::bfloat16)); + (*qKtrGemm_ops[threadNum])(matQ, matK, 0, param.key_seq_len, pp, pKIn0_aux == prev_k); + } prev_k = pKIn0_aux; auto pMatMul0Out = bufferMatMul0Out_local; @@ -198,7 +204,14 @@ void mha_gpt_impl_amx::mha_bf16(const mha_gpt::exec_param ¶m) { for (int m = 0; m < seq_cout; m++) { float* src = reinterpret_cast(pMatMul0Out + m * rndup(param.key_seq_len * sizeof(float), 64)); ov::bfloat16* dst = reinterpret_cast(pMatMul0Out + m * rndup(param.key_seq_len * sizeof(ov::bfloat16), 64)); - mul_add_f32_avx512(src, src, _create_param.normal_factor, pAddIn1_aux + (m + seq_start) * param.key_seq_len, param.key_seq_len); + if (!_create_param.is_bloom) + mul_add_f32_avx512(src, src, _create_param.normal_factor, pAddIn1_aux + (m + seq_start) * param.key_seq_len, param.key_seq_len); + else + // alibi shape: [batch, head_num, 1, key_seq_len] + mul_add2_f32_avx512(src, src, _create_param.normal_factor, + alibi + i0 * _create_param.num_heads * param.key_seq_len + i1 * param.key_seq_len, + pAddIn1_aux + (m + seq_start) * param.key_seq_len, + param.key_seq_len); softmax_avx512(dst, src, param.key_seq_len, nullptr); } } else { @@ -208,7 +221,13 @@ void mha_gpt_impl_amx::mha_bf16(const mha_gpt::exec_param ¶m) { for (int m = 0; m < seq_cout; m++) { float* src = reinterpret_cast(pMatMul0Out + m * rndup(param.key_seq_len * sizeof(float), 64)); ov::bfloat16* dst = reinterpret_cast(pMatMul0Out + m * rndup(param.key_seq_len * sizeof(ov::bfloat16), 64)); - mul_add_f32_avx512(src, src, _create_param.normal_factor, pAddIn1_aux, valid_softmax_items); + if (!_create_param.is_bloom) + mul_add_f32_avx512(src, src, _create_param.normal_factor, pAddIn1_aux, valid_softmax_items); + else + mul_add2_f32_avx512(src, src, _create_param.normal_factor, + alibi + i0 * _create_param.num_heads * param.key_seq_len + i1 * param.key_seq_len, + pAddIn1_aux, + valid_softmax_items); softmax_avx512(dst, src, valid_softmax_items, nullptr); // attn_scores = torch.where(causal_mask, attn_scores, mask_value) if (param.key_seq_len > valid_softmax_items) { @@ -237,16 +256,17 @@ void mha_gpt_impl_amx::mha_bf16(const mha_gpt::exec_param ¶m) { void mha_gpt_impl_amx::mha_i8(const mha_gpt::exec_param ¶m) { uint8_t* pQIn0 = param.q; - auto& pKIn0 = param.k; - auto& attn_masks = param.attention_mask; - auto& pVIn0 = param.v; + auto pKIn0 = param.k; + auto attn_masks = param.attention_mask; + auto pVIn0 = param.v; uint8_t* pout = param.attn_output; + auto alibi = param.alibi; auto outPrcSize = get_precision_size(_create_param.dst_precision); auto& gemAvB_ops = gemAvB_i8xi8; auto& qKtrGemm_ops = qKtrGemm_i8xi8; auto& qKVGemm_ops = qKVGemm_u8xi8; - bool is_vector = param.query_seq_len == 1 && _create_param.head_size >= 64 && _create_param.head_size <= 64 * 6; + bool is_vector = param.query_seq_len == 1 && _create_param.head_size >= 64 && _create_param.head_size <= 64 * 6 && !_create_param.is_bloom; // dequant param auto mul_scales = _create_param.normal_factor * param.q_dequant * param.k_dequant; // prepare for per channel @@ -324,10 +344,15 @@ void mha_gpt_impl_amx::mha_i8(const mha_gpt::exec_param ¶m) { auto bufferMatMul1Out_local = reinterpret_cast(bufferMatMul1Out.get() + threadNum * bufferMatMul1OutSize); tensor2D matQ(seq_cout, _create_param.head_size, reinterpret_cast(pQIn0_aux), _create_param.head_size_aligned * sizeof(int8_t)); - tensor2D matK(param.key_seq_len, _create_param.head_size, reinterpret_cast(pKIn0_aux), _create_param.head_size_aligned * sizeof(int8_t)); tensor2D matQK(seq_cout, param.key_seq_len, reinterpret_cast(bufferMatMul0Out_local), rndup(param.key_seq_len * sizeof(float), 64)); amx_kernel::PP::BiasGeluStore pp(matQK); - (*qKtrGemm_ops[threadNum])(matQ, matK, 0, param.key_seq_len, pp, prev_k == pKIn0_aux); + if (!_create_param.is_bloom) { + tensor2D matK(param.key_seq_len, _create_param.head_size, reinterpret_cast(pKIn0_aux), _create_param.head_size_aligned * sizeof(int8_t)); + (*qKtrGemm_ops[threadNum])(matQ, matK, 0, param.key_seq_len, pp, prev_k == pKIn0_aux); + } else { + tensor2D matK(_create_param.head_size, param.key_seq_len, reinterpret_cast(pKIn0_aux), param.key_seq_len * sizeof(int8_t)); + (*qKtrGemm_ops[threadNum])(matQ, matK, 0, param.key_seq_len, pp, prev_k == pKIn0_aux); + } prev_k = pKIn0_aux; auto pMatMul0Out = bufferMatMul0Out_local; @@ -338,6 +363,14 @@ void mha_gpt_impl_amx::mha_i8(const mha_gpt::exec_param ¶m) { float* src = reinterpret_cast(pMatMul0Out + m * rndup(param.key_seq_len * sizeof(float), 64)); uint8_t* dst = reinterpret_cast(pMatMul0Out + m * rndup(param.key_seq_len * sizeof(uint8_t), 64)); mul_add_f32_avx512(src, src, mul_scales, pAddIn1_aux + (m + seq_start) * param.key_seq_len, param.key_seq_len); + if (!_create_param.is_bloom) + mul_add_f32_avx512(src, src, mul_scales, pAddIn1_aux + (m + seq_start) * param.key_seq_len, param.key_seq_len); + else + // alibi shape: [batch, head_num, 1, key_seq_len] + mul_add2_f32_avx512(src, src, mul_scales, + alibi + i0 * _create_param.num_heads * param.key_seq_len + i1 * param.key_seq_len, + pAddIn1_aux + (m + seq_start) * param.key_seq_len, + param.key_seq_len); softmax_avx512(dst, src, param.key_seq_len, param.qk_quant); } } else { @@ -347,7 +380,13 @@ void mha_gpt_impl_amx::mha_i8(const mha_gpt::exec_param ¶m) { for (int m = 0; m < seq_cout; m++) { float* src = reinterpret_cast(pMatMul0Out + m * rndup(param.key_seq_len * sizeof(float), 64)); uint8_t* dst = reinterpret_cast(pMatMul0Out + m * rndup(param.key_seq_len * sizeof(uint8_t), 64)); - mul_add_f32_avx512(src, src, mul_scales, pAddIn1_aux, valid_softmax_items); + if (!_create_param.is_bloom) + mul_add_f32_avx512(src, src, mul_scales, pAddIn1_aux, valid_softmax_items); + else + mul_add2_f32_avx512(src, src, mul_scales, + alibi + i0 * _create_param.num_heads * param.key_seq_len + i1 * param.key_seq_len, + pAddIn1_aux, + valid_softmax_items); softmax_avx512(dst, src, valid_softmax_items, param.qk_quant); // attn_scores = torch.where(causal_mask, attn_scores, mask_value) if (param.key_seq_len > valid_softmax_items) { diff --git a/src/utility_kernel_avx512.hpp b/src/utility_kernel_avx512.hpp index e47f5c1..109888c 100644 --- a/src/utility_kernel_avx512.hpp +++ b/src/utility_kernel_avx512.hpp @@ -111,4 +111,26 @@ inline void mul_add_f32_avx512(float* dst, float* src, float mul, float* add, in } } +inline void mul_add2_f32_avx512(float* dst, float* src, float mul, float* add1, float* add2, int ele_num) { + auto mul_f = _mm512_set1_ps(mul); + int i; + auto tail = ele_num % 16; + __mmask16 msk = _cvtu32_mask16(0xFFFFu >> (16 - tail)); + for (i = 0; i < ele_num - tail; i += 16) { + auto a_f = _mm512_loadu_ps(src); + auto add1_f = _mm512_loadu_ps(add1); + auto add2_f = _mm512_loadu_ps(add2); + _mm512_storeu_ps(dst, _mm512_add_ps(_mm512_fmadd_ps(a_f, mul_f, add1_f), add2_f)); + src += 16; + dst += 16; + add1 += 16; + add2 += 16; + } + if (tail) { + auto a_f = _mm512_maskz_loadu_ps(msk, src); + auto add1_f = _mm512_maskz_loadu_ps(msk, add1); + auto add2_f = _mm512_maskz_loadu_ps(msk, add2); + _mm512_mask_storeu_ps(dst, msk, _mm512_add_ps(_mm512_fmadd_ps(a_f, mul_f, add1_f), add2_f)); + } +} } \ No newline at end of file diff --git a/tests/script/ext/attn_gpt.cpp b/tests/script/ext/attn_gpt.cpp index a87a3a1..8b7511e 100644 --- a/tests/script/ext/attn_gpt.cpp +++ b/tests/script/ext/attn_gpt.cpp @@ -70,7 +70,7 @@ attn_gpt::attn_gpt(): _emb_gpt(std::make_shared()), bool attn_gpt::create(const attn_gpt::create_param& param) { _create_param = param; - llmdnn::emb_gpt::create_param emb_param; + llmdnn::emb_gpt::create_param emb_param = {0}; emb_param.num_heads = param.num_heads; emb_param.head_size = param.head_size; emb_param.head_size_aligned = param.head_size_aligned; @@ -82,7 +82,7 @@ bool attn_gpt::create(const attn_gpt::create_param& param) { if (!_emb_gpt->create(emb_param)) return false; - llmdnn::mha_gpt::create_param mha_param; + llmdnn::mha_gpt::create_param mha_param = {0}; mha_param.num_heads = param.num_heads; mha_param.head_size = param.head_size; mha_param.head_size_aligned = param.head_size_aligned; @@ -104,7 +104,7 @@ void attn_gpt::exec(const attn_gpt::exec_param& param) { _query_cached_batch = param.batch; } - llmdnn::emb_gpt::exec_param emb_param; + llmdnn::emb_gpt::exec_param emb_param = {0}; emb_param.batch = param.batch; emb_param.query_seq_len = param.query_seq_len; emb_param.past_seq_len = param.past_seq_len; @@ -125,7 +125,7 @@ void attn_gpt::exec(const attn_gpt::exec_param& param) { emb_param.sin = param.sin; _emb_gpt->exec(emb_param); - llmdnn::mha_gpt::exec_param mha_param; + llmdnn::mha_gpt::exec_param mha_param = {0}; mha_param.batch = param.batch; mha_param.query_seq_len = param.query_seq_len; mha_param.key_seq_len = param.query_seq_len + param.past_seq_len; @@ -152,7 +152,7 @@ void regclass_attn_gpt(pybind11::module m) { const size_t max_seq_len, const size_t rotary_dims, bool use_position2d) { - attn_gpt::create_param param; + attn_gpt::create_param param = {0}; param.num_heads = num_heads; param.head_size = head_size; param.head_size_aligned = head_size_aligned; @@ -204,7 +204,7 @@ void regclass_attn_gpt(pybind11::module m) { AT_ASSERT(past_seq_len <= layer_past_key_dst.size(2) && head_size <= layer_past_key_dst.size(3) && query_seq_len <= layer_past_key_dst.size(2)); - attn_gpt::exec_param param; + attn_gpt::exec_param param = {0}; param.batch = batch; param.query_seq_len = query_seq_len; param.past_seq_len = past_seq_len; diff --git a/tests/script/ext/emb_gpt.cpp b/tests/script/ext/emb_gpt.cpp index cbf4aa8..669ac44 100644 --- a/tests/script/ext/emb_gpt.cpp +++ b/tests/script/ext/emb_gpt.cpp @@ -24,7 +24,7 @@ void regclass_emb_gpt(pybind11::module m) { const std::string dst_precision_name, const size_t rotary_dims, bool use_position2d) { - llmdnn::emb_gpt::create_param param; + llmdnn::emb_gpt::create_param param = {0}; param.num_heads = num_heads; param.head_size = head_size; param.head_size_aligned = head_size_aligned; @@ -75,7 +75,7 @@ void regclass_emb_gpt(pybind11::module m) { AT_ASSERT(past_seq_len <= layer_past_key_dst.size(2) && head_size <= layer_past_key_dst.size(3) && query_seq_len <= layer_past_key_dst.size(2)); - llmdnn::emb_gpt::exec_param param; + llmdnn::emb_gpt::exec_param param = {0}; param.batch = batch; param.query_seq_len = query_seq_len; param.past_seq_len = past_seq_len; @@ -136,7 +136,7 @@ void regclass_emb_gpt(pybind11::module m) { AT_ASSERT(past_seq_len <= layer_past_key_dst.size(2) && head_size <= layer_past_key_dst.size(3) && query_seq_len <= layer_past_key_dst.size(2)); - llmdnn::emb_gpt::exec_param param; + llmdnn::emb_gpt::exec_param param = {0}; param.batch = batch; param.query_seq_len = query_seq_len; param.past_seq_len = past_seq_len; diff --git a/tests/script/ext/mha_gpt.cpp b/tests/script/ext/mha_gpt.cpp index 704a1fb..c3ee7ff 100644 --- a/tests/script/ext/mha_gpt.cpp +++ b/tests/script/ext/mha_gpt.cpp @@ -21,8 +21,9 @@ void regclass_mha_gpt(pybind11::module m) { const float normal_factor, const std::string qkv_precision_name, const std::string dst_precision_name, - const size_t max_seq_len) { - llmdnn::mha_gpt::create_param param; + const size_t max_seq_len, + bool is_bloom) { + llmdnn::mha_gpt::create_param param = {0}; param.num_heads = num_heads; param.head_size = head_size; param.head_size_aligned = head_size_aligned; @@ -30,6 +31,7 @@ void regclass_mha_gpt(pybind11::module m) { param.qkv_precision = llmdnn::get_dt_from_str(qkv_precision_name); param.dst_precision = llmdnn::get_dt_from_str(dst_precision_name); param.max_seq_len = max_seq_len; + param.is_bloom = is_bloom; if (param.qkv_precision == llmdnn::dnnl_data_type_undef) throw pybind11::type_error("Incorrect qkv type " + qkv_precision_name); if (param.dst_precision == llmdnn::dnnl_data_type_undef) @@ -44,15 +46,16 @@ void regclass_mha_gpt(pybind11::module m) { py::arg("qkv_precision_name"), py::arg("dst_precision_name"), py::arg("max_seq_len"), + py::arg("is_bloom") = false, R"( Create mha :param num_heads: heads number. :type num_heads: int )"); - cls.def("exec", [] (llmdnn::mha_gpt& self, const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v, const torch::Tensor& attn_mask, int64_t head_size, int64_t key_seq_len) { + cls.def("exec", [] (llmdnn::mha_gpt& self, const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v, const torch::Tensor& alibi, const torch::Tensor& attn_mask, int64_t head_size, int64_t key_seq_len) { // q: [batch, num_heads, query_seq_len, head_size_aligned] - // k: [batch, num_heads, max_seq_len, head_size_aligned] valid in max_seq_len: key_seq_len + // k: [batch, num_heads, head_size_aligned, max_seq_len] valid in max_seq_len: key_seq_len // v: [batch, num_heads, max_seq_len, head_size_aligned] valid in max_seq_len: key_seq_len // attn_mask: [batch, 1, 1/query_seq_len, key_seq_len] // out: [batch, query_seq_len, num_heads * head_size] @@ -61,14 +64,14 @@ void regclass_mha_gpt(pybind11::module m) { auto num_heads = q.size(1); auto query_seq_len = q.size(2); auto head_size_aligned = q.size(3); - auto max_seq_len = k.size(2); + auto max_seq_len = v.size(2); auto attn_len = attn_mask.size(3); AT_ASSERT(max_seq_len == v.size(2) && batch == k.size(0) && batch == v.size(0) && batch == attn_mask.size(0) && num_heads == k.size(1) && num_heads == v.size(1) && - head_size_aligned == k.size(3) && head_size_aligned == v.size(3)); + head_size_aligned == v.size(3)); - llmdnn::mha_gpt::exec_param param; + llmdnn::mha_gpt::exec_param param = {0}; param.batch = batch; param.query_seq_len = query_seq_len; param.key_seq_len = key_seq_len == 0 ? max_seq_len : key_seq_len; @@ -86,6 +89,7 @@ void regclass_mha_gpt(pybind11::module m) { param.k[i] = reinterpret_cast(k[i].data_ptr()); param.v[i] = reinterpret_cast(v[i].data_ptr()); } + param.alibi = alibi.data_ptr(); self.exec(param); return out; @@ -93,6 +97,7 @@ void regclass_mha_gpt(pybind11::module m) { py::arg("q"), py::arg("k"), py::arg("v"), + py::arg("alibi"), py::arg("attn_mask"), py::arg("head_size") = 0, py::arg("key_seq_len") = 0, @@ -121,7 +126,7 @@ void regclass_mha_gpt(pybind11::module m) { num_heads == k.size(1) && num_heads == v.size(1) && head_size_aligned == k.size(3) && head_size_aligned == v.size(3)); - llmdnn::mha_gpt::exec_param param; + llmdnn::mha_gpt::exec_param param = {0}; param.batch = batch; param.query_seq_len = query_seq_len; param.key_seq_len = key_seq_len == 0 ? max_seq_len : key_seq_len; diff --git a/tests/script/test_mha_bloom.py b/tests/script/test_mha_bloom.py new file mode 100644 index 0000000..0e98148 --- /dev/null +++ b/tests/script/test_mha_bloom.py @@ -0,0 +1,237 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +import math +import sys +import torch +import numpy as np +import llmdnn as ld +from torch import nn +import torch.nn.functional as F + +# copy from transformers/models/bloom/modeling_bloom.py +class BloomAttention(nn.Module): + def __init__(self, head_dim:int, num_heads:int): + super().__init__() + + # self.pretraining_tp = config.pretraining_tp + # self.slow_but_exact = config.slow_but_exact + + # self.hidden_size = config.hidden_size + self.num_heads = num_heads + self.head_dim = head_dim + # self.split_size = self.hidden_size + # self.hidden_dropout = config.hidden_dropout + + # if self.head_dim * self.num_heads != self.hidden_size: + # raise ValueError( + # f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:" + # f" {self.num_heads})." + # ) + + # Layer-wise attention scaling + self.inv_norm_factor = 1.0 / math.sqrt(head_dim) + self.beta = 1.0 + + # self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True) + # self.dense = nn.Linear(self.hidden_size, self.hidden_size) + # self.attention_dropout = nn.Dropout(config.attention_dropout) + + # def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # """ + # Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory + # storage as `fused_qkv` + + # Args: + # fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim] + + # Returns: + # query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim] + # value: [batch_size, seq_length, num_heads, head_dim] + # """ + # batch_size, seq_length, three_times_hidden_size = fused_qkv.shape + # fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim) + # return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :] + + def _merge_heads(self, x: torch.Tensor) -> torch.Tensor: + """ + Merge heads together over the last dimenstion + + Args: + x: (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim] + + Returns: + torch.tensor: [batch_size, seq_length, num_heads * head_dim] + """ + # What we want to achieve is: + # batch_size * num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads * head_dim + batch_size_and_num_heads, seq_length, _ = x.shape + batch_size = batch_size_and_num_heads // self.num_heads + + # First view to decompose the batch size + # batch_size * num_heads, seq_length, head_dim -> batch_size, num_heads, seq_length, head_dim + x = x.view(batch_size, self.num_heads, seq_length, self.head_dim) + + # batch_size, num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads, head_dim + x = x.permute(0, 2, 1, 3) + + # batch_size, seq_length, num_heads, head_dim -> batch_size, seq_length, num_heads * head_dim + return x.reshape(batch_size, seq_length, self.num_heads * self.head_dim) + + def forward( + self, + query_layer: torch.Tensor, # [batch * head_num, q_len, head_size] + key_layer: torch.Tensor, # [batch * head_num, head_size, q_len+kv_len] + value_layer: torch.Tensor, # [batch * head_num, q_len+kv_len, head_size] + alibi: torch.Tensor, # [batch * head_num, 1, q_len+kv_len] + attention_mask: torch.Tensor, # [batch * head_num, q_len, q_len+kv_len] + ): + # fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] + + # # 3 x [batch_size, seq_length, num_heads, head_dim] + # (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) + + batch_size, _, q_length, _ = query_layer.shape + + query_layer = query_layer.reshape(-1, q_length, self.head_dim) + key_layer = key_layer.reshape(-1, key_layer.size(2), key_layer.size(3)) + value_layer = value_layer.reshape(-1, value_layer.size(2), value_layer.size(3)) + # if layer_past is not None: + # past_key, past_value = layer_past + # # concatenate along seq_length dimension: + # # - key: [batch_size * self.num_heads, head_dim, kv_length] + # # - value: [batch_size * self.num_heads, kv_length, head_dim] + # key_layer = torch.cat((past_key, key_layer), dim=2) + # value_layer = torch.cat((past_value, value_layer), dim=1) + + _, _, kv_length = key_layer.shape + + # if use_cache is True: + # present = (key_layer, value_layer) + # else: + # present = None + + # [batch_size * num_heads, q_length, kv_length] + # we use `torch.Tensor.baddbmm` instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11 + matmul_result = alibi.baddbmm( + batch1=query_layer, + batch2=key_layer, + beta=self.beta, + alpha=self.inv_norm_factor, + ) + + # change view to [batch_size, num_heads, q_length, kv_length] + attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length) + + # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length] + input_dtype = attention_scores.dtype + # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38` + if input_dtype == torch.float16: + attention_scores = attention_scores.to(torch.float) + #attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min) + attn_weights = attention_scores + attention_mask + attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype) + + # change view [batch_size x num_heads, q_length, kv_length] + attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, kv_length) + + # matmul: [batch_size * num_heads, q_length, head_dim] + context_layer = torch.bmm(attention_probs_reshaped, value_layer) + + # change view [batch_size, num_heads, q_length, head_dim] + context_layer = self._merge_heads(context_layer) + + # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232 + # if self.pretraining_tp > 1 and self.slow_but_exact: + # slices = self.hidden_size / self.pretraining_tp + # output_tensor = torch.zeros_like(context_layer) + # for i in range(self.pretraining_tp): + # output_tensor = output_tensor + F.linear( + # context_layer[:, :, int(i * slices) : int((i + 1) * slices)], + # self.dense.weight[:, int(i * slices) : int((i + 1) * slices)], + # ) + # else: + # output_tensor = self.dense(context_layer) + + # output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training) + + # outputs = (output_tensor, present) + # if output_attentions: + # outputs += (attention_probs,) + + return context_layer + +class BloomAttentionExt: + def __init__(self, num_attention_heads, hidden_size, max_position_embeddings, is_int8=False): + self.mha = ld.mha_gpt() + num_heads = num_attention_heads + head_size = hidden_size // num_attention_heads + max_seq_len = max_position_embeddings + + head_size_aligned = head_size + normal_factor = 1.0 / math.sqrt(head_size) + qkv_precision_name = 's8' if is_int8 else 'bf16' + dst_precision_name = 's8' if is_int8 else 'bf16' + self.mha.create(num_heads, head_size, head_size_aligned, normal_factor, qkv_precision_name, + dst_precision_name, max_seq_len, True) + + def forward(self, query, key, value, alibi, attention_mask): + return self.mha.exec(query, key, value, alibi, attention_mask) + +HEAD_NUM = 32 +SIZE_PER_HEAD = 80 +HIDDEN_SIZE = HEAD_NUM * SIZE_PER_HEAD +MAX_POSITION_EMBEDDINGS = 1024 #2048 +def get_ref_model(): + ref_net = BloomAttention(SIZE_PER_HEAD, HEAD_NUM) + ref_net = ref_net.to(dtype=torch.bfloat16) + return ref_net + +def test_bloom(): + inputs = [ + # q, k, v, attn_mask + # q: [batch, num_heads, query_seq_len, head_size] + # k: [batch, num_heads, head_size, key_seq_len] + # v: [batch, num_heads, value_seq_len, head_size] + # alibi: [batch, num_heads, 1, key_seq_len] + # attn: [2, 1, 1, key_seq_len] + (np.random.random(size=[2, HEAD_NUM, 2, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, SIZE_PER_HEAD, 32]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 32, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 1, 32]).astype(np.float32), + np.zeros([2, 1, 2, 32], dtype=np.float32)), + (np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, SIZE_PER_HEAD, 200]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 1, 200]).astype(np.float32), + np.zeros([2, 1, 200, 200], dtype=np.float32)), + (np.random.random(size=[2, HEAD_NUM, 1, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, SIZE_PER_HEAD, 200]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 1, 200]).astype(np.float32), + np.zeros([2, 1, 1, 200], dtype=np.float32)), + ] + ref_net = get_ref_model() + net = BloomAttentionExt(HEAD_NUM, HIDDEN_SIZE, MAX_POSITION_EMBEDDINGS) + with torch.cpu.amp.autocast(): + for (i, input) in enumerate(inputs): + q, k, v, alibi, attn_mask = input + q = torch.from_numpy(q).to(torch.bfloat16) + k = torch.from_numpy(k).to(torch.bfloat16) + v = torch.from_numpy(v).to(torch.bfloat16) + alibi = torch.from_numpy(alibi) # to(torch.bfloat16) + alibi = alibi.view(-1, alibi.size(2), alibi.size(3)) + attn_mask = torch.from_numpy(attn_mask) + attn_mask[:,:,:,-2:] = torch.finfo(torch.float32).min + ref_output = ref_net.forward(q, k, v, alibi, attn_mask) + output = net.forward(q, k, v, alibi, attn_mask) + if not torch.allclose(ref_output, output, rtol=0.001, atol=0.01): + print(f"error at index {i} ref:\n{ref_output} \ncur:\n {output} ") + assert(False) + + print('done.') + return + +if __name__ == "__main__": + test_bloom() diff --git a/tests/script/test_mha_chatglm.py b/tests/script/test_mha_chatglm.py index 59179e6..dd33799 100644 --- a/tests/script/test_mha_chatglm.py +++ b/tests/script/test_mha_chatglm.py @@ -122,7 +122,7 @@ def __init__(self, num_attention_heads, hidden_size, head_size_aligned, max_posi dst_precision_name, max_seq_len) def forward(self, query, key, value, attention_mask, head_size, key_seq_len): - return self.mha.exec(query, key, value, attention_mask, head_size, key_seq_len) + return self.mha.exec(query, key, value, torch.tensor(1.0), attention_mask, head_size, key_seq_len) def forward_quant(self, query, key, value, attention_mask, q_quant, k_quant, qk_quant, v_quant, requant): # q_dequant, k_dequant, v_dequant, qk_quant, std::vector& qkv_quant diff --git a/tests/script/test_mha_gpt.py b/tests/script/test_mha_gpt.py index ac99f99..8419c8b 100644 --- a/tests/script/test_mha_gpt.py +++ b/tests/script/test_mha_gpt.py @@ -119,7 +119,7 @@ def __init__(self, num_attention_heads, hidden_size, max_position_embeddings, is dst_precision_name, max_seq_len) def forward(self, query, key, value, attention_mask): - return self.mha.exec(query, key, value, attention_mask) + return self.mha.exec(query, key, value, torch.tensor(1.0), attention_mask) def forward_quant(self, query, key, value, attention_mask, q_quant, k_quant, qk_quant, v_quant, requant): # q_dequant, k_dequant, v_dequant, qk_quant, std::vector& qkv_quant From 48e01f405d2becab4afdeb0596fe308186c08e12 Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Mon, 10 Jul 2023 18:19:41 +0800 Subject: [PATCH 24/54] use target_compile_options to change cxx flags --- CMakeLists.txt | 12 +++--------- src/CMakeLists.txt | 1 + tests/CMakeLists.txt | 3 ++- 3 files changed, 6 insertions(+), 10 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 7d2217c..831a44f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -19,27 +19,21 @@ if(MSVC) if(MSVC_VERSION VERSION_LESS 1928) message(FATAL_ERROR "Insufficient msvc compiler version, current ${MSVC_VERSION}, minimum 1928.") endif() - # Force to always compile with W4 - if(CMAKE_CXX_FLAGS MATCHES "/W[0-4]") - string(REGEX REPLACE "/W[0-4]" "/W4" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") - else() - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /W4") - endif() elseif(CMAKE_COMPILER_IS_GNUCC OR CMAKE_COMPILER_IS_GNUCXX) if(CMAKE_CXX_COMPILER_VERSION VERSION_LESS "11.2") message(FATAL_ERROR "Insufficient gcc compiler version, current ${CMAKE_CXX_COMPILER_VERSION}, minimum 11.2.") endif() - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=sapphirerapids -flax-vector-conversions") + set(EXTRA_CXX_FLAGS -march=sapphirerapids -flax-vector-conversions) elseif(OV_COMPILER_IS_CLANG) if(CMAKE_CXX_COMPILER_VERSION VERSION_LESS "12") message(FATAL_ERROR "Insufficient clang compiler version, current ${CMAKE_CXX_COMPILER_VERSION}, minimum 12.") endif() - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=sapphirerapids -flax-vector-conversions") + set(EXTRA_CXX_FLAGS -march=sapphirerapids -flax-vector-conversions) elseif(CMAKE_CXX_COMPILER_ID STREQUAL "IntelLLVM") if(CMAKE_CXX_COMPILER_VERSION VERSION_LESS "2023.0") message(FATAL_ERROR "Insufficient intel compiler version, current ${CMAKE_CXX_COMPILER_VERSION}, minimum 2023.0.") endif() - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=sapphirerapids") + set(EXTRA_CXX_FLAGS -march=sapphirerapids) endif() if (NOT CMAKE_RUNTIME_OUTPUT_DIRECTORY) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 93c0e79..e9cc194 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -13,6 +13,7 @@ set_target_properties(${PROJECT_NAME} PROPERTIES target_include_directories(${PROJECT_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} PUBLIC $ $/${CMAKE_INSTALL_INCLUDEDIR}>) +target_compile_options(${PROJECT_NAME} PRIVATE ${EXTRA_CXX_FLAGS}) set(CMAKE_DST lib/cmake/${PROJECT_NAME}) # header files diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 8ed51d5..e43a428 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -30,4 +30,5 @@ find_package(OpenMP REQUIRED) add_executable(cpu_extensions_tests ${TEST_SOURCE_FILES}) target_include_directories(cpu_extensions_tests PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../src) -target_link_libraries(cpu_extensions_tests cpu_extensions gtest_main stdc++ OpenMP::OpenMP_CXX) \ No newline at end of file +target_link_libraries(cpu_extensions_tests cpu_extensions gtest_main stdc++ OpenMP::OpenMP_CXX) +target_compile_options(cpu_extensions_tests PRIVATE ${EXTRA_CXX_FLAGS}) From f906cdd33b82307eda961b16d1800fc7efbd67c6 Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Wed, 12 Jul 2023 05:56:08 +0800 Subject: [PATCH 25/54] mha uses tensor to support different strides --- include/llm_mha_gpt.hpp | 17 +-- include/llm_plain_tensor.hpp | 224 +++++++++++++++++++++++++++++++ src/CMakeLists.txt | 1 - src/common/llm_plain_tensor.cpp | 169 +++++++++++++++++++++++ src/mha_gpt_amx.cpp | 119 +++++++++------- tests/script/ext/attn_gpt.cpp | 12 +- tests/script/ext/mha_gpt.cpp | 41 +++--- tests/script/ext/setup.py | 2 +- tests/script/test_mha_bloom.py | 2 +- tests/script/test_mha_chatglm.py | 4 +- tests/script/test_mha_gpt.py | 4 +- 11 files changed, 498 insertions(+), 97 deletions(-) create mode 100644 include/llm_plain_tensor.hpp create mode 100644 src/common/llm_plain_tensor.cpp diff --git a/include/llm_mha_gpt.hpp b/include/llm_mha_gpt.hpp index 83ea8be..c905fbe 100644 --- a/include/llm_mha_gpt.hpp +++ b/include/llm_mha_gpt.hpp @@ -8,6 +8,7 @@ #include #include #include "llm_types.hpp" +#include "llm_plain_tensor.hpp" namespace llmdnn { @@ -44,7 +45,6 @@ class mha_gpt { struct create_param { size_t num_heads; size_t head_size; - size_t head_size_aligned; // better to aligned to 64 bytes for best performance, apply for qkv size_t max_seq_len; // max seq length for computing the size of matmul tmp result float normal_factor; // supported (qkv, dst): (bf16, bf16), (s8, s8) @@ -57,17 +57,14 @@ class mha_gpt { size_t query_seq_len; size_t key_seq_len; bool is_causal_in_attention; // causal mask is fused in attention mask: chatglm uses it. - uint8_t* q; // q buffer, compact, shape: [batch, num_heads, query_seq_len, head_size] - uint8_t** k; // k buffer, k[N] stands different batch which may be discreted - // k[0] shape: [batch, num_heads, key_seq_len, head_size] - uint8_t** v; // v buffer, v[N] stands different batch which may be discreted - // v[0] shape: [batch, num_heads, value_seq_len, head_size] - float* attention_mask; // attention mask, attention_mask[0] shape: + plain_tensor q; // q buffer, shape: [batch, num_heads, query_seq_len, head_size] + plain_tensor k; // k buffer, shape: [batch, num_heads, key_seq_len, head_size] + plain_tensor v; // v buffer, shape: [batch, num_heads, value_seq_len, head_size] + plain_tensor attention_mask; // attention mask, shape: // [batch, 1, 1, key_seq_len], when is_causal_in_attention is false // [batch, 1, query_seq_len, key_seq_len], when is_causal_in_attention is true - uint8_t* attn_output; // output, compact, shape: [batch, query_seq_len, num_heads * head_size] - size_t head_stride_in_kv; // kv stride for next head; kv may be preallocated a big buffer - float* alibi; // only is_bloom is true will use + plain_tensor attn_output; // output, compact, shape: [batch, query_seq_len, num_heads * head_size] + plain_tensor alibi; // only is_bloom is true will use, shape: [batch, num_heads, 1, key_seq_len] // expected quant schema: // q,k,v use per tensor quant, attn_output may use per tensor/channel quant float q_dequant; diff --git a/include/llm_plain_tensor.hpp b/include/llm_plain_tensor.hpp new file mode 100644 index 0000000..bfc0faa --- /dev/null +++ b/include/llm_plain_tensor.hpp @@ -0,0 +1,224 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace llmdnn { + +#define PLAINTENSOR_RANK_MAX 8 +struct plain_tensor_base { + size_t m_strides[PLAINTENSOR_RANK_MAX]; + size_t m_dims[PLAINTENSOR_RANK_MAX]; + size_t m_rank; + + std::shared_ptr m_ptr; + size_t m_capacity = 0; + size_t m_element_size = 1; + + uint8_t* batched_ptr_buff[8]; + std::vector batched_ptr_backup; + + operator bool() { + return static_cast(m_ptr); + } + + size_t size(int i) { + assert(i < m_rank); + return m_dims[i]; + } + size_t stride(int i) { + assert(i < m_rank); + return m_strides[i]; + } +}; + +struct plain_tensor : public plain_tensor_base { + plain_tensor(); + ~plain_tensor(); + + struct tensor_index { + int start; + int end; + int step; + int count; + // select all + tensor_index() { + start = 0; + end = INT_MAX; + step = 1; + } + bool slice_with_squeeze() { + return end == INT_MIN; + } + // tensor_index(start) : select 1 element (with squeeze) + // tensor_index(start, end, step) : select a range w/o squeeze + tensor_index(int start, int end = INT_MIN, int step = 1) : start(start), end(end), step(step) {} + + void regularize(int size) { + if (start < 0) + start += size; + assert(start >= 0 && start < size); + if (end != INT_MIN) { + if (end < 0) + end += size; + if (end > size) + end = size; + assert(end >= 0 && end <= size); + count = (end - start + step - 1) / step; + } else { + count = 1; + } + } + }; + + plain_tensor index(const std::initializer_list& indices) const; + + // slice: return a sub-view (w/o ownership/refcount to original data) + plain_tensor slice(int axis, int start, int end) const; + + bool is_dense() const; + + /* + suppose current shape is [a0,a1,...,am] + and target shape is [b0,b1,...,bn] + reshape is only valid when (a0*a1*...*am) == (b0*b1*...*bn) <======= (A) + + uniform a tensor's shape into groups from last to first, the dimension is merged + into current group if the subtensor in the group is still dense after merge. + otherwise a new group is formed. + + then reshape is performed on group basis, the check (A) is performed on group bases. + which means any reshape inside the group is OK, but not across the group boundary. + + this can be done in one-loop, while group is forming, and checks are performed. + + simplified form is when whole tensor is dense + */ + plain_tensor reshape(const std::initializer_list& target_shape) const; + + plain_tensor permute(const std::initializer_list& order) const; + + template + void resize(const std::vector& new_dims, DT* data = nullptr) { + resize(new_dims, data, sizeof(DT)); + } + + void resize(const std::vector& new_dims, void* data, size_t element_size); + + template + DT* data() const { + return reinterpret_cast(m_ptr.get()); + } + + template + DT& at(const std::initializer_list& index) const { + size_t off = 0; + auto it = index.begin(); + for (auto& stride : m_strides) { + auto coordinate = (it != index.end()) ? (*it++) : 0; + off += stride * coordinate; + } + return *reinterpret_cast(reinterpret_cast(m_ptr.get()) + off); + } + + template + DT& operator()(const std::initializer_list& index) const { + return at
(index); + } + + void assert_dims(const std::initializer_list& expect_dims) const; + uint8_t** get_batched_ptrs() { + uint8_t** ret_ptrs = batched_ptr_buff; + auto batch_size = m_dims[0]; + if (batch_size > sizeof(batched_ptr_buff) / sizeof(batched_ptr_buff[0])) { + batched_ptr_backup.resize(batch_size); + ret_ptrs = &batched_ptr_backup[0]; + } + for (size_t b = 0; b < batch_size; b++) { + ret_ptrs[b] = &at({b}); + } + return ret_ptrs; + } + + template + std::string repr(int max_total_lines = 16, int lines_per_row = 1) const { + std::stringstream ss; + ss << typeid(DT).name() << " shape=["; + const char* sep = ""; + size_t sz = 1; + for (size_t i = 0; i < m_rank; i++) { + ss << sep << m_dims[i]; + sz *= m_dims[i]; + sep = ","; + } + ss << "] strides=["; + sep = ""; + for (size_t i = 0; i < m_rank; i++) { + ss << sep << m_strides[i]; + sep = ","; + } + ss << "] {"; + if (m_rank > 1) + ss << "\n"; + auto last_dim_size = m_dims[m_rank - 1]; + int row_id = 0; + int cur_row_lines_left = lines_per_row; + int cur_line_elecnt = 0; + int cur_row_elecnt = 0; + size_t i; + auto* p = reinterpret_cast(m_ptr.get()); + for (i = 0; i < sz && max_total_lines > 0; i++) { + if ((i % last_dim_size) == 0) { + ss << row_id << ":\t\t"; + row_id++; + cur_row_lines_left = lines_per_row; + } + + // display current element if we still have buget + if (cur_row_lines_left > 0) { + ss << p[i] << ","; + cur_line_elecnt++; + cur_row_elecnt++; + if ((cur_line_elecnt % 16) == 15 || (cur_row_elecnt == last_dim_size)) { + max_total_lines--; + cur_row_lines_left--; + if (cur_row_lines_left == 0) { + if (cur_row_elecnt == last_dim_size) + ss << ",\n"; + else + ss << "...\n"; + cur_row_elecnt = 0; + } else { + ss << "\n\t\t"; + } + cur_line_elecnt = 0; + } + } + } + if (i < sz) { + ss << "... ... ... ... \n"; + } + ss << "}"; + return ss.str(); + } + + template + friend std::ostream& operator<<(std::ostream& os, const plain_tensor& dt); +}; + +template +std::ostream& operator<<(std::ostream& os, const plain_tensor& dt) { + os << dt.repr(); + return os; +} + +} // namespace llmdnn diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index e9cc194..90594d5 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -32,7 +32,6 @@ install(TARGETS ${PROJECT_NAME} # config file install(EXPORT ${PROJECT_NAME}Targets FILE ${PROJECT_NAME}Config.cmake - NAMESPACE ${PROJECT_NAME}:: DESTINATION ${CMAKE_DST}) # version file diff --git a/src/common/llm_plain_tensor.cpp b/src/common/llm_plain_tensor.cpp new file mode 100644 index 0000000..c322719 --- /dev/null +++ b/src/common/llm_plain_tensor.cpp @@ -0,0 +1,169 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include +#include +#include + +#include "bf16.hpp" +#include "llm_plain_tensor.hpp" + +namespace llmdnn { + +plain_tensor::plain_tensor() { +} + +plain_tensor::~plain_tensor() { +} + +plain_tensor plain_tensor::index(const std::initializer_list& indices) const { + plain_tensor sub_tensor; + assert(indices.size() <= m_rank); + int i_src = 0; + int i_dst = 0; + sub_tensor.m_capacity = 0; + size_t off = 0; + for (auto idx : indices) { + auto src_dim = m_dims[i_src]; + auto src_stride = m_strides[i_src]; + idx.regularize(src_dim); + off += idx.start * src_stride; + if (idx.slice_with_squeeze()) { + // no output dimension + i_src++; + continue; + } + sub_tensor.m_dims[i_dst] = idx.count; + sub_tensor.m_strides[i_dst] = src_stride; + i_dst++; + i_src++; + } + sub_tensor.m_rank = i_dst; // index may imply squeeze + sub_tensor.m_ptr = std::shared_ptr(reinterpret_cast(m_ptr.get()) + off, [](void*) {}); + return sub_tensor; +} + +// slice: return a sub-view (w/o ownership/refcount to original data) +plain_tensor plain_tensor::slice(int axis, int start, int end) const { + plain_tensor sub_tensor; + assert(axis < m_rank); + + sub_tensor.m_capacity = 0; + sub_tensor.m_rank = m_rank; // slice dosen't change rank & strides + for (size_t i = 0; i < m_rank; i++) { + sub_tensor.m_strides[i] = m_strides[i]; + sub_tensor.m_dims[i] = m_dims[i]; + } + sub_tensor.m_dims[axis] = end - start; + + auto off = start * m_strides[axis]; + auto* data = reinterpret_cast(m_ptr.get()) + off; + sub_tensor.m_ptr = std::shared_ptr(reinterpret_cast(data), [](void*) {}); + + return sub_tensor; +} + +bool plain_tensor::is_dense() const { + // check if it's dense tensor + size_t stride = m_element_size; + for (int i = m_rank - 1; i >= 0; i--) { + if (m_strides[i] != stride) + return false; + stride *= m_dims[i]; + } + return true; +} + +/* + suppose current shape is [a0,a1,...,am] + and target shape is [b0,b1,...,bn] + reshape is only valid when (a0*a1*...*am) == (b0*b1*...*bn) <======= (A) + + uniform a tensor's shape into groups from last to first, the dimension is merged + into current group if the subtensor in the group is still dense after merge. + otherwise a new group is formed. + + then reshape is performed on group basis, the check (A) is performed on group bases. + which means any reshape inside the group is OK, but not across the group boundary. + + this can be done in one-loop, while group is forming, and checks are performed. + + simplified form is when whole tensor is dense +*/ +plain_tensor plain_tensor::reshape(const std::initializer_list& target_shape) const { + // only valid for dense memory + plain_tensor new_tensor_view; + assert(is_dense()); + //assert(shape_size(target_shape) == shape_size(m_dims)); + new_tensor_view.resize(std::vector(target_shape), m_ptr.get(), m_element_size); + return new_tensor_view; +} + +plain_tensor plain_tensor::permute(const std::initializer_list& order) const { + plain_tensor new_tensor_view; + assert(order.size() == m_rank); + new_tensor_view.m_capacity = 0; + new_tensor_view.m_ptr = m_ptr; + new_tensor_view.m_rank = m_rank; + auto it_order = order.begin(); + // also should check order has no repeat element + for (size_t i = 0; i < m_rank; i++) { + auto j = *it_order++; + assert(j >= 0 && j < m_rank); + new_tensor_view.m_dims[i] = m_dims[j]; + new_tensor_view.m_strides[i] = m_strides[j]; + } + return new_tensor_view; +} + +void plain_tensor::resize(const std::vector& new_dims, void* data, size_t element_size) { + // initialize strides for compact/dense tensor + m_element_size = element_size; + m_rank = new_dims.size(); + assert(m_rank <= PLAINTENSOR_RANK_MAX); + size_t stride = element_size; + for (int i = m_rank - 1; i >= 0; i--) { + m_dims[i] = new_dims[i]; + m_strides[i] = stride; + stride *= new_dims[i]; + } + + if (!data) { + auto capacity_new = m_strides[0] * m_dims[0]; + if (capacity_new > m_capacity) { + m_ptr = std::shared_ptr(aligned_alloc(64, capacity_new), [](void* p) { + ::free(p); + }); + m_capacity = capacity_new; + } + } else { + // m_capacity is zero to indicate that we don't own the memory + m_capacity = 0; + m_ptr = std::shared_ptr(reinterpret_cast(data), [](void*) {}); + } +} + +void plain_tensor::assert_dims(const std::initializer_list& expect_dims) const { + if (m_rank != expect_dims.size()) { + asm("int3"); + std::cout << "dims not same\n"; + } + if (!std::equal(expect_dims.begin(), expect_dims.end(), m_dims)) { + std::stringstream ss; + ss << " m_dims=["; + for (size_t i = 0; i < m_rank; i++) + ss << m_dims[i] << ","; + ss << "] expect_dims=["; + for (auto& i : expect_dims) + ss << i << ","; + ss << "]"; + asm("int3"); + std::cout << ss.str(); + } +} + +} // namespace llmdnn diff --git a/src/mha_gpt_amx.cpp b/src/mha_gpt_amx.cpp index 5667216..4b052c2 100644 --- a/src/mha_gpt_amx.cpp +++ b/src/mha_gpt_amx.cpp @@ -2,10 +2,12 @@ // SPDX-License-Identifier: Apache-2.0 // +#include #include #include #include "common/simple_parallel.hpp" +#include "common/tensor2d.hpp" #include "common/utility.hpp" #include "utility_kernel_avx512.hpp" #include "mm_kernel_common_amx.hpp" @@ -27,6 +29,7 @@ struct mha_gpt_impl_amx : public mha_gpt::impl { void mha_bf16(const mha_gpt::exec_param ¶m); void mha_i8(const mha_gpt::exec_param ¶m); + size_t head_size_aligned; size_t bufferMatMul0OutSize; size_t bufferMatMul1OutSize; @@ -62,6 +65,7 @@ bool mha_gpt_impl_amx::create(const mha_gpt::create_param& param) { // attn_output: [batch, query_seq_len, num_heads * head_size] size_t numThreads = getTotalThreads(); if (_create_param.qkv_precision == dnnl_s8) { + head_size_aligned = rndup(_create_param.head_size, 64); qKtrGemm_i8xi8.resize(numThreads); for (size_t i = 0; i < numThreads; i++) { qKtrGemm_i8xi8[i] = std::make_shared>(false, !param.is_bloom); @@ -79,6 +83,7 @@ bool mha_gpt_impl_amx::create(const mha_gpt::create_param& param) { [](void * p) { ::free(p); }); memset(qkvQuantBuf.get(), 0, sizeof(param.head_size * sizeof(float))); } else { + head_size_aligned = rndup(_create_param.head_size, 32); gemAvB_BF16xBF16.resize(numThreads); for (size_t i = 0; i < numThreads; i++) { gemAvB_BF16xBF16[i] = std::make_shared>(); @@ -94,7 +99,7 @@ bool mha_gpt_impl_amx::create(const mha_gpt::create_param& param) { } bufferMatMul0OutSize = _create_param.max_seq_len * rndup(_create_param.max_seq_len * sizeof(float), 64); - bufferMatMul1OutSize = _create_param.max_seq_len * _create_param.head_size_aligned * sizeof(float); + bufferMatMul1OutSize = _create_param.max_seq_len * head_size_aligned * sizeof(float); bufferMatMul0Out = std::shared_ptr( reinterpret_cast(aligned_alloc(64, numThreads * bufferMatMul0OutSize)), @@ -108,36 +113,34 @@ bool mha_gpt_impl_amx::create(const mha_gpt::create_param& param) { } void mha_gpt_impl_amx::mha_bf16(const mha_gpt::exec_param ¶m) { - uint8_t* pQIn0 = param.q; - auto pKIn0 = param.k; - auto attn_masks = param.attention_mask; - auto pVIn0 = param.v; - uint8_t* pout = param.attn_output; - auto alibi = param.alibi; + auto& q = param.q; + auto& k = param.k; + auto& v = param.v; + auto* attn_masks = param.attention_mask.data(); + uint8_t* pout = param.attn_output.data(); + auto alibi = param.alibi.data(); auto outPrcSize = get_precision_size(_create_param.qkv_precision); auto& gemAvB_ops = gemAvB_BF16xBF16; auto& qKtrGemm_ops = qKtrGemm_BF16xBF16; auto& qKVGemm_ops = qKVGemm_BF16xBF16; bool is_vector = param.query_seq_len == 1 && _create_param.head_size >= 32 && _create_param.head_size <= 32 * 6 && !_create_param.is_bloom; - size_t head_stride_in_q = _create_param.head_size_aligned * param.query_seq_len; - size_t batch_stride_in_q = head_stride_in_q * _create_param.num_heads; size_t head_stride_in_attn = _create_param.head_size; size_t batch_stride_in_attn = _create_param.head_size * _create_param.num_heads * param.query_seq_len; size_t causal_mask_offset_start = param.key_seq_len - param.query_seq_len; if (is_vector) { parallel_for2d(param.batch, _create_param.num_heads, [&](size_t threadNum, size_t i0, size_t i1) { - auto pQIn0_aux = pQIn0 + (i0 * batch_stride_in_q + i1 * head_stride_in_q) * get_precision_size(_create_param.qkv_precision); - auto pKIn0_aux = pKIn0[i0] + i1 * param.head_stride_in_kv * get_precision_size(_create_param.qkv_precision); - auto pVIn0_aux = pVIn0[i0] + i1 * param.head_stride_in_kv * get_precision_size(_create_param.qkv_precision); + auto pQIn0_aux = &q.at({i0, i1}); + auto pKIn0_aux = &k.at({i0, i1}); + auto pVIn0_aux = &v.at({i0, i1}); auto pAddIn1_aux = attn_masks + i0 * param.key_seq_len; auto bufferMatMul0Out_local = reinterpret_cast(bufferMatMul0Out.get() + threadNum * bufferMatMul0OutSize); auto bufferMatMul1Out_local = reinterpret_cast(bufferMatMul1Out.get() + threadNum * bufferMatMul1OutSize); - - tensor2D matK(param.key_seq_len, _create_param.head_size, reinterpret_cast(pKIn0_aux), _create_param.head_size_aligned * sizeof(ov::bfloat16)); + + tensor2D matK(param.key_seq_len, _create_param.head_size, reinterpret_cast(pKIn0_aux), k.m_strides[2]); // N: key_seq_len, K: head_size // q[1, K] * transpose(k[N, K]) ==> // k[N, K] * transpose(q[1, K]) ==> @@ -149,12 +152,12 @@ void mha_gpt_impl_amx::mha_bf16(const mha_gpt::exec_param ¶m) { softmax_avx512(reinterpret_cast(pMatMul0Out), pMatMul0Out, param.key_seq_len, nullptr); auto pOut_aux = pout + (i0 * batch_stride_in_attn + i1 * head_stride_in_attn) * outPrcSize; tensor2D matQK(param.query_seq_len, param.key_seq_len, reinterpret_cast(bufferMatMul0Out_local), rndup(param.key_seq_len * sizeof(ov::bfloat16), 64)); - tensor2D matV(param.key_seq_len, _create_param.head_size, reinterpret_cast(pVIn0_aux), _create_param.head_size_aligned * sizeof(ov::bfloat16)); - tensor2D matQKV(param.query_seq_len, _create_param.head_size, reinterpret_cast(bufferMatMul1Out_local), _create_param.head_size_aligned * sizeof(float)); + tensor2D matV(param.key_seq_len, _create_param.head_size, reinterpret_cast(pVIn0_aux), v.m_strides[2]); + tensor2D matQKV(param.query_seq_len, _create_param.head_size, reinterpret_cast(bufferMatMul1Out_local), head_size_aligned * sizeof(float)); amx_kernel::PP::BiasGeluStore pp(matQKV); (*qKVGemm_ops[threadNum])(matQK, matV, 0, _create_param.head_size, pp); memcpy2d_stride_avx512(reinterpret_cast(pOut_aux), reinterpret_cast(bufferMatMul1Out_local), param.query_seq_len, - _create_param.head_size, _create_param.head_size_aligned * sizeof(float), _create_param.num_heads * _create_param.head_size * sizeof(ov::bfloat16), nullptr); + _create_param.head_size, head_size_aligned * sizeof(float), _create_param.num_heads * _create_param.head_size * sizeof(ov::bfloat16), nullptr); }); } else { auto numThreads = getTotalThreads(); @@ -178,21 +181,21 @@ void mha_gpt_impl_amx::mha_bf16(const mha_gpt::exec_param ¶m) { // q: [batch, num_heads, query_seq_len, head_size] // k: [batch, num_heads, key_seq_len, head_size] // v: [batch, num_heads, value_seq_len, head_size] - auto pQIn0_aux = pQIn0 + (i0 * batch_stride_in_q + i1 * head_stride_in_q + seq_start * _create_param.head_size_aligned) * get_precision_size(_create_param.qkv_precision); - auto pKIn0_aux = pKIn0[i0] + i1 * param.head_stride_in_kv * get_precision_size(_create_param.qkv_precision); - auto pVIn0_aux = pVIn0[i0] + i1 * param.head_stride_in_kv * get_precision_size(_create_param.qkv_precision); + auto pQIn0_aux = &q.at({static_cast(i0), static_cast(i1), static_cast(seq_start)}); + auto pKIn0_aux = &k.at({static_cast(i0), static_cast(i1)}); + auto pVIn0_aux = &v.at({static_cast(i0), static_cast(i1)}); auto bufferMatMul0Out_local = reinterpret_cast(bufferMatMul0Out.get() + threadNum * bufferMatMul0OutSize); auto bufferMatMul1Out_local = reinterpret_cast(bufferMatMul1Out.get() + threadNum * bufferMatMul1OutSize); - tensor2D matQ(seq_cout, _create_param.head_size, reinterpret_cast(pQIn0_aux), _create_param.head_size_aligned * sizeof(ov::bfloat16)); + tensor2D matQ(seq_cout, _create_param.head_size, reinterpret_cast(pQIn0_aux), q.m_strides[2]); tensor2D matQK(seq_cout, param.key_seq_len, reinterpret_cast(bufferMatMul0Out_local), rndup(param.key_seq_len * sizeof(float), 64)); amx_kernel::PP::BiasGeluStore pp(matQK); if (!_create_param.is_bloom) { - tensor2D matK(param.key_seq_len, _create_param.head_size, reinterpret_cast(pKIn0_aux), _create_param.head_size_aligned * sizeof(ov::bfloat16)); + tensor2D matK(param.key_seq_len, _create_param.head_size, reinterpret_cast(pKIn0_aux), k.m_strides[2]); (*qKtrGemm_ops[threadNum])(matQ, matK, 0, param.key_seq_len, pp, pKIn0_aux == prev_k); } else { - tensor2D matK(_create_param.head_size, param.key_seq_len, reinterpret_cast(pKIn0_aux), param.key_seq_len * sizeof(ov::bfloat16)); + tensor2D matK(_create_param.head_size, param.key_seq_len, reinterpret_cast(pKIn0_aux), k.m_strides[3]); (*qKtrGemm_ops[threadNum])(matQ, matK, 0, param.key_seq_len, pp, pKIn0_aux == prev_k); } prev_k = pKIn0_aux; @@ -241,13 +244,13 @@ void mha_gpt_impl_amx::mha_bf16(const mha_gpt::exec_param ¶m) { auto pOut_aux = pout + (i0 * batch_stride_in_attn + i1 * head_stride_in_attn + seq_start * head_stride_in_attn * _create_param.num_heads) * outPrcSize; tensor2D matQKBF16(seq_cout, param.key_seq_len, reinterpret_cast(bufferMatMul0Out_local), rndup(param.key_seq_len * sizeof(ov::bfloat16), 64)); - tensor2D matV(param.key_seq_len, _create_param.head_size, reinterpret_cast(pVIn0_aux), _create_param.head_size_aligned * sizeof(ov::bfloat16)); - tensor2D matQKV(seq_cout, _create_param.head_size, reinterpret_cast(bufferMatMul1Out_local), _create_param.head_size_aligned * sizeof(float)); + tensor2D matV(param.key_seq_len, _create_param.head_size, reinterpret_cast(pVIn0_aux), v.m_strides[2]); + tensor2D matQKV(seq_cout, _create_param.head_size, reinterpret_cast(bufferMatMul1Out_local), head_size_aligned * sizeof(float)); amx_kernel::PP::BiasGeluStore pp2(matQKV); (*qKVGemm_ops[threadNum])(matQKBF16, matV, 0, _create_param.head_size, pp2, prev_v == pVIn0_aux); prev_v = pVIn0_aux; memcpy2d_stride_avx512(reinterpret_cast(pOut_aux), reinterpret_cast(bufferMatMul1Out_local), seq_cout, - _create_param.head_size, _create_param.head_size_aligned * sizeof(float), _create_param.num_heads * _create_param.head_size * sizeof(ov::bfloat16), nullptr); + _create_param.head_size, head_size_aligned * sizeof(float), _create_param.num_heads * _create_param.head_size * sizeof(ov::bfloat16), nullptr); parallel_it_step(i0, param.batch, i1, _create_param.num_heads, seq, seq_cout_all); } }); @@ -255,12 +258,12 @@ void mha_gpt_impl_amx::mha_bf16(const mha_gpt::exec_param ¶m) { } void mha_gpt_impl_amx::mha_i8(const mha_gpt::exec_param ¶m) { - uint8_t* pQIn0 = param.q; - auto pKIn0 = param.k; - auto attn_masks = param.attention_mask; - auto pVIn0 = param.v; - uint8_t* pout = param.attn_output; - auto alibi = param.alibi; + auto& q = param.q; + auto& k = param.k; + auto& v = param.v; + auto attn_masks = param.attention_mask.data(); + uint8_t* pout = param.attn_output.data(); + auto alibi = param.alibi.data(); auto outPrcSize = get_precision_size(_create_param.dst_precision); auto& gemAvB_ops = gemAvB_i8xi8; @@ -277,24 +280,22 @@ void mha_gpt_impl_amx::mha_i8(const mha_gpt::exec_param ¶m) { if (param.qkv_quant.size() == 1) { std::fill(qkvQuantBuf.get() + 1, qkvQuantBuf.get() + _create_param.head_size, *qkvQuantBuf.get()); } - size_t head_stride_in_q = _create_param.head_size_aligned * param.query_seq_len; - size_t batch_stride_in_q = head_stride_in_q * _create_param.num_heads; size_t head_stride_in_attn = _create_param.head_size; size_t batch_stride_in_attn = _create_param.head_size * _create_param.num_heads * param.query_seq_len; size_t causal_mask_offset_start = param.key_seq_len - param.query_seq_len; if (is_vector) { parallel_for2d(param.batch, _create_param.num_heads, [&](size_t threadNum, size_t i0, size_t i1) { - auto pQIn0_aux = pQIn0 + (i0 * batch_stride_in_q + i1 * head_stride_in_q) * get_precision_size(_create_param.qkv_precision); - auto pKIn0_aux = pKIn0[i0] + i1 * param.head_stride_in_kv * get_precision_size(_create_param.qkv_precision); - auto pVIn0_aux = pVIn0[i0] + i1 * param.head_stride_in_kv * get_precision_size(_create_param.qkv_precision); + auto pQIn0_aux = &q.at({i0, i1}); + auto pKIn0_aux = &k.at({i0, i1}); + auto pVIn0_aux = &v.at({i0, i1}); auto pAddIn1_aux = attn_masks + i0 * param.key_seq_len; auto bufferMatMul0Out_local = reinterpret_cast(bufferMatMul0Out.get() + threadNum * bufferMatMul0OutSize); auto bufferMatMul1Out_local = reinterpret_cast(bufferMatMul1Out.get() + threadNum * bufferMatMul1OutSize); - tensor2D matK(param.key_seq_len, _create_param.head_size, reinterpret_cast(pKIn0_aux), _create_param.head_size_aligned * sizeof(int8_t)); + tensor2D matK(param.key_seq_len, _create_param.head_size, reinterpret_cast(pKIn0_aux), k.m_strides[2]); // N: key_seq_len, K: head_size // q[1, K] * transpose(k[N, K]) ==> // k[N, K] * transpose(q[1, K]) ==> @@ -307,12 +308,12 @@ void mha_gpt_impl_amx::mha_i8(const mha_gpt::exec_param ¶m) { softmax_avx512(reinterpret_cast(pMatMul0Out), pMatMul0Out, param.key_seq_len, param.qk_quant); auto pOut_aux = pout + (i0 * batch_stride_in_attn + i1 * head_stride_in_attn) * outPrcSize; tensor2D matQK(param.query_seq_len, param.key_seq_len, reinterpret_cast(bufferMatMul0Out_local), rndup(param.key_seq_len * sizeof(uint8_t), 64)); - tensor2D matV(param.key_seq_len, _create_param.head_size, reinterpret_cast(pVIn0_aux), _create_param.head_size_aligned * sizeof(int8_t)); - tensor2D matQKV(param.query_seq_len, _create_param.head_size, reinterpret_cast(bufferMatMul1Out_local), _create_param.head_size_aligned * sizeof(float)); + tensor2D matV(param.key_seq_len, _create_param.head_size, reinterpret_cast(pVIn0_aux), v.m_strides[2]); + tensor2D matQKV(param.query_seq_len, _create_param.head_size, reinterpret_cast(bufferMatMul1Out_local), head_size_aligned * sizeof(float)); amx_kernel::PP::BiasGeluStore pp(matQKV); (*qKVGemm_ops[threadNum])(matQK, matV, 0, _create_param.head_size, pp); memcpy2d_stride_avx512(reinterpret_cast(pOut_aux), reinterpret_cast(bufferMatMul1Out_local), param.query_seq_len, - _create_param.head_size, _create_param.head_size_aligned * sizeof(float), _create_param.num_heads * _create_param.head_size, qkvQuantBuf.get()); + _create_param.head_size, head_size_aligned * sizeof(float), _create_param.num_heads * _create_param.head_size, qkvQuantBuf.get()); }); } else { auto numThreads = getTotalThreads(); @@ -336,21 +337,21 @@ void mha_gpt_impl_amx::mha_i8(const mha_gpt::exec_param ¶m) { // q: [batch, num_heads, query_seq_len, head_size] // k: [batch, num_heads, key_seq_len, head_size] // v: [batch, num_heads, value_seq_len, head_size] - auto pQIn0_aux = pQIn0 + (i0 * batch_stride_in_q + i1 * head_stride_in_q + seq_start * _create_param.head_size_aligned); - auto pKIn0_aux = pKIn0[i0] + i1 * param.head_stride_in_kv; - auto pVIn0_aux = pVIn0[i0] + i1 * param.head_stride_in_kv; + auto pQIn0_aux = &q.at({static_cast(i0), static_cast(i1), static_cast(seq_start)}); + auto pKIn0_aux = &k.at({static_cast(i0), static_cast(i1)}); + auto pVIn0_aux = &v.at({static_cast(i0), static_cast(i1)}); auto bufferMatMul0Out_local = reinterpret_cast(bufferMatMul0Out.get() + threadNum * bufferMatMul0OutSize); auto bufferMatMul1Out_local = reinterpret_cast(bufferMatMul1Out.get() + threadNum * bufferMatMul1OutSize); - tensor2D matQ(seq_cout, _create_param.head_size, reinterpret_cast(pQIn0_aux), _create_param.head_size_aligned * sizeof(int8_t)); + tensor2D matQ(seq_cout, _create_param.head_size, reinterpret_cast(pQIn0_aux), q.m_strides[2]); tensor2D matQK(seq_cout, param.key_seq_len, reinterpret_cast(bufferMatMul0Out_local), rndup(param.key_seq_len * sizeof(float), 64)); amx_kernel::PP::BiasGeluStore pp(matQK); if (!_create_param.is_bloom) { - tensor2D matK(param.key_seq_len, _create_param.head_size, reinterpret_cast(pKIn0_aux), _create_param.head_size_aligned * sizeof(int8_t)); + tensor2D matK(param.key_seq_len, _create_param.head_size, reinterpret_cast(pKIn0_aux), k.m_strides[2]); (*qKtrGemm_ops[threadNum])(matQ, matK, 0, param.key_seq_len, pp, prev_k == pKIn0_aux); } else { - tensor2D matK(_create_param.head_size, param.key_seq_len, reinterpret_cast(pKIn0_aux), param.key_seq_len * sizeof(int8_t)); + tensor2D matK(_create_param.head_size, param.key_seq_len, reinterpret_cast(pKIn0_aux), k.m_strides[3]); (*qKtrGemm_ops[threadNum])(matQ, matK, 0, param.key_seq_len, pp, prev_k == pKIn0_aux); } prev_k = pKIn0_aux; @@ -399,15 +400,15 @@ void mha_gpt_impl_amx::mha_i8(const mha_gpt::exec_param ¶m) { auto pOut_aux = pout + (i0 * batch_stride_in_attn + i1 * head_stride_in_attn + seq_start * head_stride_in_attn * _create_param.num_heads) * outPrcSize; tensor2D matQKI8(seq_cout, param.key_seq_len, reinterpret_cast(bufferMatMul0Out_local), rndup(param.key_seq_len * sizeof(uint8_t), 64)); - tensor2D matV(param.key_seq_len, _create_param.head_size, reinterpret_cast(pVIn0_aux), _create_param.head_size_aligned * sizeof(int8_t)); - tensor2D matQKV(seq_cout, _create_param.head_size, reinterpret_cast(bufferMatMul1Out_local), _create_param.head_size_aligned * sizeof(float)); + tensor2D matV(param.key_seq_len, _create_param.head_size, reinterpret_cast(pVIn0_aux), v.m_strides[2]); + tensor2D matQKV(seq_cout, _create_param.head_size, reinterpret_cast(bufferMatMul1Out_local), head_size_aligned * sizeof(float)); amx_kernel::PP::BiasGeluStore pp2(matQKV); (*qKVGemm_ops[threadNum])(matQKI8, matV, 0, _create_param.head_size, pp2, prev_v == pVIn0_aux); prev_v = pVIn0_aux; // matmul1: [batch, num_heads, query_seq_len, head_size] // attn_output: [batch, query_seq_len, num_heads * head_size] memcpy2d_stride_avx512(reinterpret_cast(pOut_aux), reinterpret_cast(bufferMatMul1Out_local), seq_cout, - _create_param.head_size, _create_param.head_size_aligned * sizeof(float), _create_param.num_heads * _create_param.head_size, qkvQuantBuf.get()); + _create_param.head_size, head_size_aligned * sizeof(float), _create_param.num_heads * _create_param.head_size, qkvQuantBuf.get()); parallel_it_step(i0, param.batch, i1, _create_param.num_heads, seq, seq_cout_all); } }); @@ -415,6 +416,24 @@ void mha_gpt_impl_amx::mha_i8(const mha_gpt::exec_param ¶m) { } void mha_gpt_impl_amx::exec(const mha_gpt::exec_param& param) { + if (param.q.m_rank != 4 || param.k.m_rank != 4 || param.v.m_rank != 4) { + std::cout << "q,k,v rank does not equal 4.\n"; + return; + } + auto b = param.q.m_dims[0]; + auto hn = param.q.m_dims[1]; + auto qs = param.q.m_dims[2]; + auto hs = param.q.m_dims[3]; + auto ks = param.k.m_dims[2]; + + if (!(b == param.k.m_dims[0] && b == param.v.m_dims[0] && + hn == param.k.m_dims[1] && hn == param.v.m_dims[1] && + ks == param.v.m_dims[2] && + hs == param.k.m_dims[3] && hs == param.v.m_dims[3])) { + std::cout << "dim of q,k,v is error.\n"; + return; + } + if (_create_param.qkv_precision == dnnl_f32) { assert(false); } else if (_create_param.qkv_precision == dnnl_bf16) { diff --git a/tests/script/ext/attn_gpt.cpp b/tests/script/ext/attn_gpt.cpp index 8b7511e..10aa193 100644 --- a/tests/script/ext/attn_gpt.cpp +++ b/tests/script/ext/attn_gpt.cpp @@ -85,7 +85,6 @@ bool attn_gpt::create(const attn_gpt::create_param& param) { llmdnn::mha_gpt::create_param mha_param = {0}; mha_param.num_heads = param.num_heads; mha_param.head_size = param.head_size; - mha_param.head_size_aligned = param.head_size_aligned; mha_param.normal_factor = param.normal_factor; mha_param.qkv_precision = param.qkv_precision; mha_param.dst_precision = param.dst_precision; @@ -129,13 +128,12 @@ void attn_gpt::exec(const attn_gpt::exec_param& param) { mha_param.batch = param.batch; mha_param.query_seq_len = param.query_seq_len; mha_param.key_seq_len = param.query_seq_len + param.past_seq_len; - mha_param.q = emb_param.query_dst; - mha_param.attn_output = param.attn_output; - mha_param.head_stride_in_kv = param.head_stride_in_kv; + mha_param.q.resize({param.batch, _create_param.num_heads, param.query_seq_len, _create_param.head_size_aligned}, reinterpret_cast(emb_param.query_dst)); + mha_param.attn_output.resize({param.batch, param.query_seq_len, _create_param.num_heads * _create_param.head_size}, reinterpret_cast(param.attn_output)); mha_param.is_causal_in_attention = param.is_causal_in_attention; - mha_param.attention_mask = param.attention_mask; - mha_param.k = emb_param.layer_past_key_dst; - mha_param.v = emb_param.layer_past_value_dst; + mha_param.attention_mask.resize({param.batch, 1, param.query_seq_len, mha_param.key_seq_len}, static_cast(param.attention_mask)); + mha_param.k.resize({param.batch, _create_param.num_heads, _create_param.max_seq_len, _create_param.head_size_aligned}, reinterpret_cast(param.layer_past_key_dst[0])); + mha_param.v.resize({param.batch, _create_param.num_heads, _create_param.max_seq_len, _create_param.head_size_aligned}, reinterpret_cast(param.layer_past_value_dst[0])); _mha_gpt->exec(mha_param); } diff --git a/tests/script/ext/mha_gpt.cpp b/tests/script/ext/mha_gpt.cpp index c3ee7ff..eb734fc 100644 --- a/tests/script/ext/mha_gpt.cpp +++ b/tests/script/ext/mha_gpt.cpp @@ -2,9 +2,11 @@ // SPDX-License-Identifier: Apache-2.0 // +#include #include #include #include "alloca.h" +#include "common/bf16.hpp" #include "module.hpp" #include "common/utility.hpp" #include "utility_kernel_amx.hpp" @@ -17,7 +19,6 @@ void regclass_mha_gpt(pybind11::module m) { cls.def("create", [] (llmdnn::mha_gpt& self, const size_t num_heads, const size_t head_size, - const size_t head_size_aligned, const float normal_factor, const std::string qkv_precision_name, const std::string dst_precision_name, @@ -26,7 +27,6 @@ void regclass_mha_gpt(pybind11::module m) { llmdnn::mha_gpt::create_param param = {0}; param.num_heads = num_heads; param.head_size = head_size; - param.head_size_aligned = head_size_aligned; param.normal_factor = normal_factor; param.qkv_precision = llmdnn::get_dt_from_str(qkv_precision_name); param.dst_precision = llmdnn::get_dt_from_str(dst_precision_name); @@ -41,7 +41,6 @@ void regclass_mha_gpt(pybind11::module m) { }, py::arg("num_heads"), py::arg("head_size"), - py::arg("head_size_aligned"), py::arg("normal_factor"), py::arg("qkv_precision_name"), py::arg("dst_precision_name"), @@ -78,18 +77,17 @@ void regclass_mha_gpt(pybind11::module m) { head_size = head_size == 0 ? head_size_aligned : head_size; auto out = q.new_empty({batch, query_seq_len, num_heads * head_size}); AT_ASSERT((int64_t)param.key_seq_len == attn_len); - param.q = reinterpret_cast(q.data_ptr()); - param.attn_output = reinterpret_cast(out.data_ptr()); - param.head_stride_in_kv = max_seq_len * head_size_aligned; + param.q.resize({q.size(0), q.size(1), q.size(2), q.size(3)}, reinterpret_cast(q.data_ptr())); + param.attn_output.resize({batch, query_seq_len, num_heads * head_size}, reinterpret_cast(out.data_ptr())); param.is_causal_in_attention = attn_mask.size(2) != 1; - param.attention_mask = attn_mask.data_ptr(); - param.k = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); - param.v = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); - for (int i = 0; i < batch; i++) { - param.k[i] = reinterpret_cast(k[i].data_ptr()); - param.v[i] = reinterpret_cast(v[i].data_ptr()); + param.attention_mask.resize({attn_mask.size(0), attn_mask.size(1), attn_mask.size(2), attn_mask.size(3)}, attn_mask.data_ptr()); + param.k.resize({k.size(0), k.size(1), k.size(2), k.size(3)}, reinterpret_cast(k.data_ptr())); + if (alibi.dim() == 3) { + std::swap(param.k.m_dims[2], param.k.m_dims[3]); + std::swap(param.k.m_strides[2], param.k.m_strides[3]); + param.alibi.resize({alibi.size(0), alibi.size(1), alibi.size(2)}, alibi.data_ptr()); } - param.alibi = alibi.data_ptr(); + param.v.resize({v.size(0), v.size(1), v.size(2), v.size(3)}, reinterpret_cast(v.data_ptr())); self.exec(param); return out; @@ -133,22 +131,19 @@ void regclass_mha_gpt(pybind11::module m) { head_size = head_size == 0 ? head_size_aligned : head_size; auto out = q.new_empty({batch, query_seq_len, num_heads * head_size}, torch::TensorOptions(torch::kInt8)); AT_ASSERT((int64_t)param.key_seq_len == attn_len); - param.q = reinterpret_cast(q.data_ptr()); - param.attn_output = reinterpret_cast(out.data_ptr()); - param.head_stride_in_kv = max_seq_len * head_size_aligned; - param.k = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); - param.v = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); + + param.q.resize({q.size(0), q.size(1), q.size(2), q.size(3)}, reinterpret_cast(q.data_ptr())); + param.attn_output.resize({batch, query_seq_len, num_heads * head_size}, reinterpret_cast(out.data_ptr())); param.is_causal_in_attention = attn_mask.size(2) != 1; - param.attention_mask = attn_mask.data_ptr(); + param.attention_mask.resize({attn_mask.size(0), attn_mask.size(1), attn_mask.size(2), attn_mask.size(3)}, attn_mask.data_ptr()); + param.k.resize({k.size(0), k.size(1), k.size(2), k.size(3)}, reinterpret_cast(k.data_ptr())); + param.v.resize({v.size(0), v.size(1), v.size(2), v.size(3)}, reinterpret_cast(v.data_ptr())); + param.q_dequant = q_dequant; param.k_dequant = k_dequant; param.v_dequant = v_dequant; param.qk_quant = qk_quant; param.qkv_quant = qkv_quant; - for (int i = 0; i < batch; i++) { - param.k[i] = reinterpret_cast(k[i].data_ptr()); - param.v[i] = reinterpret_cast(v[i].data_ptr()); - } self.exec(param); return out; diff --git a/tests/script/ext/setup.py b/tests/script/ext/setup.py index 2780c94..3ea5aa9 100644 --- a/tests/script/ext/setup.py +++ b/tests/script/ext/setup.py @@ -16,7 +16,7 @@ debug = False if 'DEBUG_EXT' in os.environ: debug = True if os.environ['DEBUG_EXT'] == '1' else False -extra_args = ['-fopenmp', +extra_args = ['-fopenmp', '-Wno-narrowing', '-Wno-attributes', '-march=native'] cpu_extensions_lib_dir = f'{os.getcwd()}/../../../build/lib' if debug: diff --git a/tests/script/test_mha_bloom.py b/tests/script/test_mha_bloom.py index 0e98148..fee6f39 100644 --- a/tests/script/test_mha_bloom.py +++ b/tests/script/test_mha_bloom.py @@ -173,7 +173,7 @@ def __init__(self, num_attention_heads, hidden_size, max_position_embeddings, is normal_factor = 1.0 / math.sqrt(head_size) qkv_precision_name = 's8' if is_int8 else 'bf16' dst_precision_name = 's8' if is_int8 else 'bf16' - self.mha.create(num_heads, head_size, head_size_aligned, normal_factor, qkv_precision_name, + self.mha.create(num_heads, head_size, normal_factor, qkv_precision_name, dst_precision_name, max_seq_len, True) def forward(self, query, key, value, alibi, attention_mask): diff --git a/tests/script/test_mha_chatglm.py b/tests/script/test_mha_chatglm.py index dd33799..7aa2d3b 100644 --- a/tests/script/test_mha_chatglm.py +++ b/tests/script/test_mha_chatglm.py @@ -118,11 +118,11 @@ def __init__(self, num_attention_heads, hidden_size, head_size_aligned, max_posi normal_factor = 1.0 / math.sqrt(head_size) qkv_precision_name = 's8' if is_int8 else 'bf16' dst_precision_name = 's8' if is_int8 else 'bf16' - self.mha.create(num_heads, head_size, head_size_aligned, normal_factor, qkv_precision_name, + self.mha.create(num_heads, head_size, normal_factor, qkv_precision_name, dst_precision_name, max_seq_len) def forward(self, query, key, value, attention_mask, head_size, key_seq_len): - return self.mha.exec(query, key, value, torch.tensor(1.0), attention_mask, head_size, key_seq_len) + return self.mha.exec(query, key, value, torch.tensor([]), attention_mask, head_size, key_seq_len) def forward_quant(self, query, key, value, attention_mask, q_quant, k_quant, qk_quant, v_quant, requant): # q_dequant, k_dequant, v_dequant, qk_quant, std::vector& qkv_quant diff --git a/tests/script/test_mha_gpt.py b/tests/script/test_mha_gpt.py index 8419c8b..b3e7c8c 100644 --- a/tests/script/test_mha_gpt.py +++ b/tests/script/test_mha_gpt.py @@ -115,11 +115,11 @@ def __init__(self, num_attention_heads, hidden_size, max_position_embeddings, is normal_factor = 1.0 / math.sqrt(head_size) qkv_precision_name = 's8' if is_int8 else 'bf16' dst_precision_name = 's8' if is_int8 else 'bf16' - self.mha.create(num_heads, head_size, head_size_aligned, normal_factor, qkv_precision_name, + self.mha.create(num_heads, head_size, normal_factor, qkv_precision_name, dst_precision_name, max_seq_len) def forward(self, query, key, value, attention_mask): - return self.mha.exec(query, key, value, torch.tensor(1.0), attention_mask) + return self.mha.exec(query, key, value, torch.tensor([]), attention_mask) def forward_quant(self, query, key, value, attention_mask, q_quant, k_quant, qk_quant, v_quant, requant): # q_dequant, k_dequant, v_dequant, qk_quant, std::vector& qkv_quant From 6205ab796e49493a99df315e889a9eae5faa0897 Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Thu, 13 Jul 2023 00:05:32 +0800 Subject: [PATCH 26/54] remove chatglm dependency --- tests/script/models/chatglm-6b/LICENSE | 201 --- tests/script/models/chatglm-6b/MODEL_LICENSE | 65 - tests/script/models/chatglm-6b/README.md | 89 - tests/script/models/chatglm-6b/config.json | 28 - .../chatglm-6b/configuration_chatglm.py | 103 -- .../models/chatglm-6b/modeling_chatglm.org.py | 1450 ---------------- .../models/chatglm-6b/modeling_chatglm.py | 1490 ----------------- .../chatglm-6b/pytorch_model.bin.index.json | 375 ----- .../script/models/chatglm-6b/quantization.py | 201 --- .../chatglm-6b/test_modeling_chatglm.py | 245 --- .../models/chatglm-6b/tokenization_chatglm.py | 443 ----- .../models/chatglm-6b/tokenizer_config.json | 20 - tests/script/requirements.txt | 1 - 13 files changed, 4711 deletions(-) delete mode 100644 tests/script/models/chatglm-6b/LICENSE delete mode 100644 tests/script/models/chatglm-6b/MODEL_LICENSE delete mode 100644 tests/script/models/chatglm-6b/README.md delete mode 100644 tests/script/models/chatglm-6b/config.json delete mode 100644 tests/script/models/chatglm-6b/configuration_chatglm.py delete mode 100644 tests/script/models/chatglm-6b/modeling_chatglm.org.py delete mode 100644 tests/script/models/chatglm-6b/modeling_chatglm.py delete mode 100644 tests/script/models/chatglm-6b/pytorch_model.bin.index.json delete mode 100644 tests/script/models/chatglm-6b/quantization.py delete mode 100644 tests/script/models/chatglm-6b/test_modeling_chatglm.py delete mode 100644 tests/script/models/chatglm-6b/tokenization_chatglm.py delete mode 100644 tests/script/models/chatglm-6b/tokenizer_config.json diff --git a/tests/script/models/chatglm-6b/LICENSE b/tests/script/models/chatglm-6b/LICENSE deleted file mode 100644 index ac4aee5..0000000 --- a/tests/script/models/chatglm-6b/LICENSE +++ /dev/null @@ -1,201 +0,0 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright Zhengxiao Du - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. \ No newline at end of file diff --git a/tests/script/models/chatglm-6b/MODEL_LICENSE b/tests/script/models/chatglm-6b/MODEL_LICENSE deleted file mode 100644 index f8e2731..0000000 --- a/tests/script/models/chatglm-6b/MODEL_LICENSE +++ /dev/null @@ -1,65 +0,0 @@ -The ChatGLM-6B License - -一、定义 - -“许可方”是指分发其软件的 ChatGLM-6B 模型团队。 - -“软件”是指根据本许可提供的 ChatGLM-6B 模型参数。 - -2. 许可授予 - -根据本许可的条款和条件,许可方特此授予您非排他性、全球性、不可转让、不可再许可、可撤销、免版税的版权许可,仅用于您的非商业研究目的。 - -上述版权声明和本许可声明应包含在本软件的所有副本或重要部分中。 - -3.限制 - -您不得出于任何商业、军事或非法目的使用、复制、修改、合并、发布、分发、复制或创建本软件的全部或部分衍生作品。 - -您不得利用本软件从事任何危害国家安全和国家统一、危害社会公共利益、侵犯人身权益的行为。 - -4.免责声明 - -本软件“按原样”提供,不提供任何明示或暗示的保证,包括但不限于对适销性、特定用途的适用性和非侵权性的保证。 在任何情况下,作者或版权持有人均不对任何索赔、损害或其他责任负责,无论是在合同诉讼、侵权行为还是其他方面,由软件或软件的使用或其他交易引起、由软件引起或与之相关 软件。 - -5. 责任限制 - -除适用法律禁止的范围外,在任何情况下且根据任何法律理论,无论是基于侵权行为、疏忽、合同、责任或其他原因,任何许可方均不对您承担任何直接、间接、特殊、偶然、示范性、 或间接损害,或任何其他商业损失,即使许可人已被告知此类损害的可能性。 - -6.争议解决 - -本许可受中华人民共和国法律管辖并按其解释。 因本许可引起的或与本许可有关的任何争议应提交北京市海淀区人民法院。 - -请注意,许可证可能会更新到更全面的版本。 有关许可和版权的任何问题,请通过 glm-130b@googlegroups.com 与我们联系。 - -1. Definitions - -“Licensor” means the ChatGLM-6B Model Team that distributes its Software. - -“Software” means the ChatGLM-6B model parameters made available under this license. - -2. License Grant - -Subject to the terms and conditions of this License, the Licensor hereby grants to you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty-free copyright license to use the Software solely for your non-commercial research purposes. - -The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. - -3. Restriction - -You will not use, copy, modify, merge, publish, distribute, reproduce, or create derivative works of the Software, in whole or in part, for any commercial, military, or illegal purposes. - -You will not use the Software for any act that may undermine China's national security and national unity, harm the public interest of society, or infringe upon the rights and interests of human beings. - -4. Disclaimer - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -5. Limitation of Liability - -EXCEPT TO THE EXTENT PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER BASED IN TORT, NEGLIGENCE, CONTRACT, LIABILITY, OR OTHERWISE WILL ANY LICENSOR BE LIABLE TO YOU FOR ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES, OR ANY OTHER COMMERCIAL LOSSES, EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. - -6. Dispute Resolution - -This license shall be governed and construed in accordance with the laws of People’s Republic of China. Any dispute arising from or in connection with this License shall be submitted to Haidian District People's Court in Beijing. - -Note that the license is subject to update to a more comprehensive version. For any questions related to the license and copyright, please contact us at glm-130b@googlegroups.com. diff --git a/tests/script/models/chatglm-6b/README.md b/tests/script/models/chatglm-6b/README.md deleted file mode 100644 index 857edd7..0000000 --- a/tests/script/models/chatglm-6b/README.md +++ /dev/null @@ -1,89 +0,0 @@ ---- -language: -- zh -- en -tags: -- glm -- chatglm -- thudm ---- -# ChatGLM-6B -

- 🌐 Blog • 💻 Github Repo • 🐦 Twitter • 📃 [GLM@ACL 22] [GitHub] • 📃 [GLM-130B@ICLR 23] [GitHub]
-

- -

- 👋 Join our Slack and WeChat -

- -## 介绍 -ChatGLM-6B 是一个开源的、支持中英双语问答的对话语言模型,基于 [General Language Model (GLM)](https://github.com/THUDM/GLM) 架构,具有 62 亿参数。结合模型量化技术,用户可以在消费级的显卡上进行本地部署(INT4 量化级别下最低只需 6GB 显存)。ChatGLM-6B 使用了和 [ChatGLM](https://chatglm.cn) 相同的技术,针对中文问答和对话进行了优化。经过约 1T 标识符的中英双语训练,辅以监督微调、反馈自助、人类反馈强化学习等技术的加持,62 亿参数的 ChatGLM-6B 已经能生成相当符合人类偏好的回答。 - -ChatGLM-6B is an open bilingual language model based on [General Language Model (GLM)](https://github.com/THUDM/GLM) framework, with 6.2 billion parameters. With the quantization technique, users can deploy locally on consumer-grade graphics cards (only 6GB of GPU memory is required at the INT4 quantization level). ChatGLM-6B uses technology similar to ChatGPT, optimized for Chinese QA and dialogue. The model is trained for about 1T tokens of Chinese and English corpus, supplemented by supervised fine-tuning, feedback bootstrap, and reinforcement learning wit human feedback. With only about 6.2 billion parameters, the model is able to generate answers that are in line with human preference. - -## 软件依赖 - -```shell -pip install protobuf==3.20.0 transformers==4.27.1 icetk cpm_kernels -``` - -## 代码调用 - -可以通过如下代码调用 ChatGLM-6B 模型来生成对话: - -```ipython ->>> from transformers import AutoTokenizer, AutoModel ->>> tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) ->>> model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() ->>> response, history = model.chat(tokenizer, "你好", history=[]) ->>> print(response) -你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。 ->>> response, history = model.chat(tokenizer, "晚上睡不着应该怎么办", history=history) ->>> print(response) -晚上睡不着可能会让你感到焦虑或不舒服,但以下是一些可以帮助你入睡的方法: - -1. 制定规律的睡眠时间表:保持规律的睡眠时间表可以帮助你建立健康的睡眠习惯,使你更容易入睡。尽量在每天的相同时间上床,并在同一时间起床。 -2. 创造一个舒适的睡眠环境:确保睡眠环境舒适,安静,黑暗且温度适宜。可以使用舒适的床上用品,并保持房间通风。 -3. 放松身心:在睡前做些放松的活动,例如泡个热水澡,听些轻柔的音乐,阅读一些有趣的书籍等,有助于缓解紧张和焦虑,使你更容易入睡。 -4. 避免饮用含有咖啡因的饮料:咖啡因是一种刺激性物质,会影响你的睡眠质量。尽量避免在睡前饮用含有咖啡因的饮料,例如咖啡,茶和可乐。 -5. 避免在床上做与睡眠无关的事情:在床上做些与睡眠无关的事情,例如看电影,玩游戏或工作等,可能会干扰你的睡眠。 -6. 尝试呼吸技巧:深呼吸是一种放松技巧,可以帮助你缓解紧张和焦虑,使你更容易入睡。试着慢慢吸气,保持几秒钟,然后缓慢呼气。 - -如果这些方法无法帮助你入睡,你可以考虑咨询医生或睡眠专家,寻求进一步的建议。 -``` - -关于更多的使用说明,包括如何运行命令行和网页版本的 DEMO,以及使用模型量化以节省显存,请参考我们的 [Github Repo](https://github.com/THUDM/ChatGLM-6B)。 - -For more instructions, including how to run CLI and web demos, and model quantization, please refer to our [Github Repo](https://github.com/THUDM/ChatGLM-6B). - -## Change Log -* v1.1.0 ([942945d](https://huggingface.co/THUDM/chatglm-6b/commit/942945df047dee66f653c68ae0e56655045f1741)): 更新 v1.1 版本 checkpoint -* v0.1.0 ([f831824](https://huggingface.co/THUDM/chatglm-6b/commit/f83182484538e663a03d3f73647f10f89878f438)) - -## 协议 - -本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源,ChatGLM-6B 模型的权重的使用则需要遵循 [Model License](MODEL_LICENSE)。 - -## 引用 - -如果你觉得我们的工作有帮助的话,请考虑引用下列论文: - -``` -@inproceedings{ - zeng2023glm-130b, - title={{GLM}-130B: An Open Bilingual Pre-trained Model}, - author={Aohan Zeng and Xiao Liu and Zhengxiao Du and Zihan Wang and Hanyu Lai and Ming Ding and Zhuoyi Yang and Yifan Xu and Wendi Zheng and Xiao Xia and Weng Lam Tam and Zixuan Ma and Yufei Xue and Jidong Zhai and Wenguang Chen and Zhiyuan Liu and Peng Zhang and Yuxiao Dong and Jie Tang}, - booktitle={The Eleventh International Conference on Learning Representations (ICLR)}, - year={2023}, - url={https://openreview.net/forum?id=-Aw0rrrPUF} -} -``` -``` -@inproceedings{du2022glm, - title={GLM: General Language Model Pretraining with Autoregressive Blank Infilling}, - author={Du, Zhengxiao and Qian, Yujie and Liu, Xiao and Ding, Ming and Qiu, Jiezhong and Yang, Zhilin and Tang, Jie}, - booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)}, - pages={320--335}, - year={2022} -} -``` \ No newline at end of file diff --git a/tests/script/models/chatglm-6b/config.json b/tests/script/models/chatglm-6b/config.json deleted file mode 100644 index 7cc6e70..0000000 --- a/tests/script/models/chatglm-6b/config.json +++ /dev/null @@ -1,28 +0,0 @@ -{ - "_name_or_path": "THUDM/chatglm-6b", - "architectures": [ - "ChatGLMModel" - ], - "auto_map": { - "AutoConfig": "configuration_chatglm.ChatGLMConfig", - "AutoModel": "modeling_chatglm.ChatGLMForConditionalGeneration", - "AutoModelForSeq2SeqLM": "modeling_chatglm.ChatGLMForConditionalGeneration" - }, - "bos_token_id": 130004, - "eos_token_id": 130005, - "mask_token_id": 130000, - "gmask_token_id": 130001, - "pad_token_id": 3, - "hidden_size": 4096, - "inner_hidden_size": 16384, - "layernorm_epsilon": 1e-05, - "max_sequence_length": 2048, - "model_type": "chatglm", - "num_attention_heads": 32, - "num_layers": 28, - "position_encoding_2d": true, - "torch_dtype": "float32", - "transformers_version": "4.23.1", - "use_cache": true, - "vocab_size": 130528 -} diff --git a/tests/script/models/chatglm-6b/configuration_chatglm.py b/tests/script/models/chatglm-6b/configuration_chatglm.py deleted file mode 100644 index 78f3425..0000000 --- a/tests/script/models/chatglm-6b/configuration_chatglm.py +++ /dev/null @@ -1,103 +0,0 @@ -""" ChatGLM model configuration """ - -from transformers.configuration_utils import PretrainedConfig -from transformers.utils import logging - -logger = logging.get_logger(__name__) - - -class ChatGLMConfig(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`~ChatGLMModel`]. - It is used to instantiate an ChatGLM model according to the specified arguments, defining the model - architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of - the ChatGLM-6B [THUDM/ChatGLM-6B](https://huggingface.co/THUDM/chatglm-6b) architecture. - - Configuration objects inherit from [`PretrainedConfig`] and can be used - to control the model outputs. Read the documentation from [`PretrainedConfig`] - for more information. - - - Args: - vocab_size (`int`, *optional*, defaults to 150528): - Vocabulary size of the ChatGLM-6B model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`~ChatGLMModel`] or - [`~TFChatGLMModel`]. - hidden_size (`int`, *optional*, defaults to 4096): - Dimension of the encoder layers and the pooler layer. - num_hidden_layers (`int`, *optional*, defaults to 28): - Number of hidden layers in the Transformer encoder. - num_attention_heads (`int`, *optional*, defaults to 32): - Number of attention heads for each attention layer in the Transformer encoder. - inner_hidden_size (`int`, *optional*, defaults to 16384): - Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. - max_sequence_length (`int`, *optional*, defaults to 512): - The maximum sequence length that this model might ever be used with. - Typically set this to something large just in case (e.g., 512 or 1024 or 2048). - layernorm_epsilon (`float`, *optional*, defaults to 1e-5): - The epsilon used by the layer normalization layers. - use_cache (`bool`, *optional*, defaults to `True`): - Whether the model should return the last key/values attentions (not used by all models). - Example: - - ```python - >>> from configuration_chatglm import ChatGLMConfig - >>> from modeling_chatglm import ChatGLMModel - - >>> # Initializing a ChatGLM-6B THUDM/ChatGLM-6B style configuration - >>> configuration = ChatGLMConfig() - - >>> # Initializing a model from the THUDM/ChatGLM-6B style configuration - >>> model = ChatGLMModel(configuration) - - >>> # Accessing the model configuration - >>> configuration = model.config - ``` -""" - model_type = "chatglm" - - def __init__( - self, - vocab_size=150528, - hidden_size=4096, - num_layers=28, - num_attention_heads=32, - layernorm_epsilon=1e-5, - use_cache=False, - bos_token_id=150004, - eos_token_id=150005, - mask_token_id=150000, - gmask_token_id=150001, - pad_token_id=0, - max_sequence_length=2048, - inner_hidden_size=16384, - position_encoding_2d=True, - quantization_bit=0, - pre_seq_len=None, - prefix_projection=False, - **kwargs - ): - self.num_layers = num_layers - self.vocab_size = vocab_size - self.hidden_size = hidden_size - self.num_attention_heads = num_attention_heads - self.max_sequence_length = max_sequence_length - self.layernorm_epsilon = layernorm_epsilon - self.inner_hidden_size = inner_hidden_size - self.use_cache = use_cache - self.bos_token_id = bos_token_id - self.eos_token_id = eos_token_id - self.pad_token_id = pad_token_id - self.mask_token_id = mask_token_id - self.gmask_token_id = gmask_token_id - self.position_encoding_2d = position_encoding_2d - self.quantization_bit = quantization_bit - self.pre_seq_len = pre_seq_len - self.prefix_projection = prefix_projection - - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - **kwargs - ) diff --git a/tests/script/models/chatglm-6b/modeling_chatglm.org.py b/tests/script/models/chatglm-6b/modeling_chatglm.org.py deleted file mode 100644 index d24bf52..0000000 --- a/tests/script/models/chatglm-6b/modeling_chatglm.org.py +++ /dev/null @@ -1,1450 +0,0 @@ -""" PyTorch ChatGLM model. """ - -import math -import copy -import os -import warnings -import re -import sys - -import torch -import torch.utils.checkpoint -import torch.nn.functional as F -from torch import nn -from torch.nn import CrossEntropyLoss, LayerNorm -from torch.nn.utils import skip_init -from typing import Optional, Tuple, Union, List, Callable, Dict, Any - -from transformers.utils import ( - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, -) -from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, - BaseModelOutputWithPastAndCrossAttentions, -) -from transformers.modeling_utils import PreTrainedModel -from transformers.utils import logging -from transformers.generation.logits_process import LogitsProcessor -from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput - -from .configuration_chatglm import ChatGLMConfig - -# flags required to enable jit fusion kernels - -if sys.platform != 'darwin': - torch._C._jit_set_profiling_mode(False) - torch._C._jit_set_profiling_executor(False) - torch._C._jit_override_can_fuse_on_cpu(True) - torch._C._jit_override_can_fuse_on_gpu(True) - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "THUDM/ChatGLM-6B" -_CONFIG_FOR_DOC = "ChatGLM6BConfig" - -CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [ - "THUDM/chatglm-6b", - # See all ChatGLM-6B models at https://huggingface.co/models?filter=chatglm -] - - -class InvalidScoreLogitsProcessor(LogitsProcessor): - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - if torch.isnan(scores).any() or torch.isinf(scores).any(): - scores.zero_() - scores[..., 5] = 5e4 - return scores - - -def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path): - """Load tf checkpoints in a pytorch model.""" - try: - import re - - import numpy as np - import tensorflow as tf - except ImportError: - logger.error( - "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " - "https://www.tensorflow.org/install/ for installation instructions." - ) - raise - tf_path = os.path.abspath(tf_checkpoint_path) - logger.info(f"Converting TensorFlow checkpoint from {tf_path}") - # Load weights from TF model - init_vars = tf.train.list_variables(tf_path) - names = [] - arrays = [] - for name, shape in init_vars: - logger.info(f"Loading TF weight {name} with shape {shape}") - array = tf.train.load_variable(tf_path, name) - names.append(name) - arrays.append(array) - - for name, array in zip(names, arrays): - name = name.split("/") - # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v - # which are not required for using pretrained model - if any( - n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] - for n in name - ): - logger.info(f"Skipping {'/'.join(name)}") - continue - pointer = model - for m_name in name: - if re.fullmatch(r"[A-Za-z]+_\d+", m_name): - scope_names = re.split(r"_(\d+)", m_name) - else: - scope_names = [m_name] - if scope_names[0] == "kernel" or scope_names[0] == "gamma": - pointer = getattr(pointer, "weight") - elif scope_names[0] == "output_bias" or scope_names[0] == "beta": - pointer = getattr(pointer, "bias") - elif scope_names[0] == "output_weights": - pointer = getattr(pointer, "weight") - elif scope_names[0] == "squad": - pointer = getattr(pointer, "classifier") - else: - try: - pointer = getattr(pointer, scope_names[0]) - except AttributeError: - logger.info(f"Skipping {'/'.join(name)}") - continue - if len(scope_names) >= 2: - num = int(scope_names[1]) - pointer = pointer[num] - if m_name[-11:] == "_embeddings": - pointer = getattr(pointer, "weight") - elif m_name == "kernel": - array = np.transpose(array) - try: - assert ( - pointer.shape == array.shape - ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" - except AssertionError as e: - e.args += (pointer.shape, array.shape) - raise - logger.info(f"Initialize PyTorch weight {name}") - pointer.data = torch.from_numpy(array) - return model - - -class PrefixEncoder(torch.nn.Module): - """ - The torch.nn model to encode the prefix - Input shape: (batch-size, prefix-length) - Output shape: (batch-size, prefix-length, 2*layers*hidden) - """ - - def __init__(self, config): - super().__init__() - self.prefix_projection = config.prefix_projection - if self.prefix_projection: - # Use a two-layer MLP to encode the prefix - self.embedding = torch.nn.Embedding(config.pre_seq_len, config.hidden_size) - self.trans = torch.nn.Sequential( - torch.nn.Linear(config.hidden_size, config.hidden_size), - torch.nn.Tanh(), - torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2) - ) - else: - self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_layers * config.hidden_size * 2) - - def forward(self, prefix: torch.Tensor): - if self.prefix_projection: - prefix_tokens = self.embedding(prefix) - past_key_values = self.trans(prefix_tokens) - else: - past_key_values = self.embedding(prefix) - return past_key_values - - -@torch.jit.script -def gelu_impl(x): - """OpenAI's gelu implementation.""" - return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * - (1.0 + 0.044715 * x * x))) - - -def gelu(x): - return gelu_impl(x) - - -class RotaryEmbedding(torch.nn.Module): - def __init__(self, dim, base=10000, precision=torch.half, learnable=False): - super().__init__() - inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) - inv_freq = inv_freq.half() - self.learnable = learnable - if learnable: - self.inv_freq = torch.nn.Parameter(inv_freq) - self.max_seq_len_cached = None - else: - self.register_buffer('inv_freq', inv_freq) - self.max_seq_len_cached = None - self.cos_cached = None - self.sin_cached = None - self.precision = precision - - def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, - error_msgs): - pass - - def forward(self, x, seq_dim=1, seq_len=None): - if seq_len is None: - seq_len = x.shape[seq_dim] - if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached): - self.max_seq_len_cached = None if self.learnable else seq_len - t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype) - freqs = torch.einsum('i,j->ij', t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1).to(x.device) - if self.precision == torch.bfloat16: - emb = emb.float() - - # [sx, 1 (b * np), hn] - cos_cached = emb.cos()[:, None, :] - sin_cached = emb.sin()[:, None, :] - if self.precision == torch.bfloat16: - cos_cached = cos_cached.bfloat16() - sin_cached = sin_cached.bfloat16() - if self.learnable: - return cos_cached, sin_cached - self.cos_cached, self.sin_cached = cos_cached, sin_cached - return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...] - - def _apply(self, fn): - if self.cos_cached is not None: - self.cos_cached = fn(self.cos_cached) - if self.sin_cached is not None: - self.sin_cached = fn(self.sin_cached) - return super()._apply(fn) - - -def rotate_half(x): - x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] - return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions - - -@torch.jit.script -def apply_rotary_pos_emb_index(q, k, cos, sin, position_id): - # position_id: [b, sq], q, k: [b, sq, np, hn], cos: [sq, 1, hn] -> [b, sq, 1, hn] - cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), \ - F.embedding(position_id, sin.squeeze(1)).unsqueeze(2) - q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) - return q, k - - -def attention_fn( - self, - query_layer, - key_layer, - value_layer, - attention_mask, - hidden_size_per_partition, - layer_id, - layer_past=None, - scaling_attention_score=True, - use_cache=False, -): - if layer_past is not None: - past_key, past_value = layer_past[0], layer_past[1] - key_layer = torch.cat((past_key, key_layer), dim=1) - value_layer = torch.cat((past_value, value_layer), dim=1) - - # batch, seqlen, num_attention_heads, hidden_size_per_attention_head - b, seq_len, nh, hidden_size = key_layer.shape - - if use_cache: - present = (key_layer, value_layer) - else: - present = None - - query_key_layer_scaling_coeff = float(layer_id + 1) - if scaling_attention_score: - query_layer = query_layer / (math.sqrt(hidden_size) * query_key_layer_scaling_coeff) - - # =================================== - # Raw attention scores. [b, np, s, s] - # =================================== - - # [b, np, sq, sk] - output_size = (query_layer.size(0), query_layer.size(2), query_layer.size(1), key_layer.size(1)) - - # # [sq, b, np, hn] -> [sq, b * np, hn] - # query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) - # # [sk, b, np, hn] -> [sk, b * np, hn] - # key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) - # [b, sq, np, hn] -> [b, np, sq, hn] - query_layer = query_layer.transpose(1, 2) - # [b, sk, np, hn] -> [b, np, sk, hn] - key_layer = key_layer.transpose(1, 2) - # [b, np, sq, hn] -> [b * np, sq, hn] - query_layer = query_layer.view(output_size[0] * output_size[1], output_size[2], -1) - # [b, np, sk, hn] -> [b * np, sk, hn] - key_layer = key_layer.view(output_size[0] * output_size[1], output_size[3], -1) - - matmul_result = torch.zeros( - 1, 1, 1, - dtype=query_layer.dtype, - device=query_layer.device, - ) - - matmul_result = torch.baddbmm( - matmul_result, - query_layer, # [b * np, sq, hn] - key_layer.transpose(1, 2), # [b * np, hn, sk] - beta=0.0, - alpha=1.0, - ) - - # change view to [b, np, sq, sk] - attention_scores = matmul_result.view(*output_size) - - if self.scale_mask_softmax: - self.scale_mask_softmax.scale = query_key_layer_scaling_coeff - attention_probs = self.scale_mask_softmax(attention_scores, attention_mask.contiguous()) - else: - if not (attention_mask == 0).all(): - # if auto-regressive, skip - attention_scores.masked_fill_(attention_mask, -10000.0) - dtype = attention_scores.dtype - attention_scores = attention_scores.float() - attention_scores = attention_scores * query_key_layer_scaling_coeff - - attention_probs = F.softmax(attention_scores, dim=-1) - - attention_probs = attention_probs.type(dtype) - - # ========================= - # Context layer. [sq, b, hp] - # ========================= - - # value_layer -> context layer. - # [b, sk, np, hn] --> [b, np, sq, hn] - - # context layer shape: [b, np, sq, hn] - output_size = (value_layer.size(0), value_layer.size(2), query_layer.size(1), value_layer.size(3)) - - # # change view [sk, b * np, hn] - # value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) - # [b, sk, np, hn] -> [b, np, sk, hn] - value_layer = value_layer.transpose(1, 2) - # [b, np, sk, hn] -> [b * np, sk, hn] - value_layer = value_layer.view(output_size[0] * output_size[1], value_layer.size(2), output_size[3]) - - # change view [b * np, sq, sk] - attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) - - # matmul: [b * np, sq, hn] - context_layer = torch.bmm(attention_probs, value_layer) - - # change view [b, np, sq, hn] - context_layer = context_layer.view(*output_size) - - # [b, np, sq, hn] --> [b, sq, np, hn] - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - - # [b, sq, np, hn] --> [b, sq, hp] - new_context_layer_shape = context_layer.size()[:-2] + (hidden_size_per_partition,) - context_layer = context_layer.view(*new_context_layer_shape) - - outputs = (context_layer, present, attention_probs) - - return outputs - - -def default_init(cls, *args, **kwargs): - return cls(*args, **kwargs) - - -class SelfAttention(torch.nn.Module): - def __init__(self, hidden_size, num_attention_heads, - layer_id, hidden_size_per_attention_head=None, bias=True, - params_dtype=torch.float, position_encoding_2d=True, empty_init=True): - if empty_init: - init_method = skip_init - else: - init_method = default_init - super(SelfAttention, self).__init__() - - self.layer_id = layer_id - self.hidden_size = hidden_size - self.hidden_size_per_partition = hidden_size - self.num_attention_heads = num_attention_heads - self.num_attention_heads_per_partition = num_attention_heads - self.position_encoding_2d = position_encoding_2d - self.rotary_emb = RotaryEmbedding( - self.hidden_size // (self.num_attention_heads * 2) - if position_encoding_2d - else self.hidden_size // self.num_attention_heads, - base=10000, - precision=torch.half, - learnable=False, - ) - - self.scale_mask_softmax = None - - if hidden_size_per_attention_head is None: - self.hidden_size_per_attention_head = hidden_size // num_attention_heads - else: - self.hidden_size_per_attention_head = hidden_size_per_attention_head - - self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head - - # Strided linear layer. - self.query_key_value = init_method( - torch.nn.Linear, - hidden_size, - 3 * self.inner_hidden_size, - bias=bias, - dtype=params_dtype, - ) - - self.dense = init_method( - torch.nn.Linear, - self.inner_hidden_size, - hidden_size, - bias=bias, - dtype=params_dtype, - ) - - @staticmethod - def attention_mask_func(attention_scores, attention_mask): - attention_scores.masked_fill_(attention_mask, -10000.0) - return attention_scores - - def split_tensor_along_last_dim(self, tensor, num_partitions, - contiguous_split_chunks=False): - """Split a tensor along its last dimension. - Arguments: - tensor: input tensor. - num_partitions: number of partitions to split the tensor - contiguous_split_chunks: If True, make each chunk contiguous - in memory. - """ - # Get the size and dimension. - last_dim = tensor.dim() - 1 - last_dim_size = tensor.size()[last_dim] // num_partitions - # Split. - tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) - # Note: torch.split does not create contiguous tensors by default. - if contiguous_split_chunks: - return tuple(chunk.contiguous() for chunk in tensor_list) - - return tensor_list - - def forward( - self, - hidden_states: torch.Tensor, - position_ids, - attention_mask: torch.Tensor, - layer_id, - layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - use_cache: bool = False, - output_attentions: bool = False, - ): - """ - hidden_states: [batch, seq_len, hidden_size] - attention_mask: [(1, 1), seq_len, seq_len] - """ - - # [batch, seq_len, 3 * hidden_size] - mixed_raw_layer = self.query_key_value(hidden_states) - - # [batch, seq_len, 3 * hidden_size] --> [batch, seq_len, num_attention_heads, 3 * hidden_size_per_attention_head] - new_tensor_shape = mixed_raw_layer.size()[:-1] + ( - self.num_attention_heads_per_partition, - 3 * self.hidden_size_per_attention_head, - ) - mixed_raw_layer = mixed_raw_layer.view(*new_tensor_shape) - - # [batch, seq_len, num_attention_heads, hidden_size_per_attention_head] - (query_layer, key_layer, value_layer) = self.split_tensor_along_last_dim(mixed_raw_layer, 3) - - if self.position_encoding_2d: - q1, q2 = query_layer.chunk(2, dim=(query_layer.ndim - 1)) - k1, k2 = key_layer.chunk(2, dim=(key_layer.ndim - 1)) - cos, sin = self.rotary_emb(q1, seq_len=position_ids.max() + 1) - # xxx - position_ids, block_position_ids = position_ids[:, 0, :].contiguous(), \ - position_ids[:, 1, :].contiguous() - q1, k1 = apply_rotary_pos_emb_index(q1, k1, cos, sin, position_ids) - q2, k2 = apply_rotary_pos_emb_index(q2, k2, cos, sin, block_position_ids) - query_layer = torch.concat([q1, q2], dim=(q1.ndim - 1)) - key_layer = torch.concat([k1, k2], dim=(k1.ndim - 1)) - else: - position_ids = position_ids.transpose(0, 1) - cos, sin = self.rotary_emb(value_layer, seq_len=position_ids.max() + 1) - # [batch, seq_len, num_attention_heads, hidden_size_per_attention_head] - query_layer, key_layer = apply_rotary_pos_emb_index(query_layer, key_layer, cos, sin, position_ids) - - # [batch, seq_len, hidden_size] - context_layer, present, attention_probs = attention_fn( - self=self, - query_layer=query_layer, - key_layer=key_layer, - value_layer=value_layer, - attention_mask=attention_mask, - hidden_size_per_partition=self.hidden_size_per_partition, - layer_id=layer_id, - layer_past=layer_past, - use_cache=use_cache - ) - - output = self.dense(context_layer) - - outputs = (output, present) - - if output_attentions: - outputs += (attention_probs,) - - return outputs # output, present, attention_probs - - -class GEGLU(torch.nn.Module): - def __init__(self): - super().__init__() - self.activation_fn = F.gelu - - def forward(self, x): - # dim=-1 breaks in jit for pt<1.10 - x1, x2 = x.chunk(2, dim=(x.ndim - 1)) - return x1 * self.activation_fn(x2) - - -class GLU(torch.nn.Module): - def __init__(self, hidden_size, inner_hidden_size=None, - layer_id=None, bias=True, activation_func=gelu, params_dtype=torch.float, empty_init=True): - super(GLU, self).__init__() - if empty_init: - init_method = skip_init - else: - init_method = default_init - self.layer_id = layer_id - self.activation_func = activation_func - - # Project to 4h. - self.hidden_size = hidden_size - if inner_hidden_size is None: - inner_hidden_size = 4 * hidden_size - self.inner_hidden_size = inner_hidden_size - self.dense_h_to_4h = init_method( - torch.nn.Linear, - self.hidden_size, - self.inner_hidden_size, - bias=bias, - dtype=params_dtype, - ) - # Project back to h. - self.dense_4h_to_h = init_method( - torch.nn.Linear, - self.inner_hidden_size, - self.hidden_size, - bias=bias, - dtype=params_dtype, - ) - - def forward(self, hidden_states): - """ - hidden_states: [seq_len, batch, hidden_size] - """ - - # [seq_len, batch, inner_hidden_size] - intermediate_parallel = self.dense_h_to_4h(hidden_states) - - intermediate_parallel = self.activation_func(intermediate_parallel) - - output = self.dense_4h_to_h(intermediate_parallel) - - return output - - -class GLMBlock(torch.nn.Module): - def __init__( - self, - hidden_size, - num_attention_heads, - layernorm_epsilon, - layer_id, - inner_hidden_size=None, - hidden_size_per_attention_head=None, - layernorm=LayerNorm, - use_bias=True, - params_dtype=torch.float, - num_layers=28, - position_encoding_2d=True, - empty_init=True - ): - super(GLMBlock, self).__init__() - # Set output layer initialization if not provided. - - self.layer_id = layer_id - - # Layernorm on the input data. - self.input_layernorm = layernorm(hidden_size, eps=layernorm_epsilon) - - self.position_encoding_2d = position_encoding_2d - - # Self attention. - self.attention = SelfAttention( - hidden_size, - num_attention_heads, - layer_id, - hidden_size_per_attention_head=hidden_size_per_attention_head, - bias=use_bias, - params_dtype=params_dtype, - position_encoding_2d=self.position_encoding_2d, - empty_init=empty_init - ) - - # Layernorm on the input data. - self.post_attention_layernorm = layernorm(hidden_size, eps=layernorm_epsilon) - - self.num_layers = num_layers - - # GLU - self.mlp = GLU( - hidden_size, - inner_hidden_size=inner_hidden_size, - bias=use_bias, - layer_id=layer_id, - params_dtype=params_dtype, - empty_init=empty_init - ) - - def forward( - self, - hidden_states: torch.Tensor, - position_ids, - attention_mask: torch.Tensor, - layer_id, - layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - use_cache: bool = False, - output_attentions: bool = False, - ): - """ - hidden_states: [seq_len, batch, hidden_size] - attention_mask: [(1, 1), seq_len, seq_len] - """ - - # Layer norm at the begining of the transformer layer. - # [batch, seq_len, hidden_size] - attention_input = self.input_layernorm(hidden_states) - - # Self attention. - attention_outputs = self.attention( - attention_input, - position_ids, - attention_mask=attention_mask, - layer_id=layer_id, - layer_past=layer_past, - use_cache=use_cache, - output_attentions=output_attentions - ) - - attention_output = attention_outputs[0] - - outputs = attention_outputs[1:] - - # Residual connection. - alpha = (2 * self.num_layers) ** 0.5 - hidden_states = attention_input * alpha + attention_output - - mlp_input = self.post_attention_layernorm(hidden_states) - - # MLP. - mlp_output = self.mlp(mlp_input) - - # Second residual connection. - output = mlp_input * alpha + mlp_output - - if use_cache: - outputs = (output,) + outputs - else: - outputs = (output,) + outputs[1:] - - return outputs # hidden_states, present, attentions - - -class ChatGLMPreTrainedModel(PreTrainedModel): - """ - An abstract class to handle weights initialization and - a simple interface for downloading and loading pretrained models. - """ - - is_parallelizable = False - supports_gradient_checkpointing = True - config_class = ChatGLMConfig - base_model_prefix = "transformer" - _no_split_modules = ["GLMBlock"] - - def __init__(self, *inputs, **kwargs): - super().__init__(*inputs, **kwargs) - - def _init_weights(self, module: nn.Module): - """Initialize the weights.""" - return - - def get_masks(self, input_ids, device): - batch_size, seq_length = input_ids.shape - context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids] - attention_mask = torch.ones((batch_size, seq_length, seq_length), device=device) - attention_mask.tril_() - for i, context_length in enumerate(context_lengths): - attention_mask[i, :, :context_length] = 1 - attention_mask.unsqueeze_(1) - attention_mask = (attention_mask < 0.5).bool() - - return attention_mask - - def get_position_ids(self, input_ids, mask_positions, device, use_gmasks=None): - batch_size, seq_length = input_ids.shape - if use_gmasks is None: - use_gmasks = [False] * batch_size - context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids] - if self.position_encoding_2d: - position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) - for i, context_length in enumerate(context_lengths): - position_ids[i, context_length:] = mask_positions[i] - block_position_ids = [torch.cat(( - torch.zeros(context_length, dtype=torch.long, device=device), - torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1 - )) for context_length in context_lengths] - block_position_ids = torch.stack(block_position_ids, dim=0) - position_ids = torch.stack((position_ids, block_position_ids), dim=1) - else: - position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) - for i, context_length in enumerate(context_lengths): - if not use_gmasks[i]: - position_ids[i, context_length:] = mask_positions[i] - - return position_ids - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, ChatGLMModel): - module.gradient_checkpointing = value - - -CHATGLM_6B_START_DOCSTRING = r""" - This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general - usage and behavior. - - Parameters: - config ([`~ChatGLM6BConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the configuration. - Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -CHATGLM_6B_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `({0})`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`ChatGLM6BTokenizer`]. - See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): - Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, 1]`: - - - 0 corresponds to a *sentence A* token, - - 1 corresponds to a *sentence B* token. - - [What are token type IDs?](../glossary#token-type-ids) - position_ids (`torch.LongTensor` of shape `({0})`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. - Selected in the range `[0, config.max_position_embeddings - 1]`. - - [What are position IDs?](../glossary#position-ids) - head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert *input_ids* indices into associated vectors - than the model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -@add_start_docstrings( - "The bare ChatGLM-6B Model transformer outputting raw hidden-states without any specific head on top.", - CHATGLM_6B_START_DOCSTRING, -) -class ChatGLMModel(ChatGLMPreTrainedModel): - """ - - The model can behave as an encoder (with only self-attention) as well - as a decoder, in which case a layer of cross-attention is added between - the self-attention layers, following the architecture described in [Attention is - all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, - Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. - - To behave as an decoder the model needs to be initialized with the - `is_decoder` argument of the configuration set to `True`. - To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` - argument and `add_cross_attention` set to `True`; an - `encoder_hidden_states` is then expected as an input to the forward pass. - """ - - def __init__(self, config: ChatGLMConfig, empty_init=True): - super().__init__(config) - if empty_init: - init_method = skip_init - else: - init_method = default_init - # recording parameters - self.max_sequence_length = config.max_sequence_length - self.hidden_size = config.hidden_size - self.params_dtype = torch.half - self.num_attention_heads = config.num_attention_heads - self.vocab_size = config.vocab_size - self.num_layers = config.num_layers - self.layernorm_epsilon = config.layernorm_epsilon - self.inner_hidden_size = config.inner_hidden_size - self.hidden_size_per_attention_head = self.hidden_size // self.num_attention_heads - self.position_encoding_2d = config.position_encoding_2d - self.pre_seq_len = config.pre_seq_len - self.prefix_projection = config.prefix_projection - - self.word_embeddings = init_method( - torch.nn.Embedding, - num_embeddings=self.vocab_size, embedding_dim=self.hidden_size, - dtype=self.params_dtype - ) - self.gradient_checkpointing = False - - def get_layer(layer_id): - return GLMBlock( - self.hidden_size, - self.num_attention_heads, - self.layernorm_epsilon, - layer_id, - inner_hidden_size=self.inner_hidden_size, - hidden_size_per_attention_head=self.hidden_size_per_attention_head, - layernorm=LayerNorm, - use_bias=True, - params_dtype=self.params_dtype, - position_encoding_2d=self.position_encoding_2d, - empty_init=empty_init - ) - - self.layers = torch.nn.ModuleList( - [get_layer(layer_id) for layer_id in range(self.num_layers)] - ) - - # Final layer norm before output. - self.final_layernorm = LayerNorm(self.hidden_size, eps=self.layernorm_epsilon) - - if self.pre_seq_len is not None: - for param in self.parameters(): - param.requires_grad = False - self.prefix_tokens = torch.arange(self.pre_seq_len).long() - self.prefix_encoder = PrefixEncoder(config) - self.dropout = torch.nn.Dropout(0.1) - - # total_params = sum(p.numel() for p in self.parameters()) - # trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) - # print("Using p-tuning v2: # trainable_params = {} / {}".format(trainable_params, total_params)) - - def get_input_embeddings(self): - return self.word_embeddings - - def set_input_embeddings(self, new_embeddings: torch.Tensor): - self.word_embeddings = new_embeddings - - def get_prompt(self, batch_size, device, dtype=torch.half): - prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device) - past_key_values = self.prefix_encoder(prefix_tokens).type(dtype) - past_key_values = past_key_values.view( - batch_size, - self.pre_seq_len, - self.num_layers * 2, - self.num_attention_heads, - self.hidden_size // self.num_attention_heads - ) - # seq_len, b, nh, hidden_size - past_key_values = self.dropout(past_key_values) - past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2) - # past_key_values = [(v[0], v[1]) for v in past_key_values] - return past_key_values - - @add_start_docstrings_to_model_forward(CHATGLM_6B_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=BaseModelOutputWithPastAndCrossAttentions, - config_class=_CONFIG_FOR_DOC, - ) - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, - inputs_embeds: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPast]: - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape[:2] - elif inputs_embeds is not None: - batch_size, seq_length = inputs_embeds.shape[:2] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if inputs_embeds is None: - inputs_embeds = self.word_embeddings(input_ids) - - if past_key_values is None: - if self.pre_seq_len is not None: - past_key_values = self.get_prompt(batch_size=input_ids.shape[0], device=input_ids.device, - dtype=inputs_embeds.dtype) - else: - past_key_values = tuple([None] * len(self.layers)) - - if attention_mask is None: - attention_mask = self.get_masks( - input_ids, - device=input_ids.device - ) - - - if position_ids is None: - MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id - seqs = input_ids.tolist() - - mask_positions, use_gmasks = [], [] - for seq in seqs: - mask_token = gMASK if gMASK in seq else MASK - use_gmask = mask_token == gMASK - mask_positions.append(seq.index(mask_token)) - use_gmasks.append(use_gmask) - - position_ids = self.get_position_ids( - input_ids, - mask_positions=mask_positions, - device=input_ids.device, - use_gmasks=use_gmasks - ) - - if self.pre_seq_len is not None and attention_mask is not None: - prefix_attention_mask = torch.ones(batch_size, 1, input_ids.size(-1), self.pre_seq_len).to( - attention_mask.device) - prefix_attention_mask = (prefix_attention_mask < 0.5).bool() - attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3) - - # [seq_len, batch, hidden_size] - # hidden_states = inputs_embeds.transpose(0, 1) - # xxx - hidden_states = inputs_embeds - - presents = () if use_cache else None - all_self_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - - if attention_mask is None: - attention_mask = torch.zeros(1, 1, device=input_ids.device).bool() - else: - attention_mask = attention_mask.to(hidden_states.device) - - for i, layer in enumerate(self.layers): - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - layer_past = past_key_values[i] - - if self.gradient_checkpointing and self.training: - layer_ret = torch.utils.checkpoint.checkpoint( - layer, - hidden_states, - position_ids, - attention_mask, - torch.tensor(i), - layer_past, - use_cache, - output_attentions - ) - else: - layer_ret = layer( - hidden_states, - position_ids=position_ids, - attention_mask=attention_mask, - layer_id=torch.tensor(i), - layer_past=layer_past, - use_cache=use_cache, - output_attentions=output_attentions - ) - - hidden_states = layer_ret[0] - - if use_cache: - presents = presents + (layer_ret[1],) - - if output_attentions: - all_self_attentions = all_self_attentions + (layer_ret[2 if use_cache else 1],) - - # Final layer norm. - hidden_states = self.final_layernorm(hidden_states) - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) - - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=presents, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - ) - - -class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): - def __init__(self, config: ChatGLMConfig, empty_init=True): - super().__init__(config) - if empty_init: - init_method = skip_init - else: - init_method = default_init - - # self.hidden_size = config.hidden_size - # self.params_dtype = torch.half - # self.vocab_size = config.vocab_size - self.max_sequence_length = config.max_sequence_length - - self.position_encoding_2d = config.position_encoding_2d - - self.transformer = ChatGLMModel(config, empty_init=empty_init) - - self.lm_head = init_method( - nn.Linear, - config.hidden_size, - config.vocab_size, - bias=False, - dtype=torch.half - ) - - self.config = config - - self.quantized = False - - if self.config.quantization_bit: - self.quantize(self.config.quantization_bit, empty_init=True) - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def _update_model_kwargs_for_generation( - self, - outputs: ModelOutput, - model_kwargs: Dict[str, Any], - is_encoder_decoder: bool = False, - standardize_cache_format: bool = False, - ) -> Dict[str, Any]: - # update past_key_values - model_kwargs["past_key_values"] = self._extract_past_from_model_output( - outputs, standardize_cache_format=standardize_cache_format - ) - - # update attention mask - if "attention_mask" in model_kwargs: - attention_mask = model_kwargs["attention_mask"] - if attention_mask is not None and attention_mask.dtype == torch.bool: - attention_mask = torch.cat( - [attention_mask, attention_mask.new_ones((*attention_mask.shape[:3], 1))], dim=3) - new_attention_mask = attention_mask[:, :, -1:].clone() - new_attention_mask[..., -1] = False - model_kwargs["attention_mask"] = torch.cat( - [attention_mask, new_attention_mask], dim=2 - ) - - # update position ids - if "position_ids" in model_kwargs: - position_ids = model_kwargs["position_ids"] - new_position_id = position_ids[..., -1:].clone() - new_position_id[:, 1, :] += 1 - model_kwargs["position_ids"] = torch.cat( - [position_ids, new_position_id], dim=-1 - ) - - return model_kwargs - - def prepare_inputs_for_generation( - self, - input_ids: torch.LongTensor, - past: Optional[torch.Tensor] = None, - past_key_values: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - **kwargs - ) -> dict: - batch_size, seq_length = input_ids.shape - MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id - seqs = input_ids.tolist() - mask_positions, use_gmasks = [], [] - for seq in seqs: - mask_token = gMASK if gMASK in seq else MASK - use_gmask = mask_token == gMASK - mask_positions.append(seq.index(mask_token)) - use_gmasks.append(use_gmask) - - # only last token for input_ids if past is not None - if past is not None or past_key_values is not None: - last_token = input_ids[:, -1].unsqueeze(-1) - if attention_mask is not None and attention_mask.dtype == torch.bool: - attention_mask = attention_mask[:, :, -1:] - else: - attention_mask = None - if position_ids is not None: - position_ids = position_ids[..., -1:] - else: - context_lengths = [seq.index(self.config.bos_token_id) for seq in seqs] - if self.position_encoding_2d: - position_ids = torch.tensor( - [[mask_position, seq_length - context_length] for mask_position, context_length in - zip(mask_positions, context_lengths)], dtype=torch.long, device=input_ids.device).unsqueeze(-1) - else: - position_ids = torch.tensor([mask_position for mask_position in mask_positions], dtype=torch.long, - device=input_ids.device).unsqueeze(-1) - - if past is None: - past = past_key_values - return { - "input_ids": last_token, - "past_key_values": past, - "position_ids": position_ids, - "attention_mask": attention_mask - } - else: - if attention_mask is not None and attention_mask.dtype != torch.bool: - logger.warning_once(f"The dtype of attention mask ({attention_mask.dtype}) is not bool") - attention_mask = None - if attention_mask is None: - attention_mask = self.get_masks( - input_ids, - device=input_ids.device - ) - if position_ids is None: - position_ids = self.get_position_ids( - input_ids, - device=input_ids.device, - mask_positions=mask_positions, - use_gmasks=use_gmasks - ) - - return { - "input_ids": input_ids, - "past_key_values": past, - "position_ids": position_ids, - "attention_mask": attention_mask - } - - def forward( - self, - input_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[Tuple[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ): - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.transformer( - input_ids=input_ids, - position_ids=position_ids, - attention_mask=attention_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = transformer_outputs[0] - - lm_logits = self.lm_head(hidden_states).contiguous() - - loss = None - if labels is not None: - lm_logits = lm_logits.to(torch.float32) - - # Shift so that tokens < n predict n - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss(ignore_index=-100) - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) - - lm_logits = lm_logits.to(hidden_states.dtype) - loss = loss.to(hidden_states.dtype) - - if not return_dict: - output = (lm_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=lm_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - @staticmethod - def _reorder_cache( - past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor - ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: - """ - This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or - [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct - beam_idx at every generation step. - - Output shares the same memory storage as `past`. - """ - return tuple( - ( - layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)), - layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)), - ) - for layer_past in past - ) - - def process_response(self, response): - response = response.strip() - response = response.replace("[[训练时间]]", "2023年") - punkts = [ - [",", ","], - ["!", "!"], - [":", ":"], - [";", ";"], - ["\?", "?"], - ] - for item in punkts: - response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response) - response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response) - return response - - @torch.no_grad() - def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1, - do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs): - if history is None: - history = [] - if logits_processor is None: - logits_processor = LogitsProcessorList() - logits_processor.append(InvalidScoreLogitsProcessor()) - gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p, - "temperature": temperature, "logits_processor": logits_processor, **kwargs} - if not history: - prompt = query - else: - prompt = "" - for i, (old_query, response) in enumerate(history): - prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response) - prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) - inputs = tokenizer([prompt], return_tensors="pt") - inputs = inputs.to(self.device) - outputs = self.generate(**inputs, **gen_kwargs) - outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] - response = tokenizer.decode(outputs) - response = self.process_response(response) - history = history + [(query, response)] - return response, history - - @torch.no_grad() - def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, - do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs): - if history is None: - history = [] - if logits_processor is None: - logits_processor = LogitsProcessorList() - logits_processor.append(InvalidScoreLogitsProcessor()) - gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p, - "temperature": temperature, "logits_processor": logits_processor, **kwargs} - if not history: - prompt = query - else: - prompt = "" - for i, (old_query, response) in enumerate(history): - prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response) - prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) - inputs = tokenizer([prompt], return_tensors="pt") - inputs = inputs.to(self.device) - for outputs in self.stream_generate(**inputs, **gen_kwargs): - outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] - response = tokenizer.decode(outputs) - response = self.process_response(response) - new_history = history + [(query, response)] - yield response, new_history - - @torch.no_grad() - def stream_generate( - self, - input_ids, - generation_config: Optional[GenerationConfig] = None, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, - **kwargs, - ): - batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] - - if generation_config is None: - generation_config = self.generation_config - generation_config = copy.deepcopy(generation_config) - model_kwargs = generation_config.update(**kwargs) - bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id - - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - - has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None - if has_default_max_length and generation_config.max_new_tokens is None: - warnings.warn( - f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " - "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" - " recommend using `max_new_tokens` to control the maximum length of the generation.", - UserWarning, - ) - elif generation_config.max_new_tokens is not None: - generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length - if not has_default_max_length: - logger.warn( - f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" - f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " - "Please refer to the documentation for more information. " - "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)", - UserWarning, - ) - - if input_ids_seq_length >= generation_config.max_length: - input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" - logger.warning( - f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" - f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" - " increasing `max_new_tokens`." - ) - - # 2. Set generation parameters if not already defined - logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() - stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - - logits_processor = self._get_logits_processor( - generation_config=generation_config, - input_ids_seq_length=input_ids_seq_length, - encoder_input_ids=input_ids, - prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, - logits_processor=logits_processor, - ) - - stopping_criteria = self._get_stopping_criteria( - generation_config=generation_config, stopping_criteria=stopping_criteria - ) - logits_warper = self._get_logits_warper(generation_config) - - unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) - scores = None - while True: - model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - # forward pass to get next token - outputs = self( - **model_inputs, - return_dict=True, - output_attentions=False, - output_hidden_states=False, - ) - - next_token_logits = outputs.logits[:, -1, :] - - # pre-process distribution - next_token_scores = logits_processor(input_ids, next_token_logits) - next_token_scores = logits_warper(input_ids, next_token_scores) - - # sample - probs = nn.functional.softmax(next_token_scores, dim=-1) - if generation_config.do_sample: - next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) - else: - next_tokens = torch.argmax(probs, dim=-1) - - # update generated ids, model inputs, and length for next step - input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) - model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder - ) - unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long()) - - # stop when each sentence is finished, or if we exceed the maximum length - if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): - break - yield input_ids - - def quantize(self, bits: int, empty_init=False, **kwargs): - if bits == 0: - return - - from .quantization import quantize - - if self.quantized: - logger.info("Already quantized.") - return self - - self.quantized = True - - self.config.quantization_bit = bits - - self.transformer = quantize(self.transformer, bits, empty_init=empty_init, **kwargs) - return self diff --git a/tests/script/models/chatglm-6b/modeling_chatglm.py b/tests/script/models/chatglm-6b/modeling_chatglm.py deleted file mode 100644 index 5d851ff..0000000 --- a/tests/script/models/chatglm-6b/modeling_chatglm.py +++ /dev/null @@ -1,1490 +0,0 @@ -""" PyTorch ChatGLM model. """ - -import math -import copy -import os -import warnings -import re -import sys - -import torch -import torch.utils.checkpoint -import torch.nn.functional as F -from torch import nn -from torch.nn import CrossEntropyLoss, LayerNorm -from torch.nn.utils import skip_init -from typing import Optional, Tuple, Union, List, Callable, Dict, Any -import llmdnn as ld - -from transformers.utils import ( - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, -) -from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, - BaseModelOutputWithPastAndCrossAttentions, -) -from transformers.modeling_utils import PreTrainedModel -from transformers.utils import logging -from transformers.generation.logits_process import LogitsProcessor -from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput - -from .configuration_chatglm import ChatGLMConfig - -# flags required to enable jit fusion kernels - -if sys.platform != 'darwin': - torch._C._jit_set_profiling_mode(False) - torch._C._jit_set_profiling_executor(False) - torch._C._jit_override_can_fuse_on_cpu(True) - torch._C._jit_override_can_fuse_on_gpu(True) - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "THUDM/ChatGLM-6B" -_CONFIG_FOR_DOC = "ChatGLM6BConfig" - -CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [ - "THUDM/chatglm-6b", - # See all ChatGLM-6B models at https://huggingface.co/models?filter=chatglm -] - - -class InvalidScoreLogitsProcessor(LogitsProcessor): - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - if torch.isnan(scores).any() or torch.isinf(scores).any(): - scores.zero_() - scores[..., 5] = 5e4 - return scores - - -def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path): - """Load tf checkpoints in a pytorch model.""" - try: - import re - - import numpy as np - import tensorflow as tf - except ImportError: - logger.error( - "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " - "https://www.tensorflow.org/install/ for installation instructions." - ) - raise - tf_path = os.path.abspath(tf_checkpoint_path) - logger.info(f"Converting TensorFlow checkpoint from {tf_path}") - # Load weights from TF model - init_vars = tf.train.list_variables(tf_path) - names = [] - arrays = [] - for name, shape in init_vars: - logger.info(f"Loading TF weight {name} with shape {shape}") - array = tf.train.load_variable(tf_path, name) - names.append(name) - arrays.append(array) - - for name, array in zip(names, arrays): - name = name.split("/") - # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v - # which are not required for using pretrained model - if any( - n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] - for n in name - ): - logger.info(f"Skipping {'/'.join(name)}") - continue - pointer = model - for m_name in name: - if re.fullmatch(r"[A-Za-z]+_\d+", m_name): - scope_names = re.split(r"_(\d+)", m_name) - else: - scope_names = [m_name] - if scope_names[0] == "kernel" or scope_names[0] == "gamma": - pointer = getattr(pointer, "weight") - elif scope_names[0] == "output_bias" or scope_names[0] == "beta": - pointer = getattr(pointer, "bias") - elif scope_names[0] == "output_weights": - pointer = getattr(pointer, "weight") - elif scope_names[0] == "squad": - pointer = getattr(pointer, "classifier") - else: - try: - pointer = getattr(pointer, scope_names[0]) - except AttributeError: - logger.info(f"Skipping {'/'.join(name)}") - continue - if len(scope_names) >= 2: - num = int(scope_names[1]) - pointer = pointer[num] - if m_name[-11:] == "_embeddings": - pointer = getattr(pointer, "weight") - elif m_name == "kernel": - array = np.transpose(array) - try: - assert ( - pointer.shape == array.shape - ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" - except AssertionError as e: - e.args += (pointer.shape, array.shape) - raise - logger.info(f"Initialize PyTorch weight {name}") - pointer.data = torch.from_numpy(array) - return model - - -class PrefixEncoder(torch.nn.Module): - """ - The torch.nn model to encode the prefix - Input shape: (batch-size, prefix-length) - Output shape: (batch-size, prefix-length, 2*layers*hidden) - """ - - def __init__(self, config): - super().__init__() - self.prefix_projection = config.prefix_projection - if self.prefix_projection: - # Use a two-layer MLP to encode the prefix - self.embedding = torch.nn.Embedding(config.pre_seq_len, config.hidden_size) - self.trans = torch.nn.Sequential( - torch.nn.Linear(config.hidden_size, config.hidden_size), - torch.nn.Tanh(), - torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2) - ) - else: - self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_layers * config.hidden_size * 2) - - def forward(self, prefix: torch.Tensor): - if self.prefix_projection: - prefix_tokens = self.embedding(prefix) - past_key_values = self.trans(prefix_tokens) - else: - past_key_values = self.embedding(prefix) - return past_key_values - - -@torch.jit.script -def gelu_impl(x): - """OpenAI's gelu implementation.""" - return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * - (1.0 + 0.044715 * x * x))) - - -def gelu(x): - return gelu_impl(x) - - -class RotaryEmbedding(torch.nn.Module): - def __init__(self, dim, base=10000, precision=torch.half, learnable=False): - super().__init__() - inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) - inv_freq = inv_freq.half() - self.learnable = learnable - if learnable: - self.inv_freq = torch.nn.Parameter(inv_freq) - self.max_seq_len_cached = None - else: - self.register_buffer('inv_freq', inv_freq) - self.max_seq_len_cached = None - self.cos_cached = None - self.sin_cached = None - self.precision = precision - - def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, - error_msgs): - pass - - def forward(self, x, seq_dim=1, seq_len=None): - if seq_len is None: - seq_len = x.shape[seq_dim] - if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached): - self.max_seq_len_cached = None if self.learnable else seq_len - t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype) - freqs = torch.einsum('i,j->ij', t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1).to(x.device) - if self.precision == torch.bfloat16: - emb = emb.float() - - # [sx, 1 (b * np), hn] - cos_cached = emb.cos()[:, None, :] - sin_cached = emb.sin()[:, None, :] - if self.precision == torch.bfloat16: - cos_cached = cos_cached.bfloat16() - sin_cached = sin_cached.bfloat16() - if self.learnable: - return cos_cached, sin_cached - self.cos_cached, self.sin_cached = cos_cached, sin_cached - return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...] - - def _apply(self, fn): - if self.cos_cached is not None: - self.cos_cached = fn(self.cos_cached) - if self.sin_cached is not None: - self.sin_cached = fn(self.sin_cached) - return super()._apply(fn) - - -def rotate_half(x): - x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] - return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions - - -@torch.jit.script -def apply_rotary_pos_emb_index(q, k, cos, sin, position_id): - # position_id: [b, sq], q, k: [b, sq, np, hn], cos: [sq, 1, hn] -> [b, sq, 1, hn] - cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), \ - F.embedding(position_id, sin.squeeze(1)).unsqueeze(2) - q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) - return q, k - - -def attention_fn( - self, - query_layer, - key_layer, - value_layer, - attention_mask, - hidden_size_per_partition, - layer_id, - layer_past=None, - scaling_attention_score=True, - use_cache=False, -): - if layer_past is not None: - past_key, past_value = layer_past[0], layer_past[1] - key_layer = torch.cat((past_key, key_layer), dim=1) - value_layer = torch.cat((past_value, value_layer), dim=1) - - # batch, seqlen, num_attention_heads, hidden_size_per_attention_head - b, seq_len, nh, hidden_size = key_layer.shape - - if use_cache: - present = (key_layer, value_layer) - else: - present = None - - query_key_layer_scaling_coeff = float(layer_id + 1) - if scaling_attention_score: - query_layer = query_layer / (math.sqrt(hidden_size) * query_key_layer_scaling_coeff) - - # =================================== - # Raw attention scores. [b, np, s, s] - # =================================== - - # [b, np, sq, sk] - output_size = (query_layer.size(0), query_layer.size(2), query_layer.size(1), key_layer.size(1)) - - # # [sq, b, np, hn] -> [sq, b * np, hn] - # query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) - # # [sk, b, np, hn] -> [sk, b * np, hn] - # key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) - # [b, sq, np, hn] -> [b, np, sq, hn] - query_layer = query_layer.transpose(1, 2) - # [b, sk, np, hn] -> [b, np, sk, hn] - key_layer = key_layer.transpose(1, 2) - # [b, np, sq, hn] -> [b * np, sq, hn] - query_layer = query_layer.view(output_size[0] * output_size[1], output_size[2], -1) - # [b, np, sk, hn] -> [b * np, sk, hn] - key_layer = key_layer.view(output_size[0] * output_size[1], output_size[3], -1) - - matmul_result = torch.zeros( - 1, 1, 1, - dtype=query_layer.dtype, - device=query_layer.device, - ) - - matmul_result = torch.baddbmm( - matmul_result, - query_layer, # [b * np, sq, hn] - key_layer.transpose(1, 2), # [b * np, hn, sk] - beta=0.0, - alpha=1.0, - ) - - # change view to [b, np, sq, sk] - attention_scores = matmul_result.view(*output_size) - - if self.scale_mask_softmax: - self.scale_mask_softmax.scale = query_key_layer_scaling_coeff - attention_probs = self.scale_mask_softmax(attention_scores, attention_mask.contiguous()) - else: - if not (attention_mask == 0).all(): - # if auto-regressive, skip - attention_scores.masked_fill_(attention_mask, -10000.0) - dtype = attention_scores.dtype - attention_scores = attention_scores.float() - attention_scores = attention_scores * query_key_layer_scaling_coeff - - attention_probs = F.softmax(attention_scores, dim=-1) - - attention_probs = attention_probs.type(dtype) - - # ========================= - # Context layer. [sq, b, hp] - # ========================= - - # value_layer -> context layer. - # [b, sk, np, hn] --> [b, np, sq, hn] - - # context layer shape: [b, np, sq, hn] - output_size = (value_layer.size(0), value_layer.size(2), query_layer.size(1), value_layer.size(3)) - - # # change view [sk, b * np, hn] - # value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) - # [b, sk, np, hn] -> [b, np, sk, hn] - value_layer = value_layer.transpose(1, 2) - # [b, np, sk, hn] -> [b * np, sk, hn] - value_layer = value_layer.view(output_size[0] * output_size[1], value_layer.size(2), output_size[3]) - - # change view [b * np, sq, sk] - attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) - - # matmul: [b * np, sq, hn] - context_layer = torch.bmm(attention_probs, value_layer) - - # change view [b, np, sq, hn] - context_layer = context_layer.view(*output_size) - - # [b, np, sq, hn] --> [b, sq, np, hn] - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - - # [b, sq, np, hn] --> [b, sq, hp] - new_context_layer_shape = context_layer.size()[:-2] + (hidden_size_per_partition,) - context_layer = context_layer.view(*new_context_layer_shape) - - outputs = (context_layer, present, attention_probs) - - return outputs - - -def default_init(cls, *args, **kwargs): - return cls(*args, **kwargs) - - -class SelfAttention(torch.nn.Module): - def __init__(self, hidden_size, num_attention_heads, - layer_id, hidden_size_per_attention_head=None, bias=True, - params_dtype=torch.float, position_encoding_2d=True, empty_init=True, max_sequence_length=2048): - if empty_init: - init_method = skip_init - else: - init_method = default_init - super(SelfAttention, self).__init__() - - self.layer_id = layer_id - self.hidden_size = hidden_size - self.hidden_size_per_partition = hidden_size - self.num_attention_heads = num_attention_heads - self.num_attention_heads_per_partition = num_attention_heads - self.position_encoding_2d = position_encoding_2d - self.rotary_emb = RotaryEmbedding( - self.hidden_size // (self.num_attention_heads * 2) - if position_encoding_2d - else self.hidden_size // self.num_attention_heads, - base=10000, - precision=torch.half, - learnable=False, - ) - - self.scale_mask_softmax = None - - if hidden_size_per_attention_head is None: - self.hidden_size_per_attention_head = hidden_size // num_attention_heads - else: - self.hidden_size_per_attention_head = hidden_size_per_attention_head - - self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head - - # Strided linear layer. - self.query_key_value = init_method( - torch.nn.Linear, - hidden_size, - 3 * self.inner_hidden_size, - bias=bias, - dtype=params_dtype, - ) - - self.dense = init_method( - torch.nn.Linear, - self.inner_hidden_size, - hidden_size, - bias=bias, - dtype=params_dtype, - ) - - self.attn = ld.attn_gpt() - head_size = hidden_size // num_attention_heads - self.head_size_aligned = (head_size + 31) // 32 * 32 - self.max_sequence_length = max_sequence_length - normal_factor = 1.0 / math.sqrt(head_size) - rotary_ndims = int(head_size * 0.5) - self.attn.create(num_attention_heads, head_size, self.head_size_aligned, - normal_factor, 'bf16', 'bf16', max_sequence_length, rotary_ndims, True) - self.layer_past_key_padded = None - self.layer_past_value_padded = None - self.past_seq_len = 0 - inv_freq = 1.0 / (10000 ** (torch.arange(0, rotary_ndims, 2).float() / rotary_ndims)) - - # Build here to make `torch.jit.trace` work. - self.max_seq_len_cached = max_sequence_length - t = torch.arange(self.max_seq_len_cached, device=inv_freq.device, dtype=inv_freq.dtype) - freqs = torch.einsum("i,j->ij", t, inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.cos_cached = emb.cos()[None, None, :, :] - self.sin_cached = emb.sin()[None, None, :, :] - - @staticmethod - def attention_mask_func(attention_scores, attention_mask): - attention_scores.masked_fill_(attention_mask, -10000.0) - return attention_scores - - def split_tensor_along_last_dim(self, tensor, num_partitions, - contiguous_split_chunks=False): - """Split a tensor along its last dimension. - Arguments: - tensor: input tensor. - num_partitions: number of partitions to split the tensor - contiguous_split_chunks: If True, make each chunk contiguous - in memory. - """ - # Get the size and dimension. - last_dim = tensor.dim() - 1 - last_dim_size = tensor.size()[last_dim] // num_partitions - # Split. - tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) - # Note: torch.split does not create contiguous tensors by default. - if contiguous_split_chunks: - return tuple(chunk.contiguous() for chunk in tensor_list) - - return tensor_list - - def forward( - self, - hidden_states: torch.Tensor, - position_ids, - attention_mask: torch.Tensor, - layer_id, - layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - use_cache: bool = False, - output_attentions: bool = False, - ): - """ - hidden_states: [batch, seq_len, hidden_size] - attention_mask: [(1, 1), seq_len, seq_len] - """ - - # [batch, seq_len, 3 * hidden_size] - mixed_raw_layer = self.query_key_value(hidden_states) - position_ids = position_ids.contiguous() - if self.layer_past_key_padded is None: - shape = (hidden_states.size(0), self.num_attention_heads, self.max_sequence_length, self.head_size_aligned) - self.layer_past_key_padded = torch.zeros(shape, dtype=torch.bfloat16) - self.layer_past_value_padded = torch.zeros(shape, dtype=torch.bfloat16) - if layer_past is None: - self.past_seq_len = 0 - context_layer = self.attn.exec_position(mixed_raw_layer, self.layer_past_key_padded, self.layer_past_value_padded, self.past_seq_len, attention_mask, position_ids, self.cos_cached, self.sin_cached) - present = (self.layer_past_key_padded, self.layer_past_value_padded) - attention_probs = None - self.past_seq_len += mixed_raw_layer.size(1) - - if 0: - # [batch, seq_len, 3 * hidden_size] --> [batch, seq_len, num_attention_heads, 3 * hidden_size_per_attention_head] - new_tensor_shape = mixed_raw_layer.size()[:-1] + ( - self.num_attention_heads_per_partition, - 3 * self.hidden_size_per_attention_head, - ) - mixed_raw_layer = mixed_raw_layer.view(*new_tensor_shape) - - # [batch, seq_len, num_attention_heads, hidden_size_per_attention_head] - (query_layer, key_layer, value_layer) = self.split_tensor_along_last_dim(mixed_raw_layer, 3) - - if self.position_encoding_2d: - q1, q2 = query_layer.chunk(2, dim=(query_layer.ndim - 1)) - k1, k2 = key_layer.chunk(2, dim=(key_layer.ndim - 1)) - cos, sin = self.rotary_emb(q1, seq_len=position_ids.max() + 1) - # xxx - position_ids, block_position_ids = position_ids[:, 0, :].contiguous(), \ - position_ids[:, 1, :].contiguous() - q1, k1 = apply_rotary_pos_emb_index(q1, k1, cos, sin, position_ids) - q2, k2 = apply_rotary_pos_emb_index(q2, k2, cos, sin, block_position_ids) - query_layer = torch.concat([q1, q2], dim=(q1.ndim - 1)) - key_layer = torch.concat([k1, k2], dim=(k1.ndim - 1)) - else: - position_ids = position_ids.transpose(0, 1) - cos, sin = self.rotary_emb(value_layer, seq_len=position_ids.max() + 1) - # [batch, seq_len, num_attention_heads, hidden_size_per_attention_head] - query_layer, key_layer = apply_rotary_pos_emb_index(query_layer, key_layer, cos, sin, position_ids) - - # [batch, seq_len, hidden_size] - context_layer, present, attention_probs = attention_fn( - self=self, - query_layer=query_layer, - key_layer=key_layer, - value_layer=value_layer, - attention_mask=attention_mask, - hidden_size_per_partition=self.hidden_size_per_partition, - layer_id=layer_id, - layer_past=layer_past, - use_cache=use_cache - ) - - output = self.dense(context_layer) - - outputs = (output, present) - - if output_attentions: - outputs += (attention_probs,) - - return outputs # output, present, attention_probs - - -class GEGLU(torch.nn.Module): - def __init__(self): - super().__init__() - self.activation_fn = F.gelu - - def forward(self, x): - # dim=-1 breaks in jit for pt<1.10 - x1, x2 = x.chunk(2, dim=(x.ndim - 1)) - return x1 * self.activation_fn(x2) - - -class GLU(torch.nn.Module): - def __init__(self, hidden_size, inner_hidden_size=None, - layer_id=None, bias=True, activation_func=gelu, params_dtype=torch.float, empty_init=True): - super(GLU, self).__init__() - if empty_init: - init_method = skip_init - else: - init_method = default_init - self.layer_id = layer_id - self.activation_func = activation_func - - # Project to 4h. - self.hidden_size = hidden_size - if inner_hidden_size is None: - inner_hidden_size = 4 * hidden_size - self.inner_hidden_size = inner_hidden_size - self.dense_h_to_4h = init_method( - torch.nn.Linear, - self.hidden_size, - self.inner_hidden_size, - bias=bias, - dtype=params_dtype, - ) - # Project back to h. - self.dense_4h_to_h = init_method( - torch.nn.Linear, - self.inner_hidden_size, - self.hidden_size, - bias=bias, - dtype=params_dtype, - ) - - def forward(self, hidden_states): - """ - hidden_states: [seq_len, batch, hidden_size] - """ - - # [seq_len, batch, inner_hidden_size] - intermediate_parallel = self.dense_h_to_4h(hidden_states) - - intermediate_parallel = self.activation_func(intermediate_parallel) - - output = self.dense_4h_to_h(intermediate_parallel) - - return output - - -class GLMBlock(torch.nn.Module): - def __init__( - self, - hidden_size, - num_attention_heads, - layernorm_epsilon, - layer_id, - inner_hidden_size=None, - hidden_size_per_attention_head=None, - layernorm=LayerNorm, - use_bias=True, - params_dtype=torch.float, - num_layers=28, - position_encoding_2d=True, - empty_init=True, - max_sequence_length=2048 - ): - super(GLMBlock, self).__init__() - # Set output layer initialization if not provided. - - self.layer_id = layer_id - - # Layernorm on the input data. - self.input_layernorm = layernorm(hidden_size, eps=layernorm_epsilon) - - self.position_encoding_2d = position_encoding_2d - - # Self attention. - self.attention = SelfAttention( - hidden_size, - num_attention_heads, - layer_id, - hidden_size_per_attention_head=hidden_size_per_attention_head, - bias=use_bias, - params_dtype=params_dtype, - position_encoding_2d=self.position_encoding_2d, - empty_init=empty_init, - max_sequence_length=max_sequence_length - ) - - # Layernorm on the input data. - self.post_attention_layernorm = layernorm(hidden_size, eps=layernorm_epsilon) - - self.num_layers = num_layers - - # GLU - self.mlp = GLU( - hidden_size, - inner_hidden_size=inner_hidden_size, - bias=use_bias, - layer_id=layer_id, - params_dtype=params_dtype, - empty_init=empty_init - ) - - def forward( - self, - hidden_states: torch.Tensor, - position_ids, - attention_mask: torch.Tensor, - layer_id, - layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - use_cache: bool = False, - output_attentions: bool = False, - ): - """ - hidden_states: [seq_len, batch, hidden_size] - attention_mask: [(1, 1), seq_len, seq_len] - """ - - # Layer norm at the begining of the transformer layer. - # [batch, seq_len, hidden_size] - attention_input = self.input_layernorm(hidden_states) - - # Self attention. - attention_outputs = self.attention( - attention_input, - position_ids, - attention_mask=attention_mask, - layer_id=layer_id, - layer_past=layer_past, - use_cache=use_cache, - output_attentions=output_attentions - ) - - attention_output = attention_outputs[0] - - outputs = attention_outputs[1:] - - # Residual connection. - alpha = (2 * self.num_layers) ** 0.5 - hidden_states = attention_input * alpha + attention_output - - mlp_input = self.post_attention_layernorm(hidden_states) - - # MLP. - mlp_output = self.mlp(mlp_input) - - # Second residual connection. - output = mlp_input * alpha + mlp_output - - if use_cache: - outputs = (output,) + outputs - else: - outputs = (output,) + outputs[1:] - - return outputs # hidden_states, present, attentions - - -class ChatGLMPreTrainedModel(PreTrainedModel): - """ - An abstract class to handle weights initialization and - a simple interface for downloading and loading pretrained models. - """ - - is_parallelizable = False - supports_gradient_checkpointing = True - config_class = ChatGLMConfig - base_model_prefix = "transformer" - _no_split_modules = ["GLMBlock"] - - def __init__(self, *inputs, **kwargs): - super().__init__(*inputs, **kwargs) - - def _init_weights(self, module: nn.Module): - """Initialize the weights.""" - return - - def get_masks(self, input_ids, device): - batch_size, seq_length = input_ids.shape - context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids] - attention_mask = torch.ones((batch_size, seq_length, seq_length), device=device) - attention_mask.tril_() - for i, context_length in enumerate(context_lengths): - attention_mask[i, :, :context_length] = 1 - attention_mask.unsqueeze_(1) - attention_mask = (attention_mask < 0.5).bool() - - return attention_mask - - def get_position_ids(self, input_ids, mask_positions, device, use_gmasks=None): - batch_size, seq_length = input_ids.shape - if use_gmasks is None: - use_gmasks = [False] * batch_size - context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids] - if self.position_encoding_2d: - position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) - for i, context_length in enumerate(context_lengths): - position_ids[i, context_length:] = mask_positions[i] - block_position_ids = [torch.cat(( - torch.zeros(context_length, dtype=torch.long, device=device), - torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1 - )) for context_length in context_lengths] - block_position_ids = torch.stack(block_position_ids, dim=0) - position_ids = torch.stack((position_ids, block_position_ids), dim=1) - else: - position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) - for i, context_length in enumerate(context_lengths): - if not use_gmasks[i]: - position_ids[i, context_length:] = mask_positions[i] - - return position_ids - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, ChatGLMModel): - module.gradient_checkpointing = value - - -CHATGLM_6B_START_DOCSTRING = r""" - This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general - usage and behavior. - - Parameters: - config ([`~ChatGLM6BConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the configuration. - Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -CHATGLM_6B_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `({0})`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`ChatGLM6BTokenizer`]. - See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): - Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, 1]`: - - - 0 corresponds to a *sentence A* token, - - 1 corresponds to a *sentence B* token. - - [What are token type IDs?](../glossary#token-type-ids) - position_ids (`torch.LongTensor` of shape `({0})`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. - Selected in the range `[0, config.max_position_embeddings - 1]`. - - [What are position IDs?](../glossary#position-ids) - head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert *input_ids* indices into associated vectors - than the model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -@add_start_docstrings( - "The bare ChatGLM-6B Model transformer outputting raw hidden-states without any specific head on top.", - CHATGLM_6B_START_DOCSTRING, -) -class ChatGLMModel(ChatGLMPreTrainedModel): - """ - - The model can behave as an encoder (with only self-attention) as well - as a decoder, in which case a layer of cross-attention is added between - the self-attention layers, following the architecture described in [Attention is - all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, - Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. - - To behave as an decoder the model needs to be initialized with the - `is_decoder` argument of the configuration set to `True`. - To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` - argument and `add_cross_attention` set to `True`; an - `encoder_hidden_states` is then expected as an input to the forward pass. - """ - - def __init__(self, config: ChatGLMConfig, empty_init=True): - super().__init__(config) - if empty_init: - init_method = skip_init - else: - init_method = default_init - # recording parameters - self.max_sequence_length = config.max_sequence_length - self.hidden_size = config.hidden_size - self.params_dtype = torch.half - self.num_attention_heads = config.num_attention_heads - self.vocab_size = config.vocab_size - self.num_layers = config.num_layers - self.layernorm_epsilon = config.layernorm_epsilon - self.inner_hidden_size = config.inner_hidden_size - self.hidden_size_per_attention_head = self.hidden_size // self.num_attention_heads - self.position_encoding_2d = config.position_encoding_2d - self.pre_seq_len = config.pre_seq_len - self.prefix_projection = config.prefix_projection - - self.word_embeddings = init_method( - torch.nn.Embedding, - num_embeddings=self.vocab_size, embedding_dim=self.hidden_size, - dtype=self.params_dtype - ) - self.gradient_checkpointing = False - - def get_layer(layer_id): - return GLMBlock( - self.hidden_size, - self.num_attention_heads, - self.layernorm_epsilon, - layer_id, - inner_hidden_size=self.inner_hidden_size, - hidden_size_per_attention_head=self.hidden_size_per_attention_head, - layernorm=LayerNorm, - use_bias=True, - params_dtype=self.params_dtype, - position_encoding_2d=self.position_encoding_2d, - empty_init=empty_init, - max_sequence_length=self.max_sequence_length - ) - - self.layers = torch.nn.ModuleList( - [get_layer(layer_id) for layer_id in range(self.num_layers)] - ) - - # Final layer norm before output. - self.final_layernorm = LayerNorm(self.hidden_size, eps=self.layernorm_epsilon) - - if self.pre_seq_len is not None: - for param in self.parameters(): - param.requires_grad = False - self.prefix_tokens = torch.arange(self.pre_seq_len).long() - self.prefix_encoder = PrefixEncoder(config) - self.dropout = torch.nn.Dropout(0.1) - - # total_params = sum(p.numel() for p in self.parameters()) - # trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) - # print("Using p-tuning v2: # trainable_params = {} / {}".format(trainable_params, total_params)) - - def get_input_embeddings(self): - return self.word_embeddings - - def set_input_embeddings(self, new_embeddings: torch.Tensor): - self.word_embeddings = new_embeddings - - def get_prompt(self, batch_size, device, dtype=torch.half): - prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device) - past_key_values = self.prefix_encoder(prefix_tokens).type(dtype) - past_key_values = past_key_values.view( - batch_size, - self.pre_seq_len, - self.num_layers * 2, - self.num_attention_heads, - self.hidden_size // self.num_attention_heads - ) - # seq_len, b, nh, hidden_size - past_key_values = self.dropout(past_key_values) - past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2) - # past_key_values = [(v[0], v[1]) for v in past_key_values] - return past_key_values - - @add_start_docstrings_to_model_forward(CHATGLM_6B_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=BaseModelOutputWithPastAndCrossAttentions, - config_class=_CONFIG_FOR_DOC, - ) - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, - inputs_embeds: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPast]: - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape[:2] - elif inputs_embeds is not None: - batch_size, seq_length = inputs_embeds.shape[:2] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if inputs_embeds is None: - inputs_embeds = self.word_embeddings(input_ids) - - if past_key_values is None: - if self.pre_seq_len is not None: - past_key_values = self.get_prompt(batch_size=input_ids.shape[0], device=input_ids.device, - dtype=inputs_embeds.dtype) - else: - past_key_values = tuple([None] * len(self.layers)) - - if attention_mask is None: - attention_mask = self.get_masks( - input_ids, - device=input_ids.device - ) - - - if position_ids is None: - MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id - seqs = input_ids.tolist() - - mask_positions, use_gmasks = [], [] - for seq in seqs: - mask_token = gMASK if gMASK in seq else MASK - use_gmask = mask_token == gMASK - mask_positions.append(seq.index(mask_token)) - use_gmasks.append(use_gmask) - - position_ids = self.get_position_ids( - input_ids, - mask_positions=mask_positions, - device=input_ids.device, - use_gmasks=use_gmasks - ) - - if self.pre_seq_len is not None and attention_mask is not None: - prefix_attention_mask = torch.ones(batch_size, 1, input_ids.size(-1), self.pre_seq_len).to( - attention_mask.device) - prefix_attention_mask = (prefix_attention_mask < 0.5).bool() - attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3) - - # [seq_len, batch, hidden_size] - # hidden_states = inputs_embeds.transpose(0, 1) - # xxx - hidden_states = inputs_embeds - - presents = () if use_cache else None - all_self_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - - if attention_mask is None: - attention_mask = torch.zeros(1, 1, device=input_ids.device).bool() - else: - attention_mask = attention_mask.to(hidden_states.device) - - for i, layer in enumerate(self.layers): - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - layer_past = past_key_values[i] - - if self.gradient_checkpointing and self.training: - layer_ret = torch.utils.checkpoint.checkpoint( - layer, - hidden_states, - position_ids, - attention_mask, - torch.tensor(i), - layer_past, - use_cache, - output_attentions - ) - else: - layer_ret = layer( - hidden_states, - position_ids=position_ids, - attention_mask=attention_mask, - layer_id=torch.tensor(i), - layer_past=layer_past, - use_cache=use_cache, - output_attentions=output_attentions - ) - - hidden_states = layer_ret[0] - - if use_cache: - presents = presents + (layer_ret[1],) - - if output_attentions: - all_self_attentions = all_self_attentions + (layer_ret[2 if use_cache else 1],) - - # Final layer norm. - hidden_states = self.final_layernorm(hidden_states) - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) - - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=presents, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - ) - - -class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): - def __init__(self, config: ChatGLMConfig, empty_init=True): - super().__init__(config) - if empty_init: - init_method = skip_init - else: - init_method = default_init - - # self.hidden_size = config.hidden_size - # self.params_dtype = torch.half - # self.vocab_size = config.vocab_size - self.max_sequence_length = config.max_sequence_length - - self.position_encoding_2d = config.position_encoding_2d - - self.transformer = ChatGLMModel(config, empty_init=empty_init) - - self.lm_head = init_method( - nn.Linear, - config.hidden_size, - config.vocab_size, - bias=False, - dtype=torch.half - ) - - self.config = config - - self.quantized = False - - if self.config.quantization_bit: - self.quantize(self.config.quantization_bit, empty_init=True) - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def _update_model_kwargs_for_generation( - self, - outputs: ModelOutput, - model_kwargs: Dict[str, Any], - is_encoder_decoder: bool = False, - standardize_cache_format: bool = False, - ) -> Dict[str, Any]: - # update past_key_values - model_kwargs["past_key_values"] = self._extract_past_from_model_output( - outputs, standardize_cache_format=standardize_cache_format - ) - - # update attention mask - if "attention_mask" in model_kwargs: - attention_mask = model_kwargs["attention_mask"] - if attention_mask is not None: # and attention_mask.dtype == torch.bool: - right = torch.empty((*attention_mask.shape[:3], 1), dtype=torch.float32) - right[:] = -10000.0 - attention_mask = torch.cat( - [attention_mask, right], dim=3) - new_attention_mask = attention_mask[:, :, -1:].clone() - new_attention_mask[..., -1] = 0 - model_kwargs["attention_mask"] = torch.cat( - [attention_mask, new_attention_mask], dim=2 - ) - - # update position ids - if "position_ids" in model_kwargs: - position_ids = model_kwargs["position_ids"] - new_position_id = position_ids[..., -1:].clone() - new_position_id[:, 1, :] += 1 - model_kwargs["position_ids"] = torch.cat( - [position_ids, new_position_id], dim=-1 - ) - - return model_kwargs - - def prepare_inputs_for_generation( - self, - input_ids: torch.LongTensor, - past: Optional[torch.Tensor] = None, - past_key_values: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - **kwargs - ) -> dict: - batch_size, seq_length = input_ids.shape - MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id - seqs = input_ids.tolist() - mask_positions, use_gmasks = [], [] - for seq in seqs: - mask_token = gMASK if gMASK in seq else MASK - use_gmask = mask_token == gMASK - mask_positions.append(seq.index(mask_token)) - use_gmasks.append(use_gmask) - - # only last token for input_ids if past is not None - if past is not None or past_key_values is not None: - last_token = input_ids[:, -1].unsqueeze(-1) - if attention_mask is not None: # and attention_mask.dtype == torch.bool: - attention_mask = attention_mask[:, :, -1:] - # else: - # attention_mask = None - if position_ids is not None: - position_ids = position_ids[..., -1:] - # else: - # context_lengths = [seq.index(self.config.bos_token_id) for seq in seqs] - # if self.position_encoding_2d: - # position_ids = torch.tensor( - # [[mask_position, seq_length - context_length] for mask_position, context_length in - # zip(mask_positions, context_lengths)], dtype=torch.long, device=input_ids.device).unsqueeze(-1) - # else: - # position_ids = torch.tensor([mask_position for mask_position in mask_positions], dtype=torch.long, - # device=input_ids.device).unsqueeze(-1) - - if past is None: - past = past_key_values - return { - "input_ids": last_token, - "past_key_values": past, - "position_ids": position_ids, - "attention_mask": attention_mask - } - else: - # if attention_mask is not None and attention_mask.dtype != torch.bool: - # logger.warning_once(f"The dtype of attention mask ({attention_mask.dtype}) is not bool") - # attention_mask = None - # if attention_mask is None: - # attention_mask = self.get_masks( - # input_ids, - # device=input_ids.device - # ) - # if position_ids is None: - # position_ids = self.get_position_ids( - # input_ids, - # device=input_ids.device, - # mask_positions=mask_positions, - # use_gmasks=use_gmasks - # ) - - return { - "input_ids": input_ids, - "past_key_values": past, - "position_ids": position_ids, - "attention_mask": attention_mask - } - - def forward( - self, - input_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[Tuple[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ): - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.transformer( - input_ids=input_ids, - position_ids=position_ids, - attention_mask=attention_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = transformer_outputs[0] - - lm_logits = self.lm_head(hidden_states).contiguous() - - loss = None - if labels is not None: - lm_logits = lm_logits.to(torch.float32) - - # Shift so that tokens < n predict n - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss(ignore_index=-100) - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) - - lm_logits = lm_logits.to(hidden_states.dtype) - loss = loss.to(hidden_states.dtype) - - if not return_dict: - output = (lm_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=lm_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - @staticmethod - def _reorder_cache( - past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor - ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: - """ - This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or - [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct - beam_idx at every generation step. - - Output shares the same memory storage as `past`. - """ - return tuple( - ( - layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)), - layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)), - ) - for layer_past in past - ) - - def process_response(self, response): - response = response.strip() - response = response.replace("[[训练时间]]", "2023年") - punkts = [ - [",", ","], - ["!", "!"], - [":", ":"], - [";", ";"], - ["\?", "?"], - ] - for item in punkts: - response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response) - response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response) - return response - - @torch.no_grad() - def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1, - do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs): - if history is None: - history = [] - if logits_processor is None: - logits_processor = LogitsProcessorList() - logits_processor.append(InvalidScoreLogitsProcessor()) - gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p, - "temperature": temperature, "logits_processor": logits_processor, **kwargs} - if not history: - prompt = query - else: - prompt = "" - for i, (old_query, response) in enumerate(history): - prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response) - prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) - inputs = tokenizer([prompt], return_tensors="pt") - inputs = inputs.to(self.device) - outputs = self.generate(**inputs, **gen_kwargs) - outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] - response = tokenizer.decode(outputs) - response = self.process_response(response) - history = history + [(query, response)] - return response, history - - @torch.no_grad() - def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, - do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs): - if history is None: - history = [] - if logits_processor is None: - logits_processor = LogitsProcessorList() - logits_processor.append(InvalidScoreLogitsProcessor()) - gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p, - "temperature": temperature, "logits_processor": logits_processor, **kwargs} - if not history: - prompt = query - else: - prompt = "" - for i, (old_query, response) in enumerate(history): - prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response) - prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) - inputs = tokenizer([prompt], return_tensors="pt") - inputs = inputs.to(self.device) - for outputs in self.stream_generate(**inputs, **gen_kwargs): - outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] - response = tokenizer.decode(outputs) - response = self.process_response(response) - new_history = history + [(query, response)] - yield response, new_history - - @torch.no_grad() - def stream_generate( - self, - input_ids, - generation_config: Optional[GenerationConfig] = None, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, - **kwargs, - ): - batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] - - if generation_config is None: - generation_config = self.generation_config - generation_config = copy.deepcopy(generation_config) - model_kwargs = generation_config.update(**kwargs) - bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id - - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - - has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None - if has_default_max_length and generation_config.max_new_tokens is None: - warnings.warn( - f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " - "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" - " recommend using `max_new_tokens` to control the maximum length of the generation.", - UserWarning, - ) - elif generation_config.max_new_tokens is not None: - generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length - if not has_default_max_length: - logger.warn( - f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" - f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " - "Please refer to the documentation for more information. " - "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)", - UserWarning, - ) - - if input_ids_seq_length >= generation_config.max_length: - input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" - logger.warning( - f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" - f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" - " increasing `max_new_tokens`." - ) - - # 2. Set generation parameters if not already defined - logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() - stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - - logits_processor = self._get_logits_processor( - generation_config=generation_config, - input_ids_seq_length=input_ids_seq_length, - encoder_input_ids=input_ids, - prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, - logits_processor=logits_processor, - ) - - stopping_criteria = self._get_stopping_criteria( - generation_config=generation_config, stopping_criteria=stopping_criteria - ) - logits_warper = self._get_logits_warper(generation_config) - - unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) - scores = None - while True: - model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - # forward pass to get next token - outputs = self( - **model_inputs, - return_dict=True, - output_attentions=False, - output_hidden_states=False, - ) - - next_token_logits = outputs.logits[:, -1, :] - - # pre-process distribution - next_token_scores = logits_processor(input_ids, next_token_logits) - next_token_scores = logits_warper(input_ids, next_token_scores) - - # sample - probs = nn.functional.softmax(next_token_scores, dim=-1) - if generation_config.do_sample: - next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) - else: - next_tokens = torch.argmax(probs, dim=-1) - - # update generated ids, model inputs, and length for next step - input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) - model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder - ) - unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long()) - - # stop when each sentence is finished, or if we exceed the maximum length - if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): - break - yield input_ids - - def quantize(self, bits: int, empty_init=False, **kwargs): - if bits == 0: - return - - from .quantization import quantize - - if self.quantized: - logger.info("Already quantized.") - return self - - self.quantized = True - - self.config.quantization_bit = bits - - self.transformer = quantize(self.transformer, bits, empty_init=empty_init, **kwargs) - return self diff --git a/tests/script/models/chatglm-6b/pytorch_model.bin.index.json b/tests/script/models/chatglm-6b/pytorch_model.bin.index.json deleted file mode 100644 index b8ada2b..0000000 --- a/tests/script/models/chatglm-6b/pytorch_model.bin.index.json +++ /dev/null @@ -1,375 +0,0 @@ -{ - "metadata": { - "total_size": 13744473856 - }, - "weight_map": { - "lm_head.weight": "pytorch_model-00008-of-00008.bin", - "transformer.final_layernorm.bias": "pytorch_model-00007-of-00008.bin", - "transformer.final_layernorm.weight": "pytorch_model-00007-of-00008.bin", - "transformer.layers.0.attention.dense.bias": "pytorch_model-00001-of-00008.bin", - "transformer.layers.0.attention.dense.weight": "pytorch_model-00001-of-00008.bin", - "transformer.layers.0.attention.query_key_value.bias": "pytorch_model-00001-of-00008.bin", - "transformer.layers.0.attention.query_key_value.weight": "pytorch_model-00001-of-00008.bin", - "transformer.layers.0.attention.rotary_emb.inv_freq": "pytorch_model-00001-of-00008.bin", - "transformer.layers.0.input_layernorm.bias": "pytorch_model-00001-of-00008.bin", - "transformer.layers.0.input_layernorm.weight": "pytorch_model-00001-of-00008.bin", - "transformer.layers.0.mlp.dense_4h_to_h.bias": "pytorch_model-00001-of-00008.bin", - "transformer.layers.0.mlp.dense_4h_to_h.weight": "pytorch_model-00001-of-00008.bin", - "transformer.layers.0.mlp.dense_h_to_4h.bias": "pytorch_model-00001-of-00008.bin", - "transformer.layers.0.mlp.dense_h_to_4h.weight": "pytorch_model-00001-of-00008.bin", - "transformer.layers.0.post_attention_layernorm.bias": "pytorch_model-00001-of-00008.bin", - "transformer.layers.0.post_attention_layernorm.weight": "pytorch_model-00001-of-00008.bin", - "transformer.layers.1.attention.dense.bias": "pytorch_model-00001-of-00008.bin", - "transformer.layers.1.attention.dense.weight": "pytorch_model-00001-of-00008.bin", - "transformer.layers.1.attention.query_key_value.bias": "pytorch_model-00001-of-00008.bin", - "transformer.layers.1.attention.query_key_value.weight": "pytorch_model-00001-of-00008.bin", - "transformer.layers.1.attention.rotary_emb.inv_freq": "pytorch_model-00001-of-00008.bin", - "transformer.layers.1.input_layernorm.bias": "pytorch_model-00001-of-00008.bin", - "transformer.layers.1.input_layernorm.weight": "pytorch_model-00001-of-00008.bin", - "transformer.layers.1.mlp.dense_4h_to_h.bias": "pytorch_model-00002-of-00008.bin", - "transformer.layers.1.mlp.dense_4h_to_h.weight": "pytorch_model-00002-of-00008.bin", - "transformer.layers.1.mlp.dense_h_to_4h.bias": "pytorch_model-00001-of-00008.bin", - "transformer.layers.1.mlp.dense_h_to_4h.weight": "pytorch_model-00001-of-00008.bin", - "transformer.layers.1.post_attention_layernorm.bias": "pytorch_model-00001-of-00008.bin", - "transformer.layers.1.post_attention_layernorm.weight": "pytorch_model-00001-of-00008.bin", - "transformer.layers.10.attention.dense.bias": "pytorch_model-00003-of-00008.bin", - "transformer.layers.10.attention.dense.weight": "pytorch_model-00003-of-00008.bin", - "transformer.layers.10.attention.query_key_value.bias": "pytorch_model-00003-of-00008.bin", - "transformer.layers.10.attention.query_key_value.weight": "pytorch_model-00003-of-00008.bin", - "transformer.layers.10.attention.rotary_emb.inv_freq": "pytorch_model-00003-of-00008.bin", - "transformer.layers.10.input_layernorm.bias": "pytorch_model-00003-of-00008.bin", - "transformer.layers.10.input_layernorm.weight": "pytorch_model-00003-of-00008.bin", - "transformer.layers.10.mlp.dense_4h_to_h.bias": "pytorch_model-00003-of-00008.bin", - "transformer.layers.10.mlp.dense_4h_to_h.weight": "pytorch_model-00003-of-00008.bin", - "transformer.layers.10.mlp.dense_h_to_4h.bias": "pytorch_model-00003-of-00008.bin", - "transformer.layers.10.mlp.dense_h_to_4h.weight": "pytorch_model-00003-of-00008.bin", - "transformer.layers.10.post_attention_layernorm.bias": "pytorch_model-00003-of-00008.bin", - "transformer.layers.10.post_attention_layernorm.weight": "pytorch_model-00003-of-00008.bin", - "transformer.layers.11.attention.dense.bias": "pytorch_model-00004-of-00008.bin", - "transformer.layers.11.attention.dense.weight": "pytorch_model-00004-of-00008.bin", - "transformer.layers.11.attention.query_key_value.bias": "pytorch_model-00003-of-00008.bin", - "transformer.layers.11.attention.query_key_value.weight": "pytorch_model-00003-of-00008.bin", - "transformer.layers.11.attention.rotary_emb.inv_freq": "pytorch_model-00003-of-00008.bin", - "transformer.layers.11.input_layernorm.bias": "pytorch_model-00003-of-00008.bin", - "transformer.layers.11.input_layernorm.weight": "pytorch_model-00003-of-00008.bin", - "transformer.layers.11.mlp.dense_4h_to_h.bias": "pytorch_model-00004-of-00008.bin", - "transformer.layers.11.mlp.dense_4h_to_h.weight": "pytorch_model-00004-of-00008.bin", - "transformer.layers.11.mlp.dense_h_to_4h.bias": "pytorch_model-00004-of-00008.bin", - "transformer.layers.11.mlp.dense_h_to_4h.weight": "pytorch_model-00004-of-00008.bin", - "transformer.layers.11.post_attention_layernorm.bias": "pytorch_model-00004-of-00008.bin", - "transformer.layers.11.post_attention_layernorm.weight": "pytorch_model-00004-of-00008.bin", - "transformer.layers.12.attention.dense.bias": "pytorch_model-00004-of-00008.bin", - "transformer.layers.12.attention.dense.weight": "pytorch_model-00004-of-00008.bin", - "transformer.layers.12.attention.query_key_value.bias": "pytorch_model-00004-of-00008.bin", - "transformer.layers.12.attention.query_key_value.weight": "pytorch_model-00004-of-00008.bin", - "transformer.layers.12.attention.rotary_emb.inv_freq": "pytorch_model-00004-of-00008.bin", - "transformer.layers.12.input_layernorm.bias": "pytorch_model-00004-of-00008.bin", - "transformer.layers.12.input_layernorm.weight": "pytorch_model-00004-of-00008.bin", - "transformer.layers.12.mlp.dense_4h_to_h.bias": "pytorch_model-00004-of-00008.bin", - "transformer.layers.12.mlp.dense_4h_to_h.weight": "pytorch_model-00004-of-00008.bin", - "transformer.layers.12.mlp.dense_h_to_4h.bias": "pytorch_model-00004-of-00008.bin", - "transformer.layers.12.mlp.dense_h_to_4h.weight": "pytorch_model-00004-of-00008.bin", - "transformer.layers.12.post_attention_layernorm.bias": "pytorch_model-00004-of-00008.bin", - "transformer.layers.12.post_attention_layernorm.weight": "pytorch_model-00004-of-00008.bin", - "transformer.layers.13.attention.dense.bias": "pytorch_model-00004-of-00008.bin", - "transformer.layers.13.attention.dense.weight": "pytorch_model-00004-of-00008.bin", - "transformer.layers.13.attention.query_key_value.bias": "pytorch_model-00004-of-00008.bin", - "transformer.layers.13.attention.query_key_value.weight": "pytorch_model-00004-of-00008.bin", - "transformer.layers.13.attention.rotary_emb.inv_freq": "pytorch_model-00004-of-00008.bin", - "transformer.layers.13.input_layernorm.bias": "pytorch_model-00004-of-00008.bin", - "transformer.layers.13.input_layernorm.weight": "pytorch_model-00004-of-00008.bin", - "transformer.layers.13.mlp.dense_4h_to_h.bias": "pytorch_model-00004-of-00008.bin", - "transformer.layers.13.mlp.dense_4h_to_h.weight": "pytorch_model-00004-of-00008.bin", - "transformer.layers.13.mlp.dense_h_to_4h.bias": "pytorch_model-00004-of-00008.bin", - "transformer.layers.13.mlp.dense_h_to_4h.weight": "pytorch_model-00004-of-00008.bin", - "transformer.layers.13.post_attention_layernorm.bias": "pytorch_model-00004-of-00008.bin", - "transformer.layers.13.post_attention_layernorm.weight": "pytorch_model-00004-of-00008.bin", - "transformer.layers.14.attention.dense.bias": "pytorch_model-00004-of-00008.bin", - "transformer.layers.14.attention.dense.weight": "pytorch_model-00004-of-00008.bin", - "transformer.layers.14.attention.query_key_value.bias": "pytorch_model-00004-of-00008.bin", - "transformer.layers.14.attention.query_key_value.weight": "pytorch_model-00004-of-00008.bin", - "transformer.layers.14.attention.rotary_emb.inv_freq": "pytorch_model-00004-of-00008.bin", - "transformer.layers.14.input_layernorm.bias": "pytorch_model-00004-of-00008.bin", - "transformer.layers.14.input_layernorm.weight": "pytorch_model-00004-of-00008.bin", - "transformer.layers.14.mlp.dense_4h_to_h.bias": "pytorch_model-00004-of-00008.bin", - "transformer.layers.14.mlp.dense_4h_to_h.weight": "pytorch_model-00004-of-00008.bin", - "transformer.layers.14.mlp.dense_h_to_4h.bias": "pytorch_model-00004-of-00008.bin", - "transformer.layers.14.mlp.dense_h_to_4h.weight": "pytorch_model-00004-of-00008.bin", - "transformer.layers.14.post_attention_layernorm.bias": "pytorch_model-00004-of-00008.bin", - "transformer.layers.14.post_attention_layernorm.weight": "pytorch_model-00004-of-00008.bin", - "transformer.layers.15.attention.dense.bias": "pytorch_model-00004-of-00008.bin", - "transformer.layers.15.attention.dense.weight": "pytorch_model-00004-of-00008.bin", - "transformer.layers.15.attention.query_key_value.bias": "pytorch_model-00004-of-00008.bin", - "transformer.layers.15.attention.query_key_value.weight": "pytorch_model-00004-of-00008.bin", - "transformer.layers.15.attention.rotary_emb.inv_freq": "pytorch_model-00004-of-00008.bin", - "transformer.layers.15.input_layernorm.bias": "pytorch_model-00004-of-00008.bin", - "transformer.layers.15.input_layernorm.weight": "pytorch_model-00004-of-00008.bin", - "transformer.layers.15.mlp.dense_4h_to_h.bias": "pytorch_model-00004-of-00008.bin", - "transformer.layers.15.mlp.dense_4h_to_h.weight": "pytorch_model-00004-of-00008.bin", - "transformer.layers.15.mlp.dense_h_to_4h.bias": "pytorch_model-00004-of-00008.bin", - "transformer.layers.15.mlp.dense_h_to_4h.weight": "pytorch_model-00004-of-00008.bin", - "transformer.layers.15.post_attention_layernorm.bias": "pytorch_model-00004-of-00008.bin", - "transformer.layers.15.post_attention_layernorm.weight": "pytorch_model-00004-of-00008.bin", - "transformer.layers.16.attention.dense.bias": "pytorch_model-00005-of-00008.bin", - "transformer.layers.16.attention.dense.weight": "pytorch_model-00005-of-00008.bin", - "transformer.layers.16.attention.query_key_value.bias": "pytorch_model-00005-of-00008.bin", - "transformer.layers.16.attention.query_key_value.weight": "pytorch_model-00005-of-00008.bin", - "transformer.layers.16.attention.rotary_emb.inv_freq": "pytorch_model-00004-of-00008.bin", - "transformer.layers.16.input_layernorm.bias": "pytorch_model-00004-of-00008.bin", - "transformer.layers.16.input_layernorm.weight": "pytorch_model-00004-of-00008.bin", - "transformer.layers.16.mlp.dense_4h_to_h.bias": "pytorch_model-00005-of-00008.bin", - "transformer.layers.16.mlp.dense_4h_to_h.weight": "pytorch_model-00005-of-00008.bin", - "transformer.layers.16.mlp.dense_h_to_4h.bias": "pytorch_model-00005-of-00008.bin", - "transformer.layers.16.mlp.dense_h_to_4h.weight": "pytorch_model-00005-of-00008.bin", - "transformer.layers.16.post_attention_layernorm.bias": "pytorch_model-00005-of-00008.bin", - "transformer.layers.16.post_attention_layernorm.weight": "pytorch_model-00005-of-00008.bin", - "transformer.layers.17.attention.dense.bias": "pytorch_model-00005-of-00008.bin", - "transformer.layers.17.attention.dense.weight": "pytorch_model-00005-of-00008.bin", - "transformer.layers.17.attention.query_key_value.bias": "pytorch_model-00005-of-00008.bin", - "transformer.layers.17.attention.query_key_value.weight": "pytorch_model-00005-of-00008.bin", - "transformer.layers.17.attention.rotary_emb.inv_freq": "pytorch_model-00005-of-00008.bin", - "transformer.layers.17.input_layernorm.bias": "pytorch_model-00005-of-00008.bin", - "transformer.layers.17.input_layernorm.weight": "pytorch_model-00005-of-00008.bin", - "transformer.layers.17.mlp.dense_4h_to_h.bias": "pytorch_model-00005-of-00008.bin", - "transformer.layers.17.mlp.dense_4h_to_h.weight": "pytorch_model-00005-of-00008.bin", - "transformer.layers.17.mlp.dense_h_to_4h.bias": "pytorch_model-00005-of-00008.bin", - "transformer.layers.17.mlp.dense_h_to_4h.weight": "pytorch_model-00005-of-00008.bin", - "transformer.layers.17.post_attention_layernorm.bias": "pytorch_model-00005-of-00008.bin", - "transformer.layers.17.post_attention_layernorm.weight": "pytorch_model-00005-of-00008.bin", - "transformer.layers.18.attention.dense.bias": "pytorch_model-00005-of-00008.bin", - "transformer.layers.18.attention.dense.weight": "pytorch_model-00005-of-00008.bin", - "transformer.layers.18.attention.query_key_value.bias": "pytorch_model-00005-of-00008.bin", - "transformer.layers.18.attention.query_key_value.weight": "pytorch_model-00005-of-00008.bin", - "transformer.layers.18.attention.rotary_emb.inv_freq": "pytorch_model-00005-of-00008.bin", - "transformer.layers.18.input_layernorm.bias": "pytorch_model-00005-of-00008.bin", - "transformer.layers.18.input_layernorm.weight": "pytorch_model-00005-of-00008.bin", - "transformer.layers.18.mlp.dense_4h_to_h.bias": "pytorch_model-00005-of-00008.bin", - "transformer.layers.18.mlp.dense_4h_to_h.weight": "pytorch_model-00005-of-00008.bin", - "transformer.layers.18.mlp.dense_h_to_4h.bias": "pytorch_model-00005-of-00008.bin", - "transformer.layers.18.mlp.dense_h_to_4h.weight": "pytorch_model-00005-of-00008.bin", - "transformer.layers.18.post_attention_layernorm.bias": "pytorch_model-00005-of-00008.bin", - "transformer.layers.18.post_attention_layernorm.weight": "pytorch_model-00005-of-00008.bin", - "transformer.layers.19.attention.dense.bias": "pytorch_model-00005-of-00008.bin", - "transformer.layers.19.attention.dense.weight": "pytorch_model-00005-of-00008.bin", - "transformer.layers.19.attention.query_key_value.bias": "pytorch_model-00005-of-00008.bin", - "transformer.layers.19.attention.query_key_value.weight": "pytorch_model-00005-of-00008.bin", - "transformer.layers.19.attention.rotary_emb.inv_freq": "pytorch_model-00005-of-00008.bin", - "transformer.layers.19.input_layernorm.bias": "pytorch_model-00005-of-00008.bin", - "transformer.layers.19.input_layernorm.weight": "pytorch_model-00005-of-00008.bin", - "transformer.layers.19.mlp.dense_4h_to_h.bias": "pytorch_model-00005-of-00008.bin", - "transformer.layers.19.mlp.dense_4h_to_h.weight": "pytorch_model-00005-of-00008.bin", - "transformer.layers.19.mlp.dense_h_to_4h.bias": "pytorch_model-00005-of-00008.bin", - "transformer.layers.19.mlp.dense_h_to_4h.weight": "pytorch_model-00005-of-00008.bin", - "transformer.layers.19.post_attention_layernorm.bias": "pytorch_model-00005-of-00008.bin", - "transformer.layers.19.post_attention_layernorm.weight": "pytorch_model-00005-of-00008.bin", - "transformer.layers.2.attention.dense.bias": "pytorch_model-00002-of-00008.bin", - "transformer.layers.2.attention.dense.weight": "pytorch_model-00002-of-00008.bin", - "transformer.layers.2.attention.query_key_value.bias": "pytorch_model-00002-of-00008.bin", - "transformer.layers.2.attention.query_key_value.weight": "pytorch_model-00002-of-00008.bin", - "transformer.layers.2.attention.rotary_emb.inv_freq": "pytorch_model-00002-of-00008.bin", - "transformer.layers.2.input_layernorm.bias": "pytorch_model-00002-of-00008.bin", - "transformer.layers.2.input_layernorm.weight": "pytorch_model-00002-of-00008.bin", - "transformer.layers.2.mlp.dense_4h_to_h.bias": "pytorch_model-00002-of-00008.bin", - "transformer.layers.2.mlp.dense_4h_to_h.weight": "pytorch_model-00002-of-00008.bin", - "transformer.layers.2.mlp.dense_h_to_4h.bias": "pytorch_model-00002-of-00008.bin", - "transformer.layers.2.mlp.dense_h_to_4h.weight": "pytorch_model-00002-of-00008.bin", - "transformer.layers.2.post_attention_layernorm.bias": "pytorch_model-00002-of-00008.bin", - "transformer.layers.2.post_attention_layernorm.weight": "pytorch_model-00002-of-00008.bin", - "transformer.layers.20.attention.dense.bias": "pytorch_model-00005-of-00008.bin", - "transformer.layers.20.attention.dense.weight": "pytorch_model-00005-of-00008.bin", - "transformer.layers.20.attention.query_key_value.bias": "pytorch_model-00005-of-00008.bin", - "transformer.layers.20.attention.query_key_value.weight": "pytorch_model-00005-of-00008.bin", - "transformer.layers.20.attention.rotary_emb.inv_freq": "pytorch_model-00005-of-00008.bin", - "transformer.layers.20.input_layernorm.bias": "pytorch_model-00005-of-00008.bin", - "transformer.layers.20.input_layernorm.weight": "pytorch_model-00005-of-00008.bin", - "transformer.layers.20.mlp.dense_4h_to_h.bias": "pytorch_model-00006-of-00008.bin", - "transformer.layers.20.mlp.dense_4h_to_h.weight": "pytorch_model-00006-of-00008.bin", - "transformer.layers.20.mlp.dense_h_to_4h.bias": "pytorch_model-00005-of-00008.bin", - "transformer.layers.20.mlp.dense_h_to_4h.weight": "pytorch_model-00005-of-00008.bin", - "transformer.layers.20.post_attention_layernorm.bias": "pytorch_model-00005-of-00008.bin", - "transformer.layers.20.post_attention_layernorm.weight": "pytorch_model-00005-of-00008.bin", - "transformer.layers.21.attention.dense.bias": "pytorch_model-00006-of-00008.bin", - "transformer.layers.21.attention.dense.weight": "pytorch_model-00006-of-00008.bin", - "transformer.layers.21.attention.query_key_value.bias": "pytorch_model-00006-of-00008.bin", - "transformer.layers.21.attention.query_key_value.weight": "pytorch_model-00006-of-00008.bin", - "transformer.layers.21.attention.rotary_emb.inv_freq": "pytorch_model-00006-of-00008.bin", - "transformer.layers.21.input_layernorm.bias": "pytorch_model-00006-of-00008.bin", - "transformer.layers.21.input_layernorm.weight": "pytorch_model-00006-of-00008.bin", - "transformer.layers.21.mlp.dense_4h_to_h.bias": "pytorch_model-00006-of-00008.bin", - "transformer.layers.21.mlp.dense_4h_to_h.weight": "pytorch_model-00006-of-00008.bin", - "transformer.layers.21.mlp.dense_h_to_4h.bias": "pytorch_model-00006-of-00008.bin", - "transformer.layers.21.mlp.dense_h_to_4h.weight": "pytorch_model-00006-of-00008.bin", - "transformer.layers.21.post_attention_layernorm.bias": "pytorch_model-00006-of-00008.bin", - "transformer.layers.21.post_attention_layernorm.weight": "pytorch_model-00006-of-00008.bin", - "transformer.layers.22.attention.dense.bias": "pytorch_model-00006-of-00008.bin", - "transformer.layers.22.attention.dense.weight": "pytorch_model-00006-of-00008.bin", - "transformer.layers.22.attention.query_key_value.bias": "pytorch_model-00006-of-00008.bin", - "transformer.layers.22.attention.query_key_value.weight": "pytorch_model-00006-of-00008.bin", - "transformer.layers.22.attention.rotary_emb.inv_freq": "pytorch_model-00006-of-00008.bin", - "transformer.layers.22.input_layernorm.bias": "pytorch_model-00006-of-00008.bin", - "transformer.layers.22.input_layernorm.weight": "pytorch_model-00006-of-00008.bin", - "transformer.layers.22.mlp.dense_4h_to_h.bias": "pytorch_model-00006-of-00008.bin", - "transformer.layers.22.mlp.dense_4h_to_h.weight": "pytorch_model-00006-of-00008.bin", - "transformer.layers.22.mlp.dense_h_to_4h.bias": "pytorch_model-00006-of-00008.bin", - "transformer.layers.22.mlp.dense_h_to_4h.weight": "pytorch_model-00006-of-00008.bin", - "transformer.layers.22.post_attention_layernorm.bias": "pytorch_model-00006-of-00008.bin", - "transformer.layers.22.post_attention_layernorm.weight": "pytorch_model-00006-of-00008.bin", - "transformer.layers.23.attention.dense.bias": "pytorch_model-00006-of-00008.bin", - "transformer.layers.23.attention.dense.weight": "pytorch_model-00006-of-00008.bin", - "transformer.layers.23.attention.query_key_value.bias": "pytorch_model-00006-of-00008.bin", - "transformer.layers.23.attention.query_key_value.weight": "pytorch_model-00006-of-00008.bin", - "transformer.layers.23.attention.rotary_emb.inv_freq": "pytorch_model-00006-of-00008.bin", - "transformer.layers.23.input_layernorm.bias": "pytorch_model-00006-of-00008.bin", - "transformer.layers.23.input_layernorm.weight": "pytorch_model-00006-of-00008.bin", - "transformer.layers.23.mlp.dense_4h_to_h.bias": "pytorch_model-00006-of-00008.bin", - "transformer.layers.23.mlp.dense_4h_to_h.weight": "pytorch_model-00006-of-00008.bin", - "transformer.layers.23.mlp.dense_h_to_4h.bias": "pytorch_model-00006-of-00008.bin", - "transformer.layers.23.mlp.dense_h_to_4h.weight": "pytorch_model-00006-of-00008.bin", - "transformer.layers.23.post_attention_layernorm.bias": "pytorch_model-00006-of-00008.bin", - "transformer.layers.23.post_attention_layernorm.weight": "pytorch_model-00006-of-00008.bin", - "transformer.layers.24.attention.dense.bias": "pytorch_model-00006-of-00008.bin", - "transformer.layers.24.attention.dense.weight": "pytorch_model-00006-of-00008.bin", - "transformer.layers.24.attention.query_key_value.bias": "pytorch_model-00006-of-00008.bin", - "transformer.layers.24.attention.query_key_value.weight": "pytorch_model-00006-of-00008.bin", - "transformer.layers.24.attention.rotary_emb.inv_freq": "pytorch_model-00006-of-00008.bin", - "transformer.layers.24.input_layernorm.bias": "pytorch_model-00006-of-00008.bin", - "transformer.layers.24.input_layernorm.weight": "pytorch_model-00006-of-00008.bin", - "transformer.layers.24.mlp.dense_4h_to_h.bias": "pytorch_model-00006-of-00008.bin", - "transformer.layers.24.mlp.dense_4h_to_h.weight": "pytorch_model-00006-of-00008.bin", - "transformer.layers.24.mlp.dense_h_to_4h.bias": "pytorch_model-00006-of-00008.bin", - "transformer.layers.24.mlp.dense_h_to_4h.weight": "pytorch_model-00006-of-00008.bin", - "transformer.layers.24.post_attention_layernorm.bias": "pytorch_model-00006-of-00008.bin", - "transformer.layers.24.post_attention_layernorm.weight": "pytorch_model-00006-of-00008.bin", - "transformer.layers.25.attention.dense.bias": "pytorch_model-00006-of-00008.bin", - "transformer.layers.25.attention.dense.weight": "pytorch_model-00006-of-00008.bin", - "transformer.layers.25.attention.query_key_value.bias": "pytorch_model-00006-of-00008.bin", - "transformer.layers.25.attention.query_key_value.weight": "pytorch_model-00006-of-00008.bin", - "transformer.layers.25.attention.rotary_emb.inv_freq": "pytorch_model-00006-of-00008.bin", - "transformer.layers.25.input_layernorm.bias": "pytorch_model-00006-of-00008.bin", - "transformer.layers.25.input_layernorm.weight": "pytorch_model-00006-of-00008.bin", - "transformer.layers.25.mlp.dense_4h_to_h.bias": "pytorch_model-00007-of-00008.bin", - "transformer.layers.25.mlp.dense_4h_to_h.weight": "pytorch_model-00007-of-00008.bin", - "transformer.layers.25.mlp.dense_h_to_4h.bias": "pytorch_model-00007-of-00008.bin", - "transformer.layers.25.mlp.dense_h_to_4h.weight": "pytorch_model-00007-of-00008.bin", - "transformer.layers.25.post_attention_layernorm.bias": "pytorch_model-00006-of-00008.bin", - "transformer.layers.25.post_attention_layernorm.weight": "pytorch_model-00006-of-00008.bin", - "transformer.layers.26.attention.dense.bias": "pytorch_model-00007-of-00008.bin", - "transformer.layers.26.attention.dense.weight": "pytorch_model-00007-of-00008.bin", - "transformer.layers.26.attention.query_key_value.bias": "pytorch_model-00007-of-00008.bin", - "transformer.layers.26.attention.query_key_value.weight": "pytorch_model-00007-of-00008.bin", - "transformer.layers.26.attention.rotary_emb.inv_freq": "pytorch_model-00007-of-00008.bin", - "transformer.layers.26.input_layernorm.bias": "pytorch_model-00007-of-00008.bin", - "transformer.layers.26.input_layernorm.weight": "pytorch_model-00007-of-00008.bin", - "transformer.layers.26.mlp.dense_4h_to_h.bias": "pytorch_model-00007-of-00008.bin", - "transformer.layers.26.mlp.dense_4h_to_h.weight": "pytorch_model-00007-of-00008.bin", - "transformer.layers.26.mlp.dense_h_to_4h.bias": "pytorch_model-00007-of-00008.bin", - "transformer.layers.26.mlp.dense_h_to_4h.weight": "pytorch_model-00007-of-00008.bin", - "transformer.layers.26.post_attention_layernorm.bias": "pytorch_model-00007-of-00008.bin", - "transformer.layers.26.post_attention_layernorm.weight": "pytorch_model-00007-of-00008.bin", - "transformer.layers.27.attention.dense.bias": "pytorch_model-00007-of-00008.bin", - "transformer.layers.27.attention.dense.weight": "pytorch_model-00007-of-00008.bin", - "transformer.layers.27.attention.query_key_value.bias": "pytorch_model-00007-of-00008.bin", - "transformer.layers.27.attention.query_key_value.weight": "pytorch_model-00007-of-00008.bin", - "transformer.layers.27.attention.rotary_emb.inv_freq": "pytorch_model-00007-of-00008.bin", - "transformer.layers.27.input_layernorm.bias": "pytorch_model-00007-of-00008.bin", - "transformer.layers.27.input_layernorm.weight": "pytorch_model-00007-of-00008.bin", - "transformer.layers.27.mlp.dense_4h_to_h.bias": "pytorch_model-00007-of-00008.bin", - "transformer.layers.27.mlp.dense_4h_to_h.weight": "pytorch_model-00007-of-00008.bin", - "transformer.layers.27.mlp.dense_h_to_4h.bias": "pytorch_model-00007-of-00008.bin", - "transformer.layers.27.mlp.dense_h_to_4h.weight": "pytorch_model-00007-of-00008.bin", - "transformer.layers.27.post_attention_layernorm.bias": "pytorch_model-00007-of-00008.bin", - "transformer.layers.27.post_attention_layernorm.weight": "pytorch_model-00007-of-00008.bin", - "transformer.layers.3.attention.dense.bias": "pytorch_model-00002-of-00008.bin", - "transformer.layers.3.attention.dense.weight": "pytorch_model-00002-of-00008.bin", - "transformer.layers.3.attention.query_key_value.bias": "pytorch_model-00002-of-00008.bin", - "transformer.layers.3.attention.query_key_value.weight": "pytorch_model-00002-of-00008.bin", - "transformer.layers.3.attention.rotary_emb.inv_freq": "pytorch_model-00002-of-00008.bin", - "transformer.layers.3.input_layernorm.bias": "pytorch_model-00002-of-00008.bin", - "transformer.layers.3.input_layernorm.weight": "pytorch_model-00002-of-00008.bin", - "transformer.layers.3.mlp.dense_4h_to_h.bias": "pytorch_model-00002-of-00008.bin", - "transformer.layers.3.mlp.dense_4h_to_h.weight": "pytorch_model-00002-of-00008.bin", - "transformer.layers.3.mlp.dense_h_to_4h.bias": "pytorch_model-00002-of-00008.bin", - "transformer.layers.3.mlp.dense_h_to_4h.weight": "pytorch_model-00002-of-00008.bin", - "transformer.layers.3.post_attention_layernorm.bias": "pytorch_model-00002-of-00008.bin", - "transformer.layers.3.post_attention_layernorm.weight": "pytorch_model-00002-of-00008.bin", - "transformer.layers.4.attention.dense.bias": "pytorch_model-00002-of-00008.bin", - "transformer.layers.4.attention.dense.weight": "pytorch_model-00002-of-00008.bin", - "transformer.layers.4.attention.query_key_value.bias": "pytorch_model-00002-of-00008.bin", - "transformer.layers.4.attention.query_key_value.weight": "pytorch_model-00002-of-00008.bin", - "transformer.layers.4.attention.rotary_emb.inv_freq": "pytorch_model-00002-of-00008.bin", - "transformer.layers.4.input_layernorm.bias": "pytorch_model-00002-of-00008.bin", - "transformer.layers.4.input_layernorm.weight": "pytorch_model-00002-of-00008.bin", - "transformer.layers.4.mlp.dense_4h_to_h.bias": "pytorch_model-00002-of-00008.bin", - "transformer.layers.4.mlp.dense_4h_to_h.weight": "pytorch_model-00002-of-00008.bin", - "transformer.layers.4.mlp.dense_h_to_4h.bias": "pytorch_model-00002-of-00008.bin", - "transformer.layers.4.mlp.dense_h_to_4h.weight": "pytorch_model-00002-of-00008.bin", - "transformer.layers.4.post_attention_layernorm.bias": "pytorch_model-00002-of-00008.bin", - "transformer.layers.4.post_attention_layernorm.weight": "pytorch_model-00002-of-00008.bin", - "transformer.layers.5.attention.dense.bias": "pytorch_model-00002-of-00008.bin", - "transformer.layers.5.attention.dense.weight": "pytorch_model-00002-of-00008.bin", - "transformer.layers.5.attention.query_key_value.bias": "pytorch_model-00002-of-00008.bin", - "transformer.layers.5.attention.query_key_value.weight": "pytorch_model-00002-of-00008.bin", - "transformer.layers.5.attention.rotary_emb.inv_freq": "pytorch_model-00002-of-00008.bin", - "transformer.layers.5.input_layernorm.bias": "pytorch_model-00002-of-00008.bin", - "transformer.layers.5.input_layernorm.weight": "pytorch_model-00002-of-00008.bin", - "transformer.layers.5.mlp.dense_4h_to_h.bias": "pytorch_model-00002-of-00008.bin", - "transformer.layers.5.mlp.dense_4h_to_h.weight": "pytorch_model-00002-of-00008.bin", - "transformer.layers.5.mlp.dense_h_to_4h.bias": "pytorch_model-00002-of-00008.bin", - "transformer.layers.5.mlp.dense_h_to_4h.weight": "pytorch_model-00002-of-00008.bin", - "transformer.layers.5.post_attention_layernorm.bias": "pytorch_model-00002-of-00008.bin", - "transformer.layers.5.post_attention_layernorm.weight": "pytorch_model-00002-of-00008.bin", - "transformer.layers.6.attention.dense.bias": "pytorch_model-00002-of-00008.bin", - "transformer.layers.6.attention.dense.weight": "pytorch_model-00002-of-00008.bin", - "transformer.layers.6.attention.query_key_value.bias": "pytorch_model-00002-of-00008.bin", - "transformer.layers.6.attention.query_key_value.weight": "pytorch_model-00002-of-00008.bin", - "transformer.layers.6.attention.rotary_emb.inv_freq": "pytorch_model-00002-of-00008.bin", - "transformer.layers.6.input_layernorm.bias": "pytorch_model-00002-of-00008.bin", - "transformer.layers.6.input_layernorm.weight": "pytorch_model-00002-of-00008.bin", - "transformer.layers.6.mlp.dense_4h_to_h.bias": "pytorch_model-00003-of-00008.bin", - "transformer.layers.6.mlp.dense_4h_to_h.weight": "pytorch_model-00003-of-00008.bin", - "transformer.layers.6.mlp.dense_h_to_4h.bias": "pytorch_model-00003-of-00008.bin", - "transformer.layers.6.mlp.dense_h_to_4h.weight": "pytorch_model-00003-of-00008.bin", - "transformer.layers.6.post_attention_layernorm.bias": "pytorch_model-00002-of-00008.bin", - "transformer.layers.6.post_attention_layernorm.weight": "pytorch_model-00002-of-00008.bin", - "transformer.layers.7.attention.dense.bias": "pytorch_model-00003-of-00008.bin", - "transformer.layers.7.attention.dense.weight": "pytorch_model-00003-of-00008.bin", - "transformer.layers.7.attention.query_key_value.bias": "pytorch_model-00003-of-00008.bin", - "transformer.layers.7.attention.query_key_value.weight": "pytorch_model-00003-of-00008.bin", - "transformer.layers.7.attention.rotary_emb.inv_freq": "pytorch_model-00003-of-00008.bin", - "transformer.layers.7.input_layernorm.bias": "pytorch_model-00003-of-00008.bin", - "transformer.layers.7.input_layernorm.weight": "pytorch_model-00003-of-00008.bin", - "transformer.layers.7.mlp.dense_4h_to_h.bias": "pytorch_model-00003-of-00008.bin", - "transformer.layers.7.mlp.dense_4h_to_h.weight": "pytorch_model-00003-of-00008.bin", - "transformer.layers.7.mlp.dense_h_to_4h.bias": "pytorch_model-00003-of-00008.bin", - "transformer.layers.7.mlp.dense_h_to_4h.weight": "pytorch_model-00003-of-00008.bin", - "transformer.layers.7.post_attention_layernorm.bias": "pytorch_model-00003-of-00008.bin", - "transformer.layers.7.post_attention_layernorm.weight": "pytorch_model-00003-of-00008.bin", - "transformer.layers.8.attention.dense.bias": "pytorch_model-00003-of-00008.bin", - "transformer.layers.8.attention.dense.weight": "pytorch_model-00003-of-00008.bin", - "transformer.layers.8.attention.query_key_value.bias": "pytorch_model-00003-of-00008.bin", - "transformer.layers.8.attention.query_key_value.weight": "pytorch_model-00003-of-00008.bin", - "transformer.layers.8.attention.rotary_emb.inv_freq": "pytorch_model-00003-of-00008.bin", - "transformer.layers.8.input_layernorm.bias": "pytorch_model-00003-of-00008.bin", - "transformer.layers.8.input_layernorm.weight": "pytorch_model-00003-of-00008.bin", - "transformer.layers.8.mlp.dense_4h_to_h.bias": "pytorch_model-00003-of-00008.bin", - "transformer.layers.8.mlp.dense_4h_to_h.weight": "pytorch_model-00003-of-00008.bin", - "transformer.layers.8.mlp.dense_h_to_4h.bias": "pytorch_model-00003-of-00008.bin", - "transformer.layers.8.mlp.dense_h_to_4h.weight": "pytorch_model-00003-of-00008.bin", - "transformer.layers.8.post_attention_layernorm.bias": "pytorch_model-00003-of-00008.bin", - "transformer.layers.8.post_attention_layernorm.weight": "pytorch_model-00003-of-00008.bin", - "transformer.layers.9.attention.dense.bias": "pytorch_model-00003-of-00008.bin", - "transformer.layers.9.attention.dense.weight": "pytorch_model-00003-of-00008.bin", - "transformer.layers.9.attention.query_key_value.bias": "pytorch_model-00003-of-00008.bin", - "transformer.layers.9.attention.query_key_value.weight": "pytorch_model-00003-of-00008.bin", - "transformer.layers.9.attention.rotary_emb.inv_freq": "pytorch_model-00003-of-00008.bin", - "transformer.layers.9.input_layernorm.bias": "pytorch_model-00003-of-00008.bin", - "transformer.layers.9.input_layernorm.weight": "pytorch_model-00003-of-00008.bin", - "transformer.layers.9.mlp.dense_4h_to_h.bias": "pytorch_model-00003-of-00008.bin", - "transformer.layers.9.mlp.dense_4h_to_h.weight": "pytorch_model-00003-of-00008.bin", - "transformer.layers.9.mlp.dense_h_to_4h.bias": "pytorch_model-00003-of-00008.bin", - "transformer.layers.9.mlp.dense_h_to_4h.weight": "pytorch_model-00003-of-00008.bin", - "transformer.layers.9.post_attention_layernorm.bias": "pytorch_model-00003-of-00008.bin", - "transformer.layers.9.post_attention_layernorm.weight": "pytorch_model-00003-of-00008.bin", - "transformer.word_embeddings.weight": "pytorch_model-00001-of-00008.bin" - } -} diff --git a/tests/script/models/chatglm-6b/quantization.py b/tests/script/models/chatglm-6b/quantization.py deleted file mode 100644 index 6f469f6..0000000 --- a/tests/script/models/chatglm-6b/quantization.py +++ /dev/null @@ -1,201 +0,0 @@ -from torch.nn import Linear -from torch.nn.parameter import Parameter - -import bz2 -import torch -import base64 -import ctypes -from transformers.utils import logging - -from typing import List -from functools import partial - -logger = logging.get_logger(__name__) - -try: - from cpm_kernels.kernels.base import LazyKernelCModule, KernelFunction, round_up - - class Kernel: - def __init__(self, code: bytes, function_names: List[str]): - self.code = code - self._function_names = function_names - self._cmodule = LazyKernelCModule(self.code) - - for name in self._function_names: - setattr(self, name, KernelFunction(self._cmodule, name)) - - quantization_code = "$QlpoOTFBWSZTWU9yuJUAQHN//////////f/n/8/n///n//bt4dTidcVx8X3V9FV/92/v4B7/AD5FBQFAAAChSgKpFCFAFVSigUAAAEKhSgUUqgFBKigqVREQAABQBQIANDTTIGI00BkZBkNGE0A0BkBkGQGRkaNAaAGQNBoGgDIAAYIGTI0DQAQAaGmmQMRpoDIyDIaMJoBoDIDIMgMjI0aA0AMgaDQNAGQAAwQMmRoGgAgA0NNMgYjTQGRkGQ0YTQDQGQGQZAZGRo0BoAZA0GgaAMgABggZMjQNABABoaaZAxGmgMjIMhowmgGgMgMgyAyMjRoDQAyBoNA0AZAADBAyZGgaAAmqU1NEgJqnptU/Sn4jRR6J6epk2pqb1Q/SgAPUGgyNNGjQ2SBpoAZAAGg0NB6mgDIAAAAA2oaApSREBNAARhGiYEaEwU8pvImlP0k2aam1GaGqbFNM1MHpTwmkepmyU9R6nqPKekHqNNPUxNGhp6n6p6QaZ6o9TG1GMqcoV9ly6nRanHlq6zPNbnGZNi6HSug+2nPiZ13XcnFYZW+45W11CumhzYhchOJ2GLLV1OBjBjGf4TptOddTSOcVxhqYZMYwZXZZY00zI1paX5X9J+b+f4e+x43RXSxXPOdquiGpduatGyXneN696M9t4HU2eR5XX/kPhP261NTx3JO1Ow7LyuDmeo9a7d351T1ZxnvnrvYnrXv/hXxPCeuYx2XsNmO003eg9J3Z6U7b23meJ4ri01OdzTk9BNO96brz+qT5nuvvH3ds/G+m/JcG/F2XYuhXlvO+jP7U3XgrzPN/lr8Sf1n6j4j7jZs+s/T0tNaNNYzTs12rxjwztHlnire3Nzc3N1wuBwOBwXBvZfoHpD7rFmR99V5vj3aXza3xdBbXMalubTg/jIv5dfAi54Pdc75j4z412n3Npj3Ld/ENm7a3b/Cod6h/ret1/5vn/C+l+gdslMvgPSLJ8d8q+U66fevYn/tW1chleEtNTGlcHCbLRlq0tHzF5tsbbZZfHjjLgZu42XCuC3NrdjTasZGNzgxPIrGqp7r3p7L2p5XjnpPSmTd5XtzqnB6U87zzg1Ol0zd0zsLszxR6lkxp35u6/teL0L0W922cR7Lu1lpL9CsHirzuM2T+BgsyViT6LHcm0/Vr6U/7LGGyJeqTEjt0PHWhF5mCT7R9mtlDwriYv0Tyr/OxYt6qp5r0mPVT0608TqnqMZaarU2nFwrTzzlrs1ed7z1ux60wyr4ydCaTi3enW8x68x0zU7tXSlcmPSW1mGpWJMg4zmPC2lK96tp0OE80y4MfEvnZj8zGluR6b22ki1Ou9V2nCd9xovcPvcYMZYy0lvN60ScZ45vN6yeCeeXFb1lVjnnCar5fwXwE2bzJ4HI1XVPXfXZMm44GUsMpYsmLB65TuVdm0cl0b+i/wGNN66XjeV7zuPpHcnK/juhhjdfId5jMdE5nN0dGmmm2zZs2cexD5n9p/dY352XsvXHaZNWWsmmS1atjR452nYudzvqv2HMRyvNNnlMcDl3R2+yx2uVrBubTW9icHDVtbNXlZm7jma1rM4VurZZd2y6nUau7ZXZ7bVU+mnoOVxZGMrVmvX60605JwmzGZhhhjTWtaaaMaaGTGmNMZasY0iX8VMUl8eepaIrzGSpemWOQyZORk2bNpjUybMmxqYmknCGCFynutfksaZpjTNMaaatM0xsxcGR0sociNqxNSmhhR1ZJPbsn8qyF0t2qH6iYBclclalbtTTcHTDsPaX6rlnElph2Jyumumtynv2Kk8GI7rsvXbIcJgHJOSaSXnnGaI3m87RtVXJOZ/YtgdTE6Wpha6ZlE8ayXkef1fh602r2WwvfMXtMdLlkfnLFdYYwYso+bWqm7yJqHXZGw2nrS5ZanSYnWlxBxMF1V940K2wdrI7R6OYf7DGGamMmTSbRhlS45xmVOumF1EyPCmHrrN8wwZOOrdNtLeMtzFzDlWnfTBxMk2NaXIZHBYxYLD4w8yju0ao65Vz1OIXoS9dLanwCe1PWrYuWMqf1if1z2k2yYfKJ741PDgno1ZQ8DRqvUny3mNoWTzGO6m1DkrJI8JiR5cSd+vZdGOO8nrMoc5+NDUFsMSXaZJeNlMmGLtJsovOsUp7I9S5VojKxF6bTVEelXqlfJobQr3LozSh2Jk7VcrVMfhXqszGWMzNqGhqZY0OadxkyyMssKugZR0KNFXBHlqwmJgTE/BNVMk6ItJXZMR0H47GpXv/DMOvNkmVuaV1PRfEdxuqc7Hcd+ZV/zTLaRxWk0nl9CdCeM6mn5rstHIBcpiuwmUZXeq81DacHI2rmrZ5SuE5mOZd6LQrZg9mx32TprA8BMo5jKN6yLTCi3WzQaZSuhzTtM1fUTGVpG8Tw+KXI0tjEpiWxtLYynOlktSbVlaI5kxP8TDH8kx50xoxi5KcA4pcja8KWLRlO/Ks6q06ergnvm1ca3Tq8Uw7LTUsmWyctXPWmpitl/uvGcWTGXGuAXDfhqazGmjkxcJW5hMMMMpYsXl2TZYtVOddG3XCarUt6Ptq9CZXSNzyuRzqRZOjsxdBbFVz6OA5HI43r1jityVlVpVkxmOsyaYWE1NTGq1sOVh36mHMcxtSvcy70edG0ZGR3I1Go1GRlV7mWWo1G0ZGRqlvH40l7o4m5xMWLLLYyNjnqc8556mdPqLJ31n/1nWOncxzG1tizrHs/Z+d2vP/B/l8wdJ6rHUn2nbbDq4p6htFtYzMMMTaZis1K5GKzGNmxhmUx2DDlZ/qNnIx41xnaMfCZWYaZWtNLTNW8ND4Fw1MyZOCdM428suKG1ehW8TesOydg7J+YYcD4cYR+8dFK6M4E3HM9ZfRNNL+Sn6rsl4DsrDl2HpPCnfxjGXtbZtYys1ttlyJ4T+BvexjGWRjMszK4Jpc77D3GyuVD7q0+G8m9G+2+rGm7cOR2y7FdtY2XUYx/oNlfRYxhMYyYZkyyg55enna9Kt/FFi6GMMwYwdwxWgxGMLKYmUyGExTKMZkMFhkymKuh0NOBNnBu+23LdwDoZYYzGGMxtORaTU1pjTGWTTGGtMrNWUsyyTTLLG1qy2ZjbK2DBllWqxMtBMaYZQmcE7zvvRcTkclUwdkxTaSdyySt/7fpL+T1v516Ji97fwr5JbLu305zMn5+GMTTZ9F+y7ExwmGVfG44yxn3dLv6l5i+Wth1jCrDq21nW9LqvvDzz3Vf3LLH/O/32TJ/erx3bXftO4eF+G956D952K/An4NfvOpjFjExjevP/UmE0fIoZXx6/w6lX/no3D0bLt+ixjieBM6ksRd0yB4Lt2SwYNE+gd1detlZWUnpiZfGfFaK+4PyCa/v18V8X75pe9fLXzp7l3VjF76vWZmHwGz1IZNWT7b8yddJ4q5kyrVdfru6atWc7bVYztL9Jf4GXvT+Y8m9/YsXP6H018a8D4XVOqvfzqeR+6yZOD8dPv0+U7/q5Pl+2dNb0MjzGVH5p6MNQ7cOWvw62U9aHE8DprDek+McLyvDz+te+9Zhq5+YTruufMcWMabqysTmZVWjKPfnK0wyVcrsuhjZRdLkHNvD72b9abriOSGIxiLixMOoalNPXzy+wT/tf+U6HHONfsz+xe8ufHBdQWWGWLA9if0rsnmrxK5LvRZQeWsTCsrmOYy8VteVfuRfcVTtDLItLIsMYxZLdU/DbtSemxF6Z6Zo5WBXE4tFdCyVMMXMTEMZXVlS6Xec2T4e0tHsRcEuWshcJ2YsNF5rUx1E8ifCq6Z+ZP7qdCeu/aTwFd53l16/o0NOw6O3dLavP4Hbi4RdmuDk6DoYaninC0+o4uZjbJ7Rxeu0/FbuFg+q7DVS6fQe0rZ6NDGUNNU6DEqOaLTicKnYZMnBWruljQxoaS3dZhocDge0bSTyOvdAbG5hxe2xji7E/L55xX13wWNDi6HCekcFxfCPGxY0MXC+s7afWaMdDyjyr+o8Rudm/NabOZvdl274zH4f5XK9z6On1Pe/K5TdPAslg77BjuO6Y3eO7GqvOPG/stknp1leyvLL0Z7bl9I4noMvLkzytLhWYzrOZzLXCORe028rORzOg4N/L0HlMOQ3Pgmnbb6KczlabORpu980q37TBqRu0/p3PO6234Bl03Ynuz+9W7gnsEcmvYaYY3aMYY0wx3pYd+ujsXauWdaY5Xkbtl23fPzFHiDB/QMo0yFjBllYxTQYYyxkrwn7JufwJ/PfgJ+C83X69ni6zvXcnyXabv0ncbLwsceS+RNlyN2mnneJtX0ngYO0+e+0+UnA+Wch3ji8hj5an4h+i6XBySU4n+R0roVcbw5yvHrmr4Yw8Y7x6c+9POPYHI5HI5HI5HI5HGXGww4nE4nrVyOR8XeqPEO7PLOiukYa3Novk5hV4cdtYZLI93e+uxff2jRo0aNGjRo0aNG1bVtW1dy3m83m8+tQ5ZzHw3nObwOu8La9Rc1dtkdS8A3eTk823tnktXWlxN6Oixe06zrN70Isd9jiOgZFq9yfkPqP/SLhN2Myl8jDM43bl1nbcb4cO57jlh8Jow6pzXZdL4dyODTuuhu77FyO27DdwdRxmvO+O+3N2+BdqyTwLHVczDVY4UPE4O66/ZO2cx1LFzVdSXtF7G4HMbrauOHRw6c8FdZ5m9fHZHYZXfTlZquyynSyTTKke6vcffSD9pzPA/G7n7jxPmuhc1DHMynPMrGL6AdewYmwu5ko+UUyTwrMv27rPH1v1nGqd87+p6N6LU8k3NEng53xXyHS97+44OSg/sy/hn+Se6yfYNjW0/uTgP+PvWYzLMmjhcLB/gGpri6H83/84eUXWT6T9Hsv7785z/7z4icpW+zfXypuR7rx/gMdZb1/wC678pcs8/2a3mDitGHxl9mfPlll5MafWWqxk/eYuTDgcNMzDGWLWvsuglNxs53GtN6uWpktlW1tZZYcuinMMWmnNnJydze3b2Y1McBxrBkXw799izLMZZYyy0TkbsGM4p03S2uVu5s/XXUdSdec6smVxZYYGpVmT8A+8ajuEyV5FatkvVru2x6uxGXXbH4A+jvgP4GMYy3iPLXzq/6z65+E005ey+cwMZD3fZcqc6xpjTFjQ0P3U+e++cPYmTIwj0nrK5NPTfl3WvpfLtXDcb2HQMudYOxFXQBor4L4T6vrOauFctYXJQ++NUWmJe5bmx1jDiZS1dTqWxo4GR8jm3fttpmPHppk9PEyv4/y8/sO07XacOmcqc0x2Vi9BvNJvN5oW8x4mOsydpidRxMYJPx06m1bqPzq9KtK8sxXNXFodD/+MYYaJTLwOhc9brCsV18oOR1i4tXChyTkq4lf4y1Ke+9axjDHqs1mfBbMXuP4Hzi+X7t8vzv7bHerrUPgPCxhjre4fXdfLNtNM+Jd+Zdh8xd8wP87uNPoPgv4W7/5P2BuxfsMabNnMnza+54Pdi5U671GPZY8CehX8Voeoo7FHpkeEc6715FwHZrIrUrHaviPUbPZHND+IhczrP6FcYvhOZ0Di/ETt0OI+YwNWR9r7tpf6WDeZKZDB1+z2IthOl1mPyb5FluvEx9h9d0NnM0Y1XPFkWIsk1WotJ0PBMmkvjvQTd0e71tfeV+8r8lQ/tpzpsmxJ+InrI/dj2UajUajVTUajatRqNRtGo1Go1Go4wjeMpZFMVV9CHbofPraLsJ3JpWV2XOoanCuFky4y3PPNxucK2uKC1Lbdb1eo+m5XomN6HfeZsabHLHRX/K+offtNGGmHWctcVcG44MdSqsOLY9VzX+Zxfxn2HPdWTpzWvkrtJ8M5zorrKcquRytJ5N5DZmcaW02l76nWO+BqPXm1A2Ry/0q71dH/mqrqeFjkYxjEXtsX8qubTk67rGycyqsdm4tZx5D6D5hhi0waaWmiaMP81Yjii5qxPlPuU/GfTL1Y5E6Jyfiq63qTa39A4J0sOGDgO9WF9bOXl0XfPRbsY2bPNKPy1YrFYrFYmRhhlTIyMjJWJYZHXuCXI8OoXsvfljGLFicNifpp2XunoPiG1wtx3p1Tah+/DD66OnVtVXP9rKbVxOnL0tR/rHtqB5UDErUVcl11D4qqvjpOcxX7armUNJB3LpW6bxVvD08e8h3odKKvyCFZBdSh2FVcST9xV3n3T8t1j7Kr9qgrqXg+13Pt5U7JCvFXVIV1YG5lRhkVYZJYYDDD4KOIMoHCp26WS8GB7uBh2zIdgq/PKyInjV2STShuoapUdCpX1yTwqq/z1VvET7Kh5nVPkO8YyxjLt2MaaMmWTLQvx3qnzltnXW0p2jxgbEtSny/Osv8Y9pLMXYoHVPAhkVdWVeODhR6q9/Sxe2liwwZWMVvFXfRkeIDxAePUPIrdJ4ey6yquzH+PD/bUOWAu05qVHtFd8rrKHSoeNIOUqrYr3FXyToqfYJgwmJdKpXXOwYYegNNGMzfZPp/t3t/DVs4zjNTN61rRqaWaa4NYbRjTa0tWwy2Y2tGN8ZO8ofNKq4j9SL7I+cSm4/6ovLV5HNXLI0jJidwrtk6ynCaP6Z++GjRlWS3tLeW129Mi9evxU9mtz6s5J3Z7M2ngTgnKvmpomxpaLCzPfmx0JWE+m3NLDDGOX47RctdYYNK5jakdqLkRlI39n590T5zctGSwwZZDJj6kW8XSi6ot2MmWWJ0DUT3nuvebBudScjZ79g8cWJ8av0k+/bE5WKd5MdbFpbDVMxu1DVMmtNZGJvq1mtRbn6M+g/kP0FwDwr7quZs7xosNGpbscyxhhd9TyJyFwbLcxlTasg75vW7TsV5K7ji44XPMMrdoj+Y3rT0Hie62nlYV/pwczzOmdLqLhYkzGMzCZWGMQzGMSsZYY6Di1t4nlJ+Em63mJxrVLxPbYxNEdgc1dU2iOKyoYYWjNrEeHTYybVk0atSa7ehuwsWMWTqn1TrnS6hYsi71d1+s+k+ic70e20fzE/VaTdxT9ZtU4GIXdeNx3X77guYYfpHeTQjaMX6brOu4OY4K7Y2d9mbHarI5ox3p4GpJ2Vd/Tst60f7j999pppjR+Q/Qf8J/VaORs3cji7FfFuN61+ui9s8hix1OCh5KGVV23BPXvZfz3CLyHpix+exi8z/KnCnosY2eunor+cxyPO/xJ0vKey9OvE9VjqaYu0x3Z3jd6o2b1T12D+F8l232lwaaacD5LE8LBxu7WTlbWraWpew8Xexjel3E+wWD4APITdNqR8F3R3T0lunCQ4GaE9R37DxeCYfcHi4xci5ovKfxVs55y2hf+65E/Xdp6jR5nrebTmi5incpkyOjs50JvrZwstbbW6kfuuQw+2mykf/EXNFzxfKTrxew929TR6bWnGL//F3JFOFCQT3K4lQ" - - kernels = Kernel( - bz2.decompress(base64.b64decode(quantization_code)), - [ - "int4WeightCompression", - "int4WeightExtractionFloat", - "int4WeightExtractionHalf", - "int8WeightExtractionFloat", - "int8WeightExtractionHalf", - ], - ) -except Exception as exception: - kernels = None - logger.warning("Failed to load cpm_kernels:" + str(exception)) - - -class W8A16Linear(torch.autograd.Function): - @staticmethod - def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width): - ctx.inp_shape = inp.size() - ctx.weight_bit_width = weight_bit_width - out_features = quant_w.size(0) - inp = inp.contiguous().view(-1, inp.size(-1)) - weight = extract_weight_to_half(quant_w, scale_w, weight_bit_width) - ctx.weight_shape = weight.size() - output = inp.mm(weight.t()) - ctx.save_for_backward(inp, quant_w, scale_w) - return output.view(*(ctx.inp_shape[:-1] + (out_features,))) - - @staticmethod - def backward(ctx, grad_output: torch.Tensor): - inp, quant_w, scale_w = ctx.saved_tensors - weight = extract_weight_to_half(quant_w, scale_w, ctx.weight_bit_width) - grad_output = grad_output.contiguous().view(-1, weight.size(0)) - grad_input = grad_output.mm(weight) - grad_weight = grad_output.t().mm(inp) - return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None, None - - -def compress_int4_weight(weight: torch.Tensor): # (n, m) - with torch.cuda.device(weight.device): - n, m = weight.size(0), weight.size(1) - assert m % 2 == 0 - m = m // 2 - out = torch.empty(n, m, dtype=torch.int8, device="cuda") - stream = torch.cuda.current_stream() - - gridDim = (n, 1, 1) - blockDim = (min(round_up(m, 32), 1024), 1, 1) - - kernels.int4WeightCompression( - gridDim, - blockDim, - 0, - stream, - [ctypes.c_void_p(weight.data_ptr()), ctypes.c_void_p(out.data_ptr()), ctypes.c_int32(n), ctypes.c_int32(m)], - ) - return out - - -def extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tensor, source_bit_width: int): - if source_bit_width == 8: - func = kernels.int8WeightExtractionHalf - elif source_bit_width == 4: - func = kernels.int4WeightExtractionHalf - else: - assert False, "Unsupported bit-width" - - with torch.cuda.device(weight.device): - n, m = weight.size(0), weight.size(1) - out = torch.empty(n, m * (8 // source_bit_width), dtype=torch.half, device="cuda") - stream = torch.cuda.current_stream() - - gridDim = (n, 1, 1) - blockDim = (min(round_up(m, 32), 1024), 1, 1) - - func( - gridDim, - blockDim, - 0, - stream, - [ - ctypes.c_void_p(weight.data_ptr()), - ctypes.c_void_p(scale_list.data_ptr()), - ctypes.c_void_p(out.data_ptr()), - ctypes.c_int32(n), - ctypes.c_int32(m), - ], - ) - return out - - -class QuantizedLinear(Linear): - def __init__(self, weight_bit_width: int, weight_tensor=None, bias_tensor=None, empty_init=False, *args, **kwargs): - super(QuantizedLinear, self).__init__(*args, **kwargs) - self.weight_bit_width = weight_bit_width - - shape = self.weight.shape - del self.weight - - if weight_tensor is None or empty_init: - self.weight = torch.empty( - shape[0], shape[1] * weight_bit_width // 8, dtype=torch.int8, device=kwargs["device"] - ) - self.weight_scale = torch.empty(shape[0], dtype=kwargs["dtype"], device=kwargs["device"]) - else: - self.weight_scale = (weight_tensor.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).half() - self.weight = torch.round(weight_tensor / self.weight_scale[:, None]).to(torch.int8) - if weight_bit_width == 4: - self.weight = compress_int4_weight(self.weight) - - self.weight = Parameter(self.weight.to(kwargs["device"]), requires_grad=False) - self.weight_scale = Parameter(self.weight_scale.to(kwargs["device"]), requires_grad=False) - if bias_tensor is not None: - self.bias = Parameter(bias_tensor.to(kwargs["device"]), requires_grad=False) - else: - self.bias = None - - def forward(self, input): - output = W8A16Linear.apply(input, self.weight, self.weight_scale, self.weight_bit_width) - if self.bias is not None: - output = output + self.bias - return output - - -def quantize(model, weight_bit_width, empty_init=False, **kwargs): - """Replace fp16 linear with quantized linear""" - - for layer in model.layers: - layer.attention.query_key_value = QuantizedLinear( - weight_bit_width=weight_bit_width, - weight_tensor=layer.attention.query_key_value.weight.to(torch.cuda.current_device()), - bias_tensor=layer.attention.query_key_value.bias, - in_features=layer.attention.query_key_value.in_features, - out_features=layer.attention.query_key_value.out_features, - bias=True, - dtype=torch.half, - device=layer.attention.query_key_value.weight.device, - empty_init=empty_init - ) - layer.attention.dense = QuantizedLinear( - weight_bit_width=weight_bit_width, - weight_tensor=layer.attention.dense.weight.to(torch.cuda.current_device()), - bias_tensor=layer.attention.dense.bias, - in_features=layer.attention.dense.in_features, - out_features=layer.attention.dense.out_features, - bias=True, - dtype=torch.half, - device=layer.attention.dense.weight.device, - empty_init=empty_init - ) - layer.mlp.dense_h_to_4h = QuantizedLinear( - weight_bit_width=weight_bit_width, - weight_tensor=layer.mlp.dense_h_to_4h.weight.to(torch.cuda.current_device()), - bias_tensor=layer.mlp.dense_h_to_4h.bias, - in_features=layer.mlp.dense_h_to_4h.in_features, - out_features=layer.mlp.dense_h_to_4h.out_features, - bias=True, - dtype=torch.half, - device=layer.mlp.dense_h_to_4h.weight.device, - empty_init=empty_init - ) - layer.mlp.dense_4h_to_h = QuantizedLinear( - weight_bit_width=weight_bit_width, - weight_tensor=layer.mlp.dense_4h_to_h.weight.to(torch.cuda.current_device()), - bias_tensor=layer.mlp.dense_4h_to_h.bias, - in_features=layer.mlp.dense_4h_to_h.in_features, - out_features=layer.mlp.dense_4h_to_h.out_features, - bias=True, - dtype=torch.half, - device=layer.mlp.dense_4h_to_h.weight.device, - empty_init=empty_init - ) - return model diff --git a/tests/script/models/chatglm-6b/test_modeling_chatglm.py b/tests/script/models/chatglm-6b/test_modeling_chatglm.py deleted file mode 100644 index 814c8bd..0000000 --- a/tests/script/models/chatglm-6b/test_modeling_chatglm.py +++ /dev/null @@ -1,245 +0,0 @@ -import datetime -import math -import unittest -import torch -import random -import time - -from transformers import AutoTokenizer, AutoModel -from transformers.testing_utils import require_torch, slow, torch_device - -from torch.profiler import profile, record_function, ProfilerActivity - -def set_random_seed(seed): - import random - - random.seed(seed) - - # pytorch RNGs - import torch - - torch.manual_seed(seed) - torch.backends.cudnn.deterministic = True - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(seed) - - # numpy RNG - import numpy as np - - np.random.seed(seed) - - - -def ids_tensor(shape, vocab_size): - # Creates a random int32 tensor of the shape within the vocab size - total_dims = 1 - for dim in shape: - total_dims *= dim - - values = [] - for _ in range(total_dims): - values.append(random.randint(0, vocab_size - 1)) - - return torch.tensor(data=values, dtype=torch.long, device=torch_device).view(shape).contiguous() - - -def get_model_and_tokenizer(): - model = AutoModel.from_pretrained("/home/luocheng/openvino/src/plugins/intel_cpu/thirdparty/llmdnn/tests/script/models/chatglm-6b", trust_remote_code=True) - model.to(torch_device, dtype=torch.bfloat16) - print(f"torch_device={torch_device}") - model.eval() - tokenizer = AutoTokenizer.from_pretrained("/home/luocheng/openvino/src/plugins/intel_cpu/thirdparty/llmdnn/tests/script/models/chatglm-6b", trust_remote_code=True) - return model, tokenizer - - -@require_torch -class ChatGLMGenerationTest(unittest.TestCase): - def get_generation_kwargs(self): - pass - - def ntest_chat(self): - print("======================test_chat") - model, tokenizer = get_model_and_tokenizer() - prompts = ["你好", "介绍一下清华大学", "它创建于哪一年"] - history = [] - set_random_seed(42) - expected_responses = [ - '你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。', - '清华大学是中国著名的综合性研究型大学,位于中国北京市海淀区,创建于 1911 年,前身是清华学堂。作为我国顶尖高等教育机构之一,清华大学在科学研究、工程技术、信息技术、经济管理等领域处于领先地位,也是世界上最著名的工程学府之一。\n\n清华大学拥有世界一流的教学设施和科学研究平台,设有多个学院和研究中心,包括工程学院、自然科学学院、社会科学学院、人文学院、法学院、经济管理学院等。学校拥有众多知名教授和研究团队,其中包括多位院士、国家杰出青年科学基金获得者、长江学者等。\n\n清华大学的本科生招生范围为全国中学毕业生,本科生入学要求严格,考试成绩优秀。同时,清华大学也提供研究生和博士生招生,包括硕士研究生和博士研究生。', - '清华大学创建于 1911 年。' - ] - for (prompt, expected_response) in zip(prompts, expected_responses): - response, history = model.chat(tokenizer, prompt, history=history) - print(repr(response)) - self.assertEqual(expected_response, response) - - def ntest_stream_chat(self): - print("======================test_stream_chat") - model, tokenizer = get_model_and_tokenizer() - prompts = ["你好", "介绍一下清华大学", "它创建于哪一年"] - history = [] - expected_responses = [ - '你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。', - '清华大学是中国著名的综合性研究型大学,位于中国北京市海淀区,创建于 1911 年,前身是清华学堂。作为我国顶尖高等教育机构之一,清华大学在科学研究、工程技术、信息技术、经济管理等领域处于领先地位,也是世界上最著名的工程学府之一。\n\n清华大学拥有世界一流的教学设施和科学研究平台,设有多个学院和研究中心,包括工程学院、自然科学学院、社会科学学院、人文学院、法学院、经济管理学院等。学校拥有众多知名教授和研究团队,其中包括多位院士、国家杰出青年科学基金获得者、长江学者等。\n\n清华大学的本科生招生范围为全国中学毕业生,本科生入学要求严格,考试成绩优秀。同时,清华大学也提供研究生和博士生招生,包括硕士研究生和博士研究生。', - '清华大学创建于 1911 年。' - ] - set_random_seed(42) - for prompt, expected_response in zip(prompts, expected_responses): - response = "" - for idx, (response, history) in enumerate(model.stream_chat(tokenizer, prompt, history=history)): - pass - print(repr(response)) - self.assertEqual(expected_response, response) - - def ntest_generation(self): - print("======================test_generation") - model, tokenizer = get_model_and_tokenizer() - sentence = "晚上睡不着怎么办" - parameters = [(False, 2048, 1), - (False, 64, 1), - (True, 2048, 1), - (True, 64, 1), - (True, 2048, 4)] - expected_out_sentences = [ - '晚上睡不着怎么办 以下是一些可能有助于在晚上入睡的方法:\n\n1. 保持规律的睡眠时间表:尽量在同一时间上床,并尝试在早上醒来时自然起床。\n\n2. 创建舒适的睡眠环境:保持房间安静、凉爽、黑暗、舒适,并使用舒适的床垫和枕头。\n\n3. 避免刺激性物质:避免饮用含咖啡因的饮料,如咖啡、茶和可乐,并尽可能减少饮酒。\n\n4. 放松身心:尝试进行放松的活动,如冥想、深呼吸、瑜伽或听轻柔的音乐。\n\n5. 避免在床上做其他事情:例如看电视、使用电脑或智能手机等。\n\n6. 练习放松技巧:例如渐进性肌肉松弛法、冥想或深呼吸练习。\n\n7. 寻求帮助:如果长时间都无法正常入睡,可以考虑咨询医生或专业心理医生,寻求更进一步的帮助。\n\n希望这些方法能有助于入睡。', - '晚上睡不着怎么办 以下是一些可能有助于在晚上入睡的方法:\n\n1. 保持规律的睡眠时间表:尽量在同一时间上床,并尝试在早上醒来时自然起床。\n\n2. 创建舒适的睡眠环境:保持房间安静、凉爽、黑暗、舒适,并使用舒适的床垫和枕头。', - '晚上睡不着怎么办 以下是一些有助于在晚上更好地入睡的方法:\n\n1. 维持规律的睡眠时间:每晚尽可能在同一时间上床,保持规律的睡眠时间表,帮助身体调整并更容易入睡。\n\n2. 避免在床上使用电子设备:手机、平板电脑、电脑等电子设备会发出蓝光,这会干扰身体释放褪黑素,进而导致难以入睡。建议你在睡前一小时停止使用这些设备。\n\n3. 创建舒适的睡眠环境:确保卧室安静、黑暗、凉爽,舒适的床垫和枕头,保持卧室温度适宜,这有助于让你更容易入睡。\n\n4. 放松身心:尝试进行一些放松的活动,如冥想、深呼吸、瑜伽或轻松的散步,减轻压力和焦虑,让你更容易入睡。\n\n5. 避免咖啡因和酒精:咖啡因和酒精会让大脑更加兴奋,进而干扰身体入睡过程。建议在睡前几小时避免饮用这些物质。\n\n6. 做一些安静的活动:阅读一本书、听轻柔的音乐、绣或者绘画等安静的活动,有助于自己放松身心,进而更容易入睡。\n\n如果采取以上这些方法仍然无法入睡,建议咨询医生或专业的睡眠专家,获取更好的建议和帮助。', - '晚上睡不着怎么办 以下是一些有助于在晚上更好地入睡的方法:\n\n1. 维持规律的睡眠时间:每晚尽可能在同一时间上床,保持规律的睡眠时间表,帮助身体调整并更容易入睡。\n\n2. 避免在床上使用电子设备:手机、平板电脑、电脑等电子设备会发出蓝光,这会干扰身体', - '晚上睡不着怎么办 以下是一些可能有助于在晚上入睡的方法:\n\n1. 建立规律的睡眠时间表:尽量在同一时间入睡和起床,即使在周末和假期也要尽量保持一致。\n\n2. 创造舒适的睡眠环境:保持房间安静、凉爽、黑暗、舒适,使用舒适的床垫和枕头等。\n\n3. 放松身心:尝试进行一些放松的活动,如冥想、深呼吸、瑜伽、听轻柔的音乐等,缓解压力和紧张情绪。\n\n4. 避免刺激性物质:避免饮用咖啡、茶、可乐等含咖啡因的饮料,避免吸烟和饮酒等刺激性物质。\n\n5. 避免躺在床上翻来覆去:如果躺在床上超过20分钟还不能入睡,就不要躺在床上翻来覆去,而是起床去做一些放松的活动,直到感到困倦为止。\n\n6. 练习放松技巧:如果感到焦虑或紧张,可以尝试进行一些放松技巧,如渐进性肌肉松弛、冥想等。\n\n7. 改善睡眠障碍:如果已经尝试了上述方法仍然无法入睡,可以考虑咨询医生,了解是否存在其他睡眠障碍问题,并接受相应的治疗。'] - for (do_sample, max_length, num_beams), expected_output_sentence in zip(parameters, expected_out_sentences): - set_random_seed(42) - inputs = tokenizer([sentence,], return_tensors="pt", padding=True) - inputs = inputs.to(torch_device) - print(inputs) - outputs = model.generate( - **inputs, - do_sample=do_sample, - max_length=max_length, - num_beams=num_beams - ) - print(outputs) - outputs = outputs.tolist()[0] - out_sentence = tokenizer.decode(outputs, skip_special_tokens=True) - print(out_sentence) - self.assertEqual(expected_output_sentence, out_sentence) - - def test_batch_generation(self): - print("======================test_batch_generation") - model, tokenizer = get_model_and_tokenizer() - sentences = [ - "你好", - "介绍一下清华大学" - ] - parameters = [(False, 2048, 1), - (False, 64, 1), - (True, 2048, 1), - (True, 64, 1), - (True, 2048, 4)] - expected_out_sentences = [ - ['你好 你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。', - '介绍一下清华大学 清华大学是中国著名的综合性大学,位于北京市海淀区双清路30号,其历史可以追溯到1911年创建的清华学堂,1925年更名为清华学校,1937年抗日战争全面爆发后南迁长沙,1946年迁回清华园。新中国成立后,清华学校更名为清华大学。\n\n清华大学是中国最顶尖的大学之一,在工程、科学、技术、经济、管理等领域都有很高的学术声誉和影响力。学校拥有世界一流的教学设施和科学研究平台,有多个学院和研究中心,包括工程学院、自然科学学院、人文学院、社会科学学院、经济管理学院、法学院、美术学院、医学院、器学院等。\n\n清华大学的本科生招生始于2000年,实行全面二孩政策后,本科生招生规模不断扩大。截至2022年,清华大学共有本科生近3万人,研究生近2万人,其中国际学生占比约为10%。清华大学的本科生教育注重通识教育和个性化培养,强调实践、创新、国际化和综合素质。'], - [ - '你好 你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。', - '介绍一下清华大学 清华大学是中国著名的综合性大学,位于北京市海淀区双清路30号,其历史可以追溯到1911年创建的清华学堂,1925年更名为清华学校,1937年抗日战争全面爆发后南迁长沙,1946年迁回' - ], - [ - '你好 你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。', - '介绍一下清华大学 清华大学是中国著名的综合性研究型大学,位于北京市海淀区双清路 30 号,其溯源于 1911 年创建的清华学堂, 1925 年更名为清华学校, 1937 年秋抗日战争全面爆发后闭校。1949 年 10 月开学复校,成为我国第一个社会主义大学生活了的高校。截至 2023 年,清华学校共管辖 2 个学院、13 个系,有本科专业 60 个,研究生专业 190 个。' - ], - [ - '你好 你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。', - '介绍一下清华大学 清华大学是中国著名的综合性研究型大学,位于北京市海淀区双清路 30 号,其溯源于 1911 年创建的清华学堂, 1925 年更名为清华学校, 1937 年秋抗日战争全面爆发后' - ], - [ - '你好 你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。', - '介绍一下清华大学 清华大学是中国著名的综合性研究型大学,位于北京市海淀区双清路30号,其历史可以追溯到1911年创建的清华学堂,1925年更名为清华学校,1937年抗日战争全面爆发后南迁长沙,与北京大学、南开大学组建国立长沙临时大学,1938年迁至 昆明改名为国立西南联合大学,1946年迁回北京。新中国成立后,清华学校更名为清华大学。' - ] - ] - for (do_sample, max_length, num_beams), expected_output_sentence in zip(parameters, expected_out_sentences): - set_random_seed(42) - inputs = tokenizer(sentences, return_tensors="pt", padding=True) - inputs = inputs.to(torch_device) - print(inputs) - outputs = model.generate( - **inputs, - do_sample=do_sample, - max_length=max_length, - num_beams=num_beams - ) - print(outputs) - batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True) - print(batch_out_sentence) - self.assertListEqual(expected_output_sentence, batch_out_sentence) - -def mytest(use_jit = False): - model, tokenizer = get_model_and_tokenizer() - - # import intel_extension_for_pytorch as ipex - # model = ipex.optimize(model, dtype=torch.bfloat16) - - - sentence = "世界羽毛球史上最伟大的球员都有谁?世界羽毛球史上最伟大的球员都有谁?世界羽毛球史上最伟大的球员都有谁?世界羽毛球史上最伟大的球员都有谁?" - parameters = [(False, 2048, 1), - #(True, 2048, 1), - #(True, 2048, 4) - ] - expected_out_sentences = [ - '晚上睡不着怎么办 以下是一些可能有助于在晚上入睡的方法:\n\n1. 保持规律的睡眠时间表:尽量在同一时间上床,并尝试在早上醒来时自然起床。\n\n2. 创建舒适的睡眠环境:保持房间安静、凉爽、黑暗、舒适,并使用舒适的床垫和枕头。\n\n3. 避免刺激性物质:避免饮用含咖啡因的饮料,如咖啡、茶和可乐,并尽可能减少饮酒。\n\n4. 放松身心:尝试进行放松的活动,如冥想、深呼吸、瑜伽或听轻柔的音乐。\n\n5. 避免在床上做其他事情:例如看电视、使用电脑或智能手机等。\n\n6. 练习放松技巧:例如渐进性肌肉松弛法、冥想或深呼吸练习。\n\n7. 寻求帮助:如果长时间都无法正常入睡,可以考虑咨询医生或专业心理医生,寻求更进一步的帮助。\n\n希望这些方法能有助于入睡。', - '晚上睡不着怎么办 以下是一些可能有助于在晚上入睡的方法:\n\n1. 保持规律的睡眠时间表:尽量在同一时间上床,并尝试在早上醒来时自然起床。\n\n2. 创建舒适的睡眠环境:保持房间安静、凉爽、黑暗、舒适,并使用舒适的床垫和枕头。', - '晚上睡不着怎么办 以下是一些有助于在晚上更好地入睡的方法:\n\n1. 维持规律的睡眠时间:每晚尽可能在同一时间上床,保持规律的睡眠时间表,帮助身体调整并更容易入睡。\n\n2. 避免在床上使用电子设备:手机、平板电脑、电脑等电子设备会发出蓝光,这会干扰身体释放褪黑素,进而导致难以入睡。建议你在睡前一小时停止使用这些设备。\n\n3. 创建舒适的睡眠环境:确保卧室安静、黑暗、凉爽,舒适的床垫和枕头,保持卧室温度适宜,这有助于让你更容易入睡。\n\n4. 放松身心:尝试进行一些放松的活动,如冥想、深呼吸、瑜伽或轻松的散步,减轻压力和焦虑,让你更容易入睡。\n\n5. 避免咖啡因和酒精:咖啡因和酒精会让大脑更加兴奋,进而干扰身体入睡过程。建议在睡前几小时避免饮用这些物质。\n\n6. 做一些安静的活动:阅读一本书、听轻柔的音乐、绣或者绘画等安静的活动,有助于自己放松身心,进而更容易入睡。\n\n如果采取以上这些方法仍然无法入睡,建议咨询医生或专业的睡眠专家,获取更好的建议和帮助。', - '晚上睡不着怎么办 以下是一些有助于在晚上更好地入睡的方法:\n\n1. 维持规律的睡眠时间:每晚尽可能在同一时间上床,保持规律的睡眠时间表,帮助身体调整并更容易入睡。\n\n2. 避免在床上使用电子设备:手机、平板电脑、电脑等电子设备会发出蓝光,这会干扰身体', - '晚上睡不着怎么办 以下是一些可能有助于在晚上入睡的方法:\n\n1. 建立规律的睡眠时间表:尽量在同一时间入睡和起床,即使在周末和假期也要尽量保持一致。\n\n2. 创造舒适的睡眠环境:保持房间安静、凉爽、黑暗、舒适,使用舒适的床垫和枕头等。\n\n3. 放松身心:尝试进行一些放松的活动,如冥想、深呼吸、瑜伽、听轻柔的音乐等,缓解压力和紧张情绪。\n\n4. 避免刺激性物质:避免饮用咖啡、茶、可乐等含咖啡因的饮料,避免吸烟和饮酒等刺激性物质。\n\n5. 避免躺在床上翻来覆去:如果躺在床上超过20分钟还不能入睡,就不要躺在床上翻来覆去,而是起床去做一些放松的活动,直到感到困倦为止。\n\n6. 练习放松技巧:如果感到焦虑或紧张,可以尝试进行一些放松技巧,如渐进性肌肉松弛、冥想等。\n\n7. 改善睡眠障碍:如果已经尝试了上述方法仍然无法入睡,可以考虑咨询医生,了解是否存在其他睡眠障碍问题,并接受相应的治疗。'] - - jit_model_generated = False - f = open('result.torch.txt', 'w') - for (do_sample, max_length, num_beams), expected_output_sentence in zip(parameters, expected_out_sentences): - set_random_seed(42) - inputs = tokenizer([sentence,], return_tensors="pt", padding=True) - inputs = inputs.to(torch_device) - #print(inputs) - inputs.data['position_ids'] = inputs.data['position_ids'].to(torch.int32) - attn_mask = torch.zeros_like(inputs.data['attention_mask'], dtype=torch.float32) - inputs.data['attention_mask'] = attn_mask.masked_fill_(inputs.data['attention_mask'], -10000.0) - - if not jit_model_generated and use_jit: - print("generating jit model...") - with torch.no_grad(), torch.cpu.amp.autocast(): - model = torch.jit.trace(model, inputs) - model = torch.jit.freeze(model) - jit_model_generated = True - print("done") - - for repeat in range(1): - t0 = time.time() - - # with profile(activities=[ProfilerActivity.CPU], record_shapes=False) as prof: - # with record_function("model_inference"): - with torch.no_grad(), torch.cpu.amp.autocast(): - outputs = model.generate( - **inputs, - do_sample=do_sample, - max_length=max_length, - num_beams=num_beams - ) - t1 = time.time() - - #prof.export_chrome_trace("trace.json") - #print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10)) - - #print(outputs) - outputs = outputs.tolist()[0] - - if repeat == 0: - out_sentence = tokenizer.decode(outputs, skip_special_tokens=True) - print(f"{out_sentence}") - print(f" #tokens={len(outputs)},do_sample={do_sample},max_length={max_length},num_beams={num_beams}") - f.write(out_sentence) - - print(f" [{repeat}] ::: {(t1-t0)*1e3/len(outputs)} ms/token") - - f.close() - -# numactl -C 0-46 python ./test_modeling_chatglm.py -if __name__ == '__main__': - mytest() - #unittest.main() \ No newline at end of file diff --git a/tests/script/models/chatglm-6b/tokenization_chatglm.py b/tests/script/models/chatglm-6b/tokenization_chatglm.py deleted file mode 100644 index 69ee85c..0000000 --- a/tests/script/models/chatglm-6b/tokenization_chatglm.py +++ /dev/null @@ -1,443 +0,0 @@ -"""Tokenization classes for ChatGLM.""" -from typing import List, Optional, Union -import os - -from transformers.tokenization_utils import PreTrainedTokenizer -from transformers.utils import logging, PaddingStrategy -from transformers.tokenization_utils_base import EncodedInput, BatchEncoding -from typing import Dict -import sentencepiece as spm -import numpy as np - -logger = logging.get_logger(__name__) - -PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { - "THUDM/chatglm-6b": 2048, -} - - -class TextTokenizer: - def __init__(self, model_path): - self.sp = spm.SentencePieceProcessor() - self.sp.Load(model_path) - self.num_tokens = self.sp.vocab_size() - - def encode(self, text): - return self.sp.EncodeAsIds(text) - - def decode(self, ids: List[int]): - return self.sp.DecodeIds(ids) - - def tokenize(self, text): - return self.sp.EncodeAsPieces(text) - - def convert_tokens_to_string(self, tokens): - return self.sp.DecodePieces(tokens) - - def convert_tokens_to_ids(self, tokens): - return [self.sp.PieceToId(token) for token in tokens] - - def convert_token_to_id(self, token): - return self.sp.PieceToId(token) - - def convert_id_to_token(self, idx): - return self.sp.IdToPiece(idx) - - def __len__(self): - return self.num_tokens - - -class SPTokenizer: - def __init__( - self, - vocab_file, - num_image_tokens=20000, - max_blank_length=80, - byte_fallback=True, - ): - assert vocab_file is not None - self.vocab_file = vocab_file - self.num_image_tokens = num_image_tokens - self.special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "", "", "", "", ""] - self.max_blank_length = max_blank_length - self.byte_fallback = byte_fallback - self.text_tokenizer = TextTokenizer(vocab_file) - - def _get_text_tokenizer(self): - return self.text_tokenizer - - @staticmethod - def get_blank_token(length: int): - assert length >= 2 - return f"<|blank_{length}|>" - - @staticmethod - def get_tab_token(): - return f"<|tab|>" - - @property - def num_text_tokens(self): - return self.text_tokenizer.num_tokens - - @property - def num_tokens(self): - return self.num_image_tokens + self.num_text_tokens - - @staticmethod - def _encode_whitespaces(text: str, max_len: int = 80): - text = text.replace("\t", SPTokenizer.get_tab_token()) - for i in range(max_len, 1, -1): - text = text.replace(" " * i, SPTokenizer.get_blank_token(i)) - return text - - def _preprocess(self, text: str, linebreak=True, whitespaces=True): - if linebreak: - text = text.replace("\n", "") - if whitespaces: - text = self._encode_whitespaces(text, max_len=self.max_blank_length) - return text - - def encode( - self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True - ) -> List[int]: - """ - @param text: Text to encode. - @param linebreak: Whether to encode newline (\n) in text. - @param whitespaces: Whether to encode multiple whitespaces or tab in text, useful for source code encoding. - @param special_tokens: Whether to encode special token ([MASK], [gMASK], etc.) in text. - @param add_dummy_prefix: Whether to add dummy blank space in the beginning. - """ - text = self._preprocess(text, linebreak, whitespaces) - if not add_dummy_prefix: - text = "" + text - tmp = self._get_text_tokenizer().encode(text) - tokens = [x + self.num_image_tokens for x in tmp] - return tokens if add_dummy_prefix else tokens[2:] - - def postprocess(self, text): - text = text.replace("", "\n") - text = text.replace(SPTokenizer.get_tab_token(), "\t") - for i in range(2, self.max_blank_length + 1): - text = text.replace(self.get_blank_token(i), " " * i) - return text - - def decode(self, text_ids: List[int]) -> str: - ids = [int(_id) - self.num_image_tokens for _id in text_ids] - ids = [_id for _id in ids if _id >= 0] - text = self._get_text_tokenizer().decode(ids) - text = self.postprocess(text) - return text - - def decode_tokens(self, tokens: List[str]) -> str: - text = self._get_text_tokenizer().convert_tokens_to_string(tokens) - text = self.postprocess(text) - return text - - def tokenize( - self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True - ) -> List[str]: - """ - @param text: Text to encode. - @param linebreak: Whether to encode newline (\n) in text. - @param whitespaces: Whether to encode multiple whitespaces or tab in text, useful for source code encoding. - @param special_tokens: Whether to encode special token ([MASK], [gMASK], etc.) in text. - @param add_dummy_prefix: Whether to add dummy blank space in the beginning. - """ - text = self._preprocess(text, linebreak, whitespaces) - if not add_dummy_prefix: - text = "" + text - tokens = self._get_text_tokenizer().tokenize(text) - return tokens if add_dummy_prefix else tokens[2:] - - def __getitem__(self, x: Union[int, str]): - if isinstance(x, int): - if x < self.num_image_tokens: - return "".format(x) - else: - return self.text_tokenizer.convert_id_to_token(x - self.num_image_tokens) - elif isinstance(x, str): - if x.startswith("") and x[7:-1].isdigit(): - return int(x[7:-1]) - else: - return self.text_tokenizer.convert_token_to_id(x) + self.num_image_tokens - else: - raise ValueError("The key should be str or int.") - - -class ChatGLMTokenizer(PreTrainedTokenizer): - """ - Construct a ChatGLM tokenizer. Based on byte-level Byte-Pair-Encoding. - - Args: - vocab_file (`str`): - Path to the vocabulary file. - """ - - vocab_files_names = {"vocab_file": "ice_text.model"} - max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES - model_input_names = ["input_ids", "attention_mask", "position_ids"] - - def __init__( - self, - vocab_file, - do_lower_case=False, - remove_space=False, - bos_token='', - eos_token='', - end_token='', - mask_token='[MASK]', - gmask_token='[gMASK]', - padding_side="left", - pad_token="", - unk_token="", - num_image_tokens=20000, - **kwargs - ) -> None: - super().__init__( - do_lower_case=do_lower_case, - remove_space=remove_space, - padding_side=padding_side, - bos_token=bos_token, - eos_token=eos_token, - end_token=end_token, - mask_token=mask_token, - gmask_token=gmask_token, - pad_token=pad_token, - unk_token=unk_token, - num_image_tokens=num_image_tokens, - **kwargs - ) - - self.do_lower_case = do_lower_case - self.remove_space = remove_space - self.vocab_file = vocab_file - - self.bos_token = bos_token - self.eos_token = eos_token - self.end_token = end_token - self.mask_token = mask_token - self.gmask_token = gmask_token - - self.sp_tokenizer = SPTokenizer(vocab_file, num_image_tokens=num_image_tokens) - - """ Initialisation """ - - @property - def gmask_token_id(self) -> Optional[int]: - if self.gmask_token is None: - return None - return self.convert_tokens_to_ids(self.gmask_token) - - @property - def end_token_id(self) -> Optional[int]: - """ - `Optional[int]`: Id of the end of context token in the vocabulary. Returns `None` if the token has not been - set. - """ - if self.end_token is None: - return None - return self.convert_tokens_to_ids(self.end_token) - - @property - def vocab_size(self): - """ Returns vocab size """ - return self.sp_tokenizer.num_tokens - - def get_vocab(self): - """ Returns vocab as a dict """ - vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)} - vocab.update(self.added_tokens_encoder) - return vocab - - def preprocess_text(self, inputs): - if self.remove_space: - outputs = " ".join(inputs.strip().split()) - else: - outputs = inputs - - if self.do_lower_case: - outputs = outputs.lower() - - return outputs - - def _tokenize(self, text, **kwargs): - """ Returns a tokenized string. """ - text = self.preprocess_text(text) - - seq = self.sp_tokenizer.tokenize(text) - - return seq - - def convert_tokens_to_string(self, tokens: List[str]) -> str: - return self.sp_tokenizer.decode_tokens(tokens) - - def _decode( - self, - token_ids: Union[int, List[int]], - **kwargs - ) -> str: - if isinstance(token_ids, int): - token_ids = [token_ids] - if len(token_ids) == 0: - return "" - if self.pad_token_id in token_ids: # remove pad - token_ids = list(filter((self.pad_token_id).__ne__, token_ids)) - return super()._decode(token_ids, **kwargs) - - def _convert_token_to_id(self, token): - """ Converts a token (str) in an id using the vocab. """ - return self.sp_tokenizer[token] - - def _convert_id_to_token(self, index): - """Converts an index (integer) in a token (str) using the vocab.""" - return self.sp_tokenizer[index] - - def save_vocabulary(self, save_directory, filename_prefix=None): - """ - Save the vocabulary and special tokens file to a directory. - - Args: - save_directory (`str`): - The directory in which to save the vocabulary. - filename_prefix (`str`, *optional*): - An optional prefix to add to the named of the saved files. - - Returns: - `Tuple(str)`: Paths to the files saved. - """ - if os.path.isdir(save_directory): - vocab_file = os.path.join( - save_directory, self.vocab_files_names["vocab_file"] - ) - else: - vocab_file = save_directory - - with open(self.vocab_file, 'rb') as fin: - proto_str = fin.read() - - with open(vocab_file, "wb") as writer: - writer.write(proto_str) - - return (vocab_file,) - - def build_inputs_with_special_tokens( - self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None - ) -> List[int]: - """ - Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and - adding special tokens. A BERT sequence has the following format: - - - single sequence: `[CLS] X [SEP]` - - pair of sequences: `[CLS] A [SEP] B [SEP]` - - Args: - token_ids_0 (`List[int]`): - List of IDs to which the special tokens will be added. - token_ids_1 (`List[int]`, *optional*): - Optional second list of IDs for sequence pairs. - - Returns: - `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. - """ - gmask_id = self.sp_tokenizer[self.gmask_token] - eos_id = self.sp_tokenizer[self.eos_token] - token_ids_0 = token_ids_0 + [gmask_id, self.sp_tokenizer[self.bos_token]] - if token_ids_1 is not None: - token_ids_0 = token_ids_0 + token_ids_1 + [eos_id] - return token_ids_0 - - def _pad( - self, - encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding], - max_length: Optional[int] = None, - padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, - pad_to_multiple_of: Optional[int] = None, - return_attention_mask: Optional[bool] = None, - ) -> dict: - """ - Pad encoded inputs (on left/right and up to predefined length or max length in the batch) - - Args: - encoded_inputs: - Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`). - max_length: maximum length of the returned list and optionally padding length (see below). - Will truncate by taking into account the special tokens. - padding_strategy: PaddingStrategy to use for padding. - - - PaddingStrategy.LONGEST Pad to the longest sequence in the batch - - PaddingStrategy.MAX_LENGTH: Pad to the max length (default) - - PaddingStrategy.DO_NOT_PAD: Do not pad - The tokenizer padding sides are defined in self.padding_side: - - - 'left': pads on the left of the sequences - - 'right': pads on the right of the sequences - pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. - This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability - `>= 7.5` (Volta). - return_attention_mask: - (optional) Set to False to avoid returning attention mask (default: set to model specifics) - """ - # Load from model defaults - bos_token_id = self.sp_tokenizer[self.bos_token] - mask_token_id = self.sp_tokenizer[self.mask_token] - gmask_token_id = self.sp_tokenizer[self.gmask_token] - assert self.padding_side == "left" - - required_input = encoded_inputs[self.model_input_names[0]] - seq_length = len(required_input) - - if padding_strategy == PaddingStrategy.LONGEST: - max_length = len(required_input) - - if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): - max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of - - needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length - - # Initialize attention mask if not present. - if max_length is not None: - if "attention_mask" not in encoded_inputs: - if bos_token_id in required_input: - context_length = required_input.index(bos_token_id) - else: - context_length = seq_length - attention_mask = np.ones((1, seq_length, seq_length)) - attention_mask = np.tril(attention_mask) - attention_mask[:, :, :context_length] = 1 - attention_mask = np.bool_(attention_mask < 0.5) - encoded_inputs["attention_mask"] = attention_mask - - if "position_ids" not in encoded_inputs: - if bos_token_id in required_input: - context_length = required_input.index(bos_token_id) - else: - context_length = seq_length - position_ids = np.arange(seq_length, dtype=np.int64) - mask_token = mask_token_id if mask_token_id in required_input else gmask_token_id - if mask_token in required_input: - mask_position = required_input.index(mask_token) - position_ids[context_length:] = mask_position - block_position_ids = np.concatenate( - [np.zeros(context_length, dtype=np.int64), - np.arange(1, seq_length - context_length + 1, dtype=np.int64)]) - encoded_inputs["position_ids"] = np.stack([position_ids, block_position_ids], axis=0) - - if needs_to_be_padded: - difference = max_length - len(required_input) - - if "attention_mask" in encoded_inputs: - encoded_inputs["attention_mask"] = np.pad(encoded_inputs["attention_mask"], - pad_width=[(0, 0), (difference, 0), (difference, 0)], - mode='constant', constant_values=True) - if "token_type_ids" in encoded_inputs: - encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[ - "token_type_ids" - ] - if "special_tokens_mask" in encoded_inputs: - encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"] - if "position_ids" in encoded_inputs: - encoded_inputs["position_ids"] = np.pad(encoded_inputs["position_ids"], - pad_width=[(0, 0), (difference, 0)]) - encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input - - return encoded_inputs diff --git a/tests/script/models/chatglm-6b/tokenizer_config.json b/tests/script/models/chatglm-6b/tokenizer_config.json deleted file mode 100644 index f8221e0..0000000 --- a/tests/script/models/chatglm-6b/tokenizer_config.json +++ /dev/null @@ -1,20 +0,0 @@ -{ - "name_or_path": "THUDM/chatglm-6b", - "bos_token": "", - "eos_token": "", - "end_token": "", - "gmask_token": "[gMASK]", - "mask_token": "[MASK]", - "pad_token": "", - "unk_token": "", - "remove_space": false, - "do_lower_case": false, - "tokenizer_class": "ChatGLMTokenizer", - "num_image_tokens": 0, - "auto_map": { - "AutoTokenizer": [ - "tokenization_chatglm.ChatGLMTokenizer", - null - ] - } -} diff --git a/tests/script/requirements.txt b/tests/script/requirements.txt index b526b7c..8ec05ac 100644 --- a/tests/script/requirements.txt +++ b/tests/script/requirements.txt @@ -2,4 +2,3 @@ numpy==1.24.2 torch==2.0.1+cpu pytest -ninja \ No newline at end of file From 27f324a99e3e4ead9a4775c7bffe02faad48b071 Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Thu, 13 Jul 2023 01:36:00 +0800 Subject: [PATCH 27/54] remove warning --- include/llm_plain_tensor.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/llm_plain_tensor.hpp b/include/llm_plain_tensor.hpp index bfc0faa..441bd5d 100644 --- a/include/llm_plain_tensor.hpp +++ b/include/llm_plain_tensor.hpp @@ -32,11 +32,11 @@ struct plain_tensor_base { } size_t size(int i) { - assert(i < m_rank); + assert(static_cast(i) < m_rank); return m_dims[i]; } size_t stride(int i) { - assert(i < m_rank); + assert(static_cast(i) < m_rank); return m_strides[i]; } }; From 03b7b37f4dd5cf33a6ecb4f24564546292a2240c Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Thu, 13 Jul 2023 17:39:31 +0800 Subject: [PATCH 28/54] remove share_ptr from interface --- include/llm_emb_gpt.hpp | 4 +++- include/llm_mha_gpt.hpp | 32 +++---------------------- include/llm_plain_tensor.hpp | 41 +++++++++++++++++---------------- src/common/llm_plain_tensor.cpp | 16 ++++++------- src/emb_gpt_api.cpp | 4 ++++ src/emb_gpt_avx512.cpp | 4 ++-- src/emb_gpt_avx512.hpp | 2 +- src/mha_gpt_amx.cpp | 4 ++-- src/mha_gpt_amx.hpp | 2 +- src/mha_gpt_api.cpp | 4 ++++ 10 files changed, 49 insertions(+), 64 deletions(-) diff --git a/include/llm_emb_gpt.hpp b/include/llm_emb_gpt.hpp index dfc4947..de6ce90 100644 --- a/include/llm_emb_gpt.hpp +++ b/include/llm_emb_gpt.hpp @@ -45,15 +45,17 @@ class emb_gpt { }; emb_gpt(); + ~emb_gpt(); bool create(const create_param& param); void exec(const exec_param& param); struct impl { + virtual ~impl() {} virtual bool create(const create_param& param) = 0; virtual void exec(const exec_param& param) = 0; }; protected: - std::shared_ptr _impl; + impl* _impl; }; } diff --git a/include/llm_mha_gpt.hpp b/include/llm_mha_gpt.hpp index c905fbe..e38d6c4 100644 --- a/include/llm_mha_gpt.hpp +++ b/include/llm_mha_gpt.hpp @@ -12,34 +12,6 @@ namespace llmdnn { -// pattern is: -// query:[batch, num_heads, query_seq_len, head_size] key:[batch, num_heads, key_seq_len, head_size] -// \ | -// \ Transpose0: [batch, num_heads, head_size, key_seq_len] -// \ / -// \ / -// \ / -// MatMul0: [batch, num_heads, query_seq_len, key_seq_len] -// | -// | norm_factor(const): [1] -// | / -// Multiply: [batch, num_heads, query_seq_len, key_seq_len] -// | -// | causal_mask: [1, 1, query_seq_len, key_seq_len] -// | / -// Select(only for 1x300): [batch, num_heads, query_seq_len, key_seq_len] -// | -// | attention_mask:[batch, 1, 1, key_seq_len] -// | / -// Add: [batch, num_heads, query_seq_len, key_seq_len] -// | -// SoftMax: [batch, num_heads, query_seq_len, key_seq_len] -// | -// \ value:[batch, num_heads, key_seq_len, head_size] -// \ / -// MatMul1: [batch, num_heads, query_seq_len, head_size] -// | -// Transpose1(only for 1x300): [batch, query_seq_len, num_heads * head_size] class mha_gpt { public: struct create_param { @@ -75,15 +47,17 @@ class mha_gpt { }; mha_gpt(); + ~mha_gpt(); bool create(const create_param& param); void exec(const exec_param& param); struct impl { + virtual ~impl() {} virtual bool create(const create_param& param) = 0; virtual void exec(const exec_param& param) = 0; }; protected: - std::shared_ptr _impl; + impl* _impl; }; } diff --git a/include/llm_plain_tensor.hpp b/include/llm_plain_tensor.hpp index 441bd5d..35669c4 100644 --- a/include/llm_plain_tensor.hpp +++ b/include/llm_plain_tensor.hpp @@ -11,6 +11,7 @@ #include #include #include +#include namespace llmdnn { @@ -20,15 +21,12 @@ struct plain_tensor_base { size_t m_dims[PLAINTENSOR_RANK_MAX]; size_t m_rank; - std::shared_ptr m_ptr; + void* m_ptr = nullptr; size_t m_capacity = 0; size_t m_element_size = 1; - uint8_t* batched_ptr_buff[8]; - std::vector batched_ptr_backup; - operator bool() { - return static_cast(m_ptr); + return m_ptr != nullptr; } size_t size(int i) { @@ -44,6 +42,21 @@ struct plain_tensor_base { struct plain_tensor : public plain_tensor_base { plain_tensor(); ~plain_tensor(); + plain_tensor(const plain_tensor&) = delete; + plain_tensor& operator = (const plain_tensor&) = delete; + plain_tensor(plain_tensor&& t) { + memcpy(reinterpret_cast(this), &t, sizeof(plain_tensor_base)); + t.m_capacity = 0; + t.m_ptr = nullptr; + } + plain_tensor& operator == (plain_tensor&& t) { + if (m_capacity && m_ptr) + free(m_ptr); + memcpy(reinterpret_cast(this), &t, sizeof(plain_tensor_base)); + t.m_capacity = 0; + t.m_ptr = nullptr; + return *this; + } struct tensor_index { int start; @@ -116,7 +129,7 @@ struct plain_tensor : public plain_tensor_base { template DT* data() const { - return reinterpret_cast(m_ptr.get()); + return reinterpret_cast(m_ptr); } template @@ -127,7 +140,7 @@ struct plain_tensor : public plain_tensor_base { auto coordinate = (it != index.end()) ? (*it++) : 0; off += stride * coordinate; } - return *reinterpret_cast(reinterpret_cast(m_ptr.get()) + off); + return *reinterpret_cast(reinterpret_cast(m_ptr) + off); } template @@ -136,18 +149,6 @@ struct plain_tensor : public plain_tensor_base { } void assert_dims(const std::initializer_list& expect_dims) const; - uint8_t** get_batched_ptrs() { - uint8_t** ret_ptrs = batched_ptr_buff; - auto batch_size = m_dims[0]; - if (batch_size > sizeof(batched_ptr_buff) / sizeof(batched_ptr_buff[0])) { - batched_ptr_backup.resize(batch_size); - ret_ptrs = &batched_ptr_backup[0]; - } - for (size_t b = 0; b < batch_size; b++) { - ret_ptrs[b] = &at({b}); - } - return ret_ptrs; - } template std::string repr(int max_total_lines = 16, int lines_per_row = 1) const { @@ -175,7 +176,7 @@ struct plain_tensor : public plain_tensor_base { int cur_line_elecnt = 0; int cur_row_elecnt = 0; size_t i; - auto* p = reinterpret_cast(m_ptr.get()); + auto* p = reinterpret_cast(m_ptr); for (i = 0; i < sz && max_total_lines > 0; i++) { if ((i % last_dim_size) == 0) { ss << row_id << ":\t\t"; diff --git a/src/common/llm_plain_tensor.cpp b/src/common/llm_plain_tensor.cpp index c322719..fde69a8 100644 --- a/src/common/llm_plain_tensor.cpp +++ b/src/common/llm_plain_tensor.cpp @@ -18,6 +18,8 @@ plain_tensor::plain_tensor() { } plain_tensor::~plain_tensor() { + if (m_capacity && m_ptr) + free(m_ptr); } plain_tensor plain_tensor::index(const std::initializer_list& indices) const { @@ -43,7 +45,7 @@ plain_tensor plain_tensor::index(const std::initializer_list& indi i_src++; } sub_tensor.m_rank = i_dst; // index may imply squeeze - sub_tensor.m_ptr = std::shared_ptr(reinterpret_cast(m_ptr.get()) + off, [](void*) {}); + sub_tensor.m_ptr = reinterpret_cast(m_ptr) + off; return sub_tensor; } @@ -61,8 +63,8 @@ plain_tensor plain_tensor::slice(int axis, int start, int end) const { sub_tensor.m_dims[axis] = end - start; auto off = start * m_strides[axis]; - auto* data = reinterpret_cast(m_ptr.get()) + off; - sub_tensor.m_ptr = std::shared_ptr(reinterpret_cast(data), [](void*) {}); + auto* data = reinterpret_cast(m_ptr) + off; + sub_tensor.m_ptr = reinterpret_cast(data); return sub_tensor; } @@ -99,7 +101,7 @@ plain_tensor plain_tensor::reshape(const std::initializer_list& target_s plain_tensor new_tensor_view; assert(is_dense()); //assert(shape_size(target_shape) == shape_size(m_dims)); - new_tensor_view.resize(std::vector(target_shape), m_ptr.get(), m_element_size); + new_tensor_view.resize(std::vector(target_shape), m_ptr, m_element_size); return new_tensor_view; } @@ -135,15 +137,13 @@ void plain_tensor::resize(const std::vector& new_dims, void* data, size_ if (!data) { auto capacity_new = m_strides[0] * m_dims[0]; if (capacity_new > m_capacity) { - m_ptr = std::shared_ptr(aligned_alloc(64, capacity_new), [](void* p) { - ::free(p); - }); + m_ptr = aligned_alloc(64, capacity_new); m_capacity = capacity_new; } } else { // m_capacity is zero to indicate that we don't own the memory m_capacity = 0; - m_ptr = std::shared_ptr(reinterpret_cast(data), [](void*) {}); + m_ptr = reinterpret_cast(data); } } diff --git a/src/emb_gpt_api.cpp b/src/emb_gpt_api.cpp index b8a7b84..96b2d53 100644 --- a/src/emb_gpt_api.cpp +++ b/src/emb_gpt_api.cpp @@ -13,6 +13,10 @@ namespace llmdnn { emb_gpt::emb_gpt(): _impl(new_impl_avx512()) { } +emb_gpt::~emb_gpt() { + delete _impl; +} + bool emb_gpt::create(const create_param& param) { return _impl->create(param); } diff --git a/src/emb_gpt_avx512.cpp b/src/emb_gpt_avx512.cpp index b290da2..d03a49b 100644 --- a/src/emb_gpt_avx512.cpp +++ b/src/emb_gpt_avx512.cpp @@ -217,8 +217,8 @@ void emb_gpt_impl_avx512::exec(const emb_gpt::exec_param& param) { } } -std::shared_ptr new_impl_avx512() { - return std::make_shared(); +emb_gpt::impl* new_impl_avx512() { + return new emb_gpt_impl_avx512(); } } \ No newline at end of file diff --git a/src/emb_gpt_avx512.hpp b/src/emb_gpt_avx512.hpp index 62e3bea..9a92742 100644 --- a/src/emb_gpt_avx512.hpp +++ b/src/emb_gpt_avx512.hpp @@ -12,6 +12,6 @@ namespace llmdnn { -std::shared_ptr new_impl_avx512(); +emb_gpt::impl* new_impl_avx512(); } diff --git a/src/mha_gpt_amx.cpp b/src/mha_gpt_amx.cpp index 4b052c2..bc34fed 100644 --- a/src/mha_gpt_amx.cpp +++ b/src/mha_gpt_amx.cpp @@ -445,8 +445,8 @@ void mha_gpt_impl_amx::exec(const mha_gpt::exec_param& param) { } } -std::shared_ptr new_impl_amx() { - return std::make_shared(); +mha_gpt::impl* new_impl_amx() { + return new mha_gpt_impl_amx(); } } \ No newline at end of file diff --git a/src/mha_gpt_amx.hpp b/src/mha_gpt_amx.hpp index 9409af9..c012ec2 100644 --- a/src/mha_gpt_amx.hpp +++ b/src/mha_gpt_amx.hpp @@ -12,6 +12,6 @@ namespace llmdnn { -std::shared_ptr new_impl_amx(); +mha_gpt::impl* new_impl_amx(); } diff --git a/src/mha_gpt_api.cpp b/src/mha_gpt_api.cpp index e702e8e..cf67cf3 100644 --- a/src/mha_gpt_api.cpp +++ b/src/mha_gpt_api.cpp @@ -13,6 +13,10 @@ namespace llmdnn { mha_gpt::mha_gpt(): _impl(new_impl_amx()) { } +mha_gpt::~mha_gpt() { + delete _impl; +} + bool mha_gpt::create(const create_param& param) { return _impl->create(param); } From 106402c58529e3fa6abc4071fa7654f719fa29d9 Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Fri, 14 Jul 2023 04:25:06 +0800 Subject: [PATCH 29/54] refactor: emb/mha use tensor as input parameter --- include/llm_emb_gpt.hpp | 59 +- include/llm_mha_gpt.hpp | 57 +- .../{llm_plain_tensor.hpp => llm_tensor.hpp} | 130 ++-- .../{llm_plain_tensor.cpp => tensor.cpp} | 39 +- src/emb_gpt_api.cpp | 26 +- src/emb_gpt_avx512.cpp | 271 +++------ src/emb_gpt_avx512.hpp | 13 +- src/mha_gpt_amx.cpp | 560 +++++++----------- src/mha_gpt_api.cpp | 8 +- src/utility_kernel_avx512.hpp | 17 + tests/script/ext/attn_gpt.cpp | 249 -------- tests/script/ext/emb_gpt.cpp | 192 ++---- tests/script/ext/mha_gpt.cpp | 153 +---- tests/script/ext/module.cpp | 1 - tests/script/ext/module.hpp | 1 - tests/script/ext/setup.py | 2 +- tests/script/test_attn_chatglm.py | 452 -------------- tests/script/test_mha_bloom.py | 24 +- tests/script/test_mha_chatglm.py | 259 -------- tests/script/test_mha_gpt.py | 84 +-- tests/script/test_rotary_pastkv.py | 49 +- tests/script/test_rotary_pastkv_chatglm.py | 50 +- 22 files changed, 581 insertions(+), 2115 deletions(-) rename include/{llm_plain_tensor.hpp => llm_tensor.hpp} (68%) rename src/common/{llm_plain_tensor.cpp => tensor.cpp} (82%) delete mode 100644 tests/script/ext/attn_gpt.cpp delete mode 100644 tests/script/test_attn_chatglm.py delete mode 100644 tests/script/test_mha_chatglm.py diff --git a/include/llm_emb_gpt.hpp b/include/llm_emb_gpt.hpp index de6ce90..c642a2a 100644 --- a/include/llm_emb_gpt.hpp +++ b/include/llm_emb_gpt.hpp @@ -8,54 +8,21 @@ #include #include #include "llm_types.hpp" +#include "llm_tensor.hpp" namespace llmdnn { -class emb_gpt { -public: - struct create_param { - size_t num_heads; - size_t head_size; - size_t head_size_aligned; // better to aligned to 64 bytes for best performance, apply for qkv - // supported (qkv, dst): (bf16, bf16) - data_type_t qkv_precision; - data_type_t dst_precision; - size_t rotary_dims; - bool use_position2d; // chatglm true, other false - }; - struct exec_param { - size_t batch; - size_t query_seq_len; - size_t past_seq_len; - uint8_t* q; // shape: [batch, query_seq_len, hidden size], inner stride is ldq - uint8_t* k; // shape: [batch, query_seq_len, hidden size], inner stride is ldk - uint8_t* v; // shape: [batch, query_seq_len, hidden size], inner stride is ldv - size_t ldq; // inner stride of q - size_t ldk; // inner stride of k - size_t ldv; // inner stride of v - uint8_t* query_dst; // rotary embbeding dst - uint8_t** layer_past_key_src; // past key src - uint8_t** layer_past_value_src; // past value src - uint8_t** layer_past_key_dst; // past key dst, if layer_past_key_src!=layer_past_key_dst, will copy layer_past_key_src to layer_past_key_dst - uint8_t** layer_past_value_dst; // past value dst, if layer_past_value!=layer_past_value_dst, will copy layer_past_value to layer_past_value_dst - float* cos; // cos lookup table, shape: [max_seq_len, rotary_dims] - float* sin; // sin lookup table, shape: [max_seq_len, rotary_dims] - int* position2d_ids; // shape: [batch, 2, query_seq_len] - size_t head_stride_in_kv; // kv stride for next head; kv may be preallocated a big buffer - }; - - emb_gpt(); - ~emb_gpt(); - bool create(const create_param& param); - void exec(const exec_param& param); - - struct impl { - virtual ~impl() {} - virtual bool create(const create_param& param) = 0; - virtual void exec(const exec_param& param) = 0; - }; -protected: - impl* _impl; -}; +void emb_gpt(const tensor& q_src, // q shape: [batch, query_seq_len, head_num, head_size] + const tensor& k_src, // k shape: [batch, query_seq_len, head_num, head_size] + const tensor& v_src, // v shape: [batch, query_seq_len, head_num, head_size] + const tensor& k_past, // k_past shape: [batch, num_heads, past_seq_len, head_size] + const tensor& v_past, // v_past shape: [batch, num_heads, past_seq_len, head_size] + const tensor& q_dst, // q_dst, shape: [batch, num_heads, query_seq_len, head_size] + const tensor& k_dst, // k_past shape: [batch, num_heads, query_seq_len+past_seq_len, head_size] + // if k_past!=k_past_dst, will copy k_past to k_past_dst + const tensor& v_dst, // v_past shape: [batch, num_heads, query_seq_len+past_seq_len, head_size] + const tensor& cos, // cos lookup table, shape: [1, 1, max_seq_len, rotary_dims] + const tensor& sin, // sin lookup table, shape: [1, 1, max_seq_len, rotary_dims] + const tensor& position2d_ids); // shape: [batch, 2, query_seq_len] } diff --git a/include/llm_mha_gpt.hpp b/include/llm_mha_gpt.hpp index e38d6c4..9cc98f5 100644 --- a/include/llm_mha_gpt.hpp +++ b/include/llm_mha_gpt.hpp @@ -8,53 +8,36 @@ #include #include #include "llm_types.hpp" -#include "llm_plain_tensor.hpp" +#include "llm_tensor.hpp" namespace llmdnn { class mha_gpt { public: - struct create_param { - size_t num_heads; - size_t head_size; - size_t max_seq_len; // max seq length for computing the size of matmul tmp result - float normal_factor; - // supported (qkv, dst): (bf16, bf16), (s8, s8) - data_type_t qkv_precision; - data_type_t dst_precision; - bool is_bloom; // for bloom mha - }; - struct exec_param { - size_t batch; - size_t query_seq_len; - size_t key_seq_len; - bool is_causal_in_attention; // causal mask is fused in attention mask: chatglm uses it. - plain_tensor q; // q buffer, shape: [batch, num_heads, query_seq_len, head_size] - plain_tensor k; // k buffer, shape: [batch, num_heads, key_seq_len, head_size] - plain_tensor v; // v buffer, shape: [batch, num_heads, value_seq_len, head_size] - plain_tensor attention_mask; // attention mask, shape: - // [batch, 1, 1, key_seq_len], when is_causal_in_attention is false - // [batch, 1, query_seq_len, key_seq_len], when is_causal_in_attention is true - plain_tensor attn_output; // output, compact, shape: [batch, query_seq_len, num_heads * head_size] - plain_tensor alibi; // only is_bloom is true will use, shape: [batch, num_heads, 1, key_seq_len] - // expected quant schema: - // q,k,v use per tensor quant, attn_output may use per tensor/channel quant - float q_dequant; - float k_dequant; - float v_dequant; - float qk_quant; - std::vector qkv_quant; // size==1 per tensor, size==head_size per channel - }; - mha_gpt(); ~mha_gpt(); - bool create(const create_param& param); - void exec(const exec_param& param); + + void exec(const tensor& q, // q shape: [batch, num_heads, query_seq_len, head_size] + const tensor& k, // k shape: [batch, num_heads, key_seq_len, head_size] + const tensor& v, // v shape: [batch, num_heads, value_seq_len, head_size] + const tensor& output, // output, compact, shape: [batch, query_seq_len, num_heads * head_size] + const tensor& attn_mask, // attention mask[opt], shape: + // [batch, 1, 1, key_seq_len], + // [batch, 1, query_seq_len, key_seq_len] + const tensor& alibi, // alibi[opt] shape: [batch, num_heads, 1, key_seq_len] + float normal_factor, + bool use_causal_mask = false);// add causal mask struct impl { virtual ~impl() {} - virtual bool create(const create_param& param) = 0; - virtual void exec(const exec_param& param) = 0; + virtual void exec(const tensor& q, + const tensor& k, + const tensor& v, + const tensor& output, + const tensor& attn_mask, + const tensor& alibi, + float normal_factor, + bool use_causal_mask = false) = 0; }; protected: impl* _impl; diff --git a/include/llm_plain_tensor.hpp b/include/llm_tensor.hpp similarity index 68% rename from include/llm_plain_tensor.hpp rename to include/llm_tensor.hpp index 35669c4..10671d4 100644 --- a/include/llm_plain_tensor.hpp +++ b/include/llm_tensor.hpp @@ -13,50 +13,86 @@ #include #include +#include "llm_types.hpp" + +// forward declaration +namespace ov { +class bfloat16; +}; + namespace llmdnn { -#define PLAINTENSOR_RANK_MAX 8 -struct plain_tensor_base { - size_t m_strides[PLAINTENSOR_RANK_MAX]; - size_t m_dims[PLAINTENSOR_RANK_MAX]; - size_t m_rank; +template +struct precision_of { + static constexpr data_type_t value = dnnl_data_type_undef; +}; - void* m_ptr = nullptr; - size_t m_capacity = 0; - size_t m_element_size = 1; +template <> +struct precision_of { + static constexpr data_type_t value = dnnl_f32; +}; - operator bool() { - return m_ptr != nullptr; - } +template <> +struct precision_of { + static constexpr data_type_t value = dnnl_s32; +}; - size_t size(int i) { - assert(static_cast(i) < m_rank); - return m_dims[i]; - } - size_t stride(int i) { - assert(static_cast(i) < m_rank); - return m_strides[i]; - } +template <> +struct precision_of { + static constexpr data_type_t value = dnnl_bf16; +}; + +template <> +struct precision_of { + static constexpr data_type_t value = dnnl_u8; +}; + +template <> +struct precision_of { + static constexpr data_type_t value = dnnl_s8; }; -struct plain_tensor : public plain_tensor_base { - plain_tensor(); - ~plain_tensor(); - plain_tensor(const plain_tensor&) = delete; - plain_tensor& operator = (const plain_tensor&) = delete; - plain_tensor(plain_tensor&& t) { - memcpy(reinterpret_cast(this), &t, sizeof(plain_tensor_base)); + +#define TENSOR_RANK_MAX 8 +struct tensor { + size_t m_strides[TENSOR_RANK_MAX]; + size_t m_dims[TENSOR_RANK_MAX]; + size_t m_rank = 0; + + void* m_ptr = nullptr; + size_t m_capacity = 0; // 0 means not own m_ptr + size_t m_element_size = 0; + data_type_t m_dtype = dnnl_data_type_undef; + + tensor(); + ~tensor(); + tensor(const tensor&) = delete; + tensor& operator = (const tensor&) = delete; + tensor(tensor&& t) { + memcpy(reinterpret_cast(this), &t, sizeof(*this)); t.m_capacity = 0; t.m_ptr = nullptr; } - plain_tensor& operator == (plain_tensor&& t) { + tensor& operator = (tensor&& t) { if (m_capacity && m_ptr) free(m_ptr); - memcpy(reinterpret_cast(this), &t, sizeof(plain_tensor_base)); + memcpy(reinterpret_cast(this), &t, sizeof(*this)); t.m_capacity = 0; t.m_ptr = nullptr; return *this; } + operator bool() const { + return m_ptr != nullptr; + } + + size_t size(int i) const { + assert(static_cast(i) < m_rank); + return m_dims[i]; + } + size_t stride(int i) const { + assert(static_cast(i) < m_rank); + return m_strides[i]; + } struct tensor_index { int start; @@ -93,39 +129,31 @@ struct plain_tensor : public plain_tensor_base { } }; - plain_tensor index(const std::initializer_list& indices) const; + tensor index(const std::initializer_list& indices) const; // slice: return a sub-view (w/o ownership/refcount to original data) - plain_tensor slice(int axis, int start, int end) const; + tensor slice(int axis, int start, int end) const; bool is_dense() const; - /* - suppose current shape is [a0,a1,...,am] - and target shape is [b0,b1,...,bn] - reshape is only valid when (a0*a1*...*am) == (b0*b1*...*bn) <======= (A) - - uniform a tensor's shape into groups from last to first, the dimension is merged - into current group if the subtensor in the group is still dense after merge. - otherwise a new group is formed. - - then reshape is performed on group basis, the check (A) is performed on group bases. - which means any reshape inside the group is OK, but not across the group boundary. + tensor reshape(const std::initializer_list& target_shape) const; - this can be done in one-loop, while group is forming, and checks are performed. + tensor permute(const std::initializer_list& order) const; - simplified form is when whole tensor is dense - */ - plain_tensor reshape(const std::initializer_list& target_shape) const; - - plain_tensor permute(const std::initializer_list& order) const; + template + void resize(const size_t* new_dims, size_t dim_num, DT* data = nullptr) { + resize(new_dims, dim_num, data, sizeof(DT), precision_of
::value); + } template void resize(const std::vector& new_dims, DT* data = nullptr) { - resize(new_dims, data, sizeof(DT)); + resize(new_dims.data(), new_dims.size(), data); } - void resize(const std::vector& new_dims, void* data, size_t element_size); + void resize(const size_t* new_dims, size_t dim_num, void* data, size_t element_size, data_type_t dtype); + void resize(const std::vector& new_dims, void* data, size_t element_size, data_type_t dtype) { + resize(new_dims.data(), new_dims.size(), data, element_size, dtype); + } template DT* data() const { @@ -213,11 +241,11 @@ struct plain_tensor : public plain_tensor_base { } template - friend std::ostream& operator<<(std::ostream& os, const plain_tensor& dt); + friend std::ostream& operator<<(std::ostream& os, const tensor& dt); }; template -std::ostream& operator<<(std::ostream& os, const plain_tensor& dt) { +std::ostream& operator<<(std::ostream& os, const tensor& dt) { os << dt.repr(); return os; } diff --git a/src/common/llm_plain_tensor.cpp b/src/common/tensor.cpp similarity index 82% rename from src/common/llm_plain_tensor.cpp rename to src/common/tensor.cpp index fde69a8..b8e4258 100644 --- a/src/common/llm_plain_tensor.cpp +++ b/src/common/tensor.cpp @@ -10,20 +10,20 @@ #include #include "bf16.hpp" -#include "llm_plain_tensor.hpp" +#include "llm_tensor.hpp" namespace llmdnn { -plain_tensor::plain_tensor() { +tensor::tensor() { } -plain_tensor::~plain_tensor() { +tensor::~tensor() { if (m_capacity && m_ptr) free(m_ptr); } -plain_tensor plain_tensor::index(const std::initializer_list& indices) const { - plain_tensor sub_tensor; +tensor tensor::index(const std::initializer_list& indices) const { + tensor sub_tensor; assert(indices.size() <= m_rank); int i_src = 0; int i_dst = 0; @@ -46,12 +46,12 @@ plain_tensor plain_tensor::index(const std::initializer_list& indi } sub_tensor.m_rank = i_dst; // index may imply squeeze sub_tensor.m_ptr = reinterpret_cast(m_ptr) + off; - return sub_tensor; + return std::move(sub_tensor); } // slice: return a sub-view (w/o ownership/refcount to original data) -plain_tensor plain_tensor::slice(int axis, int start, int end) const { - plain_tensor sub_tensor; +tensor tensor::slice(int axis, int start, int end) const { + tensor sub_tensor; assert(axis < m_rank); sub_tensor.m_capacity = 0; @@ -66,10 +66,10 @@ plain_tensor plain_tensor::slice(int axis, int start, int end) const { auto* data = reinterpret_cast(m_ptr) + off; sub_tensor.m_ptr = reinterpret_cast(data); - return sub_tensor; + return std::move(sub_tensor); } -bool plain_tensor::is_dense() const { +bool tensor::is_dense() const { // check if it's dense tensor size_t stride = m_element_size; for (int i = m_rank - 1; i >= 0; i--) { @@ -96,17 +96,17 @@ bool plain_tensor::is_dense() const { simplified form is when whole tensor is dense */ -plain_tensor plain_tensor::reshape(const std::initializer_list& target_shape) const { +tensor tensor::reshape(const std::initializer_list& target_shape) const { // only valid for dense memory - plain_tensor new_tensor_view; + tensor new_tensor_view; assert(is_dense()); //assert(shape_size(target_shape) == shape_size(m_dims)); - new_tensor_view.resize(std::vector(target_shape), m_ptr, m_element_size); + new_tensor_view.resize(std::vector(target_shape), m_ptr, m_element_size, m_dtype); return new_tensor_view; } -plain_tensor plain_tensor::permute(const std::initializer_list& order) const { - plain_tensor new_tensor_view; +tensor tensor::permute(const std::initializer_list& order) const { + tensor new_tensor_view; assert(order.size() == m_rank); new_tensor_view.m_capacity = 0; new_tensor_view.m_ptr = m_ptr; @@ -122,11 +122,12 @@ plain_tensor plain_tensor::permute(const std::initializer_list& order) c return new_tensor_view; } -void plain_tensor::resize(const std::vector& new_dims, void* data, size_t element_size) { +void tensor::resize(const size_t* new_dims, size_t dim_num, void* data, size_t element_size, data_type_t dtype) { // initialize strides for compact/dense tensor m_element_size = element_size; - m_rank = new_dims.size(); - assert(m_rank <= PLAINTENSOR_RANK_MAX); + m_dtype = dtype; + m_rank = dim_num; + assert(m_rank <= TENSOR_RANK_MAX); size_t stride = element_size; for (int i = m_rank - 1; i >= 0; i--) { m_dims[i] = new_dims[i]; @@ -147,7 +148,7 @@ void plain_tensor::resize(const std::vector& new_dims, void* data, size_ } } -void plain_tensor::assert_dims(const std::initializer_list& expect_dims) const { +void tensor::assert_dims(const std::initializer_list& expect_dims) const { if (m_rank != expect_dims.size()) { asm("int3"); std::cout << "dims not same\n"; diff --git a/src/emb_gpt_api.cpp b/src/emb_gpt_api.cpp index 96b2d53..8444c7d 100644 --- a/src/emb_gpt_api.cpp +++ b/src/emb_gpt_api.cpp @@ -9,20 +9,18 @@ namespace llmdnn { -// interface -emb_gpt::emb_gpt(): _impl(new_impl_avx512()) { -} - -emb_gpt::~emb_gpt() { - delete _impl; -} - -bool emb_gpt::create(const create_param& param) { - return _impl->create(param); -} - -void emb_gpt::exec(const exec_param& param) { - _impl->exec(param); +void emb_gpt(const tensor& q_src, + const tensor& k_src, + const tensor& v_src, + const tensor& k_past, + const tensor& v_past, + const tensor& q_dst, + const tensor& k_dst, + const tensor& v_dst, + const tensor& cos, + const tensor& sin, + const tensor& position2d_ids) { + emb_gpt_avx512(q_src, k_src, v_src, k_past, v_past, q_dst, k_dst, v_dst, cos, sin, position2d_ids); } } \ No newline at end of file diff --git a/src/emb_gpt_avx512.cpp b/src/emb_gpt_avx512.cpp index d03a49b..e0f0f67 100644 --- a/src/emb_gpt_avx512.cpp +++ b/src/emb_gpt_avx512.cpp @@ -2,10 +2,12 @@ // SPDX-License-Identifier: Apache-2.0 // +#include #include #include #include +#include "common/bf16.hpp" #include "common/simple_parallel.hpp" #include "common/utility.hpp" #include "utility_kernel_avx512.hpp" @@ -18,207 +20,120 @@ using namespace ov::cpu; namespace llmdnn { -struct emb_gpt_impl_avx512 : public emb_gpt::impl { - bool create(const emb_gpt::create_param& param) override; - void exec(const emb_gpt::exec_param& param) override; - - void memcpyPastKV(uint8_t** pastk_src, uint8_t** pastv_src, uint8_t** pastk_dst, uint8_t** pastv_dst, - size_t batch, size_t past_seq_len, size_t head_stride_in_kv); - void applyRotaryPosEmbMemcpy(uint8_t* q_src, uint8_t* k_src, uint8_t* v_src, size_t ldq, size_t ldk, size_t ldv, uint8_t* q_dst, uint8_t** k_dst, uint8_t** v_dst, - size_t batch, size_t q_seq_len, size_t past_seq_len, size_t head_stride_in_kv, float* cos, float* sin); - void applyRotaryPosEmbMemcpyWithPosition2d(uint8_t* q_src, uint8_t* k_src, uint8_t* v_src, size_t ldq, size_t ldk, size_t ldv, uint8_t* q_dst, uint8_t** k_dst, uint8_t** v_dst, - size_t batch, size_t q_seq_len, size_t past_seq_len, int* position2d_ids, size_t head_stride_in_kv, float* cos, float* sin); - - emb_gpt::create_param _create_param; - size_t _head_num = 32; - size_t _size_per_head = 80; - size_t _hidden_size = 32 * 80; - size_t _rotary_dim = 20; - // aligned to cache line - size_t _size_per_head_aligned = 80; - int64_t _input_type_size = 1; - int64_t _output_type_size = 1; - bool _use_position2d = false; -}; - -bool emb_gpt_impl_avx512::create(const emb_gpt::create_param& param) { - if (param.qkv_precision != dnnl_bf16) { - std::cout << "input precision must be bf16 or int8.\n"; - return false; - } - // TODO: support s8 - // if (param.dst_precision != dnnl_bf16 && param.dst_precision != dnnl_s8) { - // std::cout << "dst precision must be bf16 or int8.\n"; - // return false; - // } - _create_param = param; - - _head_num = param.num_heads; - _size_per_head = param.head_size; - _size_per_head_aligned = param.head_size_aligned; - _hidden_size = param.head_size * param.num_heads; - _rotary_dim = param.rotary_dims; - _input_type_size = sizeof(ov::bfloat16); - _output_type_size = sizeof(ov::bfloat16); - if (param.dst_precision == dnnl_s8) - _output_type_size = sizeof(int8_t); - - _use_position2d = param.use_position2d; - - return true; -} - -void emb_gpt_impl_avx512::memcpyPastKV(uint8_t** pastk_src, uint8_t** pastv_src, uint8_t** pastk_dst, uint8_t** pastv_dst, - size_t batch, size_t past_seq_len, size_t head_stride_in_kv) { - parallel_for3d(batch, _head_num, past_seq_len, [&](size_t b, size_t h, size_t s) { - auto k_dst_batch = pastk_dst[b]; - auto v_dst_batch = pastv_dst[b]; - auto k_src_batch = pastk_src[b]; - auto v_src_batch = pastv_src[b]; - auto k_dst_seq = k_dst_batch + s * _size_per_head_aligned * _output_type_size; - auto v_dst_seq = v_dst_batch + s * _size_per_head_aligned * _output_type_size; - auto k_src_seq = k_src_batch + s * _size_per_head_aligned * _output_type_size; - auto v_src_seq = v_src_batch + s * _size_per_head_aligned * _output_type_size; - auto* k_src_f = k_src_seq + h * past_seq_len * _size_per_head_aligned * _output_type_size; - auto* k_dst_f = k_dst_seq + h * head_stride_in_kv * _output_type_size; - auto* v_src_f = v_src_seq + h * past_seq_len * _size_per_head_aligned * _output_type_size; - auto* v_dst_f = v_dst_seq + h * head_stride_in_kv * _output_type_size; - - memcpy(k_dst_f, k_src_f, _output_type_size * _size_per_head); - memcpy(v_dst_f, v_src_f, _output_type_size * _size_per_head); - }); -} - -// q_src shape: [batch, q_seq_len, num_attention_heads, 3 * head_size] -// q_dst shape: [batch, num_attention_heads, q_seq_len, head_size_aligned] -// kv_src shape: [batch, q_seq_len, num_attention_heads, 3 * head_size] -// kv_dst shape: [batch, num_attention_heads, q_seq_len+past_seq_len, head_size_aligned] -void emb_gpt_impl_avx512::applyRotaryPosEmbMemcpy(uint8_t* q_src, uint8_t* k_src, uint8_t* v_src, size_t ldq, size_t ldk, size_t ldv, uint8_t* q_dst, uint8_t** k_dst, uint8_t** v_dst, - size_t batch, size_t q_seq_len, size_t past_seq_len, size_t head_stride_in_kv, float* cos, float* sin) { - auto key_offset = _output_type_size * past_seq_len * _size_per_head_aligned; - auto* cos_cached = cos + past_seq_len * _rotary_dim; - auto* sin_cached = sin + past_seq_len * _rotary_dim; - parallel_for3d(batch, _head_num, q_seq_len, [&](size_t b, size_t h, size_t s) { - // q, k rotary encoding - auto q_dst_batch = q_dst + b * _head_num * q_seq_len * _size_per_head_aligned * _output_type_size; - auto k_dst_batch = k_dst[b] + key_offset; - auto v_dst_batch = v_dst[b] + key_offset; - auto q_src_batch = q_src + b * _head_num * ldq * q_seq_len * _input_type_size; - auto k_src_batch = k_src + b * _head_num * ldk * q_seq_len * _input_type_size; - auto v_src_batch = v_src + b * _head_num * ldv * q_seq_len * _input_type_size; - auto q_dst_seq = q_dst_batch + s * _size_per_head_aligned * _output_type_size; - auto k_dst_seq = k_dst_batch + s * _size_per_head_aligned * _output_type_size; - auto v_dst_seq = v_dst_batch + s * _size_per_head_aligned * _output_type_size; - auto q_src_seq = q_src_batch + s * _head_num * ldq * _input_type_size; - auto k_src_seq = k_src_batch + s * _head_num * ldk * _input_type_size; - auto v_src_seq = v_src_batch + s * _head_num * ldv * _input_type_size; - auto* q_src_f = reinterpret_cast(q_src_seq + h * ldq * _input_type_size); - auto* k_src_f = reinterpret_cast(k_src_seq + h * ldk * _input_type_size); - auto* q_dst_f = reinterpret_cast(q_dst_seq + h * q_seq_len * _size_per_head_aligned * _output_type_size); - auto* k_dst_f = reinterpret_cast(k_dst_seq + h * head_stride_in_kv * _output_type_size); - rotary_avx512(_rotary_dim, cos_cached + s * _rotary_dim, sin_cached + s * _rotary_dim, q_src_f, k_src_f, q_dst_f, k_dst_f); - - // q, k concat - memcpy(reinterpret_cast(q_dst_f) + _rotary_dim * _output_type_size, reinterpret_cast(q_src_f) + _rotary_dim * _input_type_size, _output_type_size * (_size_per_head - _rotary_dim)); - memcpy(reinterpret_cast(k_dst_f) + _rotary_dim * _output_type_size, reinterpret_cast(k_src_f) + _rotary_dim * _input_type_size, _output_type_size * (_size_per_head - _rotary_dim)); - // v concat - memcpy(static_cast(v_dst_seq) + h * head_stride_in_kv * _output_type_size, - static_cast(v_src_seq) + h * ldv * _input_type_size, - _size_per_head * _output_type_size); +static void memcpy_past_kv(const tensor& k_past, const tensor& v_past, const tensor& k_dst, const tensor& v_dst) { + auto batch = k_past.m_dims[0]; + auto head_num = k_past.m_dims[1]; + auto past_seq_len = k_past.m_dims[2]; + auto size = k_past.m_dims[3]; + parallel_for3d(batch, head_num, past_seq_len, [&](size_t b, size_t h, size_t s) { + memcpy(&k_dst.at({b, h, s}), &k_past.at({b, h, s}), k_past.m_strides[2]); + memcpy(&v_dst.at({b, h, s}), &v_past.at({b, h, s}), v_past.m_strides[2]); }); } -// q_src shape: [batch, q_seq_len, num_attention_heads, 3 * head_size] -// q_dst shape: [batch, num_attention_heads, q_seq_len, head_size_aligned] -// kv_src shape: [batch, q_seq_len, num_attention_heads, 3 * head_size] -// kv_dst shape: [batch, num_attention_heads, q_seq_len+past_seq_len, head_size_aligned] +// q_src shape: [batch, q_seq_len, head_hum, head_size] +// q_dst shape: [batch, head_hum, q_seq_len, head_size] +// kv_src shape: [batch, q_seq_len, head_hum, head_size] +// kv_past shape: [batch, head_hum, past_seq_len, head_size] +// kv_dst shape: [batch, head_hum, q_seq_len+past_seq_len, head_size] // position2d_ids: [batch, 2, q_seq_len] -void emb_gpt_impl_avx512::applyRotaryPosEmbMemcpyWithPosition2d(uint8_t* q_src, uint8_t* k_src, uint8_t* v_src, size_t ldq, size_t ldk, size_t ldv, uint8_t* q_dst, uint8_t** k_dst, uint8_t** v_dst, - size_t batch, size_t q_seq_len, size_t past_seq_len, int* position2d_ids, size_t head_stride_in_kv, float* cos, float* sin) { - auto key_offset = _output_type_size * past_seq_len * _size_per_head_aligned; - auto* cos_cached = cos; - auto* sin_cached = sin; - parallel_for3d(batch, _head_num, q_seq_len, [&](size_t b, size_t h, size_t s) { +// cos/sin: [max_seq_len, rotary_dims] +static void rotary_emb_position2d(const tensor& q_src, + const tensor& k_src, + const tensor& v_src, + const tensor& k_past, + const tensor& v_past, + const tensor& q_dst, + const tensor& k_dst, + const tensor& v_dst, + const tensor& cos, + const tensor& sin, + const tensor& position2d_ids) { + auto batch = k_past.m_dims[0]; + auto head_num = k_past.m_dims[1]; + auto past_seq_len = k_past.m_dims[2]; + auto head_size = k_past.m_dims[3]; + auto query_seq_len = q_src.m_dims[1]; + auto rotary_ndim = cos.m_dims[3]; + + parallel_for3d(batch, head_num, query_seq_len, [&](size_t b, size_t h, size_t s) { + auto kv_dst_s = s + past_seq_len; // q, k rotary encoding - auto q_dst_batch = q_dst + b * _head_num * q_seq_len * _size_per_head_aligned * _output_type_size; - auto k_dst_batch = k_dst[b] + key_offset; - auto v_dst_batch = v_dst[b] + key_offset; - auto pos_batch = position2d_ids + b * 2 * q_seq_len; - auto block_batch = pos_batch + q_seq_len; - auto q_src_batch = q_src + b * _head_num * ldq * q_seq_len * _input_type_size; - auto k_src_batch = k_src + b * _head_num * ldk * q_seq_len * _input_type_size; - auto v_src_batch = v_src + b * _head_num * ldv * q_seq_len * _input_type_size; - auto q_dst_seq = q_dst_batch + s * _size_per_head_aligned * _output_type_size; - auto k_dst_seq = k_dst_batch + s * _size_per_head_aligned * _output_type_size; - auto v_dst_seq = v_dst_batch + s * _size_per_head_aligned * _output_type_size; - auto q_src_seq = q_src_batch + s * _head_num * ldq * _input_type_size; - auto k_src_seq = k_src_batch + s * _head_num * ldk * _input_type_size; - auto v_src_seq = v_src_batch + s * _head_num * ldv * _input_type_size; - auto* q_src_f = reinterpret_cast(q_src_seq + h * ldq * _input_type_size); - auto* k_src_f = reinterpret_cast(k_src_seq + h * ldk * _input_type_size); - auto* q_dst_f = reinterpret_cast(q_dst_seq + h * q_seq_len * _size_per_head_aligned * _output_type_size); - auto* k_dst_f = reinterpret_cast(k_dst_seq + h * head_stride_in_kv * _output_type_size); - rotary_avx512(_rotary_dim, cos_cached + pos_batch[s] * _rotary_dim, sin_cached + pos_batch[s] * _rotary_dim, q_src_f, k_src_f, q_dst_f, k_dst_f); - rotary_avx512(_rotary_dim, cos_cached + block_batch[s] * _rotary_dim, sin_cached + block_batch[s] * _rotary_dim, - q_src_f + _rotary_dim, - k_src_f + _rotary_dim, - q_dst_f + _rotary_dim, - k_dst_f + _rotary_dim); + if (position2d_ids) { + auto pos = position2d_ids.at({b, 0, s}); + rotary_avx512(rotary_ndim, &cos.at({0, 0, pos}), &sin.at({0, 0, pos}), + &q_src.at({b, s, h}), + &k_src.at({b, s, h}), + &q_dst.at({b, h, s}), + &k_dst.at({b, h, kv_dst_s})); + pos = position2d_ids.at({b, 1, s}); + rotary_avx512(rotary_ndim, &cos.at({0, 0, pos}), &sin.at({0, 0, pos}), + &q_src.at({b, s, h, rotary_ndim}), + &k_src.at({b, s, h, rotary_ndim}), + &q_dst.at({b, h, s, rotary_ndim}), + &k_dst.at({b, h, kv_dst_s, rotary_ndim})); + } else { + rotary_avx512(rotary_ndim, &cos.at({0, 0, s + past_seq_len}), &sin.at({0, 0, s + past_seq_len}), + &q_src.at({b, s, h}), + &k_src.at({b, s, h}), + &q_dst.at({b, h, s}), + &k_dst.at({b, h, kv_dst_s})); + memcpy(&q_dst.at({b, h, s, rotary_ndim}), &q_src.at({b, s, h, rotary_ndim}), (head_size - rotary_ndim) * sizeof(ov::bfloat16)); + memcpy(&k_dst.at({b, h, kv_dst_s, rotary_ndim}), &k_src.at({b, s, h, rotary_ndim}), (head_size - rotary_ndim) * sizeof(ov::bfloat16)); + } // v concat - memcpy(static_cast(v_dst_seq) + h * head_stride_in_kv * _output_type_size, - static_cast(v_src_seq) + h * ldv* _input_type_size, - _size_per_head * _output_type_size); + memcpy(&v_dst.at({b, h, kv_dst_s}), &v_src.at({b, s, h}), head_size * sizeof(ov::bfloat16)); }); } -void emb_gpt_impl_avx512::exec(const emb_gpt::exec_param& param) { +void emb_gpt_avx512(const tensor& q_src, + const tensor& k_src, + const tensor& v_src, + const tensor& k_past, + const tensor& v_past, + const tensor& q_dst, + const tensor& k_dst, + const tensor& v_dst, + const tensor& cos, + const tensor& sin, + const tensor& position2d_ids) { + if (q_src.m_rank != 4 || k_src.m_rank != 4 || v_src.m_rank != 4 || k_past.m_rank != 4 || v_past.m_rank != 4 || q_dst.m_rank != 4|| + k_dst.m_rank != 4 || v_dst.m_rank != 4 || cos.m_rank != 4 || sin.m_rank != 4) { + std::cout << "emb_gpt_avx512: rank is not correct: should be 4\n"; + return; + } + if (position2d_ids) { + if (position2d_ids.m_rank != 3) { + std::cout << "emb_gpt_avx512: position2d_ids rank should be 3\n"; + return; + } + if (position2d_ids.m_dims[0] != q_src.m_dims[0] || position2d_ids.m_dims[1] != 2 || position2d_ids.m_dims[2] != q_src.m_dims[1]) { + std::cout << "emb_gpt_avx512: position2d_ids dims should be [batch, 2, seq_len]\n"; + return; + } + } + // [batch, seq_len, (num_heads * 3 * head_size)] // --> [batch, seq_len, num_heads, 3 * head_size] - auto query = param.q; - auto key = param.k; - auto value = param.v; - auto query_dst = param.query_dst; - auto key_dst = param.layer_past_key_dst; - auto value_dst = param.layer_past_value_dst; - auto batch = param.batch; - auto query_seq_len = param.query_seq_len; - auto past_seq_len = param.past_seq_len; - auto head_stride_in_kv = param.head_stride_in_kv; + auto past_seq_len = k_past.m_dims[2]; // past kv src != dst, copy src to dst first - if (param.layer_past_key_src && param.layer_past_key_src[0] != param.layer_past_key_dst[0] && past_seq_len) - memcpyPastKV(param.layer_past_key_src, param.layer_past_value_src, param.layer_past_key_dst, param.layer_past_value_dst, batch, past_seq_len, head_stride_in_kv); + if (k_past.m_ptr != k_dst.m_ptr && past_seq_len) + memcpy_past_kv(k_past, v_past, k_dst, v_dst); // transpose + rotary embbeding: - // transpose: [batch, seq_len, num_attention_heads, 3 * head_size] --> - // 3 [batch, num_attention_heads, seq_len, head_size] + // transpose: [batch, seq_len, head_hum, 3 * head_size] --> + // 3 [batch, head_hum, seq_len, head_size] // rotary embbeding: part of key will write to past_key, part of query will write to tempory buffer - if (_create_param.dst_precision == dnnl_s8) { - // query pass part(temp buffer): query = torch.cat((query, query_pass), dim=-1) - // key pass part(past_key): key = torch.cat((key, key_pass), dim=-1) - // value(pastKeys): value = torch.cat((past_value, value), dim=-2) - // applyRotaryPosEmbMemcpyQuant(query, key, queryTranspose.get(), current_k_bufs, _output_type_size * new_seq_offset * _size_per_head_aligned, - // _cos_cached.get(), _sin_cached.get(), batch, seq_len, new_seq_offset, value, current_v_bufs); + if (q_src.m_dtype == dnnl_s8) { assert(false); } else { // query pass part(temp buffer): query = torch.cat((query, query_pass), dim=-1) // key pass part(past_key): key = torch.cat((key, key_pass), dim=-1) // value(pastKeys): value = torch.cat((past_value, value), dim=-2) - // q_dst shape: [batch, num_attention_heads, q_seq_len, head_size_aligned] - // kv_dst shape: [batch, num_attention_heads, q_seq_len+past_seq_len, head_size_aligned] - if (_use_position2d) { - applyRotaryPosEmbMemcpyWithPosition2d(query, key, value, param.ldq, param.ldk, param.ldv, query_dst, key_dst, value_dst, batch, query_seq_len, past_seq_len, - param.position2d_ids, head_stride_in_kv, param.cos, param.sin); - } else { - applyRotaryPosEmbMemcpy(query, key, value, param.ldq, param.ldk, param.ldv, query_dst, key_dst, value_dst, batch, query_seq_len, past_seq_len, head_stride_in_kv, - param.cos, param.sin); - } + rotary_emb_position2d(q_src, k_src, v_src, k_past, v_past, q_dst, k_dst, v_dst, cos, sin, position2d_ids); } } -emb_gpt::impl* new_impl_avx512() { - return new emb_gpt_impl_avx512(); -} - } \ No newline at end of file diff --git a/src/emb_gpt_avx512.hpp b/src/emb_gpt_avx512.hpp index 9a92742..0ef4d43 100644 --- a/src/emb_gpt_avx512.hpp +++ b/src/emb_gpt_avx512.hpp @@ -12,6 +12,15 @@ namespace llmdnn { -emb_gpt::impl* new_impl_avx512(); - +void emb_gpt_avx512(const tensor& q_src, + const tensor& k_src, + const tensor& v_src, + const tensor& k_past, + const tensor& v_past, + const tensor& q_dst, + const tensor& k_dst, + const tensor& v_dst, + const tensor& cos, + const tensor& sin, + const tensor& position2d_ids); } diff --git a/src/mha_gpt_amx.cpp b/src/mha_gpt_amx.cpp index bc34fed..846a89e 100644 --- a/src/mha_gpt_amx.cpp +++ b/src/mha_gpt_amx.cpp @@ -9,6 +9,7 @@ #include "common/simple_parallel.hpp" #include "common/tensor2d.hpp" #include "common/utility.hpp" +#include "llm_types.hpp" #include "utility_kernel_avx512.hpp" #include "mm_kernel_common_amx.hpp" #include "softmax_kernel_avx512.hpp" @@ -21,21 +22,18 @@ using namespace ov::cpu; namespace llmdnn { struct mha_gpt_impl_amx : public mha_gpt::impl { - bool create(const mha_gpt::create_param& param) override; - void exec(const mha_gpt::exec_param& param) override; + void create(data_type_t in_type, size_t seq_len, size_t head_size, bool is_bloom); + void exec(const tensor& q, const tensor& k, const tensor& v, const tensor& output, const tensor& attn_mask, const tensor& alibi, float normal_factor, bool use_causal_mask) override; - mha_gpt::create_param _create_param; + void mha_bf16(const tensor& q, const tensor& k, const tensor& v, const tensor& output, const tensor& attn_mask, const tensor& alibi, float normal_factor, bool use_causal_mask); - void mha_bf16(const mha_gpt::exec_param ¶m); - void mha_i8(const mha_gpt::exec_param ¶m); + size_t _head_size_aligned = 0; + size_t _buffer_mat0_out_size = 0; + size_t _buffer_mat1_out_size = 0; + size_t _num_threads = 0; - size_t head_size_aligned; - size_t bufferMatMul0OutSize; - size_t bufferMatMul1OutSize; - - std::shared_ptr bufferMatMul0Out; - std::shared_ptr bufferMatMul1Out; - std::shared_ptr qkvQuantBuf; + std::shared_ptr _buffer_mat0_out; + std::shared_ptr _buffer_mat1_out; std::vector>> gemAvB_BF16xBF16; std::vector>> qKtrGemm_BF16xBF16; @@ -46,402 +44,264 @@ struct mha_gpt_impl_amx : public mha_gpt::impl { std::vector>> gemAvB_i8xi8; }; -bool mha_gpt_impl_amx::create(const mha_gpt::create_param& param) { - if (param.qkv_precision != dnnl_bf16 && param.qkv_precision != dnnl_s8) { - std::cout << "input precision must be bf16 or int8.\n"; - return false; - } - if (param.dst_precision != dnnl_bf16 && param.dst_precision != dnnl_s8) { - std::cout << "dst precision must be bf16 or int8.\n"; - return false; - } - _create_param = param; - - // q: [batch, num_heads, query_seq_len, head_size] - // k: [batch, num_heads, maxSeqLen(valid: key_seq_len), head_size] - // v: [batch, num_heads, maxSeqLen(valid: value_seq_len), head_size] +void mha_gpt_impl_amx::create(data_type_t in_type, size_t seq_len, size_t head_size, bool is_bloom) { + // q: [batch, head_num, query_seq_len, head_size] + // k: [batch, head_num, maxSeqLen(valid: key_seq_len), head_size] + // v: [batch, head_num, maxSeqLen(valid: value_seq_len), head_size] // attention_mask: [batch, 1, 1, maxSeqLen(valid: key_seq_len)] - // matmul1: [batch, num_heads, query_seq_len, head_size] - // attn_output: [batch, query_seq_len, num_heads * head_size] - size_t numThreads = getTotalThreads(); - if (_create_param.qkv_precision == dnnl_s8) { - head_size_aligned = rndup(_create_param.head_size, 64); - qKtrGemm_i8xi8.resize(numThreads); - for (size_t i = 0; i < numThreads; i++) { - qKtrGemm_i8xi8[i] = std::make_shared>(false, !param.is_bloom); - } - qKVGemm_u8xi8.resize(numThreads); - for (size_t i = 0; i < numThreads; i++) { - qKVGemm_u8xi8[i] = std::make_shared>(false, false); - } - gemAvB_i8xi8.resize(numThreads); - for (size_t i = 0; i < numThreads; i++) { - gemAvB_i8xi8[i] = std::make_shared>(); - } - qkvQuantBuf = std::shared_ptr( - reinterpret_cast(aligned_alloc(64, param.head_size * sizeof(float))), - [](void * p) { ::free(p); }); - memset(qkvQuantBuf.get(), 0, sizeof(param.head_size * sizeof(float))); - } else { - head_size_aligned = rndup(_create_param.head_size, 32); - gemAvB_BF16xBF16.resize(numThreads); - for (size_t i = 0; i < numThreads; i++) { - gemAvB_BF16xBF16[i] = std::make_shared>(); - } - qKtrGemm_BF16xBF16.resize(numThreads); - for (size_t i = 0; i < numThreads; i++) { - qKtrGemm_BF16xBF16[i] = std::make_shared>(false, !param.is_bloom); - } - qKVGemm_BF16xBF16.resize(numThreads); - for (size_t i = 0; i < numThreads; i++) { - qKVGemm_BF16xBF16[i] = std::make_shared>(false, false); + // matmul1: [batch, head_num, query_seq_len, head_size] + // attn_output: [batch, query_seq_len, head_num * head_size] + if (_num_threads == 0) { + _num_threads = getTotalThreads(); + if (in_type == dnnl_s8) { + _head_size_aligned = rndup(head_size, 64); + qKtrGemm_i8xi8.resize(_num_threads); + for (size_t i = 0; i < _num_threads; i++) { + qKtrGemm_i8xi8[i] = std::make_shared>(false, !is_bloom); + } + qKVGemm_u8xi8.resize(_num_threads); + for (size_t i = 0; i < _num_threads; i++) { + qKVGemm_u8xi8[i] = std::make_shared>(false, false); + } + gemAvB_i8xi8.resize(_num_threads); + for (size_t i = 0; i < _num_threads; i++) { + gemAvB_i8xi8[i] = std::make_shared>(); + } + } else { + _head_size_aligned = rndup(head_size, 32); + gemAvB_BF16xBF16.resize(_num_threads); + for (size_t i = 0; i < _num_threads; i++) { + gemAvB_BF16xBF16[i] = std::make_shared>(); + } + qKtrGemm_BF16xBF16.resize(_num_threads); + for (size_t i = 0; i < _num_threads; i++) { + qKtrGemm_BF16xBF16[i] = std::make_shared>(false, !is_bloom); + } + qKVGemm_BF16xBF16.resize(_num_threads); + for (size_t i = 0; i < _num_threads; i++) { + qKVGemm_BF16xBF16[i] = std::make_shared>(false, false); + } } } - bufferMatMul0OutSize = _create_param.max_seq_len * rndup(_create_param.max_seq_len * sizeof(float), 64); - bufferMatMul1OutSize = _create_param.max_seq_len * head_size_aligned * sizeof(float); - - bufferMatMul0Out = std::shared_ptr( - reinterpret_cast(aligned_alloc(64, numThreads * bufferMatMul0OutSize)), - [](void * p) { ::free(p); }); - memset(bufferMatMul0Out.get(), 0, numThreads * bufferMatMul0OutSize); - bufferMatMul1Out = std::shared_ptr( - reinterpret_cast(aligned_alloc(64, numThreads * bufferMatMul1OutSize)), - [](void * p) { ::free(p); }); - memset(bufferMatMul1Out.get(), 0, numThreads * bufferMatMul1OutSize); - return true; + auto buffer_mat0_out_size = seq_len * rndup(seq_len * sizeof(float), 64); + if (buffer_mat0_out_size > _buffer_mat0_out_size) { + _buffer_mat0_out_size = seq_len * rndup(seq_len * sizeof(float), 64) * 3 / 2; + _buffer_mat1_out_size = seq_len * _head_size_aligned * sizeof(float) * 3 / 2; + + _buffer_mat0_out = std::shared_ptr( + reinterpret_cast(aligned_alloc(64, _num_threads * _buffer_mat0_out_size)), + [](void * p) { ::free(p); }); + memset(_buffer_mat0_out.get(), 0, _num_threads * _buffer_mat0_out_size); + _buffer_mat1_out = std::shared_ptr( + reinterpret_cast(aligned_alloc(64, _num_threads * _buffer_mat1_out_size)), + [](void * p) { ::free(p); }); + memset(_buffer_mat1_out.get(), 0, _num_threads * _buffer_mat1_out_size); + } } -void mha_gpt_impl_amx::mha_bf16(const mha_gpt::exec_param ¶m) { - auto& q = param.q; - auto& k = param.k; - auto& v = param.v; - auto* attn_masks = param.attention_mask.data(); - uint8_t* pout = param.attn_output.data(); - auto alibi = param.alibi.data(); +void mha_gpt_impl_amx::mha_bf16(const tensor& q, const tensor& k, const tensor& v, const tensor& output, const tensor& attn_mask, + const tensor& alibi, float normal_factor, bool use_causal_mask) { + auto batch = q.m_dims[0]; + auto head_num = q.m_dims[1]; + auto query_seq_len = q.m_dims[2]; + auto head_size = q.m_dims[3]; + auto key_seq_len = k.m_dims[2]; + bool is_bloom = k.m_strides[3] > k.m_strides[2]; + + uint8_t* out = output.data(); - auto outPrcSize = get_precision_size(_create_param.qkv_precision); auto& gemAvB_ops = gemAvB_BF16xBF16; auto& qKtrGemm_ops = qKtrGemm_BF16xBF16; auto& qKVGemm_ops = qKVGemm_BF16xBF16; - bool is_vector = param.query_seq_len == 1 && _create_param.head_size >= 32 && _create_param.head_size <= 32 * 6 && !_create_param.is_bloom; - size_t head_stride_in_attn = _create_param.head_size; - size_t batch_stride_in_attn = _create_param.head_size * _create_param.num_heads * param.query_seq_len; - size_t causal_mask_offset_start = param.key_seq_len - param.query_seq_len; + bool use_vector = query_seq_len == 1 && head_size >= 32 && head_size <= 32 * 6 && !is_bloom && !alibi && attn_mask; + size_t head_stride_in_attn = head_size; + size_t batch_stride_in_attn = head_size * head_num * query_seq_len; + size_t causal_mask_offset_start = use_causal_mask ? key_seq_len - query_seq_len : key_seq_len; - if (is_vector) { - parallel_for2d(param.batch, _create_param.num_heads, [&](size_t threadNum, size_t i0, size_t i1) { - auto pQIn0_aux = &q.at({i0, i1}); - auto pKIn0_aux = &k.at({i0, i1}); - auto pVIn0_aux = &v.at({i0, i1}); + if (use_vector) { + parallel_for2d(batch, head_num, [&](size_t thread_id, size_t i0, size_t i1) { + auto q_sub = &q.at({i0, i1}); + auto k_sub = &k.at({i0, i1}); + auto v_sub = &v.at({i0, i1}); - auto pAddIn1_aux = attn_masks + i0 * param.key_seq_len; + auto mat0_out = reinterpret_cast(_buffer_mat0_out.get() + thread_id * _buffer_mat0_out_size); + auto mat1_out = reinterpret_cast(_buffer_mat1_out.get() + thread_id * _buffer_mat1_out_size); - auto bufferMatMul0Out_local = reinterpret_cast(bufferMatMul0Out.get() + threadNum * bufferMatMul0OutSize); - auto bufferMatMul1Out_local = reinterpret_cast(bufferMatMul1Out.get() + threadNum * bufferMatMul1OutSize); - - tensor2D matK(param.key_seq_len, _create_param.head_size, reinterpret_cast(pKIn0_aux), k.m_strides[2]); + tensor2D matK(key_seq_len, head_size, reinterpret_cast(k_sub), k.m_strides[2]); // N: key_seq_len, K: head_size // q[1, K] * transpose(k[N, K]) ==> // k[N, K] * transpose(q[1, K]) ==> // k[N, K] * q[K, 1] - (*gemAvB_ops[threadNum])(matK, reinterpret_cast(pQIn0_aux), reinterpret_cast(bufferMatMul0Out_local)); - - float* pMatMul0Out = reinterpret_cast(bufferMatMul0Out_local); - mul_add_f32_avx512(pMatMul0Out, pMatMul0Out, _create_param.normal_factor, pAddIn1_aux, param.key_seq_len); - softmax_avx512(reinterpret_cast(pMatMul0Out), pMatMul0Out, param.key_seq_len, nullptr); - auto pOut_aux = pout + (i0 * batch_stride_in_attn + i1 * head_stride_in_attn) * outPrcSize; - tensor2D matQK(param.query_seq_len, param.key_seq_len, reinterpret_cast(bufferMatMul0Out_local), rndup(param.key_seq_len * sizeof(ov::bfloat16), 64)); - tensor2D matV(param.key_seq_len, _create_param.head_size, reinterpret_cast(pVIn0_aux), v.m_strides[2]); - tensor2D matQKV(param.query_seq_len, _create_param.head_size, reinterpret_cast(bufferMatMul1Out_local), head_size_aligned * sizeof(float)); + (*gemAvB_ops[thread_id])(matK, reinterpret_cast(q_sub), reinterpret_cast(mat0_out)); + + float* pMatMul0Out = reinterpret_cast(mat0_out); + mul_add_f32_avx512(pMatMul0Out, pMatMul0Out, normal_factor, &attn_mask.at({i0}), key_seq_len); + softmax_avx512(reinterpret_cast(pMatMul0Out), pMatMul0Out, key_seq_len, nullptr); + auto out_sub = out + (i0 * batch_stride_in_attn + i1 * head_stride_in_attn) * sizeof(ov::bfloat16); + tensor2D matQK(query_seq_len, key_seq_len, reinterpret_cast(mat0_out), rndup(key_seq_len * sizeof(ov::bfloat16), 64)); + tensor2D matV(key_seq_len, head_size, reinterpret_cast(v_sub), v.m_strides[2]); + tensor2D matQKV(query_seq_len, head_size, reinterpret_cast(mat1_out), _head_size_aligned * sizeof(float)); amx_kernel::PP::BiasGeluStore pp(matQKV); - (*qKVGemm_ops[threadNum])(matQK, matV, 0, _create_param.head_size, pp); - memcpy2d_stride_avx512(reinterpret_cast(pOut_aux), reinterpret_cast(bufferMatMul1Out_local), param.query_seq_len, - _create_param.head_size, head_size_aligned * sizeof(float), _create_param.num_heads * _create_param.head_size * sizeof(ov::bfloat16), nullptr); + (*qKVGemm_ops[thread_id])(matQK, matV, 0, head_size, pp); + memcpy2d_stride_avx512(reinterpret_cast(out_sub), reinterpret_cast(mat1_out), query_seq_len, + head_size, _head_size_aligned * sizeof(float), head_num * head_size * sizeof(ov::bfloat16), nullptr); }); } else { - auto numThreads = getTotalThreads(); - int seq_cout_all = rndup(param.query_seq_len, 32) / 32; - int work_amount = param.batch * _create_param.num_heads * seq_cout_all; - parallel_for(numThreads, [&](int threadNum) { - int i0; - int i1; - int seq; - int start {0}, end {0}; - splitter(work_amount, static_cast(numThreads), threadNum, start, end); + size_t seq_cout_all = rndup(query_seq_len, 32) / 32; + auto work_amount = batch * head_num * seq_cout_all; + parallel_for(_num_threads, [&](size_t thread_id) { + size_t i0; + size_t i1; + size_t seq; + size_t start {0}, end {0}; + splitter(work_amount, _num_threads, thread_id, start, end); if (start >= work_amount) return; - parallel_it_init(start, i0, param.batch, i1, _create_param.num_heads, seq, seq_cout_all); - uint8_t* prev_k = nullptr; - uint8_t* prev_v = nullptr; + parallel_it_init(start, i0, batch, i1, head_num, seq, seq_cout_all); + ov::bfloat16* prev_k = nullptr; + ov::bfloat16* prev_v = nullptr; for (int iwork = start; iwork < end; ++iwork) { - int seq_start = seq * 32; - int seq_end = std::min(static_cast(seq_start) + 32, param.query_seq_len); - int seq_cout = seq_end - seq_start; - // q: [batch, num_heads, query_seq_len, head_size] - // k: [batch, num_heads, key_seq_len, head_size] - // v: [batch, num_heads, value_seq_len, head_size] - auto pQIn0_aux = &q.at({static_cast(i0), static_cast(i1), static_cast(seq_start)}); - auto pKIn0_aux = &k.at({static_cast(i0), static_cast(i1)}); - auto pVIn0_aux = &v.at({static_cast(i0), static_cast(i1)}); - - auto bufferMatMul0Out_local = reinterpret_cast(bufferMatMul0Out.get() + threadNum * bufferMatMul0OutSize); - auto bufferMatMul1Out_local = reinterpret_cast(bufferMatMul1Out.get() + threadNum * bufferMatMul1OutSize); + auto seq_start = seq * 32; + auto seq_end = std::min(seq_start + 32, query_seq_len); + auto seq_cout = seq_end - seq_start; + // q: [batch, head_num, query_seq_len, head_size] + // k: [batch, head_num, key_seq_len, head_size] + // v: [batch, head_num, value_seq_len, head_size] + auto q_sub = &q.at({i0, i1, seq_start}); + auto k_sub = &k.at({i0, i1}); + auto v_sub = &v.at({i0, i1}); + + auto mat0_out = reinterpret_cast(_buffer_mat0_out.get() + thread_id * _buffer_mat0_out_size); + auto mat1_out = reinterpret_cast(_buffer_mat1_out.get() + thread_id * _buffer_mat1_out_size); - tensor2D matQ(seq_cout, _create_param.head_size, reinterpret_cast(pQIn0_aux), q.m_strides[2]); - tensor2D matQK(seq_cout, param.key_seq_len, reinterpret_cast(bufferMatMul0Out_local), rndup(param.key_seq_len * sizeof(float), 64)); + tensor2D matQ(seq_cout, head_size, q_sub, q.m_strides[2]); + tensor2D matQK(seq_cout, key_seq_len, mat0_out, rndup(key_seq_len * sizeof(float), 64)); amx_kernel::PP::BiasGeluStore pp(matQK); - if (!_create_param.is_bloom) { - tensor2D matK(param.key_seq_len, _create_param.head_size, reinterpret_cast(pKIn0_aux), k.m_strides[2]); - (*qKtrGemm_ops[threadNum])(matQ, matK, 0, param.key_seq_len, pp, pKIn0_aux == prev_k); + if (!is_bloom) { + tensor2D matK(key_seq_len, head_size, k_sub, k.m_strides[2]); + (*qKtrGemm_ops[thread_id])(matQ, matK, 0, key_seq_len, pp, k_sub == prev_k); } else { - tensor2D matK(_create_param.head_size, param.key_seq_len, reinterpret_cast(pKIn0_aux), k.m_strides[3]); - (*qKtrGemm_ops[threadNum])(matQ, matK, 0, param.key_seq_len, pp, pKIn0_aux == prev_k); + tensor2D matK(head_size, key_seq_len, k_sub, k.m_strides[3]); + (*qKtrGemm_ops[thread_id])(matQ, matK, 0, key_seq_len, pp, k_sub == prev_k); } - prev_k = pKIn0_aux; - - auto pMatMul0Out = bufferMatMul0Out_local; - if (param.is_causal_in_attention) { - auto pAddIn1_aux = attn_masks + i0 * param.key_seq_len * param.query_seq_len; - // loop along K dimension + prev_k = k_sub; + tensor2D softmax_dst(seq_cout, key_seq_len, reinterpret_cast(mat0_out), rndup(key_seq_len * sizeof(ov::bfloat16), 64)); + // no attention mask + size_t valid_softmax_items = std::min(causal_mask_offset_start + seq_start + 1, key_seq_len); + if (!attn_mask) { for (int m = 0; m < seq_cout; m++) { - float* src = reinterpret_cast(pMatMul0Out + m * rndup(param.key_seq_len * sizeof(float), 64)); - ov::bfloat16* dst = reinterpret_cast(pMatMul0Out + m * rndup(param.key_seq_len * sizeof(ov::bfloat16), 64)); - if (!_create_param.is_bloom) - mul_add_f32_avx512(src, src, _create_param.normal_factor, pAddIn1_aux + (m + seq_start) * param.key_seq_len, param.key_seq_len); + float* src = &matQK(m, 0); + ov::bfloat16* dst = &softmax_dst(m, 0); + if (!alibi) + mul_f32_avx512(src, src, normal_factor, valid_softmax_items); else // alibi shape: [batch, head_num, 1, key_seq_len] - mul_add2_f32_avx512(src, src, _create_param.normal_factor, - alibi + i0 * _create_param.num_heads * param.key_seq_len + i1 * param.key_seq_len, - pAddIn1_aux + (m + seq_start) * param.key_seq_len, - param.key_seq_len); - softmax_avx512(dst, src, param.key_seq_len, nullptr); - } - } else { - auto pAddIn1_aux = attn_masks + i0 * param.key_seq_len; - // loop along K dimension - size_t valid_softmax_items = causal_mask_offset_start + seq_start + 1; - for (int m = 0; m < seq_cout; m++) { - float* src = reinterpret_cast(pMatMul0Out + m * rndup(param.key_seq_len * sizeof(float), 64)); - ov::bfloat16* dst = reinterpret_cast(pMatMul0Out + m * rndup(param.key_seq_len * sizeof(ov::bfloat16), 64)); - if (!_create_param.is_bloom) - mul_add_f32_avx512(src, src, _create_param.normal_factor, pAddIn1_aux, valid_softmax_items); - else - mul_add2_f32_avx512(src, src, _create_param.normal_factor, - alibi + i0 * _create_param.num_heads * param.key_seq_len + i1 * param.key_seq_len, - pAddIn1_aux, + mul_add_f32_avx512(src, src, normal_factor, + &alibi.at({i0, i1}), valid_softmax_items); softmax_avx512(dst, src, valid_softmax_items, nullptr); - // attn_scores = torch.where(causal_mask, attn_scores, mask_value) - if (param.key_seq_len > valid_softmax_items) { + if (key_seq_len > valid_softmax_items) { auto *invalidPtr = dst + valid_softmax_items; - memset(static_cast(invalidPtr), 0, (param.key_seq_len - valid_softmax_items) * get_precision_size(_create_param.qkv_precision)); - valid_softmax_items = std::min(valid_softmax_items + 1, param.key_seq_len); + memset(static_cast(invalidPtr), 0, (key_seq_len - valid_softmax_items) * sizeof(ov::bfloat16)); + valid_softmax_items = std::min(valid_softmax_items + 1, key_seq_len); } } - } - - auto pOut_aux = pout + (i0 * batch_stride_in_attn + i1 * head_stride_in_attn - + seq_start * head_stride_in_attn * _create_param.num_heads) * outPrcSize; - tensor2D matQKBF16(seq_cout, param.key_seq_len, reinterpret_cast(bufferMatMul0Out_local), rndup(param.key_seq_len * sizeof(ov::bfloat16), 64)); - tensor2D matV(param.key_seq_len, _create_param.head_size, reinterpret_cast(pVIn0_aux), v.m_strides[2]); - tensor2D matQKV(seq_cout, _create_param.head_size, reinterpret_cast(bufferMatMul1Out_local), head_size_aligned * sizeof(float)); - amx_kernel::PP::BiasGeluStore pp2(matQKV); - (*qKVGemm_ops[threadNum])(matQKBF16, matV, 0, _create_param.head_size, pp2, prev_v == pVIn0_aux); - prev_v = pVIn0_aux; - memcpy2d_stride_avx512(reinterpret_cast(pOut_aux), reinterpret_cast(bufferMatMul1Out_local), seq_cout, - _create_param.head_size, head_size_aligned * sizeof(float), _create_param.num_heads * _create_param.head_size * sizeof(ov::bfloat16), nullptr); - parallel_it_step(i0, param.batch, i1, _create_param.num_heads, seq, seq_cout_all); - } - }); - } -} - -void mha_gpt_impl_amx::mha_i8(const mha_gpt::exec_param ¶m) { - auto& q = param.q; - auto& k = param.k; - auto& v = param.v; - auto attn_masks = param.attention_mask.data(); - uint8_t* pout = param.attn_output.data(); - auto alibi = param.alibi.data(); - - auto outPrcSize = get_precision_size(_create_param.dst_precision); - auto& gemAvB_ops = gemAvB_i8xi8; - auto& qKtrGemm_ops = qKtrGemm_i8xi8; - auto& qKVGemm_ops = qKVGemm_u8xi8; - bool is_vector = param.query_seq_len == 1 && _create_param.head_size >= 64 && _create_param.head_size <= 64 * 6 && !_create_param.is_bloom; - // dequant param - auto mul_scales = _create_param.normal_factor * param.q_dequant * param.k_dequant; - // prepare for per channel - assert(param.qkv_quant.size() == 1 || param.qkv_quant.size() == _create_param.head_size); - for (size_t i = 0; i < param.qkv_quant.size(); i++) { - (qkvQuantBuf.get())[i] = param.qkv_quant[i] * param.v_dequant / param.qk_quant; - } - if (param.qkv_quant.size() == 1) { - std::fill(qkvQuantBuf.get() + 1, qkvQuantBuf.get() + _create_param.head_size, *qkvQuantBuf.get()); - } - size_t head_stride_in_attn = _create_param.head_size; - size_t batch_stride_in_attn = _create_param.head_size * _create_param.num_heads * param.query_seq_len; - size_t causal_mask_offset_start = param.key_seq_len - param.query_seq_len; - - if (is_vector) { - parallel_for2d(param.batch, _create_param.num_heads, [&](size_t threadNum, size_t i0, size_t i1) { - auto pQIn0_aux = &q.at({i0, i1}); - auto pKIn0_aux = &k.at({i0, i1}); - auto pVIn0_aux = &v.at({i0, i1}); - - auto pAddIn1_aux = attn_masks + i0 * param.key_seq_len; - - auto bufferMatMul0Out_local = reinterpret_cast(bufferMatMul0Out.get() + threadNum * bufferMatMul0OutSize); - auto bufferMatMul1Out_local = reinterpret_cast(bufferMatMul1Out.get() + threadNum * bufferMatMul1OutSize); - - tensor2D matK(param.key_seq_len, _create_param.head_size, reinterpret_cast(pKIn0_aux), k.m_strides[2]); - // N: key_seq_len, K: head_size - // q[1, K] * transpose(k[N, K]) ==> - // k[N, K] * transpose(q[1, K]) ==> - // k[N, K] * q[K, 1] - (*gemAvB_ops[threadNum])(matK, reinterpret_cast(pQIn0_aux), reinterpret_cast(bufferMatMul0Out_local)); - cvt_i32_f32_avx512(reinterpret_cast(bufferMatMul0Out_local), reinterpret_cast(bufferMatMul0Out_local), param.key_seq_len); - - float* pMatMul0Out = reinterpret_cast(bufferMatMul0Out_local); - mul_add_f32_avx512(pMatMul0Out, pMatMul0Out, mul_scales, pAddIn1_aux, param.key_seq_len); - softmax_avx512(reinterpret_cast(pMatMul0Out), pMatMul0Out, param.key_seq_len, param.qk_quant); - auto pOut_aux = pout + (i0 * batch_stride_in_attn + i1 * head_stride_in_attn) * outPrcSize; - tensor2D matQK(param.query_seq_len, param.key_seq_len, reinterpret_cast(bufferMatMul0Out_local), rndup(param.key_seq_len * sizeof(uint8_t), 64)); - tensor2D matV(param.key_seq_len, _create_param.head_size, reinterpret_cast(pVIn0_aux), v.m_strides[2]); - tensor2D matQKV(param.query_seq_len, _create_param.head_size, reinterpret_cast(bufferMatMul1Out_local), head_size_aligned * sizeof(float)); - amx_kernel::PP::BiasGeluStore pp(matQKV); - (*qKVGemm_ops[threadNum])(matQK, matV, 0, _create_param.head_size, pp); - memcpy2d_stride_avx512(reinterpret_cast(pOut_aux), reinterpret_cast(bufferMatMul1Out_local), param.query_seq_len, - _create_param.head_size, head_size_aligned * sizeof(float), _create_param.num_heads * _create_param.head_size, qkvQuantBuf.get()); - }); - } else { - auto numThreads = getTotalThreads(); - int seq_cout_all = rndup(param.query_seq_len, 32) / 32; - int work_amount = param.batch * _create_param.num_heads * seq_cout_all; - parallel_for(numThreads, [&](int threadNum) { - int i0; - int i1; - int seq; - int start {0}, end {0}; - splitter(work_amount, static_cast(numThreads), threadNum, start, end); - if (start >= work_amount) return; - - parallel_it_init(start, i0, param.batch, i1, _create_param.num_heads, seq, seq_cout_all); - uint8_t* prev_k = nullptr; - uint8_t* prev_v = nullptr; - for (int iwork = start; iwork < end; ++iwork) { - int seq_start = seq * 32; - int seq_end = std::min(static_cast(seq_start) + 32, param.query_seq_len); - int seq_cout = seq_end - seq_start; - // q: [batch, num_heads, query_seq_len, head_size] - // k: [batch, num_heads, key_seq_len, head_size] - // v: [batch, num_heads, value_seq_len, head_size] - auto pQIn0_aux = &q.at({static_cast(i0), static_cast(i1), static_cast(seq_start)}); - auto pKIn0_aux = &k.at({static_cast(i0), static_cast(i1)}); - auto pVIn0_aux = &v.at({static_cast(i0), static_cast(i1)}); - - auto bufferMatMul0Out_local = reinterpret_cast(bufferMatMul0Out.get() + threadNum * bufferMatMul0OutSize); - auto bufferMatMul1Out_local = reinterpret_cast(bufferMatMul1Out.get() + threadNum * bufferMatMul1OutSize); - - tensor2D matQ(seq_cout, _create_param.head_size, reinterpret_cast(pQIn0_aux), q.m_strides[2]); - tensor2D matQK(seq_cout, param.key_seq_len, reinterpret_cast(bufferMatMul0Out_local), rndup(param.key_seq_len * sizeof(float), 64)); - amx_kernel::PP::BiasGeluStore pp(matQK); - if (!_create_param.is_bloom) { - tensor2D matK(param.key_seq_len, _create_param.head_size, reinterpret_cast(pKIn0_aux), k.m_strides[2]); - (*qKtrGemm_ops[threadNum])(matQ, matK, 0, param.key_seq_len, pp, prev_k == pKIn0_aux); } else { - tensor2D matK(_create_param.head_size, param.key_seq_len, reinterpret_cast(pKIn0_aux), k.m_strides[3]); - (*qKtrGemm_ops[threadNum])(matQ, matK, 0, param.key_seq_len, pp, prev_k == pKIn0_aux); - } - prev_k = pKIn0_aux; - - auto pMatMul0Out = bufferMatMul0Out_local; - if (param.is_causal_in_attention) { - auto pAddIn1_aux = attn_masks + i0 * param.key_seq_len * param.query_seq_len; - // loop along K dimension + // [batch, 1, 1, key_seq_len] or [batch, 1, query_seq_len, key_seq_len] for (int m = 0; m < seq_cout; m++) { - float* src = reinterpret_cast(pMatMul0Out + m * rndup(param.key_seq_len * sizeof(float), 64)); - uint8_t* dst = reinterpret_cast(pMatMul0Out + m * rndup(param.key_seq_len * sizeof(uint8_t), 64)); - mul_add_f32_avx512(src, src, mul_scales, pAddIn1_aux + (m + seq_start) * param.key_seq_len, param.key_seq_len); - if (!_create_param.is_bloom) - mul_add_f32_avx512(src, src, mul_scales, pAddIn1_aux + (m + seq_start) * param.key_seq_len, param.key_seq_len); + auto attn_sub = &attn_mask.at({i0, 0, attn_mask.m_dims[2] == 1 ? 0 : m + seq_start}); // s + i0 * key_seq_len * query_seq_len; + float* src = &matQK(m, 0); + ov::bfloat16* dst = &softmax_dst(m, 0); + if (!alibi) + mul_add_f32_avx512(src, src, normal_factor, attn_sub, valid_softmax_items); else // alibi shape: [batch, head_num, 1, key_seq_len] - mul_add2_f32_avx512(src, src, mul_scales, - alibi + i0 * _create_param.num_heads * param.key_seq_len + i1 * param.key_seq_len, - pAddIn1_aux + (m + seq_start) * param.key_seq_len, - param.key_seq_len); - softmax_avx512(dst, src, param.key_seq_len, param.qk_quant); - } - } else { - auto pAddIn1_aux = attn_masks + i0 * param.key_seq_len; - // loop along K dimension - size_t valid_softmax_items = causal_mask_offset_start + seq_start + 1; - for (int m = 0; m < seq_cout; m++) { - float* src = reinterpret_cast(pMatMul0Out + m * rndup(param.key_seq_len * sizeof(float), 64)); - uint8_t* dst = reinterpret_cast(pMatMul0Out + m * rndup(param.key_seq_len * sizeof(uint8_t), 64)); - if (!_create_param.is_bloom) - mul_add_f32_avx512(src, src, mul_scales, pAddIn1_aux, valid_softmax_items); - else - mul_add2_f32_avx512(src, src, mul_scales, - alibi + i0 * _create_param.num_heads * param.key_seq_len + i1 * param.key_seq_len, - pAddIn1_aux, + mul_add2_f32_avx512(src, src, normal_factor, + &alibi.at({i0, i1}), + attn_sub, valid_softmax_items); - softmax_avx512(dst, src, valid_softmax_items, param.qk_quant); - // attn_scores = torch.where(causal_mask, attn_scores, mask_value) - if (param.key_seq_len > valid_softmax_items) { + softmax_avx512(dst, src, valid_softmax_items, nullptr); + if (key_seq_len > valid_softmax_items) { auto *invalidPtr = dst + valid_softmax_items; - memset(invalidPtr, 0, (param.key_seq_len - valid_softmax_items) * get_precision_size(_create_param.qkv_precision)); - valid_softmax_items = std::min(valid_softmax_items + 1, param.key_seq_len); + memset(static_cast(invalidPtr), 0, (key_seq_len - valid_softmax_items) * sizeof(ov::bfloat16)); + valid_softmax_items = std::min(valid_softmax_items + 1, key_seq_len); } } } - auto pOut_aux = pout + (i0 * batch_stride_in_attn + i1 * head_stride_in_attn - + seq_start * head_stride_in_attn * _create_param.num_heads) * outPrcSize; - tensor2D matQKI8(seq_cout, param.key_seq_len, reinterpret_cast(bufferMatMul0Out_local), rndup(param.key_seq_len * sizeof(uint8_t), 64)); - tensor2D matV(param.key_seq_len, _create_param.head_size, reinterpret_cast(pVIn0_aux), v.m_strides[2]); - tensor2D matQKV(seq_cout, _create_param.head_size, reinterpret_cast(bufferMatMul1Out_local), head_size_aligned * sizeof(float)); + + auto out_sub = out + (i0 * batch_stride_in_attn + i1 * head_stride_in_attn + + seq_start * head_stride_in_attn * head_num) * sizeof(ov::bfloat16); + tensor2D matQKBF16(seq_cout, key_seq_len, softmax_dst.data, softmax_dst.stride); + tensor2D matV(key_seq_len, head_size, v_sub, v.m_strides[2]); + tensor2D matQKV(seq_cout, head_size, mat1_out, _head_size_aligned * sizeof(float)); amx_kernel::PP::BiasGeluStore pp2(matQKV); - (*qKVGemm_ops[threadNum])(matQKI8, matV, 0, _create_param.head_size, pp2, prev_v == pVIn0_aux); - prev_v = pVIn0_aux; - // matmul1: [batch, num_heads, query_seq_len, head_size] - // attn_output: [batch, query_seq_len, num_heads * head_size] - memcpy2d_stride_avx512(reinterpret_cast(pOut_aux), reinterpret_cast(bufferMatMul1Out_local), seq_cout, - _create_param.head_size, head_size_aligned * sizeof(float), _create_param.num_heads * _create_param.head_size, qkvQuantBuf.get()); - parallel_it_step(i0, param.batch, i1, _create_param.num_heads, seq, seq_cout_all); + (*qKVGemm_ops[thread_id])(matQKBF16, matV, 0, head_size, pp2, prev_v == v_sub); + prev_v = v_sub; + memcpy2d_stride_avx512(reinterpret_cast(out_sub), mat1_out, seq_cout, + head_size, _head_size_aligned * sizeof(float), head_num * head_size * sizeof(ov::bfloat16), nullptr); + parallel_it_step(i0, batch, i1, head_num, seq, seq_cout_all); } }); } } -void mha_gpt_impl_amx::exec(const mha_gpt::exec_param& param) { - if (param.q.m_rank != 4 || param.k.m_rank != 4 || param.v.m_rank != 4) { +void mha_gpt_impl_amx::exec(const tensor& q, const tensor& k, const tensor& v, const tensor& output, const tensor& attn_mask, const tensor& alibi, float normal_factor, bool use_causal_mask) { + if (q.m_rank != 4 || k.m_rank != 4 || v.m_rank != 4) { std::cout << "q,k,v rank does not equal 4.\n"; return; } - auto b = param.q.m_dims[0]; - auto hn = param.q.m_dims[1]; - auto qs = param.q.m_dims[2]; - auto hs = param.q.m_dims[3]; - auto ks = param.k.m_dims[2]; - - if (!(b == param.k.m_dims[0] && b == param.v.m_dims[0] && - hn == param.k.m_dims[1] && hn == param.v.m_dims[1] && - ks == param.v.m_dims[2] && - hs == param.k.m_dims[3] && hs == param.v.m_dims[3])) { + if (output.m_rank != 3) { + std::cout << "output rank should be 3.\n"; + } + if (attn_mask) { + if (attn_mask.m_rank != 4) { + std::cout << "attn_mask rank should be 4.\n"; + return; + } + if (attn_mask.m_dims[1] != 1) { + std::cout << "attn_mask dim 1 should be 1.\n"; + return; + } + } + if (alibi) { + if (alibi.m_rank != 4) { + std::cout << "alibi rank should be 4.\n"; + return; + } + if (alibi.m_dims[1] != k.m_dims[1]) { + std::cout << "alibi dim 1 should be equal to k dim 1.\n"; + return; + } + if (alibi.m_dims[2] != 1) { + std::cout << "alibi dim 2 should be 1.\n"; + return; + } + } + auto batch = q.m_dims[0]; + auto head_num = q.m_dims[1]; + auto query_seq_len = q.m_dims[2]; + auto head_size = q.m_dims[3]; + auto key_seq_len = k.m_dims[2]; + + if (!(batch == k.m_dims[0] && batch == v.m_dims[0] && + head_num == k.m_dims[1] && head_num == v.m_dims[1] && + key_seq_len == v.m_dims[2] && + head_size == k.m_dims[3] && head_size == v.m_dims[3])) { std::cout << "dim of q,k,v is error.\n"; return; } - if (_create_param.qkv_precision == dnnl_f32) { - assert(false); - } else if (_create_param.qkv_precision == dnnl_bf16) { - mha_bf16(param); - } else if (_create_param.qkv_precision == dnnl_s8) { - mha_i8(param); + bool is_bloom = k.m_strides[3] > k.m_strides[2]; + + auto in_dtype = q.m_dtype; + auto out_dtype = output.m_dtype; + + if (in_dtype == dnnl_bf16 && out_dtype == dnnl_bf16) { + create(in_dtype, key_seq_len, head_size, is_bloom); + mha_bf16(q, k, v, output, attn_mask, alibi, normal_factor, use_causal_mask); } else { - assert(false && "doesn't support provided input precisions"); + std::cout << "doesn't support provided input precisions.\n"; } } diff --git a/src/mha_gpt_api.cpp b/src/mha_gpt_api.cpp index cf67cf3..9ba35d5 100644 --- a/src/mha_gpt_api.cpp +++ b/src/mha_gpt_api.cpp @@ -17,12 +17,8 @@ mha_gpt::~mha_gpt() { delete _impl; } -bool mha_gpt::create(const create_param& param) { - return _impl->create(param); -} - -void mha_gpt::exec(const exec_param& param) { - _impl->exec(param); +void mha_gpt::exec(const tensor& q, const tensor& k, const tensor& v, const tensor& output, const tensor& attn_mask, const tensor& alibi, float normal_factor, bool use_causal_mask) { + _impl->exec(q, k, v, output, attn_mask, alibi, normal_factor, use_causal_mask); } } \ No newline at end of file diff --git a/src/utility_kernel_avx512.hpp b/src/utility_kernel_avx512.hpp index 109888c..1743253 100644 --- a/src/utility_kernel_avx512.hpp +++ b/src/utility_kernel_avx512.hpp @@ -91,6 +91,23 @@ inline void cvt_i32_f32_avx512(float* dst, int32_t* src, size_t ele_num) { } } +inline void mul_f32_avx512(float* dst, float* src, float mul, int ele_num) { + auto mul_f = _mm512_set1_ps(mul); + int i; + auto tail = ele_num % 16; + __mmask16 msk = _cvtu32_mask16(0xFFFFu >> (16 - tail)); + for (i = 0; i < ele_num - tail; i += 16) { + auto a_f = _mm512_loadu_ps(src); + _mm512_storeu_ps(dst, _mm512_mul_ps(a_f, mul_f)); + src += 16; + dst += 16; + } + if (tail) { + auto a_f = _mm512_maskz_loadu_ps(msk, src); + _mm512_mask_storeu_ps(dst, msk, _mm512_mul_ps(a_f, mul_f)); + } +} + inline void mul_add_f32_avx512(float* dst, float* src, float mul, float* add, int ele_num) { auto mul_f = _mm512_set1_ps(mul); int i; diff --git a/tests/script/ext/attn_gpt.cpp b/tests/script/ext/attn_gpt.cpp deleted file mode 100644 index 10aa193..0000000 --- a/tests/script/ext/attn_gpt.cpp +++ /dev/null @@ -1,249 +0,0 @@ -// Copyright (C) 2018-2023 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#include -#include -#include "alloca.h" -#include "module.hpp" -#include "common/utility.hpp" -#include "utility_kernel_amx.hpp" -#include "llm_emb_gpt.hpp" -#include "llm_mha_gpt.hpp" -#include "test_common.hpp" - -using namespace torch::indexing; - -class attn_gpt { -public: - struct create_param { - size_t num_heads; - size_t head_size; - size_t head_size_aligned; // better to aligned to 64 bytes for best performance, apply for qkv - size_t max_seq_len; // max seq length for computing the size of matmul tmp result - // supported (qkv, dst): (bf16, bf16) - llmdnn::data_type_t qkv_precision; - llmdnn::data_type_t dst_precision; - size_t rotary_dims; - float normal_factor; - bool use_position2d; - }; - struct exec_param { - size_t batch; - size_t query_seq_len; - size_t past_seq_len; - bool is_causal_in_attention; // causal mask is fused in attention mask: chatglm uses it. - uint8_t* q; // shape: [batch, query_seq_len, hidden size], inner stride is ldq - uint8_t* k; // shape: [batch, query_seq_len, hidden size], inner stride is ldk - uint8_t* v; // shape: [batch, query_seq_len, hidden size], inner stride is ldv - size_t ldq; // inner stride of q - size_t ldk; // inner stride of k - size_t ldv; // inner stride of v - uint8_t** layer_past_key_dst; - uint8_t** layer_past_value_dst; - int* position2d_ids; // shape: [batch, 2, query_seq_len] - float* attention_mask; // attention mask, attention_mask[0] shape: - // [batch, 1, 1, key_seq_len], when is_causal_in_attention is false - // [batch, 1, query_seq_len, key_seq_len], when is_causal_in_attention is true - uint8_t* attn_output; - size_t head_stride_in_kv; - float* cos; - float* sin; - }; - - attn_gpt(); - bool create(const create_param& param); - void exec(const exec_param& param); - -private: - create_param _create_param; - std::shared_ptr _emb_gpt; - std::shared_ptr _mha_gpt; - std::shared_ptr _query_dst; - size_t _query_cached_batch = 0; -}; - -attn_gpt::attn_gpt(): _emb_gpt(std::make_shared()), - _mha_gpt(std::make_shared()) { - -} - -bool attn_gpt::create(const attn_gpt::create_param& param) { - _create_param = param; - llmdnn::emb_gpt::create_param emb_param = {0}; - emb_param.num_heads = param.num_heads; - emb_param.head_size = param.head_size; - emb_param.head_size_aligned = param.head_size_aligned; - emb_param.qkv_precision = param.qkv_precision; - emb_param.dst_precision = param.dst_precision; - emb_param.rotary_dims = param.rotary_dims; - emb_param.use_position2d = param.use_position2d; - - if (!_emb_gpt->create(emb_param)) - return false; - - llmdnn::mha_gpt::create_param mha_param = {0}; - mha_param.num_heads = param.num_heads; - mha_param.head_size = param.head_size; - mha_param.normal_factor = param.normal_factor; - mha_param.qkv_precision = param.qkv_precision; - mha_param.dst_precision = param.dst_precision; - mha_param.max_seq_len = param.max_seq_len; - - return _mha_gpt->create(mha_param); -} - -void attn_gpt::exec(const attn_gpt::exec_param& param) { - if (_query_cached_batch < param.batch) { - auto capacity = param.batch * _create_param.max_seq_len * (_create_param.num_heads * _create_param.head_size_aligned) * - llmdnn::get_precision_size(_create_param.qkv_precision); - _query_dst = std::shared_ptr(reinterpret_cast(aligned_alloc(64, capacity)), - [](void * p) { ::free(p); }); - memset(_query_dst.get(), 0, capacity); - _query_cached_batch = param.batch; - } - - llmdnn::emb_gpt::exec_param emb_param = {0}; - emb_param.batch = param.batch; - emb_param.query_seq_len = param.query_seq_len; - emb_param.past_seq_len = param.past_seq_len; - emb_param.q = param.q; - emb_param.k = param.k; - emb_param.v = param.v; - emb_param.ldq = param.ldq; - emb_param.ldk = param.ldk; - emb_param.ldv = param.ldv; - emb_param.query_dst = _query_dst.get(); - emb_param.layer_past_key_src = param.layer_past_key_dst; - emb_param.layer_past_value_src = param.layer_past_value_dst; - emb_param.layer_past_key_dst = param.layer_past_key_dst; - emb_param.layer_past_value_dst = param.layer_past_value_dst; - emb_param.position2d_ids = param.position2d_ids; - emb_param.head_stride_in_kv = param.head_stride_in_kv; - emb_param.cos = param.cos; - emb_param.sin = param.sin; - _emb_gpt->exec(emb_param); - - llmdnn::mha_gpt::exec_param mha_param = {0}; - mha_param.batch = param.batch; - mha_param.query_seq_len = param.query_seq_len; - mha_param.key_seq_len = param.query_seq_len + param.past_seq_len; - mha_param.q.resize({param.batch, _create_param.num_heads, param.query_seq_len, _create_param.head_size_aligned}, reinterpret_cast(emb_param.query_dst)); - mha_param.attn_output.resize({param.batch, param.query_seq_len, _create_param.num_heads * _create_param.head_size}, reinterpret_cast(param.attn_output)); - mha_param.is_causal_in_attention = param.is_causal_in_attention; - mha_param.attention_mask.resize({param.batch, 1, param.query_seq_len, mha_param.key_seq_len}, static_cast(param.attention_mask)); - mha_param.k.resize({param.batch, _create_param.num_heads, _create_param.max_seq_len, _create_param.head_size_aligned}, reinterpret_cast(param.layer_past_key_dst[0])); - mha_param.v.resize({param.batch, _create_param.num_heads, _create_param.max_seq_len, _create_param.head_size_aligned}, reinterpret_cast(param.layer_past_value_dst[0])); - _mha_gpt->exec(mha_param); -} - -void regclass_attn_gpt(pybind11::module m) { - py::class_> cls(m, "attn_gpt"); - cls.def(py::init<>()); - cls.def("create", [] (attn_gpt& self, - const size_t num_heads, - const size_t head_size, - const size_t head_size_aligned, - float normal_factor, - const std::string qkv_precision_name, - const std::string dst_precision_name, - const size_t max_seq_len, - const size_t rotary_dims, - bool use_position2d) { - attn_gpt::create_param param = {0}; - param.num_heads = num_heads; - param.head_size = head_size; - param.head_size_aligned = head_size_aligned; - param.normal_factor = normal_factor; - param.qkv_precision = llmdnn::get_dt_from_str(qkv_precision_name); - param.dst_precision = llmdnn::get_dt_from_str(dst_precision_name); - param.max_seq_len = max_seq_len; - param.rotary_dims = rotary_dims; - param.use_position2d = use_position2d; - if (param.qkv_precision == llmdnn::dnnl_data_type_undef) - throw pybind11::type_error("Incorrect qkv type " + qkv_precision_name); - if (param.dst_precision == llmdnn::dnnl_data_type_undef) - throw pybind11::type_error("Incorrect dst type " + dst_precision_name); - if (!self.create(param)) - throw pybind11::type_error("Incorrect param"); - }, - py::arg("num_heads"), - py::arg("head_size"), - py::arg("head_size_aligned"), - py::arg("normal_factor"), - py::arg("qkv_precision_name"), - py::arg("dst_precision_name"), - py::arg("max_seq_len"), - py::arg("rotary_dims"), - py::arg("use_position2d") = false, - R"( - Create emb - - :param num_heads: heads number. - :type num_heads: int - )"); - cls.def("exec_position", [] (attn_gpt& self, const torch::Tensor& qkv, const torch::Tensor& layer_past_key_dst, - const torch::Tensor& layer_past_value_dst, int64_t past_seq_len, const torch::Tensor& attn_mask, - const torch::Tensor& position2d_ids, const torch::Tensor& cos, const torch::Tensor& sin) { - // qkv: [batch, seq_len, (num_heads * 3 * head_size)] - // layer_past_padded: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] - // past_seq_len: past_seq_len==layer_past.shape[-2] - // attn_mask: [batch, 1, 1/query_seq_len, key_seq_len] - // key/value: [batch, num_heads, query_seq_len+past_seq_len, head_size_aligned] - AT_ASSERT(qkv.dim() == 3 && layer_past_key_dst.dim() == 4 && layer_past_value_dst.dim() == 4 && attn_mask.dim() == 4 && - qkv.size(0) == layer_past_key_dst.size(0) && - layer_past_key_dst.dim() == layer_past_value_dst.dim()); - auto batch = qkv.size(0); - auto num_heads = layer_past_key_dst.size(1); - auto query_seq_len = qkv.size(1); - auto head_size = qkv.size(2) / 3 / num_heads; - auto head_size_aligned = layer_past_key_dst.size(3); - auto max_seq_len = layer_past_key_dst.size(2); - AT_ASSERT(past_seq_len <= layer_past_key_dst.size(2) && head_size <= layer_past_key_dst.size(3) && - query_seq_len <= layer_past_key_dst.size(2)); - - attn_gpt::exec_param param = {0}; - param.batch = batch; - param.query_seq_len = query_seq_len; - param.past_seq_len = past_seq_len; - param.q = reinterpret_cast(qkv.data_ptr()); - param.k = param.q + head_size * sizeof(ov::bfloat16); - param.v = param.k + head_size * sizeof(ov::bfloat16); - param.ldq = param.ldk = param.ldv = head_size * 3; - param.layer_past_key_dst = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); - param.layer_past_value_dst = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); - for (int i = 0; i < batch; i++) { - param.layer_past_key_dst[i] = reinterpret_cast(layer_past_key_dst[i].data_ptr()); - param.layer_past_value_dst[i] = reinterpret_cast(layer_past_value_dst[i].data_ptr()); - } - param.position2d_ids = reinterpret_cast(position2d_ids.data_ptr()); - - param.is_causal_in_attention = attn_mask.size(2) != 1; - param.attention_mask = attn_mask.data_ptr(); - param.head_stride_in_kv = max_seq_len * head_size_aligned; - auto out = qkv.new_empty({batch, query_seq_len, num_heads * head_size}); - param.attn_output = reinterpret_cast(out.data_ptr()); - param.cos = reinterpret_cast(cos.data_ptr()); - param.sin = reinterpret_cast(sin.data_ptr()); - - self.exec(param); - - // auto options = torch::TensorOptions().dtype(torch::kBFloat16); - // auto query = torch::from_blob(param.query, {batch, num_heads, query_seq_len, head_size}, options); - return out; - }, - py::arg("qkv"), - py::arg("layer_past_key_dst"), - py::arg("layer_past_value_dst"), - py::arg("past_seq_len"), - py::arg("attn_mask"), - py::arg("position2d_ids"), - py::arg("cos"), - py::arg("sin"), - R"( - exec emb - - :param num_heads: heads number. - :type num_heads: int - )"); -} \ No newline at end of file diff --git a/tests/script/ext/emb_gpt.cpp b/tests/script/ext/emb_gpt.cpp index 669ac44..1576934 100644 --- a/tests/script/ext/emb_gpt.cpp +++ b/tests/script/ext/emb_gpt.cpp @@ -4,7 +4,9 @@ #include #include +#include #include "alloca.h" +#include "llm_tensor.hpp" #include "module.hpp" #include "common/utility.hpp" #include "utility_kernel_amx.hpp" @@ -14,167 +16,59 @@ using namespace torch::indexing; void regclass_emb_gpt(pybind11::module m) { - py::class_> cls(m, "emb_gpt"); - cls.def(py::init<>()); - cls.def("create", [] (llmdnn::emb_gpt& self, - const size_t num_heads, - const size_t head_size, - const size_t head_size_aligned, - const std::string qkv_precision_name, - const std::string dst_precision_name, - const size_t rotary_dims, - bool use_position2d) { - llmdnn::emb_gpt::create_param param = {0}; - param.num_heads = num_heads; - param.head_size = head_size; - param.head_size_aligned = head_size_aligned; - param.qkv_precision = llmdnn::get_dt_from_str(qkv_precision_name); - param.dst_precision = llmdnn::get_dt_from_str(dst_precision_name); - param.rotary_dims = rotary_dims; - param.use_position2d = use_position2d; - if (param.qkv_precision == llmdnn::dnnl_data_type_undef) - throw pybind11::type_error("Incorrect qkv type " + qkv_precision_name); - if (param.dst_precision == llmdnn::dnnl_data_type_undef) - throw pybind11::type_error("Incorrect dst type " + dst_precision_name); - if (!self.create(param)) - throw pybind11::type_error("Incorrect param"); - }, - py::arg("num_heads"), - py::arg("head_size"), - py::arg("head_size_aligned"), - py::arg("qkv_precision_name"), - py::arg("dst_precision_name"), - py::arg("rotary_dims"), - py::arg("use_position2d") = false, - R"( - Create emb - - :param num_heads: heads number. - :type num_heads: int - )"); - // torch::List - cls.def("exec", [] (llmdnn::emb_gpt& self, const torch::Tensor& qkv, const torch::Tensor& layer_past_key_dst, - const torch::Tensor& layer_past_value_dst, const torch::Tensor& query_padded, int64_t past_seq_len, const torch::Tensor& cos, const torch::Tensor& sin) { + m.def("emb_gpt", [] ( + const torch::Tensor& qkv, + const torch::Tensor& k_past, + const torch::Tensor& v_past, + const torch::Tensor& cos, + const torch::Tensor& sin, + const torch::Tensor& position2d_ids) { // qkv: [batch, seq_len, (num_heads * 3 * head_size)] - // layer_past_padded: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] - // past_seq_len: past_seq_len==layer_past.shape[-2] - // query_padded: [batch, num_heads, query_seq_len, head_size_aligned] - // key/value: [batch, num_heads, query_seq_len+past_seq_len, head_size_aligned] - AT_ASSERT(qkv.dim() == 3 && layer_past_key_dst.dim() == 4 && layer_past_value_dst.dim() == 4 && - qkv.size(0) == layer_past_key_dst.size(0) && - layer_past_key_dst.dim() == layer_past_value_dst.dim()); - AT_ASSERT(query_padded.dim() == 4 && query_padded.size(0) == qkv.size(0) && - query_padded.size(1) == layer_past_key_dst.size(1) && query_padded.size(2) == qkv.size(1) && - query_padded.size(3) == layer_past_key_dst.size(3)); + // k_past: [batch, head_num, past_seq_len, head_size] + // q_dst: [batch, head_num, query_seq_len, head_size] + // k_dst: [batch, head_num, query_seq_len+past_seq_len, head_size] + // cos: [max_seq_len, rotary_dims] + // position2d_ids: [batch, 2, query_seq_len] + AT_ASSERT(qkv.dim() == 3 && k_past.dim() == 4 && v_past.dim() == 4); auto batch = qkv.size(0); - auto num_heads = layer_past_key_dst.size(1); auto query_seq_len = qkv.size(1); - auto head_size = qkv.size(2) / 3 / num_heads; - auto head_size_aligned = layer_past_key_dst.size(3); - auto max_seq_len = layer_past_key_dst.size(2); - AT_ASSERT(past_seq_len <= layer_past_key_dst.size(2) && head_size <= layer_past_key_dst.size(3) && - query_seq_len <= layer_past_key_dst.size(2)); - - llmdnn::emb_gpt::exec_param param = {0}; - param.batch = batch; - param.query_seq_len = query_seq_len; - param.past_seq_len = past_seq_len; - param.q = reinterpret_cast(qkv.data_ptr()); - param.k = param.q + head_size * sizeof(ov::bfloat16); - param.v = param.k + head_size * sizeof(ov::bfloat16); - param.ldq = param.ldk = param.ldv = head_size * 3; - param.query_dst = reinterpret_cast(query_padded.data_ptr()); - param.layer_past_key_src = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); - param.layer_past_value_src = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); - param.layer_past_key_dst = param.layer_past_key_src; - param.layer_past_value_dst = param.layer_past_value_src; - param.head_stride_in_kv = max_seq_len * head_size_aligned; - param.cos = reinterpret_cast(cos.data_ptr()); - param.sin = reinterpret_cast(sin.data_ptr()); - for (int i = 0; i < batch; i++) { - param.layer_past_key_src[i] = reinterpret_cast(layer_past_key_dst[i].data_ptr()); - param.layer_past_value_src[i] = reinterpret_cast(layer_past_value_dst[i].data_ptr()); - } + auto head_num = k_past.size(1); + auto head_size = k_past.size(3); + auto past_seq_len = k_past.size(2); + auto kv_seq_len = query_seq_len + past_seq_len; - self.exec(param); + torch::Tensor q_dst = qkv.new_empty({batch, head_num, query_seq_len, head_size}); + torch::Tensor k_dst = qkv.new_empty({batch, head_num, kv_seq_len, head_size}); + torch::Tensor v_dst = qkv.new_empty({batch, head_num, kv_seq_len, head_size}); + llmdnn::tensor q_, k_, v_, k_past_, v_past_, q_dst_, k_dst_, v_dst_, cos_, sin_, position2d_ids_; + q_.resize({batch, query_seq_len, head_num, head_size * 3}, reinterpret_cast(qkv.data_ptr()) + head_size * 0); + q_.m_dims[3] = head_size; + k_.resize({batch, query_seq_len, head_num, head_size * 3}, reinterpret_cast(qkv.data_ptr()) + head_size * 1); + k_.m_dims[3] = head_size; + v_.resize({batch, query_seq_len, head_num, head_size * 3}, reinterpret_cast(qkv.data_ptr()) + head_size * 2); + v_.m_dims[3] = head_size; + k_past_.resize({batch, head_num, past_seq_len, head_size}, reinterpret_cast(k_past.data_ptr())); + v_past_.resize({batch, head_num, past_seq_len, head_size}, reinterpret_cast(v_past.data_ptr())); + q_dst_.resize({batch, head_num, query_seq_len, head_size}, reinterpret_cast(q_dst.data_ptr())); + k_dst_.resize({batch, head_num, kv_seq_len, head_size}, reinterpret_cast(k_dst.data_ptr())); + v_dst_.resize({batch, head_num, kv_seq_len, head_size}, reinterpret_cast(v_dst.data_ptr())); + cos_.resize({cos.size(0), cos.size(1), cos.size(2), cos.size(3)}, cos.data_ptr()); + sin_.resize({sin.size(0), sin.size(1), sin.size(2), sin.size(3)}, sin.data_ptr()); + if (position2d_ids.numel()) + position2d_ids_.resize({batch, 2, query_seq_len}, position2d_ids.data_ptr()); + + llmdnn::emb_gpt(q_, k_, v_, k_past_, v_past_, q_dst_, k_dst_, v_dst_, cos_, sin_, position2d_ids_); + return std::make_tuple(q_dst, k_dst, v_dst); // auto options = torch::TensorOptions().dtype(torch::kBFloat16); // auto query = torch::from_blob(param.query, {batch, num_heads, query_seq_len, head_size}, options); }, py::arg("qkv"), - py::arg("layer_past_key_dst"), - py::arg("layer_past_value_dst"), - py::arg("query_padded"), - py::arg("past_seq_len"), + py::arg("k_past"), + py::arg("v_past"), py::arg("cos"), py::arg("sin"), - R"( - exec emb - - :param num_heads: heads number. - :type num_heads: int - )"); - cls.def("exec_position", [] (llmdnn::emb_gpt& self, const torch::Tensor& qkv, const torch::Tensor& layer_past_key_src, const torch::Tensor& layer_past_value_src, - const torch::Tensor& layer_past_key_dst, const torch::Tensor& layer_past_value_dst, const torch::Tensor& query_padded, int64_t past_seq_len, const torch::Tensor position2d_ids, const torch::Tensor& cos, const torch::Tensor& sin) { - // qkv: [batch, seq_len, (num_heads * 3 * head_size)] - // layer_past_padded: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] - // past_seq_len: past_seq_len==layer_past.shape[-2] - // query_padded: [batch, num_heads, query_seq_len, head_size_aligned] - // key/value: [batch, num_heads, query_seq_len+past_seq_len, head_size_aligned] - AT_ASSERT(qkv.dim() == 3 && layer_past_key_dst.dim() == 4 && layer_past_value_dst.dim() == 4 && - qkv.size(0) == layer_past_key_dst.size(0) && - layer_past_key_dst.dim() == layer_past_value_dst.dim()); - AT_ASSERT(query_padded.dim() == 4 && query_padded.size(0) == qkv.size(0) && - query_padded.size(1) == layer_past_key_dst.size(1) && query_padded.size(2) == qkv.size(1) && - query_padded.size(3) == layer_past_key_dst.size(3)); - auto batch = qkv.size(0); - auto num_heads = layer_past_key_dst.size(1); - auto query_seq_len = qkv.size(1); - auto head_size = qkv.size(2) / 3 / num_heads; - auto head_size_aligned = layer_past_key_dst.size(3); - auto max_seq_len = layer_past_key_dst.size(2); - AT_ASSERT(past_seq_len <= layer_past_key_dst.size(2) && head_size <= layer_past_key_dst.size(3) && - query_seq_len <= layer_past_key_dst.size(2)); - - llmdnn::emb_gpt::exec_param param = {0}; - param.batch = batch; - param.query_seq_len = query_seq_len; - param.past_seq_len = past_seq_len; - param.q = reinterpret_cast(qkv.data_ptr()); - param.k = param.q + head_size * sizeof(ov::bfloat16); - param.v = param.k + head_size * sizeof(ov::bfloat16); - param.ldq = param.ldk = param.ldv = head_size * 3; - param.query_dst = reinterpret_cast(query_padded.data_ptr()); - param.layer_past_key_src = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); - param.layer_past_value_src = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); - param.layer_past_key_dst = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); - param.layer_past_value_dst = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); - param.head_stride_in_kv = max_seq_len * head_size_aligned; - param.cos = reinterpret_cast(cos.data_ptr()); - param.sin = reinterpret_cast(sin.data_ptr()); - for (int i = 0; i < batch; i++) { - param.layer_past_key_src[i] = past_seq_len == 0 ? nullptr : reinterpret_cast(layer_past_key_src[i].data_ptr()); - param.layer_past_value_src[i] = past_seq_len == 0 ? nullptr : reinterpret_cast(layer_past_value_src[i].data_ptr()); - param.layer_past_key_dst[i] = reinterpret_cast(layer_past_key_dst[i].data_ptr()); - param.layer_past_value_dst[i] = reinterpret_cast(layer_past_value_dst[i].data_ptr()); - } - param.position2d_ids = reinterpret_cast(position2d_ids.data_ptr()); - - self.exec(param); - - // auto options = torch::TensorOptions().dtype(torch::kBFloat16); - // auto query = torch::from_blob(param.query, {batch, num_heads, query_seq_len, head_size}, options); - }, - py::arg("qkv"), - py::arg("layer_past_key_src"), - py::arg("layer_past_value_src"), - py::arg("layer_past_key_dst"), - py::arg("layer_past_value_dst"), - py::arg("query_padded"), - py::arg("past_seq_len"), py::arg("position2d_ids"), - py::arg("cos"), - py::arg("sin"), R"( exec emb diff --git a/tests/script/ext/mha_gpt.cpp b/tests/script/ext/mha_gpt.cpp index eb734fc..844550b 100644 --- a/tests/script/ext/mha_gpt.cpp +++ b/tests/script/ext/mha_gpt.cpp @@ -7,6 +7,7 @@ #include #include "alloca.h" #include "common/bf16.hpp" +#include "llm_tensor.hpp" #include "module.hpp" #include "common/utility.hpp" #include "utility_kernel_amx.hpp" @@ -16,80 +17,40 @@ void regclass_mha_gpt(pybind11::module m) { py::class_> cls(m, "mha_gpt"); cls.def(py::init<>()); - cls.def("create", [] (llmdnn::mha_gpt& self, - const size_t num_heads, - const size_t head_size, - const float normal_factor, - const std::string qkv_precision_name, - const std::string dst_precision_name, - const size_t max_seq_len, - bool is_bloom) { - llmdnn::mha_gpt::create_param param = {0}; - param.num_heads = num_heads; - param.head_size = head_size; - param.normal_factor = normal_factor; - param.qkv_precision = llmdnn::get_dt_from_str(qkv_precision_name); - param.dst_precision = llmdnn::get_dt_from_str(dst_precision_name); - param.max_seq_len = max_seq_len; - param.is_bloom = is_bloom; - if (param.qkv_precision == llmdnn::dnnl_data_type_undef) - throw pybind11::type_error("Incorrect qkv type " + qkv_precision_name); - if (param.dst_precision == llmdnn::dnnl_data_type_undef) - throw pybind11::type_error("Incorrect dst type " + dst_precision_name); - if (!self.create(param)) - throw pybind11::type_error("Incorrect param"); - }, - py::arg("num_heads"), - py::arg("head_size"), - py::arg("normal_factor"), - py::arg("qkv_precision_name"), - py::arg("dst_precision_name"), - py::arg("max_seq_len"), - py::arg("is_bloom") = false, - R"( - Create mha - - :param num_heads: heads number. - :type num_heads: int - )"); - cls.def("exec", [] (llmdnn::mha_gpt& self, const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v, const torch::Tensor& alibi, const torch::Tensor& attn_mask, int64_t head_size, int64_t key_seq_len) { - // q: [batch, num_heads, query_seq_len, head_size_aligned] - // k: [batch, num_heads, head_size_aligned, max_seq_len] valid in max_seq_len: key_seq_len - // v: [batch, num_heads, max_seq_len, head_size_aligned] valid in max_seq_len: key_seq_len + cls.def("exec", [] (llmdnn::mha_gpt& self, const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v, const torch::Tensor& alibi, + const torch::Tensor& attn_mask, float normal_factor, bool use_causal) { + // q: [batch, num_heads, query_seq_len, head_size] + // k: [batch, num_heads, key_seq_len, head_size] + // v: [batch, num_heads, key_seq_len, head_size] // attn_mask: [batch, 1, 1/query_seq_len, key_seq_len] // out: [batch, query_seq_len, num_heads * head_size] + // alibi: [batch, num_heads, 1, key_seq_len] AT_ASSERT(q.dim() == 4 && k.dim() == 4 && v.dim() == 4 && attn_mask.dim() == 4); auto batch = q.size(0); auto num_heads = q.size(1); auto query_seq_len = q.size(2); - auto head_size_aligned = q.size(3); - auto max_seq_len = v.size(2); - auto attn_len = attn_mask.size(3); - AT_ASSERT(max_seq_len == v.size(2) && - batch == k.size(0) && batch == v.size(0) && batch == attn_mask.size(0) && + auto head_size = q.size(3); + AT_ASSERT(batch == k.size(0) && batch == v.size(0) && batch == attn_mask.size(0) && num_heads == k.size(1) && num_heads == v.size(1) && - head_size_aligned == v.size(3)); + head_size == v.size(3)); - llmdnn::mha_gpt::exec_param param = {0}; - param.batch = batch; - param.query_seq_len = query_seq_len; - param.key_seq_len = key_seq_len == 0 ? max_seq_len : key_seq_len; - head_size = head_size == 0 ? head_size_aligned : head_size; - auto out = q.new_empty({batch, query_seq_len, num_heads * head_size}); - AT_ASSERT((int64_t)param.key_seq_len == attn_len); - param.q.resize({q.size(0), q.size(1), q.size(2), q.size(3)}, reinterpret_cast(q.data_ptr())); - param.attn_output.resize({batch, query_seq_len, num_heads * head_size}, reinterpret_cast(out.data_ptr())); - param.is_causal_in_attention = attn_mask.size(2) != 1; - param.attention_mask.resize({attn_mask.size(0), attn_mask.size(1), attn_mask.size(2), attn_mask.size(3)}, attn_mask.data_ptr()); - param.k.resize({k.size(0), k.size(1), k.size(2), k.size(3)}, reinterpret_cast(k.data_ptr())); - if (alibi.dim() == 3) { - std::swap(param.k.m_dims[2], param.k.m_dims[3]); - std::swap(param.k.m_strides[2], param.k.m_strides[3]); - param.alibi.resize({alibi.size(0), alibi.size(1), alibi.size(2)}, alibi.data_ptr()); + llmdnn::tensor q_, k_, v_, out_, attn_mask_, alibi_; + q_.resize({q.size(0), q.size(1), q.size(2), q.size(3)}, reinterpret_cast(q.data_ptr())); + k_.resize({k.size(0), k.size(1), k.size(2), k.size(3)}, reinterpret_cast(k.data_ptr())); + if (k.size(2) != v.size(2)) { + // bloom k shape: [batch, num_heads, head_size, key_seq_len] + std::swap(k_.m_dims[2], k_.m_dims[3]); + std::swap(k_.m_strides[2], k_.m_strides[3]); } - param.v.resize({v.size(0), v.size(1), v.size(2), v.size(3)}, reinterpret_cast(v.data_ptr())); + v_.resize({v.size(0), v.size(1), v.size(2), v.size(3)}, reinterpret_cast(v.data_ptr())); + auto out = q.new_empty({batch, query_seq_len, num_heads * head_size}); + out_.resize({batch, query_seq_len, num_heads * head_size}, reinterpret_cast(out.data_ptr())); + if (attn_mask.numel()) + attn_mask_.resize({attn_mask.size(0), attn_mask.size(1), attn_mask.size(2), attn_mask.size(3)}, attn_mask.data_ptr()); + if (alibi.numel()) + alibi_.resize({alibi.size(0), alibi.size(1), alibi.size(2), alibi.size(3)}, alibi.data_ptr()); + self.exec(q_, k_, v_, out_, attn_mask_, alibi_, normal_factor, use_causal); - self.exec(param); return out; }, py::arg("q"), @@ -97,71 +58,11 @@ void regclass_mha_gpt(pybind11::module m) { py::arg("v"), py::arg("alibi"), py::arg("attn_mask"), - py::arg("head_size") = 0, - py::arg("key_seq_len") = 0, + py::arg("normal_factor"), + py::arg("use_causal"), R"( exec mha - :param num_heads: heads number. - :type num_heads: int - )"); - cls.def("exec_quant", [] (llmdnn::mha_gpt& self, const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v, const torch::Tensor& attn_mask, - float q_dequant, float k_dequant, float v_dequant, float qk_quant, const std::vector& qkv_quant, int64_t head_size, int64_t key_seq_len) { - // q: [batch, num_heads, query_seq_len, head_size_aligned] - // k: [batch, num_heads, max_seq_len, head_size_aligned] valid in max_seq_len: key_seq_len - // v: [batch, num_heads, max_seq_len, head_size_aligned] valid in max_seq_len: key_seq_len - // attn_mask: [batch, 1, 1/query_seq_len, key_seq_len] - // out: [batch, query_seq_len, num_heads * head_size] - AT_ASSERT(q.dim() == 4 && k.dim() == 4 && v.dim() == 4 && attn_mask.dim() == 4); - auto batch = q.size(0); - auto num_heads = q.size(1); - auto query_seq_len = q.size(2); - auto head_size_aligned = q.size(3); - auto max_seq_len = k.size(2); - auto attn_len = attn_mask.size(3); - AT_ASSERT(max_seq_len == v.size(2) && - batch == k.size(0) && batch == v.size(0) && batch == attn_mask.size(0) && - num_heads == k.size(1) && num_heads == v.size(1) && - head_size_aligned == k.size(3) && head_size_aligned == v.size(3)); - - llmdnn::mha_gpt::exec_param param = {0}; - param.batch = batch; - param.query_seq_len = query_seq_len; - param.key_seq_len = key_seq_len == 0 ? max_seq_len : key_seq_len; - head_size = head_size == 0 ? head_size_aligned : head_size; - auto out = q.new_empty({batch, query_seq_len, num_heads * head_size}, torch::TensorOptions(torch::kInt8)); - AT_ASSERT((int64_t)param.key_seq_len == attn_len); - - param.q.resize({q.size(0), q.size(1), q.size(2), q.size(3)}, reinterpret_cast(q.data_ptr())); - param.attn_output.resize({batch, query_seq_len, num_heads * head_size}, reinterpret_cast(out.data_ptr())); - param.is_causal_in_attention = attn_mask.size(2) != 1; - param.attention_mask.resize({attn_mask.size(0), attn_mask.size(1), attn_mask.size(2), attn_mask.size(3)}, attn_mask.data_ptr()); - param.k.resize({k.size(0), k.size(1), k.size(2), k.size(3)}, reinterpret_cast(k.data_ptr())); - param.v.resize({v.size(0), v.size(1), v.size(2), v.size(3)}, reinterpret_cast(v.data_ptr())); - - param.q_dequant = q_dequant; - param.k_dequant = k_dequant; - param.v_dequant = v_dequant; - param.qk_quant = qk_quant; - param.qkv_quant = qkv_quant; - - self.exec(param); - return out; - }, - py::arg("q"), - py::arg("k"), - py::arg("v"), - py::arg("attn_mask"), - py::arg("q_dequant"), - py::arg("k_dequant"), - py::arg("v_dequant"), - py::arg("qk_quant"), - py::arg("qkv_quant"), - py::arg("head_size") = 0, - py::arg("key_seq_len") = 0, - R"( - exec mha quant - :param num_heads: heads number. :type num_heads: int )"); diff --git a/tests/script/ext/module.cpp b/tests/script/ext/module.cpp index 58c0ba1..d016470 100644 --- a/tests/script/ext/module.cpp +++ b/tests/script/ext/module.cpp @@ -16,5 +16,4 @@ PYBIND11_MODULE(llmdnn, m) { } regclass_mha_gpt(m); regclass_emb_gpt(m); - regclass_attn_gpt(m); } \ No newline at end of file diff --git a/tests/script/ext/module.hpp b/tests/script/ext/module.hpp index 12efeed..f7a704a 100644 --- a/tests/script/ext/module.hpp +++ b/tests/script/ext/module.hpp @@ -8,4 +8,3 @@ void regclass_mha_gpt(pybind11::module m); void regclass_emb_gpt(pybind11::module m); -void regclass_attn_gpt(pybind11::module m); \ No newline at end of file diff --git a/tests/script/ext/setup.py b/tests/script/ext/setup.py index 3ea5aa9..1087a1e 100644 --- a/tests/script/ext/setup.py +++ b/tests/script/ext/setup.py @@ -33,7 +33,7 @@ ['module.cpp', f'../../src/test_common.cpp', 'mha_gpt.cpp', 'emb_gpt.cpp', - 'attn_gpt.cpp', + #'attn_gpt.cpp', ], extra_compile_args=extra_args, include_dirs=[f'{os.getcwd()}/../../src', diff --git a/tests/script/test_attn_chatglm.py b/tests/script/test_attn_chatglm.py deleted file mode 100644 index 2caf02a..0000000 --- a/tests/script/test_attn_chatglm.py +++ /dev/null @@ -1,452 +0,0 @@ -# Copyright (C) 2018-2023 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 -# - -import math -import sys -import torch -import numpy as np -import llmdnn as ld -from torch import nn -import torch.nn.functional as F -from torch.nn.utils import skip_init -from typing import Optional, Tuple, Union, List, Callable, Dict, Any - -# copy from chatglm-6b/modeling_chatglm.py -class RotaryEmbedding(torch.nn.Module): - def __init__(self, dim, max_position_embeddings, base=10000, precision=torch.half, learnable=False): - super().__init__() - inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) - #inv_freq = inv_freq.half() - self.learnable = learnable - if learnable: - self.inv_freq = torch.nn.Parameter(inv_freq) - self.max_seq_len_cached = None - else: - self.register_buffer('inv_freq', inv_freq) - self.max_seq_len_cached = None - self.cos_cached = None - self.sin_cached = None - # use f32 to pass accuracy test - # Build here to make `torch.jit.trace` work. - self.max_seq_len_cached = max_position_embeddings - t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.cos_cached = emb.cos()[:, None, :] - self.sin_cached = emb.sin()[:, None, :] - - self.precision = precision - - def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, - error_msgs): - pass - - def forward(self, x, seq_dim=1, seq_len=None): - if seq_len is None: - seq_len = x.shape[seq_dim] - if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached): - self.max_seq_len_cached = None if self.learnable else seq_len - t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype) - freqs = torch.einsum('i,j->ij', t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1).to(x.device) - if self.precision == torch.bfloat16: - emb = emb.float() - - # [sx, 1 (b * np), hn] - cos_cached = emb.cos()[:, None, :] - sin_cached = emb.sin()[:, None, :] - if self.precision == torch.bfloat16: - cos_cached = cos_cached.bfloat16() - sin_cached = sin_cached.bfloat16() - if self.learnable: - return cos_cached, sin_cached - self.cos_cached, self.sin_cached = cos_cached, sin_cached - return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...] - - def _apply(self, fn): - if self.cos_cached is not None: - self.cos_cached = fn(self.cos_cached) - if self.sin_cached is not None: - self.sin_cached = fn(self.sin_cached) - return super()._apply(fn) - - -def rotate_half(x): - x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] - return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions - - -@torch.jit.script -def apply_rotary_pos_emb_index(q, k, cos, sin, position_id): - # position_id: [b, sq], q, k: [b, sq, np, hn], cos: [sq, 1, hn] -> [b, sq, 1, hn] - cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), \ - F.embedding(position_id, sin.squeeze(1)).unsqueeze(2) - q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) - return q, k - -class SelfAttention(torch.nn.Module): - def __init__(self, max_position_embeddings, hidden_size, num_attention_heads, - layer_id, hidden_size_per_attention_head=None, bias=True, - params_dtype=torch.float, position_encoding_2d=True, empty_init=True): - super(SelfAttention, self).__init__() - - self.layer_id = layer_id - self.hidden_size = hidden_size - self.hidden_size_per_partition = hidden_size - self.num_attention_heads = num_attention_heads - self.head_size = self.hidden_size // self.num_attention_heads - self.norm_factor = torch.sqrt(torch.tensor(self.head_size, dtype=torch.float32)).to(torch.get_default_dtype()) - self.num_attention_heads_per_partition = num_attention_heads - self.position_encoding_2d = position_encoding_2d - self.rotary_emb = RotaryEmbedding( - self.hidden_size // (self.num_attention_heads * 2) - if position_encoding_2d - else self.hidden_size // self.num_attention_heads, - max_position_embeddings, - base=10000, - precision=torch.half, - learnable=False, - ) - - self.scale_mask_softmax = None - - if hidden_size_per_attention_head is None: - self.hidden_size_per_attention_head = hidden_size // num_attention_heads - else: - self.hidden_size_per_attention_head = hidden_size_per_attention_head - - self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head - - @staticmethod - def attention_mask_func(attention_scores, attention_mask): - attention_scores.masked_fill_(attention_mask, -10000.0) - return attention_scores - - def split_tensor_along_last_dim(self, tensor, num_partitions, - contiguous_split_chunks=False): - """Split a tensor along its last dimension. - Arguments: - tensor: input tensor. - num_partitions: number of partitions to split the tensor - contiguous_split_chunks: If True, make each chunk contiguous - in memory. - """ - # Get the size and dimension. - last_dim = tensor.dim() - 1 - last_dim_size = tensor.size()[last_dim] // num_partitions - # Split. - tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) - # Note: torch.split does not create contiguous tensors by default. - if contiguous_split_chunks: - return tuple(chunk.contiguous() for chunk in tensor_list) - - return tensor_list - - def attention_fn( - self, - query_layer, - key_layer, - value_layer, - layer_past=None, - attention_mask=None - ): - # batch, seqlen, num_attention_heads, hidden_size_per_attention_head - b, seq_len, nh, hidden_size = key_layer.shape - - # =================================== - # Raw attention scores. [b, np, s, s] - # =================================== - - # [b, np, sq, sk] - output_size = (query_layer.size(0), query_layer.size(2), query_layer.size(1), key_layer.size(1)) - - # # [sq, b, np, hn] -> [sq, b * np, hn] - # query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) - # # [sk, b, np, hn] -> [sk, b * np, hn] - # key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) - # [b, sq, np, hn] -> [b, np, sq, hn] - query_layer = query_layer.transpose(1, 2) - # [b, sk, np, hn] -> [b, np, sk, hn] - key_layer = key_layer.transpose(1, 2) - # [b, np, sq, hn] -> [b * np, sq, hn] - #query_layer = query_layer.view(output_size[0] * output_size[1], output_size[2], -1) - # [b, np, sk, hn] -> [b * np, sk, hn] - #key_layer = key_layer.view(output_size[0] * output_size[1], output_size[3], -1) - - # value_layer -> context layer. - # [b, sk, np, hn] --> [b, np, sq, hn] - - # context layer shape: [b, np, sq, hn] - output_size = (value_layer.size(0), value_layer.size(2), query_layer.size(1), value_layer.size(3)) - - # # change view [sk, b * np, hn] - # value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) - # [b, sk, np, hn] -> [b, np, sk, hn] - value_layer = value_layer.transpose(1, 2) - - if layer_past is not None: - past_key, past_value = layer_past[0], layer_past[1] - key_layer = torch.cat((past_key, key_layer), dim=2) - value_layer = torch.cat((past_value, value_layer), dim=2) - - # [b, np, sk, hn] -> [b * np, sk, hn] - #value_layer = value_layer.view(output_size[0] * output_size[1], value_layer.size(2), output_size[3]) - # return query_layer, key_layer, value_layer - past = (key_layer, value_layer) - - # from test_mha_chatglm.py/forward - query_layer = query_layer / self.norm_factor - - # =================================== - # Raw attention scores. [b, np, s, s] - # =================================== - - # [b, np, sq, sk] - output_size = (query_layer.size(0), query_layer.size(1), query_layer.size(2), key_layer.size(2)) - - # [b, np, sq, hn] -> [b * np, sq, hn] - query_layer = query_layer.reshape(output_size[0] * output_size[1], output_size[2], -1) - # [b, np, sk, hn] -> [b * np, sk, hn] - key_layer = key_layer.reshape(output_size[0] * output_size[1], output_size[3], -1) - - matmul_result = torch.zeros( - 1, 1, 1, - dtype=query_layer.dtype, - device=query_layer.device, - ) - - matmul_result = torch.baddbmm( - matmul_result, - query_layer, # [b * np, sq, hn] - key_layer.transpose(1, 2), # [b * np, hn, sk] - beta=0.0, - alpha=1.0, - ) - - # change view to [b, np, sq, sk] - attention_scores = matmul_result.view(*output_size) - - # if not (attention_mask == 0).all(): - # # if auto-regressive, skip - # attention_scores.masked_fill_(attention_mask, -10000.0) - attention_scores = attention_scores + attention_mask - dtype = attention_scores.dtype - attention_scores = attention_scores.float() - - attention_probs = F.softmax(attention_scores, dim=-1) - - attention_probs = attention_probs.type(dtype) - - # ========================= - # Context layer. [sq, b, hp] - # ========================= - - # value_layer -> context layer. - # [b, sk, np, hn] --> [b, np, sq, hn] - - # context layer shape: [b, np, sq, hn] - output_size = (value_layer.size(0), value_layer.size(1), query_layer.size(1), value_layer.size(3)) - - # [b, np, sk, hn] -> [b * np, sk, hn] - value_layer = value_layer.view(output_size[0] * output_size[1], value_layer.size(2), output_size[3]) - - # change view [b * np, sq, sk] - attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) - - # matmul: [b * np, sq, hn] - context_layer = torch.bmm(attention_probs, value_layer) - - # change view [b, np, sq, hn] - context_layer = context_layer.view(*output_size) - - # [b, np, sq, hn] --> [b, sq, np, hn] - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - - # [b, sq, np, hn] --> [b, sq, hp] - new_context_layer_shape = (output_size[0], output_size[2], output_size[1] * output_size[3]) - context_layer = context_layer.view(*new_context_layer_shape) - - return context_layer, past - - - def forward( - self, - qkv: torch.Tensor, # [batch, seq_len, 3 * hidden_size] - layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]], - attention_mask: torch.Tensor, # [batch, 1, query_seq_len, key_seq_len] - position_ids # [batch, 2, query_seq_len] - ): - """ - hidden_states: [batch, seq_len, hidden_size] - attention_mask: [(1, 1), seq_len, seq_len] - """ - - # [batch, seq_len, 3 * hidden_size] - mixed_raw_layer = qkv # self.query_key_value(hidden_states) - - # [batch, seq_len, 3 * hidden_size] --> [batch, seq_len, num_attention_heads, 3 * hidden_size_per_attention_head] - new_tensor_shape = mixed_raw_layer.size()[:-1] + ( - self.num_attention_heads_per_partition, - 3 * self.hidden_size_per_attention_head, - ) - mixed_raw_layer = mixed_raw_layer.view(*new_tensor_shape) - - # [batch, seq_len, num_attention_heads, hidden_size_per_attention_head] - (query_layer, key_layer, value_layer) = self.split_tensor_along_last_dim(mixed_raw_layer, 3) - - if self.position_encoding_2d: - q1, q2 = query_layer.chunk(2, dim=(query_layer.ndim - 1)) - k1, k2 = key_layer.chunk(2, dim=(key_layer.ndim - 1)) - cos, sin = self.rotary_emb(q1, seq_len=position_ids.max() + 1) - # xxx - position_ids, block_position_ids = position_ids[:, 0, :].contiguous(), \ - position_ids[:, 1, :].contiguous() - q1, k1 = apply_rotary_pos_emb_index(q1, k1, cos, sin, position_ids) - q2, k2 = apply_rotary_pos_emb_index(q2, k2, cos, sin, block_position_ids) - query_layer = torch.concat([q1, q2], dim=(q1.ndim - 1)) - key_layer = torch.concat([k1, k2], dim=(k1.ndim - 1)) - else: - position_ids = position_ids.transpose(0, 1) - cos, sin = self.rotary_emb(value_layer, seq_len=position_ids.max() + 1) - # [batch, seq_len, num_attention_heads, hidden_size_per_attention_head] - query_layer, key_layer = apply_rotary_pos_emb_index(query_layer, key_layer, cos, sin, position_ids) - - query_layer = query_layer.to(dtype=torch.bfloat16) - key_layer = key_layer.to(dtype=torch.bfloat16) - - # [batch, seq_len, hidden_size] - attn, past = self.attention_fn( - query_layer=query_layer, - key_layer=key_layer, - value_layer=value_layer, - layer_past=layer_past, - attention_mask=attention_mask - ) - - return attn, past - - -class GPTAttentionExt: - def __init__(self, num_attention_heads, hidden_size, head_size_aligned, max_position_embeddings, rotary_emb_base, rotary_pct, is_int8=False): - self.attn = ld.attn_gpt() - num_heads = num_attention_heads - head_size = hidden_size // num_attention_heads - max_seq_len = max_position_embeddings - normal_factor = 1.0 / math.sqrt(head_size) - - qkv_precision_name = 's8' if is_int8 else 'bf16' - dst_precision_name = 's8' if is_int8 else 'bf16' - rotary_ndims = int(head_size * rotary_pct) - self.attn.create(num_heads, head_size, head_size_aligned, normal_factor, qkv_precision_name, - dst_precision_name, max_seq_len, rotary_ndims, True) - - inv_freq = 1. / (rotary_emb_base ** (torch.arange(0, rotary_ndims, 2).float() / rotary_ndims)) - #inv_freq = inv_freq.half() - self.max_seq_len_cached = None - self.cos_cached = None - self.sin_cached = None - # use f32 to pass accuracy test - # Build here to make `torch.jit.trace` work. - self.max_seq_len_cached = max_position_embeddings - t = torch.arange(self.max_seq_len_cached, device=inv_freq.device, dtype=inv_freq.dtype) - freqs = torch.einsum("i,j->ij", t, inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.cos_cached = emb.cos()[:, None, :] - self.sin_cached = emb.sin()[:, None, :] - - # qkv: [batch, seq_len, (num_heads * 3 * head_size)] - # layer_past_padded: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] - # past_seq_len: past_seq_len==layer_past.shape[-2] - # attn_mask: [batch, 1, query_seq_len, key_seq_len] - # position_ids: [batch, 2, query_seq_len] - # return: - # 0: qkv [batch, seq_len, (num_heads * 3 * head_size)] - # 1: k: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] - # 2: v: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] - def forward(self, qkv, layer_past_key_padded, layer_past_value_padded, past_seq_len, attn_mask, position_ids): - return self.attn.exec_position(qkv, layer_past_key_padded, layer_past_value_padded, past_seq_len, attn_mask, position_ids, self.cos_cached, self.sin_cached) - - -HEAD_NUM = 32 -SIZE_PER_HEAD = 80 -SIZE_PER_HEAD_ALIGN = 96 -HIDDEN_SIZE = HEAD_NUM * SIZE_PER_HEAD -MAX_POSITION_EMBEDDINGS = 1024 #2048 -ROTARY_EMB_BASE = 10000 -ROTARY_PCT = 0.5 -MAX_SEQ_LEN = 1024 -def get_ref_model(): - ref_net = SelfAttention(MAX_POSITION_EMBEDDINGS, HIDDEN_SIZE, HEAD_NUM, 0, None) - ref_net = ref_net.to(dtype=torch.bfloat16) - return ref_net - -def test_attn(): - inputs = [ - # qkv: [batch, seq_len, (num_heads * 3 * head_size)] - # layer_past: [batch, num_attention_heads, past_seq_len, head_size] - # attn: [batch, 1, query_seq_len, key_seq_len] - (np.random.random(size=[2, 200, 3 * HEAD_NUM * SIZE_PER_HEAD]).astype(np.float32), - np.random.random(size=[2, HEAD_NUM, 0, SIZE_PER_HEAD]).astype(np.float32), - np.random.random(size=[2, HEAD_NUM, 0, SIZE_PER_HEAD]).astype(np.float32), - np.zeros([2, 1, 200, 200], dtype=np.float32)), - (np.random.random(size=[2, 1, 3 * HEAD_NUM * SIZE_PER_HEAD]).astype(np.float32), - np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), - np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), - np.zeros([2, 1, 1, 201], dtype=np.float32)), - ] - ref_net = get_ref_model() - net = GPTAttentionExt(HEAD_NUM, HIDDEN_SIZE, SIZE_PER_HEAD_ALIGN, MAX_POSITION_EMBEDDINGS, ROTARY_EMB_BASE, ROTARY_PCT) - with torch.cpu.amp.autocast(): - for (i, input) in enumerate(inputs): - qkv, layer_past_key, layer_past_value, attn_mask = input - qkv = torch.from_numpy(qkv).to(torch.bfloat16) - layer_past_key = torch.from_numpy(layer_past_key).to(torch.bfloat16) - layer_past_value = torch.from_numpy(layer_past_value).to(torch.bfloat16) - attn_mask = torch.from_numpy(attn_mask) - attn_mask[:,:,:,-2:] = torch.finfo(torch.float32).min - past_seq_len = layer_past_key.shape[-2] - shape = list(layer_past_key.shape) - - if qkv.size(1) > 1: - seq_batch1 = torch.arange(end=qkv.size(1) - 1, dtype=torch.int32) - seq_batch1 = torch.concat((seq_batch1, seq_batch1[-1:])) - block_batch1 = torch.concat((torch.zeros(qkv.size(1) -1, dtype=torch.int32), torch.ones(1, dtype=torch.int32))) - else: - seq_batch1 = torch.tensor([3], dtype=torch.int32) - block_batch1 = torch.tensor([5], dtype=torch.int32) - seq_ids = torch.empty((qkv.size(0), 2, qkv.size(1)), dtype=torch.int32) - seq_ids[:, 0, :] = seq_batch1 - seq_ids[:, 1, :] = block_batch1 - output_ref, (key_ref, value_ref) = ref_net.forward(qkv, (layer_past_key, layer_past_value), attn_mask, seq_ids) - - shape[-2] = MAX_SEQ_LEN - shape[-1] = SIZE_PER_HEAD_ALIGN - layer_past_key_padded = torch.zeros(shape, dtype=torch.bfloat16) - layer_past_value_padded = torch.zeros(shape, dtype=torch.bfloat16) - layer_past_key_padded[:,:,:layer_past_key.shape[-2],:layer_past_key.shape[-1]] = layer_past_key - layer_past_value_padded[:,:,:layer_past_key.shape[-2],:layer_past_key.shape[-1]] = layer_past_value - key_ref = key_ref.to(dtype=torch.bfloat16) - output = net.forward(qkv, layer_past_key_padded, layer_past_value_padded, past_seq_len, attn_mask, seq_ids) - key, value = layer_past_key_padded, layer_past_value_padded - # check output - if not torch.allclose(output_ref, output, rtol=0.001, atol=0.01): - print(f"error at index {i} ref:\n{output_ref} \ncur:\n {output} ") - assert(False) - # check key - if not torch.allclose(key_ref, key[:,:,:key_ref.shape[-2],:key_ref.shape[-1]], rtol=0.001, atol=0.01): - print(f"error at key index {i} ref:\n{key_ref} \ncur:\n {key} ") - assert(False) - # check value - if not torch.allclose(value_ref, value[:,:,:value_ref.shape[-2],:value_ref.shape[-1]], rtol=0.001, atol=0.01): - print(f"error at value index {i} ref:\n{value_ref} \ncur:\n {value} ") - assert(False) - - print('done.') - return - -if __name__ == "__main__": - test_attn() \ No newline at end of file diff --git a/tests/script/test_mha_bloom.py b/tests/script/test_mha_bloom.py index fee6f39..cef9573 100644 --- a/tests/script/test_mha_bloom.py +++ b/tests/script/test_mha_bloom.py @@ -163,21 +163,11 @@ def forward( return context_layer class BloomAttentionExt: - def __init__(self, num_attention_heads, hidden_size, max_position_embeddings, is_int8=False): + def __init__(self): self.mha = ld.mha_gpt() - num_heads = num_attention_heads - head_size = hidden_size // num_attention_heads - max_seq_len = max_position_embeddings - head_size_aligned = head_size - normal_factor = 1.0 / math.sqrt(head_size) - qkv_precision_name = 's8' if is_int8 else 'bf16' - dst_precision_name = 's8' if is_int8 else 'bf16' - self.mha.create(num_heads, head_size, normal_factor, qkv_precision_name, - dst_precision_name, max_seq_len, True) - - def forward(self, query, key, value, alibi, attention_mask): - return self.mha.exec(query, key, value, alibi, attention_mask) + def forward(self, query, key, value, alibi, attention_mask, normal_factor): + return self.mha.exec(query, key, value, alibi, attention_mask, normal_factor, False) HEAD_NUM = 32 SIZE_PER_HEAD = 80 @@ -195,7 +185,7 @@ def test_bloom(): # k: [batch, num_heads, head_size, key_seq_len] # v: [batch, num_heads, value_seq_len, head_size] # alibi: [batch, num_heads, 1, key_seq_len] - # attn: [2, 1, 1, key_seq_len] + # attn: [2, 1, query_seq_len, key_seq_len] (np.random.random(size=[2, HEAD_NUM, 2, SIZE_PER_HEAD]).astype(np.float32), np.random.random(size=[2, HEAD_NUM, SIZE_PER_HEAD, 32]).astype(np.float32), np.random.random(size=[2, HEAD_NUM, 32, SIZE_PER_HEAD]).astype(np.float32), @@ -213,7 +203,7 @@ def test_bloom(): np.zeros([2, 1, 1, 200], dtype=np.float32)), ] ref_net = get_ref_model() - net = BloomAttentionExt(HEAD_NUM, HIDDEN_SIZE, MAX_POSITION_EMBEDDINGS) + net = BloomAttentionExt() with torch.cpu.amp.autocast(): for (i, input) in enumerate(inputs): q, k, v, alibi, attn_mask = input @@ -221,11 +211,11 @@ def test_bloom(): k = torch.from_numpy(k).to(torch.bfloat16) v = torch.from_numpy(v).to(torch.bfloat16) alibi = torch.from_numpy(alibi) # to(torch.bfloat16) - alibi = alibi.view(-1, alibi.size(2), alibi.size(3)) attn_mask = torch.from_numpy(attn_mask) attn_mask[:,:,:,-2:] = torch.finfo(torch.float32).min + output = net.forward(q, k, v, alibi, attn_mask, normal_factor = 1.0 / math.sqrt(SIZE_PER_HEAD)) + alibi = alibi.view(-1, alibi.size(2), alibi.size(3)) ref_output = ref_net.forward(q, k, v, alibi, attn_mask) - output = net.forward(q, k, v, alibi, attn_mask) if not torch.allclose(ref_output, output, rtol=0.001, atol=0.01): print(f"error at index {i} ref:\n{ref_output} \ncur:\n {output} ") assert(False) diff --git a/tests/script/test_mha_chatglm.py b/tests/script/test_mha_chatglm.py deleted file mode 100644 index 7aa2d3b..0000000 --- a/tests/script/test_mha_chatglm.py +++ /dev/null @@ -1,259 +0,0 @@ -# Copyright (C) 2018-2023 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 -# - -import math -import sys -import torch -import numpy as np -import llmdnn as ld -from torch import nn -import torch.nn.functional as F - -# copy from chatglm-6b/modeling_chatglm.py -class GPTNeoXAttention(nn.Module): - def __init__(self, config): - super().__init__() - self.num_attention_heads = config.num_attention_heads - self.hidden_size = config.hidden_size - self.head_size = self.hidden_size // self.num_attention_heads - self.norm_factor = torch.sqrt(torch.tensor(self.head_size, dtype=torch.float32)).to(torch.get_default_dtype()) - - def forward(self, - query_layer, # [b, np, sq, hn] - key_layer, # [b, np, sk, hn] - value_layer, # [b, np, sk, hn] - attention_mask # [b, np, s, s] - ): - return self.attention_fn(query_layer, key_layer, value_layer, attention_mask) - - def attention_fn( - self, - query_layer, - key_layer, - value_layer, - attention_mask - ): - query_layer = query_layer / self.norm_factor - - # =================================== - # Raw attention scores. [b, np, s, s] - # =================================== - - # [b, np, sq, sk] - output_size = (query_layer.size(0), query_layer.size(1), query_layer.size(2), key_layer.size(2)) - - # [b, np, sq, hn] -> [b * np, sq, hn] - query_layer = query_layer.view(output_size[0] * output_size[1], output_size[2], -1) - # [b, np, sk, hn] -> [b * np, sk, hn] - key_layer = key_layer.view(output_size[0] * output_size[1], output_size[3], -1) - - matmul_result = torch.zeros( - 1, 1, 1, - dtype=query_layer.dtype, - device=query_layer.device, - ) - - matmul_result = torch.baddbmm( - matmul_result, - query_layer, # [b * np, sq, hn] - key_layer.transpose(1, 2), # [b * np, hn, sk] - beta=0.0, - alpha=1.0, - ) - - # change view to [b, np, sq, sk] - attention_scores = matmul_result.view(*output_size) - - # if not (attention_mask == 0).all(): - # # if auto-regressive, skip - # attention_scores.masked_fill_(attention_mask, -10000.0) - attention_scores = attention_scores + attention_mask - dtype = attention_scores.dtype - attention_scores = attention_scores.float() - - attention_probs = F.softmax(attention_scores, dim=-1) - - attention_probs = attention_probs.type(dtype) - - # ========================= - # Context layer. [sq, b, hp] - # ========================= - - # value_layer -> context layer. - # [b, sk, np, hn] --> [b, np, sq, hn] - - # context layer shape: [b, np, sq, hn] - output_size = (value_layer.size(0), value_layer.size(1), query_layer.size(1), value_layer.size(3)) - - # [b, np, sk, hn] -> [b * np, sk, hn] - value_layer = value_layer.view(output_size[0] * output_size[1], value_layer.size(2), output_size[3]) - - # change view [b * np, sq, sk] - attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) - - # matmul: [b * np, sq, hn] - context_layer = torch.bmm(attention_probs, value_layer) - - # change view [b, np, sq, hn] - context_layer = context_layer.view(*output_size) - - # [b, np, sq, hn] --> [b, sq, np, hn] - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - - # [b, sq, np, hn] --> [b, sq, hp] - new_context_layer_shape = (output_size[0], output_size[2], output_size[1] * output_size[3]) - context_layer = context_layer.view(*new_context_layer_shape) - - return context_layer - -class GPTNeoXAttentionExt: - def __init__(self, num_attention_heads, hidden_size, head_size_aligned, max_position_embeddings, is_int8=False): - self.mha = ld.mha_gpt() - num_heads = num_attention_heads - head_size = hidden_size // num_attention_heads - max_seq_len = max_position_embeddings - - head_size_aligned = head_size_aligned - normal_factor = 1.0 / math.sqrt(head_size) - qkv_precision_name = 's8' if is_int8 else 'bf16' - dst_precision_name = 's8' if is_int8 else 'bf16' - self.mha.create(num_heads, head_size, normal_factor, qkv_precision_name, - dst_precision_name, max_seq_len) - - def forward(self, query, key, value, attention_mask, head_size, key_seq_len): - return self.mha.exec(query, key, value, torch.tensor([]), attention_mask, head_size, key_seq_len) - - def forward_quant(self, query, key, value, attention_mask, q_quant, k_quant, qk_quant, v_quant, requant): - # q_dequant, k_dequant, v_dequant, qk_quant, std::vector& qkv_quant - return self.mha.exec_quant(query, key, value, attention_mask, 1.0 / q_quant, 1.0 / k_quant, 1.0 / v_quant, qk_quant, requant) - -HEAD_NUM = 32 -SIZE_PER_HEAD = 80 -SIZE_PER_HEAD_ALIGN = 96 -HIDDEN_SIZE = HEAD_NUM * SIZE_PER_HEAD -MAX_POSITION_EMBEDDINGS = 1024 #2048 -def get_ref_model(): - class FakeConfig: - def __init__(self): - self.num_attention_heads = HEAD_NUM - self.hidden_size = HIDDEN_SIZE - self.max_position_embeddings = MAX_POSITION_EMBEDDINGS - config = FakeConfig() - ref_net = GPTNeoXAttention(config) - ref_net = ref_net.to(dtype=torch.bfloat16) - return ref_net - -def test_chatglm_neox(): - inputs = [ - # q, k, v, attn_mask - # q: [batch, num_heads, query_seq_len, head_size] - # k: [batch, num_heads, key_seq_len, head_size] - # v: [batch, num_heads, value_seq_len, head_size] - # attn: [batch, 1, query_seq_len, key_seq_len] - (np.random.random(size=[2, HEAD_NUM, 2, SIZE_PER_HEAD]).astype(np.float32), - np.random.random(size=[2, HEAD_NUM, 32, SIZE_PER_HEAD]).astype(np.float32), - np.random.random(size=[2, HEAD_NUM, 32, SIZE_PER_HEAD]).astype(np.float32), - np.zeros([2, 1, 2, 32], dtype=np.float32)), - (np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), - np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), - np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), - np.zeros([2, 1, 200, 200], dtype=np.float32)), - (np.random.random(size=[2, HEAD_NUM, 1, SIZE_PER_HEAD]).astype(np.float32), - np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), - np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), - np.zeros([2, 1, 1, 200], dtype=np.float32)), - ] - ref_net = get_ref_model() - net = GPTNeoXAttentionExt(HEAD_NUM, HIDDEN_SIZE, SIZE_PER_HEAD_ALIGN, MAX_POSITION_EMBEDDINGS) - with torch.cpu.amp.autocast(): - for (i, input) in enumerate(inputs): - q, k, v, attn_mask = input - q = torch.from_numpy(q).to(torch.bfloat16) - k = torch.from_numpy(k).to(torch.bfloat16) - v = torch.from_numpy(v).to(torch.bfloat16) - attn_mask = torch.from_numpy(attn_mask) - attn_mask[:,:,:,-2:] = torch.finfo(torch.float32).min - ref_output = ref_net.forward(q, k, v, attn_mask) - - shape = list(k.shape) - shape[-2] = MAX_POSITION_EMBEDDINGS - shape[-1] = SIZE_PER_HEAD_ALIGN - key_padded = torch.zeros(shape, dtype=torch.bfloat16) - value_padded = torch.zeros(shape, dtype=torch.bfloat16) - query_shape = list(q.shape) - query_shape[-1] = SIZE_PER_HEAD_ALIGN - query_padded = torch.zeros(query_shape, dtype=torch.bfloat16) - key_padded[:,:,:k.shape[-2],:k.shape[-1]] = k - value_padded[:,:,:k.shape[-2],:k.shape[-1]] = v - query_padded[:,:,:,:q.shape[-1]] = q - - output = net.forward(query_padded, key_padded, value_padded, attn_mask, k.size(3), k.size(2)) - if not torch.allclose(ref_output, output, rtol=0.001, atol=0.01): - print(f"error at index {i} ref:\n{ref_output} \ncur:\n {output} ") - assert(False) - - print('done.') - return - -def todotest_gpt_neox_int8(): - low = -4 - high = 4 - range_ = high - low - q_quant = 127.0 / high - qs = [ - np.random.random(size=[2, HEAD_NUM, 900, SIZE_PER_HEAD]).astype(np.float32)*range_+low, - np.random.random(size=[2, HEAD_NUM , 1, SIZE_PER_HEAD]).astype(np.float32)*range_+low, - np.random.random(size=[2, HEAD_NUM , 1, SIZE_PER_HEAD]).astype(np.float32)*range_+low, - ] - low = -2 - high = 2 - range_ = high - low - k_quant = 127.0 / high - ks = [ - np.random.random(size=[2, HEAD_NUM, 900, SIZE_PER_HEAD]).astype(np.float32)*range_+low, - np.random.random(size=[2, HEAD_NUM, 901, SIZE_PER_HEAD]).astype(np.float32)*range_+low, - np.random.random(size=[2, HEAD_NUM, 902, SIZE_PER_HEAD]).astype(np.float32)*range_+low, - ] - low = -8 - high = 8 - range_ = high - low - v_quant = 127.0 / high - vs = [ - np.random.random(size=[2, HEAD_NUM, 900, SIZE_PER_HEAD]).astype(np.float32)*range_+low, - np.random.random(size=[2, HEAD_NUM, 901, SIZE_PER_HEAD]).astype(np.float32)*range_+low, - np.random.random(size=[2, HEAD_NUM, 902, SIZE_PER_HEAD]).astype(np.float32)*range_+low, - ] - # q, k, v, attn_mask - # q: [batch, num_heads, query_seq_len, head_size] - # k: [batch, num_heads, key_seq_len, head_size] - # v: [batch, num_heads, value_seq_len, head_size] - # attn: [1, MAX_POSITION_EMBEDDINGS] - inputs = [] - for i in range(len(qs)): - inputs.append((qs[i], ks[i], vs[i], np.zeros([1, ks[i].shape[-2]], dtype=np.float32))) - - ref_net = get_ref_model() - net = GPTNeoXAttentionExt(HEAD_NUM, HIDDEN_SIZE, MAX_POSITION_EMBEDDINGS, True) - qk_quant, requant = 255.0, 10.0 - with torch.cpu.amp.autocast(): - for (i, input) in enumerate(inputs): - q, k, v, attn_mask = input - q = torch.from_numpy(q) - k = torch.from_numpy(k) - v = torch.from_numpy(v) - attn_mask = torch.from_numpy(attn_mask) - q = (q * q_quant).round().clamp(-128, 127).to(torch.int8) - k = (k * k_quant).round().clamp(-128, 127).to(torch.int8) - v = (v * v_quant).round().clamp(-128, 127).to(torch.int8) - ref_output = ref_net.forward(q, k, v, attn_mask, q_quant, k_quant, qk_quant, v_quant, requant) - output = net.forward_quant(q, k, v, attn_mask, q_quant, k_quant, qk_quant, v_quant, [requant,]) - if (torch.abs(ref_output- output) > 2).any(): - print(f"error at index {i} ref:\n{ref_output} \ncur:\n {output} ") - assert(False) - - print('done.') - return - -if __name__ == "__main__": - test_chatglm_neox() \ No newline at end of file diff --git a/tests/script/test_mha_gpt.py b/tests/script/test_mha_gpt.py index b3e7c8c..23b880a 100644 --- a/tests/script/test_mha_gpt.py +++ b/tests/script/test_mha_gpt.py @@ -105,25 +105,11 @@ def _attn(self, query, key, value, attention_mask, q_quant, k_quant, qk_quant, v return attn_output, attn_weights class GPTNeoXAttentionExt: - def __init__(self, num_attention_heads, hidden_size, max_position_embeddings, is_int8=False): + def __init__(self): self.mha = ld.mha_gpt() - num_heads = num_attention_heads - head_size = hidden_size // num_attention_heads - max_seq_len = max_position_embeddings - head_size_aligned = head_size - normal_factor = 1.0 / math.sqrt(head_size) - qkv_precision_name = 's8' if is_int8 else 'bf16' - dst_precision_name = 's8' if is_int8 else 'bf16' - self.mha.create(num_heads, head_size, normal_factor, qkv_precision_name, - dst_precision_name, max_seq_len) - - def forward(self, query, key, value, attention_mask): - return self.mha.exec(query, key, value, torch.tensor([]), attention_mask) - - def forward_quant(self, query, key, value, attention_mask, q_quant, k_quant, qk_quant, v_quant, requant): - # q_dequant, k_dequant, v_dequant, qk_quant, std::vector& qkv_quant - return self.mha.exec_quant(query, key, value, attention_mask, 1.0 / q_quant, 1.0 / k_quant, 1.0 / v_quant, qk_quant, requant) + def forward(self, query, key, value, attention_mask, normal_factor): + return self.mha.exec(query, key, value, torch.tensor([]), attention_mask, normal_factor, True) HEAD_NUM = 32 SIZE_PER_HEAD = 80 @@ -161,7 +147,7 @@ def test_gpt_neox(): np.zeros([2, 1, 1, 200], dtype=np.float32)), ] ref_net = get_ref_model() - net = GPTNeoXAttentionExt(HEAD_NUM, HIDDEN_SIZE, MAX_POSITION_EMBEDDINGS) + net = GPTNeoXAttentionExt() with torch.cpu.amp.autocast(): for (i, input) in enumerate(inputs): q, k, v, attn_mask = input @@ -171,7 +157,7 @@ def test_gpt_neox(): attn_mask = torch.from_numpy(attn_mask) attn_mask[:,:,:,-2:] = torch.finfo(torch.float32).min ref_output = ref_net.forward(q, k, v, attn_mask) - output = net.forward(q, k, v, attn_mask) + output = net.forward(q, k, v, attn_mask, normal_factor = 1.0 / math.sqrt(SIZE_PER_HEAD)) if not torch.allclose(ref_output, output, rtol=0.001, atol=0.01): print(f"error at index {i} ref:\n{ref_output} \ncur:\n {output} ") assert(False) @@ -179,65 +165,5 @@ def test_gpt_neox(): print('done.') return -def test_gpt_neox_int8(): - low = -4 - high = 4 - range_ = high - low - q_quant = 127.0 / high - qs = [ - np.random.random(size=[2, HEAD_NUM, 900, SIZE_PER_HEAD]).astype(np.float32)*range_+low, - np.random.random(size=[2, HEAD_NUM , 1, SIZE_PER_HEAD]).astype(np.float32)*range_+low, - np.random.random(size=[2, HEAD_NUM , 1, SIZE_PER_HEAD]).astype(np.float32)*range_+low, - ] - low = -2 - high = 2 - range_ = high - low - k_quant = 127.0 / high - ks = [ - np.random.random(size=[2, HEAD_NUM, 900, SIZE_PER_HEAD]).astype(np.float32)*range_+low, - np.random.random(size=[2, HEAD_NUM, 901, SIZE_PER_HEAD]).astype(np.float32)*range_+low, - np.random.random(size=[2, HEAD_NUM, 902, SIZE_PER_HEAD]).astype(np.float32)*range_+low, - ] - low = -8 - high = 8 - range_ = high - low - v_quant = 127.0 / high - vs = [ - np.random.random(size=[2, HEAD_NUM, 900, SIZE_PER_HEAD]).astype(np.float32)*range_+low, - np.random.random(size=[2, HEAD_NUM, 901, SIZE_PER_HEAD]).astype(np.float32)*range_+low, - np.random.random(size=[2, HEAD_NUM, 902, SIZE_PER_HEAD]).astype(np.float32)*range_+low, - ] - # q, k, v, attn_mask - # q: [batch, num_heads, query_seq_len, head_size] - # k: [batch, num_heads, key_seq_len, head_size] - # v: [batch, num_heads, value_seq_len, head_size] - # attn: [batch, 1, 1, key_seq_len] - inputs = [] - for i in range(len(qs)): - inputs.append((qs[i], ks[i], vs[i], np.zeros([ks[i].shape[0], 1, 1, ks[i].shape[-2]], dtype=np.float32))) - - ref_net = get_ref_model() - net = GPTNeoXAttentionExt(HEAD_NUM, HIDDEN_SIZE, MAX_POSITION_EMBEDDINGS, True) - qk_quant, requant = 255.0, 10.0 - with torch.cpu.amp.autocast(): - for (i, input) in enumerate(inputs): - q, k, v, attn_mask = input - q = torch.from_numpy(q) - k = torch.from_numpy(k) - v = torch.from_numpy(v) - attn_mask = torch.from_numpy(attn_mask) - q = (q * q_quant).round().clamp(-128, 127).to(torch.int8) - k = (k * k_quant).round().clamp(-128, 127).to(torch.int8) - v = (v * v_quant).round().clamp(-128, 127).to(torch.int8) - ref_output = ref_net.forward(q, k, v, attn_mask, q_quant, k_quant, qk_quant, v_quant, requant) - output = net.forward_quant(q, k, v, attn_mask, q_quant, k_quant, qk_quant, v_quant, [requant,]) - if (torch.abs(ref_output- output) > 2).any(): - print(f"error at index {i} ref:\n{ref_output} \ncur:\n {output} ") - assert(False) - - print('done.') - return - if __name__ == "__main__": test_gpt_neox() - test_gpt_neox_int8() \ No newline at end of file diff --git a/tests/script/test_rotary_pastkv.py b/tests/script/test_rotary_pastkv.py index 402720f..edb1866 100644 --- a/tests/script/test_rotary_pastkv.py +++ b/tests/script/test_rotary_pastkv.py @@ -108,16 +108,11 @@ def forward(self, qkv, layer_past): class GPTNeoXAttentionExt: def __init__(self, num_attention_heads, hidden_size, head_size_aligned, max_position_embeddings, rotary_emb_base, rotary_pct, is_int8=False): - self.emd = ld.emb_gpt() num_heads = num_attention_heads head_size = hidden_size // num_attention_heads max_seq_len = max_position_embeddings - qkv_precision_name = 's8' if is_int8 else 'bf16' - dst_precision_name = 's8' if is_int8 else 'bf16' rotary_ndims = int(head_size * rotary_pct) - self.emd.create(num_heads, head_size, head_size_aligned, qkv_precision_name, - dst_precision_name, rotary_ndims) inv_freq = 1.0 / (rotary_emb_base ** (torch.arange(0, rotary_ndims, 2).float() / rotary_ndims)) @@ -138,13 +133,13 @@ def __init__(self, num_attention_heads, hidden_size, head_size_aligned, max_posi # 1: query: [batch, num_attention_heads, seq_len, head_size_aligned] # 2: k: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] # 3: v: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] - def forward(self, qkv, layer_past_key_padded, layer_past_value_padded, query_padded, past_seq_len): - self.emd.exec(qkv, layer_past_key_padded, layer_past_value_padded, query_padded, past_seq_len, self.cos_cached, self.sin_cached) + def forward(self, qkv, k_past, v_past): + return ld.emb_gpt(qkv, k_past, v_past, self.cos_cached, self.sin_cached, torch.tensor([])) HEAD_NUM = 32 SIZE_PER_HEAD = 80 -SIZE_PER_HEAD_ALIGN = 96 +SIZE_PER_HEAD_ALIGN = 80 HIDDEN_SIZE = HEAD_NUM * SIZE_PER_HEAD MAX_POSITION_EMBEDDINGS = 1024 #2048 ROTARY_EMB_BASE = 10000 @@ -180,42 +175,26 @@ def test_gpt_neox(): for (i, input) in enumerate(inputs): qkv, layer_past_key, layer_past_value = input qkv = torch.from_numpy(qkv).to(torch.bfloat16) - past_seq_len = layer_past_key.shape[-2] - shape = list(layer_past_key.shape) - shape[-2] = MAX_SEQ_LEN - shape[-1] = SIZE_PER_HEAD_ALIGN - layer_past_key_padded = torch.zeros(shape, dtype=torch.bfloat16) - layer_past_value_padded = torch.zeros(shape, dtype=torch.bfloat16) - query_shape = list(shape) - query_shape[2] = qkv.shape[1] - query_padded = torch.zeros(query_shape, dtype=torch.bfloat16) layer_past_key = torch.from_numpy(layer_past_key).to(torch.bfloat16) layer_past_value = torch.from_numpy(layer_past_value).to(torch.bfloat16) - layer_past_key_padded[:,:,:layer_past_key.shape[-2],:layer_past_key.shape[-1]] = layer_past_key - layer_past_value_padded[:,:,:layer_past_key.shape[-2],:layer_past_key.shape[-1]] = layer_past_value - ref_output, query_ref, key_ref, value_ref = ref_net.forward(qkv, (layer_past_key, layer_past_value)) + + _, query_ref, key_ref, value_ref = ref_net.forward(qkv, (layer_past_key, layer_past_value)) query_ref = query_ref.to(dtype=torch.bfloat16) key_ref = key_ref.to(dtype=torch.bfloat16) - net.forward(qkv, layer_past_key_padded, layer_past_value_padded, query_padded, past_seq_len) - output, query, key, value = (layer_past_key_padded, layer_past_value_padded), query_padded, layer_past_key_padded, layer_past_value_padded - # check output - if not torch.allclose(ref_output[0].to(dtype=torch.bfloat16), output[0][:,:,:ref_output[0].shape[-2],:ref_output[0].shape[-1]], rtol=0.001, atol=0.01): - print(f"error at past key index {i} ref:\n{ref_output[0]} \ncur:\n {output[0]} ") - assert(False) - if not torch.allclose(ref_output[1], output[1][:,:,:ref_output[1].shape[-2],:ref_output[1].shape[-1]], rtol=0.001, atol=0.01): - print(f"error at past value index {i} ref:\n{ref_output[1]} \ncur:\n {output[1]} ") - assert(False) + + # no prealloc past kv + query, key, value = net.forward(qkv, layer_past_key, layer_past_value) # check query - if not torch.allclose(query_ref, query[:,:,:,:query_ref.shape[-1]], rtol=0.001, atol=0.01): - print(f"error at query index {i} ref:\n{query_ref} \ncur:\n {query} ") + if not torch.allclose(query_ref, query, rtol=0.001, atol=0.01): + print(f"error at sequence query index {i} ref:\n{query_ref} \ncur:\n {query} ") assert(False) # check key - if not torch.allclose(key_ref, key[:,:,:key_ref.shape[-2],:key_ref.shape[-1]], rtol=0.001, atol=0.01): - print(f"error at key index {i} ref:\n{key_ref} \ncur:\n {key} ") + if not torch.allclose(key_ref, key, rtol=0.001, atol=0.01): + print(f"error at sequence key index {i} ref:\n{key_ref} \ncur:\n {key} ") assert(False) # check value - if not torch.allclose(value_ref, value[:,:,:value_ref.shape[-2],:value_ref.shape[-1]], rtol=0.001, atol=0.01): - print(f"error at value index {i} ref:\n{value_ref} \ncur:\n {value} ") + if not torch.allclose(value_ref, value, rtol=0.001, atol=0.01): + print(f"error at sequence value index {i} ref:\n{value_ref} \ncur:\n {value} ") assert(False) print('done.') diff --git a/tests/script/test_rotary_pastkv_chatglm.py b/tests/script/test_rotary_pastkv_chatglm.py index 612c7d4..92f25f7 100644 --- a/tests/script/test_rotary_pastkv_chatglm.py +++ b/tests/script/test_rotary_pastkv_chatglm.py @@ -250,16 +250,11 @@ def forward( class GPTNeoXAttentionExt: def __init__(self, num_attention_heads, hidden_size, head_size_aligned, max_position_embeddings, rotary_emb_base, rotary_pct, is_int8=False): - self.emd = ld.emb_gpt() num_heads = num_attention_heads head_size = hidden_size // num_attention_heads max_seq_len = max_position_embeddings - qkv_precision_name = 's8' if is_int8 else 'bf16' - dst_precision_name = 's8' if is_int8 else 'bf16' rotary_ndims = int(head_size * rotary_pct) - self.emd.create(num_heads, head_size, head_size_aligned, qkv_precision_name, - dst_precision_name, rotary_ndims, True) inv_freq = 1. / (rotary_emb_base ** (torch.arange(0, rotary_ndims, 2).float() / rotary_ndims)) #inv_freq = inv_freq.half() @@ -273,8 +268,8 @@ def __init__(self, num_attention_heads, hidden_size, head_size_aligned, max_posi freqs = torch.einsum("i,j->ij", t, inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.cos_cached = emb.cos()[:, None, :] - self.sin_cached = emb.sin()[:, None, :] + self.cos_cached = emb.cos()[None, None, :, :] + self.sin_cached = emb.sin()[None, None, :, :] # qkv: [batch, seq_len, (num_heads * 3 * head_size)] # layer_past_padded: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] @@ -284,13 +279,13 @@ def __init__(self, num_attention_heads, hidden_size, head_size_aligned, max_posi # 1: query: [batch, num_attention_heads, seq_len, head_size_aligned] # 2: k: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] # 3: v: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] - def forward(self, qkv, layer_past_key_src, layer_past_value_src, layer_past_key_dst, layer_past_value_dst, query_padded, past_seq_len, position_ids): - self.emd.exec_position(qkv, layer_past_key_src, layer_past_value_src, layer_past_key_dst, layer_past_value_dst, query_padded, past_seq_len, position_ids, self.cos_cached, self.sin_cached) + def forward(self, qkv, k_past, v_past, position_ids): + return ld.emb_gpt(qkv, k_past, v_past, self.cos_cached, self.sin_cached, position_ids) HEAD_NUM = 32 SIZE_PER_HEAD = 80 -SIZE_PER_HEAD_ALIGN = 96 +SIZE_PER_HEAD_ALIGN = 80 HIDDEN_SIZE = HEAD_NUM * SIZE_PER_HEAD MAX_POSITION_EMBEDDINGS = 1024 #2048 ROTARY_EMB_BASE = 10000 @@ -321,8 +316,6 @@ def test_chatglm(): qkv = torch.from_numpy(qkv).to(torch.bfloat16) layer_past_key = torch.from_numpy(layer_past_key).to(torch.bfloat16) layer_past_value = torch.from_numpy(layer_past_value).to(torch.bfloat16) - past_seq_len = layer_past_key.shape[-2] - shape = list(layer_past_key.shape) if qkv.size(1) > 1: seq_batch1 = torch.arange(end=qkv.size(1) - 1, dtype=torch.int32) @@ -339,15 +332,11 @@ def test_chatglm(): key_ref = key_ref.to(dtype=torch.bfloat16) # no prealloc past kv - layer_past_key_seq = torch.zeros(key_ref.shape, dtype=torch.bfloat16) - layer_past_value_seq = torch.zeros(value_ref.shape, dtype=torch.bfloat16) - query_seq = torch.zeros(query_ref.shape, dtype=torch.bfloat16) - net_seq.forward(qkv, layer_past_key, layer_past_value, layer_past_key_seq, layer_past_value_seq, query_seq, past_seq_len, seq_ids) - query, key, value = query_seq, layer_past_key_seq, layer_past_value_seq + query, key, value = net_seq.forward(qkv, layer_past_key, layer_past_value, seq_ids) # check query if not torch.allclose(query_ref, query, rtol=0.001, atol=0.01): print(f"error at sequence query index {i} ref:\n{query_ref} \ncur:\n {query} ") - #assert(False) + assert(False) # check key if not torch.allclose(key_ref, key, rtol=0.001, atol=0.01): print(f"error at sequence key index {i} ref:\n{key_ref} \ncur:\n {key} ") @@ -357,31 +346,6 @@ def test_chatglm(): print(f"error at sequence value index {i} ref:\n{value_ref} \ncur:\n {value} ") assert(False) - # prealloc past kv - shape[-2] = MAX_SEQ_LEN - shape[-1] = SIZE_PER_HEAD_ALIGN - layer_past_key_padded = torch.zeros(shape, dtype=torch.bfloat16) - layer_past_value_padded = torch.zeros(shape, dtype=torch.bfloat16) - query_shape = list(shape) - query_shape[2] = qkv.shape[1] - query_padded = torch.zeros(query_shape, dtype=torch.bfloat16) - layer_past_key_padded[:,:,:layer_past_key.shape[-2],:layer_past_key.shape[-1]] = layer_past_key - layer_past_value_padded[:,:,:layer_past_key.shape[-2],:layer_past_key.shape[-1]] = layer_past_value - net.forward(qkv, layer_past_key_padded, layer_past_value_padded, layer_past_key_padded, layer_past_value_padded, query_padded, past_seq_len, seq_ids) - query, key, value = query_padded, layer_past_key_padded, layer_past_value_padded - # check query - if not torch.allclose(query_ref, query[:,:,:,:query_ref.shape[-1]], rtol=0.001, atol=0.01): - print(f"error at query index {i} ref:\n{query_ref} \ncur:\n {query} ") - assert(False) - # check key - if not torch.allclose(key_ref, key[:,:,:key_ref.shape[-2],:key_ref.shape[-1]], rtol=0.001, atol=0.01): - print(f"error at key index {i} ref:\n{key_ref} \ncur:\n {key} ") - assert(False) - # check value - if not torch.allclose(value_ref, value[:,:,:value_ref.shape[-2],:value_ref.shape[-1]], rtol=0.001, atol=0.01): - print(f"error at value index {i} ref:\n{value_ref} \ncur:\n {value} ") - assert(False) - print('done.') return From 0db2ce884ef2a1051bb43c0fee7ad525866265af Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Sat, 15 Jul 2023 19:31:16 +0800 Subject: [PATCH 30/54] wa gcc9 could not find 'std::__throw_bad_array_new_length()' --- CMakeLists.txt | 1 - src/CMakeLists.txt | 1 + src/common/compatible.hpp | 39 +++++++++++++++ src/fc_kernel_amx.cpp | 25 ++++------ src/mha_gpt_amx.cpp | 101 +++++++++++++++++++------------------- src/mm_kernel_amx.cpp | 27 +++++----- 6 files changed, 111 insertions(+), 83 deletions(-) create mode 100644 src/common/compatible.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 831a44f..d499054 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -13,7 +13,6 @@ message(INFO "--------------------------------") message(STATUS "Build with tests: ${CPU_EXTENSIONS_BUILD_TESTS}") message(INFO "--------------------------------") -set(CMAKE_CXX_STANDARD 17) if(MSVC) # TODO: validate if(MSVC_VERSION VERSION_LESS 1928) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 90594d5..1a4d8b8 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -14,6 +14,7 @@ target_include_directories(${PROJECT_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} PUBLIC $ $/${CMAKE_INSTALL_INCLUDEDIR}>) target_compile_options(${PROJECT_NAME} PRIVATE ${EXTRA_CXX_FLAGS}) +target_compile_features(${PROJECT_NAME} PRIVATE cxx_std_17) set(CMAKE_DST lib/cmake/${PROJECT_NAME}) # header files diff --git a/src/common/compatible.hpp b/src/common/compatible.hpp new file mode 100644 index 0000000..f3d2af1 --- /dev/null +++ b/src/common/compatible.hpp @@ -0,0 +1,39 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include + +// gcc 9 does not recognize 'std::__throw_bad_array_new_length()' which is imported by +// gcc 11. The symbol exists in std::allocator::allocate, use custom to wa. +template +class custom_allocator { +public: + using value_type = T; + custom_allocator() noexcept = default; + template + custom_allocator (const custom_allocator&) noexcept {} + inline T* allocate(std::allocator::size_type cnt, typename std::allocator::const_pointer = 0) { + return static_cast(::operator new(cnt * sizeof(T))); + } + void deallocate (T* p, std::size_t n) { + ::operator delete(p); + } +}; + +template +bool operator==(custom_allocator const&, custom_allocator const&) noexcept { + return true; +} + +template +bool operator!=(custom_allocator const& x, custom_allocator const& y) noexcept { + return !(x == y); +} + +template +using llm_vector = std::vector>; diff --git a/src/fc_kernel_amx.cpp b/src/fc_kernel_amx.cpp index 9b91e89..c13c84d 100644 --- a/src/fc_kernel_amx.cpp +++ b/src/fc_kernel_amx.cpp @@ -1,18 +1,13 @@ // Copyright (C) 2018-2023 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // +#include #include #include #include -#include -#include -#include #include -#include -#include -#include -#include #include +#include #include "llm_fc.hpp" #include "mm_kernel_common_amx.hpp" @@ -23,10 +18,10 @@ namespace llmdnn { using ov::bfloat16; struct fc_kernel { - std::shared_ptr> bf16xbf16; - std::shared_ptr> bf16xi8; - std::shared_ptr> i8xi8; - std::shared_ptr> u8xi8; + std::unique_ptr> bf16xbf16; + std::unique_ptr> bf16xi8; + std::unique_ptr> i8xi8; + std::unique_ptr> u8xi8; data_type_t dt_a; data_type_t dt_b; @@ -83,13 +78,13 @@ bool fc_kernel_create_amx(fc_kernel** mm, const fc_create_param* param) { m = new fc_kernel; if (param->dt_a == dnnl_s8 && param->dt_b == dnnl_s8) { - m->i8xi8 = std::make_shared>(true, param->b_is_trans); + m->i8xi8 = std::make_unique>(true, param->b_is_trans); } else if (param->dt_a == dnnl_u8 && param->dt_b == dnnl_s8) { - m->u8xi8 = std::make_shared>(true, param->b_is_trans); + m->u8xi8 = std::make_unique>(true, param->b_is_trans); } else if (param->dt_a == dnnl_bf16 && param->dt_b == dnnl_bf16) { - m->bf16xbf16 = std::make_shared>(true, param->b_is_trans); + m->bf16xbf16 = std::make_unique>(true, param->b_is_trans); } else if (param->dt_a == dnnl_bf16 && param->dt_b == dnnl_s8) { - m->bf16xi8 = std::make_shared>(true, param->b_is_trans); + m->bf16xi8 = std::make_unique>(true, param->b_is_trans); m->bf16xi8->quant_scale_B = param->q; m->bf16xi8->dequant_scale_B = param->dq; } else { diff --git a/src/mha_gpt_amx.cpp b/src/mha_gpt_amx.cpp index 846a89e..4d174d6 100644 --- a/src/mha_gpt_amx.cpp +++ b/src/mha_gpt_amx.cpp @@ -4,11 +4,11 @@ #include #include -#include #include "common/simple_parallel.hpp" #include "common/tensor2d.hpp" #include "common/utility.hpp" +#include "common/compatible.hpp" #include "llm_types.hpp" #include "utility_kernel_avx512.hpp" #include "mm_kernel_common_amx.hpp" @@ -22,6 +22,8 @@ using namespace ov::cpu; namespace llmdnn { struct mha_gpt_impl_amx : public mha_gpt::impl { + mha_gpt_impl_amx() = default; + ~mha_gpt_impl_amx(); void create(data_type_t in_type, size_t seq_len, size_t head_size, bool is_bloom); void exec(const tensor& q, const tensor& k, const tensor& v, const tensor& output, const tensor& attn_mask, const tensor& alibi, float normal_factor, bool use_causal_mask) override; @@ -32,18 +34,31 @@ struct mha_gpt_impl_amx : public mha_gpt::impl { size_t _buffer_mat1_out_size = 0; size_t _num_threads = 0; - std::shared_ptr _buffer_mat0_out; - std::shared_ptr _buffer_mat1_out; + uint8_t* _buffer_mat0_out = nullptr; + uint8_t* _buffer_mat1_out = nullptr; - std::vector>> gemAvB_BF16xBF16; - std::vector>> qKtrGemm_BF16xBF16; - std::vector>> qKVGemm_BF16xBF16; - - std::vector>> qKtrGemm_i8xi8; - std::vector>> qKVGemm_u8xi8; - std::vector>> gemAvB_i8xi8; + llm_vector*> gemAvB_BF16xBF16; + llm_vector*> qKtrGemm_BF16xBF16; + llm_vector*> qKVGemm_BF16xBF16; }; +mha_gpt_impl_amx::~mha_gpt_impl_amx() { + for (size_t i = 0; i < gemAvB_BF16xBF16.size(); i++) { + delete gemAvB_BF16xBF16[i]; + } + for (size_t i = 0; i < qKtrGemm_BF16xBF16.size(); i++) { + delete qKtrGemm_BF16xBF16[i]; + } + for (size_t i = 0; i < qKVGemm_BF16xBF16.size(); i++) { + delete qKVGemm_BF16xBF16[i]; + } + + if (_buffer_mat0_out) + free(_buffer_mat0_out); + if (_buffer_mat1_out) + free(_buffer_mat1_out); +} + void mha_gpt_impl_amx::create(data_type_t in_type, size_t seq_len, size_t head_size, bool is_bloom) { // q: [batch, head_num, query_seq_len, head_size] // k: [batch, head_num, maxSeqLen(valid: key_seq_len), head_size] @@ -53,34 +68,18 @@ void mha_gpt_impl_amx::create(data_type_t in_type, size_t seq_len, size_t head_s // attn_output: [batch, query_seq_len, head_num * head_size] if (_num_threads == 0) { _num_threads = getTotalThreads(); - if (in_type == dnnl_s8) { - _head_size_aligned = rndup(head_size, 64); - qKtrGemm_i8xi8.resize(_num_threads); - for (size_t i = 0; i < _num_threads; i++) { - qKtrGemm_i8xi8[i] = std::make_shared>(false, !is_bloom); - } - qKVGemm_u8xi8.resize(_num_threads); - for (size_t i = 0; i < _num_threads; i++) { - qKVGemm_u8xi8[i] = std::make_shared>(false, false); - } - gemAvB_i8xi8.resize(_num_threads); - for (size_t i = 0; i < _num_threads; i++) { - gemAvB_i8xi8[i] = std::make_shared>(); - } - } else { - _head_size_aligned = rndup(head_size, 32); - gemAvB_BF16xBF16.resize(_num_threads); - for (size_t i = 0; i < _num_threads; i++) { - gemAvB_BF16xBF16[i] = std::make_shared>(); - } - qKtrGemm_BF16xBF16.resize(_num_threads); - for (size_t i = 0; i < _num_threads; i++) { - qKtrGemm_BF16xBF16[i] = std::make_shared>(false, !is_bloom); - } - qKVGemm_BF16xBF16.resize(_num_threads); - for (size_t i = 0; i < _num_threads; i++) { - qKVGemm_BF16xBF16[i] = std::make_shared>(false, false); - } + _head_size_aligned = rndup(head_size, 32); + gemAvB_BF16xBF16.resize(_num_threads); + for (size_t i = 0; i < _num_threads; i++) { + gemAvB_BF16xBF16[i] = new amx_kernel::MatmulVector(); + } + qKtrGemm_BF16xBF16.resize(_num_threads); + for (size_t i = 0; i < _num_threads; i++) { + qKtrGemm_BF16xBF16[i] = new amx_kernel::Matmul(false, !is_bloom); + } + qKVGemm_BF16xBF16.resize(_num_threads); + for (size_t i = 0; i < _num_threads; i++) { + qKVGemm_BF16xBF16[i] = new amx_kernel::Matmul(false, false); } } @@ -88,15 +87,15 @@ void mha_gpt_impl_amx::create(data_type_t in_type, size_t seq_len, size_t head_s if (buffer_mat0_out_size > _buffer_mat0_out_size) { _buffer_mat0_out_size = seq_len * rndup(seq_len * sizeof(float), 64) * 3 / 2; _buffer_mat1_out_size = seq_len * _head_size_aligned * sizeof(float) * 3 / 2; + if (_buffer_mat0_out) + free(_buffer_mat0_out); + if (_buffer_mat1_out) + free(_buffer_mat1_out); - _buffer_mat0_out = std::shared_ptr( - reinterpret_cast(aligned_alloc(64, _num_threads * _buffer_mat0_out_size)), - [](void * p) { ::free(p); }); - memset(_buffer_mat0_out.get(), 0, _num_threads * _buffer_mat0_out_size); - _buffer_mat1_out = std::shared_ptr( - reinterpret_cast(aligned_alloc(64, _num_threads * _buffer_mat1_out_size)), - [](void * p) { ::free(p); }); - memset(_buffer_mat1_out.get(), 0, _num_threads * _buffer_mat1_out_size); + _buffer_mat0_out = reinterpret_cast(aligned_alloc(64, _num_threads * _buffer_mat0_out_size)); + memset(_buffer_mat0_out, 0, _num_threads * _buffer_mat0_out_size); + _buffer_mat1_out = reinterpret_cast(aligned_alloc(64, _num_threads * _buffer_mat1_out_size)); + memset(_buffer_mat1_out, 0, _num_threads * _buffer_mat1_out_size); } } @@ -125,8 +124,8 @@ void mha_gpt_impl_amx::mha_bf16(const tensor& q, const tensor& k, const tensor& auto k_sub = &k.at({i0, i1}); auto v_sub = &v.at({i0, i1}); - auto mat0_out = reinterpret_cast(_buffer_mat0_out.get() + thread_id * _buffer_mat0_out_size); - auto mat1_out = reinterpret_cast(_buffer_mat1_out.get() + thread_id * _buffer_mat1_out_size); + auto mat0_out = reinterpret_cast(_buffer_mat0_out + thread_id * _buffer_mat0_out_size); + auto mat1_out = reinterpret_cast(_buffer_mat1_out + thread_id * _buffer_mat1_out_size); tensor2D matK(key_seq_len, head_size, reinterpret_cast(k_sub), k.m_strides[2]); // N: key_seq_len, K: head_size @@ -172,8 +171,8 @@ void mha_gpt_impl_amx::mha_bf16(const tensor& q, const tensor& k, const tensor& auto k_sub = &k.at({i0, i1}); auto v_sub = &v.at({i0, i1}); - auto mat0_out = reinterpret_cast(_buffer_mat0_out.get() + thread_id * _buffer_mat0_out_size); - auto mat1_out = reinterpret_cast(_buffer_mat1_out.get() + thread_id * _buffer_mat1_out_size); + auto mat0_out = reinterpret_cast(_buffer_mat0_out + thread_id * _buffer_mat0_out_size); + auto mat1_out = reinterpret_cast(_buffer_mat1_out + thread_id * _buffer_mat1_out_size); tensor2D matQ(seq_cout, head_size, q_sub, q.m_strides[2]); tensor2D matQK(seq_cout, key_seq_len, mat0_out, rndup(key_seq_len * sizeof(float), 64)); @@ -210,7 +209,7 @@ void mha_gpt_impl_amx::mha_bf16(const tensor& q, const tensor& k, const tensor& } else { // [batch, 1, 1, key_seq_len] or [batch, 1, query_seq_len, key_seq_len] for (int m = 0; m < seq_cout; m++) { - auto attn_sub = &attn_mask.at({i0, 0, attn_mask.m_dims[2] == 1 ? 0 : m + seq_start}); // s + i0 * key_seq_len * query_seq_len; + auto attn_sub = &attn_mask.at({i0, 0, attn_mask.m_dims[2] == 1 ? 0 : m + seq_start}); float* src = &matQK(m, 0); ov::bfloat16* dst = &softmax_dst(m, 0); if (!alibi) diff --git a/src/mm_kernel_amx.cpp b/src/mm_kernel_amx.cpp index d4adf8f..43df55b 100644 --- a/src/mm_kernel_amx.cpp +++ b/src/mm_kernel_amx.cpp @@ -1,18 +1,13 @@ // Copyright (C) 2018-2023 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // +#include #include -#include #include -#include #include #include #include -#include -#include #include -#include -#include #include "llm_mm.hpp" #include "mm_kernel_common_amx.hpp" @@ -23,12 +18,12 @@ namespace llmdnn { using ov::bfloat16; struct mm_kernel { - std::shared_ptr> bf16xbf16; - std::shared_ptr> i8xi8; - std::shared_ptr> u8xi8; + std::unique_ptr> bf16xbf16; + std::unique_ptr> i8xi8; + std::unique_ptr> u8xi8; - std::shared_ptr> i8xi8_gemv; - std::shared_ptr> bf16xbf16_gemv; + std::unique_ptr> i8xi8_gemv; + std::unique_ptr> bf16xbf16_gemv; data_type_t dt_a; data_type_t dt_b; @@ -46,20 +41,20 @@ bool mm_kernel_create_amx(mm_kernel** mm, const mm_create_param* param) { m = new mm_kernel; if (param->b_is_gemv) { if (param->dt_a == dnnl_s8 && param->dt_b == dnnl_s8) { - m->i8xi8_gemv = std::make_shared>(); + m->i8xi8_gemv = std::make_unique>(); } else if (param->dt_a == dnnl_bf16 && param->dt_b == dnnl_bf16) { - m->bf16xbf16_gemv = std::make_shared>(); + m->bf16xbf16_gemv = std::make_unique>(); } else { std::cout << "mm_kernel_create: unsupport gemv input type, a: " << param->dt_a << ", b: " << param->dt_b << ".\n"; goto ERR; } } else { if (param->dt_a == dnnl_s8 && param->dt_b == dnnl_s8) { - m->i8xi8 = std::make_shared>(false, param->b_is_trans); + m->i8xi8 = std::make_unique>(false, param->b_is_trans); } else if (param->dt_a == dnnl_u8 && param->dt_b == dnnl_s8) { - m->u8xi8 = std::make_shared>(false, param->b_is_trans); + m->u8xi8 = std::make_unique>(false, param->b_is_trans); } else if (param->dt_a == dnnl_bf16 && param->dt_b == dnnl_bf16) { - m->bf16xbf16 = std::make_shared>(false, param->b_is_trans); + m->bf16xbf16 = std::make_unique>(false, param->b_is_trans); } else { std::cout << "mm_kernel_create: unsupport input type, a: " << param->dt_a << ", b: " << param->dt_b << ".\n"; goto ERR; From ebecf77d11ff9b453e84e53cdf8e66db9a36e3dd Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Tue, 18 Jul 2023 01:01:05 +0800 Subject: [PATCH 31/54] wa gcc 7.5 does not like newer stringstream --- include/llm_tensor.hpp | 72 ------------------------------------------ src/common/tensor.cpp | 15 +++------ 2 files changed, 5 insertions(+), 82 deletions(-) diff --git a/include/llm_tensor.hpp b/include/llm_tensor.hpp index 10671d4..6ca311e 100644 --- a/include/llm_tensor.hpp +++ b/include/llm_tensor.hpp @@ -10,7 +10,6 @@ #include #include #include -#include #include #include "llm_types.hpp" @@ -177,77 +176,6 @@ struct tensor { } void assert_dims(const std::initializer_list& expect_dims) const; - - template - std::string repr(int max_total_lines = 16, int lines_per_row = 1) const { - std::stringstream ss; - ss << typeid(DT).name() << " shape=["; - const char* sep = ""; - size_t sz = 1; - for (size_t i = 0; i < m_rank; i++) { - ss << sep << m_dims[i]; - sz *= m_dims[i]; - sep = ","; - } - ss << "] strides=["; - sep = ""; - for (size_t i = 0; i < m_rank; i++) { - ss << sep << m_strides[i]; - sep = ","; - } - ss << "] {"; - if (m_rank > 1) - ss << "\n"; - auto last_dim_size = m_dims[m_rank - 1]; - int row_id = 0; - int cur_row_lines_left = lines_per_row; - int cur_line_elecnt = 0; - int cur_row_elecnt = 0; - size_t i; - auto* p = reinterpret_cast(m_ptr); - for (i = 0; i < sz && max_total_lines > 0; i++) { - if ((i % last_dim_size) == 0) { - ss << row_id << ":\t\t"; - row_id++; - cur_row_lines_left = lines_per_row; - } - - // display current element if we still have buget - if (cur_row_lines_left > 0) { - ss << p[i] << ","; - cur_line_elecnt++; - cur_row_elecnt++; - if ((cur_line_elecnt % 16) == 15 || (cur_row_elecnt == last_dim_size)) { - max_total_lines--; - cur_row_lines_left--; - if (cur_row_lines_left == 0) { - if (cur_row_elecnt == last_dim_size) - ss << ",\n"; - else - ss << "...\n"; - cur_row_elecnt = 0; - } else { - ss << "\n\t\t"; - } - cur_line_elecnt = 0; - } - } - } - if (i < sz) { - ss << "... ... ... ... \n"; - } - ss << "}"; - return ss.str(); - } - - template - friend std::ostream& operator<<(std::ostream& os, const tensor& dt); }; -template -std::ostream& operator<<(std::ostream& os, const tensor& dt) { - os << dt.repr(); - return os; -} - } // namespace llmdnn diff --git a/src/common/tensor.cpp b/src/common/tensor.cpp index b8e4258..2850e58 100644 --- a/src/common/tensor.cpp +++ b/src/common/tensor.cpp @@ -7,7 +7,6 @@ #include #include #include -#include #include "bf16.hpp" #include "llm_tensor.hpp" @@ -150,20 +149,16 @@ void tensor::resize(const size_t* new_dims, size_t dim_num, void* data, size_t e void tensor::assert_dims(const std::initializer_list& expect_dims) const { if (m_rank != expect_dims.size()) { - asm("int3"); std::cout << "dims not same\n"; } if (!std::equal(expect_dims.begin(), expect_dims.end(), m_dims)) { - std::stringstream ss; - ss << " m_dims=["; + std::cout << " m_dims=["; for (size_t i = 0; i < m_rank; i++) - ss << m_dims[i] << ","; - ss << "] expect_dims=["; + std::cout << m_dims[i] << ","; + std::cout << "] expect_dims=["; for (auto& i : expect_dims) - ss << i << ","; - ss << "]"; - asm("int3"); - std::cout << ss.str(); + std::cout << i << ","; + std::cout << "]"; } } From 7872bad3c98941472e744584f0d3e3c25966273e Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Wed, 19 Jul 2023 20:51:32 +0800 Subject: [PATCH 32/54] use custom allactor for map --- src/common/compatible.hpp | 4 ++++ src/fc_kernel_amx.cpp | 3 ++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/common/compatible.hpp b/src/common/compatible.hpp index f3d2af1..f232763 100644 --- a/src/common/compatible.hpp +++ b/src/common/compatible.hpp @@ -6,6 +6,7 @@ #include #include +#include #include // gcc 9 does not recognize 'std::__throw_bad_array_new_length()' which is imported by @@ -37,3 +38,6 @@ bool operator!=(custom_allocator const& x, custom_allocator const& y) noex template using llm_vector = std::vector>; + +template > +using llm_map = std::map>>; \ No newline at end of file diff --git a/src/fc_kernel_amx.cpp b/src/fc_kernel_amx.cpp index c13c84d..a8bb698 100644 --- a/src/fc_kernel_amx.cpp +++ b/src/fc_kernel_amx.cpp @@ -13,6 +13,7 @@ #include "mm_kernel_common_amx.hpp" #include "utility_kernel_avx512.hpp" #include "fc_kernel_amx.hpp" +#include "common/compatible.hpp" namespace llmdnn { @@ -32,7 +33,7 @@ struct fc_kernel { using supported_key = std::tuple; using supported_value = std::pair; -static std::map supported_postops = { +static llm_map supported_postops = { { { dnnl_s8, dnnl_s8, dnnl_s8 }, { DEQUANT | QUANT, BIAS | GELU | GELU_TANH } }, { { dnnl_s8, dnnl_s8, dnnl_bf16 }, { DEQUANT, BIAS | GELU | GELU_TANH } }, { { dnnl_s8, dnnl_s8, dnnl_f32 }, { DEQUANT, BIAS | GELU | GELU_TANH } }, From 11833d8f6f234069be919087513fdfbc21a0bffa Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Mon, 24 Jul 2023 17:45:57 +0800 Subject: [PATCH 33/54] support external causal mask[opt] --- include/llm_mha_gpt.hpp | 6 + src/mha_gpt_amx.cpp | 79 ++++++------- src/mha_gpt_api.cpp | 4 +- src/utility_kernel_avx512.hpp | 144 +++++++++++++++-------- tests/script/ext/mha_gpt.cpp | 11 +- tests/script/test_mha_bloom.py | 2 +- tests/script/test_mha_gpt.py | 59 +++++++++- tests/src/test_utility_kernel_avx512.cpp | 9 +- 8 files changed, 212 insertions(+), 102 deletions(-) diff --git a/include/llm_mha_gpt.hpp b/include/llm_mha_gpt.hpp index 9cc98f5..3df2db6 100644 --- a/include/llm_mha_gpt.hpp +++ b/include/llm_mha_gpt.hpp @@ -25,6 +25,10 @@ class mha_gpt { // [batch, 1, 1, key_seq_len], // [batch, 1, query_seq_len, key_seq_len] const tensor& alibi, // alibi[opt] shape: [batch, num_heads, 1, key_seq_len] + const tensor& causal_mask, // [opt] use_causal_mask must be false, u8, shape: + // [1, 1, query_seq_len, key_seq_len] + // [batch, 1, query_seq_len, key_seq_len] + bool select_nfltmax_at_0, // used when causal_mask is not null. true: causal_mask=0 use -FLT_MAX float normal_factor, bool use_causal_mask = false);// add causal mask @@ -36,6 +40,8 @@ class mha_gpt { const tensor& output, const tensor& attn_mask, const tensor& alibi, + const tensor& causal_mask, + bool select_nfltmax_at_0, float normal_factor, bool use_causal_mask = false) = 0; }; diff --git a/src/mha_gpt_amx.cpp b/src/mha_gpt_amx.cpp index 4d174d6..3605712 100644 --- a/src/mha_gpt_amx.cpp +++ b/src/mha_gpt_amx.cpp @@ -25,9 +25,11 @@ struct mha_gpt_impl_amx : public mha_gpt::impl { mha_gpt_impl_amx() = default; ~mha_gpt_impl_amx(); void create(data_type_t in_type, size_t seq_len, size_t head_size, bool is_bloom); - void exec(const tensor& q, const tensor& k, const tensor& v, const tensor& output, const tensor& attn_mask, const tensor& alibi, float normal_factor, bool use_causal_mask) override; + void exec(const tensor& q, const tensor& k, const tensor& v, const tensor& output, const tensor& attn_mask, + const tensor& alibi, const tensor& causal_mask, bool select_nfltmax_at_0, float normal_factor, bool use_causal_mask) override; - void mha_bf16(const tensor& q, const tensor& k, const tensor& v, const tensor& output, const tensor& attn_mask, const tensor& alibi, float normal_factor, bool use_causal_mask); + void mha_bf16(const tensor& q, const tensor& k, const tensor& v, const tensor& output, const tensor& attn_mask, + const tensor& alibi, const tensor& causal_mask, bool select_nfltmax_at_0, float normal_factor, bool use_causal_mask); size_t _head_size_aligned = 0; size_t _buffer_mat0_out_size = 0; @@ -100,7 +102,7 @@ void mha_gpt_impl_amx::create(data_type_t in_type, size_t seq_len, size_t head_s } void mha_gpt_impl_amx::mha_bf16(const tensor& q, const tensor& k, const tensor& v, const tensor& output, const tensor& attn_mask, - const tensor& alibi, float normal_factor, bool use_causal_mask) { + const tensor& alibi, const tensor& causal_mask, bool select_nfltmax_at_0, float normal_factor, bool use_causal_mask) { auto batch = q.m_dims[0]; auto head_num = q.m_dims[1]; auto query_seq_len = q.m_dims[2]; @@ -113,7 +115,7 @@ void mha_gpt_impl_amx::mha_bf16(const tensor& q, const tensor& k, const tensor& auto& gemAvB_ops = gemAvB_BF16xBF16; auto& qKtrGemm_ops = qKtrGemm_BF16xBF16; auto& qKVGemm_ops = qKVGemm_BF16xBF16; - bool use_vector = query_seq_len == 1 && head_size >= 32 && head_size <= 32 * 6 && !is_bloom && !alibi && attn_mask; + bool use_vector = query_seq_len == 1 && head_size >= 32 && head_size <= 32 * 6 && !is_bloom && !alibi && attn_mask && !causal_mask; size_t head_stride_in_attn = head_size; size_t batch_stride_in_attn = head_size * head_num * query_seq_len; size_t causal_mask_offset_start = use_causal_mask ? key_seq_len - query_seq_len : key_seq_len; @@ -135,7 +137,7 @@ void mha_gpt_impl_amx::mha_bf16(const tensor& q, const tensor& k, const tensor& (*gemAvB_ops[thread_id])(matK, reinterpret_cast(q_sub), reinterpret_cast(mat0_out)); float* pMatMul0Out = reinterpret_cast(mat0_out); - mul_add_f32_avx512(pMatMul0Out, pMatMul0Out, normal_factor, &attn_mask.at({i0}), key_seq_len); + mul_add2_select_f32_avx512(pMatMul0Out, pMatMul0Out, normal_factor, nullptr, &attn_mask.at({i0}), nullptr, false, key_seq_len); softmax_avx512(reinterpret_cast(pMatMul0Out), pMatMul0Out, key_seq_len, nullptr); auto out_sub = out + (i0 * batch_stride_in_attn + i1 * head_stride_in_attn) * sizeof(ov::bfloat16); tensor2D matQK(query_seq_len, key_seq_len, reinterpret_cast(mat0_out), rndup(key_seq_len * sizeof(ov::bfloat16), 64)); @@ -188,44 +190,21 @@ void mha_gpt_impl_amx::mha_bf16(const tensor& q, const tensor& k, const tensor& tensor2D softmax_dst(seq_cout, key_seq_len, reinterpret_cast(mat0_out), rndup(key_seq_len * sizeof(ov::bfloat16), 64)); // no attention mask size_t valid_softmax_items = std::min(causal_mask_offset_start + seq_start + 1, key_seq_len); - if (!attn_mask) { - for (int m = 0; m < seq_cout; m++) { - float* src = &matQK(m, 0); - ov::bfloat16* dst = &softmax_dst(m, 0); - if (!alibi) - mul_f32_avx512(src, src, normal_factor, valid_softmax_items); - else - // alibi shape: [batch, head_num, 1, key_seq_len] - mul_add_f32_avx512(src, src, normal_factor, - &alibi.at({i0, i1}), - valid_softmax_items); - softmax_avx512(dst, src, valid_softmax_items, nullptr); - if (key_seq_len > valid_softmax_items) { - auto *invalidPtr = dst + valid_softmax_items; - memset(static_cast(invalidPtr), 0, (key_seq_len - valid_softmax_items) * sizeof(ov::bfloat16)); - valid_softmax_items = std::min(valid_softmax_items + 1, key_seq_len); - } - } - } else { - // [batch, 1, 1, key_seq_len] or [batch, 1, query_seq_len, key_seq_len] - for (int m = 0; m < seq_cout; m++) { - auto attn_sub = &attn_mask.at({i0, 0, attn_mask.m_dims[2] == 1 ? 0 : m + seq_start}); - float* src = &matQK(m, 0); - ov::bfloat16* dst = &softmax_dst(m, 0); - if (!alibi) - mul_add_f32_avx512(src, src, normal_factor, attn_sub, valid_softmax_items); - else - // alibi shape: [batch, head_num, 1, key_seq_len] - mul_add2_f32_avx512(src, src, normal_factor, - &alibi.at({i0, i1}), - attn_sub, - valid_softmax_items); - softmax_avx512(dst, src, valid_softmax_items, nullptr); - if (key_seq_len > valid_softmax_items) { - auto *invalidPtr = dst + valid_softmax_items; - memset(static_cast(invalidPtr), 0, (key_seq_len - valid_softmax_items) * sizeof(ov::bfloat16)); - valid_softmax_items = std::min(valid_softmax_items + 1, key_seq_len); - } + // attn: [batch, 1, 1, key_seq_len] or [batch, 1, query_seq_len, key_seq_len] + // alibi: [batch, num_heads, 1, key_seq_len] + // causal: [batch/1, 1, query_seq_len, key_seq_len] + for (int m = 0; m < seq_cout; m++) { + auto attn_sub = attn_mask ? &attn_mask.at({i0, 0, attn_mask.m_dims[2] == 1 ? 0 : m + seq_start}) : nullptr; + auto alibi_sub = alibi ? &alibi.at({i0, i1}) : nullptr; + auto causal_mask_sub = causal_mask ? &causal_mask.at({causal_mask.m_dims[0] == 1 ? 0 : i0, 0, m + seq_start}) : nullptr; + float* src = &matQK(m, 0); + ov::bfloat16* dst = &softmax_dst(m, 0); + mul_add2_select_f32_avx512(src, src, normal_factor, alibi_sub, attn_sub, causal_mask_sub, select_nfltmax_at_0, valid_softmax_items); + softmax_avx512(dst, src, valid_softmax_items, nullptr); + if (key_seq_len > valid_softmax_items) { + auto *invalidPtr = dst + valid_softmax_items; + memset(static_cast(invalidPtr), 0, (key_seq_len - valid_softmax_items) * sizeof(ov::bfloat16)); + valid_softmax_items = std::min(valid_softmax_items + 1, key_seq_len); } } @@ -245,7 +224,7 @@ void mha_gpt_impl_amx::mha_bf16(const tensor& q, const tensor& k, const tensor& } } -void mha_gpt_impl_amx::exec(const tensor& q, const tensor& k, const tensor& v, const tensor& output, const tensor& attn_mask, const tensor& alibi, float normal_factor, bool use_causal_mask) { +void mha_gpt_impl_amx::exec(const tensor& q, const tensor& k, const tensor& v, const tensor& output, const tensor& attn_mask, const tensor& alibi, const tensor& causal_mask, bool select_nfltmax_at_0, float normal_factor, bool use_causal_mask) { if (q.m_rank != 4 || k.m_rank != 4 || v.m_rank != 4) { std::cout << "q,k,v rank does not equal 4.\n"; return; @@ -277,6 +256,16 @@ void mha_gpt_impl_amx::exec(const tensor& q, const tensor& k, const tensor& v, c return; } } + if (causal_mask) { + if (causal_mask.m_rank != 4) { + std::cout << "causal_mask rank should be 4.\n"; + return; + } + if (use_causal_mask) { + std::cout << "use_causal_mask must be false to disable builtin causal mask.\n"; + return; + } + } auto batch = q.m_dims[0]; auto head_num = q.m_dims[1]; auto query_seq_len = q.m_dims[2]; @@ -298,7 +287,7 @@ void mha_gpt_impl_amx::exec(const tensor& q, const tensor& k, const tensor& v, c if (in_dtype == dnnl_bf16 && out_dtype == dnnl_bf16) { create(in_dtype, key_seq_len, head_size, is_bloom); - mha_bf16(q, k, v, output, attn_mask, alibi, normal_factor, use_causal_mask); + mha_bf16(q, k, v, output, attn_mask, alibi, causal_mask, select_nfltmax_at_0, normal_factor, use_causal_mask); } else { std::cout << "doesn't support provided input precisions.\n"; } diff --git a/src/mha_gpt_api.cpp b/src/mha_gpt_api.cpp index 9ba35d5..40d0038 100644 --- a/src/mha_gpt_api.cpp +++ b/src/mha_gpt_api.cpp @@ -17,8 +17,8 @@ mha_gpt::~mha_gpt() { delete _impl; } -void mha_gpt::exec(const tensor& q, const tensor& k, const tensor& v, const tensor& output, const tensor& attn_mask, const tensor& alibi, float normal_factor, bool use_causal_mask) { - _impl->exec(q, k, v, output, attn_mask, alibi, normal_factor, use_causal_mask); +void mha_gpt::exec(const tensor& q, const tensor& k, const tensor& v, const tensor& output, const tensor& attn_mask, const tensor& alibi, const tensor& causal_mask, bool select_nfltmax_at_0, float normal_factor, bool use_causal_mask) { + _impl->exec(q, k, v, output, attn_mask, alibi, causal_mask, select_nfltmax_at_0, normal_factor, use_causal_mask); } } \ No newline at end of file diff --git a/src/utility_kernel_avx512.hpp b/src/utility_kernel_avx512.hpp index 1743253..2aab313 100644 --- a/src/utility_kernel_avx512.hpp +++ b/src/utility_kernel_avx512.hpp @@ -4,6 +4,7 @@ #pragma once +#include #include #include "common/bf16.hpp" #ifdef _WIN32 @@ -91,63 +92,114 @@ inline void cvt_i32_f32_avx512(float* dst, int32_t* src, size_t ele_num) { } } -inline void mul_f32_avx512(float* dst, float* src, float mul, int ele_num) { +enum mul_add2_select_flag { + mul_add2_select_flag_none, + mul_add2_select_flag_add1 = 1, + mul_add2_select_flag_add2 = 2, + mul_add2_select_flag_select = 4 +}; +template +inline void _mul_add2_select_f32_avx512(float* dst, float* src, float mul, float* add1, float* add2, uint8_t* select, int ele_num) { auto mul_f = _mm512_set1_ps(mul); int i; auto tail = ele_num % 16; __mmask16 msk = _cvtu32_mask16(0xFFFFu >> (16 - tail)); + auto zero_i32 = _mm512_setzero_si512(); + auto nfltmax = _mm512_set1_ps(-__FLT_MAX__); for (i = 0; i < ele_num - tail; i += 16) { - auto a_f = _mm512_loadu_ps(src); - _mm512_storeu_ps(dst, _mm512_mul_ps(a_f, mul_f)); - src += 16; - dst += 16; - } - if (tail) { - auto a_f = _mm512_maskz_loadu_ps(msk, src); - _mm512_mask_storeu_ps(dst, msk, _mm512_mul_ps(a_f, mul_f)); - } -} + auto a_f = _mm512_loadu_ps(src + i); + __m512 result; + if constexpr ((flag & (mul_add2_select_flag_add1 | mul_add2_select_flag_add2)) == mul_add2_select_flag_none) + result = _mm512_mul_ps(a_f, mul_f); + else if constexpr ((flag & (mul_add2_select_flag_add1 | mul_add2_select_flag_add2)) == mul_add2_select_flag_add2) + result = _mm512_fmadd_ps(a_f, mul_f, _mm512_loadu_ps(add2 + i)); + else { + result = _mm512_fmadd_ps(a_f, mul_f, _mm512_loadu_ps(add1 + i)); + if constexpr (flag & mul_add2_select_flag_add2) + result = _mm512_add_ps(result, _mm512_loadu_ps(add2 + i)); + } + if constexpr (flag & mul_add2_select_flag_select) { + auto r_maski8 = _mm_loadu_si128(reinterpret_cast<__m128i*>(select + i)); + auto r_maski32 = _mm512_cvtepi8_epi32(r_maski8); + r_maski32 = _mm512_sub_epi32(zero_i32, r_maski32); + auto r_maskps = _mm512_movepi32_mask(r_maski32); // -FLT_MAX if mask == 0 + if constexpr (select_nfltmax_at_0) + result = _mm512_mask_blend_ps(r_maskps, nfltmax, result); + else + result = _mm512_mask_blend_ps(r_maskps, result, nfltmax); + } -inline void mul_add_f32_avx512(float* dst, float* src, float mul, float* add, int ele_num) { - auto mul_f = _mm512_set1_ps(mul); - int i; - auto tail = ele_num % 16; - __mmask16 msk = _cvtu32_mask16(0xFFFFu >> (16 - tail)); - for (i = 0; i < ele_num - tail; i += 16) { - auto a_f = _mm512_loadu_ps(src); - auto add_f = _mm512_loadu_ps(add); - _mm512_storeu_ps(dst, _mm512_fmadd_ps(a_f, mul_f, add_f)); - src += 16; - dst += 16; - add += 16; + _mm512_storeu_ps(dst + i, result); } if (tail) { - auto a_f = _mm512_maskz_loadu_ps(msk, src); - auto add_f = _mm512_maskz_loadu_ps(msk, add); - _mm512_mask_storeu_ps(dst, msk, _mm512_fmadd_ps(a_f, mul_f, add_f)); + auto a_f = _mm512_maskz_loadu_ps(msk, src + i); + __m512 result; + if constexpr ((flag & (mul_add2_select_flag_add1 | mul_add2_select_flag_add2)) == mul_add2_select_flag_none) + result = _mm512_mul_ps(a_f, mul_f); + else if constexpr ((flag & (mul_add2_select_flag_add1 | mul_add2_select_flag_add2)) == mul_add2_select_flag_add2) + result = _mm512_fmadd_ps(a_f, mul_f, _mm512_maskz_loadu_ps(msk, add2 + i)); + else { + result = _mm512_fmadd_ps(a_f, mul_f, _mm512_maskz_loadu_ps(msk, add1 + i)); + if constexpr (flag & mul_add2_select_flag_add2) + result = _mm512_add_ps(result, _mm512_maskz_loadu_ps(msk, add2 + i)); + } + if constexpr (flag & mul_add2_select_flag_select) { + auto r_maski8 = _mm512_castsi512_si128(_mm512_maskz_loadu_epi8(msk, select + i)); + auto r_maski32 = _mm512_cvtepi8_epi32(r_maski8); + r_maski32 = _mm512_sub_epi32(zero_i32, r_maski32); + auto r_maskps = _mm512_movepi32_mask(r_maski32); // -FLT_MAX if mask == 0 + if constexpr (select_nfltmax_at_0) + result = _mm512_mask_blend_ps(r_maskps, nfltmax, result); + else + result = _mm512_mask_blend_ps(r_maskps, result, nfltmax); + } + + _mm512_mask_storeu_ps(dst + i, msk, result); } } -inline void mul_add2_f32_avx512(float* dst, float* src, float mul, float* add1, float* add2, int ele_num) { - auto mul_f = _mm512_set1_ps(mul); - int i; - auto tail = ele_num % 16; - __mmask16 msk = _cvtu32_mask16(0xFFFFu >> (16 - tail)); - for (i = 0; i < ele_num - tail; i += 16) { - auto a_f = _mm512_loadu_ps(src); - auto add1_f = _mm512_loadu_ps(add1); - auto add2_f = _mm512_loadu_ps(add2); - _mm512_storeu_ps(dst, _mm512_add_ps(_mm512_fmadd_ps(a_f, mul_f, add1_f), add2_f)); - src += 16; - dst += 16; - add1 += 16; - add2 += 16; - } - if (tail) { - auto a_f = _mm512_maskz_loadu_ps(msk, src); - auto add1_f = _mm512_maskz_loadu_ps(msk, add1); - auto add2_f = _mm512_maskz_loadu_ps(msk, add2); - _mm512_mask_storeu_ps(dst, msk, _mm512_add_ps(_mm512_fmadd_ps(a_f, mul_f, add1_f), add2_f)); +inline void mul_add2_select_f32_avx512(float* dst, float* src, float mul, float* add1, float* add2, uint8_t* select, bool select_nfltmax_at_0, int ele_num) { + if (add1) { + if (add2) { + if (select) { + if (select_nfltmax_at_0) + _mul_add2_select_f32_avx512(mul_add2_select_flag_add1 | mul_add2_select_flag_add2 | mul_add2_select_flag_select), true>(dst, src, mul, add1, add2, select, ele_num); + else + _mul_add2_select_f32_avx512(mul_add2_select_flag_add1 | mul_add2_select_flag_add2 | mul_add2_select_flag_select)>(dst, src, mul, add1, add2, select, ele_num);; + } else { + _mul_add2_select_f32_avx512(mul_add2_select_flag_add1 | mul_add2_select_flag_add2)>(dst, src, mul, add1, add2, select, ele_num); + } + } else { + if (select) { + if (select_nfltmax_at_0) + _mul_add2_select_f32_avx512(mul_add2_select_flag_add1 | mul_add2_select_flag_select), true>(dst, src, mul, add1, add2, select, ele_num); + else + _mul_add2_select_f32_avx512(mul_add2_select_flag_add1 | mul_add2_select_flag_select)>(dst, src, mul, add1, add2, select, ele_num); + } else { + _mul_add2_select_f32_avx512(mul_add2_select_flag_add1)>(dst, src, mul, add1, add2, select, ele_num); + } + } + } else { + if (add2) { + if (select) { + if (select_nfltmax_at_0) + _mul_add2_select_f32_avx512(mul_add2_select_flag_add2 | mul_add2_select_flag_select), true>(dst, src, mul, add1, add2, select, ele_num); + else + _mul_add2_select_f32_avx512(mul_add2_select_flag_add2 | mul_add2_select_flag_select)>(dst, src, mul, add1, add2, select, ele_num);; + } else { + _mul_add2_select_f32_avx512(mul_add2_select_flag_add2)>(dst, src, mul, add1, add2, select, ele_num); + } + } else { + if (select) { + if (select_nfltmax_at_0) + _mul_add2_select_f32_avx512(mul_add2_select_flag_select), true>(dst, src, mul, add1, add2, select, ele_num); + else + _mul_add2_select_f32_avx512(mul_add2_select_flag_select)>(dst, src, mul, add1, add2, select, ele_num); + } else { + _mul_add2_select_f32_avx512(mul_add2_select_flag_none)>(dst, src, mul, add1, add2, select, ele_num); + } + } } } + } \ No newline at end of file diff --git a/tests/script/ext/mha_gpt.cpp b/tests/script/ext/mha_gpt.cpp index 844550b..813f891 100644 --- a/tests/script/ext/mha_gpt.cpp +++ b/tests/script/ext/mha_gpt.cpp @@ -18,13 +18,14 @@ void regclass_mha_gpt(pybind11::module m) { py::class_> cls(m, "mha_gpt"); cls.def(py::init<>()); cls.def("exec", [] (llmdnn::mha_gpt& self, const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v, const torch::Tensor& alibi, - const torch::Tensor& attn_mask, float normal_factor, bool use_causal) { + const torch::Tensor& attn_mask, const torch::Tensor& causal_mask, bool select_nfltmax_at_0, float normal_factor, bool use_causal) { // q: [batch, num_heads, query_seq_len, head_size] // k: [batch, num_heads, key_seq_len, head_size] // v: [batch, num_heads, key_seq_len, head_size] // attn_mask: [batch, 1, 1/query_seq_len, key_seq_len] // out: [batch, query_seq_len, num_heads * head_size] // alibi: [batch, num_heads, 1, key_seq_len] + // causal_mask: [batch, 1, query_seq_len, key_seq_len] AT_ASSERT(q.dim() == 4 && k.dim() == 4 && v.dim() == 4 && attn_mask.dim() == 4); auto batch = q.size(0); auto num_heads = q.size(1); @@ -34,7 +35,7 @@ void regclass_mha_gpt(pybind11::module m) { num_heads == k.size(1) && num_heads == v.size(1) && head_size == v.size(3)); - llmdnn::tensor q_, k_, v_, out_, attn_mask_, alibi_; + llmdnn::tensor q_, k_, v_, out_, attn_mask_, alibi_, causal_mask_; q_.resize({q.size(0), q.size(1), q.size(2), q.size(3)}, reinterpret_cast(q.data_ptr())); k_.resize({k.size(0), k.size(1), k.size(2), k.size(3)}, reinterpret_cast(k.data_ptr())); if (k.size(2) != v.size(2)) { @@ -49,7 +50,9 @@ void regclass_mha_gpt(pybind11::module m) { attn_mask_.resize({attn_mask.size(0), attn_mask.size(1), attn_mask.size(2), attn_mask.size(3)}, attn_mask.data_ptr()); if (alibi.numel()) alibi_.resize({alibi.size(0), alibi.size(1), alibi.size(2), alibi.size(3)}, alibi.data_ptr()); - self.exec(q_, k_, v_, out_, attn_mask_, alibi_, normal_factor, use_causal); + if (causal_mask.numel()) + causal_mask_.resize({causal_mask.size(0), causal_mask.size(1), causal_mask.size(2), causal_mask.size(3)}, causal_mask.data_ptr()); + self.exec(q_, k_, v_, out_, attn_mask_, alibi_, causal_mask_, select_nfltmax_at_0, normal_factor, use_causal); return out; }, @@ -58,6 +61,8 @@ void regclass_mha_gpt(pybind11::module m) { py::arg("v"), py::arg("alibi"), py::arg("attn_mask"), + py::arg("causal_mask"), + py::arg("select_nfltmax_at_0"), py::arg("normal_factor"), py::arg("use_causal"), R"( diff --git a/tests/script/test_mha_bloom.py b/tests/script/test_mha_bloom.py index cef9573..3c5dfee 100644 --- a/tests/script/test_mha_bloom.py +++ b/tests/script/test_mha_bloom.py @@ -167,7 +167,7 @@ def __init__(self): self.mha = ld.mha_gpt() def forward(self, query, key, value, alibi, attention_mask, normal_factor): - return self.mha.exec(query, key, value, alibi, attention_mask, normal_factor, False) + return self.mha.exec(query, key, value, alibi, attention_mask, torch.tensor([]), False, normal_factor, False) HEAD_NUM = 32 SIZE_PER_HEAD = 80 diff --git a/tests/script/test_mha_gpt.py b/tests/script/test_mha_gpt.py index 23b880a..fc5a3aa 100644 --- a/tests/script/test_mha_gpt.py +++ b/tests/script/test_mha_gpt.py @@ -108,8 +108,8 @@ class GPTNeoXAttentionExt: def __init__(self): self.mha = ld.mha_gpt() - def forward(self, query, key, value, attention_mask, normal_factor): - return self.mha.exec(query, key, value, torch.tensor([]), attention_mask, normal_factor, True) + def forward(self, query, key, value, attention_mask, normal_factor, causal_mask = torch.tensor([]), select_nfltmax_at_0 = False): + return self.mha.exec(query, key, value, torch.tensor([]), attention_mask, causal_mask, select_nfltmax_at_0, normal_factor, False if causal_mask.numel() > 0 else True) HEAD_NUM = 32 SIZE_PER_HEAD = 80 @@ -165,5 +165,58 @@ def test_gpt_neox(): print('done.') return +def test_gpt_neox_with_causal(): + inputs = [ + # q, k, v, attn_mask + # q: [batch, num_heads, query_seq_len, head_size] + # k: [batch, num_heads, key_seq_len, head_size] + # v: [batch, num_heads, value_seq_len, head_size] + # attn: [2, 1, 1, key_seq_len] + # causal: [2, 1, query_seq_len, key_seq_len] + # (np.random.random(size=[2, HEAD_NUM, 2, SIZE_PER_HEAD]).astype(np.float32), + # np.random.random(size=[2, HEAD_NUM, 32, SIZE_PER_HEAD]).astype(np.float32), + # np.random.random(size=[2, HEAD_NUM, 32, SIZE_PER_HEAD]).astype(np.float32), + # np.zeros([2, 1, 1, 32], dtype=np.float32)), + (np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), + np.zeros([2, 1, 1, 200], dtype=np.float32)), + (np.random.random(size=[2, HEAD_NUM, 1, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), + np.zeros([2, 1, 1, 200], dtype=np.float32)), + ] + ref_net = get_ref_model() + net = GPTNeoXAttentionExt() + causal_mask = torch.tril(torch.ones((MAX_POSITION_EMBEDDINGS, MAX_POSITION_EMBEDDINGS), dtype=torch.uint8)).view( + 1, 1, MAX_POSITION_EMBEDDINGS, MAX_POSITION_EMBEDDINGS + ) + with torch.cpu.amp.autocast(): + for (i, input) in enumerate(inputs): + q, k, v, attn_mask = input + q = torch.from_numpy(q).to(torch.bfloat16) + k = torch.from_numpy(k).to(torch.bfloat16) + v = torch.from_numpy(v).to(torch.bfloat16) + + batch_size, num_attention_heads, query_length, attn_head_size = q.size() + key_length = k.size(-2) + causal_mask_sub = causal_mask[:, :, key_length - query_length : key_length, :key_length].contiguous() + # 0 means -fltmax + select_nfltmax_at_0 = True + if i == 0: + # 1 means -fltmax + causal_mask_sub = 1 - causal_mask_sub + select_nfltmax_at_0 = False + attn_mask = torch.from_numpy(attn_mask) + attn_mask[:,:,:,-2:] = torch.finfo(torch.float32).min + ref_output = ref_net.forward(q, k, v, attn_mask) + output = net.forward(q, k, v, attn_mask, normal_factor = 1.0 / math.sqrt(SIZE_PER_HEAD), causal_mask = causal_mask_sub, select_nfltmax_at_0 = select_nfltmax_at_0) + if not torch.allclose(ref_output, output, rtol=0.001, atol=0.01): + print(f"error at index {i} ref:\n{ref_output} \ncur:\n {output} ") + assert(False) + + print('done.') + return + if __name__ == "__main__": - test_gpt_neox() + test_gpt_neox_with_causal() diff --git a/tests/src/test_utility_kernel_avx512.cpp b/tests/src/test_utility_kernel_avx512.cpp index 615a5b0..c0806e1 100644 --- a/tests/src/test_utility_kernel_avx512.cpp +++ b/tests/src/test_utility_kernel_avx512.cpp @@ -2,6 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 // +#include #include #include #include @@ -24,12 +25,16 @@ TEST(smoke_Utility, muladd) { float normal_factor = 1.2f; for (size_t len = 1; len < 129; len++) { std::vector x(len), x_out(len), bias(len), ref(len); + std::vector mask(len, 1); + mask[0] = 0; for (size_t i = 0; i < x.size(); i++) { x[i] = -10.0f + i; bias[i] = -100.0f + i; - ref[i] = x[i] * normal_factor + bias[i]; + ref[i] = x[i] * normal_factor + bias[i] + bias[i]; + if (mask[i] == 0) + ref[i] = -FLT_MAX; } - mul_add_f32_avx512(x_out.data(), x.data(), normal_factor, bias.data(), len); + mul_add2_select_f32_avx512(x_out.data(), x.data(), normal_factor, bias.data(), bias.data(), mask.data(), true, len); for (size_t i = 0; i < x.size(); i++) { ASSERT_TRUE(std::abs(x_out[i] - ref[i]) < 0.0001f) << " length: " << len << " pos: " << i << " cur: " << x[i] << " ref: " << ref[i]; } From 7beabd66bbad0f4f2bc77e24a7297885a80f5096 Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Mon, 24 Jul 2023 09:48:44 +0300 Subject: [PATCH 34/54] fix coverity scan errors --- src/common/tensor2d.hpp | 2 +- src/mm_kernel_common_amx.hpp | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/common/tensor2d.hpp b/src/common/tensor2d.hpp index ce83671..b406645 100644 --- a/src/common/tensor2d.hpp +++ b/src/common/tensor2d.hpp @@ -23,7 +23,7 @@ struct tensor2D { int64_t capacity = 0; int stride = 0; bool force_compact = false; - bool own; + bool own = false; int padded_dim1 = 0; tensor2D() = default; diff --git a/src/mm_kernel_common_amx.hpp b/src/mm_kernel_common_amx.hpp index 0bb22ce..2c6c412 100644 --- a/src/mm_kernel_common_amx.hpp +++ b/src/mm_kernel_common_amx.hpp @@ -459,6 +459,9 @@ namespace functional { re = _mm512_setzero_epi32(); rf = _mm512_setzero_epi32(); break; + default: + assert(false); + return; } transpose_m512i_16x16(r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf); @@ -802,6 +805,9 @@ namespace functional { re = _mm512_setzero_epi32(); rf = _mm512_setzero_epi32(); break; + default: + assert(false); + return; } transpose_m512i_16x16(r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf); _mm512_storeu_epi32(dst, r0); From 2c0ae5ba733851db29fedb2bbb2c1195765dba4e Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Tue, 25 Jul 2023 07:07:30 +0300 Subject: [PATCH 35/54] remove ov namespace --- src/common/simple_parallel.hpp | 22 ++++++++++------------ src/emb_gpt_avx512.cpp | 2 +- src/mha_gpt_amx.cpp | 4 ++-- tests/src/test_common.cpp | 10 ++++------ 4 files changed, 17 insertions(+), 21 deletions(-) diff --git a/src/common/simple_parallel.hpp b/src/common/simple_parallel.hpp index b0548f4..139ac8f 100644 --- a/src/common/simple_parallel.hpp +++ b/src/common/simple_parallel.hpp @@ -8,11 +8,10 @@ #include #include -namespace ov { -namespace cpu { +namespace utility { -size_t getTotalThreads(); -void TrySimpleParallelFor(const std::ptrdiff_t total, const std::function& fn); +size_t get_total_threads(); +void simple_parallel_for(const std::ptrdiff_t total, const std::function& fn); // copy from openvino/core/parallel.hpp template @@ -101,13 +100,13 @@ void for_1d(const int& ithr, const int& nthr, const T0& D0, const F& func) { template void parallel_for(const T0& D0, const F& func) { auto work_amount = static_cast(D0); - int nthr = static_cast(getTotalThreads()); + int nthr = static_cast(get_total_threads()); if (static_cast(nthr) > work_amount) nthr = static_cast(work_amount); if (nthr == 1) { for_1d(0, 1, D0, func); } else { - TrySimpleParallelFor(static_cast(nthr), [&](size_t ithr) { + simple_parallel_for(static_cast(nthr), [&](size_t ithr) { for_1d(static_cast(ithr), nthr, D0, func); }); } @@ -133,13 +132,13 @@ void for_2d(const int& ithr, const int& nthr, const T0& D0, const T1& D1, const template void parallel_for2d(const T0& D0, const T1& D1, const F& func) { auto work_amount = static_cast(D0 * D1); - int nthr = static_cast(getTotalThreads()); + int nthr = static_cast(get_total_threads()); if (static_cast(nthr) > work_amount) nthr = static_cast(work_amount); if (nthr == 1) { for_2d(0, 1, D0, D1, func); } else { - TrySimpleParallelFor(static_cast(nthr), [&](size_t ithr) { + simple_parallel_for(static_cast(nthr), [&](size_t ithr) { for_2d(static_cast(ithr), nthr, D0, D1, func); }); } @@ -166,17 +165,16 @@ void for_3d(const int& ithr, const int& nthr, const T0& D0, const T1& D1, const template void parallel_for3d(const T0& D0, const T1& D1, const T2& D2, const F& func) { auto work_amount = static_cast(D0 * D1 * D2); - int nthr = static_cast(getTotalThreads()); + int nthr = static_cast(get_total_threads()); if (static_cast(nthr) > work_amount) nthr = static_cast(work_amount); if (nthr == 1) { for_3d(0, 1, D0, D1, D2, func); } else { - TrySimpleParallelFor(static_cast(nthr), [&](size_t ithr) { + simple_parallel_for(static_cast(nthr), [&](size_t ithr) { for_3d(static_cast(ithr), nthr, D0, D1, D2, func); }); } } -}; // namespace cpu -}; // namespace ov \ No newline at end of file +} // namespace utility \ No newline at end of file diff --git a/src/emb_gpt_avx512.cpp b/src/emb_gpt_avx512.cpp index e0f0f67..1487ff2 100644 --- a/src/emb_gpt_avx512.cpp +++ b/src/emb_gpt_avx512.cpp @@ -16,7 +16,7 @@ #include "emb_gpt_avx512.hpp" #include "rotary_kernel_avx512.hpp" -using namespace ov::cpu; +using namespace utility; namespace llmdnn { diff --git a/src/mha_gpt_amx.cpp b/src/mha_gpt_amx.cpp index 3605712..eb29af5 100644 --- a/src/mha_gpt_amx.cpp +++ b/src/mha_gpt_amx.cpp @@ -17,7 +17,7 @@ #include "llm_mha_gpt.hpp" #include "mha_gpt_amx.hpp" -using namespace ov::cpu; +using namespace utility; namespace llmdnn { @@ -69,7 +69,7 @@ void mha_gpt_impl_amx::create(data_type_t in_type, size_t seq_len, size_t head_s // matmul1: [batch, head_num, query_seq_len, head_size] // attn_output: [batch, query_seq_len, head_num * head_size] if (_num_threads == 0) { - _num_threads = getTotalThreads(); + _num_threads = get_total_threads(); _head_size_aligned = rndup(head_size, 32); gemAvB_BF16xBF16.resize(_num_threads); for (size_t i = 0; i < _num_threads; i++) { diff --git a/tests/src/test_common.cpp b/tests/src/test_common.cpp index a96a268..be669f8 100644 --- a/tests/src/test_common.cpp +++ b/tests/src/test_common.cpp @@ -62,19 +62,17 @@ bool initXTILE() { return true; } -namespace ov { -namespace cpu { +namespace utility { -size_t getTotalThreads() { +size_t get_total_threads() { return omp_get_max_threads(); } -void TrySimpleParallelFor(const std::ptrdiff_t total, const std::function& fn) { +void simple_parallel_for(const size_t total, const std::function& fn) { #pragma omp parallel for - for(std::ptrdiff_t i = 0; i < total; i++) { + for(size_t i = 0; i < total; i++) { fn(i); } } -} } \ No newline at end of file From e8455b7143542db4feb7861787eb69fa2f7ace73 Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Tue, 25 Jul 2023 14:45:20 +0300 Subject: [PATCH 36/54] remove c++ global vars --- src/CMakeLists.txt | 1 + src/fc_kernel_amx.cpp | 20 ++++++++++---------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 1a4d8b8..6dd7b49 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -32,6 +32,7 @@ install(TARGETS ${PROJECT_NAME} # config file install(EXPORT ${PROJECT_NAME}Targets + NAMESPACE ${PROJECT_NAME}:: FILE ${PROJECT_NAME}Config.cmake DESTINATION ${CMAKE_DST}) diff --git a/src/fc_kernel_amx.cpp b/src/fc_kernel_amx.cpp index a8bb698..75ec107 100644 --- a/src/fc_kernel_amx.cpp +++ b/src/fc_kernel_amx.cpp @@ -33,17 +33,17 @@ struct fc_kernel { using supported_key = std::tuple; using supported_value = std::pair; -static llm_map supported_postops = { - { { dnnl_s8, dnnl_s8, dnnl_s8 }, { DEQUANT | QUANT, BIAS | GELU | GELU_TANH } }, - { { dnnl_s8, dnnl_s8, dnnl_bf16 }, { DEQUANT, BIAS | GELU | GELU_TANH } }, - { { dnnl_s8, dnnl_s8, dnnl_f32 }, { DEQUANT, BIAS | GELU | GELU_TANH } }, - { { dnnl_bf16, dnnl_bf16, dnnl_bf16 }, { 0, BIAS | GELU | GELU_TANH } }, - { { dnnl_bf16, dnnl_bf16, dnnl_f32 }, { 0, BIAS | GELU | GELU_TANH } }, - { { dnnl_bf16, dnnl_s8, dnnl_f32 }, { DEQUANT, BIAS | GELU | GELU_TANH } }, - { { dnnl_bf16, dnnl_s8, dnnl_bf16 }, { DEQUANT, BIAS | GELU | GELU_TANH } }, -}; - static bool check_valid_postops(size_t value, data_type_t dt_a, data_type_t dt_b, data_type_t dt_c) { + llm_map supported_postops = { + { { dnnl_s8, dnnl_s8, dnnl_s8 }, { DEQUANT | QUANT, BIAS | GELU | GELU_TANH } }, + { { dnnl_s8, dnnl_s8, dnnl_bf16 }, { DEQUANT, BIAS | GELU | GELU_TANH } }, + { { dnnl_s8, dnnl_s8, dnnl_f32 }, { DEQUANT, BIAS | GELU | GELU_TANH } }, + { { dnnl_bf16, dnnl_bf16, dnnl_bf16 }, { 0, BIAS | GELU | GELU_TANH } }, + { { dnnl_bf16, dnnl_bf16, dnnl_f32 }, { 0, BIAS | GELU | GELU_TANH } }, + { { dnnl_bf16, dnnl_s8, dnnl_f32 }, { DEQUANT, BIAS | GELU | GELU_TANH } }, + { { dnnl_bf16, dnnl_s8, dnnl_bf16 }, { DEQUANT, BIAS | GELU | GELU_TANH } }, + }; + auto it = supported_postops.find(std::make_tuple(dt_a, dt_b, dt_c)); if (it == supported_postops.end()) { return false; From 31faf591f5e43b5cf561f9fbbdc4782514c4c170 Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Wed, 26 Jul 2023 20:14:01 +0800 Subject: [PATCH 37/54] fix simple_parallel_for type --- src/common/simple_parallel.hpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/common/simple_parallel.hpp b/src/common/simple_parallel.hpp index 139ac8f..ef5534e 100644 --- a/src/common/simple_parallel.hpp +++ b/src/common/simple_parallel.hpp @@ -4,6 +4,7 @@ #pragma once +#include #include #include #include @@ -11,7 +12,7 @@ namespace utility { size_t get_total_threads(); -void simple_parallel_for(const std::ptrdiff_t total, const std::function& fn); +void simple_parallel_for(const size_t total, const std::function& fn); // copy from openvino/core/parallel.hpp template From 2326d60a0ba53a24862d9be70d64574ed53a9efe Mon Sep 17 00:00:00 2001 From: "Li, Tingqian" Date: Thu, 27 Jul 2023 23:41:44 +0800 Subject: [PATCH 38/54] optimize mha_gpt_impl_amx::create --- src/mha_gpt_amx.cpp | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/mha_gpt_amx.cpp b/src/mha_gpt_amx.cpp index eb29af5..72a02e1 100644 --- a/src/mha_gpt_amx.cpp +++ b/src/mha_gpt_amx.cpp @@ -39,6 +39,8 @@ struct mha_gpt_impl_amx : public mha_gpt::impl { uint8_t* _buffer_mat0_out = nullptr; uint8_t* _buffer_mat1_out = nullptr; + bool _is_bloom; + llm_vector*> gemAvB_BF16xBF16; llm_vector*> qKtrGemm_BF16xBF16; llm_vector*> qKVGemm_BF16xBF16; @@ -70,7 +72,6 @@ void mha_gpt_impl_amx::create(data_type_t in_type, size_t seq_len, size_t head_s // attn_output: [batch, query_seq_len, head_num * head_size] if (_num_threads == 0) { _num_threads = get_total_threads(); - _head_size_aligned = rndup(head_size, 32); gemAvB_BF16xBF16.resize(_num_threads); for (size_t i = 0; i < _num_threads; i++) { gemAvB_BF16xBF16[i] = new amx_kernel::MatmulVector(); @@ -79,25 +80,32 @@ void mha_gpt_impl_amx::create(data_type_t in_type, size_t seq_len, size_t head_s for (size_t i = 0; i < _num_threads; i++) { qKtrGemm_BF16xBF16[i] = new amx_kernel::Matmul(false, !is_bloom); } + _is_bloom = is_bloom; qKVGemm_BF16xBF16.resize(_num_threads); for (size_t i = 0; i < _num_threads; i++) { qKVGemm_BF16xBF16[i] = new amx_kernel::Matmul(false, false); } } + // correct transposeB + if (_is_bloom != is_bloom) { + for (auto& mm : qKtrGemm_BF16xBF16) { + mm->transposeB = !is_bloom; + } + _is_bloom = is_bloom; + } + auto buffer_mat0_out_size = seq_len * rndup(seq_len * sizeof(float), 64); if (buffer_mat0_out_size > _buffer_mat0_out_size) { + _head_size_aligned = rndup(head_size, 32); _buffer_mat0_out_size = seq_len * rndup(seq_len * sizeof(float), 64) * 3 / 2; _buffer_mat1_out_size = seq_len * _head_size_aligned * sizeof(float) * 3 / 2; if (_buffer_mat0_out) free(_buffer_mat0_out); if (_buffer_mat1_out) free(_buffer_mat1_out); - _buffer_mat0_out = reinterpret_cast(aligned_alloc(64, _num_threads * _buffer_mat0_out_size)); - memset(_buffer_mat0_out, 0, _num_threads * _buffer_mat0_out_size); _buffer_mat1_out = reinterpret_cast(aligned_alloc(64, _num_threads * _buffer_mat1_out_size)); - memset(_buffer_mat1_out, 0, _num_threads * _buffer_mat1_out_size); } } From 04b84838af32e969c2cdecdaede6fef4b9c09c8c Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Fri, 28 Jul 2023 01:10:24 +0800 Subject: [PATCH 39/54] add security.md and fix warnings. --- SECURITY.md | 12 ++++++++++++ src/common/tensor.cpp | 6 +++--- src/emb_gpt_avx512.cpp | 1 - src/mha_gpt_amx.cpp | 5 ++--- 4 files changed, 17 insertions(+), 7 deletions(-) create mode 100644 SECURITY.md diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 0000000..eb482d9 --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,12 @@ +# Security Policy + +## Report a Vulnerability + +Please report security issues or vulnerabilities to the [Intel® Security Center]. + +For more information on how Intel® works to resolve security issues, see +[Vulnerability Handling Guidelines]. + +[Intel® Security Center]:https://www.intel.com/security + +[Vulnerability Handling Guidelines]:https://www.intel.com/content/www/us/en/security-center/vulnerability-handling-guidelines.html diff --git a/src/common/tensor.cpp b/src/common/tensor.cpp index 2850e58..5ea01b1 100644 --- a/src/common/tensor.cpp +++ b/src/common/tensor.cpp @@ -45,13 +45,13 @@ tensor tensor::index(const std::initializer_list& indices) const { } sub_tensor.m_rank = i_dst; // index may imply squeeze sub_tensor.m_ptr = reinterpret_cast(m_ptr) + off; - return std::move(sub_tensor); + return sub_tensor; } // slice: return a sub-view (w/o ownership/refcount to original data) tensor tensor::slice(int axis, int start, int end) const { tensor sub_tensor; - assert(axis < m_rank); + assert(static_cast(axis) < m_rank); sub_tensor.m_capacity = 0; sub_tensor.m_rank = m_rank; // slice dosen't change rank & strides @@ -65,7 +65,7 @@ tensor tensor::slice(int axis, int start, int end) const { auto* data = reinterpret_cast(m_ptr) + off; sub_tensor.m_ptr = reinterpret_cast(data); - return std::move(sub_tensor); + return sub_tensor; } bool tensor::is_dense() const { diff --git a/src/emb_gpt_avx512.cpp b/src/emb_gpt_avx512.cpp index 1487ff2..00c7be8 100644 --- a/src/emb_gpt_avx512.cpp +++ b/src/emb_gpt_avx512.cpp @@ -24,7 +24,6 @@ static void memcpy_past_kv(const tensor& k_past, const tensor& v_past, const ten auto batch = k_past.m_dims[0]; auto head_num = k_past.m_dims[1]; auto past_seq_len = k_past.m_dims[2]; - auto size = k_past.m_dims[3]; parallel_for3d(batch, head_num, past_seq_len, [&](size_t b, size_t h, size_t s) { memcpy(&k_dst.at({b, h, s}), &k_past.at({b, h, s}), k_past.m_strides[2]); memcpy(&v_dst.at({b, h, s}), &v_past.at({b, h, s}), v_past.m_strides[2]); diff --git a/src/mha_gpt_amx.cpp b/src/mha_gpt_amx.cpp index 72a02e1..e4ab763 100644 --- a/src/mha_gpt_amx.cpp +++ b/src/mha_gpt_amx.cpp @@ -170,7 +170,7 @@ void mha_gpt_impl_amx::mha_bf16(const tensor& q, const tensor& k, const tensor& parallel_it_init(start, i0, batch, i1, head_num, seq, seq_cout_all); ov::bfloat16* prev_k = nullptr; ov::bfloat16* prev_v = nullptr; - for (int iwork = start; iwork < end; ++iwork) { + for (size_t iwork = start; iwork < end; ++iwork) { auto seq_start = seq * 32; auto seq_end = std::min(seq_start + 32, query_seq_len); auto seq_cout = seq_end - seq_start; @@ -201,7 +201,7 @@ void mha_gpt_impl_amx::mha_bf16(const tensor& q, const tensor& k, const tensor& // attn: [batch, 1, 1, key_seq_len] or [batch, 1, query_seq_len, key_seq_len] // alibi: [batch, num_heads, 1, key_seq_len] // causal: [batch/1, 1, query_seq_len, key_seq_len] - for (int m = 0; m < seq_cout; m++) { + for (uint32_t m = 0; m < seq_cout; m++) { auto attn_sub = attn_mask ? &attn_mask.at({i0, 0, attn_mask.m_dims[2] == 1 ? 0 : m + seq_start}) : nullptr; auto alibi_sub = alibi ? &alibi.at({i0, i1}) : nullptr; auto causal_mask_sub = causal_mask ? &causal_mask.at({causal_mask.m_dims[0] == 1 ? 0 : i0, 0, m + seq_start}) : nullptr; @@ -276,7 +276,6 @@ void mha_gpt_impl_amx::exec(const tensor& q, const tensor& k, const tensor& v, c } auto batch = q.m_dims[0]; auto head_num = q.m_dims[1]; - auto query_seq_len = q.m_dims[2]; auto head_size = q.m_dims[3]; auto key_seq_len = k.m_dims[2]; From 77ac0bcde69a1fa9aefd7b4edc536d173e18a4cb Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Tue, 1 Aug 2023 01:25:31 +0800 Subject: [PATCH 40/54] apply review comments --- CMakeLists.txt | 1 + include/llm_emb_gpt.hpp | 26 +++--- include/llm_fc.hpp | 4 +- include/llm_mha_gpt.hpp | 33 +++---- include/llm_mm.hpp | 6 +- include/llm_tensor.hpp | 14 +-- include/llm_types.hpp | 28 +++--- src/CMakeLists.txt | 3 + src/common/log.hpp | 11 +++ src/common/simple_parallel.hpp | 4 +- src/common/tensor.cpp | 14 +-- src/common/tensor2d.hpp | 3 +- src/common/tensor2d_helper.hpp | 19 ++-- src/common/utility.hpp | 36 ++++---- src/emb_gpt_api.cpp | 26 +++--- src/emb_gpt_avx512.cpp | 44 ++++----- src/emb_gpt_avx512.hpp | 24 ++--- src/fc_kernel_amx.cpp | 51 +++++----- src/fc_kernel_amx.hpp | 4 +- src/fc_kernel_api.cpp | 4 +- src/gelu_kernel_avx512.hpp | 2 +- src/mha_gpt_amx.cpp | 84 ++++++++--------- src/mha_gpt_amx.hpp | 2 +- src/mha_gpt_api.cpp | 6 +- src/mm_kernel_amx.cpp | 33 ++++--- src/mm_kernel_amx.hpp | 7 +- src/mm_kernel_api.cpp | 9 +- src/mm_kernel_common_amx.hpp | 2 +- src/rotary_kernel_avx2.hpp | 2 +- src/rotary_kernel_avx512.hpp | 2 +- src/softmax_kernel_avx512.hpp | 2 +- src/transpose_kernel_avx512.hpp | 2 +- src/utility_kernel_amx.hpp | 14 --- src/utility_kernel_avx2.hpp | 16 +--- src/utility_kernel_avx512.hpp | 2 +- tests/src/test_common.cpp | 18 ++-- tests/src/test_fc_kernel_amx.cpp | 92 +++++++++---------- tests/src/test_mm_kernel_amx.cpp | 15 +-- tests/src/test_rotary_kernel_avx2.cpp | 4 +- tests/src/test_rotary_kernel_avx512.cpp | 4 +- tests/src/test_softmax_kernel_avx512.cpp | 8 +- tests/src/test_transpose_kernel_avx512.cpp | 8 +- .../test_utility_kernel_repack1x2_avx512.cpp | 4 +- 43 files changed, 349 insertions(+), 344 deletions(-) create mode 100644 src/common/log.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index d499054..9e69cd8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -8,6 +8,7 @@ project(root) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) option(CPU_EXTENSIONS_BUILD_TESTS "Build with tests" ON) +option(CPU_EXTENSIONS_ENABLE_LOG "Enable log" OFF) message(INFO "--------------------------------") message(STATUS "Build with tests: ${CPU_EXTENSIONS_BUILD_TESTS}") diff --git a/include/llm_emb_gpt.hpp b/include/llm_emb_gpt.hpp index c642a2a..3074cb1 100644 --- a/include/llm_emb_gpt.hpp +++ b/include/llm_emb_gpt.hpp @@ -12,17 +12,17 @@ namespace llmdnn { -void emb_gpt(const tensor& q_src, // q shape: [batch, query_seq_len, head_num, head_size] - const tensor& k_src, // k shape: [batch, query_seq_len, head_num, head_size] - const tensor& v_src, // v shape: [batch, query_seq_len, head_num, head_size] - const tensor& k_past, // k_past shape: [batch, num_heads, past_seq_len, head_size] - const tensor& v_past, // v_past shape: [batch, num_heads, past_seq_len, head_size] - const tensor& q_dst, // q_dst, shape: [batch, num_heads, query_seq_len, head_size] - const tensor& k_dst, // k_past shape: [batch, num_heads, query_seq_len+past_seq_len, head_size] - // if k_past!=k_past_dst, will copy k_past to k_past_dst - const tensor& v_dst, // v_past shape: [batch, num_heads, query_seq_len+past_seq_len, head_size] - const tensor& cos, // cos lookup table, shape: [1, 1, max_seq_len, rotary_dims] - const tensor& sin, // sin lookup table, shape: [1, 1, max_seq_len, rotary_dims] - const tensor& position2d_ids); // shape: [batch, 2, query_seq_len] +status_t emb_gpt(const tensor& q_src, // q shape: [batch, query_seq_len, head_num, head_size] + const tensor& k_src, // k shape: [batch, query_seq_len, head_num, head_size] + const tensor& v_src, // v shape: [batch, query_seq_len, head_num, head_size] + const tensor& k_past, // k_past shape: [batch, num_heads, past_seq_len, head_size] + const tensor& v_past, // v_past shape: [batch, num_heads, past_seq_len, head_size] + const tensor& q_dst, // q_dst, shape: [batch, num_heads, query_seq_len, head_size] + const tensor& k_dst, // k_past shape: [batch, num_heads, query_seq_len+past_seq_len, head_size] + // if k_past!=k_past_dst, will copy k_past to k_past_dst + const tensor& v_dst, // v_past shape: [batch, num_heads, query_seq_len+past_seq_len, head_size] + const tensor& cos, // cos lookup table, shape: [1, 1, max_seq_len, rotary_dims] + const tensor& sin, // sin lookup table, shape: [1, 1, max_seq_len, rotary_dims] + const tensor& position2d_ids); // shape: [batch, 2, query_seq_len] -} +} // namespace llmdnn diff --git a/include/llm_fc.hpp b/include/llm_fc.hpp index 1272c1b..afa4c31 100644 --- a/include/llm_fc.hpp +++ b/include/llm_fc.hpp @@ -60,7 +60,7 @@ struct fc_kernel; /// fc: (bf16,s8,f32),dq,[bias],[gelu] /// fc: (bf16,s8,bf16),dq,[bias],[gelu] /// -bool fc_kernel_create(fc_kernel** mm, const fc_create_param* param); +status_t fc_kernel_create(fc_kernel** mm, const fc_create_param* param); void fc_kernel_destroy(const fc_kernel* mm); void fc_kernel_execute(const fc_kernel* mm, void* ptr_a, void* ptr_b, void* ptr_c, size_t lda, size_t ldb, size_t ldc, @@ -71,4 +71,4 @@ void fc_kernel_execute(const fc_kernel* mm, /// compute weight min/max once, set q, dq for each fc_kernel instance void fc_kernel_bf16w8_get_q_dq(size_t K, size_t N, size_t stride, void* ptr, float* q, float* dq); -} +} // namespace llmdnn diff --git a/include/llm_mha_gpt.hpp b/include/llm_mha_gpt.hpp index 3df2db6..68942bb 100644 --- a/include/llm_mha_gpt.hpp +++ b/include/llm_mha_gpt.hpp @@ -17,24 +17,25 @@ class mha_gpt { mha_gpt(); ~mha_gpt(); - void exec(const tensor& q, // q shape: [batch, num_heads, query_seq_len, head_size] - const tensor& k, // k shape: [batch, num_heads, key_seq_len, head_size] - const tensor& v, // v shape: [batch, num_heads, value_seq_len, head_size] - const tensor& output, // output, compact, shape: [batch, query_seq_len, num_heads * head_size] - const tensor& attn_mask, // attention mask[opt], shape: - // [batch, 1, 1, key_seq_len], - // [batch, 1, query_seq_len, key_seq_len] - const tensor& alibi, // alibi[opt] shape: [batch, num_heads, 1, key_seq_len] - const tensor& causal_mask, // [opt] use_causal_mask must be false, u8, shape: - // [1, 1, query_seq_len, key_seq_len] - // [batch, 1, query_seq_len, key_seq_len] - bool select_nfltmax_at_0, // used when causal_mask is not null. true: causal_mask=0 use -FLT_MAX - float normal_factor, - bool use_causal_mask = false);// add causal mask + status_t exec(const tensor& q, // q shape: [batch, num_heads, query_seq_len, head_size] + const tensor& k, // k shape: [batch, num_heads, key_seq_len, head_size] + const tensor& v, // v shape: [batch, num_heads, value_seq_len, head_size] + const tensor& output, // output, compact, shape: [batch, query_seq_len, num_heads * head_size] + const tensor& attn_mask, // attention mask[opt], shape: + // [batch, 1, 1, key_seq_len], + // [batch, 1, query_seq_len, key_seq_len] + const tensor& alibi, // alibi[opt] shape: [batch, num_heads, 1, key_seq_len] + const tensor& causal_mask, // [opt] use_causal_mask must be false, u8, shape: + // [1, 1, query_seq_len, key_seq_len] + // [batch, 1, query_seq_len, key_seq_len] + bool select_nfltmax_at_0, // used when causal_mask is not null. true means causal_mask[i]==0 use -FLT_MAX + // false means causal_mask[i]==1 use -FLT_MAX + float normal_factor, + bool use_causal_mask = false);// add causal mask struct impl { virtual ~impl() {} - virtual void exec(const tensor& q, + virtual status_t exec(const tensor& q, const tensor& k, const tensor& v, const tensor& output, @@ -49,4 +50,4 @@ class mha_gpt { impl* _impl; }; -} +} // namespace llmdnn diff --git a/include/llm_mm.hpp b/include/llm_mm.hpp index bb187f7..52dc420 100644 --- a/include/llm_mm.hpp +++ b/include/llm_mm.hpp @@ -26,10 +26,10 @@ struct mm_kernel; /// matmul: (bf16,bf16,f32) /// gemv: (bf16,bf16,f32) /// -bool mm_kernel_create(mm_kernel** mm, const mm_create_param* param); +status_t mm_kernel_create(mm_kernel** mm, const mm_create_param* param); void mm_kernel_destroy(const mm_kernel* mm); -void mm_kernel_execute(const mm_kernel* mm, void* ptr_a, void* ptr_b, void* ptr_c, size_t lda, size_t ldb, size_t ldc, +status_t mm_kernel_execute(const mm_kernel* mm, void* ptr_a, void* ptr_b, void* ptr_c, size_t lda, size_t ldb, size_t ldc, size_t M, size_t N, size_t K); -} +} // namespace llmdnn diff --git a/include/llm_tensor.hpp b/include/llm_tensor.hpp index 6ca311e..615d32f 100644 --- a/include/llm_tensor.hpp +++ b/include/llm_tensor.hpp @@ -23,32 +23,32 @@ namespace llmdnn { template struct precision_of { - static constexpr data_type_t value = dnnl_data_type_undef; + static constexpr data_type_t value = llmdnn_data_type_undef; }; template <> struct precision_of { - static constexpr data_type_t value = dnnl_f32; + static constexpr data_type_t value = llmdnn_f32; }; template <> struct precision_of { - static constexpr data_type_t value = dnnl_s32; + static constexpr data_type_t value = llmdnn_s32; }; template <> struct precision_of { - static constexpr data_type_t value = dnnl_bf16; + static constexpr data_type_t value = llmdnn_bf16; }; template <> struct precision_of { - static constexpr data_type_t value = dnnl_u8; + static constexpr data_type_t value = llmdnn_u8; }; template <> struct precision_of { - static constexpr data_type_t value = dnnl_s8; + static constexpr data_type_t value = llmdnn_s8; }; @@ -61,7 +61,7 @@ struct tensor { void* m_ptr = nullptr; size_t m_capacity = 0; // 0 means not own m_ptr size_t m_element_size = 0; - data_type_t m_dtype = dnnl_data_type_undef; + data_type_t m_dtype = llmdnn_data_type_undef; tensor(); ~tensor(); diff --git a/include/llm_types.hpp b/include/llm_types.hpp index 3ace6d6..2a48041 100644 --- a/include/llm_types.hpp +++ b/include/llm_types.hpp @@ -8,29 +8,35 @@ namespace llmdnn { -// from oneDNN /// Data type specification typedef enum { /// Undefined data type, used for empty memory descriptors. - dnnl_data_type_undef = 0, + llmdnn_data_type_undef = 0, /// 16-bit/half-precision floating point. - dnnl_f16 = 1, + llmdnn_f16 = 1, /// non-standard 16-bit (bfloat16 w/ 7 bit mantissa) floating point. - dnnl_bf16 = 2, + llmdnn_bf16 = 2, /// 32-bit/single-precision floating point. - dnnl_f32 = 3, + llmdnn_f32 = 3, /// 32-bit signed integer. - dnnl_s32 = 4, + llmdnn_s32 = 4, /// 8-bit signed integer. - dnnl_s8 = 5, + llmdnn_s8 = 5, /// 8-bit unsigned integer. - dnnl_u8 = 6, + llmdnn_u8 = 6, /// 64-bit/double-precision floating point. - dnnl_f64 = 7, + llmdnn_f64 = 7, /// Parameter to allow internal only data_types without undefined behavior. /// This parameter is chosen to be valid for so long as sizeof(int) >= 2. - dnnl_data_type_max = 0x7fff, + llmdnn_data_type_max = 0x7fff, } data_type_t; -} \ No newline at end of file +typedef enum { + status_ok, + status_invalid_arguments, + status_unimplemented, + status_fail = 10 +} status_t; + +} // namespace llmdnn diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 6dd7b49..deeabe4 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -15,6 +15,9 @@ target_include_directories(${PROJECT_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} $/${CMAKE_INSTALL_INCLUDEDIR}>) target_compile_options(${PROJECT_NAME} PRIVATE ${EXTRA_CXX_FLAGS}) target_compile_features(${PROJECT_NAME} PRIVATE cxx_std_17) +if(CPU_EXTENSIONS_ENABLE_LOG) + target_compile_definitions(${PROJECT_NAME} PRIVATE ENABLE_LOG) +endif() set(CMAKE_DST lib/cmake/${PROJECT_NAME}) # header files diff --git a/src/common/log.hpp b/src/common/log.hpp new file mode 100644 index 0000000..03d969e --- /dev/null +++ b/src/common/log.hpp @@ -0,0 +1,11 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#ifdef ENABLE_LOG + #define DEBUG_LOG std::cout +#else + #define DEBUG_LOG if (0) std::cout +#endif diff --git a/src/common/simple_parallel.hpp b/src/common/simple_parallel.hpp index ef5534e..1f7de95 100644 --- a/src/common/simple_parallel.hpp +++ b/src/common/simple_parallel.hpp @@ -9,7 +9,7 @@ #include #include -namespace utility { +namespace llmdnn { size_t get_total_threads(); void simple_parallel_for(const size_t total, const std::function& fn); @@ -178,4 +178,4 @@ void parallel_for3d(const T0& D0, const T1& D1, const T2& D2, const F& func) { } } -} // namespace utility \ No newline at end of file +} // namespace llmdnn \ No newline at end of file diff --git a/src/common/tensor.cpp b/src/common/tensor.cpp index 5ea01b1..95c37fe 100644 --- a/src/common/tensor.cpp +++ b/src/common/tensor.cpp @@ -8,6 +8,7 @@ #include #include +#include "common/log.hpp" #include "bf16.hpp" #include "llm_tensor.hpp" @@ -99,7 +100,6 @@ tensor tensor::reshape(const std::initializer_list& target_shape) const // only valid for dense memory tensor new_tensor_view; assert(is_dense()); - //assert(shape_size(target_shape) == shape_size(m_dims)); new_tensor_view.resize(std::vector(target_shape), m_ptr, m_element_size, m_dtype); return new_tensor_view; } @@ -149,16 +149,16 @@ void tensor::resize(const size_t* new_dims, size_t dim_num, void* data, size_t e void tensor::assert_dims(const std::initializer_list& expect_dims) const { if (m_rank != expect_dims.size()) { - std::cout << "dims not same\n"; + DEBUG_LOG << "dims not same\n"; } if (!std::equal(expect_dims.begin(), expect_dims.end(), m_dims)) { - std::cout << " m_dims=["; + DEBUG_LOG << " m_dims=["; for (size_t i = 0; i < m_rank; i++) - std::cout << m_dims[i] << ","; - std::cout << "] expect_dims=["; + DEBUG_LOG << m_dims[i] << ","; + DEBUG_LOG << "] expect_dims=["; for (auto& i : expect_dims) - std::cout << i << ","; - std::cout << "]"; + DEBUG_LOG << i << ","; + DEBUG_LOG << "]"; } } diff --git a/src/common/tensor2d.hpp b/src/common/tensor2d.hpp index b406645..ec7c975 100644 --- a/src/common/tensor2d.hpp +++ b/src/common/tensor2d.hpp @@ -12,6 +12,7 @@ #ifdef ENABLE_NUMA #include "numa.h" #endif +#include "log.hpp" #include "bf16.hpp" #define rndup(x, n) (((x + n - 1)/n)*n) @@ -122,7 +123,7 @@ struct tensor2D { if (is_const) memset(static_cast(data), 0, need_capacity); if (reinterpret_cast(data) % 64) - std::cout << "WARNING: resize(), data is not cache-line aligned!" << std::endl; + DEBUG_LOG << "WARNING: resize(), data is not cache-line aligned!" << std::endl; } // put a NaN at the end to test over-read // https://en.wikipedia.org/wiki/Bfloat16_floating-point_format diff --git a/src/common/tensor2d_helper.hpp b/src/common/tensor2d_helper.hpp index 417b2a6..5221280 100644 --- a/src/common/tensor2d_helper.hpp +++ b/src/common/tensor2d_helper.hpp @@ -73,7 +73,7 @@ bool operator==(const tensor2D& lhs, const tensor2D& rhs) { float f0 = lhs(i0,i1); float f1 = rhs(i0,i1); if (isnan2(f1) || isnan2(f0)) { - std::cout << " nan is found: f0=" << f0 << ", f1=" << f1 << std::endl; + DEBUG_LOG << " nan is found: f0=" << f0 << ", f1=" << f1 << std::endl; return false; } if (std::abs(f0 - f1) <= 0.01) @@ -82,7 +82,7 @@ bool operator==(const tensor2D& lhs, const tensor2D& rhs) { if (lhs(i0,i1) == rhs(i0,i1)) continue; - std::cout << " operator== failed at (" << i0 << ", " << i1 << ") value " + DEBUG_LOG << " operator== failed at (" << i0 << ", " << i1 << ") value " << lhs(i0,i1) << "!=" << rhs(i0,i1) << std::endl; return false; } @@ -95,11 +95,11 @@ bool is_normal(const tensor2D& t) { for (int i1 = 0; i1 < t.dims[1]; i1++) { float f0 = t(i0,i1); if (isnan2(f0)) { - std::cout << " found nan at (" << i0 << "," << i1 << ")" << std::endl; + DEBUG_LOG << " found nan at (" << i0 << "," << i1 << ")" << std::endl; return false; } if (isinf2(f0)) { - std::cout << " found inf at (" << i0 << "," << i1 << ")" << std::endl; + DEBUG_LOG << " found inf at (" << i0 << "," << i1 << ")" << std::endl; return false; } } @@ -122,7 +122,7 @@ bool compare(const tensor2D& lhs, const tensor2D& rhs, float tolerance) { if (std::fabs(lhs(i0,i1) > 0) && diff > 0) max_rel_diff = std::max(max_rel_diff, rel_diff); } - std::cout << "max_abs_diff=" << max_abs_diff << " max_rel_diff=" << max_rel_diff << "\n"; + DEBUG_LOG << "max_abs_diff=" << max_abs_diff << " max_rel_diff=" << max_rel_diff << "\n"; return tolerance > max_abs_diff; } @@ -150,14 +150,13 @@ std::ostream& operator<<(std::ostream& out, const tensor2D& obj) { template inline void show(const T * data, int rows, int cols) { - std::ostream& out = std::cout; - out << "==============\n"; + DEBUG_LOG << "==============\n"; for(int i0=0; i0 < rows; i0++) { - out << "[" << i0 << "," << 0 << "]: "; + DEBUG_LOG << "[" << i0 << "," << 0 << "]: "; for(int i1=0; i1 name2type[] = { - { "f16", dnnl_f16 }, - { "bf16", dnnl_bf16 }, - { "f32", dnnl_f32 }, - { "s32", dnnl_s32 }, - { "i32", dnnl_s32 }, - { "s8", dnnl_s8 }, - { "i8", dnnl_s8 }, - { "u8", dnnl_u8 }, - { "f64", dnnl_f64 }, + { "f16", llmdnn_f16 }, + { "bf16", llmdnn_bf16 }, + { "f32", llmdnn_f32 }, + { "s32", llmdnn_s32 }, + { "i32", llmdnn_s32 }, + { "s8", llmdnn_s8 }, + { "i8", llmdnn_s8 }, + { "u8", llmdnn_u8 }, + { "f64", llmdnn_f64 }, }; for (size_t i = 0; i < sizeof(name2type) / sizeof(name2type[0]); i++) { if (name == name2type[i].first) return name2type[i].second; } - return dnnl_data_type_undef; + return llmdnn_data_type_undef; } -} \ No newline at end of file +} // namespace llmdnn diff --git a/src/emb_gpt_api.cpp b/src/emb_gpt_api.cpp index 8444c7d..792ff8c 100644 --- a/src/emb_gpt_api.cpp +++ b/src/emb_gpt_api.cpp @@ -9,18 +9,18 @@ namespace llmdnn { -void emb_gpt(const tensor& q_src, - const tensor& k_src, - const tensor& v_src, - const tensor& k_past, - const tensor& v_past, - const tensor& q_dst, - const tensor& k_dst, - const tensor& v_dst, - const tensor& cos, - const tensor& sin, - const tensor& position2d_ids) { - emb_gpt_avx512(q_src, k_src, v_src, k_past, v_past, q_dst, k_dst, v_dst, cos, sin, position2d_ids); +status_t emb_gpt(const tensor& q_src, + const tensor& k_src, + const tensor& v_src, + const tensor& k_past, + const tensor& v_past, + const tensor& q_dst, + const tensor& k_dst, + const tensor& v_dst, + const tensor& cos, + const tensor& sin, + const tensor& position2d_ids) { + return emb_gpt_avx512(q_src, k_src, v_src, k_past, v_past, q_dst, k_dst, v_dst, cos, sin, position2d_ids); } -} \ No newline at end of file +} // namespace llmdnn diff --git a/src/emb_gpt_avx512.cpp b/src/emb_gpt_avx512.cpp index 00c7be8..2cd527d 100644 --- a/src/emb_gpt_avx512.cpp +++ b/src/emb_gpt_avx512.cpp @@ -7,17 +7,17 @@ #include #include +#include "common/log.hpp" #include "common/bf16.hpp" #include "common/simple_parallel.hpp" #include "common/utility.hpp" +#include "llm_types.hpp" #include "utility_kernel_avx512.hpp" #include "transpose_kernel_avx512.hpp" #include "llm_emb_gpt.hpp" #include "emb_gpt_avx512.hpp" #include "rotary_kernel_avx512.hpp" -using namespace utility; - namespace llmdnn { static void memcpy_past_kv(const tensor& k_past, const tensor& v_past, const tensor& k_dst, const tensor& v_dst) { @@ -86,30 +86,30 @@ static void rotary_emb_position2d(const tensor& q_src, }); } -void emb_gpt_avx512(const tensor& q_src, - const tensor& k_src, - const tensor& v_src, - const tensor& k_past, - const tensor& v_past, - const tensor& q_dst, - const tensor& k_dst, - const tensor& v_dst, - const tensor& cos, - const tensor& sin, - const tensor& position2d_ids) { +status_t emb_gpt_avx512(const tensor& q_src, + const tensor& k_src, + const tensor& v_src, + const tensor& k_past, + const tensor& v_past, + const tensor& q_dst, + const tensor& k_dst, + const tensor& v_dst, + const tensor& cos, + const tensor& sin, + const tensor& position2d_ids) { if (q_src.m_rank != 4 || k_src.m_rank != 4 || v_src.m_rank != 4 || k_past.m_rank != 4 || v_past.m_rank != 4 || q_dst.m_rank != 4|| k_dst.m_rank != 4 || v_dst.m_rank != 4 || cos.m_rank != 4 || sin.m_rank != 4) { - std::cout << "emb_gpt_avx512: rank is not correct: should be 4\n"; - return; + DEBUG_LOG << "emb_gpt_avx512: rank is not correct: should be 4\n"; + return status_t::status_invalid_arguments; } if (position2d_ids) { if (position2d_ids.m_rank != 3) { - std::cout << "emb_gpt_avx512: position2d_ids rank should be 3\n"; - return; + DEBUG_LOG << "emb_gpt_avx512: position2d_ids rank should be 3\n"; + return status_t::status_invalid_arguments; } if (position2d_ids.m_dims[0] != q_src.m_dims[0] || position2d_ids.m_dims[1] != 2 || position2d_ids.m_dims[2] != q_src.m_dims[1]) { - std::cout << "emb_gpt_avx512: position2d_ids dims should be [batch, 2, seq_len]\n"; - return; + DEBUG_LOG << "emb_gpt_avx512: position2d_ids dims should be [batch, 2, seq_len]\n"; + return status_t::status_invalid_arguments; } } @@ -125,7 +125,7 @@ void emb_gpt_avx512(const tensor& q_src, // transpose: [batch, seq_len, head_hum, 3 * head_size] --> // 3 [batch, head_hum, seq_len, head_size] // rotary embbeding: part of key will write to past_key, part of query will write to tempory buffer - if (q_src.m_dtype == dnnl_s8) { + if (q_src.m_dtype == llmdnn_s8) { assert(false); } else { // query pass part(temp buffer): query = torch.cat((query, query_pass), dim=-1) @@ -133,6 +133,8 @@ void emb_gpt_avx512(const tensor& q_src, // value(pastKeys): value = torch.cat((past_value, value), dim=-2) rotary_emb_position2d(q_src, k_src, v_src, k_past, v_past, q_dst, k_dst, v_dst, cos, sin, position2d_ids); } + + return status_t::status_ok; } -} \ No newline at end of file +} // namespace llmdnn diff --git a/src/emb_gpt_avx512.hpp b/src/emb_gpt_avx512.hpp index 0ef4d43..c5bd134 100644 --- a/src/emb_gpt_avx512.hpp +++ b/src/emb_gpt_avx512.hpp @@ -12,15 +12,15 @@ namespace llmdnn { -void emb_gpt_avx512(const tensor& q_src, - const tensor& k_src, - const tensor& v_src, - const tensor& k_past, - const tensor& v_past, - const tensor& q_dst, - const tensor& k_dst, - const tensor& v_dst, - const tensor& cos, - const tensor& sin, - const tensor& position2d_ids); -} +status_t emb_gpt_avx512(const tensor& q_src, + const tensor& k_src, + const tensor& v_src, + const tensor& k_past, + const tensor& v_past, + const tensor& q_dst, + const tensor& k_dst, + const tensor& v_dst, + const tensor& cos, + const tensor& sin, + const tensor& position2d_ids); +} // namespace llmdnn diff --git a/src/fc_kernel_amx.cpp b/src/fc_kernel_amx.cpp index 75ec107..85c9064 100644 --- a/src/fc_kernel_amx.cpp +++ b/src/fc_kernel_amx.cpp @@ -10,6 +10,7 @@ #include #include "llm_fc.hpp" +#include "llm_types.hpp" #include "mm_kernel_common_amx.hpp" #include "utility_kernel_avx512.hpp" #include "fc_kernel_amx.hpp" @@ -35,13 +36,13 @@ using supported_key = std::tuple; using supported_value = std::pair; static bool check_valid_postops(size_t value, data_type_t dt_a, data_type_t dt_b, data_type_t dt_c) { llm_map supported_postops = { - { { dnnl_s8, dnnl_s8, dnnl_s8 }, { DEQUANT | QUANT, BIAS | GELU | GELU_TANH } }, - { { dnnl_s8, dnnl_s8, dnnl_bf16 }, { DEQUANT, BIAS | GELU | GELU_TANH } }, - { { dnnl_s8, dnnl_s8, dnnl_f32 }, { DEQUANT, BIAS | GELU | GELU_TANH } }, - { { dnnl_bf16, dnnl_bf16, dnnl_bf16 }, { 0, BIAS | GELU | GELU_TANH } }, - { { dnnl_bf16, dnnl_bf16, dnnl_f32 }, { 0, BIAS | GELU | GELU_TANH } }, - { { dnnl_bf16, dnnl_s8, dnnl_f32 }, { DEQUANT, BIAS | GELU | GELU_TANH } }, - { { dnnl_bf16, dnnl_s8, dnnl_bf16 }, { DEQUANT, BIAS | GELU | GELU_TANH } }, + { { llmdnn_s8, llmdnn_s8, llmdnn_s8 }, { DEQUANT | QUANT, BIAS | GELU | GELU_TANH } }, + { { llmdnn_s8, llmdnn_s8, llmdnn_bf16 }, { DEQUANT, BIAS | GELU | GELU_TANH } }, + { { llmdnn_s8, llmdnn_s8, llmdnn_f32 }, { DEQUANT, BIAS | GELU | GELU_TANH } }, + { { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16 }, { 0, BIAS | GELU | GELU_TANH } }, + { { llmdnn_bf16, llmdnn_bf16, llmdnn_f32 }, { 0, BIAS | GELU | GELU_TANH } }, + { { llmdnn_bf16, llmdnn_s8, llmdnn_f32 }, { DEQUANT, BIAS | GELU | GELU_TANH } }, + { { llmdnn_bf16, llmdnn_s8, llmdnn_bf16 }, { DEQUANT, BIAS | GELU | GELU_TANH } }, }; auto it = supported_postops.find(std::make_tuple(dt_a, dt_b, dt_c)); @@ -64,32 +65,32 @@ static bool check_valid_postops(size_t value, data_type_t dt_a, data_type_t dt_b } // interface -bool fc_kernel_create_amx(fc_kernel** mm, const fc_create_param* param) { +status_t fc_kernel_create_amx(fc_kernel** mm, const fc_create_param* param) { fc_kernel* m = nullptr; if (param == nullptr || mm == nullptr) { - std::cout << "fc_kernel_create: invalid input parameter.\n"; + DEBUG_LOG << "fc_kernel_create: invalid input parameter.\n"; goto ERR; } if (!check_valid_postops(static_cast(param->postops_type), param->dt_a, param->dt_b, param->dt_c)) { - std::cout << "fc_kernel_create: unsupported data type, a: " << param->dt_a <<", b: " << param->dt_b << ", c: " << param->dt_c << + DEBUG_LOG << "fc_kernel_create: unsupported data type, a: " << param->dt_a <<", b: " << param->dt_b << ", c: " << param->dt_c << ", postops type: " << param->postops_type << ".\n"; goto ERR; } m = new fc_kernel; - if (param->dt_a == dnnl_s8 && param->dt_b == dnnl_s8) { + if (param->dt_a == llmdnn_s8 && param->dt_b == llmdnn_s8) { m->i8xi8 = std::make_unique>(true, param->b_is_trans); - } else if (param->dt_a == dnnl_u8 && param->dt_b == dnnl_s8) { + } else if (param->dt_a == llmdnn_u8 && param->dt_b == llmdnn_s8) { m->u8xi8 = std::make_unique>(true, param->b_is_trans); - } else if (param->dt_a == dnnl_bf16 && param->dt_b == dnnl_bf16) { + } else if (param->dt_a == llmdnn_bf16 && param->dt_b == llmdnn_bf16) { m->bf16xbf16 = std::make_unique>(true, param->b_is_trans); - } else if (param->dt_a == dnnl_bf16 && param->dt_b == dnnl_s8) { + } else if (param->dt_a == llmdnn_bf16 && param->dt_b == llmdnn_s8) { m->bf16xi8 = std::make_unique>(true, param->b_is_trans); m->bf16xi8->quant_scale_B = param->q; m->bf16xi8->dequant_scale_B = param->dq; } else { - std::cout << "fc_kernel_create: unsupport input type, a: " << param->dt_a << ", b: " << param->dt_b << ".\n"; + DEBUG_LOG << "fc_kernel_create: unsupport input type, a: " << param->dt_a << ", b: " << param->dt_b << ".\n"; goto ERR; } @@ -100,10 +101,10 @@ bool fc_kernel_create_amx(fc_kernel** mm, const fc_create_param* param) { m->postops_type = param->postops_type; *mm = m; - return true; + return status_t::status_ok; ERR: delete m; - return false; + return status_t::status_invalid_arguments; } void fc_kernel_destroy_amx(const fc_kernel* mm) { @@ -123,7 +124,7 @@ void fc_kernel_execute_amx(const fc_kernel* mm, void* ptr_a, void* ptr_b, void* tensor2D a(M, K, reinterpret_cast(ptr_a), lda); tensor2D b(b_d0, b_d1, reinterpret_cast(ptr_b), ldb); - if (mm->dt_c == dnnl_s8) { + if (mm->dt_c == llmdnn_s8) { tensor2D c(M, N, reinterpret_cast(ptr_c), ldc); if (!(mm->postops_type & BIAS)) { if (mm->postops_type & GELU) { @@ -160,7 +161,7 @@ void fc_kernel_execute_amx(const fc_kernel* mm, void* ptr_a, void* ptr_b, void* (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); } } - } else if (mm->dt_c == dnnl_bf16) { + } else if (mm->dt_c == llmdnn_bf16) { tensor2D c(M, N, reinterpret_cast(ptr_c), ldc); if (!bias) { if (mm->postops_type & GELU) { @@ -191,7 +192,7 @@ void fc_kernel_execute_amx(const fc_kernel* mm, void* ptr_a, void* ptr_b, void* (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); } } - } else if (mm->dt_c == dnnl_f32) { + } else if (mm->dt_c == llmdnn_f32) { tensor2D c(M, N, reinterpret_cast(ptr_c), ldc); if (!bias) { if (mm->postops_type & GELU) { @@ -233,7 +234,7 @@ void fc_kernel_execute_amx(const fc_kernel* mm, void* ptr_a, void* ptr_b, void* tensor2D a(M, K, reinterpret_cast(ptr_a), lda); tensor2D b(b_d0, b_d1, reinterpret_cast(ptr_b), ldb); - if (mm->dt_c == dnnl_bf16) { + if (mm->dt_c == llmdnn_bf16) { tensor2D c(M, N, reinterpret_cast(ptr_c), ldc); if (!(mm->postops_type & BIAS)) { if (mm->postops_type & GELU) { @@ -258,7 +259,7 @@ void fc_kernel_execute_amx(const fc_kernel* mm, void* ptr_a, void* ptr_b, void* (*mm->bf16xbf16)(a, b, n_start, n_end, ppkernel); } } - } else if (mm->dt_c == dnnl_f32) { + } else if (mm->dt_c == llmdnn_f32) { tensor2D c(M, N, reinterpret_cast(ptr_c), ldc); if (!(mm->postops_type & BIAS)) { if (mm->postops_type & GELU) { @@ -288,7 +289,7 @@ void fc_kernel_execute_amx(const fc_kernel* mm, void* ptr_a, void* ptr_b, void* tensor2D a(M, K, reinterpret_cast(ptr_a), lda); tensor2D b(N, K, reinterpret_cast(ptr_b), ldb); - if (mm->dt_c == dnnl_bf16) { + if (mm->dt_c == llmdnn_bf16) { tensor2D c(M, N, reinterpret_cast(ptr_c), ldc); if (!(mm->postops_type & BIAS)) { if (mm->postops_type & GELU) { @@ -313,7 +314,7 @@ void fc_kernel_execute_amx(const fc_kernel* mm, void* ptr_a, void* ptr_b, void* (*mm->bf16xi8)(a, b, n_start, n_end, ppkernel); } } - } else if (mm->dt_c == dnnl_f32) { + } else if (mm->dt_c == llmdnn_f32) { tensor2D c(M, N, reinterpret_cast(ptr_c), ldc); if (!(mm->postops_type & BIAS)) { if (mm->postops_type & GELU) { @@ -351,4 +352,4 @@ void fc_kernel_bf16w8_get_q_dq_amx(size_t K, size_t N, size_t stride, void* ptr, *dq = max / 127; } -} \ No newline at end of file +} // namespace llmdnn diff --git a/src/fc_kernel_amx.hpp b/src/fc_kernel_amx.hpp index 93bf47f..224c574 100644 --- a/src/fc_kernel_amx.hpp +++ b/src/fc_kernel_amx.hpp @@ -6,7 +6,7 @@ namespace llmdnn { -bool fc_kernel_create_amx(fc_kernel** mm, const fc_create_param* param); +status_t fc_kernel_create_amx(fc_kernel** mm, const fc_create_param* param); void fc_kernel_destroy_amx(const fc_kernel* mm); @@ -15,4 +15,4 @@ void fc_kernel_execute_amx(const fc_kernel* mm, void* ptr_a, void* ptr_b, void* void fc_kernel_bf16w8_get_q_dq_amx(size_t K, size_t N, size_t stride, void* ptr, float* q, float* dq); -} \ No newline at end of file +} // namespace llmdnn diff --git a/src/fc_kernel_api.cpp b/src/fc_kernel_api.cpp index d7b3e9c..b5174a0 100644 --- a/src/fc_kernel_api.cpp +++ b/src/fc_kernel_api.cpp @@ -27,7 +27,7 @@ static decltype(&fc_kernel_execute) fc_kernel_execute_ptr = fc_kernel_execute_am static decltype(&fc_kernel_bf16w8_get_q_dq) fc_kernel_bf16w8_get_q_dq_ptr = fc_kernel_bf16w8_get_q_dq_amx; // interface -bool fc_kernel_create(fc_kernel** mm, const fc_create_param* param) { +status_t fc_kernel_create(fc_kernel** mm, const fc_create_param* param) { return fc_kernel_create_ptr(mm, param); } @@ -44,4 +44,4 @@ void fc_kernel_bf16w8_get_q_dq(size_t K, size_t N, size_t stride, void* ptr, flo fc_kernel_bf16w8_get_q_dq_ptr(K, N, stride, ptr, q, dq); } -} \ No newline at end of file +} // namespace llmdnn diff --git a/src/gelu_kernel_avx512.hpp b/src/gelu_kernel_avx512.hpp index c416ad5..de635b1 100644 --- a/src/gelu_kernel_avx512.hpp +++ b/src/gelu_kernel_avx512.hpp @@ -383,4 +383,4 @@ namespace llmdnn { dst = _mm512_mul_ps(dst, x); return dst; } -} \ No newline at end of file +} // namespace llmdnn diff --git a/src/mha_gpt_amx.cpp b/src/mha_gpt_amx.cpp index e4ab763..e47c7bc 100644 --- a/src/mha_gpt_amx.cpp +++ b/src/mha_gpt_amx.cpp @@ -5,6 +5,7 @@ #include #include +#include "common/log.hpp" #include "common/simple_parallel.hpp" #include "common/tensor2d.hpp" #include "common/utility.hpp" @@ -17,15 +18,13 @@ #include "llm_mha_gpt.hpp" #include "mha_gpt_amx.hpp" -using namespace utility; - namespace llmdnn { struct mha_gpt_impl_amx : public mha_gpt::impl { mha_gpt_impl_amx() = default; ~mha_gpt_impl_amx(); void create(data_type_t in_type, size_t seq_len, size_t head_size, bool is_bloom); - void exec(const tensor& q, const tensor& k, const tensor& v, const tensor& output, const tensor& attn_mask, + status_t exec(const tensor& q, const tensor& k, const tensor& v, const tensor& output, const tensor& attn_mask, const tensor& alibi, const tensor& causal_mask, bool select_nfltmax_at_0, float normal_factor, bool use_causal_mask) override; void mha_bf16(const tensor& q, const tensor& k, const tensor& v, const tensor& output, const tensor& attn_mask, @@ -123,12 +122,12 @@ void mha_gpt_impl_amx::mha_bf16(const tensor& q, const tensor& k, const tensor& auto& gemAvB_ops = gemAvB_BF16xBF16; auto& qKtrGemm_ops = qKtrGemm_BF16xBF16; auto& qKVGemm_ops = qKVGemm_BF16xBF16; - bool use_vector = query_seq_len == 1 && head_size >= 32 && head_size <= 32 * 6 && !is_bloom && !alibi && attn_mask && !causal_mask; + bool use_gemv = query_seq_len == 1 && head_size >= 32 && head_size <= 32 * 6 && !is_bloom && !alibi && attn_mask && !causal_mask; size_t head_stride_in_attn = head_size; size_t batch_stride_in_attn = head_size * head_num * query_seq_len; size_t causal_mask_offset_start = use_causal_mask ? key_seq_len - query_seq_len : key_seq_len; - if (use_vector) { + if (use_gemv) { parallel_for2d(batch, head_num, [&](size_t thread_id, size_t i0, size_t i1) { auto q_sub = &q.at({i0, i1}); auto k_sub = &k.at({i0, i1}); @@ -157,8 +156,8 @@ void mha_gpt_impl_amx::mha_bf16(const tensor& q, const tensor& k, const tensor& head_size, _head_size_aligned * sizeof(float), head_num * head_size * sizeof(ov::bfloat16), nullptr); }); } else { - size_t seq_cout_all = rndup(query_seq_len, 32) / 32; - auto work_amount = batch * head_num * seq_cout_all; + size_t seq_count_all = rndup(query_seq_len, 32) / 32; + auto work_amount = batch * head_num * seq_count_all; parallel_for(_num_threads, [&](size_t thread_id) { size_t i0; size_t i1; @@ -167,13 +166,13 @@ void mha_gpt_impl_amx::mha_bf16(const tensor& q, const tensor& k, const tensor& splitter(work_amount, _num_threads, thread_id, start, end); if (start >= work_amount) return; - parallel_it_init(start, i0, batch, i1, head_num, seq, seq_cout_all); + parallel_it_init(start, i0, batch, i1, head_num, seq, seq_count_all); ov::bfloat16* prev_k = nullptr; ov::bfloat16* prev_v = nullptr; for (size_t iwork = start; iwork < end; ++iwork) { auto seq_start = seq * 32; auto seq_end = std::min(seq_start + 32, query_seq_len); - auto seq_cout = seq_end - seq_start; + auto seq_count = seq_end - seq_start; // q: [batch, head_num, query_seq_len, head_size] // k: [batch, head_num, key_seq_len, head_size] // v: [batch, head_num, value_seq_len, head_size] @@ -184,8 +183,8 @@ void mha_gpt_impl_amx::mha_bf16(const tensor& q, const tensor& k, const tensor& auto mat0_out = reinterpret_cast(_buffer_mat0_out + thread_id * _buffer_mat0_out_size); auto mat1_out = reinterpret_cast(_buffer_mat1_out + thread_id * _buffer_mat1_out_size); - tensor2D matQ(seq_cout, head_size, q_sub, q.m_strides[2]); - tensor2D matQK(seq_cout, key_seq_len, mat0_out, rndup(key_seq_len * sizeof(float), 64)); + tensor2D matQ(seq_count, head_size, q_sub, q.m_strides[2]); + tensor2D matQK(seq_count, key_seq_len, mat0_out, rndup(key_seq_len * sizeof(float), 64)); amx_kernel::PP::BiasGeluStore pp(matQK); if (!is_bloom) { tensor2D matK(key_seq_len, head_size, k_sub, k.m_strides[2]); @@ -195,13 +194,12 @@ void mha_gpt_impl_amx::mha_bf16(const tensor& q, const tensor& k, const tensor& (*qKtrGemm_ops[thread_id])(matQ, matK, 0, key_seq_len, pp, k_sub == prev_k); } prev_k = k_sub; - tensor2D softmax_dst(seq_cout, key_seq_len, reinterpret_cast(mat0_out), rndup(key_seq_len * sizeof(ov::bfloat16), 64)); - // no attention mask + tensor2D softmax_dst(seq_count, key_seq_len, reinterpret_cast(mat0_out), rndup(key_seq_len * sizeof(ov::bfloat16), 64)); size_t valid_softmax_items = std::min(causal_mask_offset_start + seq_start + 1, key_seq_len); // attn: [batch, 1, 1, key_seq_len] or [batch, 1, query_seq_len, key_seq_len] // alibi: [batch, num_heads, 1, key_seq_len] // causal: [batch/1, 1, query_seq_len, key_seq_len] - for (uint32_t m = 0; m < seq_cout; m++) { + for (uint32_t m = 0; m < seq_count; m++) { auto attn_sub = attn_mask ? &attn_mask.at({i0, 0, attn_mask.m_dims[2] == 1 ? 0 : m + seq_start}) : nullptr; auto alibi_sub = alibi ? &alibi.at({i0, i1}) : nullptr; auto causal_mask_sub = causal_mask ? &causal_mask.at({causal_mask.m_dims[0] == 1 ? 0 : i0, 0, m + seq_start}) : nullptr; @@ -218,60 +216,61 @@ void mha_gpt_impl_amx::mha_bf16(const tensor& q, const tensor& k, const tensor& auto out_sub = out + (i0 * batch_stride_in_attn + i1 * head_stride_in_attn + seq_start * head_stride_in_attn * head_num) * sizeof(ov::bfloat16); - tensor2D matQKBF16(seq_cout, key_seq_len, softmax_dst.data, softmax_dst.stride); + tensor2D matQKBF16(seq_count, key_seq_len, softmax_dst.data, softmax_dst.stride); tensor2D matV(key_seq_len, head_size, v_sub, v.m_strides[2]); - tensor2D matQKV(seq_cout, head_size, mat1_out, _head_size_aligned * sizeof(float)); + tensor2D matQKV(seq_count, head_size, mat1_out, _head_size_aligned * sizeof(float)); amx_kernel::PP::BiasGeluStore pp2(matQKV); (*qKVGemm_ops[thread_id])(matQKBF16, matV, 0, head_size, pp2, prev_v == v_sub); prev_v = v_sub; - memcpy2d_stride_avx512(reinterpret_cast(out_sub), mat1_out, seq_cout, + memcpy2d_stride_avx512(reinterpret_cast(out_sub), mat1_out, seq_count, head_size, _head_size_aligned * sizeof(float), head_num * head_size * sizeof(ov::bfloat16), nullptr); - parallel_it_step(i0, batch, i1, head_num, seq, seq_cout_all); + parallel_it_step(i0, batch, i1, head_num, seq, seq_count_all); } }); } } -void mha_gpt_impl_amx::exec(const tensor& q, const tensor& k, const tensor& v, const tensor& output, const tensor& attn_mask, const tensor& alibi, const tensor& causal_mask, bool select_nfltmax_at_0, float normal_factor, bool use_causal_mask) { +status_t mha_gpt_impl_amx::exec(const tensor& q, const tensor& k, const tensor& v, const tensor& output, const tensor& attn_mask, const tensor& alibi, const tensor& causal_mask, bool select_nfltmax_at_0, float normal_factor, bool use_causal_mask) { if (q.m_rank != 4 || k.m_rank != 4 || v.m_rank != 4) { - std::cout << "q,k,v rank does not equal 4.\n"; - return; + DEBUG_LOG << "q,k,v rank does not equal 4.\n"; + return status_t::status_invalid_arguments; } if (output.m_rank != 3) { - std::cout << "output rank should be 3.\n"; + DEBUG_LOG << "output rank should be 3.\n"; + return status_t::status_invalid_arguments; } if (attn_mask) { if (attn_mask.m_rank != 4) { - std::cout << "attn_mask rank should be 4.\n"; - return; + DEBUG_LOG << "attn_mask rank should be 4.\n"; + return status_t::status_invalid_arguments; } if (attn_mask.m_dims[1] != 1) { - std::cout << "attn_mask dim 1 should be 1.\n"; - return; + DEBUG_LOG << "attn_mask dim 1 should be 1.\n"; + return status_t::status_invalid_arguments; } } if (alibi) { if (alibi.m_rank != 4) { - std::cout << "alibi rank should be 4.\n"; - return; + DEBUG_LOG << "alibi rank should be 4.\n"; + return status_t::status_invalid_arguments; } if (alibi.m_dims[1] != k.m_dims[1]) { - std::cout << "alibi dim 1 should be equal to k dim 1.\n"; - return; + DEBUG_LOG << "alibi dim 1 should be equal to k dim 1.\n"; + return status_t::status_invalid_arguments; } if (alibi.m_dims[2] != 1) { - std::cout << "alibi dim 2 should be 1.\n"; - return; + DEBUG_LOG << "alibi dim 2 should be 1.\n"; + return status_t::status_invalid_arguments; } } if (causal_mask) { if (causal_mask.m_rank != 4) { - std::cout << "causal_mask rank should be 4.\n"; - return; + DEBUG_LOG << "causal_mask rank should be 4.\n"; + return status_t::status_invalid_arguments; } if (use_causal_mask) { - std::cout << "use_causal_mask must be false to disable builtin causal mask.\n"; - return; + DEBUG_LOG << "use_causal_mask must be false to disable builtin causal mask.\n"; + return status_t::status_invalid_arguments; } } auto batch = q.m_dims[0]; @@ -283,8 +282,8 @@ void mha_gpt_impl_amx::exec(const tensor& q, const tensor& k, const tensor& v, c head_num == k.m_dims[1] && head_num == v.m_dims[1] && key_seq_len == v.m_dims[2] && head_size == k.m_dims[3] && head_size == v.m_dims[3])) { - std::cout << "dim of q,k,v is error.\n"; - return; + DEBUG_LOG << "dim of q,k,v is error.\n"; + return status_t::status_invalid_arguments; } bool is_bloom = k.m_strides[3] > k.m_strides[2]; @@ -292,16 +291,19 @@ void mha_gpt_impl_amx::exec(const tensor& q, const tensor& k, const tensor& v, c auto in_dtype = q.m_dtype; auto out_dtype = output.m_dtype; - if (in_dtype == dnnl_bf16 && out_dtype == dnnl_bf16) { + if (in_dtype == llmdnn_bf16 && out_dtype == llmdnn_bf16) { create(in_dtype, key_seq_len, head_size, is_bloom); mha_bf16(q, k, v, output, attn_mask, alibi, causal_mask, select_nfltmax_at_0, normal_factor, use_causal_mask); } else { - std::cout << "doesn't support provided input precisions.\n"; + DEBUG_LOG << "doesn't support provided input precisions.\n"; + return status_t::status_invalid_arguments; } + + return status_t::status_ok; } mha_gpt::impl* new_impl_amx() { return new mha_gpt_impl_amx(); } -} \ No newline at end of file +} // namespace llmdnn diff --git a/src/mha_gpt_amx.hpp b/src/mha_gpt_amx.hpp index c012ec2..f75df98 100644 --- a/src/mha_gpt_amx.hpp +++ b/src/mha_gpt_amx.hpp @@ -14,4 +14,4 @@ namespace llmdnn { mha_gpt::impl* new_impl_amx(); -} +} // namespace llmdnn diff --git a/src/mha_gpt_api.cpp b/src/mha_gpt_api.cpp index 40d0038..0e35085 100644 --- a/src/mha_gpt_api.cpp +++ b/src/mha_gpt_api.cpp @@ -17,8 +17,8 @@ mha_gpt::~mha_gpt() { delete _impl; } -void mha_gpt::exec(const tensor& q, const tensor& k, const tensor& v, const tensor& output, const tensor& attn_mask, const tensor& alibi, const tensor& causal_mask, bool select_nfltmax_at_0, float normal_factor, bool use_causal_mask) { - _impl->exec(q, k, v, output, attn_mask, alibi, causal_mask, select_nfltmax_at_0, normal_factor, use_causal_mask); +status_t mha_gpt::exec(const tensor& q, const tensor& k, const tensor& v, const tensor& output, const tensor& attn_mask, const tensor& alibi, const tensor& causal_mask, bool select_nfltmax_at_0, float normal_factor, bool use_causal_mask) { + return _impl->exec(q, k, v, output, attn_mask, alibi, causal_mask, select_nfltmax_at_0, normal_factor, use_causal_mask); } -} \ No newline at end of file +} // namespace llmdnn diff --git a/src/mm_kernel_amx.cpp b/src/mm_kernel_amx.cpp index 43df55b..8fbc5ed 100644 --- a/src/mm_kernel_amx.cpp +++ b/src/mm_kernel_amx.cpp @@ -10,6 +10,7 @@ #include #include "llm_mm.hpp" +#include "llm_types.hpp" #include "mm_kernel_common_amx.hpp" #include "utility_kernel_avx512.hpp" #include "mm_kernel_amx.hpp" @@ -31,32 +32,32 @@ struct mm_kernel { }; // interface -bool mm_kernel_create_amx(mm_kernel** mm, const mm_create_param* param) { +status_t mm_kernel_create_amx(mm_kernel** mm, const mm_create_param* param) { mm_kernel* m = nullptr; if (param == nullptr || mm == nullptr) { - std::cout << "mm_kernel_create: invalid input parameter.\n"; + DEBUG_LOG << "mm_kernel_create: invalid input parameter.\n"; goto ERR; } m = new mm_kernel; if (param->b_is_gemv) { - if (param->dt_a == dnnl_s8 && param->dt_b == dnnl_s8) { + if (param->dt_a == llmdnn_s8 && param->dt_b == llmdnn_s8) { m->i8xi8_gemv = std::make_unique>(); - } else if (param->dt_a == dnnl_bf16 && param->dt_b == dnnl_bf16) { + } else if (param->dt_a == llmdnn_bf16 && param->dt_b == llmdnn_bf16) { m->bf16xbf16_gemv = std::make_unique>(); } else { - std::cout << "mm_kernel_create: unsupport gemv input type, a: " << param->dt_a << ", b: " << param->dt_b << ".\n"; + DEBUG_LOG << "mm_kernel_create: unsupport gemv input type, a: " << param->dt_a << ", b: " << param->dt_b << ".\n"; goto ERR; } } else { - if (param->dt_a == dnnl_s8 && param->dt_b == dnnl_s8) { + if (param->dt_a == llmdnn_s8 && param->dt_b == llmdnn_s8) { m->i8xi8 = std::make_unique>(false, param->b_is_trans); - } else if (param->dt_a == dnnl_u8 && param->dt_b == dnnl_s8) { + } else if (param->dt_a == llmdnn_u8 && param->dt_b == llmdnn_s8) { m->u8xi8 = std::make_unique>(false, param->b_is_trans); - } else if (param->dt_a == dnnl_bf16 && param->dt_b == dnnl_bf16) { + } else if (param->dt_a == llmdnn_bf16 && param->dt_b == llmdnn_bf16) { m->bf16xbf16 = std::make_unique>(false, param->b_is_trans); } else { - std::cout << "mm_kernel_create: unsupport input type, a: " << param->dt_a << ", b: " << param->dt_b << ".\n"; + DEBUG_LOG << "mm_kernel_create: unsupport input type, a: " << param->dt_a << ", b: " << param->dt_b << ".\n"; goto ERR; } } @@ -65,10 +66,10 @@ bool mm_kernel_create_amx(mm_kernel** mm, const mm_create_param* param) { m->b_is_transpose = param->b_is_trans; *mm = m; - return true; + return status_t::status_ok; ERR: delete m; - return false; + return status_t::status_invalid_arguments; } void mm_kernel_destroy_amx(const mm_kernel* mm) { @@ -77,7 +78,7 @@ void mm_kernel_destroy_amx(const mm_kernel* mm) { } } -void mm_kernel_execute_amx(const mm_kernel* mm, void* ptr_a, void* ptr_b, void* ptr_c, size_t lda, size_t ldb, size_t ldc, +status_t mm_kernel_execute_amx(const mm_kernel* mm, void* ptr_a, void* ptr_b, void* ptr_c, size_t lda, size_t ldb, size_t ldc, size_t M, size_t N, size_t K) { size_t b_d0 = K, b_d1 = N; if (mm->b_is_transpose) { @@ -110,9 +111,11 @@ void mm_kernel_execute_amx(const mm_kernel* mm, void* ptr_a, void* ptr_b, void* amx_kernel::PP::BiasGeluStore pp(c); (*mm->bf16xbf16)(a, b, 0, N, pp); } else { - std::cout << "mm_kernel_execute: no valid kernel created, call create first.\n"; + DEBUG_LOG << "mm_kernel_execute: no valid kernel created, call create first.\n"; + return status_t::status_invalid_arguments; } -} + return status_t::status_ok; +} -} \ No newline at end of file +} // namespace llmdnn diff --git a/src/mm_kernel_amx.hpp b/src/mm_kernel_amx.hpp index 1d03818..1950de2 100644 --- a/src/mm_kernel_amx.hpp +++ b/src/mm_kernel_amx.hpp @@ -15,14 +15,15 @@ #include #include "llm_mm.hpp" +#include "llm_types.hpp" namespace llmdnn { -bool mm_kernel_create_amx(mm_kernel** mm, const mm_create_param* param); +status_t mm_kernel_create_amx(mm_kernel** mm, const mm_create_param* param); void mm_kernel_destroy_amx(const mm_kernel* mm); -void mm_kernel_execute_amx(const mm_kernel* mm, void* ptr_a, void* ptr_b, void* ptr_c, size_t lda, size_t ldb, size_t ldc, +status_t mm_kernel_execute_amx(const mm_kernel* mm, void* ptr_a, void* ptr_b, void* ptr_c, size_t lda, size_t ldb, size_t ldc, size_t M, size_t N, size_t K); -} \ No newline at end of file +} // namespace llmdnn diff --git a/src/mm_kernel_api.cpp b/src/mm_kernel_api.cpp index 9b2ebff..ca52fe6 100644 --- a/src/mm_kernel_api.cpp +++ b/src/mm_kernel_api.cpp @@ -3,6 +3,7 @@ // #include "llm_mm.hpp" +#include "llm_types.hpp" #include "mm_kernel_amx.hpp" namespace llmdnn { @@ -12,7 +13,7 @@ static decltype(&mm_kernel_destroy) mm_kernel_destroy_ptr = mm_kernel_destroy_am static decltype(&mm_kernel_execute) mm_kernel_execute_ptr = mm_kernel_execute_amx; // interface -bool mm_kernel_create(mm_kernel** mm, const mm_create_param* param) { +status_t mm_kernel_create(mm_kernel** mm, const mm_create_param* param) { return mm_kernel_create_ptr(mm, param); } @@ -20,9 +21,9 @@ void mm_kernel_destroy(const mm_kernel* mm) { mm_kernel_destroy_ptr(mm); } -void mm_kernel_execute(const mm_kernel* mm, void* ptr_a, void* ptr_b, void* ptr_c, size_t lda, size_t ldb, size_t ldc, +status_t mm_kernel_execute(const mm_kernel* mm, void* ptr_a, void* ptr_b, void* ptr_c, size_t lda, size_t ldb, size_t ldc, size_t M, size_t N, size_t K) { - mm_kernel_execute_ptr(mm, ptr_a, ptr_b, ptr_c, lda, ldb, ldc, M, N, K); + return mm_kernel_execute_ptr(mm, ptr_a, ptr_b, ptr_c, lda, ldb, ldc, M, N, K); } -} \ No newline at end of file +} // namespace llmdnn diff --git a/src/mm_kernel_common_amx.hpp b/src/mm_kernel_common_amx.hpp index 2c6c412..b735764 100644 --- a/src/mm_kernel_common_amx.hpp +++ b/src/mm_kernel_common_amx.hpp @@ -2147,7 +2147,7 @@ struct Matmul { // this dynamic quantization of weight matrix using minmax // is time-consuming, should be used only for constB if (!constB) { - std::cout << "\t WANING: dynamic quantization of weight matrix for non-constB is time-consuming " << std::endl; + DEBUG_LOG << "\t WANING: dynamic quantization of weight matrix for non-constB is time-consuming " << std::endl; } // float min, max; // functional::get_min_max(_matB, min, max); diff --git a/src/rotary_kernel_avx2.hpp b/src/rotary_kernel_avx2.hpp index 8602a04..5bebd50 100644 --- a/src/rotary_kernel_avx2.hpp +++ b/src/rotary_kernel_avx2.hpp @@ -105,4 +105,4 @@ namespace llmdnn { _mm256_maskstore_ps(k_dst + i, x_mask, k_dst_f); } } -} \ No newline at end of file +} // namespace llmdnn diff --git a/src/rotary_kernel_avx512.hpp b/src/rotary_kernel_avx512.hpp index 9f1a6ce..0a4cb48 100644 --- a/src/rotary_kernel_avx512.hpp +++ b/src/rotary_kernel_avx512.hpp @@ -132,4 +132,4 @@ namespace llmdnn { _mm256_mask_storeu_epi16(k_dst + i, x_mask, _mm512_extracti64x4_epi64((__m512i)out, 0)); } } -} \ No newline at end of file +} // namespace llmdnn diff --git a/src/softmax_kernel_avx512.hpp b/src/softmax_kernel_avx512.hpp index 81c52c5..38b4719 100644 --- a/src/softmax_kernel_avx512.hpp +++ b/src/softmax_kernel_avx512.hpp @@ -223,4 +223,4 @@ namespace llmdnn { } } } -} \ No newline at end of file +} // namespace llmdnn diff --git a/src/transpose_kernel_avx512.hpp b/src/transpose_kernel_avx512.hpp index abcd37d..76af1c6 100644 --- a/src/transpose_kernel_avx512.hpp +++ b/src/transpose_kernel_avx512.hpp @@ -108,4 +108,4 @@ namespace llmdnn { dst = reinterpret_cast(reinterpret_cast(dst) + dst_stride); } } -} \ No newline at end of file +} // namespace llmdnn diff --git a/src/utility_kernel_amx.hpp b/src/utility_kernel_amx.hpp index 9a9778f..3ba2443 100644 --- a/src/utility_kernel_amx.hpp +++ b/src/utility_kernel_amx.hpp @@ -92,18 +92,4 @@ struct tileconfig_t { void store() { _tile_storeconfig(this); } - - friend std::ostream& operator<<(std::ostream& out, const tileconfig_t& cfg) { - out << " palette_id=" << static_cast(cfg.palette_id); - out << " startRow=" << static_cast(cfg.startRow); - out << " row x colsb=("; - for (int i = 0; i < 16;i++) { - if (cfg.rows[i] == 0 && cfg.cols[i] == 0) - continue; - if (i > 0) out << ","; - out << static_cast(cfg.rows[i]) << "x" << static_cast(cfg.cols[i]); - } - out << ")"; - return out; - } } __attribute__ ((__packed__)); diff --git a/src/utility_kernel_avx2.hpp b/src/utility_kernel_avx2.hpp index 3c1775e..458d59e 100644 --- a/src/utility_kernel_avx2.hpp +++ b/src/utility_kernel_avx2.hpp @@ -15,20 +15,6 @@ namespace llmdnn { -#pragma once - -#include -#include -#include -#include - -#ifdef _WIN32 -#include -#else -#include -#include -#endif - inline __m256i get_mask(int N7) { static __m256i mask[] = { _mm256_set_epi32( 0, 0, 0, 0, 0, 0, 0, 0), @@ -56,4 +42,4 @@ static inline float _mm256_reduce_add_ps(__m256 x) { return _mm_cvtss_f32(x32); } -} \ No newline at end of file +} // namespace llmdnn diff --git a/src/utility_kernel_avx512.hpp b/src/utility_kernel_avx512.hpp index 2aab313..1f44ef9 100644 --- a/src/utility_kernel_avx512.hpp +++ b/src/utility_kernel_avx512.hpp @@ -202,4 +202,4 @@ inline void mul_add2_select_f32_avx512(float* dst, float* src, float mul, float* } } -} \ No newline at end of file +} // namespace llmdnn diff --git a/tests/src/test_common.cpp b/tests/src/test_common.cpp index be669f8..8f14d26 100644 --- a/tests/src/test_common.cpp +++ b/tests/src/test_common.cpp @@ -32,14 +32,14 @@ using namespace llmdnn; std::string dtype_to_str(data_type_t type) { switch (type) { - case dnnl_data_type_undef: return "undef"; - case dnnl_f16: return "f16"; - case dnnl_bf16: return "bf16"; - case dnnl_f32: return "f32"; - case dnnl_s32: return "s32"; - case dnnl_s8: return "s8"; - case dnnl_u8: return "u8"; - case dnnl_f64: return "f64"; + case llmdnn_data_type_undef: return "undef"; + case llmdnn_f16: return "f16"; + case llmdnn_bf16: return "bf16"; + case llmdnn_f32: return "f32"; + case llmdnn_s32: return "s32"; + case llmdnn_s8: return "s8"; + case llmdnn_u8: return "u8"; + case llmdnn_f64: return "f64"; default: return "unkown"; } } @@ -62,7 +62,7 @@ bool initXTILE() { return true; } -namespace utility { +namespace llmdnn { size_t get_total_threads() { return omp_get_max_threads(); diff --git a/tests/src/test_fc_kernel_amx.cpp b/tests/src/test_fc_kernel_amx.cpp index 0046541..3ec7872 100644 --- a/tests/src/test_fc_kernel_amx.cpp +++ b/tests/src/test_fc_kernel_amx.cpp @@ -67,7 +67,7 @@ class FCKernelTest : public TestWithParam { _dt_a, _dt_b, _dt_c, _is_transpose, _postops_type }; - ASSERT_TRUE(fc_kernel_create(&fc, ¶m)); + ASSERT_TRUE(fc_kernel_create(&fc, ¶m) == llmdnn::status_ok); auto gemm = std::shared_ptr(fc, [](fc_kernel* p) { fc_kernel_destroy(p); }); tensor2D A(_M, _K, true); @@ -138,19 +138,19 @@ class FCKernelTest : public TestWithParam { }; TEST_P(FCKernelTest, Func) { - if (_dt_a == dnnl_s8 && _dt_b == dnnl_s8 && _dt_c == dnnl_s8) { + if (_dt_a == llmdnn_s8 && _dt_b == llmdnn_s8 && _dt_c == llmdnn_s8) { do_test(); - } else if (_dt_a == dnnl_s8 && _dt_b == dnnl_s8 && _dt_c == dnnl_bf16) { + } else if (_dt_a == llmdnn_s8 && _dt_b == llmdnn_s8 && _dt_c == llmdnn_bf16) { do_test(); - } else if (_dt_a == dnnl_s8 && _dt_b == dnnl_s8 && _dt_c == dnnl_f32) { + } else if (_dt_a == llmdnn_s8 && _dt_b == llmdnn_s8 && _dt_c == llmdnn_f32) { do_test(); - } else if (_dt_a == dnnl_bf16 && _dt_b == dnnl_bf16 && _dt_c == dnnl_bf16) { + } else if (_dt_a == llmdnn_bf16 && _dt_b == llmdnn_bf16 && _dt_c == llmdnn_bf16) { do_test(); - } else if (_dt_a == dnnl_bf16 && _dt_b == dnnl_bf16 && _dt_c == dnnl_f32) { + } else if (_dt_a == llmdnn_bf16 && _dt_b == llmdnn_bf16 && _dt_c == llmdnn_f32) { do_test(); - } else if (_dt_a == dnnl_bf16 && _dt_b == dnnl_s8 && _dt_c == dnnl_f32) { + } else if (_dt_a == llmdnn_bf16 && _dt_b == llmdnn_s8 && _dt_c == llmdnn_f32) { do_test(); - } else if (_dt_a == dnnl_bf16 && _dt_b == dnnl_s8 && _dt_c == dnnl_bf16) { + } else if (_dt_a == llmdnn_bf16 && _dt_b == llmdnn_s8 && _dt_c == llmdnn_bf16) { do_test(); } else { ASSERT_TRUE(false); @@ -166,45 +166,45 @@ TEST_P(FCKernelTest, Func) { // (bf16,s8,f32),dq,[bias],[gelu] // (bf16,s8,bf16),dq,[bias],[gelu] const std::vector types = { - { dnnl_s8, dnnl_s8, dnnl_s8, DEQUANT_QUANT }, - { dnnl_s8, dnnl_s8, dnnl_s8, DEQUANT_BIAS_QUANT }, - { dnnl_s8, dnnl_s8, dnnl_s8, DEQUANT_GELU_QUANT }, - { dnnl_s8, dnnl_s8, dnnl_s8, DEQUANT_BIAS_GELU_QUANT }, - { dnnl_s8, dnnl_s8, dnnl_s8, DEQUANT_GELU_TANH_QUANT }, - { dnnl_s8, dnnl_s8, dnnl_s8, DEQUANT_BIAS_GELU_TANH_QUANT }, - { dnnl_s8, dnnl_s8, dnnl_bf16, DEQUANT }, - { dnnl_s8, dnnl_s8, dnnl_bf16, DEQUANT_BIAS }, - { dnnl_s8, dnnl_s8, dnnl_bf16, DEQUANT_GELU }, - { dnnl_s8, dnnl_s8, dnnl_bf16, DEQUANT_BIAS_GELU }, - { dnnl_s8, dnnl_s8, dnnl_bf16, DEQUANT_GELU_TANH }, - { dnnl_s8, dnnl_s8, dnnl_bf16, DEQUANT_BIAS_GELU_TANH }, - { dnnl_s8, dnnl_s8, dnnl_f32, DEQUANT }, - { dnnl_s8, dnnl_s8, dnnl_f32, DEQUANT_BIAS }, - { dnnl_s8, dnnl_s8, dnnl_f32, DEQUANT_GELU }, - { dnnl_s8, dnnl_s8, dnnl_f32, DEQUANT_BIAS_GELU }, - { dnnl_s8, dnnl_s8, dnnl_f32, DEQUANT_GELU_TANH }, - { dnnl_s8, dnnl_s8, dnnl_f32, DEQUANT_BIAS_GELU_TANH }, - { dnnl_bf16, dnnl_bf16, dnnl_bf16, NONE }, - { dnnl_bf16, dnnl_bf16, dnnl_bf16, BIAS }, - { dnnl_bf16, dnnl_bf16, dnnl_bf16, GELU }, - { dnnl_bf16, dnnl_bf16, dnnl_bf16, BIAS_GELU }, - { dnnl_bf16, dnnl_bf16, dnnl_bf16, GELU_TANH }, - { dnnl_bf16, dnnl_bf16, dnnl_bf16, BIAS_GELU_TANH }, - { dnnl_bf16, dnnl_bf16, dnnl_f32, NONE }, - { dnnl_bf16, dnnl_bf16, dnnl_f32, BIAS }, - { dnnl_bf16, dnnl_bf16, dnnl_f32, GELU }, - { dnnl_bf16, dnnl_bf16, dnnl_f32, BIAS_GELU }, - { dnnl_bf16, dnnl_bf16, dnnl_f32, GELU_TANH }, - { dnnl_bf16, dnnl_bf16, dnnl_f32, BIAS_GELU_TANH }, + { llmdnn_s8, llmdnn_s8, llmdnn_s8, DEQUANT_QUANT }, + { llmdnn_s8, llmdnn_s8, llmdnn_s8, DEQUANT_BIAS_QUANT }, + { llmdnn_s8, llmdnn_s8, llmdnn_s8, DEQUANT_GELU_QUANT }, + { llmdnn_s8, llmdnn_s8, llmdnn_s8, DEQUANT_BIAS_GELU_QUANT }, + { llmdnn_s8, llmdnn_s8, llmdnn_s8, DEQUANT_GELU_TANH_QUANT }, + { llmdnn_s8, llmdnn_s8, llmdnn_s8, DEQUANT_BIAS_GELU_TANH_QUANT }, + { llmdnn_s8, llmdnn_s8, llmdnn_bf16, DEQUANT }, + { llmdnn_s8, llmdnn_s8, llmdnn_bf16, DEQUANT_BIAS }, + { llmdnn_s8, llmdnn_s8, llmdnn_bf16, DEQUANT_GELU }, + { llmdnn_s8, llmdnn_s8, llmdnn_bf16, DEQUANT_BIAS_GELU }, + { llmdnn_s8, llmdnn_s8, llmdnn_bf16, DEQUANT_GELU_TANH }, + { llmdnn_s8, llmdnn_s8, llmdnn_bf16, DEQUANT_BIAS_GELU_TANH }, + { llmdnn_s8, llmdnn_s8, llmdnn_f32, DEQUANT }, + { llmdnn_s8, llmdnn_s8, llmdnn_f32, DEQUANT_BIAS }, + { llmdnn_s8, llmdnn_s8, llmdnn_f32, DEQUANT_GELU }, + { llmdnn_s8, llmdnn_s8, llmdnn_f32, DEQUANT_BIAS_GELU }, + { llmdnn_s8, llmdnn_s8, llmdnn_f32, DEQUANT_GELU_TANH }, + { llmdnn_s8, llmdnn_s8, llmdnn_f32, DEQUANT_BIAS_GELU_TANH }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, NONE }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, BIAS }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, GELU }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, BIAS_GELU }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, GELU_TANH }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, BIAS_GELU_TANH }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, NONE }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, BIAS }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, GELU }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, BIAS_GELU }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, GELU_TANH }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, BIAS_GELU_TANH }, // TODO: support weight compression - // { dnnl_bf16, dnnl_s8, dnnl_f32, DEQUANT }, - // { dnnl_bf16, dnnl_s8, dnnl_f32, DEQUANT_BIAS }, - // { dnnl_bf16, dnnl_s8, dnnl_f32, DEQUANT_GELU }, - // { dnnl_bf16, dnnl_s8, dnnl_f32, DEQUANT_BIAS_GELU }, - // { dnnl_bf16, dnnl_s8, dnnl_bf16, DEQUANT }, - // { dnnl_bf16, dnnl_s8, dnnl_bf16, DEQUANT_BIAS }, - // { dnnl_bf16, dnnl_s8, dnnl_bf16, DEQUANT_GELU }, - // { dnnl_bf16, dnnl_s8, dnnl_bf16, DEQUANT_BIAS_GELU }, + // { llmdnn_bf16, llmdnn_s8, llmdnn_f32, DEQUANT }, + // { llmdnn_bf16, llmdnn_s8, llmdnn_f32, DEQUANT_BIAS }, + // { llmdnn_bf16, llmdnn_s8, llmdnn_f32, DEQUANT_GELU }, + // { llmdnn_bf16, llmdnn_s8, llmdnn_f32, DEQUANT_BIAS_GELU }, + // { llmdnn_bf16, llmdnn_s8, llmdnn_bf16, DEQUANT }, + // { llmdnn_bf16, llmdnn_s8, llmdnn_bf16, DEQUANT_BIAS }, + // { llmdnn_bf16, llmdnn_s8, llmdnn_bf16, DEQUANT_GELU }, + // { llmdnn_bf16, llmdnn_s8, llmdnn_bf16, DEQUANT_BIAS_GELU }, }; // M, N, K diff --git a/tests/src/test_mm_kernel_amx.cpp b/tests/src/test_mm_kernel_amx.cpp index ce0fcb2..da1ec6f 100644 --- a/tests/src/test_mm_kernel_amx.cpp +++ b/tests/src/test_mm_kernel_amx.cpp @@ -13,6 +13,7 @@ #include "llm_mm.hpp" #include "common/tensor2d.hpp" #include "common/tensor2d_helper.hpp" +#include "llm_types.hpp" #include "test_common.hpp" using namespace std; @@ -56,7 +57,7 @@ class GemmKernelTest : public TestWithParam { template void test() { - if (_N == 1 && (_is_transpose || _types.first == dnnl_u8)) { + if (_N == 1 && (_is_transpose || _types.first == llmdnn_u8)) { GTEST_SKIP() << "gemv does not support transpose or u8s8."; } mm_kernel* mm; @@ -64,7 +65,7 @@ class GemmKernelTest : public TestWithParam { _types.first, _types.second, _N == 1, _is_transpose }; - ASSERT_TRUE(mm_kernel_create(&mm, ¶m)); + ASSERT_TRUE(mm_kernel_create(&mm, ¶m) == llmdnn::status_ok); auto gemm = std::shared_ptr(mm, [](mm_kernel* p) { mm_kernel_destroy(p); }); tensor2D A(_M, _K, true); @@ -97,9 +98,9 @@ class GemmKernelTest : public TestWithParam { }; TEST_P(GemmKernelTest, Func) { - if (_types.first == dnnl_u8 && _types.second == dnnl_s8) { + if (_types.first == llmdnn_u8 && _types.second == llmdnn_s8) { test(); - } else if (_types.first == dnnl_s8 && _types.second == dnnl_s8) { + } else if (_types.first == llmdnn_s8 && _types.second == llmdnn_s8) { test(); } else { test(); @@ -107,9 +108,9 @@ TEST_P(GemmKernelTest, Func) { } const std::vector> types = { - { dnnl_u8, dnnl_s8 }, - { dnnl_s8, dnnl_s8 }, - { dnnl_bf16, dnnl_bf16 }, + { llmdnn_u8, llmdnn_s8 }, + { llmdnn_s8, llmdnn_s8 }, + { llmdnn_bf16, llmdnn_bf16 }, }; // M, N, K diff --git a/tests/src/test_rotary_kernel_avx2.cpp b/tests/src/test_rotary_kernel_avx2.cpp index 0ab1fca..4718162 100644 --- a/tests/src/test_rotary_kernel_avx2.cpp +++ b/tests/src/test_rotary_kernel_avx2.cpp @@ -93,7 +93,7 @@ class RotaryTestAVX2 : public TestWithParam { }; TEST_P(RotaryTestAVX2, rotary) { - if (_types == dnnl_s8) { + if (_types == llmdnn_s8) { ASSERT_TRUE(false); } else { test(0.01f); @@ -101,7 +101,7 @@ TEST_P(RotaryTestAVX2, rotary) { } const std::vector types = { - dnnl_f32 + llmdnn_f32 }; INSTANTIATE_TEST_SUITE_P(smoke_Rotary, RotaryTestAVX2, diff --git a/tests/src/test_rotary_kernel_avx512.cpp b/tests/src/test_rotary_kernel_avx512.cpp index 6be5829..6cd196b 100644 --- a/tests/src/test_rotary_kernel_avx512.cpp +++ b/tests/src/test_rotary_kernel_avx512.cpp @@ -93,7 +93,7 @@ class RotaryTest : public TestWithParam { }; TEST_P(RotaryTest, rotary) { - if (_types == dnnl_s8) { + if (_types == llmdnn_s8) { ASSERT_TRUE(false); } else { test(0.01f); @@ -101,7 +101,7 @@ TEST_P(RotaryTest, rotary) { } const std::vector types = { - dnnl_bf16 + llmdnn_bf16 }; INSTANTIATE_TEST_SUITE_P(smoke_Rotary, RotaryTest, diff --git a/tests/src/test_softmax_kernel_avx512.cpp b/tests/src/test_softmax_kernel_avx512.cpp index a78fa1f..34942ba 100644 --- a/tests/src/test_softmax_kernel_avx512.cpp +++ b/tests/src/test_softmax_kernel_avx512.cpp @@ -113,11 +113,11 @@ class SoftmaxTest : public TestWithParam { }; TEST_P(SoftmaxTest, Func) { - if (_types == dnnl_s8) { + if (_types == llmdnn_s8) { test(1.1f); - } else if (_types == dnnl_u8) { + } else if (_types == llmdnn_u8) { test(1.1f); - } else if (_types == dnnl_f32) { + } else if (_types == llmdnn_f32) { test(0.00001f); } else { test(0.01f); @@ -125,7 +125,7 @@ TEST_P(SoftmaxTest, Func) { } const std::vector types = { - dnnl_s8, dnnl_bf16, dnnl_u8, dnnl_f32 + llmdnn_s8, llmdnn_bf16, llmdnn_u8, llmdnn_f32 }; INSTANTIATE_TEST_SUITE_P(smoke_Softmax, SoftmaxTest, diff --git a/tests/src/test_transpose_kernel_avx512.cpp b/tests/src/test_transpose_kernel_avx512.cpp index 46035b5..702fdc3 100644 --- a/tests/src/test_transpose_kernel_avx512.cpp +++ b/tests/src/test_transpose_kernel_avx512.cpp @@ -111,11 +111,11 @@ class TransposeTest : public TestWithParam { }; TEST_P(TransposeTest, memcpy2d) { - if (_types == dnnl_s8) { + if (_types == llmdnn_s8) { test(1.1f); - } else if (_types == dnnl_u8) { + } else if (_types == llmdnn_u8) { test(1.1f); - } else if (_types == dnnl_f32) { + } else if (_types == llmdnn_f32) { test(0.00001f); } else { test(0.01f); @@ -123,7 +123,7 @@ TEST_P(TransposeTest, memcpy2d) { } const std::vector types = { - dnnl_s8, dnnl_bf16, dnnl_u8, dnnl_f32 + llmdnn_s8, llmdnn_bf16, llmdnn_u8, llmdnn_f32 }; INSTANTIATE_TEST_SUITE_P(smoke_Transpose, TransposeTest, diff --git a/tests/src/test_utility_kernel_repack1x2_avx512.cpp b/tests/src/test_utility_kernel_repack1x2_avx512.cpp index f140ba4..b5176ed 100644 --- a/tests/src/test_utility_kernel_repack1x2_avx512.cpp +++ b/tests/src/test_utility_kernel_repack1x2_avx512.cpp @@ -131,7 +131,7 @@ class RepackTest : public TestWithParam { }; TEST_P(RepackTest, Func) { - if (_types == dnnl_s8) { + if (_types == llmdnn_s8) { test(); } else { test(); @@ -139,7 +139,7 @@ TEST_P(RepackTest, Func) { } const std::vector types = { - dnnl_s8, dnnl_bf16 + llmdnn_s8, llmdnn_bf16 }; INSTANTIATE_TEST_SUITE_P(smoke_Repack, RepackTest, From 6de45ce1f97440fa1602c8b18d2653d072aeb716 Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Tue, 1 Aug 2023 21:58:08 +0800 Subject: [PATCH 41/54] apply review comments --- src/mm_kernel_common_amx.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mm_kernel_common_amx.hpp b/src/mm_kernel_common_amx.hpp index b735764..63014d3 100644 --- a/src/mm_kernel_common_amx.hpp +++ b/src/mm_kernel_common_amx.hpp @@ -1280,7 +1280,7 @@ namespace PP { q_scale_per_oc = scale_per_oc; } - // source buffC can be i32 or f32 + // source buffC can be i32 or f32, buffC size is [32, 32], valid_m/valid_n is in [1, 32] template::value, bool>::type = true> void operator()(tensor2D & buffC, int m, int n, int valid_m, int valid_n) { auto * psrc = &buffC(0,0); From ca79cd9428f9a3e831408333d35eaffc874e17e1 Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Fri, 4 Aug 2023 03:47:51 +0800 Subject: [PATCH 42/54] fc weight support f32, add weight pack api (cherry picked from commit e5d88c410abb0050ee2a3dee3f78378c031285d5) --- include/llm_fc.hpp | 9 ++-- src/fc_kernel_amx.cpp | 78 ++++++++++++++++++++++++-------- src/fc_kernel_amx.hpp | 6 ++- src/fc_kernel_api.cpp | 11 +++-- src/mm_kernel_common_amx.hpp | 31 +++++++++++++ tests/src/test_common.hpp | 4 +- tests/src/test_fc_kernel_amx.cpp | 21 ++++++++- 7 files changed, 130 insertions(+), 30 deletions(-) diff --git a/include/llm_fc.hpp b/include/llm_fc.hpp index afa4c31..5b95415 100644 --- a/include/llm_fc.hpp +++ b/include/llm_fc.hpp @@ -55,15 +55,18 @@ struct fc_kernel; /// fc: (s8,s8,s8),dq,[bias],[gelu],q /// fc: (s8,s8,bf16),dq,[bias],[gelu] /// fc: (s8,s8,f32),dq,[bias],[gelu] +/// fc: (bf16,f32,bf16),[bias],[gelu] +/// fc: (bf16,f32,f32),[bias],[gelu] /// fc: (bf16,bf16,bf16),[bias],[gelu] /// fc: (bf16,bf16,f32),[bias],[gelu] /// fc: (bf16,s8,f32),dq,[bias],[gelu] /// fc: (bf16,s8,bf16),dq,[bias],[gelu] /// status_t fc_kernel_create(fc_kernel** mm, const fc_create_param* param); -void fc_kernel_destroy(const fc_kernel* mm); -void fc_kernel_execute(const fc_kernel* mm, - void* ptr_a, void* ptr_b, void* ptr_c, size_t lda, size_t ldb, size_t ldc, +void fc_kernel_destroy(fc_kernel* mm); +void fc_kernel_pack_weight(fc_kernel* mm, void* ptr_b, size_t N, size_t K, size_t stride_b, size_t n_start, size_t n_end); +void fc_kernel_execute(fc_kernel* mm, + void* ptr_a, void* ptr_c, size_t stride_a, size_t stride_c, size_t M, size_t N, size_t K, size_t n_start, size_t n_end, float* dq=nullptr, float* q=nullptr, float* bias=nullptr); diff --git a/src/fc_kernel_amx.cpp b/src/fc_kernel_amx.cpp index 85c9064..f5c729d 100644 --- a/src/fc_kernel_amx.cpp +++ b/src/fc_kernel_amx.cpp @@ -28,6 +28,7 @@ struct fc_kernel { data_type_t dt_a; data_type_t dt_b; data_type_t dt_c; + size_t stride_b; postops_types postops_type; bool b_is_transpose; }; @@ -39,6 +40,8 @@ static bool check_valid_postops(size_t value, data_type_t dt_a, data_type_t dt_b { { llmdnn_s8, llmdnn_s8, llmdnn_s8 }, { DEQUANT | QUANT, BIAS | GELU | GELU_TANH } }, { { llmdnn_s8, llmdnn_s8, llmdnn_bf16 }, { DEQUANT, BIAS | GELU | GELU_TANH } }, { { llmdnn_s8, llmdnn_s8, llmdnn_f32 }, { DEQUANT, BIAS | GELU | GELU_TANH } }, + { { llmdnn_bf16, llmdnn_f32, llmdnn_bf16 }, { 0, BIAS | GELU | GELU_TANH } }, + { { llmdnn_bf16, llmdnn_f32, llmdnn_f32 }, { 0, BIAS | GELU | GELU_TANH } }, { { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16 }, { 0, BIAS | GELU | GELU_TANH } }, { { llmdnn_bf16, llmdnn_bf16, llmdnn_f32 }, { 0, BIAS | GELU | GELU_TANH } }, { { llmdnn_bf16, llmdnn_s8, llmdnn_f32 }, { DEQUANT, BIAS | GELU | GELU_TANH } }, @@ -83,7 +86,7 @@ status_t fc_kernel_create_amx(fc_kernel** mm, const fc_create_param* param) { m->i8xi8 = std::make_unique>(true, param->b_is_trans); } else if (param->dt_a == llmdnn_u8 && param->dt_b == llmdnn_s8) { m->u8xi8 = std::make_unique>(true, param->b_is_trans); - } else if (param->dt_a == llmdnn_bf16 && param->dt_b == llmdnn_bf16) { + } else if (param->dt_a == llmdnn_bf16 && (param->dt_b == llmdnn_bf16 || param->dt_b == llmdnn_f32)) { m->bf16xbf16 = std::make_unique>(true, param->b_is_trans); } else if (param->dt_a == llmdnn_bf16 && param->dt_b == llmdnn_s8) { m->bf16xi8 = std::make_unique>(true, param->b_is_trans); @@ -107,13 +110,50 @@ status_t fc_kernel_create_amx(fc_kernel** mm, const fc_create_param* param) { return status_t::status_invalid_arguments; } -void fc_kernel_destroy_amx(const fc_kernel* mm) { +void fc_kernel_destroy_amx(fc_kernel* mm) { if (mm) { delete mm; } } -void fc_kernel_execute_amx(const fc_kernel* mm, void* ptr_a, void* ptr_b, void* ptr_c, size_t lda, size_t ldb, size_t ldc, +void fc_kernel_pack_weight_amx(fc_kernel* mm, void* ptr_b, size_t N, size_t K, size_t stride_b, size_t n_start, size_t n_end) { + mm->stride_b = stride_b; + size_t b_d0 = K, b_d1 = N; + if (mm->b_is_transpose) { + b_d0 = N; + b_d1 = K; + } + if (mm->i8xi8) { + tensor2D b(b_d0, b_d1, static_cast(ptr_b), mm->stride_b); + auto matB = amx_kernel::getSubMatB(b, n_start, n_end, mm->b_is_transpose); + amx_kernel::repackB_1x2(matB, mm->b_is_transpose, mm->i8xi8->internalB, true); + } else if (mm->u8xi8) { + tensor2D b(b_d0, b_d1, static_cast(ptr_b), mm->stride_b); + auto matB = amx_kernel::getSubMatB(b, n_start, n_end, mm->b_is_transpose); + amx_kernel::repackB_1x2(matB, mm->b_is_transpose, mm->u8xi8->internalB, true); + } else if (mm->bf16xbf16) { + if (mm->dt_b == llmdnn_bf16) { + tensor2D b(b_d0, b_d1, static_cast(ptr_b), mm->stride_b); + auto matB = amx_kernel::getSubMatB(b, n_start, n_end, mm->b_is_transpose); + amx_kernel::repackB_1x2(matB, mm->b_is_transpose, mm->bf16xbf16->internalB, true); + } else { + tensor2D b(b_d0, b_d1, static_cast(ptr_b), mm->stride_b); + tensor2D internalTmpB; + auto matB = amx_kernel::getSubMatB(b, n_start, n_end, mm->b_is_transpose); + amx_kernel::functional::f32_to_bf16_tensor(internalTmpB, matB); + amx_kernel::repackB_1x2(internalTmpB, mm->b_is_transpose, mm->bf16xbf16->internalB, true); + } + } else { + tensor2D b(b_d0, b_d1, static_cast(ptr_b), mm->stride_b); + auto matB = amx_kernel::getSubMatB(b, n_start, n_end, mm->b_is_transpose); + tensor2D internalTmpB; + amx_kernel::repackB_1x2(matB, mm->b_is_transpose, internalTmpB, true); + amx_kernel::functional::bf16_to_i8_tensor(mm->bf16xi8->internalBI8, internalTmpB, mm->bf16xi8->quant_scale_B); + } + +} + +void fc_kernel_execute_amx(fc_kernel* mm, void* ptr_a, void* ptr_c, size_t stride_a, size_t stride_c, size_t M, size_t N, size_t K, size_t n_start, size_t n_end, float* dq, float* q, float* bias) { size_t b_d0 = K, b_d1 = N; if (mm->b_is_transpose) { @@ -121,11 +161,11 @@ void fc_kernel_execute_amx(const fc_kernel* mm, void* ptr_a, void* ptr_b, void* b_d1 = K; } if (mm->i8xi8) { - tensor2D a(M, K, reinterpret_cast(ptr_a), lda); - tensor2D b(b_d0, b_d1, reinterpret_cast(ptr_b), ldb); + tensor2D a(M, K, reinterpret_cast(ptr_a), stride_a); + tensor2D b(b_d0, b_d1, nullptr, mm->stride_b); if (mm->dt_c == llmdnn_s8) { - tensor2D c(M, N, reinterpret_cast(ptr_c), ldc); + tensor2D c(M, N, reinterpret_cast(ptr_c), stride_c); if (!(mm->postops_type & BIAS)) { if (mm->postops_type & GELU) { amx_kernel::PP::BiasGeluStore ppkernel(c); @@ -162,7 +202,7 @@ void fc_kernel_execute_amx(const fc_kernel* mm, void* ptr_a, void* ptr_b, void* } } } else if (mm->dt_c == llmdnn_bf16) { - tensor2D c(M, N, reinterpret_cast(ptr_c), ldc); + tensor2D c(M, N, reinterpret_cast(ptr_c), stride_c); if (!bias) { if (mm->postops_type & GELU) { amx_kernel::PP::BiasGeluStore ppkernel(c); @@ -193,7 +233,7 @@ void fc_kernel_execute_amx(const fc_kernel* mm, void* ptr_a, void* ptr_b, void* } } } else if (mm->dt_c == llmdnn_f32) { - tensor2D c(M, N, reinterpret_cast(ptr_c), ldc); + tensor2D c(M, N, reinterpret_cast(ptr_c), stride_c); if (!bias) { if (mm->postops_type & GELU) { amx_kernel::PP::BiasGeluStore ppkernel(c); @@ -225,17 +265,17 @@ void fc_kernel_execute_amx(const fc_kernel* mm, void* ptr_a, void* ptr_b, void* } } } else if (mm->u8xi8) { - tensor2D a(M, K, reinterpret_cast(ptr_a), lda); - tensor2D b(b_d0, b_d1, reinterpret_cast(ptr_b), ldb); - tensor2D c(M, N, reinterpret_cast(ptr_c), ldc); + tensor2D a(M, K, reinterpret_cast(ptr_a), stride_a); + tensor2D b(b_d0, b_d1, nullptr, mm->stride_b); + tensor2D c(M, N, reinterpret_cast(ptr_c), stride_c); amx_kernel::PP::BiasGeluStore pp(c); (*mm->u8xi8)(a, b, n_start, n_end, pp); } else if (mm->bf16xbf16) { - tensor2D a(M, K, reinterpret_cast(ptr_a), lda); - tensor2D b(b_d0, b_d1, reinterpret_cast(ptr_b), ldb); + tensor2D a(M, K, reinterpret_cast(ptr_a), stride_a); + tensor2D b(b_d0, b_d1, nullptr, mm->stride_b); if (mm->dt_c == llmdnn_bf16) { - tensor2D c(M, N, reinterpret_cast(ptr_c), ldc); + tensor2D c(M, N, reinterpret_cast(ptr_c), stride_c); if (!(mm->postops_type & BIAS)) { if (mm->postops_type & GELU) { amx_kernel::PP::BiasGeluStore ppkernel(c); @@ -260,7 +300,7 @@ void fc_kernel_execute_amx(const fc_kernel* mm, void* ptr_a, void* ptr_b, void* } } } else if (mm->dt_c == llmdnn_f32) { - tensor2D c(M, N, reinterpret_cast(ptr_c), ldc); + tensor2D c(M, N, reinterpret_cast(ptr_c), stride_c); if (!(mm->postops_type & BIAS)) { if (mm->postops_type & GELU) { amx_kernel::PP::BiasGeluStore ppkernel(c); @@ -286,11 +326,11 @@ void fc_kernel_execute_amx(const fc_kernel* mm, void* ptr_a, void* ptr_b, void* } } } else { - tensor2D a(M, K, reinterpret_cast(ptr_a), lda); - tensor2D b(N, K, reinterpret_cast(ptr_b), ldb); + tensor2D a(M, K, reinterpret_cast(ptr_a), stride_a); + tensor2D b(N, K, nullptr, mm->stride_b); if (mm->dt_c == llmdnn_bf16) { - tensor2D c(M, N, reinterpret_cast(ptr_c), ldc); + tensor2D c(M, N, reinterpret_cast(ptr_c), stride_c); if (!(mm->postops_type & BIAS)) { if (mm->postops_type & GELU) { amx_kernel::PP::BiasGeluStore ppkernel(c); @@ -315,7 +355,7 @@ void fc_kernel_execute_amx(const fc_kernel* mm, void* ptr_a, void* ptr_b, void* } } } else if (mm->dt_c == llmdnn_f32) { - tensor2D c(M, N, reinterpret_cast(ptr_c), ldc); + tensor2D c(M, N, reinterpret_cast(ptr_c), stride_c); if (!(mm->postops_type & BIAS)) { if (mm->postops_type & GELU) { amx_kernel::PP::BiasGeluStore ppkernel(c); diff --git a/src/fc_kernel_amx.hpp b/src/fc_kernel_amx.hpp index 224c574..f6f2011 100644 --- a/src/fc_kernel_amx.hpp +++ b/src/fc_kernel_amx.hpp @@ -8,9 +8,11 @@ namespace llmdnn { status_t fc_kernel_create_amx(fc_kernel** mm, const fc_create_param* param); -void fc_kernel_destroy_amx(const fc_kernel* mm); +void fc_kernel_destroy_amx(fc_kernel* mm); -void fc_kernel_execute_amx(const fc_kernel* mm, void* ptr_a, void* ptr_b, void* ptr_c, size_t lda, size_t ldb, size_t ldc, +void fc_kernel_pack_weight_amx(fc_kernel* mm, void* ptr_b, size_t N, size_t K, size_t stride_b, size_t n_start, size_t n_end); + +void fc_kernel_execute_amx(fc_kernel* mm, void* ptr_a, void* ptr_c, size_t stride_a, size_t stride_c, size_t M, size_t N, size_t K, size_t n_start, size_t n_end, float* dq, float* q, float* bias); void fc_kernel_bf16w8_get_q_dq_amx(size_t K, size_t N, size_t stride, void* ptr, float* q, float* dq); diff --git a/src/fc_kernel_api.cpp b/src/fc_kernel_api.cpp index b5174a0..94ffbba 100644 --- a/src/fc_kernel_api.cpp +++ b/src/fc_kernel_api.cpp @@ -23,6 +23,7 @@ namespace llmdnn { static decltype(&fc_kernel_create) fc_kernel_create_ptr = fc_kernel_create_amx; static decltype(&fc_kernel_destroy) fc_kernel_destroy_ptr = fc_kernel_destroy_amx; +static decltype(&fc_kernel_pack_weight) fc_kernel_pack_weight_ptr = fc_kernel_pack_weight_amx; static decltype(&fc_kernel_execute) fc_kernel_execute_ptr = fc_kernel_execute_amx; static decltype(&fc_kernel_bf16w8_get_q_dq) fc_kernel_bf16w8_get_q_dq_ptr = fc_kernel_bf16w8_get_q_dq_amx; @@ -31,13 +32,17 @@ status_t fc_kernel_create(fc_kernel** mm, const fc_create_param* param) { return fc_kernel_create_ptr(mm, param); } -void fc_kernel_destroy(const fc_kernel* mm) { +void fc_kernel_destroy(fc_kernel* mm) { fc_kernel_destroy_ptr(mm); } -void fc_kernel_execute(const fc_kernel* mm, void* ptr_a, void* ptr_b, void* ptr_c, size_t lda, size_t ldb, size_t ldc, +void fc_kernel_pack_weight(fc_kernel* mm, void* ptr_b, size_t N, size_t K, size_t stride_b, size_t n_start, size_t n_end) { + fc_kernel_pack_weight_ptr(mm, ptr_b, N, K, stride_b, n_start, n_end); +} + +void fc_kernel_execute(fc_kernel* mm, void* ptr_a, void* ptr_c, size_t stride_a, size_t stride_c, size_t M, size_t N, size_t K, size_t n_start, size_t n_end, float* dq, float* q, float* bias) { - fc_kernel_execute_ptr(mm, ptr_a, ptr_b, ptr_c, lda, ldb, ldc, M, N, K, n_start, n_end, dq, q, bias); + fc_kernel_execute_ptr(mm, ptr_a, ptr_c, stride_a, stride_c, M, N, K, n_start, n_end, dq, q, bias); } void fc_kernel_bf16w8_get_q_dq(size_t K, size_t N, size_t stride, void* ptr, float* q, float* dq) { diff --git a/src/mm_kernel_common_amx.hpp b/src/mm_kernel_common_amx.hpp index 63014d3..3e139a9 100644 --- a/src/mm_kernel_common_amx.hpp +++ b/src/mm_kernel_common_amx.hpp @@ -1222,6 +1222,37 @@ namespace functional { } } } + + inline void f32_to_bf16_tensor(tensor2D& dst, tensor2D& src) { + dst.resize(src.dims[0], src.dims[1]); + auto tail = src.dims[1] % 16; + __mmask16 x_mask = _cvtu32_mask16(0xFFFFu >> (16 - tail)); + for (int k = 0; k < src.dims[0]; k++) { + auto p_src = &src(k, 0); + auto p_dst = &dst(k, 0); + int i; + for(i = 0; i < src.dims[1] / 32 * 32; i += 32) { + auto x0 = _mm512_loadu_ps(p_src + i); + auto x1 = _mm512_loadu_ps(p_src + i + 16); + auto out = _mm512_cvtne2ps_pbh(x1, x0); + _mm512_storeu_epi32(reinterpret_cast(p_dst) + i, (__m512i)out); + } + if (i < src.dims[1] - tail) { + auto x = _mm512_loadu_ps(p_src + i); + auto out = _mm512_cvtne2ps_pbh(x, x); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(reinterpret_cast(p_dst) + i), + _mm512_extracti64x4_epi64(out, 0)); + i += 16; + } + // handle tails + if (tail) { + auto x = _mm512_maskz_loadu_ps(x_mask, p_src + i); + auto out = _mm512_cvtne2ps_pbh(x, x); + _mm256_mask_storeu_epi16(reinterpret_cast<__m256i*>(reinterpret_cast(p_dst) + i), + x_mask, _mm512_extracti64x4_epi64(out, 0)); + } + } + } }; // 2x2 tiles post process kernels diff --git a/tests/src/test_common.hpp b/tests/src/test_common.hpp index f9acf1e..ac44735 100644 --- a/tests/src/test_common.hpp +++ b/tests/src/test_common.hpp @@ -99,9 +99,9 @@ inline void matmul(tensor2D & A, } } -template +template void matmul(tensor2D & A, - tensor2D & B, + tensor2D & B, tensor2D & C, float * dq = nullptr, float * bias = nullptr, diff --git a/tests/src/test_fc_kernel_amx.cpp b/tests/src/test_fc_kernel_amx.cpp index 3ec7872..3ee6be1 100644 --- a/tests/src/test_fc_kernel_amx.cpp +++ b/tests/src/test_fc_kernel_amx.cpp @@ -95,7 +95,8 @@ class FCKernelTest : public TestWithParam { ptr_B = B.data; ldb = B.stride; } - fc_kernel_execute(gemm.get(), A.data, ptr_B, C.data, A.stride, ldb, + fc_kernel_pack_weight(gemm.get(), ptr_B, _N, _K, ldb, 0, _N); + fc_kernel_execute(gemm.get(), A.data, C.data, A.stride, C.stride, _M, _N, _K, 0, _N, dq.data, q.data, bias.data); C_Ref = 0; float* ptr_dq = nullptr; @@ -148,6 +149,10 @@ TEST_P(FCKernelTest, Func) { do_test(); } else if (_dt_a == llmdnn_bf16 && _dt_b == llmdnn_bf16 && _dt_c == llmdnn_f32) { do_test(); + } else if (_dt_a == llmdnn_bf16 && _dt_b == llmdnn_f32 && _dt_c == llmdnn_bf16) { + do_test(); + } else if (_dt_a == llmdnn_bf16 && _dt_b == llmdnn_f32 && _dt_c == llmdnn_f32) { + do_test(); } else if (_dt_a == llmdnn_bf16 && _dt_b == llmdnn_s8 && _dt_c == llmdnn_f32) { do_test(); } else if (_dt_a == llmdnn_bf16 && _dt_b == llmdnn_s8 && _dt_c == llmdnn_bf16) { @@ -163,6 +168,8 @@ TEST_P(FCKernelTest, Func) { // (s8,s8,f32),dq,[bias],[gelu] // (bf16,bf16,bf16),[bias],[gelu] // (bf16,bf16,f32),[bias],[gelu] +// (bf16,f32,bf16),[bias],[gelu] +// (bf16,f32,f32),[bias],[gelu] // (bf16,s8,f32),dq,[bias],[gelu] // (bf16,s8,bf16),dq,[bias],[gelu] const std::vector types = { @@ -196,6 +203,18 @@ const std::vector types = { { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, BIAS_GELU }, { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, GELU_TANH }, { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, BIAS_GELU_TANH }, + { llmdnn_bf16, llmdnn_f32, llmdnn_bf16, NONE }, + { llmdnn_bf16, llmdnn_f32, llmdnn_bf16, BIAS }, + { llmdnn_bf16, llmdnn_f32, llmdnn_bf16, GELU }, + { llmdnn_bf16, llmdnn_f32, llmdnn_bf16, BIAS_GELU }, + { llmdnn_bf16, llmdnn_f32, llmdnn_bf16, GELU_TANH }, + { llmdnn_bf16, llmdnn_f32, llmdnn_bf16, BIAS_GELU_TANH }, + { llmdnn_bf16, llmdnn_f32, llmdnn_f32, NONE }, + { llmdnn_bf16, llmdnn_f32, llmdnn_f32, BIAS }, + { llmdnn_bf16, llmdnn_f32, llmdnn_f32, GELU }, + { llmdnn_bf16, llmdnn_f32, llmdnn_f32, BIAS_GELU }, + { llmdnn_bf16, llmdnn_f32, llmdnn_f32, GELU_TANH }, + { llmdnn_bf16, llmdnn_f32, llmdnn_f32, BIAS_GELU_TANH }, // TODO: support weight compression // { llmdnn_bf16, llmdnn_s8, llmdnn_f32, DEQUANT }, // { llmdnn_bf16, llmdnn_s8, llmdnn_f32, DEQUANT_BIAS }, From cbfe9ffc9b1f5cce65deb6196deb990c93e128b8 Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Sat, 5 Aug 2023 01:56:48 +0800 Subject: [PATCH 43/54] opt f32 weight pack (cherry picked from commit fad8086f10060984414d598ac38a0d5694154d77) --- src/fc_kernel_amx.cpp | 4 +- src/mm_kernel_common_amx.hpp | 140 ++++++++++++++++-- .../test_utility_kernel_repack1x2_avx512.cpp | 38 +++-- 3 files changed, 153 insertions(+), 29 deletions(-) diff --git a/src/fc_kernel_amx.cpp b/src/fc_kernel_amx.cpp index f5c729d..460d98e 100644 --- a/src/fc_kernel_amx.cpp +++ b/src/fc_kernel_amx.cpp @@ -138,10 +138,8 @@ void fc_kernel_pack_weight_amx(fc_kernel* mm, void* ptr_b, size_t N, size_t K, s amx_kernel::repackB_1x2(matB, mm->b_is_transpose, mm->bf16xbf16->internalB, true); } else { tensor2D b(b_d0, b_d1, static_cast(ptr_b), mm->stride_b); - tensor2D internalTmpB; auto matB = amx_kernel::getSubMatB(b, n_start, n_end, mm->b_is_transpose); - amx_kernel::functional::f32_to_bf16_tensor(internalTmpB, matB); - amx_kernel::repackB_1x2(internalTmpB, mm->b_is_transpose, mm->bf16xbf16->internalB, true); + amx_kernel::repackB_1x2(matB, mm->b_is_transpose, mm->bf16xbf16->internalB, true); } } else { tensor2D b(b_d0, b_d1, static_cast(ptr_b), mm->stride_b); diff --git a/src/mm_kernel_common_amx.hpp b/src/mm_kernel_common_amx.hpp index 3e139a9..79efc47 100644 --- a/src/mm_kernel_common_amx.hpp +++ b/src/mm_kernel_common_amx.hpp @@ -1223,7 +1223,7 @@ namespace functional { } } - inline void f32_to_bf16_tensor(tensor2D& dst, tensor2D& src) { + inline void f32_to_bf16_tensor(tensor2D& dst, const tensor2D& src) { dst.resize(src.dims[0], src.dims[1]); auto tail = src.dims[1] % 16; __mmask16 x_mask = _cvtu32_mask16(0xFFFFu >> (16 - tail)); @@ -1510,55 +1510,52 @@ void repackB_1x2(const tensor2D &Bi, bool transpose, tensor2D& Bo, bool is for(; n < N - n_tail; n += N_unit) { // a K_padded x N_unit submatrix layouted in B0/B1... and put sequentially auto* dst = reinterpret_cast(&Bo(n / N_unit, 0)); - auto* src0 = reinterpret_cast(&Bi(n, 0)); int k; for(k = 0; k < Kbody; k += kStep) { // B0 (16x32) => transpose+repack as 32x16(16x16x2) or 64x16(16x16x4) - functional::transpose_epi32_16x16(dst, src0 + 0 * 16 * Bi.stride + k * sizeof(T), Bi.stride); + functional::transpose_epi32_16x16(dst, &Bi(n, k), Bi.stride); dst += 1024; - functional::transpose_epi32_16x16(dst, src0 + 1 * 16 * Bi.stride + k * sizeof(T), Bi.stride); + functional::transpose_epi32_16x16(dst, &Bi(n + 16, k), Bi.stride); dst += 1024; } if (Ktails) { // Ktails part is loaded into A tile right-aligned, so B tile must also load // Ktails part to bottom-aligned, and fill upper padding with zero - functional::transpose_epi32_16xN_right_align(dst, src0 + 0 * 16 * Bi.stride + k * sizeof(T), Bi.stride, (K - k)*sizeof(T)); + functional::transpose_epi32_16xN_right_align(dst, &Bi(n, k), Bi.stride, (K - k) * sizeof(T)); dst += 1024; - functional::transpose_epi32_16xN_right_align(dst, src0 + 1 * 16 * Bi.stride + k * sizeof(T), Bi.stride, (K - k)*sizeof(T)); + functional::transpose_epi32_16xN_right_align(dst, &Bi(n + 16, k), Bi.stride, (K - k) * sizeof(T)); dst += 1024; } } // n_tail: [16, 32) if (N - n >= 16) { auto* dst = reinterpret_cast(&Bo(n / N_unit, 0)); - auto* src0 = reinterpret_cast(&Bi(n, 0)); int k; for(k = 0; k < Kbody; k += kStep) { // B0 (16x32) => transpose+repack as 32x16(16x16x2) or 64x16(16x16x4) - functional::transpose_epi32_16x16(dst, src0 + 0 * 16 * Bi.stride + k * sizeof(T), Bi.stride); + functional::transpose_epi32_16x16(dst, &Bi(n, k), Bi.stride); dst += 1024 * 2; } if (Ktails) { // Ktails part is loaded into A tile right-aligned, so B tile must also load // Ktails part to bottom-aligned, and fill upper padding with zero - functional::transpose_epi32_16xN_right_align(dst, src0 + 0 * 16 * Bi.stride + k * sizeof(T), Bi.stride, (K - k) * sizeof(T)); + functional::transpose_epi32_16xN_right_align(dst, &Bi(n, k), Bi.stride, (K - k) * sizeof(T)); } n += 16; } // n_tail: (0, 16) if (N - n > 0) { auto* dst = reinterpret_cast(&Bo(n / N_unit, 0)) + (n_tail > 16 ? 1024 : 0); - auto* src0 = reinterpret_cast(&Bi(n, 0)); int k; for(k = 0; k < Kbody; k += kStep) { // B0 (16x32) => transpose+repack as 32x16(16x16x2) or 64x16(16x16x4) - functional::transpose_epi32_Mx16(dst, src0 + 0 * 16 * Bi.stride + k * sizeof(T), Bi.stride, N - n); + functional::transpose_epi32_Mx16(dst, &Bi(n, k), Bi.stride, N - n); dst += 1024 * 2; } if (Ktails) { // Ktails part is loaded into A tile right-aligned, so B tile must also load // Ktails part to bottom-aligned, and fill upper padding with zero - functional::transpose_epi32_MxN_right_align(dst, src0 + 0 * 16 * Bi.stride + k * sizeof(T), Bi.stride, (K - k) * sizeof(T), N - n); + functional::transpose_epi32_MxN_right_align(dst, &Bi(n, k), Bi.stride, (K - k) * sizeof(T), N - n); } n = N; } @@ -1599,6 +1596,125 @@ void repackB_1x2(const tensor2D &Bi, bool transpose, tensor2D& Bo, bool is } } +inline void repackB_1x2(tensor2D &Bi, bool transpose, tensor2D& Bo, bool is_const) { + int K = Bi.dims[transpose ? 1 : 0]; + int N = Bi.dims[transpose ? 0 : 1]; + + // K_padded : round up to multiple of 32/64 + int kStep = 64 / sizeof(ov::bfloat16); + int K_padded = (K + kStep - 1) / kStep * kStep; + int Ktails = K % kStep; + int Kbody = K - Ktails; + + // N_padded : round up to multiple of (2*16) + int N_unit = 2 * 16; + int N_padded = (N + N_unit - 1) / N_unit * N_unit; + + // Bo(ni, 0) is a vector flattened from a slice of shape [K_padded x N_unit] + Bo.resize(N_padded / N_unit, K_padded * N_unit, false, is_const); + + int n = 0; + int n_tail = N % N_unit; + if (transpose) { + tensor2D Btmp(16, 32); + for(; n < N - n_tail; n += N_unit) { + // a K_padded x N_unit submatrix layouted in B0/B1... and put sequentially + auto* dst = reinterpret_cast(&Bo(n / N_unit, 0)); + int k; + for(k = 0; k < Kbody; k += kStep) { + // B0 (16x32) => transpose+repack as 32x16(16x16x2) or 64x16(16x16x4) + functional::f32_to_bf16_tensor(Btmp, tensor2D(16, 32, &Bi(n, k), Bi.stride)); + functional::transpose_epi32_16x16(dst, &Btmp(0, 0), Btmp.stride); + dst += 1024; + functional::f32_to_bf16_tensor(Btmp, tensor2D(16, 32, &Bi(n + 16, k), Bi.stride)); + functional::transpose_epi32_16x16(dst, &Btmp(0, 0), Btmp.stride); + dst += 1024; + } + if (Ktails) { + // Ktails part is loaded into A tile right-aligned, so B tile must also load + // Ktails part to bottom-aligned, and fill upper padding with zero + functional::f32_to_bf16_tensor(Btmp, tensor2D(16, K - k, &Bi(n, k), Bi.stride)); + functional::transpose_epi32_16xN_right_align(dst, &Btmp(0, 0), Btmp.stride, (K - k) * sizeof(ov::bfloat16)); + dst += 1024; + functional::f32_to_bf16_tensor(Btmp, tensor2D(16, K - k, &Bi(n + 16, k), Bi.stride)); + functional::transpose_epi32_16xN_right_align(dst, &Btmp(0, 0), Btmp.stride, (K - k) * sizeof(ov::bfloat16)); + dst += 1024; + } + } + // n_tail: [16, 32) + if (N - n >= 16) { + auto* dst = reinterpret_cast(&Bo(n / N_unit, 0)); + int k; + for(k = 0; k < Kbody; k += kStep) { + // B0 (16x32) => transpose+repack as 32x16(16x16x2) or 64x16(16x16x4) + functional::f32_to_bf16_tensor(Btmp, tensor2D(16, 32, &Bi(n, k), Bi.stride)); + functional::transpose_epi32_16x16(dst, &Btmp(0, 0), Btmp.stride); + dst += 1024 * 2; + } + if (Ktails) { + // Ktails part is loaded into A tile right-aligned, so B tile must also load + // Ktails part to bottom-aligned, and fill upper padding with zero + functional::f32_to_bf16_tensor(Btmp, tensor2D(16, K - k, &Bi(n, k), Bi.stride)); + functional::transpose_epi32_16xN_right_align(dst, &Btmp(0, 0), Btmp.stride, (K - k) * sizeof(ov::bfloat16)); + } + n += 16; + } + // n_tail: (0, 16) + if (N - n > 0) { + auto* dst = reinterpret_cast(&Bo(n / N_unit, 0)) + (n_tail > 16 ? 1024 : 0); + int k; + for(k = 0; k < Kbody; k += kStep) { + // B0 (16x32) => transpose+repack as 32x16(16x16x2) or 64x16(16x16x4) + functional::f32_to_bf16_tensor(Btmp, tensor2D(N - n, 32, &Bi(n, k), Bi.stride)); + functional::transpose_epi32_Mx16(dst, &Btmp(0, 0), Btmp.stride, N - n); + dst += 1024 * 2; + } + if (Ktails) { + // Ktails part is loaded into A tile right-aligned, so B tile must also load + // Ktails part to bottom-aligned, and fill upper padding with zero + functional::f32_to_bf16_tensor(Btmp, tensor2D(N - n, K - k, &Bi(n, k), Bi.stride)); + functional::transpose_epi32_MxN_right_align(dst, &Btmp(0, 0), Btmp.stride, (K - k) * sizeof(ov::bfloat16), N - n); + } + n = N; + } + // second B tile is untouched, need to set to zero + if (n_tail > 0 && n_tail <= 16) { + auto* dst = reinterpret_cast(&Bo(n / N_unit, 0)); + for (int k = 0; k < K_padded; k += kStep) { + memset(dst + 1024, 0, 1024); + dst += 1024 * 2; + } + } + } else { + // pack & layout sequentially + int n = 0; + int n_tail = N % N_unit; + tensor2D Btmp(32, 32); + for(; n < N - n_tail; n += N_unit) { + auto * dst = reinterpret_cast(&Bo(n / N_unit, 0)); + for(int k = 0; k < K; k += kStep) { + // bf16: B0 B1 32x(16+16) => repack as two 16x16x2 + int src_rows = std::min(K - k, kStep); + functional::f32_to_bf16_tensor(Btmp, tensor2D(src_rows, 32, &Bi(k, n), Bi.stride)); + functional::kpack_tile_B0B1(dst, dst + 1024, &Btmp(0, 0), Btmp.stride, src_rows); + dst += 2048; + } + } + // n_tail: (0, 32) + if (N - n > 0) { + auto * dst = reinterpret_cast(&Bo(n / N_unit, 0)); + for(int k = 0; k < K; k += kStep) { + // bf16: B0 B1 32x(16+16) => repack as two 16x16x2 + int src_rows = std::min(K - k, kStep); + functional::f32_to_bf16_tensor(Btmp, tensor2D(src_rows, N - n, &Bi(k, n), Bi.stride)); + functional::kpack_tile_B0B1_ntail(dst, dst + 1024, &Btmp(0, 0), Btmp.stride, src_rows, N - n); + dst += 2048; + } + n += 16; + } + } +} + template struct acc_type {}; template<> diff --git a/tests/src/test_utility_kernel_repack1x2_avx512.cpp b/tests/src/test_utility_kernel_repack1x2_avx512.cpp index b5176ed..fbee232 100644 --- a/tests/src/test_utility_kernel_repack1x2_avx512.cpp +++ b/tests/src/test_utility_kernel_repack1x2_avx512.cpp @@ -23,17 +23,17 @@ using ::testing::Values; using ::testing::ValuesIn; using RepackTestParamSet = std::tuple< - data_type_t // data type + std::pair // data type >; class RepackTest : public TestWithParam { public: static std::string getTestCaseName(const testing::TestParamInfo& obj) { - data_type_t types; + std::pair types; std::tie(types) = obj.param; std::ostringstream result; - result << dtype_to_str(types); + result << "IN_" << dtype_to_str(types.first) << "_OUT_" << dtype_to_str(types.second); return result.str(); } @@ -84,17 +84,23 @@ class RepackTest : public TestWithParam { } } - template + template void test() { auto testone = [] (int k, int n, std::string prefix) { - tensor2D A(k, n, true); + tensor2D A(k, n, true); fill_rnd(A); - tensor2D AT = A.Tr(true); - tensor2D A_out, AT_out, A_ref; + tensor2D AT = A.Tr(true); + tensor2D A_out, AT_out, A_ref; amx_kernel::repackB_1x2(A, false, A_out, false); amx_kernel::repackB_1x2(AT, true, AT_out, false); - gen_ref(AT, A_ref); + if constexpr (std::is_same_v) { + tensor2D AT_bf16(n, k, true); + amx_kernel::functional::f32_to_bf16_tensor(AT_bf16, AT); + gen_ref(AT_bf16, A_ref); + } else { + gen_ref(AT, A_ref); + } ASSERT_TRUE(A_out == A_ref) << " " << prefix << " without transform K: " << k << " N: " << n; ASSERT_TRUE(AT_out == A_ref) << " " << prefix << " with transform K: " << k << " N: " << n; }; @@ -127,19 +133,23 @@ class RepackTest : public TestWithParam { testone(64 + 16 + 3, 128 + 16 + 5, "alltail"); } - data_type_t _types; + std::pair _types; }; TEST_P(RepackTest, Func) { - if (_types == llmdnn_s8) { - test(); + if (_types.first == llmdnn_s8 && _types.second == llmdnn_s8) { + test(); + } else if (_types.first == llmdnn::llmdnn_bf16 && _types.second == llmdnn_bf16) { + test(); } else { - test(); + test(); } } -const std::vector types = { - llmdnn_s8, llmdnn_bf16 +const std::vector> types = { + {llmdnn_s8, llmdnn_s8}, + {llmdnn_bf16, llmdnn_bf16}, + {llmdnn_f32, llmdnn_bf16}, }; INSTANTIATE_TEST_SUITE_P(smoke_Repack, RepackTest, From 3f31227e9fba56c9e1ed50c370c8c30acb3a390b Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Thu, 10 Aug 2023 16:52:37 +0800 Subject: [PATCH 44/54] remove writeable bufferC (cherry picked from commit bef692412d36c445695136f4ba811c423e3cda96) --- src/mm_kernel_common_amx.hpp | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/mm_kernel_common_amx.hpp b/src/mm_kernel_common_amx.hpp index 79efc47..e3845d7 100644 --- a/src/mm_kernel_common_amx.hpp +++ b/src/mm_kernel_common_amx.hpp @@ -1963,13 +1963,9 @@ struct Matmul { // but only different K(32 vs 64) in A,C & B tiles constexpr static int kStep = is_i8_mode ? 64 : 32; - // 2x2 C tiles buffer - // most usecase requires post-processing with AVX, thus buffC - // is used to transfer data to AVX register - tensor2D buffC; Matmul(bool constB = false, bool transposeB = false) : - constB(constB), transposeB(transposeB), buffC(32, 32) {} + constB(constB), transposeB(transposeB) {} // ppkernel is a callable which captures the runtime args // by itself, so no need to pass in any post-process related @@ -2086,6 +2082,12 @@ struct Matmul { bool skip_repack = false) { int M = matA.dims[0]; int K = matA.dims[1]; + // 2x2 C tiles buffer + // most usecase requires post-processing with AVX, thus buffC + // is used to transfer data to AVX register + alignas(64) TC buff[32 * 32]; + tensor2D buffC(32, 32, buff, 32 * sizeof(TC)); + if (K < kStep) { int B0, B1; if (transposeB) { From 76d5f0f89ba72ad4dc6bcc0db74d360aed48e26d Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Sat, 12 Aug 2023 00:12:11 +0800 Subject: [PATCH 45/54] use numa to alloc mem (cherry picked from commit 96b9a5917917a53f08cc15d4d1692ddba21ba05a) --- src/CMakeLists.txt | 1 + src/common/memory_alloc.cpp | 28 ++++++++++++++++++++++++++++ src/common/memory_alloc.hpp | 10 ++++++++++ src/common/tensor2d.hpp | 23 +++++------------------ src/mm_kernel_common_amx.hpp | 4 ---- 5 files changed, 44 insertions(+), 22 deletions(-) create mode 100644 src/common/memory_alloc.cpp create mode 100644 src/common/memory_alloc.hpp diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index deeabe4..dafbe68 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -18,6 +18,7 @@ target_compile_features(${PROJECT_NAME} PRIVATE cxx_std_17) if(CPU_EXTENSIONS_ENABLE_LOG) target_compile_definitions(${PROJECT_NAME} PRIVATE ENABLE_LOG) endif() +target_link_libraries(${PROJECT_NAME} PUBLIC numa) set(CMAKE_DST lib/cmake/${PROJECT_NAME}) # header files diff --git a/src/common/memory_alloc.cpp b/src/common/memory_alloc.cpp new file mode 100644 index 0000000..955d30f --- /dev/null +++ b/src/common/memory_alloc.cpp @@ -0,0 +1,28 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// +#include +#include +#include +#include +#include +#include +#include "memory_alloc.hpp" + +void* llmdnn_alloc(size_t aligned_size, size_t size, bool hint_numa) { + if (hint_numa && numa_available() != -1) { + int cur_cpu = sched_getcpu(); + auto cur_numa_node = numa_node_of_cpu(cur_cpu); + return numa_alloc_onnode(size, cur_numa_node); + } else { + return aligned_alloc(aligned_size, size); + } +} + +void llmdnn_free(void* p, size_t size, bool hint_numa) { + if (hint_numa && numa_available() != -1) { + numa_free(p, size); + } else { + ::free(p); + } +} diff --git a/src/common/memory_alloc.hpp b/src/common/memory_alloc.hpp new file mode 100644 index 0000000..530ddc9 --- /dev/null +++ b/src/common/memory_alloc.hpp @@ -0,0 +1,10 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +void* llmdnn_alloc(size_t aligned_size, size_t size, bool hint_numa = true); +void llmdnn_free(void* p, size_t size, bool hint_numa = true); diff --git a/src/common/tensor2d.hpp b/src/common/tensor2d.hpp index ec7c975..069c735 100644 --- a/src/common/tensor2d.hpp +++ b/src/common/tensor2d.hpp @@ -9,9 +9,7 @@ #include #include #include -#ifdef ENABLE_NUMA -#include "numa.h" -#endif +#include "memory_alloc.hpp" #include "log.hpp" #include "bf16.hpp" @@ -30,7 +28,7 @@ struct tensor2D { tensor2D() = default; tensor2D(const tensor2D&) = delete; ~tensor2D() { - if (own && data) ::free(data); + if (own && data) llmdnn_free(data, capacity); } operator bool() { @@ -104,22 +102,11 @@ struct tensor2D { if (capacity < need_capacity) { if (!is_const) need_capacity *= 2; - capacity = need_capacity; // align begin address to cache line is vital, so tile load can // use all bandwidth (L1D/L2 only deliver data in unit of 64-byte aligned cache-line) - -#ifdef ENABLE_NUMA - if (USE_NUMA) { - data = std::shared_ptr( - reinterpret_cast(numa_alloc_local(capacity)), - [need_capacity](void * p){ numa_free(p, need_capacity); }); - } else { -#else - { -#endif - if (data) ::free(data); - data = reinterpret_cast(aligned_alloc(64, capacity)); - } + if (data) llmdnn_free(data, capacity); + data = reinterpret_cast(llmdnn_alloc(64, need_capacity)); + capacity = need_capacity; if (is_const) memset(static_cast(data), 0, need_capacity); if (reinterpret_cast(data) % 64) diff --git a/src/mm_kernel_common_amx.hpp b/src/mm_kernel_common_amx.hpp index e3845d7..0f443cb 100644 --- a/src/mm_kernel_common_amx.hpp +++ b/src/mm_kernel_common_amx.hpp @@ -16,10 +16,6 @@ #include #endif -#ifdef ENABLE_NUMA -#include "numa.h" -#endif - using namespace llmdnn; namespace amx_kernel { From 2dd181a563ba5d4bc6b4242930d1e59e74402f57 Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Thu, 17 Aug 2023 01:52:33 +0800 Subject: [PATCH 46/54] add falcon broadcast support before rotary --- CMakeLists.txt | 2 +- include/llm_emb_gpt.hpp | 9 +- src/emb_gpt_avx512.cpp | 57 ++++- tests/script/README.md | 1 + tests/script/ext/emb_gpt.cpp | 60 ++++- tests/script/ext/setup.py | 1 + tests/script/test_rotary_pastkv_falcon.py | 262 ++++++++++++++++++++++ 7 files changed, 384 insertions(+), 8 deletions(-) create mode 100644 tests/script/test_rotary_pastkv_falcon.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 9e69cd8..068837a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -8,7 +8,7 @@ project(root) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) option(CPU_EXTENSIONS_BUILD_TESTS "Build with tests" ON) -option(CPU_EXTENSIONS_ENABLE_LOG "Enable log" OFF) +option(CPU_EXTENSIONS_ENABLE_LOG "Enable log" ON) message(INFO "--------------------------------") message(STATUS "Build with tests: ${CPU_EXTENSIONS_BUILD_TESTS}") diff --git a/include/llm_emb_gpt.hpp b/include/llm_emb_gpt.hpp index 3074cb1..f74e3e6 100644 --- a/include/llm_emb_gpt.hpp +++ b/include/llm_emb_gpt.hpp @@ -12,9 +12,12 @@ namespace llmdnn { -status_t emb_gpt(const tensor& q_src, // q shape: [batch, query_seq_len, head_num, head_size] - const tensor& k_src, // k shape: [batch, query_seq_len, head_num, head_size] - const tensor& v_src, // v shape: [batch, query_seq_len, head_num, head_size] +status_t emb_gpt(const tensor& q_src, // q shape: [batch, query_seq_len, head_num, head_size] or + // [batch, query_seq_len, num_kv_heads, head_num/num_kv_heads, head_size] + const tensor& k_src, // k shape: [batch, query_seq_len, head_num, head_size] or + // [batch, query_seq_len, num_kv_heads, 1, head_size] + const tensor& v_src, // v shape: [batch, query_seq_len, head_num, head_size] or + // [batch, query_seq_len, num_kv_heads, 1, head_size] const tensor& k_past, // k_past shape: [batch, num_heads, past_seq_len, head_size] const tensor& v_past, // v_past shape: [batch, num_heads, past_seq_len, head_size] const tensor& q_dst, // q_dst, shape: [batch, num_heads, query_seq_len, head_size] diff --git a/src/emb_gpt_avx512.cpp b/src/emb_gpt_avx512.cpp index 2cd527d..1338494 100644 --- a/src/emb_gpt_avx512.cpp +++ b/src/emb_gpt_avx512.cpp @@ -86,6 +86,53 @@ static void rotary_emb_position2d(const tensor& q_src, }); } + +// q_src shape: [batch, q_seq_len, num_kv_heads, head_num/num_kv_heads, head_size] +// q_dst shape: [batch, head_hum, q_seq_len, head_size] +// kv_src shape: [batch, q_seq_len, num_kv_heads, 1, head_size] +// kv_past shape: [batch, head_hum, past_seq_len, head_size] +// kv_dst shape: [batch, head_hum, q_seq_len+past_seq_len, head_size] +// position2d_ids: [batch, 2, q_seq_len] +// cos/sin: [max_seq_len, rotary_dims] +static void rotary_emb_falcon(const tensor& q_src, + const tensor& k_src, + const tensor& v_src, + const tensor& k_past, + const tensor& v_past, + const tensor& q_dst, + const tensor& k_dst, + const tensor& v_dst, + const tensor& cos, + const tensor& sin) { + auto batch = k_past.m_dims[0]; + auto head_num = k_past.m_dims[1]; + auto past_seq_len = k_past.m_dims[2]; + auto head_size = k_past.m_dims[3]; + auto query_seq_len = q_src.m_dims[1]; + auto rotary_ndim = cos.m_dims[3]; + auto num_kv_heads_in_group = q_src.m_dims[3]; + + parallel_for3d(batch, head_num, query_seq_len, [&](size_t b, size_t h, size_t s) { + auto kv_dst_s = s + past_seq_len; + auto cur_num_kv_heads = h / num_kv_heads_in_group; + auto cur_sub_head_num = h % num_kv_heads_in_group; + + // q, k rotary encoding + rotary_avx512(rotary_ndim, &cos.at({0, 0, s + past_seq_len}), &sin.at({0, 0, s + past_seq_len}), + &q_src.at({b, s, cur_num_kv_heads, cur_sub_head_num}), + &k_src.at({b, s, cur_num_kv_heads, 0}), + &q_dst.at({b, h, s}), + &k_dst.at({b, h, kv_dst_s})); + if (head_size > rotary_ndim) { + memcpy(&q_dst.at({b, h, s, rotary_ndim}), &q_src.at({b, s, h, rotary_ndim}), (head_size - rotary_ndim) * sizeof(ov::bfloat16)); + memcpy(&k_dst.at({b, h, kv_dst_s, rotary_ndim}), &k_src.at({b, s, h, rotary_ndim}), (head_size - rotary_ndim) * sizeof(ov::bfloat16)); + } + + // v concat + memcpy(&v_dst.at({b, h, kv_dst_s}), &v_src.at({b, s, cur_num_kv_heads, 0}), head_size * sizeof(ov::bfloat16)); + }); +} + status_t emb_gpt_avx512(const tensor& q_src, const tensor& k_src, const tensor& v_src, @@ -97,9 +144,10 @@ status_t emb_gpt_avx512(const tensor& q_src, const tensor& cos, const tensor& sin, const tensor& position2d_ids) { - if (q_src.m_rank != 4 || k_src.m_rank != 4 || v_src.m_rank != 4 || k_past.m_rank != 4 || v_past.m_rank != 4 || q_dst.m_rank != 4|| + if ((q_src.m_rank != 4 && q_src.m_rank != 5) || (k_src.m_rank != 4 && k_src.m_rank != 5) || (v_src.m_rank != 4 && v_src.m_rank != 5) || + k_past.m_rank != 4 || v_past.m_rank != 4 || q_dst.m_rank != 4 || k_dst.m_rank != 4 || v_dst.m_rank != 4 || cos.m_rank != 4 || sin.m_rank != 4) { - DEBUG_LOG << "emb_gpt_avx512: rank is not correct: should be 4\n"; + DEBUG_LOG << "emb_gpt_avx512: rank is not correct: should be 4/5\n"; return status_t::status_invalid_arguments; } if (position2d_ids) { @@ -131,7 +179,10 @@ status_t emb_gpt_avx512(const tensor& q_src, // query pass part(temp buffer): query = torch.cat((query, query_pass), dim=-1) // key pass part(past_key): key = torch.cat((key, key_pass), dim=-1) // value(pastKeys): value = torch.cat((past_value, value), dim=-2) - rotary_emb_position2d(q_src, k_src, v_src, k_past, v_past, q_dst, k_dst, v_dst, cos, sin, position2d_ids); + if (q_src.m_rank == 4) + rotary_emb_position2d(q_src, k_src, v_src, k_past, v_past, q_dst, k_dst, v_dst, cos, sin, position2d_ids); + else + rotary_emb_falcon(q_src, k_src, v_src, k_past, v_past, q_dst, k_dst, v_dst, cos, sin); } return status_t::status_ok; diff --git a/tests/script/README.md b/tests/script/README.md index b301c08..6ef1941 100644 --- a/tests/script/README.md +++ b/tests/script/README.md @@ -5,6 +5,7 @@ prepare python enviroment ``` python3 -m venv .env source .env/bin/activate +sudo apt install libnuma-dev pip3 install -r requirements.txt ``` diff --git a/tests/script/ext/emb_gpt.cpp b/tests/script/ext/emb_gpt.cpp index 1576934..43acafc 100644 --- a/tests/script/ext/emb_gpt.cpp +++ b/tests/script/ext/emb_gpt.cpp @@ -56,7 +56,7 @@ void regclass_emb_gpt(pybind11::module m) { sin_.resize({sin.size(0), sin.size(1), sin.size(2), sin.size(3)}, sin.data_ptr()); if (position2d_ids.numel()) position2d_ids_.resize({batch, 2, query_seq_len}, position2d_ids.data_ptr()); - + llmdnn::emb_gpt(q_, k_, v_, k_past_, v_past_, q_dst_, k_dst_, v_dst_, cos_, sin_, position2d_ids_); return std::make_tuple(q_dst, k_dst, v_dst); @@ -72,6 +72,64 @@ void regclass_emb_gpt(pybind11::module m) { R"( exec emb + :param num_heads: heads number. + :type num_heads: int + )"); + m.def("emb_gpt", [] ( + const torch::Tensor& qkv, + int num_kv_heads, + const torch::Tensor& k_past, + const torch::Tensor& v_past, + const torch::Tensor& cos, + const torch::Tensor& sin) { + // qkv: [batch, seq_len, (head_num + num_kv_heads * 2) * head_size] + // k_past: [batch, head_num, past_seq_len, head_size] + // q_dst: [batch, head_num, query_seq_len, head_size] + // k_dst: [batch, head_num, query_seq_len+past_seq_len, head_size] + // cos: [max_seq_len, rotary_dims] + AT_ASSERT(qkv.dim() == 3 && k_past.dim() == 4 && v_past.dim() == 4); + auto batch = qkv.size(0); + auto query_seq_len = qkv.size(1); + auto head_num = k_past.size(1); + auto head_size = k_past.size(3); + auto past_seq_len = k_past.size(2); + auto kv_seq_len = query_seq_len + past_seq_len; + AT_ASSERT(qkv.size(2) / head_size - 2 * num_kv_heads == head_num); + + torch::Tensor q_dst = qkv.new_empty({batch, head_num, query_seq_len, head_size}); + torch::Tensor k_dst = qkv.new_empty({batch, head_num, kv_seq_len, head_size}); + torch::Tensor v_dst = qkv.new_empty({batch, head_num, kv_seq_len, head_size}); + // q, k, v will be [batch, seq_len, num_kv_heads, head_num/num_kv_heads|1, head_size] + llmdnn::tensor q_, k_, v_, k_past_, v_past_, q_dst_, k_dst_, v_dst_, cos_, sin_, position2d_ids_; + q_.resize({batch, query_seq_len, num_kv_heads, qkv.size(2) / head_size / num_kv_heads, head_size}, reinterpret_cast(qkv.data_ptr()) + head_size * 0); + q_.m_dims[3] = head_num / num_kv_heads; + k_.resize({batch, query_seq_len, num_kv_heads, qkv.size(2) / head_size / num_kv_heads, head_size}, reinterpret_cast(qkv.data_ptr()) + head_size * q_.m_dims[3]); + k_.m_dims[3] = 1; + v_.resize({batch, query_seq_len, num_kv_heads, qkv.size(2) / head_size / num_kv_heads, head_size}, reinterpret_cast(qkv.data_ptr()) + head_size * (q_.m_dims[3] + 1)); + v_.m_dims[3] = 1; + k_past_.resize({batch, head_num, past_seq_len, head_size}, reinterpret_cast(k_past.data_ptr())); + v_past_.resize({batch, head_num, past_seq_len, head_size}, reinterpret_cast(v_past.data_ptr())); + q_dst_.resize({batch, head_num, query_seq_len, head_size}, reinterpret_cast(q_dst.data_ptr())); + k_dst_.resize({batch, head_num, kv_seq_len, head_size}, reinterpret_cast(k_dst.data_ptr())); + v_dst_.resize({batch, head_num, kv_seq_len, head_size}, reinterpret_cast(v_dst.data_ptr())); + cos_.resize({cos.size(0), cos.size(1), cos.size(2), cos.size(3)}, cos.data_ptr()); + sin_.resize({sin.size(0), sin.size(1), sin.size(2), sin.size(3)}, sin.data_ptr()); + + llmdnn::emb_gpt(q_, k_, v_, k_past_, v_past_, q_dst_, k_dst_, v_dst_, cos_, sin_, position2d_ids_); + + return std::make_tuple(q_dst, k_dst, v_dst); + // auto options = torch::TensorOptions().dtype(torch::kBFloat16); + // auto query = torch::from_blob(param.query, {batch, num_heads, query_seq_len, head_size}, options); + }, + py::arg("qkv"), + py::arg("num_kv_heads"), + py::arg("k_past"), + py::arg("v_past"), + py::arg("cos"), + py::arg("sin"), + R"( + exec emb + :param num_heads: heads number. :type num_heads: int )"); diff --git a/tests/script/ext/setup.py b/tests/script/ext/setup.py index 1087a1e..133ac56 100644 --- a/tests/script/ext/setup.py +++ b/tests/script/ext/setup.py @@ -42,6 +42,7 @@ library_dirs=[f'{sys.prefix}/lib', cpu_extensions_lib_dir], libraries=['cpu_extensions', + 'numa', 'stdc++']), ], cmdclass={'build_ext': cpp_extension.BuildExtension} diff --git a/tests/script/test_rotary_pastkv_falcon.py b/tests/script/test_rotary_pastkv_falcon.py new file mode 100644 index 0000000..b9f951e --- /dev/null +++ b/tests/script/test_rotary_pastkv_falcon.py @@ -0,0 +1,262 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +import math +import sys +import torch +import numpy as np +import llmdnn as ld +from torch import nn +import torch.nn.functional as F +from torch.nn.utils import skip_init +from typing import Optional, Tuple, Union, List, Callable, Dict, Any + +# copy from transformers/models/falcon/modeling_falcon.py +# rotary pos emb helpers (torch.jit.script does not seem to support staticmethod...) +def rotate_half(x): + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +class FalconRotaryEmbedding(nn.Module): + """Implementation of RotaryEmbedding from GPT-NeoX. + This implementation is designed to operate on queries and keys that are compatible with `[batch_size, + n_heads_per_partition, seq_len, head_dim]` (e.g. MinGPTAttention format). + """ + + def __init__(self, head_dim: int, base=10000): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.head_dim = head_dim + self.seq_len_cached = -1 + self.cos_cached: torch.Tensor | None = None + self.sin_cached: torch.Tensor | None = None + + def cos_sin(self, seq_len: int, past_key_values_length: int, device="cpu", dtype=torch.bfloat16) -> torch.Tensor: + total_length = seq_len + past_key_values_length + if total_length > self.seq_len_cached: + self.seq_len_cached = total_length + t = torch.arange(total_length, device=device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1).to(device) + + if dtype in [torch.float16, torch.bfloat16]: + emb = emb.float() + + self.cos_cached = emb.cos()[None, :, :] + self.sin_cached = emb.sin()[None, :, :] + + # self.cos_cached = self.cos_cached.type(dtype) + # self.sin_cached = self.sin_cached.type(dtype) + + return ( + self.cos_cached[:, past_key_values_length : seq_len + past_key_values_length], + self.sin_cached[:, past_key_values_length : seq_len + past_key_values_length], + ) + + def forward(self, query, key, past_key_values_length=0): + batch, seq_len, head_dim = query.shape + cos, sin = self.cos_sin(seq_len, past_key_values_length, query.device, query.dtype) + return (query * cos) + (rotate_half(query) * sin), (key * cos) + (rotate_half(key) * sin) + +class FalconAttention(torch.nn.Module): + def __init__(self, hidden_size, num_attention_heads, num_kv_heads, new_decoder_architecture=True): + super().__init__() + + self.hidden_size = hidden_size + self.num_heads = num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.split_size = self.hidden_size + + if self.head_dim * self.num_heads != self.hidden_size: + raise ValueError( + f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:" + f" {self.num_heads})." + ) + + self.maybe_rotary = FalconRotaryEmbedding(self.head_dim) #if config.rotary else lambda q, k, t: (q, k) + + # Layer-wise attention scaling + self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim) + self.beta = self.inv_norm_factor + if new_decoder_architecture: + qkv_out_dim = (num_kv_heads * 2 + num_attention_heads) * self.head_dim + elif config.multi_query: + qkv_out_dim = self.hidden_size + 2 * self.head_dim + else: + qkv_out_dim = 3 * self.hidden_size + self.new_decoder_architecture = new_decoder_architecture + self.num_kv_heads = num_kv_heads #if (self.new_decoder_architecture or not self.multi_query) else 1 + + def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Split the last dimension into (num_heads, head_dim), results share same memory storage as `fused_qkv` + + Args: + fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim] + + Returns: + query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim] + value: [batch_size, seq_length, num_heads, head_dim] + """ + if self.new_decoder_architecture: + batch, seq_len, _ = fused_qkv.shape + qkv = fused_qkv.view(batch, seq_len, -1, self.num_heads // self.num_kv_heads + 2, self.head_dim) + query = qkv[:, :, :, :-2] + key = qkv[:, :, :, [-2]] + value = qkv[:, :, :, [-1]] + key = torch.broadcast_to(key, query.shape) + value = torch.broadcast_to(value, query.shape) + + query, key, value = [x.flatten(2, 3) for x in (query, key, value)] + return query, key, value + elif not self.multi_query: + batch_size, seq_length, three_times_hidden_size = fused_qkv.shape + fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim) + return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :] + else: + batch_size, seq_length, three_times_hidden_size = fused_qkv.shape + fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads + 2, self.head_dim) + return fused_qkv[..., :-2, :], fused_qkv[..., [-2], :], fused_qkv[..., [-1], :] + + def forward( + self, + fused_qkv: torch.Tensor, # [batch_size, seq_length, 9216] + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + ): + num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads + # 3 x [batch_size, seq_length, num_heads, head_dim] + (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) + + batch_size, query_length, _, _ = query_layer.shape + + query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, query_length, self.head_dim) + key_layer = key_layer.transpose(1, 2).reshape( + batch_size * num_kv_heads, + query_length, + self.head_dim, + ) + value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_kv_heads, query_length, self.head_dim) + + layer_past = [item.view(item.size(0) * item.size(1), item.size(2), item.size(3)) for item in layer_past] + past_kv_length = 0 if layer_past is None else layer_past[0].shape[1] + query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length) + + if layer_past is not None: + past_key, past_value = layer_past + # concatenate along seq_length dimension: + # - key: [batch_size * self.num_heads, kv_length, head_dim] + # - value: [batch_size * self.num_heads, kv_length, head_dim] + key_layer = torch.cat((past_key, key_layer), dim=1) + value_layer = torch.cat((past_value, value_layer), dim=1) + + _, kv_length, _ = key_layer.shape + if use_cache: + present = (key_layer, value_layer) + else: + present = None + + query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim) + key_layer_ = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim) + value_layer_ = value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim) + + return query_layer_, key_layer_, value_layer_ + + +class FalconAttentionExt: + def __init__(self, num_attention_heads, hidden_size, max_position_embeddings, rotary_ndims, rotary_emb_base=10000): + num_heads = num_attention_heads + head_size = hidden_size // num_attention_heads + max_seq_len = max_position_embeddings + + inv_freq = 1. / (rotary_emb_base ** (torch.arange(0, rotary_ndims, 2).float() / rotary_ndims)) + #inv_freq = inv_freq.half() + self.max_seq_len_cached = None + self.cos_cached = None + self.sin_cached = None + # use f32 to pass accuracy test + # Build here to make `torch.jit.trace` work. + self.max_seq_len_cached = max_position_embeddings + t = torch.arange(self.max_seq_len_cached, device=inv_freq.device, dtype=inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.cos_cached = emb.cos()[None, None, :, :] + self.sin_cached = emb.sin()[None, None, :, :] + + # qkv: [batch, seq_len, ((num_heads + num_kv_heads * 2) * head_size)] + # layer_past_padded: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] + # past_seq_len: past_seq_len==layer_past.shape[-2] + # return: + # 0: (k, v): ([batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned], [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned]) + # 1: query: [batch, num_attention_heads, seq_len, head_size_aligned] + # 2: k: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] + # 3: v: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] + def forward(self, qkv, num_kv_heads, k_past, v_past): + return ld.emb_gpt(qkv, num_kv_heads, k_past, v_past, self.cos_cached, self.sin_cached) + + +HEAD_NUM = 128 +SIZE_PER_HEAD = 64 +SIZE_PER_HEAD_ALIGN = 64 +NUM_KV_HEADS = 8 +HIDDEN_SIZE = HEAD_NUM * SIZE_PER_HEAD +MAX_POSITION_EMBEDDINGS = 1024 #2048 +ROTARY_EMB_BASE = 10000 +ROTARY_PCT = 0.5 +MAX_SEQ_LEN = 1024 +def get_ref_model(): + ref_net = FalconAttention(hidden_size=HIDDEN_SIZE, num_attention_heads=HEAD_NUM, num_kv_heads=NUM_KV_HEADS, new_decoder_architecture=True) + ref_net.maybe_rotary.cos_sin(0, MAX_SEQ_LEN) + ref_net = ref_net.to(dtype=torch.bfloat16) + return ref_net + +def test_falcon(): + inputs = [ + # qkv: [batch, seq_len, (num_heads + 2 * num_kv_heads) * head_size)] + # layer_past: [batch, num_attention_heads, past_seq_len, head_size] + (np.random.random(size=[2, 200, (HEAD_NUM + 2 * NUM_KV_HEADS) * SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 0, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 0, SIZE_PER_HEAD]).astype(np.float32)), + (np.random.random(size=[2, 1, (HEAD_NUM + 2 * NUM_KV_HEADS) * SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32)), + ] + ref_net = get_ref_model() + net_seq = FalconAttentionExt(HEAD_NUM, HIDDEN_SIZE, MAX_POSITION_EMBEDDINGS, SIZE_PER_HEAD, ROTARY_EMB_BASE) + with torch.cpu.amp.autocast(): + for (i, input) in enumerate(inputs): + qkv, layer_past_key, layer_past_value = input + qkv = torch.from_numpy(qkv).to(torch.bfloat16) + layer_past_key = torch.from_numpy(layer_past_key).to(torch.bfloat16) + layer_past_value = torch.from_numpy(layer_past_value).to(torch.bfloat16) + + query_ref, key_ref, value_ref = ref_net.forward(qkv, (layer_past_key, layer_past_value)) + query_ref = query_ref.to(dtype=torch.bfloat16) + key_ref = key_ref.to(dtype=torch.bfloat16) + + # no prealloc past kv + query, key, value = net_seq.forward(qkv, NUM_KV_HEADS, layer_past_key, layer_past_value) + # check query + if not torch.allclose(query_ref, query, rtol=0.001, atol=0.01): + print(f"error at sequence query index {i} ref:\n{query_ref} \ncur:\n {query} ") + assert(False) + # check key + if not torch.allclose(key_ref, key, rtol=0.001, atol=0.01): + print(f"error at sequence key index {i} ref:\n{key_ref} \ncur:\n {key} ") + assert(False) + # check value + if not torch.allclose(value_ref, value, rtol=0.001, atol=0.01): + print(f"error at sequence value index {i} ref:\n{value_ref} \ncur:\n {value} ") + assert(False) + + print('done.') + return + +if __name__ == "__main__": + test_falcon() \ No newline at end of file From 3a7432c5b2f69d0fe910ab917acabbf1b428ea2d Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Thu, 17 Aug 2023 23:18:27 +0800 Subject: [PATCH 47/54] fix int8 compress --- include/llm_fc.hpp | 6 +- src/fc_kernel_amx.cpp | 25 +++-- src/fc_kernel_amx.hpp | 2 +- src/fc_kernel_api.cpp | 4 +- tests/src/test_fc_kernel_amx.cpp | 151 ++++++++++++++++--------------- 5 files changed, 101 insertions(+), 87 deletions(-) diff --git a/include/llm_fc.hpp b/include/llm_fc.hpp index 5b95415..0cd08e0 100644 --- a/include/llm_fc.hpp +++ b/include/llm_fc.hpp @@ -55,8 +55,6 @@ struct fc_kernel; /// fc: (s8,s8,s8),dq,[bias],[gelu],q /// fc: (s8,s8,bf16),dq,[bias],[gelu] /// fc: (s8,s8,f32),dq,[bias],[gelu] -/// fc: (bf16,f32,bf16),[bias],[gelu] -/// fc: (bf16,f32,f32),[bias],[gelu] /// fc: (bf16,bf16,bf16),[bias],[gelu] /// fc: (bf16,bf16,f32),[bias],[gelu] /// fc: (bf16,s8,f32),dq,[bias],[gelu] @@ -64,7 +62,9 @@ struct fc_kernel; /// status_t fc_kernel_create(fc_kernel** mm, const fc_create_param* param); void fc_kernel_destroy(fc_kernel* mm); -void fc_kernel_pack_weight(fc_kernel* mm, void* ptr_b, size_t N, size_t K, size_t stride_b, size_t n_start, size_t n_end); +// when fc_create_param.dt_b==bf16, dt_b is in [bf16, f32] +// when fc_create_param.dt_b==s8, dt_b is in [bf16, f32] +void fc_kernel_pack_weight(fc_kernel* mm, void* ptr_b, data_type_t dt_b, size_t N, size_t K, size_t stride_b, size_t n_start, size_t n_end); void fc_kernel_execute(fc_kernel* mm, void* ptr_a, void* ptr_c, size_t stride_a, size_t stride_c, size_t M, size_t N, size_t K, size_t n_start, size_t n_end, diff --git a/src/fc_kernel_amx.cpp b/src/fc_kernel_amx.cpp index 460d98e..d325573 100644 --- a/src/fc_kernel_amx.cpp +++ b/src/fc_kernel_amx.cpp @@ -40,8 +40,6 @@ static bool check_valid_postops(size_t value, data_type_t dt_a, data_type_t dt_b { { llmdnn_s8, llmdnn_s8, llmdnn_s8 }, { DEQUANT | QUANT, BIAS | GELU | GELU_TANH } }, { { llmdnn_s8, llmdnn_s8, llmdnn_bf16 }, { DEQUANT, BIAS | GELU | GELU_TANH } }, { { llmdnn_s8, llmdnn_s8, llmdnn_f32 }, { DEQUANT, BIAS | GELU | GELU_TANH } }, - { { llmdnn_bf16, llmdnn_f32, llmdnn_bf16 }, { 0, BIAS | GELU | GELU_TANH } }, - { { llmdnn_bf16, llmdnn_f32, llmdnn_f32 }, { 0, BIAS | GELU | GELU_TANH } }, { { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16 }, { 0, BIAS | GELU | GELU_TANH } }, { { llmdnn_bf16, llmdnn_bf16, llmdnn_f32 }, { 0, BIAS | GELU | GELU_TANH } }, { { llmdnn_bf16, llmdnn_s8, llmdnn_f32 }, { DEQUANT, BIAS | GELU | GELU_TANH } }, @@ -116,7 +114,7 @@ void fc_kernel_destroy_amx(fc_kernel* mm) { } } -void fc_kernel_pack_weight_amx(fc_kernel* mm, void* ptr_b, size_t N, size_t K, size_t stride_b, size_t n_start, size_t n_end) { +void fc_kernel_pack_weight_amx(fc_kernel* mm, void* ptr_b, data_type_t dt_b, size_t N, size_t K, size_t stride_b, size_t n_start, size_t n_end) { mm->stride_b = stride_b; size_t b_d0 = K, b_d1 = N; if (mm->b_is_transpose) { @@ -132,7 +130,7 @@ void fc_kernel_pack_weight_amx(fc_kernel* mm, void* ptr_b, size_t N, size_t K, s auto matB = amx_kernel::getSubMatB(b, n_start, n_end, mm->b_is_transpose); amx_kernel::repackB_1x2(matB, mm->b_is_transpose, mm->u8xi8->internalB, true); } else if (mm->bf16xbf16) { - if (mm->dt_b == llmdnn_bf16) { + if (dt_b == llmdnn_bf16) { tensor2D b(b_d0, b_d1, static_cast(ptr_b), mm->stride_b); auto matB = amx_kernel::getSubMatB(b, n_start, n_end, mm->b_is_transpose); amx_kernel::repackB_1x2(matB, mm->b_is_transpose, mm->bf16xbf16->internalB, true); @@ -142,13 +140,22 @@ void fc_kernel_pack_weight_amx(fc_kernel* mm, void* ptr_b, size_t N, size_t K, s amx_kernel::repackB_1x2(matB, mm->b_is_transpose, mm->bf16xbf16->internalB, true); } } else { - tensor2D b(b_d0, b_d1, static_cast(ptr_b), mm->stride_b); - auto matB = amx_kernel::getSubMatB(b, n_start, n_end, mm->b_is_transpose); tensor2D internalTmpB; - amx_kernel::repackB_1x2(matB, mm->b_is_transpose, internalTmpB, true); + if (dt_b == llmdnn_bf16) { + tensor2D b(b_d0, b_d1, static_cast(ptr_b), mm->stride_b); + auto matB = amx_kernel::getSubMatB(b, n_start, n_end, mm->b_is_transpose); + amx_kernel::repackB_1x2(matB, mm->b_is_transpose, internalTmpB, true); + } else { + tensor2D b(b_d0, b_d1, static_cast(ptr_b), mm->stride_b); + auto matB = amx_kernel::getSubMatB(b, n_start, n_end, mm->b_is_transpose); + amx_kernel::repackB_1x2(matB, mm->b_is_transpose, internalTmpB, true); + } + if (mm->bf16xi8->dequant_scale_B == 0) { + fc_kernel_bf16w8_get_q_dq_amx(internalTmpB.dims[0], internalTmpB.dims[1], internalTmpB.stride, internalTmpB.data, + &mm->bf16xi8->quant_scale_B, &mm->bf16xi8->dequant_scale_B); + } amx_kernel::functional::bf16_to_i8_tensor(mm->bf16xi8->internalBI8, internalTmpB, mm->bf16xi8->quant_scale_B); } - } void fc_kernel_execute_amx(fc_kernel* mm, void* ptr_a, void* ptr_c, size_t stride_a, size_t stride_c, @@ -325,7 +332,7 @@ void fc_kernel_execute_amx(fc_kernel* mm, void* ptr_a, void* ptr_c, size_t strid } } else { tensor2D a(M, K, reinterpret_cast(ptr_a), stride_a); - tensor2D b(N, K, nullptr, mm->stride_b); + tensor2D b(b_d0, b_d1, nullptr, mm->stride_b); if (mm->dt_c == llmdnn_bf16) { tensor2D c(M, N, reinterpret_cast(ptr_c), stride_c); diff --git a/src/fc_kernel_amx.hpp b/src/fc_kernel_amx.hpp index f6f2011..bbf0727 100644 --- a/src/fc_kernel_amx.hpp +++ b/src/fc_kernel_amx.hpp @@ -10,7 +10,7 @@ status_t fc_kernel_create_amx(fc_kernel** mm, const fc_create_param* param); void fc_kernel_destroy_amx(fc_kernel* mm); -void fc_kernel_pack_weight_amx(fc_kernel* mm, void* ptr_b, size_t N, size_t K, size_t stride_b, size_t n_start, size_t n_end); +void fc_kernel_pack_weight_amx(fc_kernel* mm, void* ptr_b, data_type_t dt_b, size_t N, size_t K, size_t stride_b, size_t n_start, size_t n_end); void fc_kernel_execute_amx(fc_kernel* mm, void* ptr_a, void* ptr_c, size_t stride_a, size_t stride_c, size_t M, size_t N, size_t K, size_t n_start, size_t n_end, float* dq, float* q, float* bias); diff --git a/src/fc_kernel_api.cpp b/src/fc_kernel_api.cpp index 94ffbba..3e9910a 100644 --- a/src/fc_kernel_api.cpp +++ b/src/fc_kernel_api.cpp @@ -36,8 +36,8 @@ void fc_kernel_destroy(fc_kernel* mm) { fc_kernel_destroy_ptr(mm); } -void fc_kernel_pack_weight(fc_kernel* mm, void* ptr_b, size_t N, size_t K, size_t stride_b, size_t n_start, size_t n_end) { - fc_kernel_pack_weight_ptr(mm, ptr_b, N, K, stride_b, n_start, n_end); +void fc_kernel_pack_weight(fc_kernel* mm, void* ptr_b, data_type_t dt_b, size_t N, size_t K, size_t stride_b, size_t n_start, size_t n_end) { + fc_kernel_pack_weight_ptr(mm, ptr_b, dt_b, N, K, stride_b, n_start, n_end); } void fc_kernel_execute(fc_kernel* mm, void* ptr_a, void* ptr_c, size_t stride_a, size_t stride_c, diff --git a/tests/src/test_fc_kernel_amx.cpp b/tests/src/test_fc_kernel_amx.cpp index 3ee6be1..18a976d 100644 --- a/tests/src/test_fc_kernel_amx.cpp +++ b/tests/src/test_fc_kernel_amx.cpp @@ -22,7 +22,11 @@ using ::testing::Values; using ::testing::ValuesIn; using FCKernelTestShape = std::tuple; -using FCKernelTestDTPost = std::tuple; +using FCKernelTestDTPost = std::tuple; using FCKernelTestParamSet = std::tuple< FCKernelTestDTPost, // a, b, c data type, postops bool, // b needs transpose @@ -35,16 +39,17 @@ class FCKernelTest : public TestWithParam { FCKernelTestDTPost types; bool is_transpose; postops_types postops_type; - data_type_t dt_a, dt_b, dt_c; + data_type_t dt_a, dt_b, dt_c, dt_weight; FCKernelTestShape shape; int M, N, K; std::tie(types, is_transpose, shape) = obj.param; std::tie(M, N, K) = shape; - std::tie(dt_a, dt_b, dt_c, postops_type) = types; + std::tie(dt_a, dt_b, dt_c, dt_weight, postops_type) = types; std::ostringstream result; result << "A_" << dtype_to_str(dt_a) << "_B_" << dtype_to_str(dt_b) - << "_C_" << dtype_to_str(dt_c) << (is_transpose ? "_transpose" : "") + << "_C_" << dtype_to_str(dt_c) << "_WEIGHT_" << dtype_to_str(dt_weight) + << (is_transpose ? "_transpose" : "") << "_postops_" << postops_type << "_M_" << M << "_N_" << N << "_K_" << K; return result.str(); } @@ -57,7 +62,7 @@ class FCKernelTest : public TestWithParam { FCKernelTestDTPost types; std::tie(types, _is_transpose, shape) = GetParam(); std::tie(_M, _N, _K) = shape; - std::tie(_dt_a, _dt_b, _dt_c, _postops_type) = types; + std::tie(_dt_a, _dt_b, _dt_c, _dt_weight, _postops_type) = types; }; template @@ -95,7 +100,7 @@ class FCKernelTest : public TestWithParam { ptr_B = B.data; ldb = B.stride; } - fc_kernel_pack_weight(gemm.get(), ptr_B, _N, _K, ldb, 0, _N); + fc_kernel_pack_weight(gemm.get(), ptr_B, _dt_weight, _N, _K, ldb, 0, _N); fc_kernel_execute(gemm.get(), A.data, C.data, A.stride, C.stride, _M, _N, _K, 0, _N, dq.data, q.data, bias.data); C_Ref = 0; @@ -103,7 +108,7 @@ class FCKernelTest : public TestWithParam { float* ptr_q = nullptr; float* ptr_bias = nullptr; func_act act = func_act(); - if (_postops_type & DEQUANT) { + if ((_postops_type & DEQUANT) && _dt_a == llmdnn::llmdnn_s8) { ptr_dq = dq.data; } if (_postops_type & QUANT) { @@ -135,28 +140,24 @@ class FCKernelTest : public TestWithParam { int _M, _N, _K; bool _is_transpose; postops_types _postops_type; - data_type_t _dt_a, _dt_b, _dt_c; + data_type_t _dt_a, _dt_b, _dt_c, _dt_weight; }; TEST_P(FCKernelTest, Func) { - if (_dt_a == llmdnn_s8 && _dt_b == llmdnn_s8 && _dt_c == llmdnn_s8) { + if (_dt_a == llmdnn_s8 && _dt_weight == llmdnn_s8 && _dt_c == llmdnn_s8) { do_test(); - } else if (_dt_a == llmdnn_s8 && _dt_b == llmdnn_s8 && _dt_c == llmdnn_bf16) { + } else if (_dt_a == llmdnn_s8 && _dt_weight == llmdnn_s8 && _dt_c == llmdnn_bf16) { do_test(); - } else if (_dt_a == llmdnn_s8 && _dt_b == llmdnn_s8 && _dt_c == llmdnn_f32) { + } else if (_dt_a == llmdnn_s8 && _dt_weight == llmdnn_s8 && _dt_c == llmdnn_f32) { do_test(); - } else if (_dt_a == llmdnn_bf16 && _dt_b == llmdnn_bf16 && _dt_c == llmdnn_bf16) { + } else if (_dt_a == llmdnn_bf16 && _dt_weight == llmdnn_bf16 && _dt_c == llmdnn_bf16) { do_test(); - } else if (_dt_a == llmdnn_bf16 && _dt_b == llmdnn_bf16 && _dt_c == llmdnn_f32) { + } else if (_dt_a == llmdnn_bf16 && _dt_weight == llmdnn_bf16 && _dt_c == llmdnn_f32) { do_test(); - } else if (_dt_a == llmdnn_bf16 && _dt_b == llmdnn_f32 && _dt_c == llmdnn_bf16) { + } else if (_dt_a == llmdnn_bf16 && _dt_weight == llmdnn_f32 && _dt_c == llmdnn_bf16) { do_test(); - } else if (_dt_a == llmdnn_bf16 && _dt_b == llmdnn_f32 && _dt_c == llmdnn_f32) { + } else if (_dt_a == llmdnn_bf16 && _dt_weight == llmdnn_f32 && _dt_c == llmdnn_f32) { do_test(); - } else if (_dt_a == llmdnn_bf16 && _dt_b == llmdnn_s8 && _dt_c == llmdnn_f32) { - do_test(); - } else if (_dt_a == llmdnn_bf16 && _dt_b == llmdnn_s8 && _dt_c == llmdnn_bf16) { - do_test(); } else { ASSERT_TRUE(false); } @@ -168,62 +169,68 @@ TEST_P(FCKernelTest, Func) { // (s8,s8,f32),dq,[bias],[gelu] // (bf16,bf16,bf16),[bias],[gelu] // (bf16,bf16,f32),[bias],[gelu] -// (bf16,f32,bf16),[bias],[gelu] -// (bf16,f32,f32),[bias],[gelu] // (bf16,s8,f32),dq,[bias],[gelu] // (bf16,s8,bf16),dq,[bias],[gelu] const std::vector types = { - { llmdnn_s8, llmdnn_s8, llmdnn_s8, DEQUANT_QUANT }, - { llmdnn_s8, llmdnn_s8, llmdnn_s8, DEQUANT_BIAS_QUANT }, - { llmdnn_s8, llmdnn_s8, llmdnn_s8, DEQUANT_GELU_QUANT }, - { llmdnn_s8, llmdnn_s8, llmdnn_s8, DEQUANT_BIAS_GELU_QUANT }, - { llmdnn_s8, llmdnn_s8, llmdnn_s8, DEQUANT_GELU_TANH_QUANT }, - { llmdnn_s8, llmdnn_s8, llmdnn_s8, DEQUANT_BIAS_GELU_TANH_QUANT }, - { llmdnn_s8, llmdnn_s8, llmdnn_bf16, DEQUANT }, - { llmdnn_s8, llmdnn_s8, llmdnn_bf16, DEQUANT_BIAS }, - { llmdnn_s8, llmdnn_s8, llmdnn_bf16, DEQUANT_GELU }, - { llmdnn_s8, llmdnn_s8, llmdnn_bf16, DEQUANT_BIAS_GELU }, - { llmdnn_s8, llmdnn_s8, llmdnn_bf16, DEQUANT_GELU_TANH }, - { llmdnn_s8, llmdnn_s8, llmdnn_bf16, DEQUANT_BIAS_GELU_TANH }, - { llmdnn_s8, llmdnn_s8, llmdnn_f32, DEQUANT }, - { llmdnn_s8, llmdnn_s8, llmdnn_f32, DEQUANT_BIAS }, - { llmdnn_s8, llmdnn_s8, llmdnn_f32, DEQUANT_GELU }, - { llmdnn_s8, llmdnn_s8, llmdnn_f32, DEQUANT_BIAS_GELU }, - { llmdnn_s8, llmdnn_s8, llmdnn_f32, DEQUANT_GELU_TANH }, - { llmdnn_s8, llmdnn_s8, llmdnn_f32, DEQUANT_BIAS_GELU_TANH }, - { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, NONE }, - { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, BIAS }, - { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, GELU }, - { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, BIAS_GELU }, - { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, GELU_TANH }, - { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, BIAS_GELU_TANH }, - { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, NONE }, - { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, BIAS }, - { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, GELU }, - { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, BIAS_GELU }, - { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, GELU_TANH }, - { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, BIAS_GELU_TANH }, - { llmdnn_bf16, llmdnn_f32, llmdnn_bf16, NONE }, - { llmdnn_bf16, llmdnn_f32, llmdnn_bf16, BIAS }, - { llmdnn_bf16, llmdnn_f32, llmdnn_bf16, GELU }, - { llmdnn_bf16, llmdnn_f32, llmdnn_bf16, BIAS_GELU }, - { llmdnn_bf16, llmdnn_f32, llmdnn_bf16, GELU_TANH }, - { llmdnn_bf16, llmdnn_f32, llmdnn_bf16, BIAS_GELU_TANH }, - { llmdnn_bf16, llmdnn_f32, llmdnn_f32, NONE }, - { llmdnn_bf16, llmdnn_f32, llmdnn_f32, BIAS }, - { llmdnn_bf16, llmdnn_f32, llmdnn_f32, GELU }, - { llmdnn_bf16, llmdnn_f32, llmdnn_f32, BIAS_GELU }, - { llmdnn_bf16, llmdnn_f32, llmdnn_f32, GELU_TANH }, - { llmdnn_bf16, llmdnn_f32, llmdnn_f32, BIAS_GELU_TANH }, - // TODO: support weight compression - // { llmdnn_bf16, llmdnn_s8, llmdnn_f32, DEQUANT }, - // { llmdnn_bf16, llmdnn_s8, llmdnn_f32, DEQUANT_BIAS }, - // { llmdnn_bf16, llmdnn_s8, llmdnn_f32, DEQUANT_GELU }, - // { llmdnn_bf16, llmdnn_s8, llmdnn_f32, DEQUANT_BIAS_GELU }, - // { llmdnn_bf16, llmdnn_s8, llmdnn_bf16, DEQUANT }, - // { llmdnn_bf16, llmdnn_s8, llmdnn_bf16, DEQUANT_BIAS }, - // { llmdnn_bf16, llmdnn_s8, llmdnn_bf16, DEQUANT_GELU }, - // { llmdnn_bf16, llmdnn_s8, llmdnn_bf16, DEQUANT_BIAS_GELU }, + { llmdnn_s8, llmdnn_s8, llmdnn_s8, llmdnn_s8, DEQUANT_QUANT }, + { llmdnn_s8, llmdnn_s8, llmdnn_s8, llmdnn_s8, DEQUANT_BIAS_QUANT }, + { llmdnn_s8, llmdnn_s8, llmdnn_s8, llmdnn_s8, DEQUANT_GELU_QUANT }, + { llmdnn_s8, llmdnn_s8, llmdnn_s8, llmdnn_s8, DEQUANT_BIAS_GELU_QUANT }, + { llmdnn_s8, llmdnn_s8, llmdnn_s8, llmdnn_s8, DEQUANT_GELU_TANH_QUANT }, + { llmdnn_s8, llmdnn_s8, llmdnn_s8, llmdnn_s8, DEQUANT_BIAS_GELU_TANH_QUANT }, + { llmdnn_s8, llmdnn_s8, llmdnn_bf16, llmdnn_s8, DEQUANT }, + { llmdnn_s8, llmdnn_s8, llmdnn_bf16, llmdnn_s8, DEQUANT_BIAS }, + { llmdnn_s8, llmdnn_s8, llmdnn_bf16, llmdnn_s8, DEQUANT_GELU }, + { llmdnn_s8, llmdnn_s8, llmdnn_bf16, llmdnn_s8, DEQUANT_BIAS_GELU }, + { llmdnn_s8, llmdnn_s8, llmdnn_bf16, llmdnn_s8, DEQUANT_GELU_TANH }, + { llmdnn_s8, llmdnn_s8, llmdnn_bf16, llmdnn_s8, DEQUANT_BIAS_GELU_TANH }, + { llmdnn_s8, llmdnn_s8, llmdnn_f32, llmdnn_s8, DEQUANT }, + { llmdnn_s8, llmdnn_s8, llmdnn_f32, llmdnn_s8, DEQUANT_BIAS }, + { llmdnn_s8, llmdnn_s8, llmdnn_f32, llmdnn_s8, DEQUANT_GELU }, + { llmdnn_s8, llmdnn_s8, llmdnn_f32, llmdnn_s8, DEQUANT_BIAS_GELU }, + { llmdnn_s8, llmdnn_s8, llmdnn_f32, llmdnn_s8, DEQUANT_GELU_TANH }, + { llmdnn_s8, llmdnn_s8, llmdnn_f32, llmdnn_s8, DEQUANT_BIAS_GELU_TANH }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, NONE }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, BIAS }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, GELU }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, BIAS_GELU }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, GELU_TANH }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, BIAS_GELU_TANH }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, llmdnn_bf16, NONE }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, llmdnn_bf16, BIAS }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, llmdnn_bf16, GELU }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, llmdnn_bf16, BIAS_GELU }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, llmdnn_bf16, GELU_TANH }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, llmdnn_bf16, BIAS_GELU_TANH }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, llmdnn_f32, NONE }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, llmdnn_f32, BIAS }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, llmdnn_f32, GELU }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, llmdnn_f32, BIAS_GELU }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, llmdnn_f32, GELU_TANH }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, llmdnn_f32, BIAS_GELU_TANH }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, llmdnn_f32, NONE }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, llmdnn_f32, BIAS }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, llmdnn_f32, GELU }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, llmdnn_f32, BIAS_GELU }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, llmdnn_f32, GELU_TANH }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, llmdnn_f32, BIAS_GELU_TANH }, + // weight compression + { llmdnn_bf16, llmdnn_s8, llmdnn_f32, llmdnn_bf16, DEQUANT }, + { llmdnn_bf16, llmdnn_s8, llmdnn_f32, llmdnn_bf16, DEQUANT_BIAS }, + { llmdnn_bf16, llmdnn_s8, llmdnn_f32, llmdnn_bf16, DEQUANT_GELU }, + { llmdnn_bf16, llmdnn_s8, llmdnn_f32, llmdnn_bf16, DEQUANT_BIAS_GELU }, + { llmdnn_bf16, llmdnn_s8, llmdnn_bf16, llmdnn_bf16, DEQUANT }, + { llmdnn_bf16, llmdnn_s8, llmdnn_bf16, llmdnn_bf16, DEQUANT_BIAS }, + { llmdnn_bf16, llmdnn_s8, llmdnn_bf16, llmdnn_bf16, DEQUANT_GELU }, + { llmdnn_bf16, llmdnn_s8, llmdnn_bf16, llmdnn_bf16, DEQUANT_BIAS_GELU }, + { llmdnn_bf16, llmdnn_s8, llmdnn_f32, llmdnn_f32, DEQUANT }, + { llmdnn_bf16, llmdnn_s8, llmdnn_f32, llmdnn_f32, DEQUANT_BIAS }, + { llmdnn_bf16, llmdnn_s8, llmdnn_f32, llmdnn_f32, DEQUANT_GELU }, + { llmdnn_bf16, llmdnn_s8, llmdnn_f32, llmdnn_f32, DEQUANT_BIAS_GELU }, + { llmdnn_bf16, llmdnn_s8, llmdnn_bf16, llmdnn_f32, DEQUANT }, + { llmdnn_bf16, llmdnn_s8, llmdnn_bf16, llmdnn_f32, DEQUANT_BIAS }, + { llmdnn_bf16, llmdnn_s8, llmdnn_bf16, llmdnn_f32, DEQUANT_GELU }, + { llmdnn_bf16, llmdnn_s8, llmdnn_bf16, llmdnn_f32, DEQUANT_BIAS_GELU }, }; // M, N, K From 4a4e48c40a0895d5f3cb0a26cc6ffa4257166c95 Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Fri, 18 Aug 2023 18:35:45 +0800 Subject: [PATCH 48/54] cache temp mem when K<32 --- src/common/tensor2d.hpp | 43 +++++++++++++++++++++++------------- src/mm_kernel_common_amx.hpp | 34 ++++++++++++++++------------ 2 files changed, 48 insertions(+), 29 deletions(-) diff --git a/src/common/tensor2d.hpp b/src/common/tensor2d.hpp index 069c735..a45416f 100644 --- a/src/common/tensor2d.hpp +++ b/src/common/tensor2d.hpp @@ -23,12 +23,13 @@ struct tensor2D { int stride = 0; bool force_compact = false; bool own = false; + bool use_numa_alloc = false; int padded_dim1 = 0; tensor2D() = default; tensor2D(const tensor2D&) = delete; ~tensor2D() { - if (own && data) llmdnn_free(data, capacity); + if (own && data) llmdnn_free(data, capacity, use_numa_alloc); } operator bool() { @@ -50,8 +51,8 @@ struct tensor2D { padded_dim1 = stride / sizeof(T); } - tensor2D Tr(bool _force_compact = false) { - tensor2D ret(dims[1], dims[0], _force_compact); + tensor2D Tr(bool _force_compact = false) { + tensor2D ret(dims[1], dims[0], _force_compact); for(int c0=0; c0 < dims[0]; ++c0) { for(int c1=0; c1 < dims[1]; ++c1) { ret(c1, c0) = (*this)(c0, c1); @@ -59,8 +60,8 @@ struct tensor2D { } return ret; } - tensor2D clone() { - tensor2D ret; + tensor2D clone() { + tensor2D ret; ret.resize(dims[0], dims[1], force_compact); if (ret.stride == stride) { memcpy(ret.data, data, dims[0] * stride); @@ -71,8 +72,8 @@ struct tensor2D { } return ret; } - tensor2D clone_with_padzero(int dim0, int dim1) { - tensor2D ret; + tensor2D clone_with_padzero(int dim0, int dim1) { + tensor2D ret; ret.resize(dim0, dim1, force_compact); assert(dim0 >= dims[0] && dim1 >= dims[1]); for(int i = 0; i < dims[0]; i++) { @@ -85,7 +86,18 @@ struct tensor2D { return ret; } - void resize(int d0, int d1, bool _force_compact = false, bool is_const=false) { + void copyto_with_padzero(tensor2D& dst, int dim0, int dim1) { + dst.resize(dim0, dim1, force_compact); + assert(dim0 >= dims[0] && dim1 >= dims[1]); + for(int i = 0; i < dims[0]; i++) { + memcpy(&dst(i, 0), &(*this)(i, 0), dims[1] * sizeof(T)); + memset(reinterpret_cast(&dst(i, 0) + dims[1]), 0, dst.stride - dims[1] * sizeof(T)); + } + if (dims[1] == dim1) { + memset(reinterpret_cast(dst.data + dims[0] * dst.padded_dim1), 0, (dim0 - dims[0]) * dst.stride); + } + } + void resize(int d0, int d1, bool _force_compact = false, bool is_const = false) { own = true; force_compact = _force_compact; dims[0] = d0; @@ -104,8 +116,9 @@ struct tensor2D { need_capacity *= 2; // align begin address to cache line is vital, so tile load can // use all bandwidth (L1D/L2 only deliver data in unit of 64-byte aligned cache-line) - if (data) llmdnn_free(data, capacity); - data = reinterpret_cast(llmdnn_alloc(64, need_capacity)); + if (data) llmdnn_free(data, capacity, use_numa_alloc); + use_numa_alloc = is_const; + data = reinterpret_cast(llmdnn_alloc(64, need_capacity, use_numa_alloc)); capacity = need_capacity; if (is_const) memset(static_cast(data), 0, need_capacity); @@ -144,13 +157,13 @@ struct tensor2D { (*this)[k] = v; } - tensor2D& operator=(const tensor2D & t2) = delete; + tensor2D& operator=(const tensor2D& t2) = delete; // move semantics - tensor2D(tensor2D && t2) { + tensor2D(tensor2D && t2) { dims[0] = t2.dims[0]; dims[1] = t2.dims[1]; - if (own && data) ::free(data); + if (own && data) llmdnn_free(data, capacity, use_numa_alloc); data = t2.data; own = t2.own; capacity = t2.capacity; @@ -161,10 +174,10 @@ struct tensor2D { t2.data = nullptr; } - tensor2D& operator=(tensor2D && t2) { + tensor2D& operator=(tensor2D && t2) { dims[0] = t2.dims[0]; dims[1] = t2.dims[1]; - if (own && data) ::free(data); + if (own && data) llmdnn_free(data, capacity, use_numa_alloc); own = t2.own; data = t2.data; capacity = t2.capacity; diff --git a/src/mm_kernel_common_amx.hpp b/src/mm_kernel_common_amx.hpp index 0f443cb..df16df9 100644 --- a/src/mm_kernel_common_amx.hpp +++ b/src/mm_kernel_common_amx.hpp @@ -1943,6 +1943,8 @@ struct Matmul { // B matrix is orgnized as tensor2D of shape axb where a=round_up_div(N, 32), b=round_up(K,32/64)*32 // so b is size of submatrix of Kx32 composed of two columns of B0/B1 tiles. tensor2D internalB; + tensor2D A_PaddedK; // pad to 32(bf16)/64(int8) buffer + tensor2D B_PaddedK; // pad to 32(bf16)/64(int8) buffer bool constB; bool transposeB; @@ -2084,6 +2086,8 @@ struct Matmul { alignas(64) TC buff[32 * 32]; tensor2D buffC(32, 32, buff, 32 * sizeof(TC)); + tensor2D* pA = &matA; + tensor2D* pB = &_matB; if (K < kStep) { int B0, B1; if (transposeB) { @@ -2093,11 +2097,13 @@ struct Matmul { B0 = kStep; B1 = _matB.dims[1]; } - matA = matA.clone_with_padzero(M, kStep); - _matB = _matB.clone_with_padzero(B0, B1); + matA.copyto_with_padzero(A_PaddedK, M, kStep); + pA = &A_PaddedK; + _matB.copyto_with_padzero(B_PaddedK, B0, B1); + pB = &B_PaddedK; K = kStep; } - auto matB = getSubMatB(_matB, n0, n1, transposeB); + auto matB = getSubMatB(*pB, n0, n1, transposeB); int N = matB.dims[transposeB ? 0 : 1]; assert(K == matB.dims[transposeB ? 1 : 0]); // Due to the fact that we load a full tile at tails of K dimension @@ -2125,12 +2131,12 @@ struct Matmul { auto * pB0 = reinterpret_cast(&internalB[0]); tileconfig_t tfg(1, 0, 8, 16, 64); switch((K + kStep - 1)/kStep) { - case 1: kernel_slimB<1>(M, N, K, n0, matA, pB0, buffC, ppkernel); break; - case 2: kernel_slimB<2>(M, N, K, n0, matA, pB0, buffC, ppkernel); break; - case 3: kernel_slimB<3>(M, N, K, n0, matA, pB0, buffC, ppkernel); break; - case 4: kernel_slimB<4>(M, N, K, n0, matA, pB0, buffC, ppkernel); break; - case 5: kernel_slimB<5>(M, N, K, n0, matA, pB0, buffC, ppkernel); break; - case 6: kernel_slimB<6>(M, N, K, n0, matA, pB0, buffC, ppkernel); break; + case 1: kernel_slimB<1>(M, N, K, n0, *pA, pB0, buffC, ppkernel); break; + case 2: kernel_slimB<2>(M, N, K, n0, *pA, pB0, buffC, ppkernel); break; + case 3: kernel_slimB<3>(M, N, K, n0, *pA, pB0, buffC, ppkernel); break; + case 4: kernel_slimB<4>(M, N, K, n0, *pA, pB0, buffC, ppkernel); break; + case 5: kernel_slimB<5>(M, N, K, n0, *pA, pB0, buffC, ppkernel); break; + case 6: kernel_slimB<6>(M, N, K, n0, *pA, pB0, buffC, ppkernel); break; default: assert(false); // impossible since (K <= 6*kStep) } @@ -2146,11 +2152,11 @@ struct Matmul { auto * pB0 = reinterpret_cast(&internalB[0]); auto * const pC0 = &buffC[0]; int k; - const auto strideA = matA.stride; + const auto strideA = (*pA).stride; loop2D_no_bM<32>(M, N, [&](int m, int n, int valid_m, int valid_n) { _tile_zero(0); _tile_zero(1); - int8_t * pA0 = reinterpret_cast(&matA[0]); + int8_t * pA0 = reinterpret_cast(&(*pA)[0]); for(k=0; k(&matA(m, 0)); - auto * pA1 = reinterpret_cast(&matA(m + 16, 0)); - auto strideA = matA.stride; + auto * pA0 = reinterpret_cast(&(*pA)(m, 0)); + auto * pA1 = reinterpret_cast(&(*pA)(m + 16, 0)); + auto strideA = (*pA).stride; auto * pB = reinterpret_cast(&internalB(n>>5, 0)); _tile_zero(0); _tile_zero(1); From c29681cb39f15124b312cd3ae2d4e5fc707d7b92 Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Wed, 30 Aug 2023 02:39:20 +0800 Subject: [PATCH 49/54] add numa as the first level task partition basis --- include/llm_fc.hpp | 28 +++- src/common/memory_alloc.cpp | 68 +++++++- src/common/memory_alloc.hpp | 6 + src/common/tensor2d.hpp | 2 +- src/fc_amx.cpp | 204 +++++++++++++++++++++++ src/fc_amx.hpp | 17 ++ src/fc_api.cpp | 32 ++++ src/fc_kernel_amx.cpp | 71 +++++++- src/fc_kernel_amx.hpp | 4 +- src/fc_kernel_api.cpp | 12 +- tests/src/test_fc_amx.cpp | 267 +++++++++++++++++++++++++++++++ tests/src/test_fc_kernel_amx.cpp | 2 +- 12 files changed, 696 insertions(+), 17 deletions(-) create mode 100644 src/fc_amx.cpp create mode 100644 src/fc_amx.hpp create mode 100644 src/fc_api.cpp create mode 100644 tests/src/test_fc_amx.cpp diff --git a/include/llm_fc.hpp b/include/llm_fc.hpp index 0cd08e0..6cf401b 100644 --- a/include/llm_fc.hpp +++ b/include/llm_fc.hpp @@ -5,6 +5,7 @@ #pragma once #include "llm_types.hpp" +#include "llm_tensor.hpp" namespace llmdnn { @@ -65,13 +66,32 @@ void fc_kernel_destroy(fc_kernel* mm); // when fc_create_param.dt_b==bf16, dt_b is in [bf16, f32] // when fc_create_param.dt_b==s8, dt_b is in [bf16, f32] void fc_kernel_pack_weight(fc_kernel* mm, void* ptr_b, data_type_t dt_b, size_t N, size_t K, size_t stride_b, size_t n_start, size_t n_end); +void fc_kernel_pack_weight_to_dst(fc_kernel* mm, void* src_b, void* dst_b, data_type_t dt_b, size_t N, size_t K, size_t stride_b, size_t n_start, size_t n_end); +// ptr_b may be null if using fc_kernel_pack_weight to pack into internal buffer +// if ptr_b is not null, its layout is [N/32, 32*rndup(K,32|64)] void fc_kernel_execute(fc_kernel* mm, - void* ptr_a, void* ptr_c, size_t stride_a, size_t stride_c, + void* ptr_a, void* ptr_b, void* ptr_c, size_t stride_a, size_t stride_c, size_t M, size_t N, size_t K, size_t n_start, size_t n_end, float* dq=nullptr, float* q=nullptr, float* bias=nullptr); -/// weight compression -/// compute weight min/max once, set q, dq for each fc_kernel instance -void fc_kernel_bf16w8_get_q_dq(size_t K, size_t N, size_t stride, void* ptr, float* q, float* dq); +/// Generates a fc based on param +class fc { +public: + fc(); + ~fc(); + + bool init(const fc_create_param& param); + void pack_weight(const tensor& w); + status_t exec(const tensor& input, const tensor& output, const tensor& dq, const tensor& q, const tensor& bias); + + struct impl { + virtual ~impl() {} + virtual bool init(const fc_create_param& param) = 0; + virtual void pack_weight(const tensor& w) = 0; + virtual status_t exec(const tensor& input, const tensor& output, const tensor& dq, const tensor& q, const tensor& bias) = 0; + }; +protected: + impl* _impl; +}; } // namespace llmdnn diff --git a/src/common/memory_alloc.cpp b/src/common/memory_alloc.cpp index 955d30f..4fc74e0 100644 --- a/src/common/memory_alloc.cpp +++ b/src/common/memory_alloc.cpp @@ -8,9 +8,26 @@ #include #include #include "memory_alloc.hpp" +#include "common/simple_parallel.hpp" + +static bool llmdnn_use_numa() { + if (numa_available() == -1) + return false; + + static bool init = false; + static bool use_numa = true; + if (!init) { + init = true; + auto p = std::getenv("LLMDNN_USE_NUMA"); + if (p) { + use_numa = p[0] != '0'; + } + } + return use_numa; +} void* llmdnn_alloc(size_t aligned_size, size_t size, bool hint_numa) { - if (hint_numa && numa_available() != -1) { + if (hint_numa && llmdnn_use_numa()) { int cur_cpu = sched_getcpu(); auto cur_numa_node = numa_node_of_cpu(cur_cpu); return numa_alloc_onnode(size, cur_numa_node); @@ -20,7 +37,54 @@ void* llmdnn_alloc(size_t aligned_size, size_t size, bool hint_numa) { } void llmdnn_free(void* p, size_t size, bool hint_numa) { - if (hint_numa && numa_available() != -1) { + if (hint_numa && llmdnn_use_numa()) { + numa_free(p, size); + } else { + ::free(p); + } +} + +int llmdnn_get_numa_id_for_cur_task() { + if (llmdnn_use_numa()) { + int cur_cpu = sched_getcpu(); + return numa_node_of_cpu(cur_cpu); + } else { + return 0; + } +} + +llm_vector llmdnn_get_numa_nodes() { + llm_vector numa_nodes; + if (llmdnn_use_numa()) { + auto thread_nums = llmdnn::get_total_threads(); + llm_vector numa_nodes_list; + numa_nodes_list.resize(thread_nums); + llmdnn::parallel_for(thread_nums, [&] (size_t id) { + int cur_cpu = sched_getcpu(); + numa_nodes_list[id] = numa_node_of_cpu(cur_cpu); + }); + for (auto numa_node : numa_nodes_list) { + if (std::find(numa_nodes.begin(), numa_nodes.end(), numa_node) == numa_nodes.end()) { + numa_nodes.push_back(numa_node); + } + } + std::sort(numa_nodes.begin(), numa_nodes.end()); + } else { + numa_nodes.push_back(0); + } + return numa_nodes; +} + +void* llmdnn_alloc_on(size_t aligned_size, size_t size, int numa_id) { + if (llmdnn_use_numa()) { + return numa_alloc_onnode(size, static_cast(numa_id)); + } else { + return aligned_alloc(aligned_size, size); + } +} + +void llmdnn_free_on(void* p, size_t size) { + if (llmdnn_use_numa()) { numa_free(p, size); } else { ::free(p); diff --git a/src/common/memory_alloc.hpp b/src/common/memory_alloc.hpp index 530ddc9..a704add 100644 --- a/src/common/memory_alloc.hpp +++ b/src/common/memory_alloc.hpp @@ -5,6 +5,12 @@ #pragma once #include +#include "compatible.hpp" void* llmdnn_alloc(size_t aligned_size, size_t size, bool hint_numa = true); void llmdnn_free(void* p, size_t size, bool hint_numa = true); + +llm_vector llmdnn_get_numa_nodes(); +void* llmdnn_alloc_on(size_t aligned_size, size_t size, int numa_id); +void llmdnn_free_on(void* p, size_t size); +int llmdnn_get_numa_id_for_cur_task(); \ No newline at end of file diff --git a/src/common/tensor2d.hpp b/src/common/tensor2d.hpp index a45416f..c365bad 100644 --- a/src/common/tensor2d.hpp +++ b/src/common/tensor2d.hpp @@ -98,7 +98,6 @@ struct tensor2D { } } void resize(int d0, int d1, bool _force_compact = false, bool is_const = false) { - own = true; force_compact = _force_compact; dims[0] = d0; dims[1] = d1; @@ -112,6 +111,7 @@ struct tensor2D { // resize method never shrink capacity, and extra T is added to put nan as test auto need_capacity = dims[0] * stride + 4096; if (capacity < need_capacity) { + own = true; if (!is_const) need_capacity *= 2; // align begin address to cache line is vital, so tile load can diff --git a/src/fc_amx.cpp b/src/fc_amx.cpp new file mode 100644 index 0000000..b413937 --- /dev/null +++ b/src/fc_amx.cpp @@ -0,0 +1,204 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include + +#include "common/log.hpp" +#include "common/simple_parallel.hpp" +#include "common/tensor2d.hpp" +#include "common/utility.hpp" +#include "common/compatible.hpp" +#include "common/memory_alloc.hpp" +#include "llm_types.hpp" +#include "utility_kernel_avx512.hpp" +#include "mm_kernel_common_amx.hpp" +#include "softmax_kernel_avx512.hpp" +#include "transpose_kernel_avx512.hpp" +#include "llm_fc.hpp" +#include "fc_amx.hpp" + +namespace llmdnn { + +struct fc_impl_amx : public fc::impl { + fc_impl_amx() = default; + ~fc_impl_amx(); + + bool init(const fc_create_param& param) override; + void pack_weight(const tensor& w) override; + status_t exec(const tensor& input, const tensor& output, const tensor& dq, const tensor& q, const tensor& bias) override; + void build_thread_infos(const llm_vector& numa_nodes); + + fc_create_param _create_param; + llm_vector _kernel; // one kernel for each numa node + llm_vector _weights; // one weight for each numa node + llm_vector _weight_sizes; // one weight size for each numa node + size_t _thread_nums; // thread numbers + size_t _N_in_one_numa; // N on each numa node + llm_vector _thread_nums_in_one_numa; // thread numbers in one numa node + size_t _K_align; + struct work_info { + int numa_id; // numa node id, use to index in _numa_nodes + size_t thread_no_in_one_numa; // sequence no in one numa node + }; + llm_vector _work_infos; // map thread id to numa node id and thread no in one numa node +}; + +fc_impl_amx::~fc_impl_amx() { + for (size_t i = 0; i < _kernel.size(); i++) { + fc_kernel_destroy(_kernel[i]); + } + for (size_t i = 0; i < _weight_sizes.size(); i++) { + llmdnn_free_on(_weights[i], _weight_sizes[i]); + } +} + +bool fc_impl_amx::init(const fc_create_param& param) { + _create_param = param; + _thread_nums = get_total_threads(); + _kernel.resize(_thread_nums); + std::atomic_bool ret{true}; + parallel_for(_thread_nums, [&] (size_t id) { + if (fc_kernel_create(&_kernel[id], ¶m) != llmdnn::status_ok) { + ret = false; + } + }); + + return ret; +} + +void fc_impl_amx::build_thread_infos(const llm_vector& numa_nodes) { + _work_infos.resize(_thread_nums); + struct int_atomic { + std::atomic_int v{0}; + }; + llm_vector thread_id_in_one_numa(numa_nodes.size()); + // the real numa id may not be continuous, but we need a number to index _numa_nodes + parallel_for(_thread_nums, [&] (size_t id) { + auto cur_numa_id = llmdnn_get_numa_id_for_cur_task(); + for (int i = 0; i < static_cast(numa_nodes.size()); i++) { + if (numa_nodes[i] == cur_numa_id) { + _work_infos[id].numa_id = i; + _work_infos[id].thread_no_in_one_numa = thread_id_in_one_numa[i].v.fetch_add(1); + break; + } + } + }); + + // check: the index is stable in another loop + std::mutex m; + parallel_for(_thread_nums, [&] (size_t id) { + auto cur_numa_id = llmdnn_get_numa_id_for_cur_task(); + for (int i = 0; i < static_cast(numa_nodes.size()); i++) { + if (numa_nodes[i] == cur_numa_id) { + if (_work_infos[id].numa_id != i) { + std::lock_guard l(m); + DEBUG_LOG << "index test fail: cur numa index of thread no " << id << " is " << i << ", prev index " << _work_infos[id].numa_id << "\n"; + } + break; + } + } + }); + + // check: each numa should have same thread numbers + _thread_nums_in_one_numa.resize(numa_nodes.size()); + int actual_threads = thread_id_in_one_numa[0].v; + _thread_nums_in_one_numa[0] = thread_id_in_one_numa[0].v; + for (size_t i = 1; i < thread_id_in_one_numa.size(); i++) { + if (thread_id_in_one_numa[0].v != thread_id_in_one_numa[i].v) { + DEBUG_LOG << "numa test fail: thread number of numa " << i << " is " << thread_id_in_one_numa[i].v << ", not equal to numa 0 thread numbers: " << thread_id_in_one_numa[0].v << "\n"; + } + actual_threads += thread_id_in_one_numa[i].v; + _thread_nums_in_one_numa[i] = thread_id_in_one_numa[i].v; + } + // check: actual threads number should equal to _thread_nums + if (static_cast(_thread_nums) != actual_threads) { + DEBUG_LOG << "thread number test fail: actual threads number: " << actual_threads << ", not equal to _thread_nums " << _thread_nums << "\n"; + } +} + +void fc_impl_amx::pack_weight(const tensor& w) { + auto N = w.m_dims[_create_param.b_is_trans ? 0 : 1]; + auto K = w.m_dims[_create_param.b_is_trans ? 1 : 0]; + // will allocate memory on different numa nodes: + // 1, get numa nodes number, allocate memory on each numa node + // 2, get cores number, compute each cores area and pack each area simultaneously + // for omp, GOMP_CPU_AFFINITY=0-95 numactl -C0-95 will bind threads to cores + auto numa_nodes = llmdnn_get_numa_nodes(); + auto numa_nodes_nums = numa_nodes.size(); + auto N_blocks = rndup(N, 32) / 32; + // NOTE: assuming memory/thread is evenly distributed across mutiple numas. Need to support unbalanced numa? + _N_in_one_numa = (N_blocks + numa_nodes_nums - 1) / numa_nodes_nums * 32; + if (_create_param.dt_b == data_type_t::llmdnn_bf16) { + _K_align = rndup(K, 32); + } else { + _K_align = rndup(K, 64); + } + _weights.resize(numa_nodes_nums); + _weight_sizes.resize(numa_nodes_nums); + // allocate memory + for (size_t i = 0; i < numa_nodes_nums; i++) { + auto size = _K_align * _N_in_one_numa * get_precision_size(_create_param.dt_b); + _weights[i] = reinterpret_cast(llmdnn_alloc_on(64, size + 4096, numa_nodes[i])); + _weight_sizes[i] = size + 4096; + memset(_weights[i] + size, 0, 4096); + } + build_thread_infos(numa_nodes); + auto work_amount_in_one_numa = _N_in_one_numa / 32; + parallel_for(_thread_nums, [&] (size_t id) { + auto numa_id = _work_infos[id].numa_id; + auto thread_no_in_one_numa = _work_infos[id].thread_no_in_one_numa; + size_t start, end; + splitter(work_amount_in_one_numa, static_cast(_thread_nums_in_one_numa[numa_id]), thread_no_in_one_numa, start, end); + size_t n0_in_one_numa = start * 32; + size_t n1_in_one_numa = std::min(end * 32, _N_in_one_numa); + if (n0_in_one_numa >= _N_in_one_numa) return; + auto n0 = n0_in_one_numa + _N_in_one_numa * numa_id; + auto n1 = n1_in_one_numa + _N_in_one_numa * numa_id; + n1 = std::min(n1, N); + if (n0 >= n1) return; + + auto dst = _weights[numa_id] + n0_in_one_numa * _K_align * get_precision_size(_create_param.dt_b); + fc_kernel_pack_weight_to_dst(_kernel[id], w.data(), dst, w.m_dtype, N, K, w.stride(0), n0, n1); + }); +} + +status_t fc_impl_amx::exec(const tensor& input, const tensor& output, const tensor& dq, const tensor& q, const tensor& bias) { + if (input.m_rank != 2 || output.m_rank != 2 || bias.m_rank != 2) { + DEBUG_LOG << "input,output,bias rank should be 2.\n"; + return status_t::status_invalid_arguments; + } + + auto M = input.size(0); + auto N = output.size(1); + auto K = input.size(1); + auto work_amount_in_one_numa = _N_in_one_numa / 32; + parallel_for(_thread_nums, [&](size_t id) { + auto numa_id = _work_infos[id].numa_id; + auto thread_no_in_one_numa = _work_infos[id].thread_no_in_one_numa; + size_t start, end; + splitter(work_amount_in_one_numa, static_cast(_thread_nums_in_one_numa[numa_id]), thread_no_in_one_numa, start, end); + size_t n0_in_one_numa = start * 32; + size_t n1_in_one_numa = std::min(end * 32, _N_in_one_numa); + if (n0_in_one_numa >= _N_in_one_numa) return; + auto n0 = n0_in_one_numa + _N_in_one_numa * numa_id; + auto n1 = n1_in_one_numa + _N_in_one_numa * numa_id; + n1 = std::min(n1, N); + if (n0 >= n1) return; + + auto weight = _weights[numa_id] + n0_in_one_numa * _K_align * get_precision_size(_create_param.dt_b); + fc_kernel_execute(_kernel[id], input.data(), weight, output.data(), input.stride(0), + output.stride(0), M, N, K, n0, n1, dq.data(), q.data(), bias.data()); + }); + + return status_t::status_ok; +} + +fc::impl* new_fc_impl_amx() { + return new fc_impl_amx(); +} + +} // namespace llmdnn diff --git a/src/fc_amx.hpp b/src/fc_amx.hpp new file mode 100644 index 0000000..53b23ad --- /dev/null +++ b/src/fc_amx.hpp @@ -0,0 +1,17 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include "llm_types.hpp" +#include "llm_fc.hpp" + +namespace llmdnn { + +fc::impl* new_fc_impl_amx(); + +} // namespace llmdnn diff --git a/src/fc_api.cpp b/src/fc_api.cpp new file mode 100644 index 0000000..3abdf96 --- /dev/null +++ b/src/fc_api.cpp @@ -0,0 +1,32 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include + +#include "fc_amx.hpp" + +namespace llmdnn { + +// interface +fc::fc(): _impl(new_fc_impl_amx()) { +} + +fc::~fc() { + delete _impl; +} + +bool fc::init(const fc_create_param& param) { + return _impl->init(param); +} + +void fc::pack_weight(const tensor& w) { + return _impl->pack_weight(w); +} + +status_t fc::exec(const tensor& input, const tensor& output, const tensor& dq, const tensor& q, const tensor& bias) { + return _impl->exec(input, output, dq, q, bias); +} + +} // namespace llmdnn diff --git a/src/fc_kernel_amx.cpp b/src/fc_kernel_amx.cpp index d325573..685a36c 100644 --- a/src/fc_kernel_amx.cpp +++ b/src/fc_kernel_amx.cpp @@ -1,6 +1,8 @@ // Copyright (C) 2018-2023 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // +#include +#include #include #include #include @@ -158,7 +160,60 @@ void fc_kernel_pack_weight_amx(fc_kernel* mm, void* ptr_b, data_type_t dt_b, siz } } -void fc_kernel_execute_amx(fc_kernel* mm, void* ptr_a, void* ptr_c, size_t stride_a, size_t stride_c, +void fc_kernel_pack_weight_to_dst_amx(fc_kernel* mm, void* src_b, void* dst_b, data_type_t dt_b, size_t N, size_t K, size_t stride_b, size_t n_start, size_t n_end) { + mm->stride_b = stride_b; + size_t b_d0 = K, b_d1 = N; + if (mm->b_is_transpose) { + b_d0 = N; + b_d1 = K; + } + if (mm->i8xi8) { + tensor2D b(b_d0, b_d1, static_cast(src_b), mm->stride_b); + // do not care about the real dimension, only ensure .capacity big enough + mm->i8xi8->internalB = tensor2D(1, 1, static_cast(dst_b), 1); + mm->i8xi8->internalB.capacity = INT_MAX; + auto matB = amx_kernel::getSubMatB(b, n_start, n_end, mm->b_is_transpose); + amx_kernel::repackB_1x2(matB, mm->b_is_transpose, mm->i8xi8->internalB, true); + } else if (mm->u8xi8) { + tensor2D b(b_d0, b_d1, static_cast(src_b), mm->stride_b); + mm->u8xi8->internalB = tensor2D(1, 1, static_cast(dst_b), 1); + mm->u8xi8->internalB.capacity = INT_MAX; + auto matB = amx_kernel::getSubMatB(b, n_start, n_end, mm->b_is_transpose); + amx_kernel::repackB_1x2(matB, mm->b_is_transpose, mm->u8xi8->internalB, true); + } else if (mm->bf16xbf16) { + mm->bf16xbf16->internalB = tensor2D(1, 1, static_cast(dst_b), 1); + mm->bf16xbf16->internalB.capacity = INT_MAX; + if (dt_b == llmdnn_bf16) { + tensor2D b(b_d0, b_d1, static_cast(src_b), mm->stride_b); + auto matB = amx_kernel::getSubMatB(b, n_start, n_end, mm->b_is_transpose); + amx_kernel::repackB_1x2(matB, mm->b_is_transpose, mm->bf16xbf16->internalB, true); + } else { + tensor2D b(b_d0, b_d1, static_cast(src_b), mm->stride_b); + auto matB = amx_kernel::getSubMatB(b, n_start, n_end, mm->b_is_transpose); + amx_kernel::repackB_1x2(matB, mm->b_is_transpose, mm->bf16xbf16->internalB, true); + } + } else { + tensor2D internalTmpB; + mm->bf16xi8->internalBI8 = tensor2D(1, 1, static_cast(dst_b), 1); + mm->bf16xi8->internalBI8.capacity = INT_MAX; + if (dt_b == llmdnn_bf16) { + tensor2D b(b_d0, b_d1, static_cast(src_b), mm->stride_b); + auto matB = amx_kernel::getSubMatB(b, n_start, n_end, mm->b_is_transpose); + amx_kernel::repackB_1x2(matB, mm->b_is_transpose, internalTmpB, true); + } else { + tensor2D b(b_d0, b_d1, static_cast(src_b), mm->stride_b); + auto matB = amx_kernel::getSubMatB(b, n_start, n_end, mm->b_is_transpose); + amx_kernel::repackB_1x2(matB, mm->b_is_transpose, internalTmpB, true); + } + if (mm->bf16xi8->dequant_scale_B == 0) { + fc_kernel_bf16w8_get_q_dq_amx(internalTmpB.dims[0], internalTmpB.dims[1], internalTmpB.stride, internalTmpB.data, + &mm->bf16xi8->quant_scale_B, &mm->bf16xi8->dequant_scale_B); + } + amx_kernel::functional::bf16_to_i8_tensor(mm->bf16xi8->internalBI8, internalTmpB, mm->bf16xi8->quant_scale_B); + } +} + +void fc_kernel_execute_amx(fc_kernel* mm, void* ptr_a, void* ptr_b, void* ptr_c, size_t stride_a, size_t stride_c, size_t M, size_t N, size_t K, size_t n_start, size_t n_end, float* dq, float* q, float* bias) { size_t b_d0 = K, b_d1 = N; if (mm->b_is_transpose) { @@ -168,6 +223,10 @@ void fc_kernel_execute_amx(fc_kernel* mm, void* ptr_a, void* ptr_c, size_t strid if (mm->i8xi8) { tensor2D a(M, K, reinterpret_cast(ptr_a), stride_a); tensor2D b(b_d0, b_d1, nullptr, mm->stride_b); + if (ptr_b) { + auto K_padded = rndup(K, 64); + mm->i8xi8->internalB = tensor2D(N / 32, 32 * K_padded, static_cast(ptr_b), 32 * K_padded); + } if (mm->dt_c == llmdnn_s8) { tensor2D c(M, N, reinterpret_cast(ptr_c), stride_c); @@ -278,7 +337,10 @@ void fc_kernel_execute_amx(fc_kernel* mm, void* ptr_a, void* ptr_c, size_t strid } else if (mm->bf16xbf16) { tensor2D a(M, K, reinterpret_cast(ptr_a), stride_a); tensor2D b(b_d0, b_d1, nullptr, mm->stride_b); - + if (ptr_b) { + auto K_padded = rndup(K, 32); + mm->bf16xbf16->internalB = tensor2D(N / 32, 32 * K_padded, static_cast(ptr_b), 32 * K_padded * sizeof(bfloat16)); + } if (mm->dt_c == llmdnn_bf16) { tensor2D c(M, N, reinterpret_cast(ptr_c), stride_c); if (!(mm->postops_type & BIAS)) { @@ -334,6 +396,11 @@ void fc_kernel_execute_amx(fc_kernel* mm, void* ptr_a, void* ptr_c, size_t strid tensor2D a(M, K, reinterpret_cast(ptr_a), stride_a); tensor2D b(b_d0, b_d1, nullptr, mm->stride_b); + if (ptr_b) { + auto K_padded = rndup(K, 64); + mm->bf16xi8->internalBI8 = tensor2D(N / 32, 32 * K_padded, static_cast(ptr_b), 32 * K_padded); + } + if (mm->dt_c == llmdnn_bf16) { tensor2D c(M, N, reinterpret_cast(ptr_c), stride_c); if (!(mm->postops_type & BIAS)) { diff --git a/src/fc_kernel_amx.hpp b/src/fc_kernel_amx.hpp index bbf0727..f293f0c 100644 --- a/src/fc_kernel_amx.hpp +++ b/src/fc_kernel_amx.hpp @@ -12,7 +12,9 @@ void fc_kernel_destroy_amx(fc_kernel* mm); void fc_kernel_pack_weight_amx(fc_kernel* mm, void* ptr_b, data_type_t dt_b, size_t N, size_t K, size_t stride_b, size_t n_start, size_t n_end); -void fc_kernel_execute_amx(fc_kernel* mm, void* ptr_a, void* ptr_c, size_t stride_a, size_t stride_c, +void fc_kernel_pack_weight_to_dst_amx(fc_kernel* mm, void* src_b, void* dst_b, data_type_t dt_b, size_t N, size_t K, size_t stride_b, size_t n_start, size_t n_end); + +void fc_kernel_execute_amx(fc_kernel* mm, void* ptr_a, void* ptr_b, void* ptr_c, size_t stride_a, size_t stride_c, size_t M, size_t N, size_t K, size_t n_start, size_t n_end, float* dq, float* q, float* bias); void fc_kernel_bf16w8_get_q_dq_amx(size_t K, size_t N, size_t stride, void* ptr, float* q, float* dq); diff --git a/src/fc_kernel_api.cpp b/src/fc_kernel_api.cpp index 3e9910a..956b1a8 100644 --- a/src/fc_kernel_api.cpp +++ b/src/fc_kernel_api.cpp @@ -24,8 +24,8 @@ namespace llmdnn { static decltype(&fc_kernel_create) fc_kernel_create_ptr = fc_kernel_create_amx; static decltype(&fc_kernel_destroy) fc_kernel_destroy_ptr = fc_kernel_destroy_amx; static decltype(&fc_kernel_pack_weight) fc_kernel_pack_weight_ptr = fc_kernel_pack_weight_amx; +static decltype(&fc_kernel_pack_weight_to_dst) fc_kernel_pack_weight_to_dst_ptr = fc_kernel_pack_weight_to_dst_amx; static decltype(&fc_kernel_execute) fc_kernel_execute_ptr = fc_kernel_execute_amx; -static decltype(&fc_kernel_bf16w8_get_q_dq) fc_kernel_bf16w8_get_q_dq_ptr = fc_kernel_bf16w8_get_q_dq_amx; // interface status_t fc_kernel_create(fc_kernel** mm, const fc_create_param* param) { @@ -40,13 +40,13 @@ void fc_kernel_pack_weight(fc_kernel* mm, void* ptr_b, data_type_t dt_b, size_t fc_kernel_pack_weight_ptr(mm, ptr_b, dt_b, N, K, stride_b, n_start, n_end); } -void fc_kernel_execute(fc_kernel* mm, void* ptr_a, void* ptr_c, size_t stride_a, size_t stride_c, - size_t M, size_t N, size_t K, size_t n_start, size_t n_end, float* dq, float* q, float* bias) { - fc_kernel_execute_ptr(mm, ptr_a, ptr_c, stride_a, stride_c, M, N, K, n_start, n_end, dq, q, bias); +void fc_kernel_pack_weight_to_dst(fc_kernel* mm, void* src_b, void* dst_b, data_type_t dt_b, size_t N, size_t K, size_t stride_b, size_t n_start, size_t n_end) { + fc_kernel_pack_weight_to_dst_ptr(mm, src_b, dst_b, dt_b, N, K, stride_b, n_start, n_end); } -void fc_kernel_bf16w8_get_q_dq(size_t K, size_t N, size_t stride, void* ptr, float* q, float* dq) { - fc_kernel_bf16w8_get_q_dq_ptr(K, N, stride, ptr, q, dq); +void fc_kernel_execute(fc_kernel* mm, void* ptr_a, void* ptr_b, void* ptr_c, size_t stride_a, size_t stride_c, + size_t M, size_t N, size_t K, size_t n_start, size_t n_end, float* dq, float* q, float* bias) { + fc_kernel_execute_ptr(mm, ptr_a, ptr_b, ptr_c, stride_a, stride_c, M, N, K, n_start, n_end, dq, q, bias); } } // namespace llmdnn diff --git a/tests/src/test_fc_amx.cpp b/tests/src/test_fc_amx.cpp new file mode 100644 index 0000000..abda8f4 --- /dev/null +++ b/tests/src/test_fc_amx.cpp @@ -0,0 +1,267 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include +#include +#include +#include +#include "gtest/gtest.h" +#include "llm_fc.hpp" +#include "common/tensor2d.hpp" +#include "common/tensor2d_helper.hpp" +#include "llm_tensor.hpp" +#include "llm_types.hpp" +#include "test_common.hpp" + +using namespace std; +using namespace llmdnn; +using ::testing::TestWithParam; +using ::testing::Values; +using ::testing::ValuesIn; + +using FCTestShape = std::tuple; +using FCTestDTPost = std::tuple; +using FCTestParamSet = std::tuple< + FCTestDTPost, // a, b, c data type, postops + bool, // b needs transpose + FCTestShape // M, N, K + >; + +class FCTest : public TestWithParam { +public: + static std::string getTestCaseName(const testing::TestParamInfo& obj) { + FCTestDTPost types; + bool is_transpose; + postops_types postops_type; + data_type_t dt_a, dt_b, dt_c, dt_weight; + FCTestShape shape; + int M, N, K; + std::tie(types, is_transpose, shape) = obj.param; + std::tie(M, N, K) = shape; + std::tie(dt_a, dt_b, dt_c, dt_weight, postops_type) = types; + + std::ostringstream result; + result << "A_" << dtype_to_str(dt_a) << "_B_" << dtype_to_str(dt_b) + << "_C_" << dtype_to_str(dt_c) << "_WEIGHT_" << dtype_to_str(dt_weight) + << (is_transpose ? "_transpose" : "") + << "_postops_" << postops_type << "_M_" << M << "_N_" << N << "_K_" << K; + return result.str(); + } + +protected: + virtual void SetUp() override { + initXTILE(); + + FCTestShape shape; + FCTestDTPost types; + std::tie(types, _is_transpose, shape) = GetParam(); + std::tie(_M, _N, _K) = shape; + std::tie(_dt_a, _dt_b, _dt_c, _dt_weight, _postops_type) = types; + }; + + template + void do_test() { + fc_create_param param = { + _dt_a, _dt_b, _dt_c, + _is_transpose, _postops_type + }; + llmdnn::fc fc; + ASSERT_TRUE(fc.init(param)); + + tensor2D A(_M, _K, true); + tensor2D B(_K, _N, true); + tensor2D C(_M, _N, true); + tensor2D C_Ref(_M, _N, true); + tensor2D dq(1, _N); + tensor2D q(1, _N); + tensor2D bias(1, _N); + + fill_rnd(A); + fill_rnd(B); + dq = 2; + q = 2; + fill_rnd(bias); + bias = 1; + + tensor2D BT = B.Tr(true); + TB* ptr_B; + size_t ldb; + tensor weight; + if (_is_transpose) { + ptr_B = BT.data; + ldb = BT.stride; + weight.resize({ static_cast(BT.dims[0]), static_cast(BT.dims[1]) }, static_cast(ptr_B)); + } else { + ptr_B = B.data; + ldb = B.stride; + weight.resize({ static_cast(B.dims[0]), static_cast(B.dims[1]) }, static_cast(ptr_B)); + } + fc.pack_weight(weight); + tensor input, output, bias_t, q_t, dq_t; + input.resize({ static_cast(A.dims[0]), static_cast(A.dims[1]) }, static_cast(A.data)); + output.resize({ static_cast(C.dims[0]), static_cast(C.dims[1]) }, static_cast(C.data)); + dq_t.resize({ static_cast(dq.dims[0]), static_cast(dq.dims[1]) }, dq.data); + q_t.resize({ static_cast(q.dims[0]), static_cast(q.dims[1]) }, q.data); + bias_t.resize({ static_cast(bias.dims[0]), static_cast(bias.dims[1]) }, bias.data); + ASSERT_TRUE(fc.exec(input, output, dq_t, q_t, bias_t) == llmdnn::status_ok); + C_Ref = 0; + float* ptr_dq = nullptr; + float* ptr_q = nullptr; + float* ptr_bias = nullptr; + func_act act = func_act(); + if ((_postops_type & DEQUANT) && _dt_a == llmdnn::llmdnn_s8) { + ptr_dq = dq.data; + } + if (_postops_type & QUANT) { + ptr_q = q.data; + } + if (_postops_type & BIAS) { + ptr_bias = bias.data; + } + if (_postops_type & GELU) { + act = [] (float x) { + return x * 0.5 * (1 + std::erf(x / std::sqrt(2))); + }; + } + if (_postops_type & GELU_TANH) { + act = [] (float x) { + return 0.5f * x * (1.0f + std::tanh(std::sqrt(2.0f / 3.1415926f) * x * (1 + 0.044715f * x * x))); + }; + } + + matmul(A, B, C_Ref, ptr_dq, ptr_bias, act, ptr_q); + float thresh = 0.0001f; + if (std::is_same::value || std::is_same::value) + thresh = 1.1f; + if (std::is_same::value) + thresh = 0.01f; + ASSERT_TRUE(compare(C, C_Ref, thresh)); + } + + int _M, _N, _K; + bool _is_transpose; + postops_types _postops_type; + data_type_t _dt_a, _dt_b, _dt_c, _dt_weight; +}; + +TEST_P(FCTest, Func) { + if (_dt_a == llmdnn_s8 && _dt_weight == llmdnn_s8 && _dt_c == llmdnn_s8) { + do_test(); + } else if (_dt_a == llmdnn_s8 && _dt_weight == llmdnn_s8 && _dt_c == llmdnn_bf16) { + do_test(); + } else if (_dt_a == llmdnn_s8 && _dt_weight == llmdnn_s8 && _dt_c == llmdnn_f32) { + do_test(); + } else if (_dt_a == llmdnn_bf16 && _dt_weight == llmdnn_bf16 && _dt_c == llmdnn_bf16) { + do_test(); + } else if (_dt_a == llmdnn_bf16 && _dt_weight == llmdnn_bf16 && _dt_c == llmdnn_f32) { + do_test(); + } else if (_dt_a == llmdnn_bf16 && _dt_weight == llmdnn_f32 && _dt_c == llmdnn_bf16) { + do_test(); + } else if (_dt_a == llmdnn_bf16 && _dt_weight == llmdnn_f32 && _dt_c == llmdnn_f32) { + do_test(); + } else { + ASSERT_TRUE(false); + } +} + +// supported: +// (s8,s8,s8),dq,[bias],[gelu],q +// (s8,s8,bf16),dq,[bias],[gelu] +// (s8,s8,f32),dq,[bias],[gelu] +// (bf16,bf16,bf16),[bias],[gelu] +// (bf16,bf16,f32),[bias],[gelu] +// (bf16,s8,f32),dq,[bias],[gelu] +// (bf16,s8,bf16),dq,[bias],[gelu] +const std::vector types = { + { llmdnn_s8, llmdnn_s8, llmdnn_s8, llmdnn_s8, DEQUANT_QUANT }, + { llmdnn_s8, llmdnn_s8, llmdnn_s8, llmdnn_s8, DEQUANT_BIAS_QUANT }, + { llmdnn_s8, llmdnn_s8, llmdnn_s8, llmdnn_s8, DEQUANT_GELU_QUANT }, + { llmdnn_s8, llmdnn_s8, llmdnn_s8, llmdnn_s8, DEQUANT_BIAS_GELU_QUANT }, + { llmdnn_s8, llmdnn_s8, llmdnn_s8, llmdnn_s8, DEQUANT_GELU_TANH_QUANT }, + { llmdnn_s8, llmdnn_s8, llmdnn_s8, llmdnn_s8, DEQUANT_BIAS_GELU_TANH_QUANT }, + { llmdnn_s8, llmdnn_s8, llmdnn_bf16, llmdnn_s8, DEQUANT }, + { llmdnn_s8, llmdnn_s8, llmdnn_bf16, llmdnn_s8, DEQUANT_BIAS }, + { llmdnn_s8, llmdnn_s8, llmdnn_bf16, llmdnn_s8, DEQUANT_GELU }, + { llmdnn_s8, llmdnn_s8, llmdnn_bf16, llmdnn_s8, DEQUANT_BIAS_GELU }, + { llmdnn_s8, llmdnn_s8, llmdnn_bf16, llmdnn_s8, DEQUANT_GELU_TANH }, + { llmdnn_s8, llmdnn_s8, llmdnn_bf16, llmdnn_s8, DEQUANT_BIAS_GELU_TANH }, + { llmdnn_s8, llmdnn_s8, llmdnn_f32, llmdnn_s8, DEQUANT }, + { llmdnn_s8, llmdnn_s8, llmdnn_f32, llmdnn_s8, DEQUANT_BIAS }, + { llmdnn_s8, llmdnn_s8, llmdnn_f32, llmdnn_s8, DEQUANT_GELU }, + { llmdnn_s8, llmdnn_s8, llmdnn_f32, llmdnn_s8, DEQUANT_BIAS_GELU }, + { llmdnn_s8, llmdnn_s8, llmdnn_f32, llmdnn_s8, DEQUANT_GELU_TANH }, + { llmdnn_s8, llmdnn_s8, llmdnn_f32, llmdnn_s8, DEQUANT_BIAS_GELU_TANH }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, NONE }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, BIAS }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, GELU }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, BIAS_GELU }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, GELU_TANH }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, BIAS_GELU_TANH }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, llmdnn_bf16, NONE }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, llmdnn_bf16, BIAS }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, llmdnn_bf16, GELU }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, llmdnn_bf16, BIAS_GELU }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, llmdnn_bf16, GELU_TANH }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, llmdnn_bf16, BIAS_GELU_TANH }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, llmdnn_f32, NONE }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, llmdnn_f32, BIAS }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, llmdnn_f32, GELU }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, llmdnn_f32, BIAS_GELU }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, llmdnn_f32, GELU_TANH }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, llmdnn_f32, BIAS_GELU_TANH }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, llmdnn_f32, NONE }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, llmdnn_f32, BIAS }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, llmdnn_f32, GELU }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, llmdnn_f32, BIAS_GELU }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, llmdnn_f32, GELU_TANH }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, llmdnn_f32, BIAS_GELU_TANH }, + // weight compression + { llmdnn_bf16, llmdnn_s8, llmdnn_f32, llmdnn_bf16, DEQUANT }, + { llmdnn_bf16, llmdnn_s8, llmdnn_f32, llmdnn_bf16, DEQUANT_BIAS }, + { llmdnn_bf16, llmdnn_s8, llmdnn_f32, llmdnn_bf16, DEQUANT_GELU }, + { llmdnn_bf16, llmdnn_s8, llmdnn_f32, llmdnn_bf16, DEQUANT_BIAS_GELU }, + { llmdnn_bf16, llmdnn_s8, llmdnn_bf16, llmdnn_bf16, DEQUANT }, + { llmdnn_bf16, llmdnn_s8, llmdnn_bf16, llmdnn_bf16, DEQUANT_BIAS }, + { llmdnn_bf16, llmdnn_s8, llmdnn_bf16, llmdnn_bf16, DEQUANT_GELU }, + { llmdnn_bf16, llmdnn_s8, llmdnn_bf16, llmdnn_bf16, DEQUANT_BIAS_GELU }, + { llmdnn_bf16, llmdnn_s8, llmdnn_f32, llmdnn_f32, DEQUANT }, + { llmdnn_bf16, llmdnn_s8, llmdnn_f32, llmdnn_f32, DEQUANT_BIAS }, + { llmdnn_bf16, llmdnn_s8, llmdnn_f32, llmdnn_f32, DEQUANT_GELU }, + { llmdnn_bf16, llmdnn_s8, llmdnn_f32, llmdnn_f32, DEQUANT_BIAS_GELU }, + { llmdnn_bf16, llmdnn_s8, llmdnn_bf16, llmdnn_f32, DEQUANT }, + { llmdnn_bf16, llmdnn_s8, llmdnn_bf16, llmdnn_f32, DEQUANT_BIAS }, + { llmdnn_bf16, llmdnn_s8, llmdnn_bf16, llmdnn_f32, DEQUANT_GELU }, + { llmdnn_bf16, llmdnn_s8, llmdnn_bf16, llmdnn_f32, DEQUANT_BIAS_GELU }, +}; + +// M, N, K +const std::vector shapes = { + // normal + {256, 128, 448}, + // k tail + {256, 48, 449}, + // M tail == unroll 8 + {256 + 8, 48, 449}, + // M tail == unroll 8 + 2 + {256 + 10, 48, 449}, + // N tail + {256, 40, 448}, + // all tail + {256 + 9, 47, 449}, + // gemv, K <= 64(32)*6 + {256, 1, 80}, +}; + +INSTANTIATE_TEST_SUITE_P(smoke_FC, FCTest, + ::testing::Combine(ValuesIn(types), + Values(true, false), + ValuesIn(shapes)), + FCTest::getTestCaseName); diff --git a/tests/src/test_fc_kernel_amx.cpp b/tests/src/test_fc_kernel_amx.cpp index 18a976d..58ff26a 100644 --- a/tests/src/test_fc_kernel_amx.cpp +++ b/tests/src/test_fc_kernel_amx.cpp @@ -101,7 +101,7 @@ class FCKernelTest : public TestWithParam { ldb = B.stride; } fc_kernel_pack_weight(gemm.get(), ptr_B, _dt_weight, _N, _K, ldb, 0, _N); - fc_kernel_execute(gemm.get(), A.data, C.data, A.stride, + fc_kernel_execute(gemm.get(), A.data, nullptr, C.data, A.stride, C.stride, _M, _N, _K, 0, _N, dq.data, q.data, bias.data); C_Ref = 0; float* ptr_dq = nullptr; From d30b5b529e43fea7af4e8b550fc4dba2fde314a1 Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Thu, 31 Aug 2023 00:20:38 +0800 Subject: [PATCH 50/54] dlopen libnuma, remove compilation phase dependency --- src/CMakeLists.txt | 2 +- src/common/memory_alloc.cpp | 108 +++++++++++++++++++++++++++++------- src/fc_amx.cpp | 89 ++++++++++++++++++----------- tests/src/test_fc_amx.cpp | 2 +- 4 files changed, 146 insertions(+), 55 deletions(-) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index dafbe68..29f4982 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -18,7 +18,7 @@ target_compile_features(${PROJECT_NAME} PRIVATE cxx_std_17) if(CPU_EXTENSIONS_ENABLE_LOG) target_compile_definitions(${PROJECT_NAME} PRIVATE ENABLE_LOG) endif() -target_link_libraries(${PROJECT_NAME} PUBLIC numa) +target_link_libraries(${PROJECT_NAME} PUBLIC dl) set(CMAKE_DST lib/cmake/${PROJECT_NAME}) # header files diff --git a/src/common/memory_alloc.cpp b/src/common/memory_alloc.cpp index 4fc74e0..89327fa 100644 --- a/src/common/memory_alloc.cpp +++ b/src/common/memory_alloc.cpp @@ -6,31 +6,99 @@ #include #include #include -#include +#include + #include "memory_alloc.hpp" #include "common/simple_parallel.hpp" -static bool llmdnn_use_numa() { - if (numa_available() == -1) - return false; - - static bool init = false; - static bool use_numa = true; - if (!init) { - init = true; - auto p = std::getenv("LLMDNN_USE_NUMA"); - if (p) { - use_numa = p[0] != '0'; +struct numa_funcs { + numa_funcs() { + _numa_handle = dlopen(libnuma_path, RTLD_NOW); + if (_numa_handle) { + _numa_available = reinterpret_cast(dlsym(_numa_handle, "numa_available")); + _numa_node_of_cpu = reinterpret_cast(dlsym(_numa_handle, "numa_node_of_cpu")); + _numa_alloc_onnode = reinterpret_cast(dlsym(_numa_handle, "numa_alloc_onnode")); + _numa_free = reinterpret_cast(dlsym(_numa_handle, "numa_free")); + } + } + + ~numa_funcs() { + if (_numa_handle) { + dlclose(_numa_handle); + } + } + + static numa_funcs& get() { + static numa_funcs funcs; + return funcs; + } + + int numa_available() { + if (_numa_available) { + return _numa_available(); + } else { + return -1; } } - return use_numa; + + int numa_node_of_cpu(int cpu) { + if (_numa_node_of_cpu) { + return _numa_node_of_cpu(cpu); + } else { + return 0; + } + } + + void *numa_alloc_onnode(size_t size, int node) { + if (_numa_alloc_onnode) { + return _numa_alloc_onnode(size, node); + } else { + return aligned_alloc(64, size); + } + } + + void numa_free(void *mem, size_t size) { + if (_numa_free) { + _numa_free(mem, size); + } else { + ::free(mem); + } + } + +private: + constexpr static const char* libnuma_path = "libnuma.so.1"; + void* _numa_handle = nullptr; + int (*_numa_available)(void) = nullptr; + int (*_numa_node_of_cpu)(int cpu) = nullptr; + void *(*_numa_alloc_onnode)(size_t size, int node) = nullptr; + void (*_numa_free)(void *mem, size_t size) = nullptr; +}; + +static bool llmdnn_use_numa() { + struct init_numa_flag { + init_numa_flag() { + auto p = std::getenv("LLMDNN_USE_NUMA"); + if (p) { + use_numa = p[0] != '0'; + } + if (use_numa) { + use_numa = numa_funcs::get().numa_available() != -1; + } + } + + bool use_numa = true; + }; + + static init_numa_flag flag; + + return flag.use_numa; } void* llmdnn_alloc(size_t aligned_size, size_t size, bool hint_numa) { if (hint_numa && llmdnn_use_numa()) { int cur_cpu = sched_getcpu(); - auto cur_numa_node = numa_node_of_cpu(cur_cpu); - return numa_alloc_onnode(size, cur_numa_node); + auto cur_numa_node = numa_funcs::get().numa_node_of_cpu(cur_cpu); + return numa_funcs::get().numa_alloc_onnode(size, cur_numa_node); } else { return aligned_alloc(aligned_size, size); } @@ -38,7 +106,7 @@ void* llmdnn_alloc(size_t aligned_size, size_t size, bool hint_numa) { void llmdnn_free(void* p, size_t size, bool hint_numa) { if (hint_numa && llmdnn_use_numa()) { - numa_free(p, size); + numa_funcs::get().numa_free(p, size); } else { ::free(p); } @@ -47,7 +115,7 @@ void llmdnn_free(void* p, size_t size, bool hint_numa) { int llmdnn_get_numa_id_for_cur_task() { if (llmdnn_use_numa()) { int cur_cpu = sched_getcpu(); - return numa_node_of_cpu(cur_cpu); + return numa_funcs::get().numa_node_of_cpu(cur_cpu); } else { return 0; } @@ -61,7 +129,7 @@ llm_vector llmdnn_get_numa_nodes() { numa_nodes_list.resize(thread_nums); llmdnn::parallel_for(thread_nums, [&] (size_t id) { int cur_cpu = sched_getcpu(); - numa_nodes_list[id] = numa_node_of_cpu(cur_cpu); + numa_nodes_list[id] = numa_funcs::get().numa_node_of_cpu(cur_cpu); }); for (auto numa_node : numa_nodes_list) { if (std::find(numa_nodes.begin(), numa_nodes.end(), numa_node) == numa_nodes.end()) { @@ -77,7 +145,7 @@ llm_vector llmdnn_get_numa_nodes() { void* llmdnn_alloc_on(size_t aligned_size, size_t size, int numa_id) { if (llmdnn_use_numa()) { - return numa_alloc_onnode(size, static_cast(numa_id)); + return numa_funcs::get().numa_alloc_onnode(size, static_cast(numa_id)); } else { return aligned_alloc(aligned_size, size); } @@ -85,7 +153,7 @@ void* llmdnn_alloc_on(size_t aligned_size, size_t size, int numa_id) { void llmdnn_free_on(void* p, size_t size) { if (llmdnn_use_numa()) { - numa_free(p, size); + numa_funcs::get().numa_free(p, size); } else { ::free(p); } diff --git a/src/fc_amx.cpp b/src/fc_amx.cpp index b413937..2e5a07b 100644 --- a/src/fc_amx.cpp +++ b/src/fc_amx.cpp @@ -30,26 +30,28 @@ struct fc_impl_amx : public fc::impl { bool init(const fc_create_param& param) override; void pack_weight(const tensor& w) override; status_t exec(const tensor& input, const tensor& output, const tensor& dq, const tensor& q, const tensor& bias) override; - void build_thread_infos(const llm_vector& numa_nodes); + void associate_thread_numa(const llm_vector& numa_nodes); fc_create_param _create_param; - llm_vector _kernel; // one kernel for each numa node - llm_vector _weights; // one weight for each numa node - llm_vector _weight_sizes; // one weight size for each numa node - size_t _thread_nums; // thread numbers - size_t _N_in_one_numa; // N on each numa node - llm_vector _thread_nums_in_one_numa; // thread numbers in one numa node + llm_vector _kernel; // one kernel for each numa node + llm_vector _weights; // one weight for each numa node + llm_vector _weight_sizes; // one weight size for each numa node + llm_vector _numa_nodes; // numa nodes + size_t _thread_nums; // thread numbers + size_t _N_in_one_numa; // N on each numa node + llm_vector _thread_nums_in_one_numa; // thread numbers in one numa node size_t _K_align; struct work_info { - int numa_id; // numa node id, use to index in _numa_nodes - size_t thread_no_in_one_numa; // sequence no in one numa node + int numa_id = 0; // numa node id, use to index in _weights + size_t thread_no_in_one_numa = 0; // sequence no in one numa node }; - llm_vector _work_infos; // map thread id to numa node id and thread no in one numa node + llm_vector _thread_infos; // map thread id to numa node id and thread no in one numa node }; fc_impl_amx::~fc_impl_amx() { for (size_t i = 0; i < _kernel.size(); i++) { - fc_kernel_destroy(_kernel[i]); + if (_kernel[i]) + fc_kernel_destroy(_kernel[i]); } for (size_t i = 0; i < _weight_sizes.size(); i++) { llmdnn_free_on(_weights[i], _weight_sizes[i]); @@ -59,19 +61,24 @@ fc_impl_amx::~fc_impl_amx() { bool fc_impl_amx::init(const fc_create_param& param) { _create_param = param; _thread_nums = get_total_threads(); - _kernel.resize(_thread_nums); - std::atomic_bool ret{true}; - parallel_for(_thread_nums, [&] (size_t id) { - if (fc_kernel_create(&_kernel[id], ¶m) != llmdnn::status_ok) { + _kernel.resize(_thread_nums, nullptr); + bool ret = true; + for (size_t i = 0; i < _thread_nums; i++) { + if (fc_kernel_create(&_kernel[i], ¶m) != llmdnn::status_ok) { ret = false; + break; } - }); + } + if (ret) { + _numa_nodes = llmdnn_get_numa_nodes(); + associate_thread_numa(_numa_nodes); + } return ret; } -void fc_impl_amx::build_thread_infos(const llm_vector& numa_nodes) { - _work_infos.resize(_thread_nums); +void fc_impl_amx::associate_thread_numa(const llm_vector& numa_nodes) { + _thread_infos.resize(_thread_nums); struct int_atomic { std::atomic_int v{0}; }; @@ -81,8 +88,8 @@ void fc_impl_amx::build_thread_infos(const llm_vector& numa_nodes) { auto cur_numa_id = llmdnn_get_numa_id_for_cur_task(); for (int i = 0; i < static_cast(numa_nodes.size()); i++) { if (numa_nodes[i] == cur_numa_id) { - _work_infos[id].numa_id = i; - _work_infos[id].thread_no_in_one_numa = thread_id_in_one_numa[i].v.fetch_add(1); + _thread_infos[id].numa_id = i; + _thread_infos[id].thread_no_in_one_numa = thread_id_in_one_numa[i].v.fetch_add(1); break; } } @@ -94,9 +101,9 @@ void fc_impl_amx::build_thread_infos(const llm_vector& numa_nodes) { auto cur_numa_id = llmdnn_get_numa_id_for_cur_task(); for (int i = 0; i < static_cast(numa_nodes.size()); i++) { if (numa_nodes[i] == cur_numa_id) { - if (_work_infos[id].numa_id != i) { + if (_thread_infos[id].numa_id != i) { std::lock_guard l(m); - DEBUG_LOG << "index test fail: cur numa index of thread no " << id << " is " << i << ", prev index " << _work_infos[id].numa_id << "\n"; + DEBUG_LOG << "index test warning: cur numa index of thread no " << id << " is " << i << ", prev index " << _thread_infos[id].numa_id << "\n"; } break; } @@ -107,16 +114,35 @@ void fc_impl_amx::build_thread_infos(const llm_vector& numa_nodes) { _thread_nums_in_one_numa.resize(numa_nodes.size()); int actual_threads = thread_id_in_one_numa[0].v; _thread_nums_in_one_numa[0] = thread_id_in_one_numa[0].v; + bool zero_threads_in_one_numa = _thread_nums_in_one_numa[0] == 0; for (size_t i = 1; i < thread_id_in_one_numa.size(); i++) { if (thread_id_in_one_numa[0].v != thread_id_in_one_numa[i].v) { - DEBUG_LOG << "numa test fail: thread number of numa " << i << " is " << thread_id_in_one_numa[i].v << ", not equal to numa 0 thread numbers: " << thread_id_in_one_numa[0].v << "\n"; + DEBUG_LOG << "numa test warning: thread number of numa " << i << " is " << thread_id_in_one_numa[i].v << ", not equal to numa 0 thread numbers: " << thread_id_in_one_numa[0].v << "\n"; } actual_threads += thread_id_in_one_numa[i].v; _thread_nums_in_one_numa[i] = thread_id_in_one_numa[i].v; + zero_threads_in_one_numa |= _thread_nums_in_one_numa[i] == 0; + } + if (zero_threads_in_one_numa) { + // no threads in one numa, the result will be wrong + DEBUG_LOG << "zero threads warning: there is no threads in some numa. Will assign threads statically.\n"; } + // check: actual threads number should equal to _thread_nums if (static_cast(_thread_nums) != actual_threads) { - DEBUG_LOG << "thread number test fail: actual threads number: " << actual_threads << ", not equal to _thread_nums " << _thread_nums << "\n"; + DEBUG_LOG << "thread number test warning: actual threads number: " << actual_threads << ", not equal to _thread_nums " << _thread_nums << "\n"; + } + + // fix thread numbers in one numa to get correct result regardless of performance + if (zero_threads_in_one_numa || static_cast(_thread_nums) != actual_threads) { + auto thread_num_in_one_numa = (_thread_nums + numa_nodes.size() - 1) / numa_nodes.size(); + for (size_t i = 0; i < numa_nodes.size(); i++) { + _thread_nums_in_one_numa[i] = std::min(thread_num_in_one_numa, _thread_nums - i * thread_num_in_one_numa); + } + for (int i = 0; i < static_cast(_thread_infos.size()); i++) { + _thread_infos[i].numa_id = i / thread_num_in_one_numa; + _thread_infos[i].thread_no_in_one_numa = i % thread_num_in_one_numa; + } } } @@ -126,9 +152,7 @@ void fc_impl_amx::pack_weight(const tensor& w) { // will allocate memory on different numa nodes: // 1, get numa nodes number, allocate memory on each numa node // 2, get cores number, compute each cores area and pack each area simultaneously - // for omp, GOMP_CPU_AFFINITY=0-95 numactl -C0-95 will bind threads to cores - auto numa_nodes = llmdnn_get_numa_nodes(); - auto numa_nodes_nums = numa_nodes.size(); + auto numa_nodes_nums = _numa_nodes.size(); auto N_blocks = rndup(N, 32) / 32; // NOTE: assuming memory/thread is evenly distributed across mutiple numas. Need to support unbalanced numa? _N_in_one_numa = (N_blocks + numa_nodes_nums - 1) / numa_nodes_nums * 32; @@ -142,15 +166,14 @@ void fc_impl_amx::pack_weight(const tensor& w) { // allocate memory for (size_t i = 0; i < numa_nodes_nums; i++) { auto size = _K_align * _N_in_one_numa * get_precision_size(_create_param.dt_b); - _weights[i] = reinterpret_cast(llmdnn_alloc_on(64, size + 4096, numa_nodes[i])); + _weights[i] = reinterpret_cast(llmdnn_alloc_on(64, size + 4096, _numa_nodes[i])); _weight_sizes[i] = size + 4096; memset(_weights[i] + size, 0, 4096); } - build_thread_infos(numa_nodes); auto work_amount_in_one_numa = _N_in_one_numa / 32; parallel_for(_thread_nums, [&] (size_t id) { - auto numa_id = _work_infos[id].numa_id; - auto thread_no_in_one_numa = _work_infos[id].thread_no_in_one_numa; + auto numa_id = _thread_infos[id].numa_id; + auto thread_no_in_one_numa = _thread_infos[id].thread_no_in_one_numa; size_t start, end; splitter(work_amount_in_one_numa, static_cast(_thread_nums_in_one_numa[numa_id]), thread_no_in_one_numa, start, end); size_t n0_in_one_numa = start * 32; @@ -177,8 +200,8 @@ status_t fc_impl_amx::exec(const tensor& input, const tensor& output, const tens auto K = input.size(1); auto work_amount_in_one_numa = _N_in_one_numa / 32; parallel_for(_thread_nums, [&](size_t id) { - auto numa_id = _work_infos[id].numa_id; - auto thread_no_in_one_numa = _work_infos[id].thread_no_in_one_numa; + auto numa_id = _thread_infos[id].numa_id; + auto thread_no_in_one_numa = _thread_infos[id].thread_no_in_one_numa; size_t start, end; splitter(work_amount_in_one_numa, static_cast(_thread_nums_in_one_numa[numa_id]), thread_no_in_one_numa, start, end); size_t n0_in_one_numa = start * 32; diff --git a/tests/src/test_fc_amx.cpp b/tests/src/test_fc_amx.cpp index abda8f4..07f67e3 100644 --- a/tests/src/test_fc_amx.cpp +++ b/tests/src/test_fc_amx.cpp @@ -253,7 +253,7 @@ const std::vector shapes = { // M tail == unroll 8 + 2 {256 + 10, 48, 449}, // N tail - {256, 40, 448}, + {256, 95, 448}, // all tail {256 + 9, 47, 449}, // gemv, K <= 64(32)*6 From 0d7dac43dcfd13d6b92f78546493920b157d0961 Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Thu, 31 Aug 2023 02:13:35 +0800 Subject: [PATCH 51/54] avoid usage of vector --- src/common/tensor.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/common/tensor.cpp b/src/common/tensor.cpp index 95c37fe..0117972 100644 --- a/src/common/tensor.cpp +++ b/src/common/tensor.cpp @@ -5,7 +5,6 @@ #include #include #include -#include #include #include "common/log.hpp" @@ -100,7 +99,7 @@ tensor tensor::reshape(const std::initializer_list& target_shape) const // only valid for dense memory tensor new_tensor_view; assert(is_dense()); - new_tensor_view.resize(std::vector(target_shape), m_ptr, m_element_size, m_dtype); + new_tensor_view.resize(target_shape.begin(), target_shape.size(), m_ptr, m_element_size, m_dtype); return new_tensor_view; } From 581a3cf822e95dc3e98108de640923d66c3444f3 Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Fri, 15 Sep 2023 23:28:31 +0800 Subject: [PATCH 52/54] support perchannel u8 compress weight --- include/llm_fc.hpp | 11 +- src/common/tensor2d.hpp | 2 +- src/fc_amx.cpp | 31 +- src/fc_kernel_amx.cpp | 58 +-- src/mm_kernel_common_amx.hpp | 467 ++++++++++++++---- tests/src/test_common.hpp | 2 - tests/src/test_fc_amx.cpp | 132 ++++- tests/src/test_fc_kernel_amx.cpp | 127 +++-- .../test_utility_kernel_repack1x2_avx512.cpp | 56 +++ 9 files changed, 681 insertions(+), 205 deletions(-) diff --git a/include/llm_fc.hpp b/include/llm_fc.hpp index 6cf401b..c29371c 100644 --- a/include/llm_fc.hpp +++ b/include/llm_fc.hpp @@ -43,8 +43,9 @@ struct fc_create_param { bool b_is_trans; postops_types postops_type; // for weight compression - float q; - float dq; + float* scale; + float* zp; + int scale_zp_size; }; struct fc_kernel; @@ -58,13 +59,13 @@ struct fc_kernel; /// fc: (s8,s8,f32),dq,[bias],[gelu] /// fc: (bf16,bf16,bf16),[bias],[gelu] /// fc: (bf16,bf16,f32),[bias],[gelu] -/// fc: (bf16,s8,f32),dq,[bias],[gelu] -/// fc: (bf16,s8,bf16),dq,[bias],[gelu] +/// fc: (bf16,u8,f32),dq,[bias],[gelu] +/// fc: (bf16,u8,bf16),dq,[bias],[gelu] /// status_t fc_kernel_create(fc_kernel** mm, const fc_create_param* param); void fc_kernel_destroy(fc_kernel* mm); // when fc_create_param.dt_b==bf16, dt_b is in [bf16, f32] -// when fc_create_param.dt_b==s8, dt_b is in [bf16, f32] +// when fc_create_param.dt_b==u8, dt_b is in [bf16, f32] void fc_kernel_pack_weight(fc_kernel* mm, void* ptr_b, data_type_t dt_b, size_t N, size_t K, size_t stride_b, size_t n_start, size_t n_end); void fc_kernel_pack_weight_to_dst(fc_kernel* mm, void* src_b, void* dst_b, data_type_t dt_b, size_t N, size_t K, size_t stride_b, size_t n_start, size_t n_end); // ptr_b may be null if using fc_kernel_pack_weight to pack into internal buffer diff --git a/src/common/tensor2d.hpp b/src/common/tensor2d.hpp index c365bad..8e1756d 100644 --- a/src/common/tensor2d.hpp +++ b/src/common/tensor2d.hpp @@ -13,7 +13,7 @@ #include "log.hpp" #include "bf16.hpp" -#define rndup(x, n) (((x + n - 1)/n)*n) +#define rndup(x, n) ((((x) + (n) - 1) / (n)) * (n)) template struct tensor2D { diff --git a/src/fc_amx.cpp b/src/fc_amx.cpp index 2e5a07b..ff83294 100644 --- a/src/fc_amx.cpp +++ b/src/fc_amx.cpp @@ -31,6 +31,7 @@ struct fc_impl_amx : public fc::impl { void pack_weight(const tensor& w) override; status_t exec(const tensor& input, const tensor& output, const tensor& dq, const tensor& q, const tensor& bias) override; void associate_thread_numa(const llm_vector& numa_nodes); + void init_weight_compress_param(); fc_create_param _create_param; llm_vector _kernel; // one kernel for each numa node @@ -46,6 +47,8 @@ struct fc_impl_amx : public fc::impl { size_t thread_no_in_one_numa = 0; // sequence no in one numa node }; llm_vector _thread_infos; // map thread id to numa node id and thread no in one numa node + tensor2D _descale; + tensor2D _zp; }; fc_impl_amx::~fc_impl_amx() { @@ -58,13 +61,37 @@ fc_impl_amx::~fc_impl_amx() { } } +void fc_impl_amx::init_weight_compress_param() { + fc_create_param& param = _create_param; + if (param.scale) { + auto size = rndup(param.scale_zp_size, 64 / sizeof(float)); + _descale.resize(1, size, false, false); + memcpy(_descale.data, param.scale, param.scale_zp_size * sizeof(float)); + memset(_descale.data + param.scale_zp_size, 0, (size - param.scale_zp_size) * sizeof(float)); + auto zp_size = rndup(param.scale_zp_size * 2, 64 / sizeof(float)); + _zp.resize(1, zp_size, false, false); + if (param.zp) { + for (int i = 0; i < param.scale_zp_size; i++) { + _zp(0, 2 * i) = param.zp[i]; + _zp(0, 2 * i + 1) = param.zp[i]; + } + memset(_zp.data + param.scale_zp_size * 2, 0, (zp_size - param.scale_zp_size * 2) * sizeof(float)); + } else { + memset(_zp.data, 0, zp_size * sizeof(float)); + } + param.scale = _descale.data; + param.zp = _zp.data; + } +} + bool fc_impl_amx::init(const fc_create_param& param) { _create_param = param; _thread_nums = get_total_threads(); _kernel.resize(_thread_nums, nullptr); + init_weight_compress_param(); bool ret = true; for (size_t i = 0; i < _thread_nums; i++) { - if (fc_kernel_create(&_kernel[i], ¶m) != llmdnn::status_ok) { + if (fc_kernel_create(&_kernel[i], &_create_param) != llmdnn::status_ok) { ret = false; break; } @@ -156,7 +183,7 @@ void fc_impl_amx::pack_weight(const tensor& w) { auto N_blocks = rndup(N, 32) / 32; // NOTE: assuming memory/thread is evenly distributed across mutiple numas. Need to support unbalanced numa? _N_in_one_numa = (N_blocks + numa_nodes_nums - 1) / numa_nodes_nums * 32; - if (_create_param.dt_b == data_type_t::llmdnn_bf16) { + if (_create_param.dt_a == data_type_t::llmdnn_bf16) { _K_align = rndup(K, 32); } else { _K_align = rndup(K, 64); diff --git a/src/fc_kernel_amx.cpp b/src/fc_kernel_amx.cpp index 685a36c..0292a4e 100644 --- a/src/fc_kernel_amx.cpp +++ b/src/fc_kernel_amx.cpp @@ -23,7 +23,7 @@ namespace llmdnn { using ov::bfloat16; struct fc_kernel { std::unique_ptr> bf16xbf16; - std::unique_ptr> bf16xi8; + std::unique_ptr> bf16xi8; std::unique_ptr> i8xi8; std::unique_ptr> u8xi8; @@ -44,8 +44,8 @@ static bool check_valid_postops(size_t value, data_type_t dt_a, data_type_t dt_b { { llmdnn_s8, llmdnn_s8, llmdnn_f32 }, { DEQUANT, BIAS | GELU | GELU_TANH } }, { { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16 }, { 0, BIAS | GELU | GELU_TANH } }, { { llmdnn_bf16, llmdnn_bf16, llmdnn_f32 }, { 0, BIAS | GELU | GELU_TANH } }, - { { llmdnn_bf16, llmdnn_s8, llmdnn_f32 }, { DEQUANT, BIAS | GELU | GELU_TANH } }, - { { llmdnn_bf16, llmdnn_s8, llmdnn_bf16 }, { DEQUANT, BIAS | GELU | GELU_TANH } }, + { { llmdnn_bf16, llmdnn_u8, llmdnn_f32 }, { DEQUANT, BIAS | GELU | GELU_TANH } }, + { { llmdnn_bf16, llmdnn_u8, llmdnn_bf16 }, { DEQUANT, BIAS | GELU | GELU_TANH } }, }; auto it = supported_postops.find(std::make_tuple(dt_a, dt_b, dt_c)); @@ -88,10 +88,10 @@ status_t fc_kernel_create_amx(fc_kernel** mm, const fc_create_param* param) { m->u8xi8 = std::make_unique>(true, param->b_is_trans); } else if (param->dt_a == llmdnn_bf16 && (param->dt_b == llmdnn_bf16 || param->dt_b == llmdnn_f32)) { m->bf16xbf16 = std::make_unique>(true, param->b_is_trans); - } else if (param->dt_a == llmdnn_bf16 && param->dt_b == llmdnn_s8) { - m->bf16xi8 = std::make_unique>(true, param->b_is_trans); - m->bf16xi8->quant_scale_B = param->q; - m->bf16xi8->dequant_scale_B = param->dq; + } else if (param->dt_a == llmdnn_bf16 && param->dt_b == llmdnn_u8) { + m->bf16xi8 = std::make_unique>(true, param->b_is_trans); + m->bf16xi8->dequant_scale_B = param->scale; + m->bf16xi8->zp = param->zp; } else { DEBUG_LOG << "fc_kernel_create: unsupport input type, a: " << param->dt_a << ", b: " << param->dt_b << ".\n"; goto ERR; @@ -142,21 +142,10 @@ void fc_kernel_pack_weight_amx(fc_kernel* mm, void* ptr_b, data_type_t dt_b, siz amx_kernel::repackB_1x2(matB, mm->b_is_transpose, mm->bf16xbf16->internalB, true); } } else { - tensor2D internalTmpB; - if (dt_b == llmdnn_bf16) { - tensor2D b(b_d0, b_d1, static_cast(ptr_b), mm->stride_b); - auto matB = amx_kernel::getSubMatB(b, n_start, n_end, mm->b_is_transpose); - amx_kernel::repackB_1x2(matB, mm->b_is_transpose, internalTmpB, true); - } else { - tensor2D b(b_d0, b_d1, static_cast(ptr_b), mm->stride_b); - auto matB = amx_kernel::getSubMatB(b, n_start, n_end, mm->b_is_transpose); - amx_kernel::repackB_1x2(matB, mm->b_is_transpose, internalTmpB, true); - } - if (mm->bf16xi8->dequant_scale_B == 0) { - fc_kernel_bf16w8_get_q_dq_amx(internalTmpB.dims[0], internalTmpB.dims[1], internalTmpB.stride, internalTmpB.data, - &mm->bf16xi8->quant_scale_B, &mm->bf16xi8->dequant_scale_B); - } - amx_kernel::functional::bf16_to_i8_tensor(mm->bf16xi8->internalBI8, internalTmpB, mm->bf16xi8->quant_scale_B); + assert(dt_b == llmdnn_u8); + tensor2D b(b_d0, b_d1, static_cast(ptr_b), mm->stride_b); + auto matB = amx_kernel::getSubMatB(b, n_start, n_end, mm->b_is_transpose); + amx_kernel::repackB_1x2_compressed(matB, mm->b_is_transpose, mm->bf16xi8->internalBI8, true); } } @@ -193,23 +182,12 @@ void fc_kernel_pack_weight_to_dst_amx(fc_kernel* mm, void* src_b, void* dst_b, d amx_kernel::repackB_1x2(matB, mm->b_is_transpose, mm->bf16xbf16->internalB, true); } } else { - tensor2D internalTmpB; - mm->bf16xi8->internalBI8 = tensor2D(1, 1, static_cast(dst_b), 1); + assert(dt_b == llmdnn_u8); + mm->bf16xi8->internalBI8 = tensor2D(1, 1, static_cast(dst_b), 1); mm->bf16xi8->internalBI8.capacity = INT_MAX; - if (dt_b == llmdnn_bf16) { - tensor2D b(b_d0, b_d1, static_cast(src_b), mm->stride_b); - auto matB = amx_kernel::getSubMatB(b, n_start, n_end, mm->b_is_transpose); - amx_kernel::repackB_1x2(matB, mm->b_is_transpose, internalTmpB, true); - } else { - tensor2D b(b_d0, b_d1, static_cast(src_b), mm->stride_b); - auto matB = amx_kernel::getSubMatB(b, n_start, n_end, mm->b_is_transpose); - amx_kernel::repackB_1x2(matB, mm->b_is_transpose, internalTmpB, true); - } - if (mm->bf16xi8->dequant_scale_B == 0) { - fc_kernel_bf16w8_get_q_dq_amx(internalTmpB.dims[0], internalTmpB.dims[1], internalTmpB.stride, internalTmpB.data, - &mm->bf16xi8->quant_scale_B, &mm->bf16xi8->dequant_scale_B); - } - amx_kernel::functional::bf16_to_i8_tensor(mm->bf16xi8->internalBI8, internalTmpB, mm->bf16xi8->quant_scale_B); + tensor2D b(b_d0, b_d1, static_cast(src_b), mm->stride_b); + auto matB = amx_kernel::getSubMatB(b, n_start, n_end, mm->b_is_transpose); + amx_kernel::repackB_1x2_compressed(matB, mm->b_is_transpose, mm->bf16xi8->internalBI8, true); } } @@ -397,8 +375,8 @@ void fc_kernel_execute_amx(fc_kernel* mm, void* ptr_a, void* ptr_b, void* ptr_c, tensor2D b(b_d0, b_d1, nullptr, mm->stride_b); if (ptr_b) { - auto K_padded = rndup(K, 64); - mm->bf16xi8->internalBI8 = tensor2D(N / 32, 32 * K_padded, static_cast(ptr_b), 32 * K_padded); + auto K_padded = rndup(K, 32); + mm->bf16xi8->internalBI8 = tensor2D(N / 32, 32 * K_padded, static_cast(ptr_b), 32 * K_padded); } if (mm->dt_c == llmdnn_bf16) { diff --git a/src/mm_kernel_common_amx.hpp b/src/mm_kernel_common_amx.hpp index df16df9..5cecc98 100644 --- a/src/mm_kernel_common_amx.hpp +++ b/src/mm_kernel_common_amx.hpp @@ -1202,6 +1202,56 @@ namespace functional { } } + template + void i8_to_bf16_Kx32(int8_t *&src, ov::bfloat16 *dst, float* zp) + { + auto zp0 = _mm512_loadu_ps(zp); + auto zp1 = _mm512_loadu_ps(zp + 16); + for (int k = 0; k < K; k++) + { + auto a = _mm_load_si128((__m128i *)src); // 16 int8 + auto b = _mm_load_si128((__m128i *)(src + 16)); // 16 int8 + auto a_512 = _mm512_cvtepu8_epi32(a); // 16 int32 + auto b_512 = _mm512_cvtepu8_epi32(b); // 16 int32 + auto a_f = _mm512_cvtepi32_ps(a_512); // 16 ps + auto b_f = _mm512_cvtepi32_ps(b_512); // 16 ps + a_f = _mm512_sub_ps(a_f, zp0); + b_f = _mm512_sub_ps(b_f, zp1); + auto reg_out = _mm512_cvtne2ps_pbh(b_f, a_f); // 32 packed bf16 + _mm512_store_epi32(dst, (__m512i)reg_out); // + src += 32; // 32 int8_t dequantized into 32 bf16 + dst += 32; + } + } + + // K tail, because right align need to fill zero for padding, real valid k is from invalid_k_num to the end + template + void i8_to_bf16_Kx32_tail(int8_t *&src, ov::bfloat16 *dst, float* zp, int k_start, int invalid_k_num) + { + auto zp0 = _mm512_loadu_ps(zp); + auto zp1 = _mm512_loadu_ps(zp + 16); + auto zero = _mm512_setzero_epi32(); + for (int k = 0; k < K; k++) { + auto k_cur = k + k_start; + if (k_cur < invalid_k_num) { + _mm512_store_epi32(dst, zero); + } else { + auto a = _mm_load_si128((__m128i *)src); // 16 int8 + auto b = _mm_load_si128((__m128i *)(src + 16)); // 16 int8 + auto a_512 = _mm512_cvtepu8_epi32(a); // 16 int32 + auto b_512 = _mm512_cvtepu8_epi32(b); // 16 int32 + auto a_f = _mm512_cvtepi32_ps(a_512); // 16 ps + auto b_f = _mm512_cvtepi32_ps(b_512); // 16 ps + a_f = _mm512_sub_ps(a_f, zp0); + b_f = _mm512_sub_ps(b_f, zp1); + auto reg_out = _mm512_cvtne2ps_pbh(b_f, a_f); // 32 packed bf16 + _mm512_store_epi32(dst, (__m512i)reg_out); + } + src += 32; + dst += 32; + } + } + inline void bf16_to_i8_tensor(tensor2D& dst, tensor2D& src, float quant_scale) { dst.resize(src.dims[0], src.dims[1]); auto scale = _mm512_set1_ps(quant_scale); @@ -1249,6 +1299,49 @@ namespace functional { } } } + + inline void u8_to_u16_tensor(tensor2D& dst, const tensor2D& src) { + dst.resize(src.dims[0], src.dims[1]); + auto tail = src.dims[1] % 32; + __mmask32 x_mask = _cvtu32_mask32(0xFFFFFFFF >> (32 - tail)); + for (int k = 0; k < src.dims[0]; k++) { + auto p_src = &src(k, 0); + auto p_dst = &dst(k, 0); + int i; + for(i = 0; i < src.dims[1] / 32 * 32; i += 32) { + auto x = _mm256_loadu_epi8(p_src + i); + auto y = _mm512_cvtepu8_epi16(x); + _mm512_storeu_epi16(p_dst + i, y); + } + // handle tails + if (tail) { + auto x = _mm256_maskz_loadu_epi8(x_mask, p_src + i); + auto y = _mm512_cvtepu8_epi16(x); + _mm512_mask_storeu_epi16(p_dst + i, x_mask, y); + } + } + } + + inline void u16_to_u8_tensor(tensor2D&& dst, const tensor2D& src) { + auto tail = src.dims[1] % 32; + __mmask32 x_mask = _cvtu32_mask32(0xFFFFFFFF >> (32 - tail)); + for (int k = 0; k < src.dims[0]; k++) { + auto p_src = &src(k, 0); + auto p_dst = &dst(k, 0); + int i; + for(i = 0; i < src.dims[1] / 32 * 32; i += 32) { + auto x = _mm512_loadu_epi16(p_src + i); + auto y = _mm512_cvtusepi16_epi8(x); + _mm256_storeu_epi8(p_dst + i, y); + } + // handle tails + if (tail) { + auto x = _mm512_maskz_loadu_epi16(x_mask, p_src + i); + auto y = _mm512_cvtusepi16_epi8(x); + _mm256_mask_storeu_epi8(p_dst + i, x_mask, y); + } + } + } }; // 2x2 tiles post process kernels @@ -1711,6 +1804,135 @@ inline void repackB_1x2(tensor2D &Bi, bool transpose, tensor2D &Bi, bool transpose, tensor2D& Bo, bool is_const) { + int K = Bi.dims[transpose ? 1 : 0]; + int N = Bi.dims[transpose ? 0 : 1]; + + // K_padded : round up to multiple of 32/64 + int kStep = 32; + int K_padded = (K + kStep - 1) / kStep * kStep; + int Ktails = K % kStep; + int Kbody = K - Ktails; + + // N_padded : round up to multiple of (2*16) + int N_unit = 2 * 16; + int N_padded = (N + N_unit - 1) / N_unit * N_unit; + + // Bo(ni, 0) is a vector flattened from a slice of shape [K_padded x N_unit] + Bo.resize(N_padded / N_unit, K_padded * N_unit, false, is_const); + + int n = 0; + int n_tail = N % N_unit; + if (transpose) { + tensor2D Btmp(16, 32), transTmp(16, 32); + for(; n < N - n_tail; n += N_unit) { + // a K_padded x N_unit submatrix layouted in B0/B1... and put sequentially + auto* dst = reinterpret_cast(&Bo(n / N_unit, 0)); + int k; + for(k = 0; k < Kbody; k += kStep) { + // B0 (16x32) => transpose+repack as 32x16(16x16x2) or 64x16(16x16x4) + functional::u8_to_u16_tensor(Btmp, tensor2D(16, 32, &Bi(n, k), Bi.stride)); + functional::transpose_epi32_16x16(&transTmp(0, 0), &Btmp(0, 0), Btmp.stride); + functional::u16_to_u8_tensor(tensor2D(16, 32, dst, 32), transTmp); + dst += 512; + functional::u8_to_u16_tensor(Btmp, tensor2D(16, 32, &Bi(n + 16, k), Bi.stride)); + functional::transpose_epi32_16x16(&transTmp(0, 0), &Btmp(0, 0), Btmp.stride); + functional::u16_to_u8_tensor(tensor2D(16, 32, dst, 32), transTmp); + dst += 512; + } + if (Ktails) { + // Ktails part is loaded into A tile right-aligned, so B tile must also load + // Ktails part to bottom-aligned, and fill upper padding with zero + functional::u8_to_u16_tensor(Btmp, tensor2D(16, K - k, &Bi(n, k), Bi.stride)); + functional::transpose_epi32_16xN_right_align(&transTmp(0, 0), &Btmp(0, 0), Btmp.stride, (K - k) * sizeof(ov::bfloat16)); + functional::u16_to_u8_tensor(tensor2D(16, 32, dst, 32), transTmp); + dst += 512; + functional::u8_to_u16_tensor(Btmp, tensor2D(16, K - k, &Bi(n + 16, k), Bi.stride)); + functional::transpose_epi32_16xN_right_align(&transTmp(0, 0), &Btmp(0, 0), Btmp.stride, (K - k) * sizeof(ov::bfloat16)); + functional::u16_to_u8_tensor(tensor2D(16, 32, dst, 32), transTmp); + dst += 512; + } + } + // n_tail: [16, 32) + if (N - n >= 16) { + auto* dst = reinterpret_cast(&Bo(n / N_unit, 0)); + int k; + for(k = 0; k < Kbody; k += kStep) { + // B0 (16x32) => transpose+repack as 32x16(16x16x2) or 64x16(16x16x4) + functional::u8_to_u16_tensor(Btmp, tensor2D(16, 32, &Bi(n, k), Bi.stride)); + functional::transpose_epi32_16x16(&transTmp(0, 0), &Btmp(0, 0), Btmp.stride); + functional::u16_to_u8_tensor(tensor2D(16, 32, dst, 32), transTmp); + dst += 1024 * 1; + } + if (Ktails) { + // Ktails part is loaded into A tile right-aligned, so B tile must also load + // Ktails part to bottom-aligned, and fill upper padding with zero + functional::u8_to_u16_tensor(Btmp, tensor2D(16, K - k, &Bi(n, k), Bi.stride)); + functional::transpose_epi32_16xN_right_align(&transTmp(0, 0), &Btmp(0, 0), Btmp.stride, (K - k) * sizeof(ov::bfloat16)); + functional::u16_to_u8_tensor(tensor2D(16, 32, dst, 32), transTmp); + } + n += 16; + } + // n_tail: (0, 16) + if (N - n > 0) { + auto* dst = reinterpret_cast(&Bo(n / N_unit, 0)) + (n_tail > 16 ? 512 : 0); + int k; + for(k = 0; k < Kbody; k += kStep) { + // B0 (16x32) => transpose+repack as 32x16(16x16x2) or 64x16(16x16x4) + functional::u8_to_u16_tensor(Btmp, tensor2D(N - n, 32, &Bi(n, k), Bi.stride)); + functional::transpose_epi32_Mx16(&transTmp(0, 0), &Btmp(0, 0), Btmp.stride, N - n); + functional::u16_to_u8_tensor(tensor2D(16, 32, dst, 32), transTmp); + dst += 1024 * 1; + } + if (Ktails) { + // Ktails part is loaded into A tile right-aligned, so B tile must also load + // Ktails part to bottom-aligned, and fill upper padding with zero + functional::u8_to_u16_tensor(Btmp, tensor2D(N - n, K - k, &Bi(n, k), Bi.stride)); + functional::transpose_epi32_MxN_right_align(&transTmp(0, 0), &Btmp(0, 0), Btmp.stride, (K - k) * sizeof(ov::bfloat16), N - n); + functional::u16_to_u8_tensor(tensor2D(16, 32, dst, 32), transTmp); + } + n = N; + } + // second B tile is untouched, need to set to zero + if (n_tail > 0 && n_tail <= 16) { + auto* dst = reinterpret_cast(&Bo(n / N_unit, 0)); + for (int k = 0; k < K_padded; k += kStep) { + memset(dst + 512, 0, 512); + dst += 1024 * 1; + } + } + } else { + // pack & layout sequentially + int n = 0; + int n_tail = N % N_unit; + tensor2D Btmp(32, 32), transTmp(32, 32); + for(; n < N - n_tail; n += N_unit) { + auto * dst = reinterpret_cast(&Bo(n / N_unit, 0)); + for(int k = 0; k < K; k += kStep) { + // bf16: B0 B1 32x(16+16) => repack as two 16x16x2 + int src_rows = std::min(K - k, kStep); + functional::u8_to_u16_tensor(Btmp, tensor2D(src_rows, 32, &Bi(k, n), Bi.stride)); + functional::kpack_tile_B0B1(&transTmp(0, 0), &transTmp(16, 0), reinterpret_cast(&Btmp(0, 0)), Btmp.stride, src_rows); + functional::u16_to_u8_tensor(tensor2D(32, 32, dst, 32), transTmp); + dst += 1024; + } + } + // n_tail: (0, 32) + if (N - n > 0) { + auto * dst = reinterpret_cast(&Bo(n / N_unit, 0)); + for(int k = 0; k < K; k += kStep) { + // bf16: B0 B1 32x(16+16) => repack as two 16x16x2 + int src_rows = std::min(K - k, kStep); + functional::u8_to_u16_tensor(Btmp, tensor2D(src_rows, N - n, &Bi(k, n), Bi.stride)); + functional::kpack_tile_B0B1_ntail(&transTmp(0, 0), &transTmp(16, 0), reinterpret_cast(&Btmp(0, 0)), Btmp.stride, src_rows, N - n); + functional::u16_to_u8_tensor(tensor2D(32, 32, dst, 32), transTmp); + dst += 1024; + } + n += 16; + } + } +} + template struct acc_type {}; template<> @@ -2228,7 +2450,7 @@ struct Matmul { (ppkernel)(buffC, m, n + n0, valid_m, valid_n); }; - if (M <= 32 && M >16) { + if (M <= 32 && M > 16) { // 2x2 tile, C:0/1/2/3 A:4/5 B:6/7 no blocking along M dimension tileconfig_t tfg(1, 0, {16,16,M-16,M-16,16,M-16,16,16}, 64); loop2D_no_bM<32>(M, N, kernel_2x2); @@ -2251,8 +2473,8 @@ struct Matmul { // specialization: // TA is ov::bfloat16 and TB is int8_t, decompressed on the fly into ov::bfloat16 by simply convert template<> -struct Matmul { - tensor2D internalBI8; +struct Matmul { + tensor2D internalBI8; // wei_buff is ping-pong buffer containing ov::bfloat16 weights decompressed on the fly. tensor2D weiBuff; @@ -2270,8 +2492,8 @@ struct Matmul { Matmul(bool constB = false, bool transposeB = false) : constB(constB), transposeB(transposeB), buffC(32, 32) {} - float quant_scale_B; - float dequant_scale_B; + float* dequant_scale_B; + float* zp; template void operator()(tensor2D & matA, @@ -2292,31 +2514,13 @@ struct Matmul { int Kbody = K - Ktails; int Kbackoff = (kStep - Ktails); - // for non-constB, internalB is updated every time - // for constB, internalB is updated once - if (!constB || (internalBI8.capacity == 0)) { - // this dynamic quantization of weight matrix using minmax - // is time-consuming, should be used only for constB - if (!constB) { - DEBUG_LOG << "\t WANING: dynamic quantization of weight matrix for non-constB is time-consuming " << std::endl; - } - // float min, max; - // functional::get_min_max(_matB, min, max); - // max = std::max(std::abs(max), std::abs(min)); - // quant_scale_B = 127 / max; - // dequant_scale_B = max / 127; - - tensor2D internalTmpB; - repackB_1x2(matB, transposeB, internalTmpB, constB); - functional::bf16_to_i8_tensor(internalBI8, internalTmpB, quant_scale_B); - } - ppkernel.set_deq_scale(dequant_scale_B); + auto zp_start = zp + n0 * 2; if (M <= 16) { // C:0/1 A:2 B:3/4 // dequantize scale is moved into ppkernel - constexpr int prefetch_ahead = 64*1024; + //constexpr int prefetch_ahead = 64*1024; tileconfig_t tfg(1, 0, {M,M,M,16,16}, 64); auto * pBint = reinterpret_cast(&internalBI8[0]); auto & B2buff = weiBuff; @@ -2324,55 +2528,92 @@ struct Matmul { auto * const pB = &B2buff[0]; auto * pBsrc = pB + (32*32) * 0; auto * pBdst = pB + (32*32) * 1; - functional::i8_to_bf16_Kx32<32>(pBint, pBsrc); - auto * const pC0 = &buffC[0]; const auto strideA = matA.stride; - loop2D_no_bM<32>(M, N, [&](int m, int n, int valid_m, int valid_n) { - // C:Mx32 = A:Mx32 x B:32x32 - _tile_zero(0); - _tile_zero(1); - auto * pA0 = &matA[0]; - for(int k=0; k(pBint, pBdst); - _tile_loadd(3, pBsrc, 64); - functional::i8_to_bf16_Kx32<8>(pBint, pBdst + 8*32); - _tile_dpbf16ps(0, 2, 3); // C0 += A*B0 - - prefetch_bytes(512, _MM_HINT_T1, prefetch_ahead, pBint); - functional::i8_to_bf16_Kx32<8>(pBint, pBdst + 16*32); - _tile_loadd(4, pBsrc + 16*32, 64); - functional::i8_to_bf16_Kx32<8>(pBint, pBdst + 24*32); - _tile_dpbf16ps(1, 2, 4); // C1 += A*B1 - std::swap(pBsrc, pBdst); - } - if (Ktails) { - _tile_loadd(2, pA0 - Kbackoff, strideA); // backoff to prevent access beyond the end of A - prefetch_bytes(512, _MM_HINT_T1, prefetch_ahead, pBint); - - functional::i8_to_bf16_Kx32<8>(pBint, pBdst); - _tile_loadd(3, pBsrc, 64); - functional::i8_to_bf16_Kx32<8>(pBint, pBdst + 8*32); - _tile_dpbf16ps(0, 2, 3); // C0 += A*B0 - - prefetch_bytes(512, _MM_HINT_T1, prefetch_ahead, pBint); - functional::i8_to_bf16_Kx32<8>(pBint, pBdst + 16*32); - _tile_loadd(4, pBsrc + 16*32, 64); - functional::i8_to_bf16_Kx32<8>(pBint, pBdst + 24*32); - _tile_dpbf16ps(1, 2, 4); // C1 += A*B1 - std::swap(pBsrc, pBdst); - } - //prefetch_bytes(2048, _MM_HINT_T1, prefetch_ahead, pBint); - _tile_stored(0, pC0, buffC.stride); - _tile_stored(1, pC0 + 16, buffC.stride); - //prefetch_bytes(2048, _MM_HINT_T1, prefetch_ahead, pBint + 2048); - //int valid_n = std::min(N - n, 32); - (ppkernel)(buffC, 0, n + n0, M, valid_n); - }); + if (Ktails) { + // with tails, will decompress current block's weights: if ahead we need to know the (tails - 1) which is in the + // tight loop, then special handle the decompress process - skip subtract zeropoint step + loop2D_no_bM<32>(M, N, [&](int m, int n, int valid_m, int valid_n) { + // C:Mx32 = A:Mx32 x B:32x32 + _tile_zero(0); + _tile_zero(1); + auto * pA0 = &matA[0]; + auto cur_zp = zp_start + n * 2; + for(int k=0; k(pBint, pBsrc, cur_zp); + _tile_loadd(3, pBsrc, 64); + _tile_dpbf16ps(0, 2, 3); // C0 += A*B0 + functional::i8_to_bf16_Kx32<16>(pBint, pBsrc + 16 * 32, cur_zp + 32); + + // prefetch_bytes(512, _MM_HINT_T1, prefetch_ahead, pBint); + _tile_loadd(4, pBsrc + 16*32, 64); + _tile_dpbf16ps(1, 2, 4); // C1 += A*B1 + std::swap(pBsrc, pBdst); + } + // tails + { + _tile_loadd(2, pA0 - Kbackoff, strideA); // backoff to prevent access beyond the end of A + //prefetch_bytes(512, _MM_HINT_T1, prefetch_ahead, pBint); + + functional::i8_to_bf16_Kx32_tail<16>(pBint, pBsrc, cur_zp, 0, Kbackoff / 2); + _tile_loadd(3, pBsrc, 64); + _tile_dpbf16ps(0, 2, 3); // C0 += A*B0 + + //prefetch_bytes(512, _MM_HINT_T1, prefetch_ahead, pBint); + functional::i8_to_bf16_Kx32_tail<16>(pBint, pBsrc + 16*32, cur_zp + 32, 0, Kbackoff / 2); + _tile_loadd(4, pBsrc + 16*32, 64); + _tile_dpbf16ps(1, 2, 4); // C1 += A*B1 + std::swap(pBsrc, pBdst); + } + //prefetch_bytes(2048, _MM_HINT_T1, prefetch_ahead, pBint); + _tile_stored(0, pC0, buffC.stride); + _tile_stored(1, pC0 + 16, buffC.stride); + //prefetch_bytes(2048, _MM_HINT_T1, prefetch_ahead, pBint + 2048); + //int valid_n = std::min(N - n, 32); + (ppkernel)(buffC, 0, n + n0, M, valid_n); + }); + } else { + // no tails, will decompress next block's weights ahead + functional::i8_to_bf16_Kx32<16>(pBint, pBsrc, zp_start); + functional::i8_to_bf16_Kx32<16>(pBint, pBsrc + 16 * 32, zp_start + 32); + + loop2D_no_bM<32>(M, N, [&](int m, int n, int valid_m, int valid_n) { + // C:Mx32 = A:Mx32 x B:32x32 + _tile_zero(0); + _tile_zero(1); + auto * pA0 = &matA[0]; + auto cur_zp = zp_start + n * 2; + for(int k=0; k(pBint, pBdst, cur_zp); + _tile_loadd(3, pBsrc, 64); + functional::i8_to_bf16_Kx32<8>(pBint, pBdst + 8*32, cur_zp); + _tile_dpbf16ps(0, 2, 3); // C0 += A*B0 + + //prefetch_bytes(512, _MM_HINT_T1, prefetch_ahead, pBint); + functional::i8_to_bf16_Kx32<8>(pBint, pBdst + 16*32, cur_zp + 32); + _tile_loadd(4, pBsrc + 16*32, 64); + functional::i8_to_bf16_Kx32<8>(pBint, pBdst + 24*32, cur_zp + 32); + _tile_dpbf16ps(1, 2, 4); // C1 += A*B1 + std::swap(pBsrc, pBdst); + } + //prefetch_bytes(2048, _MM_HINT_T1, prefetch_ahead, pBint); + _tile_stored(0, pC0, buffC.stride); + _tile_stored(1, pC0 + 16, buffC.stride); + //prefetch_bytes(2048, _MM_HINT_T1, prefetch_ahead, pBint + 2048); + //int valid_n = std::min(N - n, 32); + (ppkernel)(buffC, 0, n + n0, M, valid_n); + }); + } return; } @@ -2384,48 +2625,74 @@ struct Matmul { auto * pA0 = &matA(m, 0); auto * pA1 = &matA(m + 16, 0); auto * pBint = reinterpret_cast(&internalBI8(n>>5, 0)); - functional::i8_to_bf16_Kx32<32>(pBint, pBb); - _tile_zero(0); _tile_zero(1); _tile_zero(2); _tile_zero(3); - int k; - for (k = 0; k < Kbody; k += kStep) { - functional::i8_to_bf16_Kx32<16>(pBint, pBa); + auto cur_zp = zp_start + n * 2; + if (Ktails) { + int k; + for (k = 0; k < Kbody; k += kStep) { + functional::i8_to_bf16_Kx32<16>(pBint, pBa, cur_zp); - _tile_loadd(4, pA0 + k, strideA); - _tile_loadd(6, pBb, 64); - _tile_dpbf16ps(0, 4, 6); + _tile_loadd(4, pA0 + k, strideA); + _tile_loadd(6, pBa, 64); + _tile_dpbf16ps(0, 4, 6); - _tile_loadd(5, pA1 + k, strideA); - _tile_dpbf16ps(2, 5, 6); + _tile_loadd(5, pA1 + k, strideA); + _tile_dpbf16ps(2, 5, 6); - functional::i8_to_bf16_Kx32<16>(pBint, pBa + 16*32); + functional::i8_to_bf16_Kx32<16>(pBint, pBa + 16*32, cur_zp + 32); - _tile_loadd(7, pBb + 16*32, 64); - _tile_dpbf16ps(1, 4, 7); - _tile_dpbf16ps(3, 5, 7); + _tile_loadd(7, pBa + 16*32, 64); + _tile_dpbf16ps(1, 4, 7); + _tile_dpbf16ps(3, 5, 7); - std::swap(pBa, pBb); - } - if (Ktails) { - functional::i8_to_bf16_Kx32<16>(pBint, pBa); + std::swap(pBa, pBb); + } + // tails + { + functional::i8_to_bf16_Kx32_tail<16>(pBint, pBa, cur_zp, 0, Kbackoff / 2); + + _tile_loadd(4, pA0 + k - Kbackoff, strideA); + _tile_loadd(6, pBa, 64); + _tile_dpbf16ps(0, 4, 6); + + _tile_loadd(5, pA1 + k - Kbackoff, strideA); + _tile_dpbf16ps(2, 5, 6); + + functional::i8_to_bf16_Kx32_tail<16>(pBint, pBa + 16*32, cur_zp + 32, 0, Kbackoff / 2); + + _tile_loadd(7, pBa + 16*32, 64); + _tile_dpbf16ps(1, 4, 7); + _tile_dpbf16ps(3, 5, 7); - _tile_loadd(4, pA0 + k - Kbackoff, strideA); - _tile_loadd(6, pBb, 64); - _tile_dpbf16ps(0, 4, 6); + std::swap(pBa, pBb); + } + } else { + functional::i8_to_bf16_Kx32<16>(pBint, pBb, cur_zp); + functional::i8_to_bf16_Kx32<16>(pBint, pBb + 16 * 32, cur_zp + 32); + + for (int k = 0; k < Kbody; k += kStep) { + // weights are ahead of A, if reach the last K block, need to change to next N block + cur_zp += (k == K - kStep) * 64; + functional::i8_to_bf16_Kx32<16>(pBint, pBa, cur_zp); - _tile_loadd(5, pA1 + k - Kbackoff, strideA); - _tile_dpbf16ps(2, 5, 6); + _tile_loadd(4, pA0 + k, strideA); + _tile_loadd(6, pBb, 64); + _tile_dpbf16ps(0, 4, 6); - functional::i8_to_bf16_Kx32<16>(pBint, pBa + 16*32); + _tile_loadd(5, pA1 + k, strideA); + _tile_dpbf16ps(2, 5, 6); - _tile_loadd(7, pBb + 16*32, 64); - _tile_dpbf16ps(1, 4, 7); - _tile_dpbf16ps(3, 5, 7); + functional::i8_to_bf16_Kx32<16>(pBint, pBa + 16*32, cur_zp + 32); - std::swap(pBa, pBb); + _tile_loadd(7, pBb + 16*32, 64); + _tile_dpbf16ps(1, 4, 7); + _tile_dpbf16ps(3, 5, 7); + + std::swap(pBa, pBb); + } } _tile_stored(0, &buffC(0,0), buffC.stride); _tile_stored(1, &buffC(0,16), buffC.stride); diff --git a/tests/src/test_common.hpp b/tests/src/test_common.hpp index ac44735..22dbd83 100644 --- a/tests/src/test_common.hpp +++ b/tests/src/test_common.hpp @@ -21,8 +21,6 @@ #endif #include -#define rndup(x, n) (((x + n - 1)/n)*n) - std::string dtype_to_str(llmdnn::data_type_t type); using func_act = std::function; diff --git a/tests/src/test_fc_amx.cpp b/tests/src/test_fc_amx.cpp index 07f67e3..5a4b1d3 100644 --- a/tests/src/test_fc_amx.cpp +++ b/tests/src/test_fc_amx.cpp @@ -146,6 +146,89 @@ class FCTest : public TestWithParam { ASSERT_TRUE(compare(C, C_Ref, thresh)); } + template + void do_test_wc() { + fc_create_param param = { + _dt_a, _dt_b, _dt_c, + _is_transpose, _postops_type + }; + // bf16 needs divisible by 2 + if (_K % 2 == 1) _K += 1; + + tensor2D A(_M, _K, true); + tensor2D B(_K, _N, true); + tensor2D C(_M, _N, true); + tensor2D C_Ref(_M, _N, true); + tensor2D dq(1, _N); + tensor2D zp(1, _N); + tensor2D bias(1, _N); + + fill_rnd(A); + for (int i = 0; i < _N; i++) { + // make all weight 1: w - zp == 1 + zp.data[i] = static_cast(i % 255) - 1; + dq.data[i] = (i % 3) * 0.5; + for (int j = 0; j < _K; j++) { + B(j, i) = i % 255; + } + } + fill_rnd(bias); + param.scale = dq.data; + param.zp = zp.data; + param.scale_zp_size = _N; + llmdnn::fc fc; + ASSERT_TRUE(fc.init(param)); + + tensor2D BT = B.Tr(true); + uint8_t* ptr_B; + size_t ldb; + tensor weight; + if (_is_transpose) { + ptr_B = BT.data; + ldb = BT.stride; + weight.resize({ static_cast(BT.dims[0]), static_cast(BT.dims[1]) }, static_cast(ptr_B)); + } else { + ptr_B = B.data; + ldb = B.stride; + weight.resize({ static_cast(B.dims[0]), static_cast(B.dims[1]) }, static_cast(ptr_B)); + } + fc.pack_weight(weight); + tensor input, output, bias_t, q_t, dq_t; + input.resize({ static_cast(A.dims[0]), static_cast(A.dims[1]) }, static_cast(A.data)); + output.resize({ static_cast(C.dims[0]), static_cast(C.dims[1]) }, static_cast(C.data)); + dq_t.resize({ static_cast(dq.dims[0]), static_cast(dq.dims[1]) }, dq.data); + bias_t.resize({ static_cast(bias.dims[0]), static_cast(bias.dims[1]) }, bias.data); + ASSERT_TRUE(fc.exec(input, output, dq_t, q_t, bias_t) == llmdnn::status_ok); + C_Ref = 0; + float* ptr_dq = nullptr; + float* ptr_q = nullptr; + float* ptr_bias = nullptr; + func_act act = func_act(); + ptr_dq = dq.data; + if (_postops_type & BIAS) { + ptr_bias = bias.data; + } + if (_postops_type & GELU) { + act = [] (float x) { + return x * 0.5 * (1 + std::erf(x / std::sqrt(2))); + }; + } + if (_postops_type & GELU_TANH) { + act = [] (float x) { + return 0.5f * x * (1.0f + std::tanh(std::sqrt(2.0f / 3.1415926f) * x * (1 + 0.044715f * x * x))); + }; + } + + B = 1; + matmul(A, B, C_Ref, ptr_dq, ptr_bias, act, ptr_q); + float thresh = 0.0001f; + if (std::is_same::value || std::is_same::value) + thresh = 1.1f; + if (std::is_same::value) + thresh = 0.01f; + ASSERT_TRUE(compare(C, C_Ref, thresh)); + } + int _M, _N, _K; bool _is_transpose; postops_types _postops_type; @@ -167,6 +250,10 @@ TEST_P(FCTest, Func) { do_test(); } else if (_dt_a == llmdnn_bf16 && _dt_weight == llmdnn_f32 && _dt_c == llmdnn_f32) { do_test(); + } else if (_dt_a == llmdnn_bf16 && _dt_weight == llmdnn_u8 && _dt_c == llmdnn_bf16) { + do_test_wc(); + } else if (_dt_a == llmdnn_bf16 && _dt_weight == llmdnn_u8 && _dt_c == llmdnn_f32) { + do_test_wc(); } else { ASSERT_TRUE(false); } @@ -178,8 +265,8 @@ TEST_P(FCTest, Func) { // (s8,s8,f32),dq,[bias],[gelu] // (bf16,bf16,bf16),[bias],[gelu] // (bf16,bf16,f32),[bias],[gelu] -// (bf16,s8,f32),dq,[bias],[gelu] -// (bf16,s8,bf16),dq,[bias],[gelu] +// (bf16,u8,f32),dq,[bias],[gelu] +// (bf16,u8,bf16),dq,[bias],[gelu] const std::vector types = { { llmdnn_s8, llmdnn_s8, llmdnn_s8, llmdnn_s8, DEQUANT_QUANT }, { llmdnn_s8, llmdnn_s8, llmdnn_s8, llmdnn_s8, DEQUANT_BIAS_QUANT }, @@ -224,38 +311,29 @@ const std::vector types = { { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, llmdnn_f32, GELU_TANH }, { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, llmdnn_f32, BIAS_GELU_TANH }, // weight compression - { llmdnn_bf16, llmdnn_s8, llmdnn_f32, llmdnn_bf16, DEQUANT }, - { llmdnn_bf16, llmdnn_s8, llmdnn_f32, llmdnn_bf16, DEQUANT_BIAS }, - { llmdnn_bf16, llmdnn_s8, llmdnn_f32, llmdnn_bf16, DEQUANT_GELU }, - { llmdnn_bf16, llmdnn_s8, llmdnn_f32, llmdnn_bf16, DEQUANT_BIAS_GELU }, - { llmdnn_bf16, llmdnn_s8, llmdnn_bf16, llmdnn_bf16, DEQUANT }, - { llmdnn_bf16, llmdnn_s8, llmdnn_bf16, llmdnn_bf16, DEQUANT_BIAS }, - { llmdnn_bf16, llmdnn_s8, llmdnn_bf16, llmdnn_bf16, DEQUANT_GELU }, - { llmdnn_bf16, llmdnn_s8, llmdnn_bf16, llmdnn_bf16, DEQUANT_BIAS_GELU }, - { llmdnn_bf16, llmdnn_s8, llmdnn_f32, llmdnn_f32, DEQUANT }, - { llmdnn_bf16, llmdnn_s8, llmdnn_f32, llmdnn_f32, DEQUANT_BIAS }, - { llmdnn_bf16, llmdnn_s8, llmdnn_f32, llmdnn_f32, DEQUANT_GELU }, - { llmdnn_bf16, llmdnn_s8, llmdnn_f32, llmdnn_f32, DEQUANT_BIAS_GELU }, - { llmdnn_bf16, llmdnn_s8, llmdnn_bf16, llmdnn_f32, DEQUANT }, - { llmdnn_bf16, llmdnn_s8, llmdnn_bf16, llmdnn_f32, DEQUANT_BIAS }, - { llmdnn_bf16, llmdnn_s8, llmdnn_bf16, llmdnn_f32, DEQUANT_GELU }, - { llmdnn_bf16, llmdnn_s8, llmdnn_bf16, llmdnn_f32, DEQUANT_BIAS_GELU }, + { llmdnn_bf16, llmdnn_u8, llmdnn_f32, llmdnn_u8, DEQUANT }, + { llmdnn_bf16, llmdnn_u8, llmdnn_f32, llmdnn_u8, DEQUANT_BIAS }, + { llmdnn_bf16, llmdnn_u8, llmdnn_f32, llmdnn_u8, DEQUANT_GELU }, + { llmdnn_bf16, llmdnn_u8, llmdnn_f32, llmdnn_u8, DEQUANT_BIAS_GELU }, + { llmdnn_bf16, llmdnn_u8, llmdnn_bf16, llmdnn_u8, DEQUANT }, + { llmdnn_bf16, llmdnn_u8, llmdnn_bf16, llmdnn_u8, DEQUANT_BIAS }, + { llmdnn_bf16, llmdnn_u8, llmdnn_bf16, llmdnn_u8, DEQUANT_GELU }, + { llmdnn_bf16, llmdnn_u8, llmdnn_bf16, llmdnn_u8, DEQUANT_BIAS_GELU }, }; // M, N, K const std::vector shapes = { // normal {256, 128, 448}, - // k tail - {256, 48, 449}, - // M tail == unroll 8 - {256 + 8, 48, 449}, - // M tail == unroll 8 + 2 - {256 + 10, 48, 449}, - // N tail - {256, 95, 448}, + // M < 16 + {15, 129, 447}, + {15, 129, 448}, + // M in (16, 32] + {31, 129, 447}, + {31, 129, 448}, // all tail - {256 + 9, 47, 449}, + {256 + 9, 129, 449}, + {256 + 9, 129, 448}, // gemv, K <= 64(32)*6 {256, 1, 80}, }; diff --git a/tests/src/test_fc_kernel_amx.cpp b/tests/src/test_fc_kernel_amx.cpp index 58ff26a..ad23884 100644 --- a/tests/src/test_fc_kernel_amx.cpp +++ b/tests/src/test_fc_kernel_amx.cpp @@ -137,6 +137,82 @@ class FCKernelTest : public TestWithParam { ASSERT_TRUE(compare(C, C_Ref, thresh)); } + template + void do_weight_compress_test() { + fc_kernel* fc; + fc_create_param param = { + _dt_a, _dt_b, _dt_c, + _is_transpose, _postops_type + }; + // bf16 needs divisible by 2 + if (_K % 2 == 1) _K += 1; + + tensor2D A(_M, _K, true); + tensor2D B(_K, _N, true); + tensor2D C(_M, _N, true); + tensor2D C_Ref(_M, _N, true); + tensor2D dq(1, _N); + tensor2D zp(1, _N * 2); + tensor2D bias(1, _N); + + fill_rnd(A); + dq = 999999; + for (int i = 0; i < _N; i++) { + zp.data[i * 2 + 0] = static_cast(i % 255); + zp.data[i * 2 + 1] = static_cast(i % 255); + for (int j = 0; j < _K; j++) { + B(j, i) = i % 255; + } + } + bias = 0; + param.scale = dq.data; + param.zp = zp.data; + param.scale_zp_size = _N; + ASSERT_TRUE(fc_kernel_create(&fc, ¶m) == llmdnn::status_ok); + auto gemm = std::shared_ptr(fc, [](fc_kernel* p) { fc_kernel_destroy(p); }); + + tensor2D BT = B.Tr(true); + uint8_t* ptr_B; + size_t ldb; + if (_is_transpose) { + ptr_B = BT.data; + ldb = BT.stride; + } else { + ptr_B = B.data; + ldb = B.stride; + } + fc_kernel_pack_weight(gemm.get(), ptr_B, _dt_weight, _N, _K, ldb, 0, _N); + fc_kernel_execute(gemm.get(), A.data, nullptr, C.data, A.stride, + C.stride, _M, _N, _K, 0, _N, dq.data, nullptr, bias.data); + C_Ref = 0; + float* ptr_dq = nullptr; + float* ptr_q = nullptr; + float* ptr_bias = nullptr; + func_act act = func_act(); + ptr_dq = dq.data; + if (_postops_type & BIAS) { + ptr_bias = bias.data; + } + if (_postops_type & GELU) { + act = [] (float x) { + return x * 0.5 * (1 + std::erf(x / std::sqrt(2))); + }; + } + if (_postops_type & GELU_TANH) { + act = [] (float x) { + return 0.5f * x * (1.0f + std::tanh(std::sqrt(2.0f / 3.1415926f) * x * (1 + 0.044715f * x * x))); + }; + } + + //matmul(A, B, C_Ref, ptr_dq, ptr_bias, act, ptr_q); + float thresh = 0.0001f; + if (std::is_same::value || std::is_same::value) + thresh = 1.1f; + if (std::is_same::value) + thresh = 0.01f; + ASSERT_TRUE(compare(C, C_Ref, thresh)); + } + int _M, _N, _K; bool _is_transpose; postops_types _postops_type; @@ -158,6 +234,10 @@ TEST_P(FCKernelTest, Func) { do_test(); } else if (_dt_a == llmdnn_bf16 && _dt_weight == llmdnn_f32 && _dt_c == llmdnn_f32) { do_test(); + } else if (_dt_a == llmdnn_bf16 && _dt_weight == llmdnn_u8 && _dt_c == llmdnn_bf16) { + do_weight_compress_test(); + } else if (_dt_a == llmdnn_bf16 && _dt_weight == llmdnn_u8 && _dt_c == llmdnn_f32) { + do_weight_compress_test(); } else { ASSERT_TRUE(false); } @@ -169,8 +249,8 @@ TEST_P(FCKernelTest, Func) { // (s8,s8,f32),dq,[bias],[gelu] // (bf16,bf16,bf16),[bias],[gelu] // (bf16,bf16,f32),[bias],[gelu] -// (bf16,s8,f32),dq,[bias],[gelu] -// (bf16,s8,bf16),dq,[bias],[gelu] +// (bf16,u8,f32),dq,[bias],[gelu] +// (bf16,u8,bf16),dq,[bias],[gelu] const std::vector types = { { llmdnn_s8, llmdnn_s8, llmdnn_s8, llmdnn_s8, DEQUANT_QUANT }, { llmdnn_s8, llmdnn_s8, llmdnn_s8, llmdnn_s8, DEQUANT_BIAS_QUANT }, @@ -215,38 +295,29 @@ const std::vector types = { { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, llmdnn_f32, GELU_TANH }, { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, llmdnn_f32, BIAS_GELU_TANH }, // weight compression - { llmdnn_bf16, llmdnn_s8, llmdnn_f32, llmdnn_bf16, DEQUANT }, - { llmdnn_bf16, llmdnn_s8, llmdnn_f32, llmdnn_bf16, DEQUANT_BIAS }, - { llmdnn_bf16, llmdnn_s8, llmdnn_f32, llmdnn_bf16, DEQUANT_GELU }, - { llmdnn_bf16, llmdnn_s8, llmdnn_f32, llmdnn_bf16, DEQUANT_BIAS_GELU }, - { llmdnn_bf16, llmdnn_s8, llmdnn_bf16, llmdnn_bf16, DEQUANT }, - { llmdnn_bf16, llmdnn_s8, llmdnn_bf16, llmdnn_bf16, DEQUANT_BIAS }, - { llmdnn_bf16, llmdnn_s8, llmdnn_bf16, llmdnn_bf16, DEQUANT_GELU }, - { llmdnn_bf16, llmdnn_s8, llmdnn_bf16, llmdnn_bf16, DEQUANT_BIAS_GELU }, - { llmdnn_bf16, llmdnn_s8, llmdnn_f32, llmdnn_f32, DEQUANT }, - { llmdnn_bf16, llmdnn_s8, llmdnn_f32, llmdnn_f32, DEQUANT_BIAS }, - { llmdnn_bf16, llmdnn_s8, llmdnn_f32, llmdnn_f32, DEQUANT_GELU }, - { llmdnn_bf16, llmdnn_s8, llmdnn_f32, llmdnn_f32, DEQUANT_BIAS_GELU }, - { llmdnn_bf16, llmdnn_s8, llmdnn_bf16, llmdnn_f32, DEQUANT }, - { llmdnn_bf16, llmdnn_s8, llmdnn_bf16, llmdnn_f32, DEQUANT_BIAS }, - { llmdnn_bf16, llmdnn_s8, llmdnn_bf16, llmdnn_f32, DEQUANT_GELU }, - { llmdnn_bf16, llmdnn_s8, llmdnn_bf16, llmdnn_f32, DEQUANT_BIAS_GELU }, + { llmdnn_bf16, llmdnn_u8, llmdnn_f32, llmdnn_u8, DEQUANT }, + { llmdnn_bf16, llmdnn_u8, llmdnn_f32, llmdnn_u8, DEQUANT_BIAS }, + { llmdnn_bf16, llmdnn_u8, llmdnn_f32, llmdnn_u8, DEQUANT_GELU }, + { llmdnn_bf16, llmdnn_u8, llmdnn_f32, llmdnn_u8, DEQUANT_BIAS_GELU }, + { llmdnn_bf16, llmdnn_u8, llmdnn_bf16, llmdnn_u8, DEQUANT }, + { llmdnn_bf16, llmdnn_u8, llmdnn_bf16, llmdnn_u8, DEQUANT_BIAS }, + { llmdnn_bf16, llmdnn_u8, llmdnn_bf16, llmdnn_u8, DEQUANT_GELU }, + { llmdnn_bf16, llmdnn_u8, llmdnn_bf16, llmdnn_u8, DEQUANT_BIAS_GELU }, }; // M, N, K const std::vector shapes = { // normal - {256, 48, 448}, - // k tail - {256, 48, 449}, - // M tail == unroll 8 - {256 + 8, 48, 449}, - // M tail == unroll 8 + 2 - {256 + 10, 48, 449}, - // N tail - {256, 40, 448}, + {256, 128, 448}, + // M < 16 + {15, 129, 447}, + {15, 129, 448}, + // M in (16, 32] + {31, 129, 447}, + {31, 129, 448}, // all tail - {256 + 9, 47, 449}, + {256 + 9, 129, 449}, + {256 + 9, 129, 448}, // gemv, K <= 64(32)*6 {256, 1, 80}, }; diff --git a/tests/src/test_utility_kernel_repack1x2_avx512.cpp b/tests/src/test_utility_kernel_repack1x2_avx512.cpp index fbee232..6d48cfd 100644 --- a/tests/src/test_utility_kernel_repack1x2_avx512.cpp +++ b/tests/src/test_utility_kernel_repack1x2_avx512.cpp @@ -133,12 +133,67 @@ class RepackTest : public TestWithParam { testone(64 + 16 + 3, 128 + 16 + 5, "alltail"); } + void test_wc() { + auto testone = [] (int k, int n, std::string prefix) { + tensor2D A(k, n, true); + + // get ref result + tensor2D A_ref; + tensor2D A_bf16(k, n, true), A_ref_bf16; + for (int i = 0; i < k * n; i++) { + A.data[i] = i % 23; + A_bf16.data[i] = ov::bfloat16(i % 23); + } + amx_kernel::repackB_1x2(A_bf16, false, A_ref_bf16, true); + A_ref.resize(A_ref_bf16.dims[0], A_ref_bf16.dims[1], true); + for (int i = 0; i < A_ref_bf16.dims[0] * A_ref_bf16.dims[1]; i++) { + A_ref.data[i] = static_cast(float(A_ref_bf16.data[i])); + } + + tensor2D AT = A.Tr(true); + tensor2D A_out, AT_out; + amx_kernel::repackB_1x2_compressed(A, false, A_out, true); + amx_kernel::repackB_1x2_compressed(AT, true, AT_out, true); + ASSERT_TRUE(A_out == A_ref) << " " << prefix << " without transform K: " << k << " N: " << n; + ASSERT_TRUE(AT_out == A_ref) << " " << prefix << " with transform K: " << k << " N: " << n; + }; + // n tail: transpose case needs from 1 to 31, without transpose needs one + int k = 32; + int n; + for (n = 1; n < 32; n++) { + testone(k, n, "ntail"); + } + for (n = 32 + 1; n < 32 + 32; n++) { + testone(k, n, "ntail"); + } + // k tail: transpose case needs 1, without transpose needs from 1 to 31 + n = 32; + for (k = 1; k < 32; k++) { + testone(k, n, "ktail"); + } + for (k = 32 + 1; k < 32 + 32; k++) { + testone(k, n, "ktail"); + } + // k, n normal + testone(32, 32, "normal"); + testone(64, 128, "normal"); + // k, n tail + testone(64, 128 + 5, "ntail"); + testone(64 + 3, 128, "ktail"); + testone(64 + 3, 128 + 5, "alltail"); + testone(64, 128 + 16 + 5, "ntail"); + testone(64 + 16 + 3, 128, "ktail"); + testone(64 + 16 + 3, 128 + 16 + 5, "alltail"); + } + std::pair _types; }; TEST_P(RepackTest, Func) { if (_types.first == llmdnn_s8 && _types.second == llmdnn_s8) { test(); + } else if (_types.first == llmdnn_u8 && _types.second == llmdnn_u8) { + test_wc(); } else if (_types.first == llmdnn::llmdnn_bf16 && _types.second == llmdnn_bf16) { test(); } else { @@ -147,6 +202,7 @@ TEST_P(RepackTest, Func) { } const std::vector> types = { + {llmdnn_u8, llmdnn_u8}, // compress weight {llmdnn_s8, llmdnn_s8}, {llmdnn_bf16, llmdnn_bf16}, {llmdnn_f32, llmdnn_bf16}, From 4f3ba528b80883942a291df36dc90f96c4dda74a Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Tue, 19 Sep 2023 00:08:06 +0800 Subject: [PATCH 53/54] opt for multi query --- src/mha_gpt_amx.cpp | 12 +++++++----- src/mm_kernel_common_amx.hpp | 22 ++++++++++------------ 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/src/mha_gpt_amx.cpp b/src/mha_gpt_amx.cpp index e47c7bc..e3f27fb 100644 --- a/src/mha_gpt_amx.cpp +++ b/src/mha_gpt_amx.cpp @@ -116,6 +116,8 @@ void mha_gpt_impl_amx::mha_bf16(const tensor& q, const tensor& k, const tensor& auto head_size = q.m_dims[3]; auto key_seq_len = k.m_dims[2]; bool is_bloom = k.m_strides[3] > k.m_strides[2]; + auto h_group_num = k.m_dims[1]; + size_t h_each_group_len = head_num / h_group_num; uint8_t* out = output.data(); @@ -130,8 +132,8 @@ void mha_gpt_impl_amx::mha_bf16(const tensor& q, const tensor& k, const tensor& if (use_gemv) { parallel_for2d(batch, head_num, [&](size_t thread_id, size_t i0, size_t i1) { auto q_sub = &q.at({i0, i1}); - auto k_sub = &k.at({i0, i1}); - auto v_sub = &v.at({i0, i1}); + auto k_sub = &k.at({i0, i1 / h_each_group_len}); + auto v_sub = &v.at({i0, i1 / h_each_group_len}); auto mat0_out = reinterpret_cast(_buffer_mat0_out + thread_id * _buffer_mat0_out_size); auto mat1_out = reinterpret_cast(_buffer_mat1_out + thread_id * _buffer_mat1_out_size); @@ -177,8 +179,8 @@ void mha_gpt_impl_amx::mha_bf16(const tensor& q, const tensor& k, const tensor& // k: [batch, head_num, key_seq_len, head_size] // v: [batch, head_num, value_seq_len, head_size] auto q_sub = &q.at({i0, i1, seq_start}); - auto k_sub = &k.at({i0, i1}); - auto v_sub = &v.at({i0, i1}); + auto k_sub = &k.at({i0, i1 / h_each_group_len}); + auto v_sub = &v.at({i0, i1 / h_each_group_len}); auto mat0_out = reinterpret_cast(_buffer_mat0_out + thread_id * _buffer_mat0_out_size); auto mat1_out = reinterpret_cast(_buffer_mat1_out + thread_id * _buffer_mat1_out_size); @@ -279,7 +281,7 @@ status_t mha_gpt_impl_amx::exec(const tensor& q, const tensor& k, const tensor& auto key_seq_len = k.m_dims[2]; if (!(batch == k.m_dims[0] && batch == v.m_dims[0] && - head_num == k.m_dims[1] && head_num == v.m_dims[1] && + k.m_dims[1] == v.m_dims[1] && key_seq_len == v.m_dims[2] && head_size == k.m_dims[3] && head_size == v.m_dims[3])) { DEBUG_LOG << "dim of q,k,v is error.\n"; diff --git a/src/mm_kernel_common_amx.hpp b/src/mm_kernel_common_amx.hpp index 5cecc98..4d97802 100644 --- a/src/mm_kernel_common_amx.hpp +++ b/src/mm_kernel_common_amx.hpp @@ -2476,21 +2476,13 @@ template<> struct Matmul { tensor2D internalBI8; - // wei_buff is ping-pong buffer containing ov::bfloat16 weights decompressed on the fly. - tensor2D weiBuff; - bool constB; bool transposeB; constexpr static int kStep = 32; - // 2x2 C tiles buffer - // most usecase requires post-processing with AVX, thus buffC - // is used to transfer data to AVX register - tensor2D buffC; - Matmul(bool constB = false, bool transposeB = false) : - constB(constB), transposeB(transposeB), buffC(32, 32) {} + constB(constB), transposeB(transposeB) {} float* dequant_scale_B; float* zp; @@ -2500,6 +2492,14 @@ struct Matmul { tensor2D & _matB, int n0, int n1, PP ppkernel) { + alignas(64) float buff[32 * 32]; + // wei_buff is ping-pong buffer containing ov::bfloat16 weights decompressed on the fly. + alignas(64) ov::bfloat16 weiBuff[32 * 2 * 32]; + // 2x2 C tiles buffer + // most usecase requires post-processing with AVX, thus buffC + // is used to transfer data to AVX register + tensor2D buffC(32, 32, buff, 32 * sizeof(float)); + auto matB = getSubMatB(_matB, n0, n1, transposeB); int M = matA.dims[0]; int K = matA.dims[1]; @@ -2523,9 +2523,7 @@ struct Matmul { //constexpr int prefetch_ahead = 64*1024; tileconfig_t tfg(1, 0, {M,M,M,16,16}, 64); auto * pBint = reinterpret_cast(&internalBI8[0]); - auto & B2buff = weiBuff; - B2buff.resize(32*2, 32); - auto * const pB = &B2buff[0]; + auto * const pB = weiBuff; auto * pBsrc = pB + (32*32) * 0; auto * pBdst = pB + (32*32) * 1; auto * const pC0 = &buffC[0]; From aa2c57b60c9fc827878bfa029e1024f96edb2809 Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Fri, 22 Sep 2023 15:13:22 +0800 Subject: [PATCH 54/54] fc M dimension splitting --- src/fc_amx.cpp | 113 +++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 96 insertions(+), 17 deletions(-) diff --git a/src/fc_amx.cpp b/src/fc_amx.cpp index ff83294..8255851 100644 --- a/src/fc_amx.cpp +++ b/src/fc_amx.cpp @@ -31,6 +31,7 @@ struct fc_impl_amx : public fc::impl { void pack_weight(const tensor& w) override; status_t exec(const tensor& input, const tensor& output, const tensor& dq, const tensor& q, const tensor& bias) override; void associate_thread_numa(const llm_vector& numa_nodes); + void init_m_block(); void init_weight_compress_param(); fc_create_param _create_param; @@ -40,6 +41,8 @@ struct fc_impl_amx : public fc::impl { llm_vector _numa_nodes; // numa nodes size_t _thread_nums; // thread numbers size_t _N_in_one_numa; // N on each numa node + size_t _m_block_num_idea; // idea m block number for best thread balance + size_t _n_block_num_idea; // idea n block number for best thread balance llm_vector _thread_nums_in_one_numa; // thread numbers in one numa node size_t _K_align; struct work_info { @@ -104,6 +107,27 @@ bool fc_impl_amx::init(const fc_create_param& param) { return ret; } +void fc_impl_amx::init_m_block() { + size_t div = 2; + auto work_amount = _N_in_one_numa / 32; + size_t work = work_amount; + // worse case: M block number is _thread_num + size_t threads = std::max(1ul, _thread_nums / _numa_nodes.size()); + while (div <= work_amount) { + // if work and threads can be divided by div, M block number can be dived by div + if (work % div == 0 && threads % div == 0) { + threads /= div; + work /= div; + } else { + div++; + } + if (work < div) + break; + } + _m_block_num_idea = threads; + _n_block_num_idea = _thread_nums / _numa_nodes.size() / threads; +} + void fc_impl_amx::associate_thread_numa(const llm_vector& numa_nodes) { _thread_infos.resize(_thread_nums); struct int_atomic { @@ -214,6 +238,7 @@ void fc_impl_amx::pack_weight(const tensor& w) { auto dst = _weights[numa_id] + n0_in_one_numa * _K_align * get_precision_size(_create_param.dt_b); fc_kernel_pack_weight_to_dst(_kernel[id], w.data(), dst, w.m_dtype, N, K, w.stride(0), n0, n1); }); + init_m_block(); } status_t fc_impl_amx::exec(const tensor& input, const tensor& output, const tensor& dq, const tensor& q, const tensor& bias) { @@ -225,24 +250,78 @@ status_t fc_impl_amx::exec(const tensor& input, const tensor& output, const tens auto M = input.size(0); auto N = output.size(1); auto K = input.size(1); - auto work_amount_in_one_numa = _N_in_one_numa / 32; - parallel_for(_thread_nums, [&](size_t id) { - auto numa_id = _thread_infos[id].numa_id; - auto thread_no_in_one_numa = _thread_infos[id].thread_no_in_one_numa; - size_t start, end; - splitter(work_amount_in_one_numa, static_cast(_thread_nums_in_one_numa[numa_id]), thread_no_in_one_numa, start, end); - size_t n0_in_one_numa = start * 32; - size_t n1_in_one_numa = std::min(end * 32, _N_in_one_numa); - if (n0_in_one_numa >= _N_in_one_numa) return; - auto n0 = n0_in_one_numa + _N_in_one_numa * numa_id; - auto n1 = n1_in_one_numa + _N_in_one_numa * numa_id; - n1 = std::min(n1, N); - if (n0 >= n1) return; + auto work_amount_n_in_one_numa = _N_in_one_numa / 32; + if (M < 32) { + parallel_for(_thread_nums, [&](size_t id) { + auto numa_id = _thread_infos[id].numa_id; + auto thread_no_in_one_numa = _thread_infos[id].thread_no_in_one_numa; + size_t start, end; + splitter(work_amount_n_in_one_numa, static_cast(_thread_nums_in_one_numa[numa_id]), thread_no_in_one_numa, start, end); + size_t n0_in_one_numa = start * 32; + size_t n1_in_one_numa = std::min(end * 32, _N_in_one_numa); + if (n0_in_one_numa >= _N_in_one_numa) return; + auto n0 = n0_in_one_numa + _N_in_one_numa * numa_id; + auto n1 = n1_in_one_numa + _N_in_one_numa * numa_id; + n1 = std::min(n1, N); + if (n0 >= n1) return; - auto weight = _weights[numa_id] + n0_in_one_numa * _K_align * get_precision_size(_create_param.dt_b); - fc_kernel_execute(_kernel[id], input.data(), weight, output.data(), input.stride(0), - output.stride(0), M, N, K, n0, n1, dq.data(), q.data(), bias.data()); - }); + auto weight = _weights[numa_id] + n0_in_one_numa * _K_align * get_precision_size(_create_param.dt_b); + fc_kernel_execute(_kernel[id], input.data(), weight, output.data(), input.stride(0), + output.stride(0), M, N, K, n0, n1, dq.data(), q.data(), bias.data()); + }); + } else { + // row number of each block + auto m_row = rndup(M, _m_block_num_idea) / _m_block_num_idea; + auto work_amount_n_block = _n_block_num_idea; + // at least 32 rows + if (m_row < 32) { + m_row = 32; + work_amount_n_block = work_amount_n_in_one_numa; + } + auto work_amount_m = rndup(M, m_row) / m_row; + auto work_amount = work_amount_n_block * work_amount_m; + auto n_block_in_one_numa = work_amount_n_in_one_numa / work_amount_n_block; + parallel_for(_thread_nums, [&](size_t id) { + auto numa_id = _thread_infos[id].numa_id; + auto thread_no_in_one_numa = _thread_infos[id].thread_no_in_one_numa; + size_t start, end; + splitter(work_amount, static_cast(_thread_nums_in_one_numa[numa_id]), thread_no_in_one_numa, start, end); + size_t m_start{0}, n_start{0}; + if (M > N) + parallel_it_init(start, m_start, work_amount_m, n_start, work_amount_n_block); + else + parallel_it_init(start, n_start, work_amount_n_block, m_start, work_amount_m); + for (auto work = start; work < end; work++) { + size_t n0_in_one_numa = n_start * n_block_in_one_numa * 32; + size_t n1_in_one_numa = n0_in_one_numa + n_block_in_one_numa * 32; //std::min(n_end * 32, _N_in_one_numa); + //if (n0_in_one_numa >= _N_in_one_numa) return; + auto n0 = n0_in_one_numa + _N_in_one_numa * numa_id; + auto n1 = n1_in_one_numa + _N_in_one_numa * numa_id; + n1 = std::min(n1, N); + if (n0 >= n1) continue; + + size_t m0 = m_start * m_row; + size_t m1 = std::min(m0 + m_row, M); + size_t m = m1 - m0; + + auto weight = _weights[numa_id] + n0_in_one_numa * _K_align * get_precision_size(_create_param.dt_b); + fc_kernel_execute(_kernel[id], + input.data() + m0 * input.stride(0), + weight, + output.data() + m0 * output.stride(0), + input.stride(0), + output.stride(0), + m, N, K, n0, n1, + dq.data(), + q.data(), + bias.data()); + if (M > N) + parallel_it_step(m_start, work_amount_m, n_start, work_amount_n_block); + else + parallel_it_step(n_start, work_amount_n_block, m_start, work_amount_m); + } + }); + } return status_t::status_ok; }