diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..5e92026 --- /dev/null +++ b/.gitignore @@ -0,0 +1,56 @@ +# Compiled Object files +**/.DS_Store +*.slo +*.lo +*.o +*.obj + +# Precompiled Headers +*.gch +*.pch + +# Compiled Dynamic libraries +*.so +*.dylib +*.dll + +# Fortran module files +*.mod +*.smod + +# Compiled Static libraries +*.lai +*.la +*.a +*.lib + +# Executables +*.exe +*.out +*.app + +**/cmake-build-debug +**/CMakeCache.txt +**/cmake_install.cmake +**/install_manifest.txt +**/CMakeFiles/ +**/CTestTestfile.cmake +**/Makefile +**/*.cbp +**/CMakeScripts +**/compile_commands.json + + +## Local + +build/**/* +**/build/**/* +out/* +lib/* +bin/* +test/test_runner +.vs +.cache +__pycache__ +dist +*.egg-info \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..7d2217c --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,65 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +cmake_minimum_required(VERSION 3.13) + +project(root) + +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) +option(CPU_EXTENSIONS_BUILD_TESTS "Build with tests" ON) + +message(INFO "--------------------------------") +message(STATUS "Build with tests: ${CPU_EXTENSIONS_BUILD_TESTS}") +message(INFO "--------------------------------") + +set(CMAKE_CXX_STANDARD 17) +if(MSVC) + # TODO: validate + if(MSVC_VERSION VERSION_LESS 1928) + message(FATAL_ERROR "Insufficient msvc compiler version, current ${MSVC_VERSION}, minimum 1928.") + endif() + # Force to always compile with W4 + if(CMAKE_CXX_FLAGS MATCHES "/W[0-4]") + string(REGEX REPLACE "/W[0-4]" "/W4" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") + else() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /W4") + endif() +elseif(CMAKE_COMPILER_IS_GNUCC OR CMAKE_COMPILER_IS_GNUCXX) + if(CMAKE_CXX_COMPILER_VERSION VERSION_LESS "11.2") + message(FATAL_ERROR "Insufficient gcc compiler version, current ${CMAKE_CXX_COMPILER_VERSION}, minimum 11.2.") + endif() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=sapphirerapids -flax-vector-conversions") +elseif(OV_COMPILER_IS_CLANG) + if(CMAKE_CXX_COMPILER_VERSION VERSION_LESS "12") + message(FATAL_ERROR "Insufficient clang compiler version, current ${CMAKE_CXX_COMPILER_VERSION}, minimum 12.") + endif() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=sapphirerapids -flax-vector-conversions") +elseif(CMAKE_CXX_COMPILER_ID STREQUAL "IntelLLVM") + if(CMAKE_CXX_COMPILER_VERSION VERSION_LESS "2023.0") + message(FATAL_ERROR "Insufficient intel compiler version, current ${CMAKE_CXX_COMPILER_VERSION}, minimum 2023.0.") + endif() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=sapphirerapids") +endif() + +if (NOT CMAKE_RUNTIME_OUTPUT_DIRECTORY) + set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) + set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) + set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) +endif() +add_subdirectory(src) +if (CPU_EXTENSIONS_BUILD_TESTS) + add_subdirectory(tests) +endif() + +# Get the latest commit hash +execute_process( + COMMAND git rev-parse HEAD + WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR} + OUTPUT_VARIABLE GIT_HASH + OUTPUT_STRIP_TRAILING_WHITESPACE + ) +file(WRITE ${CMAKE_BINARY_DIR}/git-state.txt ${GIT_HASH}) +install(FILES + ${CMAKE_BINARY_DIR}/git-state.txt + DESTINATION ${CMAKE_INSTALL_PREFIX}) diff --git a/README.md b/README.md new file mode 100644 index 0000000..0129645 --- /dev/null +++ b/README.md @@ -0,0 +1,9 @@ +# About CPU_Extensions +CPU_Extensions is a compute library containing processor optimized kernels code. + +# Unit tests for CPU_Extensions +## Tests for kernels +Tests for kernels are written in gtest under tests\src, use ./cpu_extensions_tests to run it. + +## Tests for complex features +Some features have many steps and the reference could not be easily written using gtest. For these features can use python to generate the reference. The directory tests\script contains these test, please refer [test in python](./tests/script/README.md). \ No newline at end of file diff --git a/include/llm_emb_gpt.hpp b/include/llm_emb_gpt.hpp new file mode 100644 index 0000000..dfc4947 --- /dev/null +++ b/include/llm_emb_gpt.hpp @@ -0,0 +1,59 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include "llm_types.hpp" + +namespace llmdnn { + +class emb_gpt { +public: + struct create_param { + size_t num_heads; + size_t head_size; + size_t head_size_aligned; // better to aligned to 64 bytes for best performance, apply for qkv + // supported (qkv, dst): (bf16, bf16) + data_type_t qkv_precision; + data_type_t dst_precision; + size_t rotary_dims; + bool use_position2d; // chatglm true, other false + }; + struct exec_param { + size_t batch; + size_t query_seq_len; + size_t past_seq_len; + uint8_t* q; // shape: [batch, query_seq_len, hidden size], inner stride is ldq + uint8_t* k; // shape: [batch, query_seq_len, hidden size], inner stride is ldk + uint8_t* v; // shape: [batch, query_seq_len, hidden size], inner stride is ldv + size_t ldq; // inner stride of q + size_t ldk; // inner stride of k + size_t ldv; // inner stride of v + uint8_t* query_dst; // rotary embbeding dst + uint8_t** layer_past_key_src; // past key src + uint8_t** layer_past_value_src; // past value src + uint8_t** layer_past_key_dst; // past key dst, if layer_past_key_src!=layer_past_key_dst, will copy layer_past_key_src to layer_past_key_dst + uint8_t** layer_past_value_dst; // past value dst, if layer_past_value!=layer_past_value_dst, will copy layer_past_value to layer_past_value_dst + float* cos; // cos lookup table, shape: [max_seq_len, rotary_dims] + float* sin; // sin lookup table, shape: [max_seq_len, rotary_dims] + int* position2d_ids; // shape: [batch, 2, query_seq_len] + size_t head_stride_in_kv; // kv stride for next head; kv may be preallocated a big buffer + }; + + emb_gpt(); + bool create(const create_param& param); + void exec(const exec_param& param); + + struct impl { + virtual bool create(const create_param& param) = 0; + virtual void exec(const exec_param& param) = 0; + }; +protected: + std::shared_ptr _impl; +}; + +} diff --git a/include/llm_fc.hpp b/include/llm_fc.hpp new file mode 100644 index 0000000..1272c1b --- /dev/null +++ b/include/llm_fc.hpp @@ -0,0 +1,74 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "llm_types.hpp" + +namespace llmdnn { + +typedef enum { + NONE = 0, + DEQUANT = 1 << 0, + BIAS = 1 << 1, + GELU_ERF = 1 << 2, + GELU_TANH = 1 << 3, + QUANT = 1 << 4, + GELU = GELU_ERF, // default is ERF + + BIAS_GELU = BIAS | GELU, + DEQUANT_BIAS_GELU = DEQUANT | BIAS_GELU, + DEQUANT_BIAS_GELU_QUANT = DEQUANT_BIAS_GELU | QUANT, + DEQUANT_BIAS_QUANT = DEQUANT | BIAS | QUANT, + DEQUANT_GELU_QUANT = DEQUANT | GELU | QUANT, + DEQUANT_QUANT = DEQUANT | QUANT, + + DEQUANT_GELU = DEQUANT | GELU, + DEQUANT_BIAS = DEQUANT | BIAS, + + BIAS_GELU_TANH = BIAS | GELU_TANH, + DEQUANT_BIAS_GELU_TANH = DEQUANT | BIAS_GELU_TANH, + DEQUANT_BIAS_GELU_TANH_QUANT = DEQUANT_BIAS_GELU_TANH | QUANT, + DEQUANT_GELU_TANH_QUANT = DEQUANT | GELU_TANH | QUANT, + + DEQUANT_GELU_TANH = DEQUANT | GELU_TANH, +} postops_types; + +struct fc_create_param { + data_type_t dt_a; + data_type_t dt_b; + data_type_t dt_c; + bool b_is_trans; + postops_types postops_type; + // for weight compression + float q; + float dq; +}; + +struct fc_kernel; + +/// Generates a mm kernel based on param +/// +/// @param mm Output kernel +/// @param param kernel parameters, supported: +/// fc: (s8,s8,s8),dq,[bias],[gelu],q +/// fc: (s8,s8,bf16),dq,[bias],[gelu] +/// fc: (s8,s8,f32),dq,[bias],[gelu] +/// fc: (bf16,bf16,bf16),[bias],[gelu] +/// fc: (bf16,bf16,f32),[bias],[gelu] +/// fc: (bf16,s8,f32),dq,[bias],[gelu] +/// fc: (bf16,s8,bf16),dq,[bias],[gelu] +/// +bool fc_kernel_create(fc_kernel** mm, const fc_create_param* param); +void fc_kernel_destroy(const fc_kernel* mm); +void fc_kernel_execute(const fc_kernel* mm, + void* ptr_a, void* ptr_b, void* ptr_c, size_t lda, size_t ldb, size_t ldc, + size_t M, size_t N, size_t K, size_t n_start, size_t n_end, + float* dq=nullptr, float* q=nullptr, float* bias=nullptr); + +/// weight compression +/// compute weight min/max once, set q, dq for each fc_kernel instance +void fc_kernel_bf16w8_get_q_dq(size_t K, size_t N, size_t stride, void* ptr, float* q, float* dq); + +} diff --git a/include/llm_mha_gpt.hpp b/include/llm_mha_gpt.hpp new file mode 100644 index 0000000..befa348 --- /dev/null +++ b/include/llm_mha_gpt.hpp @@ -0,0 +1,90 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include "llm_types.hpp" + +namespace llmdnn { + +// pattern is: +// query:[batch, num_heads, query_seq_len, head_size] key:[batch, num_heads, key_seq_len, head_size] +// \ | +// \ Transpose0: [batch, num_heads, head_size, key_seq_len] +// \ / +// \ / +// \ / +// MatMul0: [batch, num_heads, query_seq_len, key_seq_len] +// | +// | norm_factor(const): [1] +// | / +// Multiply: [batch, num_heads, query_seq_len, key_seq_len] +// | +// | causal_mask: [1, 1, query_seq_len, key_seq_len] +// | / +// Select(only for 1x300): [batch, num_heads, query_seq_len, key_seq_len] +// | +// | attention_mask:[batch, 1, 1, key_seq_len] +// | / +// Add: [batch, num_heads, query_seq_len, key_seq_len] +// | +// SoftMax: [batch, num_heads, query_seq_len, key_seq_len] +// | +// \ value:[batch, num_heads, key_seq_len, head_size] +// \ / +// MatMul1: [batch, num_heads, query_seq_len, head_size] +// | +// Transpose1(only for 1x300): [batch, query_seq_len, num_heads * head_size] +class mha_gpt { +public: + struct create_param { + size_t num_heads; + size_t head_size; + size_t head_size_aligned; // better to aligned to 64 bytes for best performance, apply for qkv + size_t max_seq_len; // max seq length for computing the size of matmul tmp result + float normal_factor; + // supported (qkv, dst): (bf16, bf16), (s8, s8) + data_type_t qkv_precision; + data_type_t dst_precision; + }; + struct exec_param { + size_t batch; + size_t query_seq_len; + size_t key_seq_len; + bool is_causal_in_attention; // causal mask is fused in attention mask: chatglm uses it. + uint8_t* q; // q buffer, compact, shape: [batch, num_heads, query_seq_len, head_size] + uint8_t** k; // k buffer, k[N] stands different batch which may be discreted + // k[0] shape: [batch, num_heads, key_seq_len, head_size] + uint8_t** v; // v buffer, v[N] stands different batch which may be discreted + // v[0] shape: [batch, num_heads, value_seq_len, head_size] + float* attention_mask; // attention mask, attention_mask[0] shape: + // [batch, 1, 1, key_seq_len], when is_causal_in_attention is false + // [batch, 1, query_seq_len, key_seq_len], when is_causal_in_attention is true + uint8_t* attn_output; // output, compact, shape: [batch, query_seq_len, num_heads * head_size] + size_t head_stride_in_kv; // kv stride for next head; kv may be preallocated a big buffer + // expected quant schema: + // q,k,v use per tensor quant, attn_output may use per tensor/channel quant + float q_dequant; + float k_dequant; + float v_dequant; + float qk_quant; + std::vector qkv_quant; // size==1 per tensor, size==head_size per channel + }; + + mha_gpt(); + bool create(const create_param& param); + void exec(const exec_param& param); + + struct impl { + virtual bool create(const create_param& param) = 0; + virtual void exec(const exec_param& param) = 0; + }; +protected: + std::shared_ptr _impl; +}; + +} diff --git a/include/llm_mm.hpp b/include/llm_mm.hpp new file mode 100644 index 0000000..bb187f7 --- /dev/null +++ b/include/llm_mm.hpp @@ -0,0 +1,35 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "llm_types.hpp" + +namespace llmdnn { + +struct mm_create_param { + data_type_t dt_a; + data_type_t dt_b; + bool b_is_gemv; // true if matrix b is vector. Shape: a[M,K], b[K,1], c[M,1] + bool b_is_trans; +}; + +struct mm_kernel; + +/// Generates a mm kernel based on param +/// +/// @param mm Output kernel +/// @param param kernel parameters, supported: +/// matmul: (u8/s8,s8,f32) +/// gemv: (s8,s8,f32) +/// matmul: (bf16,bf16,f32) +/// gemv: (bf16,bf16,f32) +/// +bool mm_kernel_create(mm_kernel** mm, const mm_create_param* param); +void mm_kernel_destroy(const mm_kernel* mm); + +void mm_kernel_execute(const mm_kernel* mm, void* ptr_a, void* ptr_b, void* ptr_c, size_t lda, size_t ldb, size_t ldc, + size_t M, size_t N, size_t K); + +} diff --git a/include/llm_types.hpp b/include/llm_types.hpp new file mode 100644 index 0000000..3ace6d6 --- /dev/null +++ b/include/llm_types.hpp @@ -0,0 +1,36 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +namespace llmdnn { + +// from oneDNN +/// Data type specification +typedef enum { + /// Undefined data type, used for empty memory descriptors. + dnnl_data_type_undef = 0, + /// 16-bit/half-precision floating point. + dnnl_f16 = 1, + /// non-standard 16-bit (bfloat16 w/ 7 bit mantissa) floating point. + dnnl_bf16 = 2, + /// 32-bit/single-precision floating point. + dnnl_f32 = 3, + /// 32-bit signed integer. + dnnl_s32 = 4, + /// 8-bit signed integer. + dnnl_s8 = 5, + /// 8-bit unsigned integer. + dnnl_u8 = 6, + /// 64-bit/double-precision floating point. + dnnl_f64 = 7, + + /// Parameter to allow internal only data_types without undefined behavior. + /// This parameter is chosen to be valid for so long as sizeof(int) >= 2. + dnnl_data_type_max = 0x7fff, +} data_type_t; + +} \ No newline at end of file diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt new file mode 100644 index 0000000..93c0e79 --- /dev/null +++ b/src/CMakeLists.txt @@ -0,0 +1,44 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +cmake_minimum_required(VERSION 3.13) +project(cpu_extensions) + +file(GLOB_RECURSE ${PROJECT_NAME}_SOURCE_FILES CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp) + +add_library(${PROJECT_NAME} STATIC ${${PROJECT_NAME}_SOURCE_FILES}) +set_target_properties(${PROJECT_NAME} PROPERTIES + POSITION_INDEPENDENT_CODE ON) +target_include_directories(${PROJECT_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} + PUBLIC $ + $/${CMAKE_INSTALL_INCLUDEDIR}>) + +set(CMAKE_DST lib/cmake/${PROJECT_NAME}) +# header files +include(GNUInstallDirs) +install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../include/ + DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}) + +# library file +install(TARGETS ${PROJECT_NAME} + EXPORT ${PROJECT_NAME}Targets + LIBRARY DESTINATION lib + ARCHIVE DESTINATION lib + RUNTIME DESTINATION bin + INCLUDES DESTINATION include) + +# config file +install(EXPORT ${PROJECT_NAME}Targets + FILE ${PROJECT_NAME}Config.cmake + NAMESPACE ${PROJECT_NAME}:: + DESTINATION ${CMAKE_DST}) + +# version file +include(CMakePackageConfigHelpers) +write_basic_package_version_file(${CMAKE_CURRENT_BINARY_DIR}/${PROJECT_NAME}ConfigVersion.cmake + VERSION 1.0.0 + COMPATIBILITY AnyNewerVersion) +install(FILES + ${CMAKE_CURRENT_BINARY_DIR}/${PROJECT_NAME}ConfigVersion.cmake + DESTINATION ${CMAKE_DST}) diff --git a/src/common/bf16.hpp b/src/common/bf16.hpp new file mode 100644 index 0000000..35f42cc --- /dev/null +++ b/src/common/bf16.hpp @@ -0,0 +1,249 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include +#include +#include + +#define ROUND_MODE_TO_NEAREST_EVEN + +#define OPENVINO_API + +namespace ov { +class OPENVINO_API bfloat16 { +public: + constexpr bfloat16() : m_value{0} {} + bfloat16(float value) : m_value { +#if defined ROUND_MODE_TO_NEAREST + round_to_nearest(value) +#elif defined ROUND_MODE_TO_NEAREST_EVEN + round_to_nearest_even(value) +#elif defined ROUND_MODE_TRUNCATE + truncate(value) +#else +# error "ROUNDING_MODE must be one of ROUND_MODE_TO_NEAREST, ROUND_MODE_TO_NEAREST_EVEN, or ROUND_MODE_TRUNCATE" +#endif + } + {} + + template + explicit bfloat16(I value) : m_value{bfloat16{static_cast(value)}.m_value} {} + + std::string to_string() const; + size_t size() const; + template + bool operator==(const T& other) const; + template + bool operator!=(const T& other) const { + return !(*this == other); + } + template + bool operator<(const T& other) const; + template + bool operator<=(const T& other) const; + template + bool operator>(const T& other) const; + template + bool operator>=(const T& other) const; + template + bfloat16 operator+(const T& other) const; + template + bfloat16 operator+=(const T& other); + template + bfloat16 operator-(const T& other) const; + template + bfloat16 operator-=(const T& other); + template + bfloat16 operator*(const T& other) const; + template + bfloat16 operator*=(const T& other); + template + bfloat16 operator/(const T& other) const; + template + bfloat16 operator/=(const T& other); + operator float() const { + uint32_t tmp = 0; + uint32_t* ptmp = &tmp; + *ptmp = (static_cast(m_value) << 16); + const float* f = reinterpret_cast(ptmp); + return *f; + } + + static std::vector to_float_vector(const std::vector&); + static std::vector from_float_vector(const std::vector&); + static constexpr bfloat16 from_bits(uint16_t bits) { + return bfloat16(bits, true); + } + uint16_t to_bits() const; + friend std::ostream& operator<<(std::ostream& out, const bfloat16& obj) { + out << static_cast(obj); + return out; + } + +#define cu32(x) (F32(x).i) + + static uint16_t round_to_nearest_even(float x) { + return static_cast((cu32(x) + ((cu32(x) & 0x00010000) >> 1)) >> 16); + } + + static uint16_t round_to_nearest(float x) { + return static_cast((cu32(x) + 0x8000) >> 16); + } + + static uint16_t truncate(float x) { + return static_cast((cu32(x)) >> 16); + } + +private: + constexpr bfloat16(uint16_t x, bool) : m_value{x} {} + union F32 { + F32(float val) : f{val} {} + F32(uint32_t val) : i{val} {} + float f; + uint32_t i; + }; + + uint16_t m_value; +}; + +#if defined(_MSC_VER) +# pragma warning(push) +# pragma warning(disable : 4756) +#endif +template +bool bfloat16::operator==(const T& other) const { +#if defined(__GNUC__) +# pragma GCC diagnostic push +# pragma GCC diagnostic ignored "-Wfloat-equal" +#endif + return (static_cast(*this) == static_cast(other)); +#if defined(__GNUC__) +# pragma GCC diagnostic pop +#endif +} + +template +bool bfloat16::operator<(const T& other) const { + return (static_cast(*this) < static_cast(other)); +} + +template +bool bfloat16::operator<=(const T& other) const { + return (static_cast(*this) <= static_cast(other)); +} + +template +bool bfloat16::operator>(const T& other) const { + return (static_cast(*this) > static_cast(other)); +} + +template +bool bfloat16::operator>=(const T& other) const { + return (static_cast(*this) >= static_cast(other)); +} + +template +bfloat16 bfloat16::operator+(const T& other) const { + return {static_cast(*this) + static_cast(other)}; +} + +template +bfloat16 bfloat16::operator+=(const T& other) { + return *this = *this + other; +} + +template +bfloat16 bfloat16::operator-(const T& other) const { + return {static_cast(*this) - static_cast(other)}; +} + +template +bfloat16 bfloat16::operator-=(const T& other) { + return *this = *this - other; +} + +template +bfloat16 bfloat16::operator*(const T& other) const { + return {static_cast(*this) * static_cast(other)}; +} + +template +bfloat16 bfloat16::operator*=(const T& other) { + return *this = *this * other; +} + +template +bfloat16 bfloat16::operator/(const T& other) const { + return {static_cast(*this) / static_cast(other)}; +} + +template +bfloat16 bfloat16::operator/=(const T& other) { + return *this = *this / other; +} +#if defined(_MSC_VER) +# pragma warning(pop) +#endif +} // namespace ov + +namespace std { +template <> +class numeric_limits { +public: + static constexpr bool is_specialized = true; + static constexpr ov::bfloat16 min() noexcept { + return ov::bfloat16::from_bits(0x007F); + } + static constexpr ov::bfloat16 max() noexcept { + return ov::bfloat16::from_bits(0x7F7F); + } + static constexpr ov::bfloat16 lowest() noexcept { + return ov::bfloat16::from_bits(0xFF7F); + } + static constexpr int digits = 7; + static constexpr int digits10 = 2; + static constexpr bool is_signed = true; + static constexpr bool is_integer = false; + static constexpr bool is_exact = false; + static constexpr int radix = 2; + static constexpr ov::bfloat16 epsilon() noexcept { + return ov::bfloat16::from_bits(0x3C00); + } + static constexpr ov::bfloat16 round_error() noexcept { + return ov::bfloat16::from_bits(0x3F00); + } + static constexpr int min_exponent = -125; + static constexpr int min_exponent10 = -37; + static constexpr int max_exponent = 128; + static constexpr int max_exponent10 = 38; + static constexpr bool has_infinity = true; + static constexpr bool has_quiet_NaN = true; + static constexpr bool has_signaling_NaN = true; + static constexpr float_denorm_style has_denorm = denorm_absent; + static constexpr bool has_denorm_loss = false; + static constexpr ov::bfloat16 infinity() noexcept { + return ov::bfloat16::from_bits(0x7F80); + } + static constexpr ov::bfloat16 quiet_NaN() noexcept { + return ov::bfloat16::from_bits(0x7FC0); + } + static constexpr ov::bfloat16 signaling_NaN() noexcept { + return ov::bfloat16::from_bits(0x7FC0); + } + static constexpr ov::bfloat16 denorm_min() noexcept { + return ov::bfloat16::from_bits(0); + } + static constexpr bool is_iec559 = false; + static constexpr bool is_bounded = false; + static constexpr bool is_modulo = false; + static constexpr bool traps = false; + static constexpr bool tinyness_before = false; + static constexpr float_round_style round_style = round_to_nearest; +}; +} // namespace std diff --git a/src/common/simple_parallel.hpp b/src/common/simple_parallel.hpp new file mode 100644 index 0000000..b0548f4 --- /dev/null +++ b/src/common/simple_parallel.hpp @@ -0,0 +1,182 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include + +namespace ov { +namespace cpu { + +size_t getTotalThreads(); +void TrySimpleParallelFor(const std::ptrdiff_t total, const std::function& fn); + +// copy from openvino/core/parallel.hpp +template +inline T parallel_it_init(T start) { + return start; +} +template +inline T parallel_it_init(T start, Q& x, const R& X, Args&&... tuple) { + start = parallel_it_init(start, static_cast(tuple)...); + x = start % X; + return start / X; +} + +inline bool parallel_it_step() { + return true; +} +template +inline bool parallel_it_step(Q& x, const R& X, Args&&... tuple) { + if (parallel_it_step(static_cast(tuple)...)) { + if (++x - X == 0) { + x = 0; + return true; + } + } + return false; +} + +template +inline void splitter(const T& n, const Q& team, const Q& tid, T& n_start, T& n_end) { + if (team <= 1 || n == 0) { + n_start = 0; + n_end = n; + } else { + T n1 = (n + (T)team - 1) / (T)team; + T n2 = n1 - 1; + T T1 = n - n2 * (T)team; + n_end = (T)tid < T1 ? n1 : n2; + n_start = (T)tid <= T1 ? tid * n1 : T1 * n1 + ((T)tid - T1) * n2; + } + + n_end += n_start; +} + +namespace helpers { +template +struct NumOfLambdaArgs : public NumOfLambdaArgs {}; + +template +struct NumOfLambdaArgs { + constexpr static int value = sizeof...(Args); +}; + +template ::value> +typename std::enable_if::type call_with_args(const ACT& body, + size_t g_id, + size_t iwork, + T... arg) { + body(g_id, iwork, arg...); +} + +template ::value> +typename std::enable_if::type call_with_args(const ACT& body, + size_t g_id, + size_t iwork, + T... arg) { + body(g_id, arg...); +} + +template ::value> +typename std::enable_if::type call_with_args(const ACT& body, + size_t g_id, + size_t iwork, + T... arg) { + body(arg...); +} +} // namespace helpers + +template +void for_1d(const int& ithr, const int& nthr, const T0& D0, const F& func) { + T0 d0{0}, end{0}; + splitter(D0, nthr, ithr, d0, end); + for (; d0 < end; ++d0) + helpers::call_with_args(func, ithr, d0, d0); +} + +template +void parallel_for(const T0& D0, const F& func) { + auto work_amount = static_cast(D0); + int nthr = static_cast(getTotalThreads()); + if (static_cast(nthr) > work_amount) + nthr = static_cast(work_amount); + if (nthr == 1) { + for_1d(0, 1, D0, func); + } else { + TrySimpleParallelFor(static_cast(nthr), [&](size_t ithr) { + for_1d(static_cast(ithr), nthr, D0, func); + }); + } +} + +template +void for_2d(const int& ithr, const int& nthr, const T0& D0, const T1& D1, const F& func) { + const size_t work_amount = (size_t)D0 * D1; + if (work_amount == 0) + return; + size_t start{0}, end{0}; + splitter(work_amount, nthr, ithr, start, end); + + T0 d0{0}; + T1 d1{0}; + parallel_it_init(start, d0, D0, d1, D1); + for (size_t iwork = start; iwork < end; ++iwork) { + helpers::call_with_args(func, ithr, iwork, d0, d1); + parallel_it_step(d0, D0, d1, D1); + } +} + +template +void parallel_for2d(const T0& D0, const T1& D1, const F& func) { + auto work_amount = static_cast(D0 * D1); + int nthr = static_cast(getTotalThreads()); + if (static_cast(nthr) > work_amount) + nthr = static_cast(work_amount); + if (nthr == 1) { + for_2d(0, 1, D0, D1, func); + } else { + TrySimpleParallelFor(static_cast(nthr), [&](size_t ithr) { + for_2d(static_cast(ithr), nthr, D0, D1, func); + }); + } +} + +template +void for_3d(const int& ithr, const int& nthr, const T0& D0, const T1& D1, const T2& D2, const F& func) { + const size_t work_amount = (size_t)D0 * D1 * D2; + if (work_amount == 0) + return; + size_t start{0}, end{0}; + splitter(work_amount, nthr, ithr, start, end); + + T0 d0{0}; + T1 d1{0}; + T2 d2{0}; + parallel_it_init(start, d0, D0, d1, D1, d2, D2); + for (size_t iwork = start; iwork < end; ++iwork) { + helpers::call_with_args(func, ithr, iwork, d0, d1, d2); + parallel_it_step(d0, D0, d1, D1, d2, D2); + } +} + +template +void parallel_for3d(const T0& D0, const T1& D1, const T2& D2, const F& func) { + auto work_amount = static_cast(D0 * D1 * D2); + int nthr = static_cast(getTotalThreads()); + if (static_cast(nthr) > work_amount) + nthr = static_cast(work_amount); + if (nthr == 1) { + for_3d(0, 1, D0, D1, D2, func); + } else { + TrySimpleParallelFor(static_cast(nthr), [&](size_t ithr) { + for_3d(static_cast(ithr), nthr, D0, D1, D2, func); + }); + } +} + +}; // namespace cpu +}; // namespace ov \ No newline at end of file diff --git a/src/common/tensor2d.hpp b/src/common/tensor2d.hpp new file mode 100644 index 0000000..ce83671 --- /dev/null +++ b/src/common/tensor2d.hpp @@ -0,0 +1,190 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include +#include +#ifdef ENABLE_NUMA +#include "numa.h" +#endif +#include "bf16.hpp" + +#define rndup(x, n) (((x + n - 1)/n)*n) + +template +struct tensor2D { + int dims[2] = {0}; + T* data = nullptr; + int64_t capacity = 0; + int stride = 0; + bool force_compact = false; + bool own; + int padded_dim1 = 0; + + tensor2D() = default; + tensor2D(const tensor2D&) = delete; + ~tensor2D() { + if (own && data) ::free(data); + } + + operator bool() { + return dims[0] * dims[1] > 0; + } + + tensor2D(int d0, int d1, bool _force_compact = false) { + capacity = 0; + resize(d0, d1, _force_compact); + } + + tensor2D(int d0, int d1, T * ext, int _stride) { + capacity = 1; + data = ext; + own = false; + dims[0] = d0; + dims[1] = d1; + stride = _stride; + padded_dim1 = stride / sizeof(T); + } + + tensor2D Tr(bool _force_compact = false) { + tensor2D ret(dims[1], dims[0], _force_compact); + for(int c0=0; c0 < dims[0]; ++c0) { + for(int c1=0; c1 < dims[1]; ++c1) { + ret(c1, c0) = (*this)(c0, c1); + } + } + return ret; + } + tensor2D clone() { + tensor2D ret; + ret.resize(dims[0], dims[1], force_compact); + if (ret.stride == stride) { + memcpy(ret.data, data, dims[0] * stride); + }else{ + for(int i=0;i clone_with_padzero(int dim0, int dim1) { + tensor2D ret; + ret.resize(dim0, dim1, force_compact); + assert(dim0 >= dims[0] && dim1 >= dims[1]); + for(int i = 0; i < dims[0]; i++) { + memcpy(&ret(i, 0), &(*this)(i, 0), dims[1] * sizeof(T)); + memset(reinterpret_cast(&ret(i, 0) + dims[1]), 0, ret.stride - dims[1] * sizeof(T)); + } + if (dims[1] == dim1) { + memset(reinterpret_cast(ret.data + dims[0] * ret.padded_dim1), 0, (dim0 - dims[0]) * ret.stride); + } + + return ret; + } + void resize(int d0, int d1, bool _force_compact = false, bool is_const=false) { + own = true; + force_compact = _force_compact; + dims[0] = d0; + dims[1] = d1; + stride = d1 * sizeof(T); + if ((stride % 64) && (!force_compact)) { + auto stride_fix = rndup(stride, 64); + stride = stride_fix; + } + padded_dim1 = stride / sizeof(T); + + // resize method never shrink capacity, and extra T is added to put nan as test + auto need_capacity = dims[0] * stride + 4096; + if (capacity < need_capacity) { + if (!is_const) + need_capacity *= 2; + capacity = need_capacity; + // align begin address to cache line is vital, so tile load can + // use all bandwidth (L1D/L2 only deliver data in unit of 64-byte aligned cache-line) + +#ifdef ENABLE_NUMA + if (USE_NUMA) { + data = std::shared_ptr( + reinterpret_cast(numa_alloc_local(capacity)), + [need_capacity](void * p){ numa_free(p, need_capacity); }); + } else { +#else + { +#endif + if (data) ::free(data); + data = reinterpret_cast(aligned_alloc(64, capacity)); + } + if (is_const) + memset(static_cast(data), 0, need_capacity); + if (reinterpret_cast(data) % 64) + std::cout << "WARNING: resize(), data is not cache-line aligned!" << std::endl; + } + // put a NaN at the end to test over-read + // https://en.wikipedia.org/wiki/Bfloat16_floating-point_format + // #define INF 0xff80 + // #define NAN1 (INF + 1) + // if (sizeof(T) == 2) { + // *reinterpret_cast(data.get() + dims[0] * padded_dim1) = NAN1; + // } + } + + T & operator[](int i) { + return data[i]; + } + + const T & operator[](int i) const { + return data[i]; + } + + //https://stackoverflow.com/questions/1936399/c-array-operator-with-multiple-arguments + T & operator()(int i0, int i1) { + return (*this)[i0 * padded_dim1 + i1]; + } + + const T & operator()(int i0, int i1) const { + return (*this)[i0 * padded_dim1 + i1]; + } + + + void operator=(const T & v) { + for(int k = 0; k < dims[0] * padded_dim1; k++) + (*this)[k] = v; + } + + tensor2D& operator=(const tensor2D & t2) = delete; + + // move semantics + tensor2D(tensor2D && t2) { + dims[0] = t2.dims[0]; + dims[1] = t2.dims[1]; + if (own && data) ::free(data); + data = t2.data; + own = t2.own; + capacity = t2.capacity; + stride = t2.stride; + padded_dim1 = t2.padded_dim1; + force_compact = t2.force_compact; + t2.capacity = 0; + t2.data = nullptr; + } + + tensor2D& operator=(tensor2D && t2) { + dims[0] = t2.dims[0]; + dims[1] = t2.dims[1]; + if (own && data) ::free(data); + own = t2.own; + data = t2.data; + capacity = t2.capacity; + stride = t2.stride; + padded_dim1 = t2.padded_dim1; + force_compact = t2.force_compact; + t2.capacity = 0; + t2.data = nullptr; + return *this; + } +}; diff --git a/src/common/tensor2d_helper.hpp b/src/common/tensor2d_helper.hpp new file mode 100644 index 0000000..417b2a6 --- /dev/null +++ b/src/common/tensor2d_helper.hpp @@ -0,0 +1,200 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include +#include +#include +#include "llm_types.hpp" +#include "tensor2d.hpp" +#include "bf16.hpp" +#ifdef _WIN32 +#include +#else +#include +#endif +#include + +// https://stackoverflow.com/questions/570669/checking-if-a-double-or-float-is-nan-in-c/57770634#57770634 +static inline uint32_t load_ieee754_rep(float a) { + uint32_t r; + static_assert(sizeof r == sizeof a, "Unexpected sizes."); + std::memcpy(&r, &a, sizeof a); // Generates movd instruction. + return r; +} +constexpr uint32_t inf_float_shl1 = UINT32_C(0xff000000); +// The shift left removes the sign bit. The exponent moves into the topmost bits, +// so that plain unsigned comparison is enough. +static inline bool isnan2(float a) { return load_ieee754_rep(a) << 1 > inf_float_shl1; } +static inline bool isinf2(float a) { return load_ieee754_rep(a) << 1 == inf_float_shl1; } +static inline bool isfinite2(float a) { return load_ieee754_rep(a) << 1 < inf_float_shl1; } + +template +void fill_rnd(tensor2D& t) { + auto * p = t.data; + int i = 0; + int total = t.dims[0] * t.padded_dim1; + // +1 -1 for integer types + // 0.5 -0.5 for float point + float scale = std::is_integral::value ? 2:1; + for(i = 0; i + 8 <= total; i+=8) { + // lower mantissa can help to avoid small errors in accuracy comparison + auto num = rand() & 0xFF; + p[i] = scale*((num & 1) - 0.5f); num>>=1; + p[i+1] = scale*((num & 1) - 0.5f); num>>=1; + p[i+2] = scale*((num & 1) - 0.5f); num>>=1; + p[i+3] = scale*((num & 1) - 0.5f); num>>=1; + p[i+4] = scale*((num & 1) - 0.5f); num>>=1; + p[i+5] = scale*((num & 1) - 0.5f); num>>=1; + p[i+6] = scale*((num & 1) - 0.5f); num>>=1; + p[i+7] = scale*((num & 1) - 0.5f); num>>=1; + } + for(; i +bool operator==(const tensor2D& lhs, const tensor2D& rhs) { + if (lhs.dims[0] != rhs.dims[0] || lhs.dims[1] != rhs.dims[1]) + return false; + for(int i0 = 0; i0 < lhs.dims[0]; i0++) + for(int i1 = 0; i1 < lhs.dims[1]; i1++) { + // with -ffast-math, std::isnan, std::isinf, x != x always return false + // so we need special logic to test nan here + if (std::is_same::value || + std::is_same::value) { + float f0 = lhs(i0,i1); + float f1 = rhs(i0,i1); + if (isnan2(f1) || isnan2(f0)) { + std::cout << " nan is found: f0=" << f0 << ", f1=" << f1 << std::endl; + return false; + } + if (std::abs(f0 - f1) <= 0.01) + continue; + } + + if (lhs(i0,i1) == rhs(i0,i1)) + continue; + std::cout << " operator== failed at (" << i0 << ", " << i1 << ") value " + << lhs(i0,i1) << "!=" << rhs(i0,i1) << std::endl; + return false; + } + return true; +} + +template +bool is_normal(const tensor2D& t) { + for (int i0 = 0; i0 < t.dims[0]; i0++) + for (int i1 = 0; i1 < t.dims[1]; i1++) { + float f0 = t(i0,i1); + if (isnan2(f0)) { + std::cout << " found nan at (" << i0 << "," << i1 << ")" << std::endl; + return false; + } + if (isinf2(f0)) { + std::cout << " found inf at (" << i0 << "," << i1 << ")" << std::endl; + return false; + } + } + return true; +} + +template +bool compare(const tensor2D& lhs, const tensor2D& rhs, float tolerance) { + float max_abs_diff = 0; + float max_rel_diff = 0; + if (lhs.dims[0] != rhs.dims[0] || lhs.dims[1] != rhs.dims[1]) + return false; + for (int i0 = 0; i0 < lhs.dims[0]; i0++) + for (int i1 = 0; i1 < lhs.dims[1]; i1++) { + float f0 = lhs(i0, i1); + float f1 = rhs(i0, i1); + auto diff = std::fabs(f0 - f1); + auto rel_diff = diff / std::fabs(f0); + max_abs_diff = std::max(max_abs_diff, diff); + if (std::fabs(lhs(i0,i1) > 0) && diff > 0) + max_rel_diff = std::max(max_rel_diff, rel_diff); + } + std::cout << "max_abs_diff=" << max_abs_diff << " max_rel_diff=" << max_rel_diff << "\n"; + return tolerance > max_abs_diff; +} + +template +std::ostream& operator<<(std::ostream& out, const tensor2D& obj) { + int i0; + auto showline = [&](int i) { + out << "[" << i << "," << 0 << "]: "; + int i1; + for(i1=0; i1 +inline void show(const T * data, int rows, int cols) { + std::ostream& out = std::cout; + out << "==============\n"; + for(int i0=0; i0 < rows; i0++) { + out << "[" << i0 << "," << 0 << "]: "; + for(int i1=0; i1 +inline void vshow(__m512i v) { + T values[512/8/sizeof(T)]; + _mm512_storeu_si512(values, v); + show(values, 1, 512/8/sizeof(T)); +} + +template +inline void vshow(__m512 v) { + T values[512/8/sizeof(T)]; + _mm512_storeu_ps(values, v); + show(values, 1, 512/8/sizeof(T)); +} + +template +inline void tshow() { + if (std::is_same::value) { + ov::bfloat16 data[16*32]; + _tile_stored(tile, data, 64); + show(data, 16, 32); + } + if (std::is_same::value) { + float data[16*16]; + _tile_stored(tile, data, 64); + show(data, 16, 16); + } + if (std::is_same::value) { + int8_t data[16*64]; + _tile_stored(tile, data, 64); + show(data, 16, 64); + } + if (std::is_same::value) { + uint8_t data[16*64]; + _tile_stored(tile, data, 64); + show(data, 16, 64); + } +} diff --git a/src/common/utility.hpp b/src/common/utility.hpp new file mode 100644 index 0000000..4d81a19 --- /dev/null +++ b/src/common/utility.hpp @@ -0,0 +1,66 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include +#include +#include +#include "llm_types.hpp" + +#ifndef OV_DECL_ALIGNED +# ifdef __GNUC__ +# define OV_DECL_ALIGNED(x) __attribute__ ((aligned (x))) +# elif defined _MSC_VER +# define OV_DECL_ALIGNED(x) __declspec(align(x)) +# else +# define OV_DECL_ALIGNED(x) +# endif +#endif // OV_DECL_ALIGNED + +namespace llmdnn { + +inline size_t get_precision_size(data_type_t type) { + switch(type) { + case dnnl_f16: + case dnnl_bf16: + return 2; + case dnnl_f32: + case dnnl_s32: + return 4; + case dnnl_s8: + case dnnl_u8: + return 1; + case dnnl_f64: + return 8; + default: + assert(false && "unknown data type"); + return 0; + } +} + +inline data_type_t get_dt_from_str(const std::string& name) { + static std::pair name2type[] = { + { "f16", dnnl_f16 }, + { "bf16", dnnl_bf16 }, + { "f32", dnnl_f32 }, + { "s32", dnnl_s32 }, + { "i32", dnnl_s32 }, + { "s8", dnnl_s8 }, + { "i8", dnnl_s8 }, + { "u8", dnnl_u8 }, + { "f64", dnnl_f64 }, + }; + for (size_t i = 0; i < sizeof(name2type) / sizeof(name2type[0]); i++) { + if (name == name2type[i].first) + return name2type[i].second; + } + + return dnnl_data_type_undef; +} + +} \ No newline at end of file diff --git a/src/emb_gpt_api.cpp b/src/emb_gpt_api.cpp new file mode 100644 index 0000000..b8a7b84 --- /dev/null +++ b/src/emb_gpt_api.cpp @@ -0,0 +1,24 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include + +#include "emb_gpt_avx512.hpp" + +namespace llmdnn { + +// interface +emb_gpt::emb_gpt(): _impl(new_impl_avx512()) { +} + +bool emb_gpt::create(const create_param& param) { + return _impl->create(param); +} + +void emb_gpt::exec(const exec_param& param) { + _impl->exec(param); +} + +} \ No newline at end of file diff --git a/src/emb_gpt_avx512.cpp b/src/emb_gpt_avx512.cpp new file mode 100644 index 0000000..b290da2 --- /dev/null +++ b/src/emb_gpt_avx512.cpp @@ -0,0 +1,224 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include + +#include "common/simple_parallel.hpp" +#include "common/utility.hpp" +#include "utility_kernel_avx512.hpp" +#include "transpose_kernel_avx512.hpp" +#include "llm_emb_gpt.hpp" +#include "emb_gpt_avx512.hpp" +#include "rotary_kernel_avx512.hpp" + +using namespace ov::cpu; + +namespace llmdnn { + +struct emb_gpt_impl_avx512 : public emb_gpt::impl { + bool create(const emb_gpt::create_param& param) override; + void exec(const emb_gpt::exec_param& param) override; + + void memcpyPastKV(uint8_t** pastk_src, uint8_t** pastv_src, uint8_t** pastk_dst, uint8_t** pastv_dst, + size_t batch, size_t past_seq_len, size_t head_stride_in_kv); + void applyRotaryPosEmbMemcpy(uint8_t* q_src, uint8_t* k_src, uint8_t* v_src, size_t ldq, size_t ldk, size_t ldv, uint8_t* q_dst, uint8_t** k_dst, uint8_t** v_dst, + size_t batch, size_t q_seq_len, size_t past_seq_len, size_t head_stride_in_kv, float* cos, float* sin); + void applyRotaryPosEmbMemcpyWithPosition2d(uint8_t* q_src, uint8_t* k_src, uint8_t* v_src, size_t ldq, size_t ldk, size_t ldv, uint8_t* q_dst, uint8_t** k_dst, uint8_t** v_dst, + size_t batch, size_t q_seq_len, size_t past_seq_len, int* position2d_ids, size_t head_stride_in_kv, float* cos, float* sin); + + emb_gpt::create_param _create_param; + size_t _head_num = 32; + size_t _size_per_head = 80; + size_t _hidden_size = 32 * 80; + size_t _rotary_dim = 20; + // aligned to cache line + size_t _size_per_head_aligned = 80; + int64_t _input_type_size = 1; + int64_t _output_type_size = 1; + bool _use_position2d = false; +}; + +bool emb_gpt_impl_avx512::create(const emb_gpt::create_param& param) { + if (param.qkv_precision != dnnl_bf16) { + std::cout << "input precision must be bf16 or int8.\n"; + return false; + } + // TODO: support s8 + // if (param.dst_precision != dnnl_bf16 && param.dst_precision != dnnl_s8) { + // std::cout << "dst precision must be bf16 or int8.\n"; + // return false; + // } + _create_param = param; + + _head_num = param.num_heads; + _size_per_head = param.head_size; + _size_per_head_aligned = param.head_size_aligned; + _hidden_size = param.head_size * param.num_heads; + _rotary_dim = param.rotary_dims; + _input_type_size = sizeof(ov::bfloat16); + _output_type_size = sizeof(ov::bfloat16); + if (param.dst_precision == dnnl_s8) + _output_type_size = sizeof(int8_t); + + _use_position2d = param.use_position2d; + + return true; +} + +void emb_gpt_impl_avx512::memcpyPastKV(uint8_t** pastk_src, uint8_t** pastv_src, uint8_t** pastk_dst, uint8_t** pastv_dst, + size_t batch, size_t past_seq_len, size_t head_stride_in_kv) { + parallel_for3d(batch, _head_num, past_seq_len, [&](size_t b, size_t h, size_t s) { + auto k_dst_batch = pastk_dst[b]; + auto v_dst_batch = pastv_dst[b]; + auto k_src_batch = pastk_src[b]; + auto v_src_batch = pastv_src[b]; + auto k_dst_seq = k_dst_batch + s * _size_per_head_aligned * _output_type_size; + auto v_dst_seq = v_dst_batch + s * _size_per_head_aligned * _output_type_size; + auto k_src_seq = k_src_batch + s * _size_per_head_aligned * _output_type_size; + auto v_src_seq = v_src_batch + s * _size_per_head_aligned * _output_type_size; + auto* k_src_f = k_src_seq + h * past_seq_len * _size_per_head_aligned * _output_type_size; + auto* k_dst_f = k_dst_seq + h * head_stride_in_kv * _output_type_size; + auto* v_src_f = v_src_seq + h * past_seq_len * _size_per_head_aligned * _output_type_size; + auto* v_dst_f = v_dst_seq + h * head_stride_in_kv * _output_type_size; + + memcpy(k_dst_f, k_src_f, _output_type_size * _size_per_head); + memcpy(v_dst_f, v_src_f, _output_type_size * _size_per_head); + }); +} + +// q_src shape: [batch, q_seq_len, num_attention_heads, 3 * head_size] +// q_dst shape: [batch, num_attention_heads, q_seq_len, head_size_aligned] +// kv_src shape: [batch, q_seq_len, num_attention_heads, 3 * head_size] +// kv_dst shape: [batch, num_attention_heads, q_seq_len+past_seq_len, head_size_aligned] +void emb_gpt_impl_avx512::applyRotaryPosEmbMemcpy(uint8_t* q_src, uint8_t* k_src, uint8_t* v_src, size_t ldq, size_t ldk, size_t ldv, uint8_t* q_dst, uint8_t** k_dst, uint8_t** v_dst, + size_t batch, size_t q_seq_len, size_t past_seq_len, size_t head_stride_in_kv, float* cos, float* sin) { + auto key_offset = _output_type_size * past_seq_len * _size_per_head_aligned; + auto* cos_cached = cos + past_seq_len * _rotary_dim; + auto* sin_cached = sin + past_seq_len * _rotary_dim; + parallel_for3d(batch, _head_num, q_seq_len, [&](size_t b, size_t h, size_t s) { + // q, k rotary encoding + auto q_dst_batch = q_dst + b * _head_num * q_seq_len * _size_per_head_aligned * _output_type_size; + auto k_dst_batch = k_dst[b] + key_offset; + auto v_dst_batch = v_dst[b] + key_offset; + auto q_src_batch = q_src + b * _head_num * ldq * q_seq_len * _input_type_size; + auto k_src_batch = k_src + b * _head_num * ldk * q_seq_len * _input_type_size; + auto v_src_batch = v_src + b * _head_num * ldv * q_seq_len * _input_type_size; + auto q_dst_seq = q_dst_batch + s * _size_per_head_aligned * _output_type_size; + auto k_dst_seq = k_dst_batch + s * _size_per_head_aligned * _output_type_size; + auto v_dst_seq = v_dst_batch + s * _size_per_head_aligned * _output_type_size; + auto q_src_seq = q_src_batch + s * _head_num * ldq * _input_type_size; + auto k_src_seq = k_src_batch + s * _head_num * ldk * _input_type_size; + auto v_src_seq = v_src_batch + s * _head_num * ldv * _input_type_size; + auto* q_src_f = reinterpret_cast(q_src_seq + h * ldq * _input_type_size); + auto* k_src_f = reinterpret_cast(k_src_seq + h * ldk * _input_type_size); + auto* q_dst_f = reinterpret_cast(q_dst_seq + h * q_seq_len * _size_per_head_aligned * _output_type_size); + auto* k_dst_f = reinterpret_cast(k_dst_seq + h * head_stride_in_kv * _output_type_size); + rotary_avx512(_rotary_dim, cos_cached + s * _rotary_dim, sin_cached + s * _rotary_dim, q_src_f, k_src_f, q_dst_f, k_dst_f); + + // q, k concat + memcpy(reinterpret_cast(q_dst_f) + _rotary_dim * _output_type_size, reinterpret_cast(q_src_f) + _rotary_dim * _input_type_size, _output_type_size * (_size_per_head - _rotary_dim)); + memcpy(reinterpret_cast(k_dst_f) + _rotary_dim * _output_type_size, reinterpret_cast(k_src_f) + _rotary_dim * _input_type_size, _output_type_size * (_size_per_head - _rotary_dim)); + // v concat + memcpy(static_cast(v_dst_seq) + h * head_stride_in_kv * _output_type_size, + static_cast(v_src_seq) + h * ldv * _input_type_size, + _size_per_head * _output_type_size); + }); +} + +// q_src shape: [batch, q_seq_len, num_attention_heads, 3 * head_size] +// q_dst shape: [batch, num_attention_heads, q_seq_len, head_size_aligned] +// kv_src shape: [batch, q_seq_len, num_attention_heads, 3 * head_size] +// kv_dst shape: [batch, num_attention_heads, q_seq_len+past_seq_len, head_size_aligned] +// position2d_ids: [batch, 2, q_seq_len] +void emb_gpt_impl_avx512::applyRotaryPosEmbMemcpyWithPosition2d(uint8_t* q_src, uint8_t* k_src, uint8_t* v_src, size_t ldq, size_t ldk, size_t ldv, uint8_t* q_dst, uint8_t** k_dst, uint8_t** v_dst, + size_t batch, size_t q_seq_len, size_t past_seq_len, int* position2d_ids, size_t head_stride_in_kv, float* cos, float* sin) { + auto key_offset = _output_type_size * past_seq_len * _size_per_head_aligned; + auto* cos_cached = cos; + auto* sin_cached = sin; + parallel_for3d(batch, _head_num, q_seq_len, [&](size_t b, size_t h, size_t s) { + // q, k rotary encoding + auto q_dst_batch = q_dst + b * _head_num * q_seq_len * _size_per_head_aligned * _output_type_size; + auto k_dst_batch = k_dst[b] + key_offset; + auto v_dst_batch = v_dst[b] + key_offset; + auto pos_batch = position2d_ids + b * 2 * q_seq_len; + auto block_batch = pos_batch + q_seq_len; + auto q_src_batch = q_src + b * _head_num * ldq * q_seq_len * _input_type_size; + auto k_src_batch = k_src + b * _head_num * ldk * q_seq_len * _input_type_size; + auto v_src_batch = v_src + b * _head_num * ldv * q_seq_len * _input_type_size; + auto q_dst_seq = q_dst_batch + s * _size_per_head_aligned * _output_type_size; + auto k_dst_seq = k_dst_batch + s * _size_per_head_aligned * _output_type_size; + auto v_dst_seq = v_dst_batch + s * _size_per_head_aligned * _output_type_size; + auto q_src_seq = q_src_batch + s * _head_num * ldq * _input_type_size; + auto k_src_seq = k_src_batch + s * _head_num * ldk * _input_type_size; + auto v_src_seq = v_src_batch + s * _head_num * ldv * _input_type_size; + auto* q_src_f = reinterpret_cast(q_src_seq + h * ldq * _input_type_size); + auto* k_src_f = reinterpret_cast(k_src_seq + h * ldk * _input_type_size); + auto* q_dst_f = reinterpret_cast(q_dst_seq + h * q_seq_len * _size_per_head_aligned * _output_type_size); + auto* k_dst_f = reinterpret_cast(k_dst_seq + h * head_stride_in_kv * _output_type_size); + rotary_avx512(_rotary_dim, cos_cached + pos_batch[s] * _rotary_dim, sin_cached + pos_batch[s] * _rotary_dim, q_src_f, k_src_f, q_dst_f, k_dst_f); + rotary_avx512(_rotary_dim, cos_cached + block_batch[s] * _rotary_dim, sin_cached + block_batch[s] * _rotary_dim, + q_src_f + _rotary_dim, + k_src_f + _rotary_dim, + q_dst_f + _rotary_dim, + k_dst_f + _rotary_dim); + + // v concat + memcpy(static_cast(v_dst_seq) + h * head_stride_in_kv * _output_type_size, + static_cast(v_src_seq) + h * ldv* _input_type_size, + _size_per_head * _output_type_size); + }); +} + +void emb_gpt_impl_avx512::exec(const emb_gpt::exec_param& param) { + // [batch, seq_len, (num_heads * 3 * head_size)] + // --> [batch, seq_len, num_heads, 3 * head_size] + auto query = param.q; + auto key = param.k; + auto value = param.v; + auto query_dst = param.query_dst; + auto key_dst = param.layer_past_key_dst; + auto value_dst = param.layer_past_value_dst; + auto batch = param.batch; + auto query_seq_len = param.query_seq_len; + auto past_seq_len = param.past_seq_len; + auto head_stride_in_kv = param.head_stride_in_kv; + + // past kv src != dst, copy src to dst first + if (param.layer_past_key_src && param.layer_past_key_src[0] != param.layer_past_key_dst[0] && past_seq_len) + memcpyPastKV(param.layer_past_key_src, param.layer_past_value_src, param.layer_past_key_dst, param.layer_past_value_dst, batch, past_seq_len, head_stride_in_kv); + + // transpose + rotary embbeding: + // transpose: [batch, seq_len, num_attention_heads, 3 * head_size] --> + // 3 [batch, num_attention_heads, seq_len, head_size] + // rotary embbeding: part of key will write to past_key, part of query will write to tempory buffer + if (_create_param.dst_precision == dnnl_s8) { + // query pass part(temp buffer): query = torch.cat((query, query_pass), dim=-1) + // key pass part(past_key): key = torch.cat((key, key_pass), dim=-1) + // value(pastKeys): value = torch.cat((past_value, value), dim=-2) + // applyRotaryPosEmbMemcpyQuant(query, key, queryTranspose.get(), current_k_bufs, _output_type_size * new_seq_offset * _size_per_head_aligned, + // _cos_cached.get(), _sin_cached.get(), batch, seq_len, new_seq_offset, value, current_v_bufs); + assert(false); + } else { + // query pass part(temp buffer): query = torch.cat((query, query_pass), dim=-1) + // key pass part(past_key): key = torch.cat((key, key_pass), dim=-1) + // value(pastKeys): value = torch.cat((past_value, value), dim=-2) + // q_dst shape: [batch, num_attention_heads, q_seq_len, head_size_aligned] + // kv_dst shape: [batch, num_attention_heads, q_seq_len+past_seq_len, head_size_aligned] + if (_use_position2d) { + applyRotaryPosEmbMemcpyWithPosition2d(query, key, value, param.ldq, param.ldk, param.ldv, query_dst, key_dst, value_dst, batch, query_seq_len, past_seq_len, + param.position2d_ids, head_stride_in_kv, param.cos, param.sin); + } else { + applyRotaryPosEmbMemcpy(query, key, value, param.ldq, param.ldk, param.ldv, query_dst, key_dst, value_dst, batch, query_seq_len, past_seq_len, head_stride_in_kv, + param.cos, param.sin); + } + } +} + +std::shared_ptr new_impl_avx512() { + return std::make_shared(); +} + +} \ No newline at end of file diff --git a/src/emb_gpt_avx512.hpp b/src/emb_gpt_avx512.hpp new file mode 100644 index 0000000..62e3bea --- /dev/null +++ b/src/emb_gpt_avx512.hpp @@ -0,0 +1,17 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include "llm_types.hpp" +#include "llm_emb_gpt.hpp" + +namespace llmdnn { + +std::shared_ptr new_impl_avx512(); + +} diff --git a/src/fc_kernel_amx.cpp b/src/fc_kernel_amx.cpp new file mode 100644 index 0000000..9b91e89 --- /dev/null +++ b/src/fc_kernel_amx.cpp @@ -0,0 +1,358 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "llm_fc.hpp" +#include "mm_kernel_common_amx.hpp" +#include "utility_kernel_avx512.hpp" +#include "fc_kernel_amx.hpp" + +namespace llmdnn { + +using ov::bfloat16; +struct fc_kernel { + std::shared_ptr> bf16xbf16; + std::shared_ptr> bf16xi8; + std::shared_ptr> i8xi8; + std::shared_ptr> u8xi8; + + data_type_t dt_a; + data_type_t dt_b; + data_type_t dt_c; + postops_types postops_type; + bool b_is_transpose; +}; + +using supported_key = std::tuple; +using supported_value = std::pair; +static std::map supported_postops = { + { { dnnl_s8, dnnl_s8, dnnl_s8 }, { DEQUANT | QUANT, BIAS | GELU | GELU_TANH } }, + { { dnnl_s8, dnnl_s8, dnnl_bf16 }, { DEQUANT, BIAS | GELU | GELU_TANH } }, + { { dnnl_s8, dnnl_s8, dnnl_f32 }, { DEQUANT, BIAS | GELU | GELU_TANH } }, + { { dnnl_bf16, dnnl_bf16, dnnl_bf16 }, { 0, BIAS | GELU | GELU_TANH } }, + { { dnnl_bf16, dnnl_bf16, dnnl_f32 }, { 0, BIAS | GELU | GELU_TANH } }, + { { dnnl_bf16, dnnl_s8, dnnl_f32 }, { DEQUANT, BIAS | GELU | GELU_TANH } }, + { { dnnl_bf16, dnnl_s8, dnnl_bf16 }, { DEQUANT, BIAS | GELU | GELU_TANH } }, +}; + +static bool check_valid_postops(size_t value, data_type_t dt_a, data_type_t dt_b, data_type_t dt_c) { + auto it = supported_postops.find(std::make_tuple(dt_a, dt_b, dt_c)); + if (it == supported_postops.end()) { + return false; + } + + size_t must_have; + size_t opt_have; + must_have = (*it).second.first; + opt_have = (*it).second.second; + + if ((value & must_have) != must_have) + return false; + // value must in must_have and opt_have + if ((value & ~(must_have | opt_have)) != 0) + return false; + + return true; +} + +// interface +bool fc_kernel_create_amx(fc_kernel** mm, const fc_create_param* param) { + fc_kernel* m = nullptr; + if (param == nullptr || mm == nullptr) { + std::cout << "fc_kernel_create: invalid input parameter.\n"; + goto ERR; + } + + if (!check_valid_postops(static_cast(param->postops_type), param->dt_a, param->dt_b, param->dt_c)) { + std::cout << "fc_kernel_create: unsupported data type, a: " << param->dt_a <<", b: " << param->dt_b << ", c: " << param->dt_c << + ", postops type: " << param->postops_type << ".\n"; + goto ERR; + } + + m = new fc_kernel; + if (param->dt_a == dnnl_s8 && param->dt_b == dnnl_s8) { + m->i8xi8 = std::make_shared>(true, param->b_is_trans); + } else if (param->dt_a == dnnl_u8 && param->dt_b == dnnl_s8) { + m->u8xi8 = std::make_shared>(true, param->b_is_trans); + } else if (param->dt_a == dnnl_bf16 && param->dt_b == dnnl_bf16) { + m->bf16xbf16 = std::make_shared>(true, param->b_is_trans); + } else if (param->dt_a == dnnl_bf16 && param->dt_b == dnnl_s8) { + m->bf16xi8 = std::make_shared>(true, param->b_is_trans); + m->bf16xi8->quant_scale_B = param->q; + m->bf16xi8->dequant_scale_B = param->dq; + } else { + std::cout << "fc_kernel_create: unsupport input type, a: " << param->dt_a << ", b: " << param->dt_b << ".\n"; + goto ERR; + } + + m->dt_a = param->dt_a; + m->dt_b = param->dt_b; + m->dt_c = param->dt_c; + m->b_is_transpose = param->b_is_trans; + m->postops_type = param->postops_type; + + *mm = m; + return true; +ERR: + delete m; + return false; +} + +void fc_kernel_destroy_amx(const fc_kernel* mm) { + if (mm) { + delete mm; + } +} + +void fc_kernel_execute_amx(const fc_kernel* mm, void* ptr_a, void* ptr_b, void* ptr_c, size_t lda, size_t ldb, size_t ldc, + size_t M, size_t N, size_t K, size_t n_start, size_t n_end, float* dq, float* q, float* bias) { + size_t b_d0 = K, b_d1 = N; + if (mm->b_is_transpose) { + b_d0 = N; + b_d1 = K; + } + if (mm->i8xi8) { + tensor2D a(M, K, reinterpret_cast(ptr_a), lda); + tensor2D b(b_d0, b_d1, reinterpret_cast(ptr_b), ldb); + + if (mm->dt_c == dnnl_s8) { + tensor2D c(M, N, reinterpret_cast(ptr_c), ldc); + if (!(mm->postops_type & BIAS)) { + if (mm->postops_type & GELU) { + amx_kernel::PP::BiasGeluStore ppkernel(c); + ppkernel.set_deq_scale(dq); + ppkernel.set_q_scale(q); + (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); + } else if (mm->postops_type & GELU_TANH) { + amx_kernel::PP::BiasGeluStore ppkernel(c); + ppkernel.set_deq_scale(dq); + ppkernel.set_q_scale(q); + (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); + } else { + amx_kernel::PP::BiasGeluStore ppkernel(c); + ppkernel.set_deq_scale(dq); + ppkernel.set_q_scale(q); + (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); + } + } else { + if (mm->postops_type & GELU) { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + ppkernel.set_deq_scale(dq); + ppkernel.set_q_scale(q); + (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); + } else if (mm->postops_type & GELU_TANH) { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + ppkernel.set_deq_scale(dq); + ppkernel.set_q_scale(q); + (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); + } else { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + ppkernel.set_deq_scale(dq); + ppkernel.set_q_scale(q); + (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); + } + } + } else if (mm->dt_c == dnnl_bf16) { + tensor2D c(M, N, reinterpret_cast(ptr_c), ldc); + if (!bias) { + if (mm->postops_type & GELU) { + amx_kernel::PP::BiasGeluStore ppkernel(c); + ppkernel.set_deq_scale(dq); + (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); + } else if (mm->postops_type & GELU_TANH) { + amx_kernel::PP::BiasGeluStore ppkernel(c); + ppkernel.set_deq_scale(dq); + (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); + } else { + amx_kernel::PP::BiasGeluStore ppkernel(c); + ppkernel.set_deq_scale(dq); + (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); + } + } else { + if (mm->postops_type & GELU) { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + ppkernel.set_deq_scale(dq); + (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); + } else if (mm->postops_type & GELU_TANH) { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + ppkernel.set_deq_scale(dq); + (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); + } else { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + ppkernel.set_deq_scale(dq); + (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); + } + } + } else if (mm->dt_c == dnnl_f32) { + tensor2D c(M, N, reinterpret_cast(ptr_c), ldc); + if (!bias) { + if (mm->postops_type & GELU) { + amx_kernel::PP::BiasGeluStore ppkernel(c); + ppkernel.set_deq_scale(dq); + (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); + } else if (mm->postops_type & GELU_TANH) { + amx_kernel::PP::BiasGeluStore ppkernel(c); + ppkernel.set_deq_scale(dq); + (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); + } else { + amx_kernel::PP::BiasGeluStore ppkernel(c); + ppkernel.set_deq_scale(dq); + (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); + } + } else { + if (mm->postops_type & GELU) { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + ppkernel.set_deq_scale(dq); + (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); + } else if (mm->postops_type & GELU_TANH) { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + ppkernel.set_deq_scale(dq); + (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); + } else { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + ppkernel.set_deq_scale(dq); + (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); + } + } + } + } else if (mm->u8xi8) { + tensor2D a(M, K, reinterpret_cast(ptr_a), lda); + tensor2D b(b_d0, b_d1, reinterpret_cast(ptr_b), ldb); + tensor2D c(M, N, reinterpret_cast(ptr_c), ldc); + amx_kernel::PP::BiasGeluStore pp(c); + (*mm->u8xi8)(a, b, n_start, n_end, pp); + } else if (mm->bf16xbf16) { + tensor2D a(M, K, reinterpret_cast(ptr_a), lda); + tensor2D b(b_d0, b_d1, reinterpret_cast(ptr_b), ldb); + + if (mm->dt_c == dnnl_bf16) { + tensor2D c(M, N, reinterpret_cast(ptr_c), ldc); + if (!(mm->postops_type & BIAS)) { + if (mm->postops_type & GELU) { + amx_kernel::PP::BiasGeluStore ppkernel(c); + (*mm->bf16xbf16)(a, b, n_start, n_end, ppkernel); + } else if (mm->postops_type & GELU_TANH) { + amx_kernel::PP::BiasGeluStore ppkernel(c); + (*mm->bf16xbf16)(a, b, n_start, n_end, ppkernel); + } else { + amx_kernel::PP::BiasGeluStore ppkernel(c); + (*mm->bf16xbf16)(a, b, n_start, n_end, ppkernel); + } + } else { + if (mm->postops_type & GELU) { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + (*mm->bf16xbf16)(a, b, n_start, n_end, ppkernel); + } else if (mm->postops_type & GELU_TANH) { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + (*mm->bf16xbf16)(a, b, n_start, n_end, ppkernel); + } else { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + (*mm->bf16xbf16)(a, b, n_start, n_end, ppkernel); + } + } + } else if (mm->dt_c == dnnl_f32) { + tensor2D c(M, N, reinterpret_cast(ptr_c), ldc); + if (!(mm->postops_type & BIAS)) { + if (mm->postops_type & GELU) { + amx_kernel::PP::BiasGeluStore ppkernel(c); + (*mm->bf16xbf16)(a, b, n_start, n_end, ppkernel); + } else if (mm->postops_type & GELU_TANH) { + amx_kernel::PP::BiasGeluStore ppkernel(c); + (*mm->bf16xbf16)(a, b, n_start, n_end, ppkernel); + } else { + amx_kernel::PP::BiasGeluStore ppkernel(c); + (*mm->bf16xbf16)(a, b, n_start, n_end, ppkernel); + } + } else { + if (mm->postops_type & GELU) { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + (*mm->bf16xbf16)(a, b, n_start, n_end, ppkernel); + } else if (mm->postops_type & GELU_TANH) { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + (*mm->bf16xbf16)(a, b, n_start, n_end, ppkernel); + } else { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + (*mm->bf16xbf16)(a, b, n_start, n_end, ppkernel); + } + } + } + } else { + tensor2D a(M, K, reinterpret_cast(ptr_a), lda); + tensor2D b(N, K, reinterpret_cast(ptr_b), ldb); + + if (mm->dt_c == dnnl_bf16) { + tensor2D c(M, N, reinterpret_cast(ptr_c), ldc); + if (!(mm->postops_type & BIAS)) { + if (mm->postops_type & GELU) { + amx_kernel::PP::BiasGeluStore ppkernel(c); + (*mm->bf16xi8)(a, b, n_start, n_end, ppkernel); + } else if (mm->postops_type & GELU_TANH) { + amx_kernel::PP::BiasGeluStore ppkernel(c); + (*mm->bf16xi8)(a, b, n_start, n_end, ppkernel); + } else { + amx_kernel::PP::BiasGeluStore ppkernel(c); + (*mm->bf16xi8)(a, b, n_start, n_end, ppkernel); + } + } else { + if (mm->postops_type & GELU) { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + (*mm->bf16xi8)(a, b, n_start, n_end, ppkernel); + } else if (mm->postops_type & GELU_TANH) { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + (*mm->bf16xi8)(a, b, n_start, n_end, ppkernel); + } else { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + (*mm->bf16xi8)(a, b, n_start, n_end, ppkernel); + } + } + } else if (mm->dt_c == dnnl_f32) { + tensor2D c(M, N, reinterpret_cast(ptr_c), ldc); + if (!(mm->postops_type & BIAS)) { + if (mm->postops_type & GELU) { + amx_kernel::PP::BiasGeluStore ppkernel(c); + (*mm->bf16xi8)(a, b, n_start, n_end, ppkernel); + } else if (mm->postops_type & GELU_TANH) { + amx_kernel::PP::BiasGeluStore ppkernel(c); + (*mm->bf16xi8)(a, b, n_start, n_end, ppkernel); + } else { + amx_kernel::PP::BiasGeluStore ppkernel(c); + (*mm->bf16xi8)(a, b, n_start, n_end, ppkernel); + } + } else { + if (mm->postops_type & GELU) { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + (*mm->bf16xi8)(a, b, n_start, n_end, ppkernel); + } else if (mm->postops_type & GELU_TANH) { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + (*mm->bf16xi8)(a, b, n_start, n_end, ppkernel); + } else { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + (*mm->bf16xi8)(a, b, n_start, n_end, ppkernel); + } + } + } + } +} + +void fc_kernel_bf16w8_get_q_dq_amx(size_t K, size_t N, size_t stride, void* ptr, float* q, float* dq) { + float min, max; + tensor2D B(K, N, reinterpret_cast(ptr), stride); + amx_kernel::functional::get_min_max(B, min, max); + max = std::max(std::abs(max), std::abs(min)); + *q = 127 / max; + *dq = max / 127; +} + +} \ No newline at end of file diff --git a/src/fc_kernel_amx.hpp b/src/fc_kernel_amx.hpp new file mode 100644 index 0000000..93bf47f --- /dev/null +++ b/src/fc_kernel_amx.hpp @@ -0,0 +1,18 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "llm_fc.hpp" + +namespace llmdnn { + +bool fc_kernel_create_amx(fc_kernel** mm, const fc_create_param* param); + +void fc_kernel_destroy_amx(const fc_kernel* mm); + +void fc_kernel_execute_amx(const fc_kernel* mm, void* ptr_a, void* ptr_b, void* ptr_c, size_t lda, size_t ldb, size_t ldc, + size_t M, size_t N, size_t K, size_t n_start, size_t n_end, float* dq, float* q, float* bias); + +void fc_kernel_bf16w8_get_q_dq_amx(size_t K, size_t N, size_t stride, void* ptr, float* q, float* dq); + +} \ No newline at end of file diff --git a/src/fc_kernel_api.cpp b/src/fc_kernel_api.cpp new file mode 100644 index 0000000..d7b3e9c --- /dev/null +++ b/src/fc_kernel_api.cpp @@ -0,0 +1,47 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "llm_fc.hpp" +#include "fc_kernel_amx.hpp" +#include "mm_kernel_common_amx.hpp" +#include "utility_kernel_avx512.hpp" + +namespace llmdnn { + +static decltype(&fc_kernel_create) fc_kernel_create_ptr = fc_kernel_create_amx; +static decltype(&fc_kernel_destroy) fc_kernel_destroy_ptr = fc_kernel_destroy_amx; +static decltype(&fc_kernel_execute) fc_kernel_execute_ptr = fc_kernel_execute_amx; +static decltype(&fc_kernel_bf16w8_get_q_dq) fc_kernel_bf16w8_get_q_dq_ptr = fc_kernel_bf16w8_get_q_dq_amx; + +// interface +bool fc_kernel_create(fc_kernel** mm, const fc_create_param* param) { + return fc_kernel_create_ptr(mm, param); +} + +void fc_kernel_destroy(const fc_kernel* mm) { + fc_kernel_destroy_ptr(mm); +} + +void fc_kernel_execute(const fc_kernel* mm, void* ptr_a, void* ptr_b, void* ptr_c, size_t lda, size_t ldb, size_t ldc, + size_t M, size_t N, size_t K, size_t n_start, size_t n_end, float* dq, float* q, float* bias) { + fc_kernel_execute_ptr(mm, ptr_a, ptr_b, ptr_c, lda, ldb, ldc, M, N, K, n_start, n_end, dq, q, bias); +} + +void fc_kernel_bf16w8_get_q_dq(size_t K, size_t N, size_t stride, void* ptr, float* q, float* dq) { + fc_kernel_bf16w8_get_q_dq_ptr(K, N, stride, ptr, q, dq); +} + +} \ No newline at end of file diff --git a/src/gelu_kernel_avx512.hpp b/src/gelu_kernel_avx512.hpp new file mode 100644 index 0000000..c416ad5 --- /dev/null +++ b/src/gelu_kernel_avx512.hpp @@ -0,0 +1,386 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#ifdef _WIN32 +#include +#else +#include +#include +#endif +#include "common/bf16.hpp" +#include "llm_types.hpp" +#include "common/utility.hpp" +#include "utility_kernel_avx512.hpp" + +namespace llmdnn { + + // gelu_erf_minimax_approx_compute_vector_fwd in oneDNN + // x*0.5*(1+erf(x/sqrt(2))) = x*0.5*(1 + x*Polynomial(x^2)) + inline __m512 gelu_erf_minmax_approx_avx512(__m512 & x) { + auto x2 = _mm512_mul_ps(x, x); // x^2 + + auto x_positive = _mm512_castsi512_ps(_mm512_and_epi32(_mm512_castps_si512(x), _mm512_set1_epi32(0x7FFFFFFF))); // clear sign mask + auto x_half = _mm512_mul_ps(x, _mm512_set1_ps(0.5f)); + + auto poly = _mm512_castsi512_ps(_mm512_set1_epi32(0x1f1c83fd)); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0xa3198977))); // poly * x^2 + xxx + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0x268a7927))); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0xa998c963))); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0x2c67ddb2))); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0xaf013b2c))); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0x315d4a4f))); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0xb3969b11))); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0x35a776e9))); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0xb79b0914))); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0x3970b255))); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0xbb1b7399))); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0x3ca3621f))); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0xbe082bc7))); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0x3f4c4228))); + + // 1.0f + erf(x * inv_sqrt2) = 1.0f + x * P(x^2) + poly = _mm512_fmadd_ps(poly, x, _mm512_set1_ps(1.0f)); + // x*0.5*(1 + x*Polynomial(x^2)) + poly = _mm512_mul_ps(poly, x_half); + + // combine: + // zone_id + // 1 -inf; -saturation_lbound : 0.0f + // 2 -saturation_lbound; -linear_ubound : x*0.5*(1 + x*Polynomial(x^2)) + // 3 -linear_ubound, linear_ubound : x*0.5 + // 4 linear_ubound : saturation_lbound : x*0.5*(1 + x*Polynomial(x^2)) + // 5 saturation_lbound: +inf : x + constexpr int neg_saturation_lbound = 0xc0a00000; + constexpr int linear_ubound = 0x33800000; + constexpr int saturation_lbound = 0x40a00000; + + auto mask_x_not_zone1 = _mm512_cmpnlt_ps_mask(x, _mm512_castsi512_ps(_mm512_set1_epi32(neg_saturation_lbound))); + x = _mm512_maskz_mov_ps(mask_x_not_zone1, x); + + auto mask_x_in_zone5 = _mm512_cmpnle_ps_mask(x_positive, _mm512_castsi512_ps(_mm512_set1_epi32(saturation_lbound))); + poly = _mm512_mask_mov_ps(poly, mask_x_in_zone5, x); + + auto mask_x_in_zone3 = _mm512_cmple_ps_mask(x_positive, _mm512_castsi512_ps(_mm512_set1_epi32(linear_ubound))); + poly = _mm512_mask_mov_ps(poly, mask_x_in_zone3, x_half); + return poly; + } + + // gelu_tanh_compute_vector_fwd in oneDNN + inline __m512 gelu_tanh_avx512(__m512& x) { + // compute G(x) = sqrt_root_two_over_pi * x * (1 + fitting_const * x * x) + auto x2 = _mm512_mul_ps(x, x); + auto y = _mm512_fmadd_ps(x2, (__m512)_mm512_set1_epi32(0x3d372713), _mm512_set1_ps(1.0f)); + y = _mm512_mul_ps(y, x); + y = _mm512_mul_ps(y, (__m512)_mm512_set1_epi32(0x3f4c422a)); + + // compute tanh(G(x)) + // We split the positive domain in 33 intervals: + // a) [0; linear_ubound]: in this interval tanh(x) = x + // b) [linear_ubound; 0x1.8p-12]: This interval spans part of a + // half binade + // c) [0x1.8p-12; 0x1.0p-11], ..., [0x1.8p2; 0x1.0p3]: + // one interval for each half binade, there are 29 of those + // d) [0x1.0p3; saturation_ubound]: + // This interval spans part of a half binade + // e) [0x1.205966p3; saturation_ubound]: in this interval, tanh(x) = 1 + // For b-d, we need 31 polynomials and will do a table lookup for those. + // To simplify the logic, we will also put a) in the table. + + // The polynomials are of degree 6, so we need to gather 7 coefficients. + // - sse4.1: we do it the naive way using vextract/vinsert. + // Here we will extract the indices in gpr only once and + // reuse them as there are only 4 of them. + // - avx: we do the same as for sse4.1 but use half of the 64-bits + // registers to store the idx of second half of YMM and half for + // responding XMM. Halfway through the copy we exchange Xmm and + // higher half of Ymm and we get the expected result. + // - avx2: we use vpermps and blend for each coefficient. + // This needs an extra vmm to store the mask + // - avx512: because the table fits in 2 registers, we can use vpermi2d. + + // because tanh(x) = -tanh(-x), we extract sign to make x postive + // and reapply sign at the end + auto y_positive = _mm512_and_ps(y, (__m512)(_mm512_set1_epi32(0x7fffffff))); + + // We compute the indices for the table lookup + auto indices = _mm512_sub_epi32((__m512i)y_positive, _mm512_set1_epi32(0x39800000)); + + indices = _mm512_and_epi32(indices, _mm512_set1_epi32(0xffc00000)); + indices = _mm512_srli_epi32(indices, 22); + // we do the argument reduction + auto y_shift = _mm512_and_ps(y_positive, (__m512)_mm512_set1_epi32(0xffc00000)); + + y_shift = _mm512_sub_ps(y_positive, y_shift); + + static uint32_t OV_DECL_ALIGNED(64) tanh_pol_table[] = { + // coefficients of degree 0 + 0x00000000, + 0x39bfffff, + 0x39ffffff, + 0x3a3ffffe, + 0x3a7ffffb, + 0x3abffff7, + 0x3affffeb, + 0x3b3fffdc, + 0x3b7fffab, + 0x3bbfff70, + 0x3bfffeab, + 0x3c3ffdc0, + 0x3c7ffaab, + 0x3cbff701, + 0x3cffeaad, + 0x3d3fdc08, + 0x3d7faacd, + 0x3dbf7081, + 0x3dfeacc9, + 0x3e3dc7fd, + 0x3e7acbf5, + 0x3eb77a9f, + 0x3eec9a9f, + 0x3f22991f, + 0x3f42f7d6, + 0x3f67b7cc, + 0x3f76ca83, + 0x3f7ebbe9, + 0x3f7fd40c, + 0x3f7fff32, + 0x3f7ffffc, + 0x3f800000, + // coefficients of degree 1 + 0x3f800000, + 0x3f800018, + 0x3f7fffe8, + 0x3f7fffda, + 0x3f7fffdc, + 0x3f7fffdc, + 0x3f7fffac, + 0x3f7fff70, + 0x3f7ffeec, + 0x3f7ffdc0, + 0x3f7ffbed, + 0x3f7ff704, + 0x3f7feff5, + 0x3f7fdbca, + 0x3f7fbfff, + 0x3f7f7041, + 0x3f7f009b, + 0x3f7dc36c, + 0x3f7c0aa8, + 0x3f7734b8, + 0x3f70a4de, + 0x3f5f1fd8, + 0x3f495493, + 0x3f18b9ec, + 0x3ed706cb, + 0x3e390b06, + 0x3d90b11f, + 0x3c21a053, + 0x3aaf7fdb, + 0x37ccc1a3, + 0x355c6733, + 0x00000000, + // coefficients of degree 2 + 0x00000000, + 0xbe4e0ff1, + 0x3d25b1b1, + 0x3d6b6dab, + 0x3c9fb1d5, + 0xbabff06f, + 0x3c07b3f6, + 0xbb3fc1bc, + 0x3a9f5921, + 0xbbbf06f2, + 0xbbb0f402, + 0xbc47db9e, + 0xbc73d5e7, + 0xbca25bda, + 0xbcfca780, + 0xbd40e07c, + 0xbd7dab03, + 0xbdbe4a0f, + 0xbdfb14a5, + 0xbe36cc8d, + 0xbe6bd102, + 0xbe9fe7c5, + 0xbeba0f10, + 0xbec206a8, + 0xbea3c388, + 0xbe277d62, + 0xbd8b7960, + 0xbc209f49, + 0xbaad44ca, + 0xb7c6eeac, + 0xb663aa41, + 0x00000000, + // coefficients of degree 3 + 0x00000000, + 0x45b3ae96, + 0xc414eb20, + 0xc450e02e, + 0xc3152b4e, + 0xbead2f56, + 0xc2162e02, + 0xbeb4bd5a, + 0xc11a59a4, + 0xbed2f507, + 0xc020d32c, + 0x3dd0f506, + 0xbf2a75e2, + 0xbff950e3, + 0xbed47334, + 0xbe809b8c, + 0xbeb64532, + 0xbe961a5b, + 0xbe9b63ac, + 0xbea0d4b2, + 0xbe828a77, + 0xbe378612, + 0xbdc20908, + 0x3d2d3957, + 0x3dd46e89, + 0x3db3f629, + 0x3d2c5e7b, + 0x3bd20403, + 0x3a59dfae, + 0x3770af45, + 0x372cc014, + 0x00000000, + // coefficients of degree 4 + 0x00000000, + 0xcc981a1b, + 0x4a7edd3d, + 0x4ab1007c, + 0x48fedd9c, + 0x41a557b5, + 0x477ee32a, + 0x422557f5, + 0x45ff3ce4, + 0x42a55641, + 0x446e0867, + 0xc33dc19a, + 0x42915214, + 0x43af4fad, + 0x4110fe88, + 0xc1099b75, + 0x3fc8a8dc, + 0xbfbeaef5, + 0xbe365aad, + 0x3f4d9652, + 0x3ddfa08f, + 0x3e34e9b8, + 0x3e2d07a6, + 0x3dc63567, + 0x3cdaeb78, + 0xbcd17537, + 0xbc92829c, + 0xbb43ab99, + 0xb9b471dd, + 0xb6baad5a, + 0xb78bafc7, + 0x00000000, + // coefficients of degree 5 + 0x00000000, + 0x52f688d5, + 0xd0505c72, + 0xd08f98e3, + 0xce505cc9, + 0xc7162b8a, + 0xcc5061d6, + 0xc7162bdf, + 0xca50b37f, + 0xc7162a3a, + 0xc8422086, + 0x471a714e, + 0xc5ece1f1, + 0xc70e3d90, + 0xc3eba94a, + 0x43e0c424, + 0xc21f4552, + 0x42217cc8, + 0x405e7dc4, + 0xc10dd401, + 0x3e96b602, + 0xbd1a6d2f, + 0xbd393883, + 0xbd674682, + 0xbd310016, + 0xb961e269, + 0x3ba32495, + 0x3a7680d5, + 0x38b3173c, + 0x35a9deea, + 0x375c3f2a, + 0x00000000, + // coefficients of degree 6 + 0x00000000, + 0xd8995ed1, + 0x558285ea, + 0x55b2cd69, + 0x53028625, + 0x4bc9991f, + 0x5082898a, + 0x4b4999b3, + 0x4e02c07c, + 0x4ac99764, + 0x4b72c822, + 0xca40c0e1, + 0x489413e4, + 0x49b12224, + 0x46134c4e, + 0xc60c2d57, + 0x43c83910, + 0xc3c872d1, + 0xc186bc9e, + 0x42325bc3, + 0xbf2ffa4a, + 0x3d9a203c, + 0xbc545a43, + 0xbae08fee, + 0x3c80225d, + 0x3b1fd1df, + 0xba36b9d1, + 0xb91de544, + 0xb71f100f, + 0xb408e2ed, + 0xb685fec8, + 0x00000000, + }; + auto pol = _mm512_permutex2var_ps(_mm512_load_ps(tanh_pol_table + 32 * 6), indices, _mm512_load_ps(tanh_pol_table + 32 * 6 + 16)); + + pol = _mm512_fmadd_ps(pol, y_shift, _mm512_permutex2var_ps(_mm512_load_ps(tanh_pol_table + 32 * 5), indices, _mm512_load_ps(tanh_pol_table + 32 * 5 + 16))); + pol = _mm512_fmadd_ps(pol, y_shift, _mm512_permutex2var_ps(_mm512_load_ps(tanh_pol_table + 32 * 4), indices, _mm512_load_ps(tanh_pol_table + 32 * 4 + 16))); + pol = _mm512_fmadd_ps(pol, y_shift, _mm512_permutex2var_ps(_mm512_load_ps(tanh_pol_table + 32 * 3), indices, _mm512_load_ps(tanh_pol_table + 32 * 3 + 16))); + pol = _mm512_fmadd_ps(pol, y_shift, _mm512_permutex2var_ps(_mm512_load_ps(tanh_pol_table + 32 * 2), indices, _mm512_load_ps(tanh_pol_table + 32 * 2 + 16))); + pol = _mm512_fmadd_ps(pol, y_shift, _mm512_permutex2var_ps(_mm512_load_ps(tanh_pol_table + 32 * 1), indices, _mm512_load_ps(tanh_pol_table + 32 * 1 + 16))); + pol = _mm512_fmadd_ps(pol, y_shift, _mm512_permutex2var_ps(_mm512_load_ps(tanh_pol_table + 32 * 0), indices, _mm512_load_ps(tanh_pol_table + 32 * 0 + 16))); + + // we restore src with cleared sign, and keep sign + auto sign = _mm512_and_ps(y, (__m512)_mm512_set1_epi32(0x80000000)); + + // Now we blend the results + // [saturation_ubound; +inf[ : we return +/- 1 + auto dst = (__m512)_mm512_set1_epi32(0x3f800000); + // [linear_ubound; saturation_lbound] : we return +/- P(x) + auto mask = (__m512)_mm512_set1_epi32(0x41102cb3); + auto mask16 = _mm512_cmp_ps_mask(mask, y_positive, _CMP_GT_OS); + dst = _mm512_mask_blend_ps(mask16, dst, pol); + // [0; linear_ubound] : we return x + mask = (__m512)_mm512_set1_epi32(0x39ddb3d7); + mask16 = _mm512_cmp_ps_mask(mask, y_positive, _CMP_GT_OS); + dst = _mm512_mask_blend_ps(mask16, dst, y_positive); + + // We reapply the sign and return + dst = _mm512_xor_ps(dst, sign); + + // compute 0.5 * x * (1 + tanh(G(x))) + dst = _mm512_add_ps(dst, (__m512)_mm512_set1_epi32(0x3f800000)); + dst = _mm512_mul_ps(dst, (__m512)_mm512_set1_epi32(0x3f000000)); + dst = _mm512_mul_ps(dst, x); + return dst; + } +} \ No newline at end of file diff --git a/src/mha_gpt_amx.cpp b/src/mha_gpt_amx.cpp new file mode 100644 index 0000000..fd9971c --- /dev/null +++ b/src/mha_gpt_amx.cpp @@ -0,0 +1,394 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include + +#include "common/simple_parallel.hpp" +#include "common/utility.hpp" +#include "utility_kernel_avx512.hpp" +#include "mm_kernel_common_amx.hpp" +#include "softmax_kernel_avx512.hpp" +#include "transpose_kernel_avx512.hpp" +#include "llm_mha_gpt.hpp" +#include "mha_gpt_amx.hpp" + +using namespace ov::cpu; + +namespace llmdnn { + +struct mha_gpt_impl_amx : public mha_gpt::impl { + bool create(const mha_gpt::create_param& param) override; + void exec(const mha_gpt::exec_param& param) override; + + mha_gpt::create_param _create_param; + + void mha_bf16(const mha_gpt::exec_param ¶m); + void mha_i8(const mha_gpt::exec_param ¶m); + + size_t bufferMatMul0OutSize; + size_t bufferMatMul1OutSize; + + std::shared_ptr bufferMatMul0Out; + std::shared_ptr bufferMatMul1Out; + std::shared_ptr qkvQuantBuf; + + std::vector>> gemAvB_BF16xBF16; + std::vector>> qKtrGemm_BF16xBF16; + std::vector>> qKVGemm_BF16xBF16; + + std::vector>> qKtrGemm_i8xi8; + std::vector>> qKVGemm_u8xi8; + std::vector>> gemAvB_i8xi8; +}; + +bool mha_gpt_impl_amx::create(const mha_gpt::create_param& param) { + if (param.qkv_precision != dnnl_bf16 && param.qkv_precision != dnnl_s8) { + std::cout << "input precision must be bf16 or int8.\n"; + return false; + } + if (param.dst_precision != dnnl_bf16 && param.dst_precision != dnnl_s8) { + std::cout << "dst precision must be bf16 or int8.\n"; + return false; + } + _create_param = param; + + // q: [batch, num_heads, query_seq_len, head_size] + // k: [batch, num_heads, maxSeqLen(valid: key_seq_len), head_size] + // v: [batch, num_heads, maxSeqLen(valid: value_seq_len), head_size] + // attention_mask: [batch, 1, 1, maxSeqLen(valid: key_seq_len)] + // matmul1: [batch, num_heads, query_seq_len, head_size] + // attn_output: [batch, query_seq_len, num_heads * head_size] + size_t numThreads = getTotalThreads(); + if (_create_param.qkv_precision == dnnl_s8) { + qKtrGemm_i8xi8.resize(numThreads); + for (size_t i = 0; i < numThreads; i++) { + qKtrGemm_i8xi8[i] = std::make_shared>(false, true); + } + qKVGemm_u8xi8.resize(numThreads); + for (size_t i = 0; i < numThreads; i++) { + qKVGemm_u8xi8[i] = std::make_shared>(false, false); + } + gemAvB_i8xi8.resize(numThreads); + for (size_t i = 0; i < numThreads; i++) { + gemAvB_i8xi8[i] = std::make_shared>(); + } + qkvQuantBuf = std::shared_ptr( + reinterpret_cast(aligned_alloc(64, param.head_size * sizeof(float))), + [](void * p) { ::free(p); }); + memset(qkvQuantBuf.get(), 0, sizeof(param.head_size * sizeof(float))); + } else { + gemAvB_BF16xBF16.resize(numThreads); + for (size_t i = 0; i < numThreads; i++) { + gemAvB_BF16xBF16[i] = std::make_shared>(); + } + qKtrGemm_BF16xBF16.resize(numThreads); + for (size_t i = 0; i < numThreads; i++) { + qKtrGemm_BF16xBF16[i] = std::make_shared>(false, true); + } + qKVGemm_BF16xBF16.resize(numThreads); + for (size_t i = 0; i < numThreads; i++) { + qKVGemm_BF16xBF16[i] = std::make_shared>(false, false); + } + } + + bufferMatMul0OutSize = _create_param.max_seq_len * rndup(_create_param.max_seq_len * sizeof(float), 64); + bufferMatMul1OutSize = _create_param.max_seq_len * _create_param.head_size_aligned * sizeof(float); + + bufferMatMul0Out = std::shared_ptr( + reinterpret_cast(aligned_alloc(64, numThreads * bufferMatMul0OutSize)), + [](void * p) { ::free(p); }); + memset(bufferMatMul0Out.get(), 0, numThreads * bufferMatMul0OutSize); + bufferMatMul1Out = std::shared_ptr( + reinterpret_cast(aligned_alloc(64, numThreads * bufferMatMul1OutSize)), + [](void * p) { ::free(p); }); + memset(bufferMatMul1Out.get(), 0, numThreads * bufferMatMul1OutSize); + return true; +} + +void mha_gpt_impl_amx::mha_bf16(const mha_gpt::exec_param ¶m) { + uint8_t* pQIn0 = param.q; + auto& pKIn0 = param.k; + auto& attn_masks = param.attention_mask; + auto& pVIn0 = param.v; + uint8_t* pout = param.attn_output; + + auto outPrcSize = get_precision_size(_create_param.qkv_precision); + auto& gemAvB_ops = gemAvB_BF16xBF16; + auto& qKtrGemm_ops = qKtrGemm_BF16xBF16; + auto& qKVGemm_ops = qKVGemm_BF16xBF16; + bool is_vector = param.query_seq_len == 1 && _create_param.head_size >= 32 && _create_param.head_size <= 32 * 6; + size_t head_stride_in_q = _create_param.head_size_aligned * param.query_seq_len; + size_t batch_stride_in_q = head_stride_in_q * _create_param.num_heads; + size_t head_stride_in_attn = _create_param.head_size; + size_t batch_stride_in_attn = _create_param.head_size * _create_param.num_heads * param.query_seq_len; + size_t causal_mask_offset_start = param.key_seq_len - param.query_seq_len; + + if (is_vector) { + parallel_for2d(param.batch, _create_param.num_heads, [&](size_t threadNum, size_t i0, size_t i1) { + auto pQIn0_aux = pQIn0 + (i0 * batch_stride_in_q + i1 * head_stride_in_q) * get_precision_size(_create_param.qkv_precision); + auto pKIn0_aux = pKIn0[i0] + i1 * param.head_stride_in_kv * get_precision_size(_create_param.qkv_precision); + auto pVIn0_aux = pVIn0[i0] + i1 * param.head_stride_in_kv * get_precision_size(_create_param.qkv_precision); + + auto pAddIn1_aux = attn_masks + i0 * param.key_seq_len; + + auto bufferMatMul0Out_local = reinterpret_cast(bufferMatMul0Out.get() + threadNum * bufferMatMul0OutSize); + auto bufferMatMul1Out_local = reinterpret_cast(bufferMatMul1Out.get() + threadNum * bufferMatMul1OutSize); + + tensor2D matK(param.key_seq_len, _create_param.head_size, reinterpret_cast(pKIn0_aux), _create_param.head_size_aligned * sizeof(ov::bfloat16)); + // N: key_seq_len, K: head_size + // q[1, K] * transpose(k[N, K]) ==> + // k[N, K] * transpose(q[1, K]) ==> + // k[N, K] * q[K, 1] + (*gemAvB_ops[threadNum])(matK, reinterpret_cast(pQIn0_aux), reinterpret_cast(bufferMatMul0Out_local)); + + float* pMatMul0Out = reinterpret_cast(bufferMatMul0Out_local); + mul_add_f32_avx512(pMatMul0Out, pMatMul0Out, _create_param.normal_factor, pAddIn1_aux, param.key_seq_len); + softmax_avx512(reinterpret_cast(pMatMul0Out), pMatMul0Out, param.key_seq_len, nullptr); + auto pOut_aux = pout + (i0 * batch_stride_in_attn + i1 * head_stride_in_attn) * outPrcSize; + tensor2D matQK(param.query_seq_len, param.key_seq_len, reinterpret_cast(bufferMatMul0Out_local), rndup(param.key_seq_len * sizeof(ov::bfloat16), 64)); + tensor2D matV(param.key_seq_len, _create_param.head_size, reinterpret_cast(pVIn0_aux), _create_param.head_size_aligned * sizeof(ov::bfloat16)); + tensor2D matQKV(param.query_seq_len, _create_param.head_size, reinterpret_cast(bufferMatMul1Out_local), _create_param.head_size_aligned * sizeof(float)); + amx_kernel::PP::BiasGeluStore pp(matQKV); + (*qKVGemm_ops[threadNum])(matQK, matV, 0, _create_param.head_size, pp); + memcpy2d_stride_avx512(reinterpret_cast(pOut_aux), reinterpret_cast(bufferMatMul1Out_local), param.query_seq_len, + _create_param.head_size, _create_param.head_size_aligned * sizeof(float), _create_param.num_heads * _create_param.head_size * sizeof(ov::bfloat16), nullptr); + }); + } else { + auto numThreads = getTotalThreads(); + int seq_cout_all = rndup(param.query_seq_len, 32) / 32; + int work_amount = param.batch * _create_param.num_heads * seq_cout_all; + parallel_for(numThreads, [&](int threadNum) { + int i0; + int i1; + int seq; + int start {0}, end {0}; + splitter(work_amount, static_cast(numThreads), threadNum, start, end); + if (start >= work_amount) return; + + parallel_it_init(start, i0, param.batch, i1, _create_param.num_heads, seq, seq_cout_all); + uint8_t* prev_k = nullptr; + uint8_t* prev_v = nullptr; + for (int iwork = start; iwork < end; ++iwork) { + int seq_start = seq * 32; + int seq_end = std::min(static_cast(seq_start) + 32, param.query_seq_len); + int seq_cout = seq_end - seq_start; + // q: [batch, num_heads, query_seq_len, head_size] + // k: [batch, num_heads, key_seq_len, head_size] + // v: [batch, num_heads, value_seq_len, head_size] + auto pQIn0_aux = pQIn0 + (i0 * batch_stride_in_q + i1 * head_stride_in_q + seq_start * _create_param.head_size_aligned) * get_precision_size(_create_param.qkv_precision); + auto pKIn0_aux = pKIn0[i0] + i1 * param.head_stride_in_kv * get_precision_size(_create_param.qkv_precision); + auto pVIn0_aux = pVIn0[i0] + i1 * param.head_stride_in_kv * get_precision_size(_create_param.qkv_precision); + + auto bufferMatMul0Out_local = reinterpret_cast(bufferMatMul0Out.get() + threadNum * bufferMatMul0OutSize); + auto bufferMatMul1Out_local = reinterpret_cast(bufferMatMul1Out.get() + threadNum * bufferMatMul1OutSize); + + tensor2D matQ(seq_cout, _create_param.head_size, reinterpret_cast(pQIn0_aux), _create_param.head_size_aligned * sizeof(ov::bfloat16)); + tensor2D matK(param.key_seq_len, _create_param.head_size, reinterpret_cast(pKIn0_aux), _create_param.head_size_aligned * sizeof(ov::bfloat16)); + tensor2D matQK(seq_cout, param.key_seq_len, reinterpret_cast(bufferMatMul0Out_local), rndup(param.key_seq_len * sizeof(float), 64)); + amx_kernel::PP::BiasGeluStore pp(matQK); + (*qKtrGemm_ops[threadNum])(matQ, matK, 0, param.key_seq_len, pp, pKIn0_aux == prev_k); + prev_k = pKIn0_aux; + + auto pMatMul0Out = bufferMatMul0Out_local; + if (param.is_causal_in_attention) { + auto pAddIn1_aux = attn_masks + i0 * param.key_seq_len * param.query_seq_len; + // loop along K dimension + for (int m = 0; m < seq_cout; m++) { + float* src = reinterpret_cast(pMatMul0Out + m * rndup(param.key_seq_len * sizeof(float), 64)); + ov::bfloat16* dst = reinterpret_cast(pMatMul0Out + m * rndup(param.key_seq_len * sizeof(ov::bfloat16), 64)); + mul_add_f32_avx512(src, src, _create_param.normal_factor, pAddIn1_aux + (m + seq_start) * param.key_seq_len, param.key_seq_len); + softmax_avx512(dst, src, param.key_seq_len, nullptr); + } + } else { + auto pAddIn1_aux = attn_masks + i0 * param.key_seq_len; + // loop along K dimension + size_t valid_softmax_items = causal_mask_offset_start + seq_start + 1; + for (int m = 0; m < seq_cout; m++) { + float* src = reinterpret_cast(pMatMul0Out + m * rndup(param.key_seq_len * sizeof(float), 64)); + ov::bfloat16* dst = reinterpret_cast(pMatMul0Out + m * rndup(param.key_seq_len * sizeof(ov::bfloat16), 64)); + mul_add_f32_avx512(src, src, _create_param.normal_factor, pAddIn1_aux, valid_softmax_items); + softmax_avx512(dst, src, valid_softmax_items, nullptr); + // attn_scores = torch.where(causal_mask, attn_scores, mask_value) + if (param.key_seq_len > valid_softmax_items) { + auto *invalidPtr = dst + valid_softmax_items; + memset(static_cast(invalidPtr), 0, (param.key_seq_len - valid_softmax_items) * get_precision_size(_create_param.qkv_precision)); + valid_softmax_items = std::min(valid_softmax_items + 1, param.key_seq_len); + } + } + } + + auto pOut_aux = pout + (i0 * batch_stride_in_attn + i1 * head_stride_in_attn + + seq_start * head_stride_in_attn * _create_param.num_heads) * outPrcSize; + tensor2D matQKBF16(seq_cout, param.key_seq_len, reinterpret_cast(bufferMatMul0Out_local), rndup(param.key_seq_len * sizeof(ov::bfloat16), 64)); + tensor2D matV(param.key_seq_len, _create_param.head_size, reinterpret_cast(pVIn0_aux), _create_param.head_size_aligned * sizeof(ov::bfloat16)); + tensor2D matQKV(seq_cout, _create_param.head_size, reinterpret_cast(bufferMatMul1Out_local), _create_param.head_size_aligned * sizeof(float)); + amx_kernel::PP::BiasGeluStore pp2(matQKV); + (*qKVGemm_ops[threadNum])(matQKBF16, matV, 0, _create_param.head_size, pp2, prev_v == pVIn0_aux); + prev_v = pVIn0_aux; + memcpy2d_stride_avx512(reinterpret_cast(pOut_aux), reinterpret_cast(bufferMatMul1Out_local), seq_cout, + _create_param.head_size, _create_param.head_size_aligned * sizeof(float), _create_param.num_heads * _create_param.head_size * sizeof(ov::bfloat16), nullptr); + parallel_it_step(i0, param.batch, i1, _create_param.num_heads, seq, seq_cout_all); + } + }); + } +} + +void mha_gpt_impl_amx::mha_i8(const mha_gpt::exec_param ¶m) { + uint8_t* pQIn0 = param.q; + auto& pKIn0 = param.k; + auto& attn_masks = param.attention_mask; + auto& pVIn0 = param.v; + uint8_t* pout = param.attn_output; + + auto outPrcSize = get_precision_size(_create_param.dst_precision); + auto& gemAvB_ops = gemAvB_i8xi8; + auto& qKtrGemm_ops = qKtrGemm_i8xi8; + auto& qKVGemm_ops = qKVGemm_u8xi8; + bool is_vector = param.query_seq_len == 1 && _create_param.head_size >= 64 && _create_param.head_size <= 64 * 6; + // dequant param + auto mul_scales = _create_param.normal_factor * param.q_dequant * param.k_dequant; + // prepare for per channel + assert(param.qkv_quant.size() == 1 || param.qkv_quant.size() == _create_param.head_size); + for (size_t i = 0; i < param.qkv_quant.size(); i++) { + (qkvQuantBuf.get())[i] = param.qkv_quant[i] * param.v_dequant / param.qk_quant; + } + if (param.qkv_quant.size() == 1) { + std::fill(qkvQuantBuf.get() + 1, qkvQuantBuf.get() + _create_param.head_size, *qkvQuantBuf.get()); + } + size_t head_stride_in_q = _create_param.head_size_aligned * param.query_seq_len; + size_t batch_stride_in_q = head_stride_in_q * _create_param.num_heads; + size_t head_stride_in_attn = _create_param.head_size; + size_t batch_stride_in_attn = _create_param.head_size * _create_param.num_heads * param.query_seq_len; + size_t causal_mask_offset_start = param.key_seq_len - param.query_seq_len; + + if (is_vector) { + parallel_for2d(param.batch, _create_param.num_heads, [&](size_t threadNum, size_t i0, size_t i1) { + auto pQIn0_aux = pQIn0 + (i0 * batch_stride_in_q + i1 * head_stride_in_q) * get_precision_size(_create_param.qkv_precision); + auto pKIn0_aux = pKIn0[i0] + i1 * param.head_stride_in_kv * get_precision_size(_create_param.qkv_precision); + auto pVIn0_aux = pVIn0[i0] + i1 * param.head_stride_in_kv * get_precision_size(_create_param.qkv_precision); + + auto pAddIn1_aux = attn_masks + i0 * param.key_seq_len; + + auto bufferMatMul0Out_local = reinterpret_cast(bufferMatMul0Out.get() + threadNum * bufferMatMul0OutSize); + auto bufferMatMul1Out_local = reinterpret_cast(bufferMatMul1Out.get() + threadNum * bufferMatMul1OutSize); + + tensor2D matK(param.key_seq_len, _create_param.head_size, reinterpret_cast(pKIn0_aux), _create_param.head_size_aligned * sizeof(int8_t)); + // N: key_seq_len, K: head_size + // q[1, K] * transpose(k[N, K]) ==> + // k[N, K] * transpose(q[1, K]) ==> + // k[N, K] * q[K, 1] + (*gemAvB_ops[threadNum])(matK, reinterpret_cast(pQIn0_aux), reinterpret_cast(bufferMatMul0Out_local)); + cvt_i32_f32_avx512(reinterpret_cast(bufferMatMul0Out_local), reinterpret_cast(bufferMatMul0Out_local), param.key_seq_len); + + float* pMatMul0Out = reinterpret_cast(bufferMatMul0Out_local); + mul_add_f32_avx512(pMatMul0Out, pMatMul0Out, mul_scales, pAddIn1_aux, param.key_seq_len); + softmax_avx512(reinterpret_cast(pMatMul0Out), pMatMul0Out, param.key_seq_len, param.qk_quant); + auto pOut_aux = pout + (i0 * batch_stride_in_attn + i1 * head_stride_in_attn) * outPrcSize; + tensor2D matQK(param.query_seq_len, param.key_seq_len, reinterpret_cast(bufferMatMul0Out_local), rndup(param.key_seq_len * sizeof(uint8_t), 64)); + tensor2D matV(param.key_seq_len, _create_param.head_size, reinterpret_cast(pVIn0_aux), _create_param.head_size_aligned * sizeof(int8_t)); + tensor2D matQKV(param.query_seq_len, _create_param.head_size, reinterpret_cast(bufferMatMul1Out_local), _create_param.head_size_aligned * sizeof(float)); + amx_kernel::PP::BiasGeluStore pp(matQKV); + (*qKVGemm_ops[threadNum])(matQK, matV, 0, _create_param.head_size, pp); + memcpy2d_stride_avx512(reinterpret_cast(pOut_aux), reinterpret_cast(bufferMatMul1Out_local), param.query_seq_len, + _create_param.head_size, _create_param.head_size_aligned * sizeof(float), _create_param.num_heads * _create_param.head_size, qkvQuantBuf.get()); + }); + } else { + auto numThreads = getTotalThreads(); + int seq_cout_all = rndup(param.query_seq_len, 32) / 32; + int work_amount = param.batch * _create_param.num_heads * seq_cout_all; + parallel_for(numThreads, [&](int threadNum) { + int i0; + int i1; + int seq; + int start {0}, end {0}; + splitter(work_amount, static_cast(numThreads), threadNum, start, end); + if (start >= work_amount) return; + + parallel_it_init(start, i0, param.batch, i1, _create_param.num_heads, seq, seq_cout_all); + uint8_t* prev_k = nullptr; + uint8_t* prev_v = nullptr; + for (int iwork = start; iwork < end; ++iwork) { + int seq_start = seq * 32; + int seq_end = std::min(static_cast(seq_start) + 32, param.query_seq_len); + int seq_cout = seq_end - seq_start; + // q: [batch, num_heads, query_seq_len, head_size] + // k: [batch, num_heads, key_seq_len, head_size] + // v: [batch, num_heads, value_seq_len, head_size] + auto pQIn0_aux = pQIn0 + (i0 * batch_stride_in_q + i1 * head_stride_in_q + seq_start * _create_param.head_size_aligned); + auto pKIn0_aux = pKIn0[i0] + i1 * param.head_stride_in_kv; + auto pVIn0_aux = pVIn0[i0] + i1 * param.head_stride_in_kv; + + auto bufferMatMul0Out_local = reinterpret_cast(bufferMatMul0Out.get() + threadNum * bufferMatMul0OutSize); + auto bufferMatMul1Out_local = reinterpret_cast(bufferMatMul1Out.get() + threadNum * bufferMatMul1OutSize); + + tensor2D matQ(seq_cout, _create_param.head_size, reinterpret_cast(pQIn0_aux), _create_param.head_size_aligned * sizeof(int8_t)); + tensor2D matK(param.key_seq_len, _create_param.head_size, reinterpret_cast(pKIn0_aux), _create_param.head_size_aligned * sizeof(int8_t)); + tensor2D matQK(seq_cout, param.key_seq_len, reinterpret_cast(bufferMatMul0Out_local), rndup(param.key_seq_len * sizeof(float), 64)); + amx_kernel::PP::BiasGeluStore pp(matQK); + (*qKtrGemm_ops[threadNum])(matQ, matK, 0, param.key_seq_len, pp, prev_k == pKIn0_aux); + prev_k = pKIn0_aux; + + auto pMatMul0Out = bufferMatMul0Out_local; + if (param.is_causal_in_attention) { + auto pAddIn1_aux = attn_masks + i0 * param.key_seq_len * param.query_seq_len; + // loop along K dimension + for (int m = 0; m < seq_cout; m++) { + float* src = reinterpret_cast(pMatMul0Out + m * rndup(param.key_seq_len * sizeof(float), 64)); + uint8_t* dst = reinterpret_cast(pMatMul0Out + m * rndup(param.key_seq_len * sizeof(uint8_t), 64)); + mul_add_f32_avx512(src, src, mul_scales, pAddIn1_aux + (m + seq_start) * param.key_seq_len, param.key_seq_len); + softmax_avx512(dst, src, param.key_seq_len, param.qk_quant); + } + } else { + auto pAddIn1_aux = attn_masks + i0 * param.key_seq_len; + // loop along K dimension + size_t valid_softmax_items = causal_mask_offset_start + seq_start + 1; + for (int m = 0; m < seq_cout; m++) { + float* src = reinterpret_cast(pMatMul0Out + m * rndup(param.key_seq_len * sizeof(float), 64)); + uint8_t* dst = reinterpret_cast(pMatMul0Out + m * rndup(param.key_seq_len * sizeof(uint8_t), 64)); + mul_add_f32_avx512(src, src, mul_scales, pAddIn1_aux, valid_softmax_items); + softmax_avx512(dst, src, valid_softmax_items, param.qk_quant); + // attn_scores = torch.where(causal_mask, attn_scores, mask_value) + if (param.key_seq_len > valid_softmax_items) { + auto *invalidPtr = dst + valid_softmax_items; + memset(invalidPtr, 0, (param.key_seq_len - valid_softmax_items) * get_precision_size(_create_param.qkv_precision)); + valid_softmax_items = std::min(valid_softmax_items + 1, param.key_seq_len); + } + } + } + auto pOut_aux = pout + (i0 * batch_stride_in_attn + i1 * head_stride_in_attn + + seq_start * head_stride_in_attn * _create_param.num_heads) * outPrcSize; + tensor2D matQKI8(seq_cout, param.key_seq_len, reinterpret_cast(bufferMatMul0Out_local), rndup(param.key_seq_len * sizeof(uint8_t), 64)); + tensor2D matV(param.key_seq_len, _create_param.head_size, reinterpret_cast(pVIn0_aux), _create_param.head_size_aligned * sizeof(int8_t)); + tensor2D matQKV(seq_cout, _create_param.head_size, reinterpret_cast(bufferMatMul1Out_local), _create_param.head_size_aligned * sizeof(float)); + amx_kernel::PP::BiasGeluStore pp2(matQKV); + (*qKVGemm_ops[threadNum])(matQKI8, matV, 0, _create_param.head_size, pp2, prev_v == pVIn0_aux); + prev_v = pVIn0_aux; + // matmul1: [batch, num_heads, query_seq_len, head_size] + // attn_output: [batch, query_seq_len, num_heads * head_size] + memcpy2d_stride_avx512(reinterpret_cast(pOut_aux), reinterpret_cast(bufferMatMul1Out_local), seq_cout, + _create_param.head_size, _create_param.head_size_aligned * sizeof(float), _create_param.num_heads * _create_param.head_size, qkvQuantBuf.get()); + parallel_it_step(i0, param.batch, i1, _create_param.num_heads, seq, seq_cout_all); + } + }); + } +} + +void mha_gpt_impl_amx::exec(const mha_gpt::exec_param& param) { + if (_create_param.qkv_precision == dnnl_f32) { + assert(false); + } else if (_create_param.qkv_precision == dnnl_bf16) { + mha_bf16(param); + } else if (_create_param.qkv_precision == dnnl_s8) { + mha_i8(param); + } else { + assert(false && "doesn't support provided input precisions"); + } +} + +std::shared_ptr new_impl_amx() { + return std::make_shared(); +} + +} \ No newline at end of file diff --git a/src/mha_gpt_amx.hpp b/src/mha_gpt_amx.hpp new file mode 100644 index 0000000..9409af9 --- /dev/null +++ b/src/mha_gpt_amx.hpp @@ -0,0 +1,17 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include "llm_types.hpp" +#include "llm_mha_gpt.hpp" + +namespace llmdnn { + +std::shared_ptr new_impl_amx(); + +} diff --git a/src/mha_gpt_api.cpp b/src/mha_gpt_api.cpp new file mode 100644 index 0000000..e702e8e --- /dev/null +++ b/src/mha_gpt_api.cpp @@ -0,0 +1,24 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include + +#include "mha_gpt_amx.hpp" + +namespace llmdnn { + +// interface +mha_gpt::mha_gpt(): _impl(new_impl_amx()) { +} + +bool mha_gpt::create(const create_param& param) { + return _impl->create(param); +} + +void mha_gpt::exec(const exec_param& param) { + _impl->exec(param); +} + +} \ No newline at end of file diff --git a/src/mm_kernel_amx.cpp b/src/mm_kernel_amx.cpp new file mode 100644 index 0000000..d4adf8f --- /dev/null +++ b/src/mm_kernel_amx.cpp @@ -0,0 +1,123 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "llm_mm.hpp" +#include "mm_kernel_common_amx.hpp" +#include "utility_kernel_avx512.hpp" +#include "mm_kernel_amx.hpp" + +namespace llmdnn { + +using ov::bfloat16; +struct mm_kernel { + std::shared_ptr> bf16xbf16; + std::shared_ptr> i8xi8; + std::shared_ptr> u8xi8; + + std::shared_ptr> i8xi8_gemv; + std::shared_ptr> bf16xbf16_gemv; + + data_type_t dt_a; + data_type_t dt_b; + bool b_is_transpose; +}; + +// interface +bool mm_kernel_create_amx(mm_kernel** mm, const mm_create_param* param) { + mm_kernel* m = nullptr; + if (param == nullptr || mm == nullptr) { + std::cout << "mm_kernel_create: invalid input parameter.\n"; + goto ERR; + } + + m = new mm_kernel; + if (param->b_is_gemv) { + if (param->dt_a == dnnl_s8 && param->dt_b == dnnl_s8) { + m->i8xi8_gemv = std::make_shared>(); + } else if (param->dt_a == dnnl_bf16 && param->dt_b == dnnl_bf16) { + m->bf16xbf16_gemv = std::make_shared>(); + } else { + std::cout << "mm_kernel_create: unsupport gemv input type, a: " << param->dt_a << ", b: " << param->dt_b << ".\n"; + goto ERR; + } + } else { + if (param->dt_a == dnnl_s8 && param->dt_b == dnnl_s8) { + m->i8xi8 = std::make_shared>(false, param->b_is_trans); + } else if (param->dt_a == dnnl_u8 && param->dt_b == dnnl_s8) { + m->u8xi8 = std::make_shared>(false, param->b_is_trans); + } else if (param->dt_a == dnnl_bf16 && param->dt_b == dnnl_bf16) { + m->bf16xbf16 = std::make_shared>(false, param->b_is_trans); + } else { + std::cout << "mm_kernel_create: unsupport input type, a: " << param->dt_a << ", b: " << param->dt_b << ".\n"; + goto ERR; + } + } + m->dt_a = param->dt_a; + m->dt_b = param->dt_b; + m->b_is_transpose = param->b_is_trans; + + *mm = m; + return true; +ERR: + delete m; + return false; +} + +void mm_kernel_destroy_amx(const mm_kernel* mm) { + if (mm) { + delete mm; + } +} + +void mm_kernel_execute_amx(const mm_kernel* mm, void* ptr_a, void* ptr_b, void* ptr_c, size_t lda, size_t ldb, size_t ldc, + size_t M, size_t N, size_t K) { + size_t b_d0 = K, b_d1 = N; + if (mm->b_is_transpose) { + b_d0 = N; + b_d1 = K; + } + if (mm->i8xi8_gemv) { + tensor2D a(M, K, reinterpret_cast(ptr_a), lda); + (*mm->i8xi8_gemv)(a, reinterpret_cast(ptr_b), reinterpret_cast(ptr_c)); + cvt_i32_f32_avx512(reinterpret_cast(ptr_c), reinterpret_cast(ptr_c), M); + } else if (mm->i8xi8) { + tensor2D a(M, K, reinterpret_cast(ptr_a), lda); + tensor2D b(b_d0, b_d1, reinterpret_cast(ptr_b), ldb); + tensor2D c(M, N, reinterpret_cast(ptr_c), ldc); + amx_kernel::PP::BiasGeluStore pp(c); + (*mm->i8xi8)(a, b, 0, N, pp); + } else if (mm->u8xi8) { + tensor2D a(M, K, reinterpret_cast(ptr_a), lda); + tensor2D b(b_d0, b_d1, reinterpret_cast(ptr_b), ldb); + tensor2D c(M, N, reinterpret_cast(ptr_c), ldc); + amx_kernel::PP::BiasGeluStore pp(c); + (*mm->u8xi8)(a, b, 0, N, pp); + } else if (mm->bf16xbf16_gemv) { + tensor2D a(M, K, reinterpret_cast(ptr_a), lda); + (*mm->bf16xbf16_gemv)(a, reinterpret_cast(ptr_b), reinterpret_cast(ptr_c)); + } else if (mm->bf16xbf16) { + tensor2D a(M, K, reinterpret_cast(ptr_a), lda); + tensor2D b(b_d0, b_d1, reinterpret_cast(ptr_b), ldb); + tensor2D c(M, N, reinterpret_cast(ptr_c), ldc); + amx_kernel::PP::BiasGeluStore pp(c); + (*mm->bf16xbf16)(a, b, 0, N, pp); + } else { + std::cout << "mm_kernel_execute: no valid kernel created, call create first.\n"; + } +} + + +} \ No newline at end of file diff --git a/src/mm_kernel_amx.hpp b/src/mm_kernel_amx.hpp new file mode 100644 index 0000000..1d03818 --- /dev/null +++ b/src/mm_kernel_amx.hpp @@ -0,0 +1,28 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "llm_mm.hpp" + +namespace llmdnn { + +bool mm_kernel_create_amx(mm_kernel** mm, const mm_create_param* param); + +void mm_kernel_destroy_amx(const mm_kernel* mm); + +void mm_kernel_execute_amx(const mm_kernel* mm, void* ptr_a, void* ptr_b, void* ptr_c, size_t lda, size_t ldb, size_t ldc, + size_t M, size_t N, size_t K); + +} \ No newline at end of file diff --git a/src/mm_kernel_api.cpp b/src/mm_kernel_api.cpp new file mode 100644 index 0000000..9b2ebff --- /dev/null +++ b/src/mm_kernel_api.cpp @@ -0,0 +1,28 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "llm_mm.hpp" +#include "mm_kernel_amx.hpp" + +namespace llmdnn { + +static decltype(&mm_kernel_create) mm_kernel_create_ptr = mm_kernel_create_amx; +static decltype(&mm_kernel_destroy) mm_kernel_destroy_ptr = mm_kernel_destroy_amx; +static decltype(&mm_kernel_execute) mm_kernel_execute_ptr = mm_kernel_execute_amx; + +// interface +bool mm_kernel_create(mm_kernel** mm, const mm_create_param* param) { + return mm_kernel_create_ptr(mm, param); +} + +void mm_kernel_destroy(const mm_kernel* mm) { + mm_kernel_destroy_ptr(mm); +} + +void mm_kernel_execute(const mm_kernel* mm, void* ptr_a, void* ptr_b, void* ptr_c, size_t lda, size_t ldb, size_t ldc, + size_t M, size_t N, size_t K) { + mm_kernel_execute_ptr(mm, ptr_a, ptr_b, ptr_c, lda, ldb, ldc, M, N, K); +} + +} \ No newline at end of file diff --git a/src/mm_kernel_common_amx.hpp b/src/mm_kernel_common_amx.hpp new file mode 100644 index 0000000..0bb22ce --- /dev/null +++ b/src/mm_kernel_common_amx.hpp @@ -0,0 +1,2396 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "common/bf16.hpp" +#include "common/tensor2d.hpp" +#include "utility_kernel_amx.hpp" +#include "gelu_kernel_avx512.hpp" +#include "llm_fc.hpp" + +#ifdef _WIN32 +#include +#else +#include +#endif + +#ifdef ENABLE_NUMA +#include "numa.h" +#endif + +using namespace llmdnn; + +namespace amx_kernel { + +namespace functional { + + inline void transpose_m512i_16x16(__m512i &r0, __m512i &r1, __m512i &r2, __m512i &r3, + __m512i &r4, __m512i &r5, __m512i &r6, __m512i &r7, + __m512i &r8, __m512i &r9, __m512i &ra, __m512i &rb, + __m512i &rc, __m512i &rd, __m512i &re, __m512i &rf) { + __m512i t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, ta, tb, tc, td, te, tf; + + t0 = _mm512_unpacklo_epi32(r0,r1); // 0 16 1 17 4 20 5 21 8 24 9 25 12 28 13 29 + t1 = _mm512_unpackhi_epi32(r0,r1); // 2 18 3 19 6 22 7 23 10 26 11 27 14 30 15 31 + t2 = _mm512_unpacklo_epi32(r2,r3); // 32 48 33 49 ... + t3 = _mm512_unpackhi_epi32(r2,r3); // 34 50 35 51 ... + t4 = _mm512_unpacklo_epi32(r4,r5); // 64 80 65 81 ... + t5 = _mm512_unpackhi_epi32(r4,r5); // 66 82 67 83 ... + t6 = _mm512_unpacklo_epi32(r6,r7); // 96 112 97 113 ... + t7 = _mm512_unpackhi_epi32(r6,r7); // 98 114 99 115 ... + t8 = _mm512_unpacklo_epi32(r8,r9); // 128 ... + t9 = _mm512_unpackhi_epi32(r8,r9); // 130 ... + ta = _mm512_unpacklo_epi32(ra,rb); // 160 ... + tb = _mm512_unpackhi_epi32(ra,rb); // 162 ... + tc = _mm512_unpacklo_epi32(rc,rd); // 196 ... + td = _mm512_unpackhi_epi32(rc,rd); // 198 ... + te = _mm512_unpacklo_epi32(re,rf); // 228 ... + tf = _mm512_unpackhi_epi32(re,rf); // 230 ... + + r0 = _mm512_unpacklo_epi64(t0,t2); // 0 16 32 48 ... + r1 = _mm512_unpackhi_epi64(t0,t2); // 1 17 33 49 ... + r2 = _mm512_unpacklo_epi64(t1,t3); // 2 18 34 49 ... + r3 = _mm512_unpackhi_epi64(t1,t3); // 3 19 35 51 ... + r4 = _mm512_unpacklo_epi64(t4,t6); // 64 80 96 112 ... + r5 = _mm512_unpackhi_epi64(t4,t6); // 65 81 97 114 ... + r6 = _mm512_unpacklo_epi64(t5,t7); // 66 82 98 113 ... + r7 = _mm512_unpackhi_epi64(t5,t7); // 67 83 99 115 ... + r8 = _mm512_unpacklo_epi64(t8,ta); // 128 144 160 176 ... + r9 = _mm512_unpackhi_epi64(t8,ta); // 129 145 161 178 ... + ra = _mm512_unpacklo_epi64(t9,tb); // 130 146 162 177 ... + rb = _mm512_unpackhi_epi64(t9,tb); // 131 147 163 179 ... + rc = _mm512_unpacklo_epi64(tc,te); // 192 208 228 240 ... + rd = _mm512_unpackhi_epi64(tc,te); // 193 209 229 241 ... + re = _mm512_unpacklo_epi64(td,tf); // 194 210 230 242 ... + rf = _mm512_unpackhi_epi64(td,tf); // 195 211 231 243 ... + + t0 = _mm512_shuffle_i32x4(r0, r4, 0x88); // 0 16 32 48 8 24 40 56 64 80 96 112 ... + t1 = _mm512_shuffle_i32x4(r1, r5, 0x88); // 1 17 33 49 ... + t2 = _mm512_shuffle_i32x4(r2, r6, 0x88); // 2 18 34 50 ... + t3 = _mm512_shuffle_i32x4(r3, r7, 0x88); // 3 19 35 51 ... + t4 = _mm512_shuffle_i32x4(r0, r4, 0xdd); // 4 20 36 52 ... + t5 = _mm512_shuffle_i32x4(r1, r5, 0xdd); // 5 21 37 53 ... + t6 = _mm512_shuffle_i32x4(r2, r6, 0xdd); // 6 22 38 54 ... + t7 = _mm512_shuffle_i32x4(r3, r7, 0xdd); // 7 23 39 55 ... + t8 = _mm512_shuffle_i32x4(r8, rc, 0x88); // 128 144 160 176 ... + t9 = _mm512_shuffle_i32x4(r9, rd, 0x88); // 129 145 161 177 ... + ta = _mm512_shuffle_i32x4(ra, re, 0x88); // 130 146 162 178 ... + tb = _mm512_shuffle_i32x4(rb, rf, 0x88); // 131 147 163 179 ... + tc = _mm512_shuffle_i32x4(r8, rc, 0xdd); // 132 148 164 180 ... + td = _mm512_shuffle_i32x4(r9, rd, 0xdd); // 133 149 165 181 ... + te = _mm512_shuffle_i32x4(ra, re, 0xdd); // 134 150 166 182 ... + tf = _mm512_shuffle_i32x4(rb, rf, 0xdd); // 135 151 167 183 ... + + r0 = _mm512_shuffle_i32x4(t0, t8, 0x88); // 0 16 32 48 64 80 96 112 ... 240 + r1 = _mm512_shuffle_i32x4(t1, t9, 0x88); // 1 17 33 49 66 81 97 113 ... 241 + r2 = _mm512_shuffle_i32x4(t2, ta, 0x88); // 2 18 34 50 67 82 98 114 ... 242 + r3 = _mm512_shuffle_i32x4(t3, tb, 0x88); // 3 19 35 51 68 83 99 115 ... 243 + r4 = _mm512_shuffle_i32x4(t4, tc, 0x88); // 4 ... + r5 = _mm512_shuffle_i32x4(t5, td, 0x88); // 5 ... + r6 = _mm512_shuffle_i32x4(t6, te, 0x88); // 6 ... + r7 = _mm512_shuffle_i32x4(t7, tf, 0x88); // 7 ... + r8 = _mm512_shuffle_i32x4(t0, t8, 0xdd); // 8 ... + r9 = _mm512_shuffle_i32x4(t1, t9, 0xdd); // 9 ... + ra = _mm512_shuffle_i32x4(t2, ta, 0xdd); // 10 ... + rb = _mm512_shuffle_i32x4(t3, tb, 0xdd); // 11 ... + rc = _mm512_shuffle_i32x4(t4, tc, 0xdd); // 12 ... + rd = _mm512_shuffle_i32x4(t5, td, 0xdd); // 13 ... + re = _mm512_shuffle_i32x4(t6, te, 0xdd); // 14 ... + rf = _mm512_shuffle_i32x4(t7, tf, 0xdd); // 15 31 47 63 79 96 111 127 ... 255 + } + + inline void transpose_epi32_16x16(void * _dst, const void * src, int stride) { + auto * dst = reinterpret_cast(_dst); + __m512i r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf; + auto * pA = reinterpret_cast(src); + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_loadu_epi32(pA + 2*stride); + r3 = _mm512_loadu_epi32(pA + 3*stride); + r4 = _mm512_loadu_epi32(pA + 4*stride); + r5 = _mm512_loadu_epi32(pA + 5*stride); + r6 = _mm512_loadu_epi32(pA + 6*stride); + r7 = _mm512_loadu_epi32(pA + 7*stride); + r8 = _mm512_loadu_epi32(pA + 8*stride); + r9 = _mm512_loadu_epi32(pA + 9*stride); + ra = _mm512_loadu_epi32(pA + 10*stride); + rb = _mm512_loadu_epi32(pA + 11*stride); + rc = _mm512_loadu_epi32(pA + 12*stride); + rd = _mm512_loadu_epi32(pA + 13*stride); + re = _mm512_loadu_epi32(pA + 14*stride); + rf = _mm512_loadu_epi32(pA + 15*stride); + + transpose_m512i_16x16(r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf); + + _mm512_storeu_epi32(dst, r0); + _mm512_storeu_epi32(dst + 16, r1); + _mm512_storeu_epi32(dst + 2*16, r2); + _mm512_storeu_epi32(dst + 3*16, r3); + _mm512_storeu_epi32(dst + 4*16, r4); + _mm512_storeu_epi32(dst + 5*16, r5); + _mm512_storeu_epi32(dst + 6*16, r6); + _mm512_storeu_epi32(dst + 7*16, r7); + _mm512_storeu_epi32(dst + 8*16, r8); + _mm512_storeu_epi32(dst + 9*16, r9); + _mm512_storeu_epi32(dst + 10*16, ra); + _mm512_storeu_epi32(dst + 11*16, rb); + _mm512_storeu_epi32(dst + 12*16, rc); + _mm512_storeu_epi32(dst + 13*16, rd); + _mm512_storeu_epi32(dst + 14*16, re); + _mm512_storeu_epi32(dst + 15*16, rf); + } + + // 16xN, N<=16, non-valid part is filled with zeros + inline void transpose_epi32_16xN(void * _dst, const void * src, int stride, int valid_bytes) { + auto * dst = reinterpret_cast(_dst); + __m512i r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf; + auto * pA = reinterpret_cast(src); + uint64_t mask_value = 0xFFFFFFFFFFFFFFFFull >> (64 - valid_bytes); + __mmask64 mask = _cvtu64_mask64(mask_value); + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); + r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); + r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); + r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); + r6 = _mm512_maskz_loadu_epi8 (mask, pA + 6*stride); + r7 = _mm512_maskz_loadu_epi8 (mask, pA + 7*stride); + r8 = _mm512_maskz_loadu_epi8 (mask, pA + 8*stride); + r9 = _mm512_maskz_loadu_epi8 (mask, pA + 9*stride); + ra = _mm512_maskz_loadu_epi8 (mask, pA + 10*stride); + rb = _mm512_maskz_loadu_epi8 (mask, pA + 11*stride); + rc = _mm512_maskz_loadu_epi8 (mask, pA + 12*stride); + rd = _mm512_maskz_loadu_epi8 (mask, pA + 13*stride); + re = _mm512_maskz_loadu_epi8 (mask, pA + 14*stride); + rf = _mm512_maskz_loadu_epi8 (mask, pA + 15*stride); + transpose_m512i_16x16(r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf); + _mm512_storeu_epi32(dst, r0); + _mm512_storeu_epi32(dst + 16, r1); + _mm512_storeu_epi32(dst + 2*16, r2); + _mm512_storeu_epi32(dst + 3*16, r3); + _mm512_storeu_epi32(dst + 4*16, r4); + _mm512_storeu_epi32(dst + 5*16, r5); + _mm512_storeu_epi32(dst + 6*16, r6); + _mm512_storeu_epi32(dst + 7*16, r7); + _mm512_storeu_epi32(dst + 8*16, r8); + _mm512_storeu_epi32(dst + 9*16, r9); + _mm512_storeu_epi32(dst + 10*16, ra); + _mm512_storeu_epi32(dst + 11*16, rb); + _mm512_storeu_epi32(dst + 12*16, rc); + _mm512_storeu_epi32(dst + 13*16, rd); + _mm512_storeu_epi32(dst + 14*16, re); + _mm512_storeu_epi32(dst + 15*16, rf); + } + + inline void transpose_epi32_Mx16(void * _dst, const void * src, int stride, int valid_m) { + auto * dst = reinterpret_cast(_dst); + __m512i r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf; + auto * pA = reinterpret_cast(src); + switch (valid_m) { + case 15: + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_loadu_epi32(pA + 2*stride); + r3 = _mm512_loadu_epi32(pA + 3*stride); + r4 = _mm512_loadu_epi32(pA + 4*stride); + r5 = _mm512_loadu_epi32(pA + 5*stride); + r6 = _mm512_loadu_epi32(pA + 6*stride); + r7 = _mm512_loadu_epi32(pA + 7*stride); + r8 = _mm512_loadu_epi32(pA + 8*stride); + r9 = _mm512_loadu_epi32(pA + 9*stride); + ra = _mm512_loadu_epi32(pA + 10*stride); + rb = _mm512_loadu_epi32(pA + 11*stride); + rc = _mm512_loadu_epi32(pA + 12*stride); + rd = _mm512_loadu_epi32(pA + 13*stride); + re = _mm512_loadu_epi32(pA + 14*stride); + rf = _mm512_setzero_epi32(); + break; + case 14: + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_loadu_epi32(pA + 2*stride); + r3 = _mm512_loadu_epi32(pA + 3*stride); + r4 = _mm512_loadu_epi32(pA + 4*stride); + r5 = _mm512_loadu_epi32(pA + 5*stride); + r6 = _mm512_loadu_epi32(pA + 6*stride); + r7 = _mm512_loadu_epi32(pA + 7*stride); + r8 = _mm512_loadu_epi32(pA + 8*stride); + r9 = _mm512_loadu_epi32(pA + 9*stride); + ra = _mm512_loadu_epi32(pA + 10*stride); + rb = _mm512_loadu_epi32(pA + 11*stride); + rc = _mm512_loadu_epi32(pA + 12*stride); + rd = _mm512_loadu_epi32(pA + 13*stride); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 13: + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_loadu_epi32(pA + 2*stride); + r3 = _mm512_loadu_epi32(pA + 3*stride); + r4 = _mm512_loadu_epi32(pA + 4*stride); + r5 = _mm512_loadu_epi32(pA + 5*stride); + r6 = _mm512_loadu_epi32(pA + 6*stride); + r7 = _mm512_loadu_epi32(pA + 7*stride); + r8 = _mm512_loadu_epi32(pA + 8*stride); + r9 = _mm512_loadu_epi32(pA + 9*stride); + ra = _mm512_loadu_epi32(pA + 10*stride); + rb = _mm512_loadu_epi32(pA + 11*stride); + rc = _mm512_loadu_epi32(pA + 12*stride); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 12: + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_loadu_epi32(pA + 2*stride); + r3 = _mm512_loadu_epi32(pA + 3*stride); + r4 = _mm512_loadu_epi32(pA + 4*stride); + r5 = _mm512_loadu_epi32(pA + 5*stride); + r6 = _mm512_loadu_epi32(pA + 6*stride); + r7 = _mm512_loadu_epi32(pA + 7*stride); + r8 = _mm512_loadu_epi32(pA + 8*stride); + r9 = _mm512_loadu_epi32(pA + 9*stride); + ra = _mm512_loadu_epi32(pA + 10*stride); + rb = _mm512_loadu_epi32(pA + 11*stride); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 11: + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_loadu_epi32(pA + 2*stride); + r3 = _mm512_loadu_epi32(pA + 3*stride); + r4 = _mm512_loadu_epi32(pA + 4*stride); + r5 = _mm512_loadu_epi32(pA + 5*stride); + r6 = _mm512_loadu_epi32(pA + 6*stride); + r7 = _mm512_loadu_epi32(pA + 7*stride); + r8 = _mm512_loadu_epi32(pA + 8*stride); + r9 = _mm512_loadu_epi32(pA + 9*stride); + ra = _mm512_loadu_epi32(pA + 10*stride); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 10: + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_loadu_epi32(pA + 2*stride); + r3 = _mm512_loadu_epi32(pA + 3*stride); + r4 = _mm512_loadu_epi32(pA + 4*stride); + r5 = _mm512_loadu_epi32(pA + 5*stride); + r6 = _mm512_loadu_epi32(pA + 6*stride); + r7 = _mm512_loadu_epi32(pA + 7*stride); + r8 = _mm512_loadu_epi32(pA + 8*stride); + r9 = _mm512_loadu_epi32(pA + 9*stride); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 9: + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_loadu_epi32(pA + 2*stride); + r3 = _mm512_loadu_epi32(pA + 3*stride); + r4 = _mm512_loadu_epi32(pA + 4*stride); + r5 = _mm512_loadu_epi32(pA + 5*stride); + r6 = _mm512_loadu_epi32(pA + 6*stride); + r7 = _mm512_loadu_epi32(pA + 7*stride); + r8 = _mm512_loadu_epi32(pA + 8*stride); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 8: + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_loadu_epi32(pA + 2*stride); + r3 = _mm512_loadu_epi32(pA + 3*stride); + r4 = _mm512_loadu_epi32(pA + 4*stride); + r5 = _mm512_loadu_epi32(pA + 5*stride); + r6 = _mm512_loadu_epi32(pA + 6*stride); + r7 = _mm512_loadu_epi32(pA + 7*stride); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 7: + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_loadu_epi32(pA + 2*stride); + r3 = _mm512_loadu_epi32(pA + 3*stride); + r4 = _mm512_loadu_epi32(pA + 4*stride); + r5 = _mm512_loadu_epi32(pA + 5*stride); + r6 = _mm512_loadu_epi32(pA + 6*stride); + r7 = _mm512_setzero_epi32(); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 6: + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_loadu_epi32(pA + 2*stride); + r3 = _mm512_loadu_epi32(pA + 3*stride); + r4 = _mm512_loadu_epi32(pA + 4*stride); + r5 = _mm512_loadu_epi32(pA + 5*stride); + r6 = _mm512_setzero_epi32(); + r7 = _mm512_setzero_epi32(); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 5: + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_loadu_epi32(pA + 2*stride); + r3 = _mm512_loadu_epi32(pA + 3*stride); + r4 = _mm512_loadu_epi32(pA + 4*stride); + r5 = _mm512_setzero_epi32(); + r6 = _mm512_setzero_epi32(); + r7 = _mm512_setzero_epi32(); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 4: + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_loadu_epi32(pA + 2*stride); + r3 = _mm512_loadu_epi32(pA + 3*stride); + r4 = _mm512_setzero_epi32(); + r5 = _mm512_setzero_epi32(); + r6 = _mm512_setzero_epi32(); + r7 = _mm512_setzero_epi32(); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 3: + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_loadu_epi32(pA + 2*stride); + r3 = _mm512_setzero_epi32(); + r4 = _mm512_setzero_epi32(); + r5 = _mm512_setzero_epi32(); + r6 = _mm512_setzero_epi32(); + r7 = _mm512_setzero_epi32(); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 2: + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_setzero_epi32(); + r3 = _mm512_setzero_epi32(); + r4 = _mm512_setzero_epi32(); + r5 = _mm512_setzero_epi32(); + r6 = _mm512_setzero_epi32(); + r7 = _mm512_setzero_epi32(); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 1: + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_setzero_epi32(); + r2 = _mm512_setzero_epi32(); + r3 = _mm512_setzero_epi32(); + r4 = _mm512_setzero_epi32(); + r5 = _mm512_setzero_epi32(); + r6 = _mm512_setzero_epi32(); + r7 = _mm512_setzero_epi32(); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + } + + transpose_m512i_16x16(r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf); + + _mm512_storeu_epi32(dst, r0); + _mm512_storeu_epi32(dst + 16, r1); + _mm512_storeu_epi32(dst + 2*16, r2); + _mm512_storeu_epi32(dst + 3*16, r3); + _mm512_storeu_epi32(dst + 4*16, r4); + _mm512_storeu_epi32(dst + 5*16, r5); + _mm512_storeu_epi32(dst + 6*16, r6); + _mm512_storeu_epi32(dst + 7*16, r7); + _mm512_storeu_epi32(dst + 8*16, r8); + _mm512_storeu_epi32(dst + 9*16, r9); + _mm512_storeu_epi32(dst + 10*16, ra); + _mm512_storeu_epi32(dst + 11*16, rb); + _mm512_storeu_epi32(dst + 12*16, rc); + _mm512_storeu_epi32(dst + 13*16, rd); + _mm512_storeu_epi32(dst + 14*16, re); + _mm512_storeu_epi32(dst + 15*16, rf); + } + + // 16xN, N<=16, non-valid part is on the left, filled with zeros + inline void transpose_epi32_16xN_right_align(void * _dst, const void * src, int stride, int valid_bytes) { + auto * dst = reinterpret_cast(_dst); + __m512i r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf; + int invalid_bytes = 64 - valid_bytes; + auto * pA = reinterpret_cast(src) - invalid_bytes; + uint64_t mask_value = 0xFFFFFFFFFFFFFFFFull << invalid_bytes; + __mmask64 mask = _cvtu64_mask64(mask_value); + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); + r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); + r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); + r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); + r6 = _mm512_maskz_loadu_epi8 (mask, pA + 6*stride); + r7 = _mm512_maskz_loadu_epi8 (mask, pA + 7*stride); + r8 = _mm512_maskz_loadu_epi8 (mask, pA + 8*stride); + r9 = _mm512_maskz_loadu_epi8 (mask, pA + 9*stride); + ra = _mm512_maskz_loadu_epi8 (mask, pA + 10*stride); + rb = _mm512_maskz_loadu_epi8 (mask, pA + 11*stride); + rc = _mm512_maskz_loadu_epi8 (mask, pA + 12*stride); + rd = _mm512_maskz_loadu_epi8 (mask, pA + 13*stride); + re = _mm512_maskz_loadu_epi8 (mask, pA + 14*stride); + rf = _mm512_maskz_loadu_epi8 (mask, pA + 15*stride); + transpose_m512i_16x16(r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf); + _mm512_storeu_epi32(dst, r0); + _mm512_storeu_epi32(dst + 16, r1); + _mm512_storeu_epi32(dst + 2*16, r2); + _mm512_storeu_epi32(dst + 3*16, r3); + _mm512_storeu_epi32(dst + 4*16, r4); + _mm512_storeu_epi32(dst + 5*16, r5); + _mm512_storeu_epi32(dst + 6*16, r6); + _mm512_storeu_epi32(dst + 7*16, r7); + _mm512_storeu_epi32(dst + 8*16, r8); + _mm512_storeu_epi32(dst + 9*16, r9); + _mm512_storeu_epi32(dst + 10*16, ra); + _mm512_storeu_epi32(dst + 11*16, rb); + _mm512_storeu_epi32(dst + 12*16, rc); + _mm512_storeu_epi32(dst + 13*16, rd); + _mm512_storeu_epi32(dst + 14*16, re); + _mm512_storeu_epi32(dst + 15*16, rf); + } + + inline void transpose_epi32_MxN_right_align(void * _dst, const void * src, int stride, int valid_bytes, int valid_m) { + auto * dst = reinterpret_cast(_dst); + __m512i r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf; + int invalid_bytes = 64 - valid_bytes; + auto * pA = reinterpret_cast(src) - invalid_bytes; + uint64_t mask_value = 0xFFFFFFFFFFFFFFFFull << invalid_bytes; + __mmask64 mask = _cvtu64_mask64(mask_value); + switch (valid_m) { + case 15: + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); + r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); + r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); + r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); + r6 = _mm512_maskz_loadu_epi8 (mask, pA + 6*stride); + r7 = _mm512_maskz_loadu_epi8 (mask, pA + 7*stride); + r8 = _mm512_maskz_loadu_epi8 (mask, pA + 8*stride); + r9 = _mm512_maskz_loadu_epi8 (mask, pA + 9*stride); + ra = _mm512_maskz_loadu_epi8 (mask, pA + 10*stride); + rb = _mm512_maskz_loadu_epi8 (mask, pA + 11*stride); + rc = _mm512_maskz_loadu_epi8 (mask, pA + 12*stride); + rd = _mm512_maskz_loadu_epi8 (mask, pA + 13*stride); + re = _mm512_maskz_loadu_epi8 (mask, pA + 14*stride); + rf = _mm512_setzero_epi32(); + break; + case 14: + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); + r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); + r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); + r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); + r6 = _mm512_maskz_loadu_epi8 (mask, pA + 6*stride); + r7 = _mm512_maskz_loadu_epi8 (mask, pA + 7*stride); + r8 = _mm512_maskz_loadu_epi8 (mask, pA + 8*stride); + r9 = _mm512_maskz_loadu_epi8 (mask, pA + 9*stride); + ra = _mm512_maskz_loadu_epi8 (mask, pA + 10*stride); + rb = _mm512_maskz_loadu_epi8 (mask, pA + 11*stride); + rc = _mm512_maskz_loadu_epi8 (mask, pA + 12*stride); + rd = _mm512_maskz_loadu_epi8 (mask, pA + 13*stride); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 13: + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); + r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); + r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); + r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); + r6 = _mm512_maskz_loadu_epi8 (mask, pA + 6*stride); + r7 = _mm512_maskz_loadu_epi8 (mask, pA + 7*stride); + r8 = _mm512_maskz_loadu_epi8 (mask, pA + 8*stride); + r9 = _mm512_maskz_loadu_epi8 (mask, pA + 9*stride); + ra = _mm512_maskz_loadu_epi8 (mask, pA + 10*stride); + rb = _mm512_maskz_loadu_epi8 (mask, pA + 11*stride); + rc = _mm512_maskz_loadu_epi8 (mask, pA + 12*stride); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 12: + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); + r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); + r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); + r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); + r6 = _mm512_maskz_loadu_epi8 (mask, pA + 6*stride); + r7 = _mm512_maskz_loadu_epi8 (mask, pA + 7*stride); + r8 = _mm512_maskz_loadu_epi8 (mask, pA + 8*stride); + r9 = _mm512_maskz_loadu_epi8 (mask, pA + 9*stride); + ra = _mm512_maskz_loadu_epi8 (mask, pA + 10*stride); + rb = _mm512_maskz_loadu_epi8 (mask, pA + 11*stride); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 11: + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); + r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); + r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); + r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); + r6 = _mm512_maskz_loadu_epi8 (mask, pA + 6*stride); + r7 = _mm512_maskz_loadu_epi8 (mask, pA + 7*stride); + r8 = _mm512_maskz_loadu_epi8 (mask, pA + 8*stride); + r9 = _mm512_maskz_loadu_epi8 (mask, pA + 9*stride); + ra = _mm512_maskz_loadu_epi8 (mask, pA + 10*stride); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 10: + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); + r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); + r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); + r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); + r6 = _mm512_maskz_loadu_epi8 (mask, pA + 6*stride); + r7 = _mm512_maskz_loadu_epi8 (mask, pA + 7*stride); + r8 = _mm512_maskz_loadu_epi8 (mask, pA + 8*stride); + r9 = _mm512_maskz_loadu_epi8 (mask, pA + 9*stride); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 9: + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); + r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); + r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); + r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); + r6 = _mm512_maskz_loadu_epi8 (mask, pA + 6*stride); + r7 = _mm512_maskz_loadu_epi8 (mask, pA + 7*stride); + r8 = _mm512_maskz_loadu_epi8 (mask, pA + 8*stride); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 8: + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); + r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); + r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); + r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); + r6 = _mm512_maskz_loadu_epi8 (mask, pA + 6*stride); + r7 = _mm512_maskz_loadu_epi8 (mask, pA + 7*stride); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 7: + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); + r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); + r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); + r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); + r6 = _mm512_maskz_loadu_epi8 (mask, pA + 6*stride); + r7 = _mm512_setzero_epi32(); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 6: + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); + r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); + r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); + r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); + r6 = _mm512_setzero_epi32(); + r7 = _mm512_setzero_epi32(); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 5: + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); + r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); + r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); + r5 = _mm512_setzero_epi32(); + r6 = _mm512_setzero_epi32(); + r7 = _mm512_setzero_epi32(); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 4: + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); + r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); + r4 = _mm512_setzero_epi32(); + r5 = _mm512_setzero_epi32(); + r6 = _mm512_setzero_epi32(); + r7 = _mm512_setzero_epi32(); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 3: + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); + r3 = _mm512_setzero_epi32(); + r4 = _mm512_setzero_epi32(); + r5 = _mm512_setzero_epi32(); + r6 = _mm512_setzero_epi32(); + r7 = _mm512_setzero_epi32(); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 2: + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_setzero_epi32(); + r3 = _mm512_setzero_epi32(); + r4 = _mm512_setzero_epi32(); + r5 = _mm512_setzero_epi32(); + r6 = _mm512_setzero_epi32(); + r7 = _mm512_setzero_epi32(); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 1: + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_setzero_epi32(); + r2 = _mm512_setzero_epi32(); + r3 = _mm512_setzero_epi32(); + r4 = _mm512_setzero_epi32(); + r5 = _mm512_setzero_epi32(); + r6 = _mm512_setzero_epi32(); + r7 = _mm512_setzero_epi32(); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + } + transpose_m512i_16x16(r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf); + _mm512_storeu_epi32(dst, r0); + _mm512_storeu_epi32(dst + 16, r1); + _mm512_storeu_epi32(dst + 2*16, r2); + _mm512_storeu_epi32(dst + 3*16, r3); + _mm512_storeu_epi32(dst + 4*16, r4); + _mm512_storeu_epi32(dst + 5*16, r5); + _mm512_storeu_epi32(dst + 6*16, r6); + _mm512_storeu_epi32(dst + 7*16, r7); + _mm512_storeu_epi32(dst + 8*16, r8); + _mm512_storeu_epi32(dst + 9*16, r9); + _mm512_storeu_epi32(dst + 10*16, ra); + _mm512_storeu_epi32(dst + 11*16, rb); + _mm512_storeu_epi32(dst + 12*16, rc); + _mm512_storeu_epi32(dst + 13*16, rd); + _mm512_storeu_epi32(dst + 14*16, re); + _mm512_storeu_epi32(dst + 15*16, rf); + } + + inline void kpack_tile_B0B1(void * _dst0, void * _dst1, const int8_t * _src, int stride, int src_rows) { + #define FROM_B(i) ((1<<4)|(i)) + static const uint32_t idx[16] = { 0,4,FROM_B(0),FROM_B(4), + 1,5,FROM_B(1),FROM_B(5), + 2,6,FROM_B(2),FROM_B(6), + 3,7,FROM_B(3),FROM_B(7)}; + auto midx = _mm512_loadu_epi64(idx); + __mmask16 mask = _cvtu32_mask16(0xFFFFu); + const auto * src = reinterpret_cast(_src); + auto * dst0 = reinterpret_cast(_dst0); + auto * dst1 = reinterpret_cast(_dst1); + if (src_rows == 64) { + for (int row = 0; row < 16; row++) { + // each element (a? or b?) is 32-bits, two lanes in each ymm register + auto a256 = _mm256_loadu_epi8(src); src += stride; // [a0 a1 a2 a3 | a4 a5 a6 a7] 256-bits ymm0 B0: a0-a3 B1: a4:a7 + auto b256 = _mm256_loadu_epi8(src); src += stride; // [b0 b1 b2 b3 | b4 b5 b6 b7] 256-bits ymm1 B0: b0-b3 B1: b4:b7 + auto c256 = _mm256_loadu_epi8(src); src += stride; // [c0 c1 c2 c3 | c4 c5 c6 c7] 256-bits ymm2 B0: c0-c3 B1: c4:c7 + auto d256 = _mm256_loadu_epi8(src); src += stride; // [d0 d1 d2 d3 | d4 d5 d6 d7] 256-bits ymm3 B0: d0-d3 B1: d4:d7 + auto a = _mm512_castsi256_si512(a256); + auto b = _mm512_castsi256_si512(b256); + auto c = _mm512_castsi256_si512(c256); + auto d = _mm512_castsi256_si512(d256); + auto ac = _mm512_mask_permutex2var_epi32(a, mask, midx, c); // [a0 a4 c0 c4 | a1 a5 c1 c5 | a2 a6 c2 c6 | a3 a7 c3 c7] + auto bd = _mm512_mask_permutex2var_epi32(b, mask, midx, d); // [b0 b4 d0 d4 | b1 b5 d1 d5 | b2 b6 d2 d6 | b3 b7 d3 d7] + auto aib = _mm512_unpacklo_epi8(ac, bd); // [a0&b0 a4&b4 | a1&b1 a5&b5 | a2&b2 a6&b6 | a3&b3 a7&b7] + auto cid = _mm512_unpackhi_epi8(ac, bd); // [c0&d0 c4&d4 | c1&d1 c5&d5 | c2&d2 c6&d6 | c3&d3 c7&d7] + auto rowB0 = _mm512_unpacklo_epi16(aib, cid); // [a0&b0&c0&d0 | a1&b1&c1&d1 | a2&b2&c2&d2 | a3&b3&c3&d3] 512-bit (64bytes) line in B0 + auto rowB1 = _mm512_unpackhi_epi16(aib, cid); // [a4&b4&c4&d4 | a5&b5&c5&d5 | a6&b6&c6&d6 | a7&b7&c7&d7] 512-bit (64bytes) line in B1 + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + dst0 += 64; + dst1 += 64; + } + } else { + // less than 64 source lines, + int allzero_dst_rows = (64-src_rows)/4; + int allnonzero_dst_rows = src_rows/4; + // padding zeros at the top + auto rowB0 = _mm512_setzero_si512(); + auto rowB1 = _mm512_setzero_si512(); + for(int i = 0; i < allzero_dst_rows ; i++) { + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + dst0 += 64; + dst1 += 64; + } + // mixed row + int tails_nz = (src_rows & 3); + if (tails_nz) { + auto a256 = _mm256_setzero_si256(); // must be zero + auto b256 = _mm256_setzero_si256(); // when tails_nz > 2 + auto c256 = _mm256_setzero_si256(); // when tails_nz > 1 + auto d256 = _mm256_setzero_si256(); // when tails_nz > 0(always load) + if (tails_nz > 2) { + b256 = _mm256_loadu_epi8 (src); src += stride; + } + if (tails_nz > 1) { + c256 = _mm256_loadu_epi8 (src); src += stride; + } + d256 = _mm256_loadu_epi8 (src); src += stride; + auto a = _mm512_castsi256_si512(a256); + auto b = _mm512_castsi256_si512(b256); + auto c = _mm512_castsi256_si512(c256); + auto d = _mm512_castsi256_si512(d256); + auto ac = _mm512_mask_permutex2var_epi32(a, mask, midx, c); // [a0 a4 c0 c4 | a1 a5 c1 c5 | a2 a6 c2 c6 | a3 a7 c3 c7] + auto bd = _mm512_mask_permutex2var_epi32(b, mask, midx, d); // [b0 b4 d0 d4 | b1 b5 d1 d5 | b2 b6 d2 d6 | b3 b7 d3 d7] + auto aib = _mm512_unpacklo_epi8(ac, bd); // [a0&b0 a4&b4 | a1&b1 a5&b5 | a2&b2 a6&b6 | a3&b3 a7&b7] + auto cid = _mm512_unpackhi_epi8(ac, bd); // [c0&d0 c4&d4 | c1&d1 c5&d5 | c2&d2 c6&d6 | c3&d3 c7&d7] + auto rowB0 = _mm512_unpacklo_epi16(aib, cid); // [a0&b0&c0&d0 | a1&b1&c1&d1 | a2&b2&c2&d2 | a3&b3&c3&d3] 512-bit (64bytes) line in B0 + auto rowB1 = _mm512_unpackhi_epi16(aib, cid); // [a4&b4&c4&d4 | a5&b5&c5&d5 | a6&b6&c6&d6 | a7&b7&c7&d7] 512-bit (64bytes) line in B1 + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + dst0 += 64; + dst1 += 64; + } + // all non zeros + for (int i = 0; i < allnonzero_dst_rows; i++) { + auto a256 = _mm256_loadu_epi8 (src); src += stride; + auto b256 = _mm256_loadu_epi8 (src); src += stride; + auto c256 = _mm256_loadu_epi8 (src); src += stride; + auto d256 = _mm256_loadu_epi8 (src); src += stride; + auto a = _mm512_castsi256_si512(a256); + auto b = _mm512_castsi256_si512(b256); + auto c = _mm512_castsi256_si512(c256); + auto d = _mm512_castsi256_si512(d256); + auto ac = _mm512_mask_permutex2var_epi32(a, mask, midx, c); // [a0 a4 c0 c4 | a1 a5 c1 c5 | a2 a6 c2 c6 | a3 a7 c3 c7] + auto bd = _mm512_mask_permutex2var_epi32(b, mask, midx, d); // [b0 b4 d0 d4 | b1 b5 d1 d5 | b2 b6 d2 d6 | b3 b7 d3 d7] + auto aib = _mm512_unpacklo_epi8(ac, bd); // [a0&b0 a4&b4 | a1&b1 a5&b5 | a2&b2 a6&b6 | a3&b3 a7&b7] + auto cid = _mm512_unpackhi_epi8(ac, bd); // [c0&d0 c4&d4 | c1&d1 c5&d5 | c2&d2 c6&d6 | c3&d3 c7&d7] + auto rowB0 = _mm512_unpacklo_epi16(aib, cid); // [a0&b0&c0&d0 | a1&b1&c1&d1 | a2&b2&c2&d2 | a3&b3&c3&d3] 512-bit (64bytes) line in B0 + auto rowB1 = _mm512_unpackhi_epi16(aib, cid); // [a4&b4&c4&d4 | a5&b5&c5&d5 | a6&b6&c6&d6 | a7&b7&c7&d7] 512-bit (64bytes) line in B1 + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + dst0 += 64; + dst1 += 64; + } + } + } + + inline void kpack_tile_B0B1(void * _dst0, void * _dst1, const ov::bfloat16 * _src, int stride, int src_rows) { + static const uint64_t idx[8] = {0,4,1,5,2,6,3,7}; + auto midx = _mm512_loadu_epi64(idx); + const auto * src = reinterpret_cast(_src); + auto * dst0 = reinterpret_cast(_dst0); + auto * dst1 = reinterpret_cast(_dst1); + __m512i a,b,rowB0, rowB1; + if (src_rows == 32) { + for (int row = 0; row < 16; row++) { + a = _mm512_loadu_epi16(src); // [a1 a2 a3 a4 | a5 a6 a7 a8] total 512-bits in 8 64bits unit + b = _mm512_loadu_epi16(src + stride); // [b1 b2 b3 b4 | b5 b6 b7 b8] total 512-bits + a = _mm512_permutexvar_epi64(midx, a); // [a1 a5 | a2 a6 | a3 a7 | a4 a8] + b = _mm512_permutexvar_epi64(midx, b); // [b1 b5 | b2 b6 | b3 b7 | b4 b8] + rowB0 = _mm512_unpacklo_epi16(a, b); // [ a1&b1 a2&b2 a3&b3 a4&b4] for each 128-bits lane, interleave word in low 64 bits + rowB1 = _mm512_unpackhi_epi16(a, b); // [ a5&b5 a6&b6 a7&b7 a8&b8] for each 128-bits lane, interleave word in high 64 bits + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + src += 2*stride; + dst0 += 64; + dst1 += 64; + } + } else { + int allzero_dst_rows = (32-src_rows)/2; + int allnonzero_dst_rows = src_rows/2; + + rowB0 = _mm512_setzero_si512(); + rowB1 = _mm512_setzero_si512(); + for(int i = 0; i < allzero_dst_rows ; i++) { + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + dst0 += 64; + dst1 += 64; + } + // mixed row + if (src_rows & 1) { + a = _mm512_setzero_si512(); + b = _mm512_loadu_epi16(src); src += stride; + a = _mm512_permutexvar_epi64(midx, a); + b = _mm512_permutexvar_epi64(midx, b); + auto rowB0 = _mm512_unpacklo_epi16(a, b); + auto rowB1 = _mm512_unpackhi_epi16(a, b); + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + dst0 += 64; + dst1 += 64; + } + // all non-zero rows + for (int i = 0; i < allnonzero_dst_rows; i++) { + a = _mm512_loadu_epi16(src); + b = _mm512_loadu_epi16(src + stride); + a = _mm512_permutexvar_epi64(midx, a); + b = _mm512_permutexvar_epi64(midx, b); + rowB0 = _mm512_unpacklo_epi16(a, b); + rowB1 = _mm512_unpackhi_epi16(a, b); + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + src += 2*stride; + dst0 += 64; + dst1 += 64; + } + } + } + + inline void kpack_tile_B0B1_ntail(void * _dst0, void * _dst1, const int8_t * _src, int stride, int src_rows, int valid_n) { + #define FROM_B(i) ((1<<4)|(i)) + static const uint32_t idx[16] = { 0,4,FROM_B(0),FROM_B(4), + 1,5,FROM_B(1),FROM_B(5), + 2,6,FROM_B(2),FROM_B(6), + 3,7,FROM_B(3),FROM_B(7)}; + auto midx = _mm512_loadu_epi64(idx); + __mmask16 mask = _cvtu32_mask16(0xFFFFu); + const auto * src = reinterpret_cast(_src); + auto * dst0 = reinterpret_cast(_dst0); + auto * dst1 = reinterpret_cast(_dst1); + __mmask32 mask_n = _cvtu32_mask32(0xFFFFFFFF >> (32 - valid_n)); + if (src_rows == 64) { + for (int row = 0; row < 16; row++) { + // each element (a? or b?) is 32-bits, two lanes in each ymm register + auto a256 = _mm256_maskz_loadu_epi8(mask_n, src); src += stride; // [a0 a1 a2 a3 | a4 a5 a6 a7] 256-bits ymm0 B0: a0-a3 B1: a4:a7 + auto b256 = _mm256_maskz_loadu_epi8(mask_n, src); src += stride; // [b0 b1 b2 b3 | b4 b5 b6 b7] 256-bits ymm1 B0: b0-b3 B1: b4:b7 + auto c256 = _mm256_maskz_loadu_epi8(mask_n, src); src += stride; // [c0 c1 c2 c3 | c4 c5 c6 c7] 256-bits ymm2 B0: c0-c3 B1: c4:c7 + auto d256 = _mm256_maskz_loadu_epi8(mask_n, src); src += stride; // [d0 d1 d2 d3 | d4 d5 d6 d7] 256-bits ymm3 B0: d0-d3 B1: d4:d7 + auto a = _mm512_castsi256_si512(a256); + auto b = _mm512_castsi256_si512(b256); + auto c = _mm512_castsi256_si512(c256); + auto d = _mm512_castsi256_si512(d256); + auto ac = _mm512_mask_permutex2var_epi32(a, mask, midx, c); // [a0 a4 c0 c4 | a1 a5 c1 c5 | a2 a6 c2 c6 | a3 a7 c3 c7] + auto bd = _mm512_mask_permutex2var_epi32(b, mask, midx, d); // [b0 b4 d0 d4 | b1 b5 d1 d5 | b2 b6 d2 d6 | b3 b7 d3 d7] + auto aib = _mm512_unpacklo_epi8(ac, bd); // [a0&b0 a4&b4 | a1&b1 a5&b5 | a2&b2 a6&b6 | a3&b3 a7&b7] + auto cid = _mm512_unpackhi_epi8(ac, bd); // [c0&d0 c4&d4 | c1&d1 c5&d5 | c2&d2 c6&d6 | c3&d3 c7&d7] + auto rowB0 = _mm512_unpacklo_epi16(aib, cid); // [a0&b0&c0&d0 | a1&b1&c1&d1 | a2&b2&c2&d2 | a3&b3&c3&d3] 512-bit (64bytes) line in B0 + auto rowB1 = _mm512_unpackhi_epi16(aib, cid); // [a4&b4&c4&d4 | a5&b5&c5&d5 | a6&b6&c6&d6 | a7&b7&c7&d7] 512-bit (64bytes) line in B1 + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + dst0 += 64; + dst1 += 64; + } + } else { + // less than 64 source lines, + int allzero_dst_rows = (64-src_rows)/4; + int allnonzero_dst_rows = src_rows/4; + // padding zeros at the top + auto rowB0 = _mm512_setzero_si512(); + auto rowB1 = _mm512_setzero_si512(); + for(int i = 0; i < allzero_dst_rows ; i++) { + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + dst0 += 64; + dst1 += 64; + } + // mixed row + int tails_nz = (src_rows & 3); + if (tails_nz) { + auto a256 = _mm256_setzero_si256(); // must be zero + auto b256 = _mm256_setzero_si256(); // when tails_nz > 2 + auto c256 = _mm256_setzero_si256(); // when tails_nz > 1 + auto d256 = _mm256_setzero_si256(); // when tails_nz > 0(always load) + if (tails_nz > 2) { + b256 = _mm256_maskz_loadu_epi8(mask_n, src); src += stride; + } + if (tails_nz > 1) { + c256 = _mm256_maskz_loadu_epi8(mask_n, src); src += stride; + } + d256 = _mm256_maskz_loadu_epi8(mask_n, src); src += stride; + auto a = _mm512_castsi256_si512(a256); + auto b = _mm512_castsi256_si512(b256); + auto c = _mm512_castsi256_si512(c256); + auto d = _mm512_castsi256_si512(d256); + auto ac = _mm512_mask_permutex2var_epi32(a, mask, midx, c); // [a0 a4 c0 c4 | a1 a5 c1 c5 | a2 a6 c2 c6 | a3 a7 c3 c7] + auto bd = _mm512_mask_permutex2var_epi32(b, mask, midx, d); // [b0 b4 d0 d4 | b1 b5 d1 d5 | b2 b6 d2 d6 | b3 b7 d3 d7] + auto aib = _mm512_unpacklo_epi8(ac, bd); // [a0&b0 a4&b4 | a1&b1 a5&b5 | a2&b2 a6&b6 | a3&b3 a7&b7] + auto cid = _mm512_unpackhi_epi8(ac, bd); // [c0&d0 c4&d4 | c1&d1 c5&d5 | c2&d2 c6&d6 | c3&d3 c7&d7] + auto rowB0 = _mm512_unpacklo_epi16(aib, cid); // [a0&b0&c0&d0 | a1&b1&c1&d1 | a2&b2&c2&d2 | a3&b3&c3&d3] 512-bit (64bytes) line in B0 + auto rowB1 = _mm512_unpackhi_epi16(aib, cid); // [a4&b4&c4&d4 | a5&b5&c5&d5 | a6&b6&c6&d6 | a7&b7&c7&d7] 512-bit (64bytes) line in B1 + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + dst0 += 64; + dst1 += 64; + } + // all non zeros + for (int i = 0; i < allnonzero_dst_rows; i++) { + auto a256 = _mm256_maskz_loadu_epi8(mask_n, src); src += stride; + auto b256 = _mm256_maskz_loadu_epi8(mask_n, src); src += stride; + auto c256 = _mm256_maskz_loadu_epi8(mask_n, src); src += stride; + auto d256 = _mm256_maskz_loadu_epi8(mask_n, src); src += stride; + auto a = _mm512_castsi256_si512(a256); + auto b = _mm512_castsi256_si512(b256); + auto c = _mm512_castsi256_si512(c256); + auto d = _mm512_castsi256_si512(d256); + auto ac = _mm512_mask_permutex2var_epi32(a, mask, midx, c); // [a0 a4 c0 c4 | a1 a5 c1 c5 | a2 a6 c2 c6 | a3 a7 c3 c7] + auto bd = _mm512_mask_permutex2var_epi32(b, mask, midx, d); // [b0 b4 d0 d4 | b1 b5 d1 d5 | b2 b6 d2 d6 | b3 b7 d3 d7] + auto aib = _mm512_unpacklo_epi8(ac, bd); // [a0&b0 a4&b4 | a1&b1 a5&b5 | a2&b2 a6&b6 | a3&b3 a7&b7] + auto cid = _mm512_unpackhi_epi8(ac, bd); // [c0&d0 c4&d4 | c1&d1 c5&d5 | c2&d2 c6&d6 | c3&d3 c7&d7] + auto rowB0 = _mm512_unpacklo_epi16(aib, cid); // [a0&b0&c0&d0 | a1&b1&c1&d1 | a2&b2&c2&d2 | a3&b3&c3&d3] 512-bit (64bytes) line in B0 + auto rowB1 = _mm512_unpackhi_epi16(aib, cid); // [a4&b4&c4&d4 | a5&b5&c5&d5 | a6&b6&c6&d6 | a7&b7&c7&d7] 512-bit (64bytes) line in B1 + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + dst0 += 64; + dst1 += 64; + } + } + } + + inline void kpack_tile_B0B1_ntail(void * _dst0, void * _dst1, const ov::bfloat16 * _src, int stride, int src_rows, int valid_n) { + static const uint64_t idx[8] = {0,4,1,5,2,6,3,7}; + auto midx = _mm512_loadu_epi64(idx); + const auto * src = reinterpret_cast(_src); + auto * dst0 = reinterpret_cast(_dst0); + auto * dst1 = reinterpret_cast(_dst1); + __mmask32 mask = _cvtu32_mask32(0xFFFFFFFF >> (32 - valid_n)); + __m512i a,b,rowB0, rowB1; + if (src_rows == 32) { + for (int row = 0; row < 16; row++) { + a = _mm512_maskz_loadu_epi16(mask, src); // [a1 a2 a3 a4 | a5 a6 a7 a8] total 512-bits in 8 64bits unit + b = _mm512_maskz_loadu_epi16(mask, src + stride); // [b1 b2 b3 b4 | b5 b6 b7 b8] total 512-bits + a = _mm512_permutexvar_epi64(midx, a); // [a1 a5 | a2 a6 | a3 a7 | a4 a8] + b = _mm512_permutexvar_epi64(midx, b); // [b1 b5 | b2 b6 | b3 b7 | b4 b8] + rowB0 = _mm512_unpacklo_epi16(a, b); // [ a1&b1 a2&b2 a3&b3 a4&b4] for each 128-bits lane, interleave word in low 64 bits + rowB1 = _mm512_unpackhi_epi16(a, b); // [ a5&b5 a6&b6 a7&b7 a8&b8] for each 128-bits lane, interleave word in high 64 bits + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + src += 2*stride; + dst0 += 64; + dst1 += 64; + } + } else { + int allzero_dst_rows = (32-src_rows)/2; + int allnonzero_dst_rows = src_rows/2; + + rowB0 = _mm512_setzero_si512(); + rowB1 = _mm512_setzero_si512(); + for(int i = 0; i < allzero_dst_rows ; i++) { + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + dst0 += 64; + dst1 += 64; + } + // mixed row + if (src_rows & 1) { + a = _mm512_setzero_si512(); + b = _mm512_maskz_loadu_epi16(mask, src); src += stride; + a = _mm512_permutexvar_epi64(midx, a); + b = _mm512_permutexvar_epi64(midx, b); + auto rowB0 = _mm512_unpacklo_epi16(a, b); + auto rowB1 = _mm512_unpackhi_epi16(a, b); + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + dst0 += 64; + dst1 += 64; + } + // all non-zero rows + for (int i = 0; i < allnonzero_dst_rows; i++) { + a = _mm512_maskz_loadu_epi16(mask, src); + b = _mm512_maskz_loadu_epi16(mask, src + stride); + a = _mm512_permutexvar_epi64(midx, a); + b = _mm512_permutexvar_epi64(midx, b); + rowB0 = _mm512_unpacklo_epi16(a, b); + rowB1 = _mm512_unpackhi_epi16(a, b); + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + src += 2*stride; + dst0 += 64; + dst1 += 64; + } + } + } + + // prepare B matrix for C matrix 2x2 blocking (B matrix + // will be accessed in 1x2) + // given 2x2 blocking scheme, Kx32 blocks are always + // accessed sequentially: + // transpose/repack each 32xK ov::bfloat16 submatrix + // into Kx32 slices (each number is a 16x32 bf16-block): + // 0 2 4 6 ... ... + // 1 3 5 7 ... ... + + inline void get_min_max(tensor2D & matB, float& min, float& max) { + int K = matB.dims[0]; + int N = matB.dims[1]; + auto m_max = _mm512_set1_ps(-__FLT_MAX__); + auto m_min = _mm512_set1_ps(__FLT_MAX__); + for (int k = 0; k < K; k++) { + int n = 0; + for (; n < N / 16 * 16; n += 16) { + auto a = _mm512_cvtepi16_epi32(_mm256_loadu_epi16(&matB(k, n))); + a = _mm512_slli_epi32(a, 16); + m_max = _mm512_max_ps((__m512)a, m_max); + m_min = _mm512_min_ps((__m512)a, m_min); + } + if (n != N) { + __mmask16 msk = _cvtu32_mask16(0xFFFFu >> (16 - (N - n))); + auto a = _mm512_cvtepi16_epi32(_mm256_maskz_loadu_epi16(msk, &matB(k, n))); + a = _mm512_slli_epi32(a, 16); + m_max = _mm512_mask_max_ps(m_max, msk, (__m512)a, m_max); + m_min = _mm512_mask_min_ps(m_min, msk, (__m512)a, m_min); + } + } + max = _mm512_reduce_max_ps(m_max); + min = _mm512_reduce_min_ps(m_min); + } + + template + void i8_to_bf16_Kx32(int8_t *&src, ov::bfloat16 *dst) + { + for (int k = 0; k < K; k++) + { + auto a = _mm_load_si128((__m128i *)src); // 16 int8 + auto b = _mm_load_si128((__m128i *)(src + 16)); // 16 int8 + auto a_512 = _mm512_cvtepi8_epi32(a); // 16 int32 + auto b_512 = _mm512_cvtepi8_epi32(b); // 16 int32 + auto a_f = _mm512_cvtepi32_ps(a_512); // 16 ps + auto b_f = _mm512_cvtepi32_ps(b_512); // 16 ps + auto reg_out = _mm512_cvtne2ps_pbh(b_f, a_f); // 32 packed bf16 + _mm512_store_epi32(dst, (__m512i)reg_out); // + src += 32; // 32 int8_t dequantized into 32 bf16 + dst += 32; + } + } + + inline void bf16_to_i8_tensor(tensor2D& dst, tensor2D& src, float quant_scale) { + dst.resize(src.dims[0], src.dims[1]); + auto scale = _mm512_set1_ps(quant_scale); + for (int k = 0; k < src.dims[0]; k++) { + auto p_src = &src(k, 0); + auto p_dst = &dst(k, 0); + for (int n = 0; n < src.dims[1]; n += 16, p_src += 16, p_dst += 16) { + auto a = _mm512_cvtepi16_epi32(_mm256_loadu_epi16(p_src)); // load packed 16 x bf16 + a = _mm512_slli_epi32(a, 16); // bf16 zero-extend to f32 + auto a_f = _mm512_mul_ps((__m512)a, scale); // scale + a = _mm512_cvtps_epi32(a_f); // ps to dw + auto a_128 = _mm512_cvtsepi32_epi8(a); // saturate convert into int8 + _mm_store_si128((__m128i*)(p_dst), a_128); + } + } + } +}; + +// 2x2 tiles post process kernels + +// 4 tiles located at C matrix (m,n) of size (valid_m, valid_n) +// tC00/tC01 +// tC10/tC11 + +namespace PP { + template + struct is_f32i32 : std::false_type {}; + template<> + struct is_f32i32 : std::true_type {}; + template<> + struct is_f32i32 : std::true_type {}; + + using Steps = postops_types; + + template + struct BiasGeluStore { + static_assert(std::is_same::value || std::is_same::value || std::is_same::value, + "BiasGeluStore only support output data types ov::bfloat16/int8_t/float"); + + BiasGeluStore(tensor2D & C, float * bias = nullptr) : C(C), bias(bias) {} + + tensor2D & C; + float * bias; + void set_bias(float * _bias) { + assert (steps & BIAS); + bias = _bias; + } + + float deq_scale_common = 1.0f; + float * deq_scale_per_oc = nullptr; + void set_deq_scale(float scale = 1.0f) { + assert (steps & DEQUANT); + deq_scale_common = scale; + deq_scale_per_oc = nullptr; + } + void set_deq_scale(float * scale_per_oc) { + assert (steps & DEQUANT); + deq_scale_common = 0; + deq_scale_per_oc = scale_per_oc; + } + + float q_scale_common = 0.0f; + float * q_scale_per_oc = nullptr; + void set_q_scale(float scale) { + assert (steps & QUANT); + q_scale_common = scale; + q_scale_per_oc = nullptr; + } + void set_q_scale(float * scale_per_oc) { + assert (steps & QUANT); + q_scale_common = 0; + q_scale_per_oc = scale_per_oc; + } + + // source buffC can be i32 or f32 + template::value, bool>::type = true> + void operator()(tensor2D & buffC, int m, int n, int valid_m, int valid_n) { + auto * psrc = &buffC(0,0); + int8_t * pdst = reinterpret_cast(&(C(m, n))); + int stride = C.stride; + + __m512 bias0, bias1; + if (steps & BIAS) { + bias0 = _mm512_loadu_ps(bias + n); + bias1 = _mm512_loadu_ps(bias + n + 16); + } + + __m512 m512_q_scale0; + __m512 m512_q_scale1; + __m512 m512_deq_scale0; + __m512 m512_deq_scale1; + if (steps & DEQUANT) { + if (deq_scale_per_oc) { + m512_deq_scale0 = _mm512_loadu_ps(deq_scale_per_oc + n); + m512_deq_scale1 = _mm512_loadu_ps(deq_scale_per_oc + n + 16); + } else { + m512_deq_scale0 = _mm512_set1_ps(deq_scale_common); + m512_deq_scale1 = _mm512_set1_ps(deq_scale_common); + } + } + if (steps & QUANT) { + if (q_scale_per_oc) { + m512_q_scale0 = _mm512_loadu_ps(q_scale_per_oc + n); + m512_q_scale1 = _mm512_loadu_ps(q_scale_per_oc + n + 16); + } else { + m512_q_scale0 = _mm512_set1_ps(q_scale_common); + m512_q_scale1 = _mm512_set1_ps(q_scale_common); + } + } + + __mmask32 kall; + __mmask16 k0, k1; + if (std::is_same::value) { + if (valid_n >= 16) { + k0 = _cvtu32_mask16(0xFFFF); + k1 = _cvtu32_mask16(0xFFFF >> (32-valid_n)); + } else { + k0 = _cvtu32_mask16(0xFFFF >> (16-valid_n)); + k1 = _cvtu32_mask16(0); + } + } else { + kall = _cvtu32_mask32(0xFFFFFFFF >> (32-valid_n)); + } + + for(int i = 0; i < valid_m; i ++) { + auto r0 = _mm512_loadu_ps(psrc); + auto r1 = _mm512_loadu_ps(psrc + 16); + if (std::is_same::value) { + r0 = _mm512_cvtepi32_ps(_mm512_castps_si512(r0)); // cvt i32=>f32 + r1 = _mm512_cvtepi32_ps(_mm512_castps_si512(r1)); // cvt i32=>f32 + } + if (steps & DEQUANT) { + r0 = _mm512_mul_ps(r0, m512_deq_scale0); // dequantize + r1 = _mm512_mul_ps(r1, m512_deq_scale1); // dequantize + } + if (steps & BIAS) { + r0 = _mm512_add_ps(r0, bias0); + r1 = _mm512_add_ps(r1, bias1); + } + if (steps & GELU) { + r0 = gelu_erf_minmax_approx_avx512(r0); + r1 = gelu_erf_minmax_approx_avx512(r1); + } + if (steps & GELU_TANH) { + r0 = gelu_tanh_avx512(r0); + r1 = gelu_tanh_avx512(r1); + } + + // quantize & store + if (steps & QUANT) { + r0 = _mm512_mul_ps(r0, m512_q_scale0); + r1 = _mm512_mul_ps(r1, m512_q_scale1); + } + if (std::is_same::value) { + auto c = _mm512_cvtne2ps_pbh(r1, r0); // convert to bf16 + _mm512_mask_storeu_epi16(pdst, kall, (__m512i)c); // store bf16 + } + if (std::is_same::value) { + auto d0 = _mm512_cvtps_epi32(r0); // convert to dword(i32) + auto d1 = _mm512_cvtps_epi32(r1); // convert to dword(i32) + auto b0 = _mm512_cvtsepi32_epi8 (d0); // dword => int8 with Saturate8 + auto b1 = _mm512_cvtsepi32_epi8 (d1); // dword => int8 with Saturate8 + auto b0b1 = _mm256_inserti32x4(_mm256_castsi128_si256(b0), b1, 1); // combine two int8 xmm into a ymm + _mm256_mask_storeu_epi8(pdst, kall, b0b1); // masked store + } + if (std::is_same::value) { + _mm512_mask_storeu_ps(pdst, k0, r0); // store float + _mm512_mask_storeu_ps(pdst + 64, k1, r1); // store float + } + pdst += stride; + psrc += 32; + } + } + }; +} + +// WA: clang could not find _mm_hint definition +#define prefetch_bytes(bytes, sel, advance, src) { \ + int8_t *p = reinterpret_cast(src); \ + for (int i = 0; i < bytes; i+=64) \ + _mm_prefetch(p + i + advance, sel); \ +} + +// matmul (FC) +// +// constB constrols whether it's FC or not +// store precision for weight compression, only for BF16 AMX + +template +tensor2D getSubMatB(tensor2D & _matB, int n0, int n1, bool transposeB) { + int Bd0 = transposeB ? (n1-n0) : _matB.dims[0]; + int Bd1 = transposeB ? _matB.dims[1] : (n1-n0); + T * pbase = transposeB ? (&_matB(n0, 0)):(&_matB(0, n0)); + return tensor2D(Bd0, Bd1, pbase, _matB.stride); +} + +template +void loop2D_no_bM(int M, int N, F f) { + for(int n=0; n +void loop2D(int M, int N, int mc, F f) { + for(int m0=0; m0= bM) +template +void loop2D_opt_Mtail(int M, int N, int mc, F f) { + assert(M > bM); + for(int m0=0; m0 +void repackB_1x2(const tensor2D &Bi, bool transpose, tensor2D& Bo, bool is_const) { + int K = Bi.dims[transpose ? 1 : 0]; + int N = Bi.dims[transpose ? 0 : 1]; + + // K_padded : round up to multiple of 32/64 + int kStep = 64 / sizeof(T); + int K_padded = (K + kStep - 1) / kStep * kStep; + int Ktails = K % kStep; + int Kbody = K - Ktails; + + // N_padded : round up to multiple of (2*16) + int N_unit = 2 * 16; + int N_padded = (N + N_unit - 1) / N_unit * N_unit; + + // Bo(ni, 0) is a vector flattened from a slice of shape [K_padded x N_unit] + Bo.resize(N_padded / N_unit, K_padded * N_unit, false, is_const); + + int n = 0; + int n_tail = N % N_unit; + if (transpose) { + for(; n < N - n_tail; n += N_unit) { + // a K_padded x N_unit submatrix layouted in B0/B1... and put sequentially + auto* dst = reinterpret_cast(&Bo(n / N_unit, 0)); + auto* src0 = reinterpret_cast(&Bi(n, 0)); + int k; + for(k = 0; k < Kbody; k += kStep) { + // B0 (16x32) => transpose+repack as 32x16(16x16x2) or 64x16(16x16x4) + functional::transpose_epi32_16x16(dst, src0 + 0 * 16 * Bi.stride + k * sizeof(T), Bi.stride); + dst += 1024; + functional::transpose_epi32_16x16(dst, src0 + 1 * 16 * Bi.stride + k * sizeof(T), Bi.stride); + dst += 1024; + } + if (Ktails) { + // Ktails part is loaded into A tile right-aligned, so B tile must also load + // Ktails part to bottom-aligned, and fill upper padding with zero + functional::transpose_epi32_16xN_right_align(dst, src0 + 0 * 16 * Bi.stride + k * sizeof(T), Bi.stride, (K - k)*sizeof(T)); + dst += 1024; + functional::transpose_epi32_16xN_right_align(dst, src0 + 1 * 16 * Bi.stride + k * sizeof(T), Bi.stride, (K - k)*sizeof(T)); + dst += 1024; + } + } + // n_tail: [16, 32) + if (N - n >= 16) { + auto* dst = reinterpret_cast(&Bo(n / N_unit, 0)); + auto* src0 = reinterpret_cast(&Bi(n, 0)); + int k; + for(k = 0; k < Kbody; k += kStep) { + // B0 (16x32) => transpose+repack as 32x16(16x16x2) or 64x16(16x16x4) + functional::transpose_epi32_16x16(dst, src0 + 0 * 16 * Bi.stride + k * sizeof(T), Bi.stride); + dst += 1024 * 2; + } + if (Ktails) { + // Ktails part is loaded into A tile right-aligned, so B tile must also load + // Ktails part to bottom-aligned, and fill upper padding with zero + functional::transpose_epi32_16xN_right_align(dst, src0 + 0 * 16 * Bi.stride + k * sizeof(T), Bi.stride, (K - k) * sizeof(T)); + } + n += 16; + } + // n_tail: (0, 16) + if (N - n > 0) { + auto* dst = reinterpret_cast(&Bo(n / N_unit, 0)) + (n_tail > 16 ? 1024 : 0); + auto* src0 = reinterpret_cast(&Bi(n, 0)); + int k; + for(k = 0; k < Kbody; k += kStep) { + // B0 (16x32) => transpose+repack as 32x16(16x16x2) or 64x16(16x16x4) + functional::transpose_epi32_Mx16(dst, src0 + 0 * 16 * Bi.stride + k * sizeof(T), Bi.stride, N - n); + dst += 1024 * 2; + } + if (Ktails) { + // Ktails part is loaded into A tile right-aligned, so B tile must also load + // Ktails part to bottom-aligned, and fill upper padding with zero + functional::transpose_epi32_MxN_right_align(dst, src0 + 0 * 16 * Bi.stride + k * sizeof(T), Bi.stride, (K - k) * sizeof(T), N - n); + } + n = N; + } + // second B tile is untouched, need to set to zero + if (n_tail > 0 && n_tail <= 16) { + auto* dst = reinterpret_cast(&Bo(n / N_unit, 0)); + for (int k = 0; k < K_padded; k += kStep) { + memset(dst + 1024, 0, 1024); + dst += 1024 * 2; + } + } + } else { + // pack & layout sequentially + int n = 0; + int n_tail = N % N_unit; + for(; n < N - n_tail; n += N_unit) { + auto * dst = reinterpret_cast(&Bo(n / N_unit, 0)); + for(int k = 0; k < K; k += kStep) { + // bf16: B0 B1 32x(16+16) => repack as two 16x16x2 + // int8: B0 B1 64x(16+16) => repack as two 16x16x4 + int src_rows = std::min(K - k, kStep); + functional::kpack_tile_B0B1(dst, dst + 1024, &Bi(k, n), Bi.stride, src_rows); + dst += 2048; + } + } + // n_tail: (0, 32) + if (N - n > 0) { + auto * dst = reinterpret_cast(&Bo(n / N_unit, 0)); + for(int k = 0; k < K; k += kStep) { + // bf16: B0 B1 32x(16+16) => repack as two 16x16x2 + // int8: B0 B1 64x(16+16) => repack as two 16x16x4 + int src_rows = std::min(K - k, kStep); + functional::kpack_tile_B0B1_ntail(dst, dst + 1024, &Bi(k, n), Bi.stride, src_rows, N - n); + dst += 2048; + } + n += 16; + } + } +} + +template +struct acc_type {}; +template<> +struct acc_type { typedef float type; }; +template<> +struct acc_type { typedef int32_t type; }; +template<> +struct acc_type { typedef int32_t type; }; + +template +using acc_type_t = typename acc_type::type; + +// matrix multiply with vector +// C_Nx1 = A_MxK * b_Kx1 + +template> +struct MatmulVector { + MatmulVector() {} + constexpr static bool is_bf16s8 = std::is_same::value && std::is_same::value; + constexpr static bool is_bf16bf16 = std::is_same::value && std::is_same::value; + constexpr static bool is_s8s8 = std::is_same::value && std::is_same::value; + constexpr static bool is_s8u8 = std::is_same::value && std::is_same::value; + constexpr static bool is_u8s8 = std::is_same::value && std::is_same::value; + constexpr static bool is_u8u8 = std::is_same::value && std::is_same::value; + constexpr static bool is_i8_mode = is_s8s8 || is_s8u8 || is_u8s8 || is_u8u8; + constexpr static int kStep = is_i8_mode ? 64 : 32; + +#define TILE_DP(dst, a, b) \ + if (is_bf16bf16) _tile_dpbf16ps(dst, a, b); \ + if (is_s8s8) _tile_dpbssd(dst, a, b); \ + if (is_s8u8) _tile_dpbsud(dst, a, b); \ + if (is_u8s8) _tile_dpbusd(dst, a, b); \ + if (is_u8u8) _tile_dpbuud(dst, a, b); + + alignas(64) int8_t KtailBuff[64]; + + template + void kernel(int M, int K, const void * pA, int strideA, const void * vB, void * vC) { + static_assert(tmmN >= 1 && tmmN <= 6, "tmmN must be within [1-6] range"); + const auto * pA0 = reinterpret_cast(pA); + int KLastOffBytes = (K - kStep) * sizeof(TA); + const auto * pB0 = reinterpret_cast(vB); + auto * pC0 = reinterpret_cast(vC); + + const auto * pBLast = pB0 + 64*(tmmN - 1); + int Ktail = K & (kStep - 1); + if (Ktail) { + if (bFallbackKtails) { + // if bContainMtails, the last submatrix needs to use special to prevent A matrix read overflow + // K tails is handled by: + // - zero-padding the last tile of vector B, at the top + // - right-align last tile load from matA + __mmask64 kmask = _cvtu64_mask64(0xFFFFFFFFFFFFFFFFull << (kStep - Ktail)*sizeof(TB)); + auto r = _mm512_maskz_loadu_epi8(kmask, pB0 + KLastOffBytes); + _mm512_storeu_epi8(KtailBuff, r); + } else { + // each row of A can be read overflow w/o worrying NaN numbers + // zero-padding the last tile of vector B as bottom is enough + __mmask64 kmask = _cvtu64_mask64(0xFFFFFFFFFFFFFFFFull >> (kStep - Ktail)*sizeof(TB)); + KLastOffBytes = (K - Ktail)*sizeof(TA); + auto r = _mm512_maskz_loadu_epi8(kmask, pB0 + KLastOffBytes); + _mm512_storeu_epi8(KtailBuff, r); + } + pBLast = KtailBuff; + } + + // load B tiles outside of loop + if (tmmN == 1) { + _tile_loadd(2, pB0, 4); + } + if (tmmN == 2) { + _tile_loadd(2, pB0, 4); + _tile_loadd(3, pBLast, 4); + } + if (tmmN == 3) { + _tile_loadd(2, pB0, 4); + _tile_loadd(3, pB0 + 64, 4); + _tile_loadd(4, pBLast, 4); + } + if (tmmN == 4) { + _tile_loadd(2, pB0, 4); + _tile_loadd(3, pB0 + 64, 4); + _tile_loadd(4, pB0 + 64*2, 4); + _tile_loadd(5, pBLast, 4); + } + if (tmmN == 5) { + _tile_loadd(2, pB0, 4); + _tile_loadd(3, pB0 + 64, 4); + _tile_loadd(4, pB0 + 64*2, 4); + _tile_loadd(5, pB0 + 64*3, 4); + _tile_loadd(6, pBLast, 4); + } + if (tmmN == 6) { + _tile_loadd(2, pB0, 4); + _tile_loadd(3, pB0 + 64, 4); + _tile_loadd(4, pB0 + 64*2, 4); + _tile_loadd(5, pB0 + 64*3, 4); + _tile_loadd(6, pB0 + 64*4, 4); + _tile_loadd(7, pBLast, 4); + } + //asm("int3"); + for(int m = 0; m < M; m+=16) { + _tile_zero(0); + if (tmmN == 1) { + _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); + } + if (tmmN == 2) { + _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); + _tile_loadd(1, pA0 + KLastOffBytes, strideA); TILE_DP(0, 1, 3); + } + if (tmmN == 3) { + _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); + _tile_loadd(1, pA0 + 64, strideA); TILE_DP(0, 1, 3); + _tile_loadd(1, pA0 + KLastOffBytes, strideA); TILE_DP(0, 1, 4); + } + if (tmmN == 4) { + _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); + _tile_loadd(1, pA0 + 64, strideA); TILE_DP(0, 1, 3); + _tile_loadd(1, pA0 + 128, strideA); TILE_DP(0, 1, 4); + _tile_loadd(1, pA0 + KLastOffBytes, strideA); TILE_DP(0, 1, 5); + } + if (tmmN == 5) { + _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); + _tile_loadd(1, pA0 + 64, strideA); TILE_DP(0, 1, 3); + _tile_loadd(1, pA0 + 128, strideA); TILE_DP(0, 1, 4); + _tile_loadd(1, pA0 + 192, strideA); TILE_DP(0, 1, 5); + _tile_loadd(1, pA0 + KLastOffBytes, strideA); TILE_DP(0, 1, 6); + } + if (tmmN == 6) { + _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); + _tile_loadd(1, pA0 + 64, strideA); TILE_DP(0, 1, 3); + _tile_loadd(1, pA0 + 128, strideA); TILE_DP(0, 1, 4); + _tile_loadd(1, pA0 + 192, strideA); TILE_DP(0, 1, 5); + _tile_loadd(1, pA0 + 256, strideA); TILE_DP(0, 1, 6); + _tile_loadd(1, pA0 + KLastOffBytes, strideA); TILE_DP(0, 1, 7); + } + _tile_stored(0, pC0, 4); pC0 += 16*4; // C is single column, always take 4 bytes + pA0 += 16 * strideA; + } + } + + void operator()(tensor2D & matA, const TB * vB, TC * vC) { + int M = matA.dims[0]; + int K = matA.dims[1]; + TA * pA = &matA[0]; + int strideA = matA.stride; + + // M tails is handled + assert(K >= kStep && K <= 6*kStep); + + int Ktail = K & (kStep - 1); + int Mtail = M & (16 - 1); + int Mbody = M - Mtail; + int numBtiles = (K + kStep - 1)/kStep; + + // if we have Ktails, then it will always be handled in Mtail, so we split + // Mtail out even if it's zero + if (Ktail) { + if (Mtail == 0) { + Mtail = 16; + Mbody -= 16; + } + } + + if (Mbody) { + tileconfig_t tfg(1, 0, { + {16, 4}, // C:0 M x 1 (4b) + {16, 64}, // A:1 M x 32/64 (64b) + {16, 4}, // B:2 32/64 x 1 (4b) + {16, 4}, // B:3 + {16, 4}, // B:4 + {16, 4}, // B:5 + {16, 4}, // B:6 + {16, 4}, // B:7 + }); + // Ktail fallback will always be done at Mtails loop + switch(numBtiles) { + case 1: kernel<1, false>(Mbody, K, pA, strideA, vB, vC); break; + case 2: kernel<2, false>(Mbody, K, pA, strideA, vB, vC); break; + case 3: kernel<3, false>(Mbody, K, pA, strideA, vB, vC); break; + case 4: kernel<4, false>(Mbody, K, pA, strideA, vB, vC); break; + case 5: kernel<5, false>(Mbody, K, pA, strideA, vB, vC); break; + case 6: kernel<6, false>(Mbody, K, pA, strideA, vB, vC); break; + default: + assert(false); // impossible since (K <= 6*kStep) + } + } + + if (Mtail) { + pA = &matA(Mbody, 0); + tileconfig_t tfg(1, 0, { + {Mtail, 4}, // C:0 M x 1 (4b) + {Mtail, 64}, // A:1 M x 32/64 (64b) + {16, 4}, // B:2 32/64 x 1 (4b) + {16, 4}, // B:3 + {16, 4}, // B:4 + {16, 4}, // B:5 + {16, 4}, // B:6 + {16, 4}, // B:7 + }); + if (Ktail) { + switch(numBtiles) { + case 1: kernel<1, true>(Mtail, K, pA, strideA, vB, vC + Mbody); break; + case 2: kernel<2, true>(Mtail, K, pA, strideA, vB, vC + Mbody); break; + case 3: kernel<3, true>(Mtail, K, pA, strideA, vB, vC + Mbody); break; + case 4: kernel<4, true>(Mtail, K, pA, strideA, vB, vC + Mbody); break; + case 5: kernel<5, true>(Mtail, K, pA, strideA, vB, vC + Mbody); break; + case 6: kernel<6, true>(Mtail, K, pA, strideA, vB, vC + Mbody); break; + default: + assert(false); // impossible since (K <= 6*kStep) + } + } else { + switch(numBtiles) { + case 1: kernel<1, false>(Mtail, K, pA, strideA, vB, vC + Mbody); break; + case 2: kernel<2, false>(Mtail, K, pA, strideA, vB, vC + Mbody); break; + case 3: kernel<3, false>(Mtail, K, pA, strideA, vB, vC + Mbody); break; + case 4: kernel<4, false>(Mtail, K, pA, strideA, vB, vC + Mbody); break; + case 5: kernel<5, false>(Mtail, K, pA, strideA, vB, vC + Mbody); break; + case 6: kernel<6, false>(Mtail, K, pA, strideA, vB, vC + Mbody); break; + default: + assert(false); // impossible since (K <= 6*kStep) + } + } + } + } +}; + +template::type> +struct Matmul { + // B matrix is orgnized as tensor2D of shape axb where a=round_up_div(N, 32), b=round_up(K,32/64)*32 + // so b is size of submatrix of Kx32 composed of two columns of B0/B1 tiles. + tensor2D internalB; + + bool constB; + bool transposeB; + + constexpr static bool is_bf16s8 = std::is_same::value && std::is_same::value; + constexpr static bool is_bf16bf16 = std::is_same::value && std::is_same::value; + constexpr static bool is_s8s8 = std::is_same::value && std::is_same::value; + constexpr static bool is_s8u8 = std::is_same::value && std::is_same::value; + constexpr static bool is_u8s8 = std::is_same::value && std::is_same::value; + constexpr static bool is_u8u8 = std::is_same::value && std::is_same::value; + constexpr static bool is_i8_mode = is_s8s8 || is_s8u8 || is_u8s8 || is_u8u8; + + // AMX bf16 & int8 has same M(=16) in A,C tile and same N(=16) in B tile + // but only different K(32 vs 64) in A,C & B tiles + constexpr static int kStep = is_i8_mode ? 64 : 32; + + // 2x2 C tiles buffer + // most usecase requires post-processing with AVX, thus buffC + // is used to transfer data to AVX register + tensor2D buffC; + + Matmul(bool constB = false, bool transposeB = false) : + constB(constB), transposeB(transposeB), buffC(32, 32) {} + + // ppkernel is a callable which captures the runtime args + // by itself, so no need to pass in any post-process related + // runtime args through this API + // + // n0/n1 allows us for calculating only partial results, so it + // can be used to run on multi-cores in parallel + // + // ppkernel will be invoked with true (m,n) with n0-offset added + // so ppkernel don't need to know which sub-matrix it's working on. + // + // for most ppkernels w/o runtime state, a single ppkernel can be + // shared among all threads. + // + // but for ppkernels doing reductions, it needs separate instance + // for each thread, also a post-merging process to combine the results. + // + // ppkernels are simple to write, further wrapping or structurelize only + // makes the design more complex, so we stop doing that. + + // I cannot find a way to call TDP intrinsic polymophically using overload or template. + // have to use old-macro-tricks, hopefully these compile-time checks can be optimized + // by compiler. +#define TILE_DP(dst, a, b) \ + if (is_bf16bf16) _tile_dpbf16ps(dst, a, b); \ + if (is_s8s8) _tile_dpbssd(dst, a, b); \ + if (is_s8u8) _tile_dpbsud(dst, a, b); \ + if (is_u8s8) _tile_dpbusd(dst, a, b); \ + if (is_u8u8) _tile_dpbuud(dst, a, b); + + template + void kernel_slimB(int M, int N, int K, int n0, + tensor2D & A, + void * B, + tensor2D & buffC, + PP ppkernel) { + auto * pB0 = reinterpret_cast(B); + auto * pC0 = &buffC[0]; + int8_t * pA0 = reinterpret_cast(&A[0]); + int strideA = A.stride; + int KlastOffBytes = (K - kStep)* sizeof(TA); + // load B tiles outside of loop + if (tmmN > 0) { + _tile_loadd(2, pB0, 64); pB0 += 1024*2; + } + if (tmmN > 1) { + _tile_loadd(3, pB0, 64); pB0 += 1024*2; + } + if (tmmN > 2) { + _tile_loadd(4, pB0, 64); pB0 += 1024*2; + } + if (tmmN > 3) { + _tile_loadd(5, pB0, 64); pB0 += 1024*2; + } + if (tmmN > 4) { + _tile_loadd(6, pB0, 64); pB0 += 1024*2; + } + if (tmmN > 5) { + _tile_loadd(7, pB0, 64); pB0 += 1024*2; + } + //asm("int3"); + for(int m0 = 0; m0 < M; m0+=16) { + int m = m0; + if (M - m0 < 16) { + // shift up to prevent M-tails + pA0 -= (16 - (M - m0))*A.stride; + m = M - 16; + } + _tile_zero(0); + if (tmmN == 1) { + _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); + } + if (tmmN == 2) { + _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); + _tile_loadd(1, pA0 + KlastOffBytes, strideA); TILE_DP(0, 1, 3); + } + if (tmmN == 3) { + _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); + _tile_loadd(1, pA0 + 64, strideA); TILE_DP(0, 1, 3); + _tile_loadd(1, pA0 + KlastOffBytes, strideA); TILE_DP(0, 1, 4); + } + if (tmmN == 4) { + _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); + _tile_loadd(1, pA0 + 64, strideA); TILE_DP(0, 1, 3); + _tile_loadd(1, pA0 + 128, strideA); TILE_DP(0, 1, 4); + _tile_loadd(1, pA0 + KlastOffBytes, strideA); TILE_DP(0, 1, 5); + } + if (tmmN == 5) { + _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); + _tile_loadd(1, pA0 + 64, strideA); TILE_DP(0, 1, 3); + _tile_loadd(1, pA0 + 128, strideA); TILE_DP(0, 1, 4); + _tile_loadd(1, pA0 + 192, strideA); TILE_DP(0, 1, 5); + _tile_loadd(1, pA0 + KlastOffBytes, strideA); TILE_DP(0, 1, 6); + } + if (tmmN == 6) { + _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); + _tile_loadd(1, pA0 + 64, strideA); TILE_DP(0, 1, 3); + _tile_loadd(1, pA0 + 128, strideA); TILE_DP(0, 1, 4); + _tile_loadd(1, pA0 + 192, strideA); TILE_DP(0, 1, 5); + _tile_loadd(1, pA0 + 256, strideA); TILE_DP(0, 1, 6); + _tile_loadd(1, pA0 + KlastOffBytes, strideA); TILE_DP(0, 1, 7); + } + _tile_stored(0, pC0, buffC.stride); + (ppkernel)(buffC, m, n0, 16, N); + pA0 += 16*A.stride; + } + } + + template + void operator()(tensor2D & matA, + tensor2D & _matB, + int n0, int n1, + PP ppkernel, + bool skip_repack = false) { + int M = matA.dims[0]; + int K = matA.dims[1]; + if (K < kStep) { + int B0, B1; + if (transposeB) { + B0 = _matB.dims[0]; + B1 = kStep; + } else { + B0 = kStep; + B1 = _matB.dims[1]; + } + matA = matA.clone_with_padzero(M, kStep); + _matB = _matB.clone_with_padzero(B0, B1); + K = kStep; + } + auto matB = getSubMatB(_matB, n0, n1, transposeB); + int N = matB.dims[transposeB ? 0 : 1]; + assert(K == matB.dims[transposeB ? 1 : 0]); + // Due to the fact that we load a full tile at tails of K dimension + // we may access memory address beyond the limit of A matrix + // to avoid read in nan values, we backoff to the left to ensure A tile + // contain valid numbers and no overflow access happens, but it requires K>=kStep; + assert(K >= kStep); + int Ktails = K % kStep; + int Kbody = K - Ktails; + int KbackoffBytes = (kStep - Ktails)*sizeof(TA); + + // for non-constB, internalB is updated every time + // for constB, internalB is updated once + if ((!constB && !skip_repack) || (internalB.capacity == 0)) { + repackB_1x2(matB, transposeB, internalB, constB); + } + + // special case when whole B matrix can fit in 6 tiles + // we can load B only once + if (M >= 16 && N <= 16 && K <= 6*kStep) { + // B is zero-padded + // C:0 + // A:1 + // B:2,3,4,5,6,7 + auto * pB0 = reinterpret_cast(&internalB[0]); + tileconfig_t tfg(1, 0, 8, 16, 64); + switch((K + kStep - 1)/kStep) { + case 1: kernel_slimB<1>(M, N, K, n0, matA, pB0, buffC, ppkernel); break; + case 2: kernel_slimB<2>(M, N, K, n0, matA, pB0, buffC, ppkernel); break; + case 3: kernel_slimB<3>(M, N, K, n0, matA, pB0, buffC, ppkernel); break; + case 4: kernel_slimB<4>(M, N, K, n0, matA, pB0, buffC, ppkernel); break; + case 5: kernel_slimB<5>(M, N, K, n0, matA, pB0, buffC, ppkernel); break; + case 6: kernel_slimB<6>(M, N, K, n0, matA, pB0, buffC, ppkernel); break; + default: + assert(false); // impossible since (K <= 6*kStep) + } + return; + } + + if (M <= 16) { + // register/cache blocking scheme is simplified when M <= 16 + // C_MxN: 0,1 + // A_MxK: 2, + // B_KxN: 3, 4 + tileconfig_t tfg(1, 0, {M,M,M,16,16}, 64); + auto * pB0 = reinterpret_cast(&internalB[0]); + auto * const pC0 = &buffC[0]; + int k; + const auto strideA = matA.stride; + loop2D_no_bM<32>(M, N, [&](int m, int n, int valid_m, int valid_n) { + _tile_zero(0); + _tile_zero(1); + int8_t * pA0 = reinterpret_cast(&matA[0]); + for(k=0; k(&matA(m, 0)); + auto * pA1 = reinterpret_cast(&matA(m + 16, 0)); + auto strideA = matA.stride; + auto * pB = reinterpret_cast(&internalB(n>>5, 0)); + _tile_zero(0); + _tile_zero(1); + _tile_zero(2); + _tile_zero(3); + // 2x2 + for (int k = 0; k < Kbody; k += kStep) { + _tile_loadd(4, pA0, strideA); pA0 += 64; + _tile_loadd(6, pB, 64); pB += 1024; + // prefetch_bytes(1024, _MM_HINT_T0, 4096, pB); + TILE_DP(0, 4, 6); + + _tile_loadd(5, pA1, strideA); pA1 += 64; + TILE_DP(2, 5, 6); + _tile_loadd(7, pB, 64); pB += 1024; + // prefetch_bytes(1024, _MM_HINT_T0, 4096, pB); + TILE_DP(1, 4, 7); + + TILE_DP(3, 5, 7); + } + if (Ktails) { + _tile_loadd(4, pA0 - KbackoffBytes, strideA); + _tile_loadd(6, pB, 64); pB += 1024; + // prefetch_bytes(1024, _MM_HINT_T0, 4096, pB); + TILE_DP(0, 4, 6); + + _tile_loadd(5, pA1 - KbackoffBytes, strideA); + TILE_DP(2, 5, 6); + _tile_loadd(7, pB, 64); pB += 1024; + // prefetch_bytes(1024, _MM_HINT_T0, 4096, pB); + TILE_DP(1, 4, 7); + + TILE_DP(3, 5, 7); + } + _tile_stored(0, &buffC(0,0), buffC.stride); + _tile_stored(1, &buffC(0,16), buffC.stride); + _tile_stored(2, &buffC(16,0), buffC.stride); + _tile_stored(3, &buffC(16,16), buffC.stride); + (ppkernel)(buffC, m, n + n0, valid_m, valid_n); + }; + + if (M <= 32 && M >16) { + // 2x2 tile, C:0/1/2/3 A:4/5 B:6/7 no blocking along M dimension + tileconfig_t tfg(1, 0, {16,16,M-16,M-16,16,M-16,16,16}, 64); + loop2D_no_bM<32>(M, N, kernel_2x2); + return; + } + + // generic input shapes with M > 32 + // determine cache blocking scheme + int elesz = sizeof(TA); + int L2 = 2048*1024; // 2MB + int slice_size = 32*rndup(K, 32)*elesz; + int mc = std::max(1, L2/slice_size - 1); + + // M > bM + tileconfig_t tfg(1, 0, 8, 16, 64); + loop2D_opt_Mtail<32, 32>(M, N, mc, kernel_2x2); + } +}; + +// specialization: +// TA is ov::bfloat16 and TB is int8_t, decompressed on the fly into ov::bfloat16 by simply convert +template<> +struct Matmul { + tensor2D internalBI8; + + // wei_buff is ping-pong buffer containing ov::bfloat16 weights decompressed on the fly. + tensor2D weiBuff; + + bool constB; + bool transposeB; + + constexpr static int kStep = 32; + + // 2x2 C tiles buffer + // most usecase requires post-processing with AVX, thus buffC + // is used to transfer data to AVX register + tensor2D buffC; + + Matmul(bool constB = false, bool transposeB = false) : + constB(constB), transposeB(transposeB), buffC(32, 32) {} + + float quant_scale_B; + float dequant_scale_B; + + template + void operator()(tensor2D & matA, + tensor2D & _matB, + int n0, int n1, + PP ppkernel) { + auto matB = getSubMatB(_matB, n0, n1, transposeB); + int M = matA.dims[0]; + int K = matA.dims[1]; + int N = matB.dims[transposeB ? 0 : 1]; + assert(K == matB.dims[transposeB ? 1 : 0]); + // Due to the fact that we load a full tile at tails of K dimension + // we may access memory address beyond the limit of A matrix + // to avoid read in nan values, we backoff to the left to ensure A tile + // contain valid numbers and no overflow access happens, but it requires K>=kStep; + assert(K >= kStep); + int Ktails = K % kStep; + int Kbody = K - Ktails; + int Kbackoff = (kStep - Ktails); + + // for non-constB, internalB is updated every time + // for constB, internalB is updated once + if (!constB || (internalBI8.capacity == 0)) { + // this dynamic quantization of weight matrix using minmax + // is time-consuming, should be used only for constB + if (!constB) { + std::cout << "\t WANING: dynamic quantization of weight matrix for non-constB is time-consuming " << std::endl; + } + // float min, max; + // functional::get_min_max(_matB, min, max); + // max = std::max(std::abs(max), std::abs(min)); + // quant_scale_B = 127 / max; + // dequant_scale_B = max / 127; + + tensor2D internalTmpB; + repackB_1x2(matB, transposeB, internalTmpB, constB); + functional::bf16_to_i8_tensor(internalBI8, internalTmpB, quant_scale_B); + } + + ppkernel.set_deq_scale(dequant_scale_B); + + if (M <= 16) { + // C:0/1 A:2 B:3/4 + // dequantize scale is moved into ppkernel + constexpr int prefetch_ahead = 64*1024; + tileconfig_t tfg(1, 0, {M,M,M,16,16}, 64); + auto * pBint = reinterpret_cast(&internalBI8[0]); + auto & B2buff = weiBuff; + B2buff.resize(32*2, 32); + auto * const pB = &B2buff[0]; + auto * pBsrc = pB + (32*32) * 0; + auto * pBdst = pB + (32*32) * 1; + functional::i8_to_bf16_Kx32<32>(pBint, pBsrc); + + auto * const pC0 = &buffC[0]; + const auto strideA = matA.stride; + loop2D_no_bM<32>(M, N, [&](int m, int n, int valid_m, int valid_n) { + // C:Mx32 = A:Mx32 x B:32x32 + _tile_zero(0); + _tile_zero(1); + auto * pA0 = &matA[0]; + for(int k=0; k(pBint, pBdst); + _tile_loadd(3, pBsrc, 64); + functional::i8_to_bf16_Kx32<8>(pBint, pBdst + 8*32); + _tile_dpbf16ps(0, 2, 3); // C0 += A*B0 + + prefetch_bytes(512, _MM_HINT_T1, prefetch_ahead, pBint); + functional::i8_to_bf16_Kx32<8>(pBint, pBdst + 16*32); + _tile_loadd(4, pBsrc + 16*32, 64); + functional::i8_to_bf16_Kx32<8>(pBint, pBdst + 24*32); + _tile_dpbf16ps(1, 2, 4); // C1 += A*B1 + std::swap(pBsrc, pBdst); + } + if (Ktails) { + _tile_loadd(2, pA0 - Kbackoff, strideA); // backoff to prevent access beyond the end of A + prefetch_bytes(512, _MM_HINT_T1, prefetch_ahead, pBint); + + functional::i8_to_bf16_Kx32<8>(pBint, pBdst); + _tile_loadd(3, pBsrc, 64); + functional::i8_to_bf16_Kx32<8>(pBint, pBdst + 8*32); + _tile_dpbf16ps(0, 2, 3); // C0 += A*B0 + + prefetch_bytes(512, _MM_HINT_T1, prefetch_ahead, pBint); + functional::i8_to_bf16_Kx32<8>(pBint, pBdst + 16*32); + _tile_loadd(4, pBsrc + 16*32, 64); + functional::i8_to_bf16_Kx32<8>(pBint, pBdst + 24*32); + _tile_dpbf16ps(1, 2, 4); // C1 += A*B1 + std::swap(pBsrc, pBdst); + } + //prefetch_bytes(2048, _MM_HINT_T1, prefetch_ahead, pBint); + _tile_stored(0, pC0, buffC.stride); + _tile_stored(1, pC0 + 16, buffC.stride); + //prefetch_bytes(2048, _MM_HINT_T1, prefetch_ahead, pBint + 2048); + //int valid_n = std::min(N - n, 32); + (ppkernel)(buffC, 0, n + n0, M, valid_n); + }); + return; + } + + // 4 tiles buffC is reused as decompressed bf16 weights + ov::bfloat16 * pBa = reinterpret_cast(&buffC(0,0)); + ov::bfloat16 * pBb = pBa + (16*32)*2; + auto kernel_2x2 = [&](int m, int n, int valid_m, int valid_n) { + auto strideA = matA.stride; + auto * pA0 = &matA(m, 0); + auto * pA1 = &matA(m + 16, 0); + auto * pBint = reinterpret_cast(&internalBI8(n>>5, 0)); + functional::i8_to_bf16_Kx32<32>(pBint, pBb); + + _tile_zero(0); + _tile_zero(1); + _tile_zero(2); + _tile_zero(3); + int k; + for (k = 0; k < Kbody; k += kStep) { + functional::i8_to_bf16_Kx32<16>(pBint, pBa); + + _tile_loadd(4, pA0 + k, strideA); + _tile_loadd(6, pBb, 64); + _tile_dpbf16ps(0, 4, 6); + + _tile_loadd(5, pA1 + k, strideA); + _tile_dpbf16ps(2, 5, 6); + + functional::i8_to_bf16_Kx32<16>(pBint, pBa + 16*32); + + _tile_loadd(7, pBb + 16*32, 64); + _tile_dpbf16ps(1, 4, 7); + _tile_dpbf16ps(3, 5, 7); + + std::swap(pBa, pBb); + } + if (Ktails) { + functional::i8_to_bf16_Kx32<16>(pBint, pBa); + + _tile_loadd(4, pA0 + k - Kbackoff, strideA); + _tile_loadd(6, pBb, 64); + _tile_dpbf16ps(0, 4, 6); + + _tile_loadd(5, pA1 + k - Kbackoff, strideA); + _tile_dpbf16ps(2, 5, 6); + + functional::i8_to_bf16_Kx32<16>(pBint, pBa + 16*32); + + _tile_loadd(7, pBb + 16*32, 64); + _tile_dpbf16ps(1, 4, 7); + _tile_dpbf16ps(3, 5, 7); + + std::swap(pBa, pBb); + } + _tile_stored(0, &buffC(0,0), buffC.stride); + _tile_stored(1, &buffC(0,16), buffC.stride); + _tile_stored(2, &buffC(16,0), buffC.stride); + _tile_stored(3, &buffC(16,16), buffC.stride); + (ppkernel)(buffC, m, n + n0, valid_m, valid_n); + }; + + if (M <= 32 && M > 16) { + // 2x2 C:0/1/2/3 A:4/5 B:6/7 + tileconfig_t tfg(1, 0, {16, 16, M-16, M-16, 16, M-16, 16, 16}, 64); + loop2D_no_bM<32>(M, N, kernel_2x2); + return; + } + + // determine blocking scheme + int elesz = sizeof(uint16_t); + int L2 = 2048*1024; // 2MB + int slice_size = 32*rndup(K, 32)*elesz; + int mc = std::max(1, L2/slice_size - 1); // if 1 32xK slice cannot fit L2, use 1 slice at least + + // main loop + tileconfig_t tfg(1, 0, 8, 16, 64); + loop2D_opt_Mtail<32, 32>(M, N, mc, kernel_2x2); + } +}; + +//https://stackoverflow.com/questions/29519222/how-to-transpose-a-16x16-matrix-using-simd-instructions +// vector multiply with matrix: +// mAvB: A(M, K) * B(K, 1) => C(M, 1) +// vAmB: A(1, K) * B(K, N) => C(1, N) +// +// in mAvB form, block of A (16x32) is transposed in register +// in unit of 2 packed bf16, and then vdpbf16ps was used +// to multiply with broadcasted B (2x1) and accumulate into C (16x1) +// +// B is pre-broadcasted in unit of 2 +// +struct GemAvB { + tensor2D Bpadded; + GemAvB() { + } + + void operator()(tensor2D & matA, + ov::bfloat16 * vecB, + float * vecC) { + int M = matA.dims[0]; + int K = matA.dims[1]; + + assert(K >= 32); + + if (K % 32) { + if (K > Bpadded.dims[1]) + Bpadded.resize(1, rndup(K, 32)); + auto newB = &Bpadded(0, 0); + memset(static_cast(newB), 0, Bpadded.stride); + memcpy(newB, vecB, K * sizeof(ov::bfloat16)); + vecB = newB; + } + + for(int m = 0; m < M; m += 16) { + auto * pA = reinterpret_cast(&matA(m, 0)); + auto * pBi32 = reinterpret_cast(vecB); + __m512 regC0 = _mm512_setzero(); + __m512 regC1 = _mm512_setzero(); + __mmask16 kmask = _cvtu32_mask16(0xFFFF); + if (M-m < 16) { + kmask = _cvtu32_mask16(0xFFFF >> (16-(M-m))); + } + for(int k = 0; k < K; k += 32, pA += 64, pBi32 += 16) { + // handle Ab: 16x32 + // transposed in register as 16x16x2 + // r0: (a0,a1)(b0,b1).... + // r1: (a2,a3)(b2,b3).... + // ... + // rf: (a30,a31),(b30,b31).... + // + __m512i r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf; + auto stride = matA.stride; + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_loadu_epi32(pA + 2*stride); + r3 = _mm512_loadu_epi32(pA + 3*stride); + r4 = _mm512_loadu_epi32(pA + 4*stride); + r5 = _mm512_loadu_epi32(pA + 5*stride); + r6 = _mm512_loadu_epi32(pA + 6*stride); + r7 = _mm512_loadu_epi32(pA + 7*stride); + r8 = _mm512_loadu_epi32(pA + 8*stride); + r9 = _mm512_loadu_epi32(pA + 9*stride); + ra = _mm512_loadu_epi32(pA + 10*stride); + rb = _mm512_loadu_epi32(pA + 11*stride); + rc = _mm512_loadu_epi32(pA + 12*stride); + rd = _mm512_loadu_epi32(pA + 13*stride); + re = _mm512_loadu_epi32(pA + 14*stride); + rf = _mm512_loadu_epi32(pA + 15*stride); + + functional::transpose_m512i_16x16(r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf); + + // vdpbf16ps + regC0 = _mm512_dpbf16_ps(regC0, (__m512bh)(r0), (__m512bh)(_mm512_set1_epi32(pBi32[0]))); + regC1 = _mm512_dpbf16_ps(regC1, (__m512bh)(r1), (__m512bh)(_mm512_set1_epi32(pBi32[1]))); + regC0 = _mm512_dpbf16_ps(regC0, (__m512bh)(r2), (__m512bh)(_mm512_set1_epi32(pBi32[2]))); + regC1 = _mm512_dpbf16_ps(regC1, (__m512bh)(r3), (__m512bh)(_mm512_set1_epi32(pBi32[3]))); + regC0 = _mm512_dpbf16_ps(regC0, (__m512bh)(r4), (__m512bh)(_mm512_set1_epi32(pBi32[4]))); + regC1 = _mm512_dpbf16_ps(regC1, (__m512bh)(r5), (__m512bh)(_mm512_set1_epi32(pBi32[5]))); + regC0 = _mm512_dpbf16_ps(regC0, (__m512bh)(r6), (__m512bh)(_mm512_set1_epi32(pBi32[6]))); + regC1 = _mm512_dpbf16_ps(regC1, (__m512bh)(r7), (__m512bh)(_mm512_set1_epi32(pBi32[7]))); + regC0 = _mm512_dpbf16_ps(regC0, (__m512bh)(r8), (__m512bh)(_mm512_set1_epi32(pBi32[8]))); + regC1 = _mm512_dpbf16_ps(regC1, (__m512bh)(r9), (__m512bh)(_mm512_set1_epi32(pBi32[9]))); + regC0 = _mm512_dpbf16_ps(regC0, (__m512bh)(ra), (__m512bh)(_mm512_set1_epi32(pBi32[10]))); + regC1 = _mm512_dpbf16_ps(regC1, (__m512bh)(rb), (__m512bh)(_mm512_set1_epi32(pBi32[11]))); + regC0 = _mm512_dpbf16_ps(regC0, (__m512bh)(rc), (__m512bh)(_mm512_set1_epi32(pBi32[12]))); + regC1 = _mm512_dpbf16_ps(regC1, (__m512bh)(rd), (__m512bh)(_mm512_set1_epi32(pBi32[13]))); + regC0 = _mm512_dpbf16_ps(regC0, (__m512bh)(re), (__m512bh)(_mm512_set1_epi32(pBi32[14]))); + regC1 = _mm512_dpbf16_ps(regC1, (__m512bh)(rf), (__m512bh)(_mm512_set1_epi32(pBi32[15]))); + } + regC0 = _mm512_add_ps(regC0, regC1); + _mm512_mask_storeu_ps (vecC + m, kmask, regC0); + //auto regOut = _mm512_cvtne2ps_pbh(regC0, regC0); // only 16 ov::bfloat16 results in lower 256bits + //_mm256_storeu_si256(reinterpret_cast<__m256i_u *>(vecC + m), _mm512_extracti64x4_epi64(regOut, 0)); + } + } +}; + +} // namespace amx diff --git a/src/rotary_kernel_avx512.hpp b/src/rotary_kernel_avx512.hpp new file mode 100644 index 0000000..9f1a6ce --- /dev/null +++ b/src/rotary_kernel_avx512.hpp @@ -0,0 +1,135 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#ifdef _WIN32 +#include +#else +#include +#include +#endif +#include "common/bf16.hpp" +#include "llm_types.hpp" +#include "utility_kernel_avx512.hpp" + +namespace llmdnn { + template + void rotary_avx512(size_t N, float* cos, float* sin, T* q_src, T* k_src, T* q_dst, T* k_dst) { + static_assert(std::is_same::value, + "rotary_avx512 only support output data types ov::bfloat16/int8_t"); + auto half = N / 2; + // for (size_t i = 0; i < half; i++) { + // q_dst[i] = q_src[i] * cos[i] - q_src[i + half] * sin[i]; + // k_dst[i] = k_src[i] * cos[i] - k_src[i + half] * sin[i]; + // } + // for (size_t i = half; i < N; i++) { + // q_dst[i] = q_src[i] * cos[i] + q_src[i - half] * sin[i]; + // k_dst[i] = k_src[i] * cos[i] + k_src[i - half] * sin[i]; + // } + size_t tail = half % 16; + __mmask16 x_mask = _cvtu32_mask16(0xFFFFu >> (16 - tail)); + size_t i; + for (i = 0; i < half - tail; i += 16) { + auto q = _mm256_loadu_epi16(q_src + i + half); + auto q_f = _mm512_cvtpbh_ps((__m256bh)q); + auto k = _mm256_loadu_epi16(k_src + i + half); + auto k_f = _mm512_cvtpbh_ps((__m256bh)k); + auto cos_f = _mm512_loadu_ps(cos + i); + auto sin_f = _mm512_loadu_ps(sin + i); + auto q_dst_f = _mm512_mul_ps(q_f, sin_f); + auto k_dst_f = _mm512_mul_ps(k_f, sin_f); + + q = _mm256_loadu_epi16(q_src + i); + q_f = _mm512_cvtpbh_ps((__m256bh)q); + k = _mm256_loadu_epi16(k_src + i); + k_f = _mm512_cvtpbh_ps((__m256bh)k); + + q_dst_f = _mm512_fmsub_ps(q_f, cos_f, q_dst_f); + k_dst_f = _mm512_fmsub_ps(k_f, cos_f, k_dst_f); + + auto out = _mm512_cvtne2ps_pbh(q_dst_f, q_dst_f); + _mm256_storeu_si256(reinterpret_cast<__m256i *>(q_dst + i), _mm512_extracti64x4_epi64((__m512i)out, 0)); + out = _mm512_cvtne2ps_pbh(k_dst_f, k_dst_f); + _mm256_storeu_si256(reinterpret_cast<__m256i *>(k_dst + i), _mm512_extracti64x4_epi64((__m512i)out, 0)); + } + if (tail) { + auto q = _mm256_maskz_loadu_epi16(x_mask, q_src + i + half); + auto q_f = _mm512_cvtpbh_ps((__m256bh)q); + auto k = _mm256_maskz_loadu_epi16(x_mask, k_src + i + half); + auto k_f = _mm512_cvtpbh_ps((__m256bh)k); + auto cos_f = _mm512_maskz_loadu_ps(x_mask, cos + i); + auto sin_f = _mm512_maskz_loadu_ps(x_mask, sin + i); + auto q_dst_f = _mm512_mul_ps(q_f, sin_f); + auto k_dst_f = _mm512_mul_ps(k_f, sin_f); + + q = _mm256_maskz_loadu_epi16(x_mask, q_src + i); + q_f = _mm512_cvtpbh_ps((__m256bh)q); + k = _mm256_maskz_loadu_epi16(x_mask, k_src + i); + k_f = _mm512_cvtpbh_ps((__m256bh)k); + + q_dst_f = _mm512_fmsub_ps(q_f, cos_f, q_dst_f); + k_dst_f = _mm512_fmsub_ps(k_f, cos_f, k_dst_f); + + auto out = _mm512_cvtne2ps_pbh(q_dst_f, q_dst_f); + _mm256_mask_storeu_epi16(q_dst + i, x_mask, _mm512_extracti64x4_epi64((__m512i)out, 0)); + out = _mm512_cvtne2ps_pbh(k_dst_f, k_dst_f); + _mm256_mask_storeu_epi16(k_dst + i, x_mask, _mm512_extracti64x4_epi64((__m512i)out, 0)); + } + // second half + q_src += half; + k_src += half; + cos += half; + sin += half; + q_dst += half; + k_dst += half; + for (i = 0; i < half - tail; i += 16) { + auto q = _mm256_loadu_epi16(q_src + i - half); + auto q_f = _mm512_cvtpbh_ps((__m256bh)q); + auto k = _mm256_loadu_epi16(k_src + i - half); + auto k_f = _mm512_cvtpbh_ps((__m256bh)k); + auto cos_f = _mm512_loadu_ps(cos + i); + auto sin_f = _mm512_loadu_ps(sin + i); + auto q_dst_f = _mm512_mul_ps(q_f, sin_f); + auto k_dst_f = _mm512_mul_ps(k_f, sin_f); + + q = _mm256_loadu_epi16(q_src + i); + q_f = _mm512_cvtpbh_ps((__m256bh)q); + k = _mm256_loadu_epi16(k_src + i); + k_f = _mm512_cvtpbh_ps((__m256bh)k); + + q_dst_f = _mm512_fmadd_ps(q_f, cos_f, q_dst_f); + k_dst_f = _mm512_fmadd_ps(k_f, cos_f, k_dst_f); + + auto out = _mm512_cvtne2ps_pbh(q_dst_f, q_dst_f); + _mm256_storeu_si256(reinterpret_cast<__m256i *>(q_dst + i), _mm512_extracti64x4_epi64((__m512i)out, 0)); + out = _mm512_cvtne2ps_pbh(k_dst_f, k_dst_f); + _mm256_storeu_si256(reinterpret_cast<__m256i *>(k_dst + i), _mm512_extracti64x4_epi64((__m512i)out, 0)); + } + if (tail) { + auto q = _mm256_maskz_loadu_epi16(x_mask, q_src + i - half); + auto q_f = _mm512_cvtpbh_ps((__m256bh)q); + auto k = _mm256_maskz_loadu_epi16(x_mask, k_src + i - half); + auto k_f = _mm512_cvtpbh_ps((__m256bh)k); + auto cos_f = _mm512_maskz_loadu_ps(x_mask, cos + i); + auto sin_f = _mm512_maskz_loadu_ps(x_mask, sin + i); + auto q_dst_f = _mm512_mul_ps(q_f, sin_f); + auto k_dst_f = _mm512_mul_ps(k_f, sin_f); + + q = _mm256_maskz_loadu_epi16(x_mask, q_src + i); + q_f = _mm512_cvtpbh_ps((__m256bh)q); + k = _mm256_maskz_loadu_epi16(x_mask, k_src + i); + k_f = _mm512_cvtpbh_ps((__m256bh)k); + + q_dst_f = _mm512_fmadd_ps(q_f, cos_f, q_dst_f); + k_dst_f = _mm512_fmadd_ps(k_f, cos_f, k_dst_f); + + auto out = _mm512_cvtne2ps_pbh(q_dst_f, q_dst_f); + _mm256_mask_storeu_epi16(q_dst + i, x_mask, _mm512_extracti64x4_epi64((__m512i)out, 0)); + out = _mm512_cvtne2ps_pbh(k_dst_f, k_dst_f); + _mm256_mask_storeu_epi16(k_dst + i, x_mask, _mm512_extracti64x4_epi64((__m512i)out, 0)); + } + } +} \ No newline at end of file diff --git a/src/softmax_kernel_avx512.hpp b/src/softmax_kernel_avx512.hpp new file mode 100644 index 0000000..81c52c5 --- /dev/null +++ b/src/softmax_kernel_avx512.hpp @@ -0,0 +1,226 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#ifdef _WIN32 +#include +#else +#include +#include +#endif +#include "common/bf16.hpp" +#include "llm_types.hpp" +#include "utility_kernel_avx512.hpp" + +namespace llmdnn { + inline void exp_ps_avx512(__m512 & src) { + static __m512 exp_ln_flt_min_f = _mm512_castsi512_ps(_mm512_set1_epi32(0xc2aeac50)); // log(FLT_MIN) + static __m512 exp_ln_flt_max_f = _mm512_castsi512_ps(_mm512_set1_epi32(0x42b17218)); // log(FLT_MAX) + static __m512 exp_log2ef = _mm512_castsi512_ps(_mm512_set1_epi32(0x3fb8aa3b)); // log2(e) + static __m512 half = _mm512_castsi512_ps(_mm512_set1_epi32(0x3f000000)); // 0.5f + static __m512 ln2f = _mm512_castsi512_ps(_mm512_set1_epi32(0x3f317218)); // ln(2) + static __m512 one = _mm512_castsi512_ps(_mm512_set1_epi32(0x3f800000)); // 1.0f + static __m512i exponent_bias = _mm512_set1_epi32(0x0000007f); // 127 + static constexpr int n_mantissa_bits = 23; + static __m512 exp_pol1 = _mm512_castsi512_ps(_mm512_set1_epi32(0x3f7ffffb)); // p1 = 0.999999701f + static __m512 exp_pol2 = _mm512_castsi512_ps(_mm512_set1_epi32(0x3efffee3)); // p2 = 0.499991506f + static __m512 exp_pol3 = _mm512_castsi512_ps(_mm512_set1_epi32(0x3e2aad40)); // p3 = 0.166676521f + static __m512 exp_pol4 = _mm512_castsi512_ps(_mm512_set1_epi32(0x3d2b9d0d)); // p4 = 0.0418978221f + static __m512 exp_pol5 = _mm512_castsi512_ps(_mm512_set1_epi32(0x3c07cfce)); // p5 = 0.00828929059f + static __m512 two = _mm512_castsi512_ps(_mm512_set1_epi32(0x40000000)); // 2 + // exp(x) = + // = exp(n * ln(2) + r) // divide x by ln(2) and get quot and rem + // = 2^n * exp(r) // simplify the exp(n*ln(2)) expression + + // get mask of values lower than log(FLT_MIN) to zero them in the output + auto zero_mask = _mm512_cmp_ps_mask(src, exp_ln_flt_min_f, _CMP_LT_OS); + + // clip src + src = _mm512_min_ps(src, exp_ln_flt_max_f); + src = _mm512_max_ps(src, exp_ln_flt_min_f); + + // aux1 : r + auto aux1 = src; + + // calculate exp(x) + // fx = x * log2(e) + 0.5 + src = _mm512_mul_ps(src, exp_log2ef); + src = _mm512_add_ps(src, half); + + // tmp = floorf(fx) + src = _mm512_floor_ps(src); + + // aux1 = x - fx * ln2 + aux1 = _mm512_fnmadd_ps(src, ln2f, aux1); + + // We do not count 2^n here, because n can reach 128 and 2^128 is not + // representable by fp32, so to get around this problem, instead of computing + // 2^n * exp(r) will be counted 2*2^(n-1)*exp(r), because 2^127 + // and 2 are numbers representable in fp32. + + // compute 2^(n-1) + src = _mm512_sub_ps(src, one); + auto aux2_i = _mm512_cvtps_epi32(src); + aux2_i = _mm512_add_epi32(aux2_i, exponent_bias); + aux2_i = _mm512_slli_epi32 (aux2_i, n_mantissa_bits); + + // set zeroes at those points which were < log(FLT_MIN) + auto zero = _mm512_setzero_ps(); + auto aux2 = _mm512_mask_blend_ps(zero_mask, _mm512_castsi512_ps(aux2_i), zero); + + // compute polynomial + src = exp_pol5; + src = _mm512_fmadd_ps(src, aux1, exp_pol4); + src = _mm512_fmadd_ps(src, aux1, exp_pol3); + src = _mm512_fmadd_ps(src, aux1, exp_pol2); + src = _mm512_fmadd_ps(src, aux1, exp_pol1); + src = _mm512_fmadd_ps(src, aux1, one); + + // y = y * 2^n + src = _mm512_mul_ps(src, aux2); + src = _mm512_mul_ps(src, two); + } + + template + void softmax_avx512(D* dst, float* src, int N, QD quant) { + static_assert(std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value, + "softmax_avx512 only support output data types ov::bfloat16/uint8_t/int8_t/float"); + + static __m512 one = _mm512_castsi512_ps(_mm512_set1_epi32(0x3f800000)); // 1.0f + auto tail = N % 16; + __mmask16 x_mask = _cvtu32_mask16(0xFFFFu >> (16 - tail)); + + // get max + auto x_max = _mm512_set1_ps(std::numeric_limits::lowest()); + int i; + for (i = 0; i < N - tail; i += 16) { + auto x = _mm512_loadu_ps(src + i); + x_max = _mm512_max_ps(x_max, x); + } + // tails + if (tail) { + auto x = _mm512_maskz_loadu_ps(x_mask, src + i); + x_max = _mm512_mask_max_ps(x_max, x_mask, x_max, x); + } + auto max = _mm512_reduce_max_ps(x_max); + x_max = _mm512_set1_ps(max); + + // softmax + auto sum_exp = _mm512_setzero_ps(); + for(i = 0; i < N - tail; i += 16) { + auto x = _mm512_loadu_ps(src + i); + x = _mm512_sub_ps(x, x_max); + exp_ps_avx512(x); // exp(x-x_max) + sum_exp = _mm512_add_ps(sum_exp, x); // sum(exp(x-x_max)) + _mm512_storeu_ps(src + i, x); // save exp(x-x_max) + } + + // handle tails + if (tail) { + auto x = _mm512_maskz_loadu_ps(x_mask, src + i); + x = _mm512_sub_ps(x, x_max); + exp_ps_avx512(x); + x = _mm512_mask_blend_ps(x_mask, _mm512_setzero_ps(), x); + sum_exp = _mm512_add_ps(sum_exp, x); + _mm512_mask_storeu_ps(src + i, x_mask, x); + } + + auto sum = _mm512_reduce_add_ps(sum_exp); + sum_exp = _mm512_set1_ps(sum); + auto reciprocal_sum_exp = _mm512_div_ps(one, sum_exp); // 1/sum_exp + + // divide + if (std::is_same::value) { + for(i = 0; i < N - tail; i += 16) { + auto x = _mm512_loadu_ps(src + i); + x = _mm512_mul_ps(x, reciprocal_sum_exp); + _mm512_storeu_ps(dst + i, x); + } + // handle tails + if (tail) { + auto x = _mm512_maskz_loadu_ps(x_mask, src + i); + x = _mm512_mul_ps(x, reciprocal_sum_exp); + _mm512_mask_storeu_ps(dst + i, x_mask, x); + } + } + if (std::is_same::value) { + for(i = 0; i < N / 32 * 32; i += 32) { + auto x0 = _mm512_loadu_ps(src + i); + auto x1 = _mm512_loadu_ps(src + i + 16); + x0 = _mm512_mul_ps(x0, reciprocal_sum_exp); + x1 = _mm512_mul_ps(x1, reciprocal_sum_exp); + auto out = _mm512_cvtne2ps_pbh(x1, x0); + _mm512_storeu_epi32(dst + i, (__m512i)out); + } + if (i < N - tail) { + auto x = _mm512_loadu_ps(src + i); + x = _mm512_mul_ps(x, reciprocal_sum_exp); + auto out = _mm512_cvtne2ps_pbh(x, x); + _mm256_storeu_si256(reinterpret_cast<__m256i *>(dst + i), _mm512_extracti64x4_epi64((__m512i)out, 0)); + i += 16; + } + // handle tails + if (tail) { + auto x = _mm512_maskz_loadu_ps(x_mask, src + i); + x = _mm512_mul_ps(x, reciprocal_sum_exp); + auto out = _mm512_cvtne2ps_pbh(x, x); + _mm256_mask_storeu_epi16(dst + i, x_mask, _mm512_extracti64x4_epi64((__m512i)out, 0)); + } + } + if (std::is_same::value) { + __m512 q; + if constexpr (std::is_same::value) + q = _mm512_set1_ps(quant); + for(i = 0; i < N - tail; i += 16) { + if constexpr (std::is_same::value) + q = _mm512_loadu_ps(quant + i); + auto x = _mm512_loadu_ps(src + i); + x = _mm512_mul_ps(x, reciprocal_sum_exp); + x = _mm512_mul_ps(x, q); + auto x_i = _mm512_cvtps_epi32(x); + _mm512_mask_cvtsepi32_storeu_epi8(dst + i, 0xFFFF, x_i); + } + // handle tails + if (tail) { + auto x = _mm512_maskz_loadu_ps(x_mask, src + i); + if constexpr (std::is_same::value) + q = _mm512_maskz_loadu_ps(x_mask, quant + i); + x = _mm512_mul_ps(x, reciprocal_sum_exp); + x = _mm512_mul_ps(x, q); + auto x_i = _mm512_cvtps_epi32(x); + _mm512_mask_cvtsepi32_storeu_epi8(dst + i, x_mask, x_i); + } + } + if (std::is_same::value) { + auto zero = _mm512_setzero_epi32(); + __m512 q; + if constexpr (std::is_same::value) + q = _mm512_set1_ps(quant); + for(i = 0; i < N - tail; i += 16) { + if constexpr (std::is_same::value) + q = _mm512_loadu_ps(quant + i); + auto x = _mm512_loadu_ps(src + i); + x = _mm512_mul_ps(x, reciprocal_sum_exp); + x = _mm512_mul_ps(x, q); + auto x_i = _mm512_cvtps_epi32(x); + x_i = _mm512_max_epi32(x_i, zero); + _mm512_mask_cvtusepi32_storeu_epi8(dst + i, 0xFFFF, x_i); + } + // handle tails + if (tail) { + auto x = _mm512_maskz_loadu_ps(x_mask, src + i); + if constexpr (std::is_same::value) + q = _mm512_maskz_loadu_ps(x_mask, quant + i); + x = _mm512_mul_ps(x, reciprocal_sum_exp); + x = _mm512_mul_ps(x, q); + auto x_i = _mm512_cvtps_epi32(x); + x_i = _mm512_max_epi32(x_i, zero); + _mm512_mask_cvtusepi32_storeu_epi8(dst + i, x_mask, x_i); + } + } + } +} \ No newline at end of file diff --git a/src/transpose_kernel_avx512.hpp b/src/transpose_kernel_avx512.hpp new file mode 100644 index 0000000..abcd37d --- /dev/null +++ b/src/transpose_kernel_avx512.hpp @@ -0,0 +1,111 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#ifdef _WIN32 +#include +#else +#include +#include +#endif +#include "common/bf16.hpp" +#include "llm_types.hpp" +#include "utility_kernel_avx512.hpp" + +namespace llmdnn { + template + void memcpy2d_stride_avx512(D* dst, S* src, size_t height, size_t width, size_t src_stride, size_t dst_stride, float* quant=nullptr); + + template + void memcpy2d_stride_avx512(D* dst, float* src, size_t height, size_t width, size_t src_stride, size_t dst_stride, float* quant=nullptr) { + static_assert(std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value, + "memcpy2d_stride_avx512 only support output data types ov::bfloat16/uint8_t/int8_t/float"); + + auto tail = width % 16; + __mmask16 x_mask = _cvtu32_mask16(0xFFFFu >> (16 - tail)); + + for (size_t j = 0; j < height; j++) { + size_t i; + if (std::is_same::value) { + for(i = 0; i < width - tail; i += 16) { + auto x = _mm512_loadu_ps(src + i); + _mm512_storeu_ps(reinterpret_cast(dst) + i, x); + } + // handle tails + if (tail) { + auto x = _mm512_maskz_loadu_ps(x_mask, src + i); + _mm512_mask_storeu_ps(reinterpret_cast(dst) + i, x_mask, x); + } + } + + if (std::is_same::value) { + for(i = 0; i < width / 32 * 32; i += 32) { + auto x0 = _mm512_loadu_ps(src + i); + auto x1 = _mm512_loadu_ps(src + i + 16); + auto out = _mm512_cvtne2ps_pbh(x1, x0); + _mm512_storeu_epi32(reinterpret_cast(dst) + i, (__m512i)out); + } + if (i < width - tail) { + auto x = _mm512_loadu_ps(src + i); + auto out = _mm512_cvtne2ps_pbh(x, x); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(reinterpret_cast(dst) + i), + _mm512_extracti64x4_epi64(out, 0)); + i += 16; + } + // handle tails + if (tail) { + auto x = _mm512_maskz_loadu_ps(x_mask, src + i); + auto out = _mm512_cvtne2ps_pbh(x, x); + _mm256_mask_storeu_epi16(reinterpret_cast<__m256i*>(reinterpret_cast(dst) + i), + x_mask, _mm512_extracti64x4_epi64(out, 0)); + } + } + + if (std::is_same::value) { + for(i = 0; i < width - tail; i += 16) { + auto x = _mm512_loadu_ps(src + i); + auto q = _mm512_loadu_ps(quant + i); + x = _mm512_mul_ps(x, q); + auto x_i = _mm512_cvtps_epi32(x); + _mm512_mask_cvtsepi32_storeu_epi8(reinterpret_cast(dst) + i, 0xFFFF, x_i); + } + // handle tails + if (tail) { + auto x = _mm512_maskz_loadu_ps(x_mask, src + i); + auto q = _mm512_maskz_loadu_ps(x_mask, quant + i); + x = _mm512_mul_ps(x, q); + auto x_i = _mm512_cvtps_epi32(x); + _mm512_mask_cvtsepi32_storeu_epi8(reinterpret_cast(dst) + i, x_mask, x_i); + } + } + + if (std::is_same::value) { + auto zero = _mm512_setzero_epi32(); + for(i = 0; i < width - tail; i += 16) { + auto x = _mm512_loadu_ps(src + i); + auto q = _mm512_loadu_ps(quant + i); + x = _mm512_mul_ps(x, q); + auto x_i = _mm512_cvtps_epi32(x); + x_i = _mm512_max_epi32(x_i, zero); + _mm512_mask_cvtusepi32_storeu_epi8(reinterpret_cast(dst) + i, 0xFFFF, x_i); + } + // handle tails + if (tail) { + auto x = _mm512_maskz_loadu_ps(x_mask, src + i); + auto q = _mm512_maskz_loadu_ps(x_mask, quant + i); + x = _mm512_mul_ps(x, q); + auto x_i = _mm512_cvtps_epi32(x); + x_i = _mm512_max_epi32(x_i, zero); + _mm512_mask_cvtusepi32_storeu_epi8(reinterpret_cast(dst) + i, x_mask, x_i); + } + } + + src = reinterpret_cast(reinterpret_cast(src) + src_stride); + dst = reinterpret_cast(reinterpret_cast(dst) + dst_stride); + } + } +} \ No newline at end of file diff --git a/src/utility_kernel_amx.hpp b/src/utility_kernel_amx.hpp new file mode 100644 index 0000000..9a9778f --- /dev/null +++ b/src/utility_kernel_amx.hpp @@ -0,0 +1,109 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include + +#ifdef _WIN32 +#include +#else +#include +#endif + +struct tileconfig_t { + uint8_t palette_id; + uint8_t startRow; + uint8_t reserved[14]; + uint16_t cols[16]; + uint8_t rows[16]; + tileconfig_t() = default; + + tileconfig_t(int palette, int _startRow, const std::initializer_list> &_rows_columnsBytes) { + palette_id = palette; + startRow = _startRow; + int i; + for(i = 0; i < 14; i++) { + reserved[i] = 0; + } + i = 0; + for (const auto& ele : _rows_columnsBytes) { + rows[i] = ele.first; + cols[i] = ele.second; + i++; + } + for(; i < 16; i++) { + cols[i] = 0; + rows[i] = 0; + } + load(); + } + + tileconfig_t(int palette, int _startRow, const std::initializer_list &_rows, int columnsBytes) { + palette_id = palette; + startRow = _startRow; + int i; + for(i = 0; i < 14; i++) { + reserved[i] = 0; + } + i = 0; + for (const auto ele : _rows) { + rows[i] = ele; + cols[i] = columnsBytes; + i++; + } + for(; i < 16; i++) { + cols[i] = 0; + rows[i] = 0; + } + load(); + } + + tileconfig_t(int palette, int _startRow, int numTiles, int _rows, int columnsBytes) { + palette_id = palette; + startRow = _startRow; + int i; + for(i = 0; i < 14; i++) { + reserved[i] = 0; + } + for(i = 0; i < numTiles; i++) { + rows[i] = _rows; + cols[i] = columnsBytes; + } + for(; i < 16; i++) { + cols[i] = 0; + rows[i] = 0; + } + load(); + } + + ~tileconfig_t() { + _tile_release(); + } + + void __attribute__((noinline)) load() { + _tile_loadconfig(this); + } + + void store() { + _tile_storeconfig(this); + } + + friend std::ostream& operator<<(std::ostream& out, const tileconfig_t& cfg) { + out << " palette_id=" << static_cast(cfg.palette_id); + out << " startRow=" << static_cast(cfg.startRow); + out << " row x colsb=("; + for (int i = 0; i < 16;i++) { + if (cfg.rows[i] == 0 && cfg.cols[i] == 0) + continue; + if (i > 0) out << ","; + out << static_cast(cfg.rows[i]) << "x" << static_cast(cfg.cols[i]); + } + out << ")"; + return out; + } +} __attribute__ ((__packed__)); diff --git a/src/utility_kernel_avx512.hpp b/src/utility_kernel_avx512.hpp new file mode 100644 index 0000000..e47f5c1 --- /dev/null +++ b/src/utility_kernel_avx512.hpp @@ -0,0 +1,114 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include "common/bf16.hpp" +#ifdef _WIN32 +#include +#else +#include +#include +#endif + +namespace llmdnn { + +/// Convert Packed BF16 Data to Packed float Data. +/// +/// \headerfile +/// +/// \param __A +/// A 256-bit vector of [16 x bfloat]. +/// \returns A 512-bit vector of [16 x float] come from convertion of __A +static __inline__ __m512 _mm512_cvtpbh_ps(__m256bh __A) { + return _mm512_castsi512_ps((__m512i)_mm512_slli_epi32( + (__m512i)_mm512_cvtepi16_epi32((__m256i)__A), 16)); +} + +// Store masks. The highest bit in each byte indicates the byte to store. +alignas(16) const unsigned char masks[16][16] = +{ + { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }, + { 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }, + { 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }, + { 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }, + { 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }, + { 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }, + { 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }, + { 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }, + { 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }, + { 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }, + { 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }, + { 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00 }, + { 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00 }, + { 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00 }, + { 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00 }, + { 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00 } +}; + +inline void store_n(__m128i mm, unsigned int n, void* storage) +{ + _mm_maskmoveu_si128(mm, reinterpret_cast< const __m128i& >(masks[n]), static_cast< char* >(storage)); +} + +inline void quant_i8_avx512(void* dst, void* src, size_t ele_num, float scale) { + size_t i = 0; + auto* a = reinterpret_cast(src); + int8_t* d = reinterpret_cast(dst); + auto s = _mm512_set1_ps(scale); + for (; i < ele_num / 16 * 16; i += 16) { + auto a0 = _mm256_loadu_epi16(a); + auto a0_f = _mm512_cvtpbh_ps((__m256bh)a0); + auto d_f = _mm512_mul_ps(a0_f, s); + auto d_i = _mm512_cvtps_epi32(d_f); + auto d_i8 = _mm512_cvtsepi32_epi8(d_i); + _mm_storeu_si128(reinterpret_cast<__m128i *>(d), d_i8); + a += 16; + d += 16; + } + if (i != ele_num) { + // https://stackoverflow.com/questions/40391708/convert-16-bit-mask-mmask16-to-m128i-control-byte-mask-on-knl-xeon-phi-72 + __mmask16 msk = _cvtu32_mask16(0xFFFFu >> (16 - (ele_num % 16))); + auto a0 = _mm256_maskz_loadu_epi16(msk, a); + auto a0_f = _mm512_cvtpbh_ps((__m256bh)a0); + auto d_f = _mm512_mul_ps(a0_f, s); + auto d_i = _mm512_cvtps_epi32(d_f); + auto d_i8 = _mm512_cvtsepi32_epi8(d_i); + store_n(d_i8, ele_num % 16, d); + } +} + +// NOTE: did not handle tail because there should be enough room +inline void cvt_i32_f32_avx512(float* dst, int32_t* src, size_t ele_num) { + for (size_t i = 0; i < (ele_num + 15) / 16 * 16; i += 16) { + auto a0 = _mm512_load_epi32(src); + auto a_f = _mm512_cvtepi32_ps(a0); + _mm512_storeu_ps(dst, a_f); + src += 16; + dst += 16; + } +} + +inline void mul_add_f32_avx512(float* dst, float* src, float mul, float* add, int ele_num) { + auto mul_f = _mm512_set1_ps(mul); + int i; + auto tail = ele_num % 16; + __mmask16 msk = _cvtu32_mask16(0xFFFFu >> (16 - tail)); + for (i = 0; i < ele_num - tail; i += 16) { + auto a_f = _mm512_loadu_ps(src); + auto add_f = _mm512_loadu_ps(add); + _mm512_storeu_ps(dst, _mm512_fmadd_ps(a_f, mul_f, add_f)); + src += 16; + dst += 16; + add += 16; + } + if (tail) { + auto a_f = _mm512_maskz_loadu_ps(msk, src); + auto add_f = _mm512_maskz_loadu_ps(msk, add); + _mm512_mask_storeu_ps(dst, msk, _mm512_fmadd_ps(a_f, mul_f, add_f)); + } +} + +} \ No newline at end of file diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt new file mode 100644 index 0000000..8ed51d5 --- /dev/null +++ b/tests/CMakeLists.txt @@ -0,0 +1,33 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +cmake_minimum_required(VERSION 3.13) +project(cpu_extensions_tests) + +enable_testing() + +if (NOT TARGET gtest_main) + include(FetchContent) + FetchContent_Declare( + googletest + GIT_REPOSITORY https://github.com/google/googletest.git + GIT_TAG release-1.11.0 + GIT_SHALLOW TRUE + GIT_PROGRESS TRUE) + FetchContent_GetProperties(googletest) + if(NOT googletest_POPULATED) + FetchContent_Populate(googletest) + add_subdirectory(${googletest_SOURCE_DIR} ${googletest_BINARY_DIR} EXCLUDE_FROM_ALL) + endif() +endif() + +set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) + +file(GLOB_RECURSE TEST_SOURCE_FILES CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/src/*.cpp) + +find_package(OpenMP REQUIRED) + +add_executable(cpu_extensions_tests ${TEST_SOURCE_FILES}) +target_include_directories(cpu_extensions_tests PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../src) +target_link_libraries(cpu_extensions_tests cpu_extensions gtest_main stdc++ OpenMP::OpenMP_CXX) \ No newline at end of file diff --git a/tests/script/README.md b/tests/script/README.md new file mode 100644 index 0000000..b301c08 --- /dev/null +++ b/tests/script/README.md @@ -0,0 +1,23 @@ +# Torch extension to help test + +## usage +prepare python enviroment +``` +python3 -m venv .env +source .env/bin/activate +pip3 install -r requirements.txt +``` + +compile extension +``` +./build.sh +``` +debug version extension(llmdnn needs to config debug version also) +``` +DEBUG_EXT=1 ./build.sh +``` + +run test +``` +pytest +``` \ No newline at end of file diff --git a/tests/script/build.sh b/tests/script/build.sh new file mode 100755 index 0000000..b6bc751 --- /dev/null +++ b/tests/script/build.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +pip uninstall -y llmdnn +cd ../../build/ || exit +make -j 20 +cd - || exit +cd ext || exit +python setup.py clean --all install diff --git a/tests/script/ext/CMakeLists.txt b/tests/script/ext/CMakeLists.txt new file mode 100644 index 0000000..41a0951 --- /dev/null +++ b/tests/script/ext/CMakeLists.txt @@ -0,0 +1,29 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +cmake_minimum_required (VERSION 3.13) + +project(cpu_extensions_ext LANGUAGES CXX) + +# ref:https://stackoverflow.com/questions/68401650/how-can-i-make-a-pytorch-extension-with-cmake +execute_process(COMMAND python3 -c "import torch;print(torch.utils.cmake_prefix_path+'/Torch/',end='')" OUTPUT_VARIABLE TORCH_CMAKE_PATH) +set(CMAKE_PREFIX_PATH ${TORCH_CMAKE_PATH}) + +find_package(Python REQUIRED COMPONENTS Development) +find_package(Torch REQUIRED) +find_package(OpenMP REQUIRED) + +add_library(cpu_extensions_ext SHARED + mha_gpt.cpp + module.cpp + ../../src/test_common.cpp +) +set_target_properties(cpu_extensions_ext PROPERTIES + OUTPUT_NAME "cpu_extensions_ext" + POSITION_INDEPENDENT_CODE ON + LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../) +install(TARGETS cpu_extensions_ext DESTINATION ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}) +target_compile_features(cpu_extensions_ext PRIVATE cxx_std_14) +target_include_directories(cpu_extensions_ext PRIVATE ../../src) +target_link_libraries(cpu_extensions_ext PRIVATE ${TORCH_LIBRARIES} Python::Python cpu_extensions stdc++ OpenMP::OpenMP_CXX) diff --git a/tests/script/ext/attn_gpt.cpp b/tests/script/ext/attn_gpt.cpp new file mode 100644 index 0000000..a87a3a1 --- /dev/null +++ b/tests/script/ext/attn_gpt.cpp @@ -0,0 +1,251 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include "alloca.h" +#include "module.hpp" +#include "common/utility.hpp" +#include "utility_kernel_amx.hpp" +#include "llm_emb_gpt.hpp" +#include "llm_mha_gpt.hpp" +#include "test_common.hpp" + +using namespace torch::indexing; + +class attn_gpt { +public: + struct create_param { + size_t num_heads; + size_t head_size; + size_t head_size_aligned; // better to aligned to 64 bytes for best performance, apply for qkv + size_t max_seq_len; // max seq length for computing the size of matmul tmp result + // supported (qkv, dst): (bf16, bf16) + llmdnn::data_type_t qkv_precision; + llmdnn::data_type_t dst_precision; + size_t rotary_dims; + float normal_factor; + bool use_position2d; + }; + struct exec_param { + size_t batch; + size_t query_seq_len; + size_t past_seq_len; + bool is_causal_in_attention; // causal mask is fused in attention mask: chatglm uses it. + uint8_t* q; // shape: [batch, query_seq_len, hidden size], inner stride is ldq + uint8_t* k; // shape: [batch, query_seq_len, hidden size], inner stride is ldk + uint8_t* v; // shape: [batch, query_seq_len, hidden size], inner stride is ldv + size_t ldq; // inner stride of q + size_t ldk; // inner stride of k + size_t ldv; // inner stride of v + uint8_t** layer_past_key_dst; + uint8_t** layer_past_value_dst; + int* position2d_ids; // shape: [batch, 2, query_seq_len] + float* attention_mask; // attention mask, attention_mask[0] shape: + // [batch, 1, 1, key_seq_len], when is_causal_in_attention is false + // [batch, 1, query_seq_len, key_seq_len], when is_causal_in_attention is true + uint8_t* attn_output; + size_t head_stride_in_kv; + float* cos; + float* sin; + }; + + attn_gpt(); + bool create(const create_param& param); + void exec(const exec_param& param); + +private: + create_param _create_param; + std::shared_ptr _emb_gpt; + std::shared_ptr _mha_gpt; + std::shared_ptr _query_dst; + size_t _query_cached_batch = 0; +}; + +attn_gpt::attn_gpt(): _emb_gpt(std::make_shared()), + _mha_gpt(std::make_shared()) { + +} + +bool attn_gpt::create(const attn_gpt::create_param& param) { + _create_param = param; + llmdnn::emb_gpt::create_param emb_param; + emb_param.num_heads = param.num_heads; + emb_param.head_size = param.head_size; + emb_param.head_size_aligned = param.head_size_aligned; + emb_param.qkv_precision = param.qkv_precision; + emb_param.dst_precision = param.dst_precision; + emb_param.rotary_dims = param.rotary_dims; + emb_param.use_position2d = param.use_position2d; + + if (!_emb_gpt->create(emb_param)) + return false; + + llmdnn::mha_gpt::create_param mha_param; + mha_param.num_heads = param.num_heads; + mha_param.head_size = param.head_size; + mha_param.head_size_aligned = param.head_size_aligned; + mha_param.normal_factor = param.normal_factor; + mha_param.qkv_precision = param.qkv_precision; + mha_param.dst_precision = param.dst_precision; + mha_param.max_seq_len = param.max_seq_len; + + return _mha_gpt->create(mha_param); +} + +void attn_gpt::exec(const attn_gpt::exec_param& param) { + if (_query_cached_batch < param.batch) { + auto capacity = param.batch * _create_param.max_seq_len * (_create_param.num_heads * _create_param.head_size_aligned) * + llmdnn::get_precision_size(_create_param.qkv_precision); + _query_dst = std::shared_ptr(reinterpret_cast(aligned_alloc(64, capacity)), + [](void * p) { ::free(p); }); + memset(_query_dst.get(), 0, capacity); + _query_cached_batch = param.batch; + } + + llmdnn::emb_gpt::exec_param emb_param; + emb_param.batch = param.batch; + emb_param.query_seq_len = param.query_seq_len; + emb_param.past_seq_len = param.past_seq_len; + emb_param.q = param.q; + emb_param.k = param.k; + emb_param.v = param.v; + emb_param.ldq = param.ldq; + emb_param.ldk = param.ldk; + emb_param.ldv = param.ldv; + emb_param.query_dst = _query_dst.get(); + emb_param.layer_past_key_src = param.layer_past_key_dst; + emb_param.layer_past_value_src = param.layer_past_value_dst; + emb_param.layer_past_key_dst = param.layer_past_key_dst; + emb_param.layer_past_value_dst = param.layer_past_value_dst; + emb_param.position2d_ids = param.position2d_ids; + emb_param.head_stride_in_kv = param.head_stride_in_kv; + emb_param.cos = param.cos; + emb_param.sin = param.sin; + _emb_gpt->exec(emb_param); + + llmdnn::mha_gpt::exec_param mha_param; + mha_param.batch = param.batch; + mha_param.query_seq_len = param.query_seq_len; + mha_param.key_seq_len = param.query_seq_len + param.past_seq_len; + mha_param.q = emb_param.query_dst; + mha_param.attn_output = param.attn_output; + mha_param.head_stride_in_kv = param.head_stride_in_kv; + mha_param.is_causal_in_attention = param.is_causal_in_attention; + mha_param.attention_mask = param.attention_mask; + mha_param.k = emb_param.layer_past_key_dst; + mha_param.v = emb_param.layer_past_value_dst; + _mha_gpt->exec(mha_param); +} + +void regclass_attn_gpt(pybind11::module m) { + py::class_> cls(m, "attn_gpt"); + cls.def(py::init<>()); + cls.def("create", [] (attn_gpt& self, + const size_t num_heads, + const size_t head_size, + const size_t head_size_aligned, + float normal_factor, + const std::string qkv_precision_name, + const std::string dst_precision_name, + const size_t max_seq_len, + const size_t rotary_dims, + bool use_position2d) { + attn_gpt::create_param param; + param.num_heads = num_heads; + param.head_size = head_size; + param.head_size_aligned = head_size_aligned; + param.normal_factor = normal_factor; + param.qkv_precision = llmdnn::get_dt_from_str(qkv_precision_name); + param.dst_precision = llmdnn::get_dt_from_str(dst_precision_name); + param.max_seq_len = max_seq_len; + param.rotary_dims = rotary_dims; + param.use_position2d = use_position2d; + if (param.qkv_precision == llmdnn::dnnl_data_type_undef) + throw pybind11::type_error("Incorrect qkv type " + qkv_precision_name); + if (param.dst_precision == llmdnn::dnnl_data_type_undef) + throw pybind11::type_error("Incorrect dst type " + dst_precision_name); + if (!self.create(param)) + throw pybind11::type_error("Incorrect param"); + }, + py::arg("num_heads"), + py::arg("head_size"), + py::arg("head_size_aligned"), + py::arg("normal_factor"), + py::arg("qkv_precision_name"), + py::arg("dst_precision_name"), + py::arg("max_seq_len"), + py::arg("rotary_dims"), + py::arg("use_position2d") = false, + R"( + Create emb + + :param num_heads: heads number. + :type num_heads: int + )"); + cls.def("exec_position", [] (attn_gpt& self, const torch::Tensor& qkv, const torch::Tensor& layer_past_key_dst, + const torch::Tensor& layer_past_value_dst, int64_t past_seq_len, const torch::Tensor& attn_mask, + const torch::Tensor& position2d_ids, const torch::Tensor& cos, const torch::Tensor& sin) { + // qkv: [batch, seq_len, (num_heads * 3 * head_size)] + // layer_past_padded: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] + // past_seq_len: past_seq_len==layer_past.shape[-2] + // attn_mask: [batch, 1, 1/query_seq_len, key_seq_len] + // key/value: [batch, num_heads, query_seq_len+past_seq_len, head_size_aligned] + AT_ASSERT(qkv.dim() == 3 && layer_past_key_dst.dim() == 4 && layer_past_value_dst.dim() == 4 && attn_mask.dim() == 4 && + qkv.size(0) == layer_past_key_dst.size(0) && + layer_past_key_dst.dim() == layer_past_value_dst.dim()); + auto batch = qkv.size(0); + auto num_heads = layer_past_key_dst.size(1); + auto query_seq_len = qkv.size(1); + auto head_size = qkv.size(2) / 3 / num_heads; + auto head_size_aligned = layer_past_key_dst.size(3); + auto max_seq_len = layer_past_key_dst.size(2); + AT_ASSERT(past_seq_len <= layer_past_key_dst.size(2) && head_size <= layer_past_key_dst.size(3) && + query_seq_len <= layer_past_key_dst.size(2)); + + attn_gpt::exec_param param; + param.batch = batch; + param.query_seq_len = query_seq_len; + param.past_seq_len = past_seq_len; + param.q = reinterpret_cast(qkv.data_ptr()); + param.k = param.q + head_size * sizeof(ov::bfloat16); + param.v = param.k + head_size * sizeof(ov::bfloat16); + param.ldq = param.ldk = param.ldv = head_size * 3; + param.layer_past_key_dst = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); + param.layer_past_value_dst = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); + for (int i = 0; i < batch; i++) { + param.layer_past_key_dst[i] = reinterpret_cast(layer_past_key_dst[i].data_ptr()); + param.layer_past_value_dst[i] = reinterpret_cast(layer_past_value_dst[i].data_ptr()); + } + param.position2d_ids = reinterpret_cast(position2d_ids.data_ptr()); + + param.is_causal_in_attention = attn_mask.size(2) != 1; + param.attention_mask = attn_mask.data_ptr(); + param.head_stride_in_kv = max_seq_len * head_size_aligned; + auto out = qkv.new_empty({batch, query_seq_len, num_heads * head_size}); + param.attn_output = reinterpret_cast(out.data_ptr()); + param.cos = reinterpret_cast(cos.data_ptr()); + param.sin = reinterpret_cast(sin.data_ptr()); + + self.exec(param); + + // auto options = torch::TensorOptions().dtype(torch::kBFloat16); + // auto query = torch::from_blob(param.query, {batch, num_heads, query_seq_len, head_size}, options); + return out; + }, + py::arg("qkv"), + py::arg("layer_past_key_dst"), + py::arg("layer_past_value_dst"), + py::arg("past_seq_len"), + py::arg("attn_mask"), + py::arg("position2d_ids"), + py::arg("cos"), + py::arg("sin"), + R"( + exec emb + + :param num_heads: heads number. + :type num_heads: int + )"); +} \ No newline at end of file diff --git a/tests/script/ext/emb_gpt.cpp b/tests/script/ext/emb_gpt.cpp new file mode 100644 index 0000000..cbf4aa8 --- /dev/null +++ b/tests/script/ext/emb_gpt.cpp @@ -0,0 +1,184 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include "alloca.h" +#include "module.hpp" +#include "common/utility.hpp" +#include "utility_kernel_amx.hpp" +#include "llm_emb_gpt.hpp" +#include "test_common.hpp" + +using namespace torch::indexing; + +void regclass_emb_gpt(pybind11::module m) { + py::class_> cls(m, "emb_gpt"); + cls.def(py::init<>()); + cls.def("create", [] (llmdnn::emb_gpt& self, + const size_t num_heads, + const size_t head_size, + const size_t head_size_aligned, + const std::string qkv_precision_name, + const std::string dst_precision_name, + const size_t rotary_dims, + bool use_position2d) { + llmdnn::emb_gpt::create_param param; + param.num_heads = num_heads; + param.head_size = head_size; + param.head_size_aligned = head_size_aligned; + param.qkv_precision = llmdnn::get_dt_from_str(qkv_precision_name); + param.dst_precision = llmdnn::get_dt_from_str(dst_precision_name); + param.rotary_dims = rotary_dims; + param.use_position2d = use_position2d; + if (param.qkv_precision == llmdnn::dnnl_data_type_undef) + throw pybind11::type_error("Incorrect qkv type " + qkv_precision_name); + if (param.dst_precision == llmdnn::dnnl_data_type_undef) + throw pybind11::type_error("Incorrect dst type " + dst_precision_name); + if (!self.create(param)) + throw pybind11::type_error("Incorrect param"); + }, + py::arg("num_heads"), + py::arg("head_size"), + py::arg("head_size_aligned"), + py::arg("qkv_precision_name"), + py::arg("dst_precision_name"), + py::arg("rotary_dims"), + py::arg("use_position2d") = false, + R"( + Create emb + + :param num_heads: heads number. + :type num_heads: int + )"); + // torch::List + cls.def("exec", [] (llmdnn::emb_gpt& self, const torch::Tensor& qkv, const torch::Tensor& layer_past_key_dst, + const torch::Tensor& layer_past_value_dst, const torch::Tensor& query_padded, int64_t past_seq_len, const torch::Tensor& cos, const torch::Tensor& sin) { + // qkv: [batch, seq_len, (num_heads * 3 * head_size)] + // layer_past_padded: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] + // past_seq_len: past_seq_len==layer_past.shape[-2] + // query_padded: [batch, num_heads, query_seq_len, head_size_aligned] + // key/value: [batch, num_heads, query_seq_len+past_seq_len, head_size_aligned] + AT_ASSERT(qkv.dim() == 3 && layer_past_key_dst.dim() == 4 && layer_past_value_dst.dim() == 4 && + qkv.size(0) == layer_past_key_dst.size(0) && + layer_past_key_dst.dim() == layer_past_value_dst.dim()); + AT_ASSERT(query_padded.dim() == 4 && query_padded.size(0) == qkv.size(0) && + query_padded.size(1) == layer_past_key_dst.size(1) && query_padded.size(2) == qkv.size(1) && + query_padded.size(3) == layer_past_key_dst.size(3)); + auto batch = qkv.size(0); + auto num_heads = layer_past_key_dst.size(1); + auto query_seq_len = qkv.size(1); + auto head_size = qkv.size(2) / 3 / num_heads; + auto head_size_aligned = layer_past_key_dst.size(3); + auto max_seq_len = layer_past_key_dst.size(2); + AT_ASSERT(past_seq_len <= layer_past_key_dst.size(2) && head_size <= layer_past_key_dst.size(3) && + query_seq_len <= layer_past_key_dst.size(2)); + + llmdnn::emb_gpt::exec_param param; + param.batch = batch; + param.query_seq_len = query_seq_len; + param.past_seq_len = past_seq_len; + param.q = reinterpret_cast(qkv.data_ptr()); + param.k = param.q + head_size * sizeof(ov::bfloat16); + param.v = param.k + head_size * sizeof(ov::bfloat16); + param.ldq = param.ldk = param.ldv = head_size * 3; + param.query_dst = reinterpret_cast(query_padded.data_ptr()); + param.layer_past_key_src = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); + param.layer_past_value_src = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); + param.layer_past_key_dst = param.layer_past_key_src; + param.layer_past_value_dst = param.layer_past_value_src; + param.head_stride_in_kv = max_seq_len * head_size_aligned; + param.cos = reinterpret_cast(cos.data_ptr()); + param.sin = reinterpret_cast(sin.data_ptr()); + for (int i = 0; i < batch; i++) { + param.layer_past_key_src[i] = reinterpret_cast(layer_past_key_dst[i].data_ptr()); + param.layer_past_value_src[i] = reinterpret_cast(layer_past_value_dst[i].data_ptr()); + } + + self.exec(param); + + // auto options = torch::TensorOptions().dtype(torch::kBFloat16); + // auto query = torch::from_blob(param.query, {batch, num_heads, query_seq_len, head_size}, options); + }, + py::arg("qkv"), + py::arg("layer_past_key_dst"), + py::arg("layer_past_value_dst"), + py::arg("query_padded"), + py::arg("past_seq_len"), + py::arg("cos"), + py::arg("sin"), + R"( + exec emb + + :param num_heads: heads number. + :type num_heads: int + )"); + cls.def("exec_position", [] (llmdnn::emb_gpt& self, const torch::Tensor& qkv, const torch::Tensor& layer_past_key_src, const torch::Tensor& layer_past_value_src, + const torch::Tensor& layer_past_key_dst, const torch::Tensor& layer_past_value_dst, const torch::Tensor& query_padded, int64_t past_seq_len, const torch::Tensor position2d_ids, const torch::Tensor& cos, const torch::Tensor& sin) { + // qkv: [batch, seq_len, (num_heads * 3 * head_size)] + // layer_past_padded: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] + // past_seq_len: past_seq_len==layer_past.shape[-2] + // query_padded: [batch, num_heads, query_seq_len, head_size_aligned] + // key/value: [batch, num_heads, query_seq_len+past_seq_len, head_size_aligned] + AT_ASSERT(qkv.dim() == 3 && layer_past_key_dst.dim() == 4 && layer_past_value_dst.dim() == 4 && + qkv.size(0) == layer_past_key_dst.size(0) && + layer_past_key_dst.dim() == layer_past_value_dst.dim()); + AT_ASSERT(query_padded.dim() == 4 && query_padded.size(0) == qkv.size(0) && + query_padded.size(1) == layer_past_key_dst.size(1) && query_padded.size(2) == qkv.size(1) && + query_padded.size(3) == layer_past_key_dst.size(3)); + auto batch = qkv.size(0); + auto num_heads = layer_past_key_dst.size(1); + auto query_seq_len = qkv.size(1); + auto head_size = qkv.size(2) / 3 / num_heads; + auto head_size_aligned = layer_past_key_dst.size(3); + auto max_seq_len = layer_past_key_dst.size(2); + AT_ASSERT(past_seq_len <= layer_past_key_dst.size(2) && head_size <= layer_past_key_dst.size(3) && + query_seq_len <= layer_past_key_dst.size(2)); + + llmdnn::emb_gpt::exec_param param; + param.batch = batch; + param.query_seq_len = query_seq_len; + param.past_seq_len = past_seq_len; + param.q = reinterpret_cast(qkv.data_ptr()); + param.k = param.q + head_size * sizeof(ov::bfloat16); + param.v = param.k + head_size * sizeof(ov::bfloat16); + param.ldq = param.ldk = param.ldv = head_size * 3; + param.query_dst = reinterpret_cast(query_padded.data_ptr()); + param.layer_past_key_src = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); + param.layer_past_value_src = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); + param.layer_past_key_dst = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); + param.layer_past_value_dst = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); + param.head_stride_in_kv = max_seq_len * head_size_aligned; + param.cos = reinterpret_cast(cos.data_ptr()); + param.sin = reinterpret_cast(sin.data_ptr()); + for (int i = 0; i < batch; i++) { + param.layer_past_key_src[i] = past_seq_len == 0 ? nullptr : reinterpret_cast(layer_past_key_src[i].data_ptr()); + param.layer_past_value_src[i] = past_seq_len == 0 ? nullptr : reinterpret_cast(layer_past_value_src[i].data_ptr()); + param.layer_past_key_dst[i] = reinterpret_cast(layer_past_key_dst[i].data_ptr()); + param.layer_past_value_dst[i] = reinterpret_cast(layer_past_value_dst[i].data_ptr()); + } + param.position2d_ids = reinterpret_cast(position2d_ids.data_ptr()); + + self.exec(param); + + // auto options = torch::TensorOptions().dtype(torch::kBFloat16); + // auto query = torch::from_blob(param.query, {batch, num_heads, query_seq_len, head_size}, options); + }, + py::arg("qkv"), + py::arg("layer_past_key_src"), + py::arg("layer_past_value_src"), + py::arg("layer_past_key_dst"), + py::arg("layer_past_value_dst"), + py::arg("query_padded"), + py::arg("past_seq_len"), + py::arg("position2d_ids"), + py::arg("cos"), + py::arg("sin"), + R"( + exec emb + + :param num_heads: heads number. + :type num_heads: int + )"); +} \ No newline at end of file diff --git a/tests/script/ext/mha_gpt.cpp b/tests/script/ext/mha_gpt.cpp new file mode 100644 index 0000000..704a1fb --- /dev/null +++ b/tests/script/ext/mha_gpt.cpp @@ -0,0 +1,168 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include "alloca.h" +#include "module.hpp" +#include "common/utility.hpp" +#include "utility_kernel_amx.hpp" +#include "llm_mha_gpt.hpp" +#include "test_common.hpp" + +void regclass_mha_gpt(pybind11::module m) { + py::class_> cls(m, "mha_gpt"); + cls.def(py::init<>()); + cls.def("create", [] (llmdnn::mha_gpt& self, + const size_t num_heads, + const size_t head_size, + const size_t head_size_aligned, + const float normal_factor, + const std::string qkv_precision_name, + const std::string dst_precision_name, + const size_t max_seq_len) { + llmdnn::mha_gpt::create_param param; + param.num_heads = num_heads; + param.head_size = head_size; + param.head_size_aligned = head_size_aligned; + param.normal_factor = normal_factor; + param.qkv_precision = llmdnn::get_dt_from_str(qkv_precision_name); + param.dst_precision = llmdnn::get_dt_from_str(dst_precision_name); + param.max_seq_len = max_seq_len; + if (param.qkv_precision == llmdnn::dnnl_data_type_undef) + throw pybind11::type_error("Incorrect qkv type " + qkv_precision_name); + if (param.dst_precision == llmdnn::dnnl_data_type_undef) + throw pybind11::type_error("Incorrect dst type " + dst_precision_name); + if (!self.create(param)) + throw pybind11::type_error("Incorrect param"); + }, + py::arg("num_heads"), + py::arg("head_size"), + py::arg("head_size_aligned"), + py::arg("normal_factor"), + py::arg("qkv_precision_name"), + py::arg("dst_precision_name"), + py::arg("max_seq_len"), + R"( + Create mha + + :param num_heads: heads number. + :type num_heads: int + )"); + cls.def("exec", [] (llmdnn::mha_gpt& self, const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v, const torch::Tensor& attn_mask, int64_t head_size, int64_t key_seq_len) { + // q: [batch, num_heads, query_seq_len, head_size_aligned] + // k: [batch, num_heads, max_seq_len, head_size_aligned] valid in max_seq_len: key_seq_len + // v: [batch, num_heads, max_seq_len, head_size_aligned] valid in max_seq_len: key_seq_len + // attn_mask: [batch, 1, 1/query_seq_len, key_seq_len] + // out: [batch, query_seq_len, num_heads * head_size] + AT_ASSERT(q.dim() == 4 && k.dim() == 4 && v.dim() == 4 && attn_mask.dim() == 4); + auto batch = q.size(0); + auto num_heads = q.size(1); + auto query_seq_len = q.size(2); + auto head_size_aligned = q.size(3); + auto max_seq_len = k.size(2); + auto attn_len = attn_mask.size(3); + AT_ASSERT(max_seq_len == v.size(2) && + batch == k.size(0) && batch == v.size(0) && batch == attn_mask.size(0) && + num_heads == k.size(1) && num_heads == v.size(1) && + head_size_aligned == k.size(3) && head_size_aligned == v.size(3)); + + llmdnn::mha_gpt::exec_param param; + param.batch = batch; + param.query_seq_len = query_seq_len; + param.key_seq_len = key_seq_len == 0 ? max_seq_len : key_seq_len; + head_size = head_size == 0 ? head_size_aligned : head_size; + auto out = q.new_empty({batch, query_seq_len, num_heads * head_size}); + AT_ASSERT((int64_t)param.key_seq_len == attn_len); + param.q = reinterpret_cast(q.data_ptr()); + param.attn_output = reinterpret_cast(out.data_ptr()); + param.head_stride_in_kv = max_seq_len * head_size_aligned; + param.is_causal_in_attention = attn_mask.size(2) != 1; + param.attention_mask = attn_mask.data_ptr(); + param.k = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); + param.v = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); + for (int i = 0; i < batch; i++) { + param.k[i] = reinterpret_cast(k[i].data_ptr()); + param.v[i] = reinterpret_cast(v[i].data_ptr()); + } + + self.exec(param); + return out; + }, + py::arg("q"), + py::arg("k"), + py::arg("v"), + py::arg("attn_mask"), + py::arg("head_size") = 0, + py::arg("key_seq_len") = 0, + R"( + exec mha + + :param num_heads: heads number. + :type num_heads: int + )"); + cls.def("exec_quant", [] (llmdnn::mha_gpt& self, const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v, const torch::Tensor& attn_mask, + float q_dequant, float k_dequant, float v_dequant, float qk_quant, const std::vector& qkv_quant, int64_t head_size, int64_t key_seq_len) { + // q: [batch, num_heads, query_seq_len, head_size_aligned] + // k: [batch, num_heads, max_seq_len, head_size_aligned] valid in max_seq_len: key_seq_len + // v: [batch, num_heads, max_seq_len, head_size_aligned] valid in max_seq_len: key_seq_len + // attn_mask: [batch, 1, 1/query_seq_len, key_seq_len] + // out: [batch, query_seq_len, num_heads * head_size] + AT_ASSERT(q.dim() == 4 && k.dim() == 4 && v.dim() == 4 && attn_mask.dim() == 4); + auto batch = q.size(0); + auto num_heads = q.size(1); + auto query_seq_len = q.size(2); + auto head_size_aligned = q.size(3); + auto max_seq_len = k.size(2); + auto attn_len = attn_mask.size(3); + AT_ASSERT(max_seq_len == v.size(2) && + batch == k.size(0) && batch == v.size(0) && batch == attn_mask.size(0) && + num_heads == k.size(1) && num_heads == v.size(1) && + head_size_aligned == k.size(3) && head_size_aligned == v.size(3)); + + llmdnn::mha_gpt::exec_param param; + param.batch = batch; + param.query_seq_len = query_seq_len; + param.key_seq_len = key_seq_len == 0 ? max_seq_len : key_seq_len; + head_size = head_size == 0 ? head_size_aligned : head_size; + auto out = q.new_empty({batch, query_seq_len, num_heads * head_size}, torch::TensorOptions(torch::kInt8)); + AT_ASSERT((int64_t)param.key_seq_len == attn_len); + param.q = reinterpret_cast(q.data_ptr()); + param.attn_output = reinterpret_cast(out.data_ptr()); + param.head_stride_in_kv = max_seq_len * head_size_aligned; + param.k = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); + param.v = reinterpret_cast(alloca(batch * sizeof(uint8_t*))); + param.is_causal_in_attention = attn_mask.size(2) != 1; + param.attention_mask = attn_mask.data_ptr(); + param.q_dequant = q_dequant; + param.k_dequant = k_dequant; + param.v_dequant = v_dequant; + param.qk_quant = qk_quant; + param.qkv_quant = qkv_quant; + for (int i = 0; i < batch; i++) { + param.k[i] = reinterpret_cast(k[i].data_ptr()); + param.v[i] = reinterpret_cast(v[i].data_ptr()); + } + + self.exec(param); + return out; + }, + py::arg("q"), + py::arg("k"), + py::arg("v"), + py::arg("attn_mask"), + py::arg("q_dequant"), + py::arg("k_dequant"), + py::arg("v_dequant"), + py::arg("qk_quant"), + py::arg("qkv_quant"), + py::arg("head_size") = 0, + py::arg("key_seq_len") = 0, + R"( + exec mha quant + + :param num_heads: heads number. + :type num_heads: int + )"); +} \ No newline at end of file diff --git a/tests/script/ext/module.cpp b/tests/script/ext/module.cpp new file mode 100644 index 0000000..58c0ba1 --- /dev/null +++ b/tests/script/ext/module.cpp @@ -0,0 +1,20 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include "module.hpp" +#include "utility_kernel_amx.hpp" +#include "llm_mha_gpt.hpp" +#include "test_common.hpp" + +PYBIND11_MODULE(llmdnn, m) { + static bool initAMX = initXTILE(); + if (!initAMX) { + std::cout << "init amx failed.\n"; + } + regclass_mha_gpt(m); + regclass_emb_gpt(m); + regclass_attn_gpt(m); +} \ No newline at end of file diff --git a/tests/script/ext/module.hpp b/tests/script/ext/module.hpp new file mode 100644 index 0000000..12efeed --- /dev/null +++ b/tests/script/ext/module.hpp @@ -0,0 +1,11 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +void regclass_mha_gpt(pybind11::module m); +void regclass_emb_gpt(pybind11::module m); +void regclass_attn_gpt(pybind11::module m); \ No newline at end of file diff --git a/tests/script/ext/setup.py b/tests/script/ext/setup.py new file mode 100644 index 0000000..2780c94 --- /dev/null +++ b/tests/script/ext/setup.py @@ -0,0 +1,48 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +from setuptools import setup, Extension +from torch.utils import cpp_extension +import sys +import os + +''' +using intel compiler: +source ~/intel/oneapi/setvars.sh +export CXX=icx +export CC=icx +''' +debug = False +if 'DEBUG_EXT' in os.environ: + debug = True if os.environ['DEBUG_EXT'] == '1' else False +extra_args = ['-fopenmp', + '-march=native'] +cpu_extensions_lib_dir = f'{os.getcwd()}/../../../build/lib' +if debug: + cpu_extensions_lib_dir = f'{os.getcwd()}/../../../build/lib' + extra_args += ['-g', '-O0'] + print('install debug version') +else: + print('install release version') + +setup(name='llmdnn', + ext_modules=[ + cpp_extension.CppExtension( + 'llmdnn', + ['module.cpp', f'../../src/test_common.cpp', + 'mha_gpt.cpp', + 'emb_gpt.cpp', + 'attn_gpt.cpp', + ], + extra_compile_args=extra_args, + include_dirs=[f'{os.getcwd()}/../../src', + f'{os.getcwd()}/../../../include', + f'{os.getcwd()}/../../../src'], + library_dirs=[f'{sys.prefix}/lib', + cpu_extensions_lib_dir], + libraries=['cpu_extensions', + 'stdc++']), + ], + cmdclass={'build_ext': cpp_extension.BuildExtension} + ) \ No newline at end of file diff --git a/tests/script/models/chatglm-6b/LICENSE b/tests/script/models/chatglm-6b/LICENSE new file mode 100644 index 0000000..ac4aee5 --- /dev/null +++ b/tests/script/models/chatglm-6b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright Zhengxiao Du + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/tests/script/models/chatglm-6b/MODEL_LICENSE b/tests/script/models/chatglm-6b/MODEL_LICENSE new file mode 100644 index 0000000..f8e2731 --- /dev/null +++ b/tests/script/models/chatglm-6b/MODEL_LICENSE @@ -0,0 +1,65 @@ +The ChatGLM-6B License + +一、定义 + +“许可方”是指分发其软件的 ChatGLM-6B 模型团队。 + +“软件”是指根据本许可提供的 ChatGLM-6B 模型参数。 + +2. 许可授予 + +根据本许可的条款和条件,许可方特此授予您非排他性、全球性、不可转让、不可再许可、可撤销、免版税的版权许可,仅用于您的非商业研究目的。 + +上述版权声明和本许可声明应包含在本软件的所有副本或重要部分中。 + +3.限制 + +您不得出于任何商业、军事或非法目的使用、复制、修改、合并、发布、分发、复制或创建本软件的全部或部分衍生作品。 + +您不得利用本软件从事任何危害国家安全和国家统一、危害社会公共利益、侵犯人身权益的行为。 + +4.免责声明 + +本软件“按原样”提供,不提供任何明示或暗示的保证,包括但不限于对适销性、特定用途的适用性和非侵权性的保证。 在任何情况下,作者或版权持有人均不对任何索赔、损害或其他责任负责,无论是在合同诉讼、侵权行为还是其他方面,由软件或软件的使用或其他交易引起、由软件引起或与之相关 软件。 + +5. 责任限制 + +除适用法律禁止的范围外,在任何情况下且根据任何法律理论,无论是基于侵权行为、疏忽、合同、责任或其他原因,任何许可方均不对您承担任何直接、间接、特殊、偶然、示范性、 或间接损害,或任何其他商业损失,即使许可人已被告知此类损害的可能性。 + +6.争议解决 + +本许可受中华人民共和国法律管辖并按其解释。 因本许可引起的或与本许可有关的任何争议应提交北京市海淀区人民法院。 + +请注意,许可证可能会更新到更全面的版本。 有关许可和版权的任何问题,请通过 glm-130b@googlegroups.com 与我们联系。 + +1. Definitions + +“Licensor” means the ChatGLM-6B Model Team that distributes its Software. + +“Software” means the ChatGLM-6B model parameters made available under this license. + +2. License Grant + +Subject to the terms and conditions of this License, the Licensor hereby grants to you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty-free copyright license to use the Software solely for your non-commercial research purposes. + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +3. Restriction + +You will not use, copy, modify, merge, publish, distribute, reproduce, or create derivative works of the Software, in whole or in part, for any commercial, military, or illegal purposes. + +You will not use the Software for any act that may undermine China's national security and national unity, harm the public interest of society, or infringe upon the rights and interests of human beings. + +4. Disclaimer + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +5. Limitation of Liability + +EXCEPT TO THE EXTENT PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER BASED IN TORT, NEGLIGENCE, CONTRACT, LIABILITY, OR OTHERWISE WILL ANY LICENSOR BE LIABLE TO YOU FOR ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES, OR ANY OTHER COMMERCIAL LOSSES, EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. + +6. Dispute Resolution + +This license shall be governed and construed in accordance with the laws of People’s Republic of China. Any dispute arising from or in connection with this License shall be submitted to Haidian District People's Court in Beijing. + +Note that the license is subject to update to a more comprehensive version. For any questions related to the license and copyright, please contact us at glm-130b@googlegroups.com. diff --git a/tests/script/models/chatglm-6b/README.md b/tests/script/models/chatglm-6b/README.md new file mode 100644 index 0000000..857edd7 --- /dev/null +++ b/tests/script/models/chatglm-6b/README.md @@ -0,0 +1,89 @@ +--- +language: +- zh +- en +tags: +- glm +- chatglm +- thudm +--- +# ChatGLM-6B +

+ 🌐 Blog • 💻 Github Repo • 🐦 Twitter • 📃 [GLM@ACL 22] [GitHub] • 📃 [GLM-130B@ICLR 23] [GitHub]
+

+ +

+ 👋 Join our Slack and WeChat +

+ +## 介绍 +ChatGLM-6B 是一个开源的、支持中英双语问答的对话语言模型,基于 [General Language Model (GLM)](https://github.com/THUDM/GLM) 架构,具有 62 亿参数。结合模型量化技术,用户可以在消费级的显卡上进行本地部署(INT4 量化级别下最低只需 6GB 显存)。ChatGLM-6B 使用了和 [ChatGLM](https://chatglm.cn) 相同的技术,针对中文问答和对话进行了优化。经过约 1T 标识符的中英双语训练,辅以监督微调、反馈自助、人类反馈强化学习等技术的加持,62 亿参数的 ChatGLM-6B 已经能生成相当符合人类偏好的回答。 + +ChatGLM-6B is an open bilingual language model based on [General Language Model (GLM)](https://github.com/THUDM/GLM) framework, with 6.2 billion parameters. With the quantization technique, users can deploy locally on consumer-grade graphics cards (only 6GB of GPU memory is required at the INT4 quantization level). ChatGLM-6B uses technology similar to ChatGPT, optimized for Chinese QA and dialogue. The model is trained for about 1T tokens of Chinese and English corpus, supplemented by supervised fine-tuning, feedback bootstrap, and reinforcement learning wit human feedback. With only about 6.2 billion parameters, the model is able to generate answers that are in line with human preference. + +## 软件依赖 + +```shell +pip install protobuf==3.20.0 transformers==4.27.1 icetk cpm_kernels +``` + +## 代码调用 + +可以通过如下代码调用 ChatGLM-6B 模型来生成对话: + +```ipython +>>> from transformers import AutoTokenizer, AutoModel +>>> tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) +>>> model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() +>>> response, history = model.chat(tokenizer, "你好", history=[]) +>>> print(response) +你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。 +>>> response, history = model.chat(tokenizer, "晚上睡不着应该怎么办", history=history) +>>> print(response) +晚上睡不着可能会让你感到焦虑或不舒服,但以下是一些可以帮助你入睡的方法: + +1. 制定规律的睡眠时间表:保持规律的睡眠时间表可以帮助你建立健康的睡眠习惯,使你更容易入睡。尽量在每天的相同时间上床,并在同一时间起床。 +2. 创造一个舒适的睡眠环境:确保睡眠环境舒适,安静,黑暗且温度适宜。可以使用舒适的床上用品,并保持房间通风。 +3. 放松身心:在睡前做些放松的活动,例如泡个热水澡,听些轻柔的音乐,阅读一些有趣的书籍等,有助于缓解紧张和焦虑,使你更容易入睡。 +4. 避免饮用含有咖啡因的饮料:咖啡因是一种刺激性物质,会影响你的睡眠质量。尽量避免在睡前饮用含有咖啡因的饮料,例如咖啡,茶和可乐。 +5. 避免在床上做与睡眠无关的事情:在床上做些与睡眠无关的事情,例如看电影,玩游戏或工作等,可能会干扰你的睡眠。 +6. 尝试呼吸技巧:深呼吸是一种放松技巧,可以帮助你缓解紧张和焦虑,使你更容易入睡。试着慢慢吸气,保持几秒钟,然后缓慢呼气。 + +如果这些方法无法帮助你入睡,你可以考虑咨询医生或睡眠专家,寻求进一步的建议。 +``` + +关于更多的使用说明,包括如何运行命令行和网页版本的 DEMO,以及使用模型量化以节省显存,请参考我们的 [Github Repo](https://github.com/THUDM/ChatGLM-6B)。 + +For more instructions, including how to run CLI and web demos, and model quantization, please refer to our [Github Repo](https://github.com/THUDM/ChatGLM-6B). + +## Change Log +* v1.1.0 ([942945d](https://huggingface.co/THUDM/chatglm-6b/commit/942945df047dee66f653c68ae0e56655045f1741)): 更新 v1.1 版本 checkpoint +* v0.1.0 ([f831824](https://huggingface.co/THUDM/chatglm-6b/commit/f83182484538e663a03d3f73647f10f89878f438)) + +## 协议 + +本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源,ChatGLM-6B 模型的权重的使用则需要遵循 [Model License](MODEL_LICENSE)。 + +## 引用 + +如果你觉得我们的工作有帮助的话,请考虑引用下列论文: + +``` +@inproceedings{ + zeng2023glm-130b, + title={{GLM}-130B: An Open Bilingual Pre-trained Model}, + author={Aohan Zeng and Xiao Liu and Zhengxiao Du and Zihan Wang and Hanyu Lai and Ming Ding and Zhuoyi Yang and Yifan Xu and Wendi Zheng and Xiao Xia and Weng Lam Tam and Zixuan Ma and Yufei Xue and Jidong Zhai and Wenguang Chen and Zhiyuan Liu and Peng Zhang and Yuxiao Dong and Jie Tang}, + booktitle={The Eleventh International Conference on Learning Representations (ICLR)}, + year={2023}, + url={https://openreview.net/forum?id=-Aw0rrrPUF} +} +``` +``` +@inproceedings{du2022glm, + title={GLM: General Language Model Pretraining with Autoregressive Blank Infilling}, + author={Du, Zhengxiao and Qian, Yujie and Liu, Xiao and Ding, Ming and Qiu, Jiezhong and Yang, Zhilin and Tang, Jie}, + booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)}, + pages={320--335}, + year={2022} +} +``` \ No newline at end of file diff --git a/tests/script/models/chatglm-6b/config.json b/tests/script/models/chatglm-6b/config.json new file mode 100644 index 0000000..7cc6e70 --- /dev/null +++ b/tests/script/models/chatglm-6b/config.json @@ -0,0 +1,28 @@ +{ + "_name_or_path": "THUDM/chatglm-6b", + "architectures": [ + "ChatGLMModel" + ], + "auto_map": { + "AutoConfig": "configuration_chatglm.ChatGLMConfig", + "AutoModel": "modeling_chatglm.ChatGLMForConditionalGeneration", + "AutoModelForSeq2SeqLM": "modeling_chatglm.ChatGLMForConditionalGeneration" + }, + "bos_token_id": 130004, + "eos_token_id": 130005, + "mask_token_id": 130000, + "gmask_token_id": 130001, + "pad_token_id": 3, + "hidden_size": 4096, + "inner_hidden_size": 16384, + "layernorm_epsilon": 1e-05, + "max_sequence_length": 2048, + "model_type": "chatglm", + "num_attention_heads": 32, + "num_layers": 28, + "position_encoding_2d": true, + "torch_dtype": "float32", + "transformers_version": "4.23.1", + "use_cache": true, + "vocab_size": 130528 +} diff --git a/tests/script/models/chatglm-6b/configuration_chatglm.py b/tests/script/models/chatglm-6b/configuration_chatglm.py new file mode 100644 index 0000000..78f3425 --- /dev/null +++ b/tests/script/models/chatglm-6b/configuration_chatglm.py @@ -0,0 +1,103 @@ +""" ChatGLM model configuration """ + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class ChatGLMConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`~ChatGLMModel`]. + It is used to instantiate an ChatGLM model according to the specified arguments, defining the model + architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of + the ChatGLM-6B [THUDM/ChatGLM-6B](https://huggingface.co/THUDM/chatglm-6b) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used + to control the model outputs. Read the documentation from [`PretrainedConfig`] + for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 150528): + Vocabulary size of the ChatGLM-6B model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`~ChatGLMModel`] or + [`~TFChatGLMModel`]. + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 28): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + inner_hidden_size (`int`, *optional*, defaults to 16384): + Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + max_sequence_length (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. + Typically set this to something large just in case (e.g., 512 or 1024 or 2048). + layernorm_epsilon (`float`, *optional*, defaults to 1e-5): + The epsilon used by the layer normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether the model should return the last key/values attentions (not used by all models). + Example: + + ```python + >>> from configuration_chatglm import ChatGLMConfig + >>> from modeling_chatglm import ChatGLMModel + + >>> # Initializing a ChatGLM-6B THUDM/ChatGLM-6B style configuration + >>> configuration = ChatGLMConfig() + + >>> # Initializing a model from the THUDM/ChatGLM-6B style configuration + >>> model = ChatGLMModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` +""" + model_type = "chatglm" + + def __init__( + self, + vocab_size=150528, + hidden_size=4096, + num_layers=28, + num_attention_heads=32, + layernorm_epsilon=1e-5, + use_cache=False, + bos_token_id=150004, + eos_token_id=150005, + mask_token_id=150000, + gmask_token_id=150001, + pad_token_id=0, + max_sequence_length=2048, + inner_hidden_size=16384, + position_encoding_2d=True, + quantization_bit=0, + pre_seq_len=None, + prefix_projection=False, + **kwargs + ): + self.num_layers = num_layers + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.max_sequence_length = max_sequence_length + self.layernorm_epsilon = layernorm_epsilon + self.inner_hidden_size = inner_hidden_size + self.use_cache = use_cache + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.mask_token_id = mask_token_id + self.gmask_token_id = gmask_token_id + self.position_encoding_2d = position_encoding_2d + self.quantization_bit = quantization_bit + self.pre_seq_len = pre_seq_len + self.prefix_projection = prefix_projection + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + **kwargs + ) diff --git a/tests/script/models/chatglm-6b/modeling_chatglm.org.py b/tests/script/models/chatglm-6b/modeling_chatglm.org.py new file mode 100644 index 0000000..d24bf52 --- /dev/null +++ b/tests/script/models/chatglm-6b/modeling_chatglm.org.py @@ -0,0 +1,1450 @@ +""" PyTorch ChatGLM model. """ + +import math +import copy +import os +import warnings +import re +import sys + +import torch +import torch.utils.checkpoint +import torch.nn.functional as F +from torch import nn +from torch.nn import CrossEntropyLoss, LayerNorm +from torch.nn.utils import skip_init +from typing import Optional, Tuple, Union, List, Callable, Dict, Any + +from transformers.utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + BaseModelOutputWithPastAndCrossAttentions, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging +from transformers.generation.logits_process import LogitsProcessor +from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput + +from .configuration_chatglm import ChatGLMConfig + +# flags required to enable jit fusion kernels + +if sys.platform != 'darwin': + torch._C._jit_set_profiling_mode(False) + torch._C._jit_set_profiling_executor(False) + torch._C._jit_override_can_fuse_on_cpu(True) + torch._C._jit_override_can_fuse_on_gpu(True) + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "THUDM/ChatGLM-6B" +_CONFIG_FOR_DOC = "ChatGLM6BConfig" + +CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "THUDM/chatglm-6b", + # See all ChatGLM-6B models at https://huggingface.co/models?filter=chatglm +] + + +class InvalidScoreLogitsProcessor(LogitsProcessor): + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + if torch.isnan(scores).any() or torch.isinf(scores).any(): + scores.zero_() + scores[..., 5] = 5e4 + return scores + + +def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array) + + for name, array in zip(names, arrays): + name = name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name + ): + logger.info(f"Skipping {'/'.join(name)}") + continue + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "kernel" or scope_names[0] == "gamma": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "output_weights": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info(f"Skipping {'/'.join(name)}") + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if m_name[-11:] == "_embeddings": + pointer = getattr(pointer, "weight") + elif m_name == "kernel": + array = np.transpose(array) + try: + assert ( + pointer.shape == array.shape + ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array) + return model + + +class PrefixEncoder(torch.nn.Module): + """ + The torch.nn model to encode the prefix + Input shape: (batch-size, prefix-length) + Output shape: (batch-size, prefix-length, 2*layers*hidden) + """ + + def __init__(self, config): + super().__init__() + self.prefix_projection = config.prefix_projection + if self.prefix_projection: + # Use a two-layer MLP to encode the prefix + self.embedding = torch.nn.Embedding(config.pre_seq_len, config.hidden_size) + self.trans = torch.nn.Sequential( + torch.nn.Linear(config.hidden_size, config.hidden_size), + torch.nn.Tanh(), + torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2) + ) + else: + self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_layers * config.hidden_size * 2) + + def forward(self, prefix: torch.Tensor): + if self.prefix_projection: + prefix_tokens = self.embedding(prefix) + past_key_values = self.trans(prefix_tokens) + else: + past_key_values = self.embedding(prefix) + return past_key_values + + +@torch.jit.script +def gelu_impl(x): + """OpenAI's gelu implementation.""" + return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * + (1.0 + 0.044715 * x * x))) + + +def gelu(x): + return gelu_impl(x) + + +class RotaryEmbedding(torch.nn.Module): + def __init__(self, dim, base=10000, precision=torch.half, learnable=False): + super().__init__() + inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) + inv_freq = inv_freq.half() + self.learnable = learnable + if learnable: + self.inv_freq = torch.nn.Parameter(inv_freq) + self.max_seq_len_cached = None + else: + self.register_buffer('inv_freq', inv_freq) + self.max_seq_len_cached = None + self.cos_cached = None + self.sin_cached = None + self.precision = precision + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, + error_msgs): + pass + + def forward(self, x, seq_dim=1, seq_len=None): + if seq_len is None: + seq_len = x.shape[seq_dim] + if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached): + self.max_seq_len_cached = None if self.learnable else seq_len + t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype) + freqs = torch.einsum('i,j->ij', t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + if self.precision == torch.bfloat16: + emb = emb.float() + + # [sx, 1 (b * np), hn] + cos_cached = emb.cos()[:, None, :] + sin_cached = emb.sin()[:, None, :] + if self.precision == torch.bfloat16: + cos_cached = cos_cached.bfloat16() + sin_cached = sin_cached.bfloat16() + if self.learnable: + return cos_cached, sin_cached + self.cos_cached, self.sin_cached = cos_cached, sin_cached + return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...] + + def _apply(self, fn): + if self.cos_cached is not None: + self.cos_cached = fn(self.cos_cached) + if self.sin_cached is not None: + self.sin_cached = fn(self.sin_cached) + return super()._apply(fn) + + +def rotate_half(x): + x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions + + +@torch.jit.script +def apply_rotary_pos_emb_index(q, k, cos, sin, position_id): + # position_id: [b, sq], q, k: [b, sq, np, hn], cos: [sq, 1, hn] -> [b, sq, 1, hn] + cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), \ + F.embedding(position_id, sin.squeeze(1)).unsqueeze(2) + q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) + return q, k + + +def attention_fn( + self, + query_layer, + key_layer, + value_layer, + attention_mask, + hidden_size_per_partition, + layer_id, + layer_past=None, + scaling_attention_score=True, + use_cache=False, +): + if layer_past is not None: + past_key, past_value = layer_past[0], layer_past[1] + key_layer = torch.cat((past_key, key_layer), dim=1) + value_layer = torch.cat((past_value, value_layer), dim=1) + + # batch, seqlen, num_attention_heads, hidden_size_per_attention_head + b, seq_len, nh, hidden_size = key_layer.shape + + if use_cache: + present = (key_layer, value_layer) + else: + present = None + + query_key_layer_scaling_coeff = float(layer_id + 1) + if scaling_attention_score: + query_layer = query_layer / (math.sqrt(hidden_size) * query_key_layer_scaling_coeff) + + # =================================== + # Raw attention scores. [b, np, s, s] + # =================================== + + # [b, np, sq, sk] + output_size = (query_layer.size(0), query_layer.size(2), query_layer.size(1), key_layer.size(1)) + + # # [sq, b, np, hn] -> [sq, b * np, hn] + # query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) + # # [sk, b, np, hn] -> [sk, b * np, hn] + # key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) + # [b, sq, np, hn] -> [b, np, sq, hn] + query_layer = query_layer.transpose(1, 2) + # [b, sk, np, hn] -> [b, np, sk, hn] + key_layer = key_layer.transpose(1, 2) + # [b, np, sq, hn] -> [b * np, sq, hn] + query_layer = query_layer.view(output_size[0] * output_size[1], output_size[2], -1) + # [b, np, sk, hn] -> [b * np, sk, hn] + key_layer = key_layer.view(output_size[0] * output_size[1], output_size[3], -1) + + matmul_result = torch.zeros( + 1, 1, 1, + dtype=query_layer.dtype, + device=query_layer.device, + ) + + matmul_result = torch.baddbmm( + matmul_result, + query_layer, # [b * np, sq, hn] + key_layer.transpose(1, 2), # [b * np, hn, sk] + beta=0.0, + alpha=1.0, + ) + + # change view to [b, np, sq, sk] + attention_scores = matmul_result.view(*output_size) + + if self.scale_mask_softmax: + self.scale_mask_softmax.scale = query_key_layer_scaling_coeff + attention_probs = self.scale_mask_softmax(attention_scores, attention_mask.contiguous()) + else: + if not (attention_mask == 0).all(): + # if auto-regressive, skip + attention_scores.masked_fill_(attention_mask, -10000.0) + dtype = attention_scores.dtype + attention_scores = attention_scores.float() + attention_scores = attention_scores * query_key_layer_scaling_coeff + + attention_probs = F.softmax(attention_scores, dim=-1) + + attention_probs = attention_probs.type(dtype) + + # ========================= + # Context layer. [sq, b, hp] + # ========================= + + # value_layer -> context layer. + # [b, sk, np, hn] --> [b, np, sq, hn] + + # context layer shape: [b, np, sq, hn] + output_size = (value_layer.size(0), value_layer.size(2), query_layer.size(1), value_layer.size(3)) + + # # change view [sk, b * np, hn] + # value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) + # [b, sk, np, hn] -> [b, np, sk, hn] + value_layer = value_layer.transpose(1, 2) + # [b, np, sk, hn] -> [b * np, sk, hn] + value_layer = value_layer.view(output_size[0] * output_size[1], value_layer.size(2), output_size[3]) + + # change view [b * np, sq, sk] + attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) + + # matmul: [b * np, sq, hn] + context_layer = torch.bmm(attention_probs, value_layer) + + # change view [b, np, sq, hn] + context_layer = context_layer.view(*output_size) + + # [b, np, sq, hn] --> [b, sq, np, hn] + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + + # [b, sq, np, hn] --> [b, sq, hp] + new_context_layer_shape = context_layer.size()[:-2] + (hidden_size_per_partition,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, present, attention_probs) + + return outputs + + +def default_init(cls, *args, **kwargs): + return cls(*args, **kwargs) + + +class SelfAttention(torch.nn.Module): + def __init__(self, hidden_size, num_attention_heads, + layer_id, hidden_size_per_attention_head=None, bias=True, + params_dtype=torch.float, position_encoding_2d=True, empty_init=True): + if empty_init: + init_method = skip_init + else: + init_method = default_init + super(SelfAttention, self).__init__() + + self.layer_id = layer_id + self.hidden_size = hidden_size + self.hidden_size_per_partition = hidden_size + self.num_attention_heads = num_attention_heads + self.num_attention_heads_per_partition = num_attention_heads + self.position_encoding_2d = position_encoding_2d + self.rotary_emb = RotaryEmbedding( + self.hidden_size // (self.num_attention_heads * 2) + if position_encoding_2d + else self.hidden_size // self.num_attention_heads, + base=10000, + precision=torch.half, + learnable=False, + ) + + self.scale_mask_softmax = None + + if hidden_size_per_attention_head is None: + self.hidden_size_per_attention_head = hidden_size // num_attention_heads + else: + self.hidden_size_per_attention_head = hidden_size_per_attention_head + + self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head + + # Strided linear layer. + self.query_key_value = init_method( + torch.nn.Linear, + hidden_size, + 3 * self.inner_hidden_size, + bias=bias, + dtype=params_dtype, + ) + + self.dense = init_method( + torch.nn.Linear, + self.inner_hidden_size, + hidden_size, + bias=bias, + dtype=params_dtype, + ) + + @staticmethod + def attention_mask_func(attention_scores, attention_mask): + attention_scores.masked_fill_(attention_mask, -10000.0) + return attention_scores + + def split_tensor_along_last_dim(self, tensor, num_partitions, + contiguous_split_chunks=False): + """Split a tensor along its last dimension. + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + """ + # Get the size and dimension. + last_dim = tensor.dim() - 1 + last_dim_size = tensor.size()[last_dim] // num_partitions + # Split. + tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) + # Note: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tensor_list + + def forward( + self, + hidden_states: torch.Tensor, + position_ids, + attention_mask: torch.Tensor, + layer_id, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + use_cache: bool = False, + output_attentions: bool = False, + ): + """ + hidden_states: [batch, seq_len, hidden_size] + attention_mask: [(1, 1), seq_len, seq_len] + """ + + # [batch, seq_len, 3 * hidden_size] + mixed_raw_layer = self.query_key_value(hidden_states) + + # [batch, seq_len, 3 * hidden_size] --> [batch, seq_len, num_attention_heads, 3 * hidden_size_per_attention_head] + new_tensor_shape = mixed_raw_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head, + ) + mixed_raw_layer = mixed_raw_layer.view(*new_tensor_shape) + + # [batch, seq_len, num_attention_heads, hidden_size_per_attention_head] + (query_layer, key_layer, value_layer) = self.split_tensor_along_last_dim(mixed_raw_layer, 3) + + if self.position_encoding_2d: + q1, q2 = query_layer.chunk(2, dim=(query_layer.ndim - 1)) + k1, k2 = key_layer.chunk(2, dim=(key_layer.ndim - 1)) + cos, sin = self.rotary_emb(q1, seq_len=position_ids.max() + 1) + # xxx + position_ids, block_position_ids = position_ids[:, 0, :].contiguous(), \ + position_ids[:, 1, :].contiguous() + q1, k1 = apply_rotary_pos_emb_index(q1, k1, cos, sin, position_ids) + q2, k2 = apply_rotary_pos_emb_index(q2, k2, cos, sin, block_position_ids) + query_layer = torch.concat([q1, q2], dim=(q1.ndim - 1)) + key_layer = torch.concat([k1, k2], dim=(k1.ndim - 1)) + else: + position_ids = position_ids.transpose(0, 1) + cos, sin = self.rotary_emb(value_layer, seq_len=position_ids.max() + 1) + # [batch, seq_len, num_attention_heads, hidden_size_per_attention_head] + query_layer, key_layer = apply_rotary_pos_emb_index(query_layer, key_layer, cos, sin, position_ids) + + # [batch, seq_len, hidden_size] + context_layer, present, attention_probs = attention_fn( + self=self, + query_layer=query_layer, + key_layer=key_layer, + value_layer=value_layer, + attention_mask=attention_mask, + hidden_size_per_partition=self.hidden_size_per_partition, + layer_id=layer_id, + layer_past=layer_past, + use_cache=use_cache + ) + + output = self.dense(context_layer) + + outputs = (output, present) + + if output_attentions: + outputs += (attention_probs,) + + return outputs # output, present, attention_probs + + +class GEGLU(torch.nn.Module): + def __init__(self): + super().__init__() + self.activation_fn = F.gelu + + def forward(self, x): + # dim=-1 breaks in jit for pt<1.10 + x1, x2 = x.chunk(2, dim=(x.ndim - 1)) + return x1 * self.activation_fn(x2) + + +class GLU(torch.nn.Module): + def __init__(self, hidden_size, inner_hidden_size=None, + layer_id=None, bias=True, activation_func=gelu, params_dtype=torch.float, empty_init=True): + super(GLU, self).__init__() + if empty_init: + init_method = skip_init + else: + init_method = default_init + self.layer_id = layer_id + self.activation_func = activation_func + + # Project to 4h. + self.hidden_size = hidden_size + if inner_hidden_size is None: + inner_hidden_size = 4 * hidden_size + self.inner_hidden_size = inner_hidden_size + self.dense_h_to_4h = init_method( + torch.nn.Linear, + self.hidden_size, + self.inner_hidden_size, + bias=bias, + dtype=params_dtype, + ) + # Project back to h. + self.dense_4h_to_h = init_method( + torch.nn.Linear, + self.inner_hidden_size, + self.hidden_size, + bias=bias, + dtype=params_dtype, + ) + + def forward(self, hidden_states): + """ + hidden_states: [seq_len, batch, hidden_size] + """ + + # [seq_len, batch, inner_hidden_size] + intermediate_parallel = self.dense_h_to_4h(hidden_states) + + intermediate_parallel = self.activation_func(intermediate_parallel) + + output = self.dense_4h_to_h(intermediate_parallel) + + return output + + +class GLMBlock(torch.nn.Module): + def __init__( + self, + hidden_size, + num_attention_heads, + layernorm_epsilon, + layer_id, + inner_hidden_size=None, + hidden_size_per_attention_head=None, + layernorm=LayerNorm, + use_bias=True, + params_dtype=torch.float, + num_layers=28, + position_encoding_2d=True, + empty_init=True + ): + super(GLMBlock, self).__init__() + # Set output layer initialization if not provided. + + self.layer_id = layer_id + + # Layernorm on the input data. + self.input_layernorm = layernorm(hidden_size, eps=layernorm_epsilon) + + self.position_encoding_2d = position_encoding_2d + + # Self attention. + self.attention = SelfAttention( + hidden_size, + num_attention_heads, + layer_id, + hidden_size_per_attention_head=hidden_size_per_attention_head, + bias=use_bias, + params_dtype=params_dtype, + position_encoding_2d=self.position_encoding_2d, + empty_init=empty_init + ) + + # Layernorm on the input data. + self.post_attention_layernorm = layernorm(hidden_size, eps=layernorm_epsilon) + + self.num_layers = num_layers + + # GLU + self.mlp = GLU( + hidden_size, + inner_hidden_size=inner_hidden_size, + bias=use_bias, + layer_id=layer_id, + params_dtype=params_dtype, + empty_init=empty_init + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids, + attention_mask: torch.Tensor, + layer_id, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + use_cache: bool = False, + output_attentions: bool = False, + ): + """ + hidden_states: [seq_len, batch, hidden_size] + attention_mask: [(1, 1), seq_len, seq_len] + """ + + # Layer norm at the begining of the transformer layer. + # [batch, seq_len, hidden_size] + attention_input = self.input_layernorm(hidden_states) + + # Self attention. + attention_outputs = self.attention( + attention_input, + position_ids, + attention_mask=attention_mask, + layer_id=layer_id, + layer_past=layer_past, + use_cache=use_cache, + output_attentions=output_attentions + ) + + attention_output = attention_outputs[0] + + outputs = attention_outputs[1:] + + # Residual connection. + alpha = (2 * self.num_layers) ** 0.5 + hidden_states = attention_input * alpha + attention_output + + mlp_input = self.post_attention_layernorm(hidden_states) + + # MLP. + mlp_output = self.mlp(mlp_input) + + # Second residual connection. + output = mlp_input * alpha + mlp_output + + if use_cache: + outputs = (output,) + outputs + else: + outputs = (output,) + outputs[1:] + + return outputs # hidden_states, present, attentions + + +class ChatGLMPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and + a simple interface for downloading and loading pretrained models. + """ + + is_parallelizable = False + supports_gradient_checkpointing = True + config_class = ChatGLMConfig + base_model_prefix = "transformer" + _no_split_modules = ["GLMBlock"] + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights(self, module: nn.Module): + """Initialize the weights.""" + return + + def get_masks(self, input_ids, device): + batch_size, seq_length = input_ids.shape + context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids] + attention_mask = torch.ones((batch_size, seq_length, seq_length), device=device) + attention_mask.tril_() + for i, context_length in enumerate(context_lengths): + attention_mask[i, :, :context_length] = 1 + attention_mask.unsqueeze_(1) + attention_mask = (attention_mask < 0.5).bool() + + return attention_mask + + def get_position_ids(self, input_ids, mask_positions, device, use_gmasks=None): + batch_size, seq_length = input_ids.shape + if use_gmasks is None: + use_gmasks = [False] * batch_size + context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids] + if self.position_encoding_2d: + position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) + for i, context_length in enumerate(context_lengths): + position_ids[i, context_length:] = mask_positions[i] + block_position_ids = [torch.cat(( + torch.zeros(context_length, dtype=torch.long, device=device), + torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1 + )) for context_length in context_lengths] + block_position_ids = torch.stack(block_position_ids, dim=0) + position_ids = torch.stack((position_ids, block_position_ids), dim=1) + else: + position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) + for i, context_length in enumerate(context_lengths): + if not use_gmasks[i]: + position_ids[i, context_length:] = mask_positions[i] + + return position_ids + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, ChatGLMModel): + module.gradient_checkpointing = value + + +CHATGLM_6B_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general + usage and behavior. + + Parameters: + config ([`~ChatGLM6BConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the configuration. + Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +CHATGLM_6B_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`ChatGLM6BTokenizer`]. + See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. + Selected in the range `[0, config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert *input_ids* indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare ChatGLM-6B Model transformer outputting raw hidden-states without any specific head on top.", + CHATGLM_6B_START_DOCSTRING, +) +class ChatGLMModel(ChatGLMPreTrainedModel): + """ + + The model can behave as an encoder (with only self-attention) as well + as a decoder, in which case a layer of cross-attention is added between + the self-attention layers, following the architecture described in [Attention is + all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, + Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + + To behave as an decoder the model needs to be initialized with the + `is_decoder` argument of the configuration set to `True`. + To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` + argument and `add_cross_attention` set to `True`; an + `encoder_hidden_states` is then expected as an input to the forward pass. + """ + + def __init__(self, config: ChatGLMConfig, empty_init=True): + super().__init__(config) + if empty_init: + init_method = skip_init + else: + init_method = default_init + # recording parameters + self.max_sequence_length = config.max_sequence_length + self.hidden_size = config.hidden_size + self.params_dtype = torch.half + self.num_attention_heads = config.num_attention_heads + self.vocab_size = config.vocab_size + self.num_layers = config.num_layers + self.layernorm_epsilon = config.layernorm_epsilon + self.inner_hidden_size = config.inner_hidden_size + self.hidden_size_per_attention_head = self.hidden_size // self.num_attention_heads + self.position_encoding_2d = config.position_encoding_2d + self.pre_seq_len = config.pre_seq_len + self.prefix_projection = config.prefix_projection + + self.word_embeddings = init_method( + torch.nn.Embedding, + num_embeddings=self.vocab_size, embedding_dim=self.hidden_size, + dtype=self.params_dtype + ) + self.gradient_checkpointing = False + + def get_layer(layer_id): + return GLMBlock( + self.hidden_size, + self.num_attention_heads, + self.layernorm_epsilon, + layer_id, + inner_hidden_size=self.inner_hidden_size, + hidden_size_per_attention_head=self.hidden_size_per_attention_head, + layernorm=LayerNorm, + use_bias=True, + params_dtype=self.params_dtype, + position_encoding_2d=self.position_encoding_2d, + empty_init=empty_init + ) + + self.layers = torch.nn.ModuleList( + [get_layer(layer_id) for layer_id in range(self.num_layers)] + ) + + # Final layer norm before output. + self.final_layernorm = LayerNorm(self.hidden_size, eps=self.layernorm_epsilon) + + if self.pre_seq_len is not None: + for param in self.parameters(): + param.requires_grad = False + self.prefix_tokens = torch.arange(self.pre_seq_len).long() + self.prefix_encoder = PrefixEncoder(config) + self.dropout = torch.nn.Dropout(0.1) + + # total_params = sum(p.numel() for p in self.parameters()) + # trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) + # print("Using p-tuning v2: # trainable_params = {} / {}".format(trainable_params, total_params)) + + def get_input_embeddings(self): + return self.word_embeddings + + def set_input_embeddings(self, new_embeddings: torch.Tensor): + self.word_embeddings = new_embeddings + + def get_prompt(self, batch_size, device, dtype=torch.half): + prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device) + past_key_values = self.prefix_encoder(prefix_tokens).type(dtype) + past_key_values = past_key_values.view( + batch_size, + self.pre_seq_len, + self.num_layers * 2, + self.num_attention_heads, + self.hidden_size // self.num_attention_heads + ) + # seq_len, b, nh, hidden_size + past_key_values = self.dropout(past_key_values) + past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2) + # past_key_values = [(v[0], v[1]) for v in past_key_values] + return past_key_values + + @add_start_docstrings_to_model_forward(CHATGLM_6B_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPastAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPast]: + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length = inputs_embeds.shape[:2] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + if past_key_values is None: + if self.pre_seq_len is not None: + past_key_values = self.get_prompt(batch_size=input_ids.shape[0], device=input_ids.device, + dtype=inputs_embeds.dtype) + else: + past_key_values = tuple([None] * len(self.layers)) + + if attention_mask is None: + attention_mask = self.get_masks( + input_ids, + device=input_ids.device + ) + + + if position_ids is None: + MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id + seqs = input_ids.tolist() + + mask_positions, use_gmasks = [], [] + for seq in seqs: + mask_token = gMASK if gMASK in seq else MASK + use_gmask = mask_token == gMASK + mask_positions.append(seq.index(mask_token)) + use_gmasks.append(use_gmask) + + position_ids = self.get_position_ids( + input_ids, + mask_positions=mask_positions, + device=input_ids.device, + use_gmasks=use_gmasks + ) + + if self.pre_seq_len is not None and attention_mask is not None: + prefix_attention_mask = torch.ones(batch_size, 1, input_ids.size(-1), self.pre_seq_len).to( + attention_mask.device) + prefix_attention_mask = (prefix_attention_mask < 0.5).bool() + attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3) + + # [seq_len, batch, hidden_size] + # hidden_states = inputs_embeds.transpose(0, 1) + # xxx + hidden_states = inputs_embeds + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + if attention_mask is None: + attention_mask = torch.zeros(1, 1, device=input_ids.device).bool() + else: + attention_mask = attention_mask.to(hidden_states.device) + + for i, layer in enumerate(self.layers): + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + layer_past = past_key_values[i] + + if self.gradient_checkpointing and self.training: + layer_ret = torch.utils.checkpoint.checkpoint( + layer, + hidden_states, + position_ids, + attention_mask, + torch.tensor(i), + layer_past, + use_cache, + output_attentions + ) + else: + layer_ret = layer( + hidden_states, + position_ids=position_ids, + attention_mask=attention_mask, + layer_id=torch.tensor(i), + layer_past=layer_past, + use_cache=use_cache, + output_attentions=output_attentions + ) + + hidden_states = layer_ret[0] + + if use_cache: + presents = presents + (layer_ret[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_ret[2 if use_cache else 1],) + + # Final layer norm. + hidden_states = self.final_layernorm(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): + def __init__(self, config: ChatGLMConfig, empty_init=True): + super().__init__(config) + if empty_init: + init_method = skip_init + else: + init_method = default_init + + # self.hidden_size = config.hidden_size + # self.params_dtype = torch.half + # self.vocab_size = config.vocab_size + self.max_sequence_length = config.max_sequence_length + + self.position_encoding_2d = config.position_encoding_2d + + self.transformer = ChatGLMModel(config, empty_init=empty_init) + + self.lm_head = init_method( + nn.Linear, + config.hidden_size, + config.vocab_size, + bias=False, + dtype=torch.half + ) + + self.config = config + + self.quantized = False + + if self.config.quantization_bit: + self.quantize(self.config.quantization_bit, empty_init=True) + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def _update_model_kwargs_for_generation( + self, + outputs: ModelOutput, + model_kwargs: Dict[str, Any], + is_encoder_decoder: bool = False, + standardize_cache_format: bool = False, + ) -> Dict[str, Any]: + # update past_key_values + model_kwargs["past_key_values"] = self._extract_past_from_model_output( + outputs, standardize_cache_format=standardize_cache_format + ) + + # update attention mask + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + if attention_mask is not None and attention_mask.dtype == torch.bool: + attention_mask = torch.cat( + [attention_mask, attention_mask.new_ones((*attention_mask.shape[:3], 1))], dim=3) + new_attention_mask = attention_mask[:, :, -1:].clone() + new_attention_mask[..., -1] = False + model_kwargs["attention_mask"] = torch.cat( + [attention_mask, new_attention_mask], dim=2 + ) + + # update position ids + if "position_ids" in model_kwargs: + position_ids = model_kwargs["position_ids"] + new_position_id = position_ids[..., -1:].clone() + new_position_id[:, 1, :] += 1 + model_kwargs["position_ids"] = torch.cat( + [position_ids, new_position_id], dim=-1 + ) + + return model_kwargs + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor, + past: Optional[torch.Tensor] = None, + past_key_values: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + **kwargs + ) -> dict: + batch_size, seq_length = input_ids.shape + MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id + seqs = input_ids.tolist() + mask_positions, use_gmasks = [], [] + for seq in seqs: + mask_token = gMASK if gMASK in seq else MASK + use_gmask = mask_token == gMASK + mask_positions.append(seq.index(mask_token)) + use_gmasks.append(use_gmask) + + # only last token for input_ids if past is not None + if past is not None or past_key_values is not None: + last_token = input_ids[:, -1].unsqueeze(-1) + if attention_mask is not None and attention_mask.dtype == torch.bool: + attention_mask = attention_mask[:, :, -1:] + else: + attention_mask = None + if position_ids is not None: + position_ids = position_ids[..., -1:] + else: + context_lengths = [seq.index(self.config.bos_token_id) for seq in seqs] + if self.position_encoding_2d: + position_ids = torch.tensor( + [[mask_position, seq_length - context_length] for mask_position, context_length in + zip(mask_positions, context_lengths)], dtype=torch.long, device=input_ids.device).unsqueeze(-1) + else: + position_ids = torch.tensor([mask_position for mask_position in mask_positions], dtype=torch.long, + device=input_ids.device).unsqueeze(-1) + + if past is None: + past = past_key_values + return { + "input_ids": last_token, + "past_key_values": past, + "position_ids": position_ids, + "attention_mask": attention_mask + } + else: + if attention_mask is not None and attention_mask.dtype != torch.bool: + logger.warning_once(f"The dtype of attention mask ({attention_mask.dtype}) is not bool") + attention_mask = None + if attention_mask is None: + attention_mask = self.get_masks( + input_ids, + device=input_ids.device + ) + if position_ids is None: + position_ids = self.get_position_ids( + input_ids, + device=input_ids.device, + mask_positions=mask_positions, + use_gmasks=use_gmasks + ) + + return { + "input_ids": input_ids, + "past_key_values": past, + "position_ids": position_ids, + "attention_mask": attention_mask + } + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + + lm_logits = self.lm_head(hidden_states).contiguous() + + loss = None + if labels is not None: + lm_logits = lm_logits.to(torch.float32) + + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss(ignore_index=-100) + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + lm_logits = lm_logits.to(hidden_states.dtype) + loss = loss.to(hidden_states.dtype) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + @staticmethod + def _reorder_cache( + past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor + ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + + Output shares the same memory storage as `past`. + """ + return tuple( + ( + layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)), + layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)), + ) + for layer_past in past + ) + + def process_response(self, response): + response = response.strip() + response = response.replace("[[训练时间]]", "2023年") + punkts = [ + [",", ","], + ["!", "!"], + [":", ":"], + [";", ";"], + ["\?", "?"], + ] + for item in punkts: + response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response) + response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response) + return response + + @torch.no_grad() + def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1, + do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs): + if history is None: + history = [] + if logits_processor is None: + logits_processor = LogitsProcessorList() + logits_processor.append(InvalidScoreLogitsProcessor()) + gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p, + "temperature": temperature, "logits_processor": logits_processor, **kwargs} + if not history: + prompt = query + else: + prompt = "" + for i, (old_query, response) in enumerate(history): + prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response) + prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) + inputs = tokenizer([prompt], return_tensors="pt") + inputs = inputs.to(self.device) + outputs = self.generate(**inputs, **gen_kwargs) + outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] + response = tokenizer.decode(outputs) + response = self.process_response(response) + history = history + [(query, response)] + return response, history + + @torch.no_grad() + def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, + do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs): + if history is None: + history = [] + if logits_processor is None: + logits_processor = LogitsProcessorList() + logits_processor.append(InvalidScoreLogitsProcessor()) + gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p, + "temperature": temperature, "logits_processor": logits_processor, **kwargs} + if not history: + prompt = query + else: + prompt = "" + for i, (old_query, response) in enumerate(history): + prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response) + prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) + inputs = tokenizer([prompt], return_tensors="pt") + inputs = inputs.to(self.device) + for outputs in self.stream_generate(**inputs, **gen_kwargs): + outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] + response = tokenizer.decode(outputs) + response = self.process_response(response) + new_history = history + [(query, response)] + yield response, new_history + + @torch.no_grad() + def stream_generate( + self, + input_ids, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, + **kwargs, + ): + batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] + + if generation_config is None: + generation_config = self.generation_config + generation_config = copy.deepcopy(generation_config) + model_kwargs = generation_config.update(**kwargs) + bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id + + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + + has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None + if has_default_max_length and generation_config.max_new_tokens is None: + warnings.warn( + f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " + "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" + " recommend using `max_new_tokens` to control the maximum length of the generation.", + UserWarning, + ) + elif generation_config.max_new_tokens is not None: + generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length + if not has_default_max_length: + logger.warn( + f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" + f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " + "Please refer to the documentation for more information. " + "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)", + UserWarning, + ) + + if input_ids_seq_length >= generation_config.max_length: + input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" + logger.warning( + f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" + f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" + " increasing `max_new_tokens`." + ) + + # 2. Set generation parameters if not already defined + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + + logits_processor = self._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids_seq_length, + encoder_input_ids=input_ids, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + logits_processor=logits_processor, + ) + + stopping_criteria = self._get_stopping_criteria( + generation_config=generation_config, stopping_criteria=stopping_criteria + ) + logits_warper = self._get_logits_warper(generation_config) + + unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + scores = None + while True: + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + # forward pass to get next token + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=False, + output_hidden_states=False, + ) + + next_token_logits = outputs.logits[:, -1, :] + + # pre-process distribution + next_token_scores = logits_processor(input_ids, next_token_logits) + next_token_scores = logits_warper(input_ids, next_token_scores) + + # sample + probs = nn.functional.softmax(next_token_scores, dim=-1) + if generation_config.do_sample: + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + else: + next_tokens = torch.argmax(probs, dim=-1) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long()) + + # stop when each sentence is finished, or if we exceed the maximum length + if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): + break + yield input_ids + + def quantize(self, bits: int, empty_init=False, **kwargs): + if bits == 0: + return + + from .quantization import quantize + + if self.quantized: + logger.info("Already quantized.") + return self + + self.quantized = True + + self.config.quantization_bit = bits + + self.transformer = quantize(self.transformer, bits, empty_init=empty_init, **kwargs) + return self diff --git a/tests/script/models/chatglm-6b/modeling_chatglm.py b/tests/script/models/chatglm-6b/modeling_chatglm.py new file mode 100644 index 0000000..5d851ff --- /dev/null +++ b/tests/script/models/chatglm-6b/modeling_chatglm.py @@ -0,0 +1,1490 @@ +""" PyTorch ChatGLM model. """ + +import math +import copy +import os +import warnings +import re +import sys + +import torch +import torch.utils.checkpoint +import torch.nn.functional as F +from torch import nn +from torch.nn import CrossEntropyLoss, LayerNorm +from torch.nn.utils import skip_init +from typing import Optional, Tuple, Union, List, Callable, Dict, Any +import llmdnn as ld + +from transformers.utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + BaseModelOutputWithPastAndCrossAttentions, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging +from transformers.generation.logits_process import LogitsProcessor +from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput + +from .configuration_chatglm import ChatGLMConfig + +# flags required to enable jit fusion kernels + +if sys.platform != 'darwin': + torch._C._jit_set_profiling_mode(False) + torch._C._jit_set_profiling_executor(False) + torch._C._jit_override_can_fuse_on_cpu(True) + torch._C._jit_override_can_fuse_on_gpu(True) + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "THUDM/ChatGLM-6B" +_CONFIG_FOR_DOC = "ChatGLM6BConfig" + +CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "THUDM/chatglm-6b", + # See all ChatGLM-6B models at https://huggingface.co/models?filter=chatglm +] + + +class InvalidScoreLogitsProcessor(LogitsProcessor): + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + if torch.isnan(scores).any() or torch.isinf(scores).any(): + scores.zero_() + scores[..., 5] = 5e4 + return scores + + +def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array) + + for name, array in zip(names, arrays): + name = name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name + ): + logger.info(f"Skipping {'/'.join(name)}") + continue + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "kernel" or scope_names[0] == "gamma": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "output_weights": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info(f"Skipping {'/'.join(name)}") + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if m_name[-11:] == "_embeddings": + pointer = getattr(pointer, "weight") + elif m_name == "kernel": + array = np.transpose(array) + try: + assert ( + pointer.shape == array.shape + ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array) + return model + + +class PrefixEncoder(torch.nn.Module): + """ + The torch.nn model to encode the prefix + Input shape: (batch-size, prefix-length) + Output shape: (batch-size, prefix-length, 2*layers*hidden) + """ + + def __init__(self, config): + super().__init__() + self.prefix_projection = config.prefix_projection + if self.prefix_projection: + # Use a two-layer MLP to encode the prefix + self.embedding = torch.nn.Embedding(config.pre_seq_len, config.hidden_size) + self.trans = torch.nn.Sequential( + torch.nn.Linear(config.hidden_size, config.hidden_size), + torch.nn.Tanh(), + torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2) + ) + else: + self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_layers * config.hidden_size * 2) + + def forward(self, prefix: torch.Tensor): + if self.prefix_projection: + prefix_tokens = self.embedding(prefix) + past_key_values = self.trans(prefix_tokens) + else: + past_key_values = self.embedding(prefix) + return past_key_values + + +@torch.jit.script +def gelu_impl(x): + """OpenAI's gelu implementation.""" + return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * + (1.0 + 0.044715 * x * x))) + + +def gelu(x): + return gelu_impl(x) + + +class RotaryEmbedding(torch.nn.Module): + def __init__(self, dim, base=10000, precision=torch.half, learnable=False): + super().__init__() + inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) + inv_freq = inv_freq.half() + self.learnable = learnable + if learnable: + self.inv_freq = torch.nn.Parameter(inv_freq) + self.max_seq_len_cached = None + else: + self.register_buffer('inv_freq', inv_freq) + self.max_seq_len_cached = None + self.cos_cached = None + self.sin_cached = None + self.precision = precision + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, + error_msgs): + pass + + def forward(self, x, seq_dim=1, seq_len=None): + if seq_len is None: + seq_len = x.shape[seq_dim] + if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached): + self.max_seq_len_cached = None if self.learnable else seq_len + t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype) + freqs = torch.einsum('i,j->ij', t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + if self.precision == torch.bfloat16: + emb = emb.float() + + # [sx, 1 (b * np), hn] + cos_cached = emb.cos()[:, None, :] + sin_cached = emb.sin()[:, None, :] + if self.precision == torch.bfloat16: + cos_cached = cos_cached.bfloat16() + sin_cached = sin_cached.bfloat16() + if self.learnable: + return cos_cached, sin_cached + self.cos_cached, self.sin_cached = cos_cached, sin_cached + return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...] + + def _apply(self, fn): + if self.cos_cached is not None: + self.cos_cached = fn(self.cos_cached) + if self.sin_cached is not None: + self.sin_cached = fn(self.sin_cached) + return super()._apply(fn) + + +def rotate_half(x): + x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions + + +@torch.jit.script +def apply_rotary_pos_emb_index(q, k, cos, sin, position_id): + # position_id: [b, sq], q, k: [b, sq, np, hn], cos: [sq, 1, hn] -> [b, sq, 1, hn] + cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), \ + F.embedding(position_id, sin.squeeze(1)).unsqueeze(2) + q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) + return q, k + + +def attention_fn( + self, + query_layer, + key_layer, + value_layer, + attention_mask, + hidden_size_per_partition, + layer_id, + layer_past=None, + scaling_attention_score=True, + use_cache=False, +): + if layer_past is not None: + past_key, past_value = layer_past[0], layer_past[1] + key_layer = torch.cat((past_key, key_layer), dim=1) + value_layer = torch.cat((past_value, value_layer), dim=1) + + # batch, seqlen, num_attention_heads, hidden_size_per_attention_head + b, seq_len, nh, hidden_size = key_layer.shape + + if use_cache: + present = (key_layer, value_layer) + else: + present = None + + query_key_layer_scaling_coeff = float(layer_id + 1) + if scaling_attention_score: + query_layer = query_layer / (math.sqrt(hidden_size) * query_key_layer_scaling_coeff) + + # =================================== + # Raw attention scores. [b, np, s, s] + # =================================== + + # [b, np, sq, sk] + output_size = (query_layer.size(0), query_layer.size(2), query_layer.size(1), key_layer.size(1)) + + # # [sq, b, np, hn] -> [sq, b * np, hn] + # query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) + # # [sk, b, np, hn] -> [sk, b * np, hn] + # key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) + # [b, sq, np, hn] -> [b, np, sq, hn] + query_layer = query_layer.transpose(1, 2) + # [b, sk, np, hn] -> [b, np, sk, hn] + key_layer = key_layer.transpose(1, 2) + # [b, np, sq, hn] -> [b * np, sq, hn] + query_layer = query_layer.view(output_size[0] * output_size[1], output_size[2], -1) + # [b, np, sk, hn] -> [b * np, sk, hn] + key_layer = key_layer.view(output_size[0] * output_size[1], output_size[3], -1) + + matmul_result = torch.zeros( + 1, 1, 1, + dtype=query_layer.dtype, + device=query_layer.device, + ) + + matmul_result = torch.baddbmm( + matmul_result, + query_layer, # [b * np, sq, hn] + key_layer.transpose(1, 2), # [b * np, hn, sk] + beta=0.0, + alpha=1.0, + ) + + # change view to [b, np, sq, sk] + attention_scores = matmul_result.view(*output_size) + + if self.scale_mask_softmax: + self.scale_mask_softmax.scale = query_key_layer_scaling_coeff + attention_probs = self.scale_mask_softmax(attention_scores, attention_mask.contiguous()) + else: + if not (attention_mask == 0).all(): + # if auto-regressive, skip + attention_scores.masked_fill_(attention_mask, -10000.0) + dtype = attention_scores.dtype + attention_scores = attention_scores.float() + attention_scores = attention_scores * query_key_layer_scaling_coeff + + attention_probs = F.softmax(attention_scores, dim=-1) + + attention_probs = attention_probs.type(dtype) + + # ========================= + # Context layer. [sq, b, hp] + # ========================= + + # value_layer -> context layer. + # [b, sk, np, hn] --> [b, np, sq, hn] + + # context layer shape: [b, np, sq, hn] + output_size = (value_layer.size(0), value_layer.size(2), query_layer.size(1), value_layer.size(3)) + + # # change view [sk, b * np, hn] + # value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) + # [b, sk, np, hn] -> [b, np, sk, hn] + value_layer = value_layer.transpose(1, 2) + # [b, np, sk, hn] -> [b * np, sk, hn] + value_layer = value_layer.view(output_size[0] * output_size[1], value_layer.size(2), output_size[3]) + + # change view [b * np, sq, sk] + attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) + + # matmul: [b * np, sq, hn] + context_layer = torch.bmm(attention_probs, value_layer) + + # change view [b, np, sq, hn] + context_layer = context_layer.view(*output_size) + + # [b, np, sq, hn] --> [b, sq, np, hn] + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + + # [b, sq, np, hn] --> [b, sq, hp] + new_context_layer_shape = context_layer.size()[:-2] + (hidden_size_per_partition,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, present, attention_probs) + + return outputs + + +def default_init(cls, *args, **kwargs): + return cls(*args, **kwargs) + + +class SelfAttention(torch.nn.Module): + def __init__(self, hidden_size, num_attention_heads, + layer_id, hidden_size_per_attention_head=None, bias=True, + params_dtype=torch.float, position_encoding_2d=True, empty_init=True, max_sequence_length=2048): + if empty_init: + init_method = skip_init + else: + init_method = default_init + super(SelfAttention, self).__init__() + + self.layer_id = layer_id + self.hidden_size = hidden_size + self.hidden_size_per_partition = hidden_size + self.num_attention_heads = num_attention_heads + self.num_attention_heads_per_partition = num_attention_heads + self.position_encoding_2d = position_encoding_2d + self.rotary_emb = RotaryEmbedding( + self.hidden_size // (self.num_attention_heads * 2) + if position_encoding_2d + else self.hidden_size // self.num_attention_heads, + base=10000, + precision=torch.half, + learnable=False, + ) + + self.scale_mask_softmax = None + + if hidden_size_per_attention_head is None: + self.hidden_size_per_attention_head = hidden_size // num_attention_heads + else: + self.hidden_size_per_attention_head = hidden_size_per_attention_head + + self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head + + # Strided linear layer. + self.query_key_value = init_method( + torch.nn.Linear, + hidden_size, + 3 * self.inner_hidden_size, + bias=bias, + dtype=params_dtype, + ) + + self.dense = init_method( + torch.nn.Linear, + self.inner_hidden_size, + hidden_size, + bias=bias, + dtype=params_dtype, + ) + + self.attn = ld.attn_gpt() + head_size = hidden_size // num_attention_heads + self.head_size_aligned = (head_size + 31) // 32 * 32 + self.max_sequence_length = max_sequence_length + normal_factor = 1.0 / math.sqrt(head_size) + rotary_ndims = int(head_size * 0.5) + self.attn.create(num_attention_heads, head_size, self.head_size_aligned, + normal_factor, 'bf16', 'bf16', max_sequence_length, rotary_ndims, True) + self.layer_past_key_padded = None + self.layer_past_value_padded = None + self.past_seq_len = 0 + inv_freq = 1.0 / (10000 ** (torch.arange(0, rotary_ndims, 2).float() / rotary_ndims)) + + # Build here to make `torch.jit.trace` work. + self.max_seq_len_cached = max_sequence_length + t = torch.arange(self.max_seq_len_cached, device=inv_freq.device, dtype=inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.cos_cached = emb.cos()[None, None, :, :] + self.sin_cached = emb.sin()[None, None, :, :] + + @staticmethod + def attention_mask_func(attention_scores, attention_mask): + attention_scores.masked_fill_(attention_mask, -10000.0) + return attention_scores + + def split_tensor_along_last_dim(self, tensor, num_partitions, + contiguous_split_chunks=False): + """Split a tensor along its last dimension. + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + """ + # Get the size and dimension. + last_dim = tensor.dim() - 1 + last_dim_size = tensor.size()[last_dim] // num_partitions + # Split. + tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) + # Note: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tensor_list + + def forward( + self, + hidden_states: torch.Tensor, + position_ids, + attention_mask: torch.Tensor, + layer_id, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + use_cache: bool = False, + output_attentions: bool = False, + ): + """ + hidden_states: [batch, seq_len, hidden_size] + attention_mask: [(1, 1), seq_len, seq_len] + """ + + # [batch, seq_len, 3 * hidden_size] + mixed_raw_layer = self.query_key_value(hidden_states) + position_ids = position_ids.contiguous() + if self.layer_past_key_padded is None: + shape = (hidden_states.size(0), self.num_attention_heads, self.max_sequence_length, self.head_size_aligned) + self.layer_past_key_padded = torch.zeros(shape, dtype=torch.bfloat16) + self.layer_past_value_padded = torch.zeros(shape, dtype=torch.bfloat16) + if layer_past is None: + self.past_seq_len = 0 + context_layer = self.attn.exec_position(mixed_raw_layer, self.layer_past_key_padded, self.layer_past_value_padded, self.past_seq_len, attention_mask, position_ids, self.cos_cached, self.sin_cached) + present = (self.layer_past_key_padded, self.layer_past_value_padded) + attention_probs = None + self.past_seq_len += mixed_raw_layer.size(1) + + if 0: + # [batch, seq_len, 3 * hidden_size] --> [batch, seq_len, num_attention_heads, 3 * hidden_size_per_attention_head] + new_tensor_shape = mixed_raw_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head, + ) + mixed_raw_layer = mixed_raw_layer.view(*new_tensor_shape) + + # [batch, seq_len, num_attention_heads, hidden_size_per_attention_head] + (query_layer, key_layer, value_layer) = self.split_tensor_along_last_dim(mixed_raw_layer, 3) + + if self.position_encoding_2d: + q1, q2 = query_layer.chunk(2, dim=(query_layer.ndim - 1)) + k1, k2 = key_layer.chunk(2, dim=(key_layer.ndim - 1)) + cos, sin = self.rotary_emb(q1, seq_len=position_ids.max() + 1) + # xxx + position_ids, block_position_ids = position_ids[:, 0, :].contiguous(), \ + position_ids[:, 1, :].contiguous() + q1, k1 = apply_rotary_pos_emb_index(q1, k1, cos, sin, position_ids) + q2, k2 = apply_rotary_pos_emb_index(q2, k2, cos, sin, block_position_ids) + query_layer = torch.concat([q1, q2], dim=(q1.ndim - 1)) + key_layer = torch.concat([k1, k2], dim=(k1.ndim - 1)) + else: + position_ids = position_ids.transpose(0, 1) + cos, sin = self.rotary_emb(value_layer, seq_len=position_ids.max() + 1) + # [batch, seq_len, num_attention_heads, hidden_size_per_attention_head] + query_layer, key_layer = apply_rotary_pos_emb_index(query_layer, key_layer, cos, sin, position_ids) + + # [batch, seq_len, hidden_size] + context_layer, present, attention_probs = attention_fn( + self=self, + query_layer=query_layer, + key_layer=key_layer, + value_layer=value_layer, + attention_mask=attention_mask, + hidden_size_per_partition=self.hidden_size_per_partition, + layer_id=layer_id, + layer_past=layer_past, + use_cache=use_cache + ) + + output = self.dense(context_layer) + + outputs = (output, present) + + if output_attentions: + outputs += (attention_probs,) + + return outputs # output, present, attention_probs + + +class GEGLU(torch.nn.Module): + def __init__(self): + super().__init__() + self.activation_fn = F.gelu + + def forward(self, x): + # dim=-1 breaks in jit for pt<1.10 + x1, x2 = x.chunk(2, dim=(x.ndim - 1)) + return x1 * self.activation_fn(x2) + + +class GLU(torch.nn.Module): + def __init__(self, hidden_size, inner_hidden_size=None, + layer_id=None, bias=True, activation_func=gelu, params_dtype=torch.float, empty_init=True): + super(GLU, self).__init__() + if empty_init: + init_method = skip_init + else: + init_method = default_init + self.layer_id = layer_id + self.activation_func = activation_func + + # Project to 4h. + self.hidden_size = hidden_size + if inner_hidden_size is None: + inner_hidden_size = 4 * hidden_size + self.inner_hidden_size = inner_hidden_size + self.dense_h_to_4h = init_method( + torch.nn.Linear, + self.hidden_size, + self.inner_hidden_size, + bias=bias, + dtype=params_dtype, + ) + # Project back to h. + self.dense_4h_to_h = init_method( + torch.nn.Linear, + self.inner_hidden_size, + self.hidden_size, + bias=bias, + dtype=params_dtype, + ) + + def forward(self, hidden_states): + """ + hidden_states: [seq_len, batch, hidden_size] + """ + + # [seq_len, batch, inner_hidden_size] + intermediate_parallel = self.dense_h_to_4h(hidden_states) + + intermediate_parallel = self.activation_func(intermediate_parallel) + + output = self.dense_4h_to_h(intermediate_parallel) + + return output + + +class GLMBlock(torch.nn.Module): + def __init__( + self, + hidden_size, + num_attention_heads, + layernorm_epsilon, + layer_id, + inner_hidden_size=None, + hidden_size_per_attention_head=None, + layernorm=LayerNorm, + use_bias=True, + params_dtype=torch.float, + num_layers=28, + position_encoding_2d=True, + empty_init=True, + max_sequence_length=2048 + ): + super(GLMBlock, self).__init__() + # Set output layer initialization if not provided. + + self.layer_id = layer_id + + # Layernorm on the input data. + self.input_layernorm = layernorm(hidden_size, eps=layernorm_epsilon) + + self.position_encoding_2d = position_encoding_2d + + # Self attention. + self.attention = SelfAttention( + hidden_size, + num_attention_heads, + layer_id, + hidden_size_per_attention_head=hidden_size_per_attention_head, + bias=use_bias, + params_dtype=params_dtype, + position_encoding_2d=self.position_encoding_2d, + empty_init=empty_init, + max_sequence_length=max_sequence_length + ) + + # Layernorm on the input data. + self.post_attention_layernorm = layernorm(hidden_size, eps=layernorm_epsilon) + + self.num_layers = num_layers + + # GLU + self.mlp = GLU( + hidden_size, + inner_hidden_size=inner_hidden_size, + bias=use_bias, + layer_id=layer_id, + params_dtype=params_dtype, + empty_init=empty_init + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids, + attention_mask: torch.Tensor, + layer_id, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + use_cache: bool = False, + output_attentions: bool = False, + ): + """ + hidden_states: [seq_len, batch, hidden_size] + attention_mask: [(1, 1), seq_len, seq_len] + """ + + # Layer norm at the begining of the transformer layer. + # [batch, seq_len, hidden_size] + attention_input = self.input_layernorm(hidden_states) + + # Self attention. + attention_outputs = self.attention( + attention_input, + position_ids, + attention_mask=attention_mask, + layer_id=layer_id, + layer_past=layer_past, + use_cache=use_cache, + output_attentions=output_attentions + ) + + attention_output = attention_outputs[0] + + outputs = attention_outputs[1:] + + # Residual connection. + alpha = (2 * self.num_layers) ** 0.5 + hidden_states = attention_input * alpha + attention_output + + mlp_input = self.post_attention_layernorm(hidden_states) + + # MLP. + mlp_output = self.mlp(mlp_input) + + # Second residual connection. + output = mlp_input * alpha + mlp_output + + if use_cache: + outputs = (output,) + outputs + else: + outputs = (output,) + outputs[1:] + + return outputs # hidden_states, present, attentions + + +class ChatGLMPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and + a simple interface for downloading and loading pretrained models. + """ + + is_parallelizable = False + supports_gradient_checkpointing = True + config_class = ChatGLMConfig + base_model_prefix = "transformer" + _no_split_modules = ["GLMBlock"] + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights(self, module: nn.Module): + """Initialize the weights.""" + return + + def get_masks(self, input_ids, device): + batch_size, seq_length = input_ids.shape + context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids] + attention_mask = torch.ones((batch_size, seq_length, seq_length), device=device) + attention_mask.tril_() + for i, context_length in enumerate(context_lengths): + attention_mask[i, :, :context_length] = 1 + attention_mask.unsqueeze_(1) + attention_mask = (attention_mask < 0.5).bool() + + return attention_mask + + def get_position_ids(self, input_ids, mask_positions, device, use_gmasks=None): + batch_size, seq_length = input_ids.shape + if use_gmasks is None: + use_gmasks = [False] * batch_size + context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids] + if self.position_encoding_2d: + position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) + for i, context_length in enumerate(context_lengths): + position_ids[i, context_length:] = mask_positions[i] + block_position_ids = [torch.cat(( + torch.zeros(context_length, dtype=torch.long, device=device), + torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1 + )) for context_length in context_lengths] + block_position_ids = torch.stack(block_position_ids, dim=0) + position_ids = torch.stack((position_ids, block_position_ids), dim=1) + else: + position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) + for i, context_length in enumerate(context_lengths): + if not use_gmasks[i]: + position_ids[i, context_length:] = mask_positions[i] + + return position_ids + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, ChatGLMModel): + module.gradient_checkpointing = value + + +CHATGLM_6B_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general + usage and behavior. + + Parameters: + config ([`~ChatGLM6BConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the configuration. + Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +CHATGLM_6B_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`ChatGLM6BTokenizer`]. + See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. + Selected in the range `[0, config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert *input_ids* indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare ChatGLM-6B Model transformer outputting raw hidden-states without any specific head on top.", + CHATGLM_6B_START_DOCSTRING, +) +class ChatGLMModel(ChatGLMPreTrainedModel): + """ + + The model can behave as an encoder (with only self-attention) as well + as a decoder, in which case a layer of cross-attention is added between + the self-attention layers, following the architecture described in [Attention is + all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, + Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + + To behave as an decoder the model needs to be initialized with the + `is_decoder` argument of the configuration set to `True`. + To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` + argument and `add_cross_attention` set to `True`; an + `encoder_hidden_states` is then expected as an input to the forward pass. + """ + + def __init__(self, config: ChatGLMConfig, empty_init=True): + super().__init__(config) + if empty_init: + init_method = skip_init + else: + init_method = default_init + # recording parameters + self.max_sequence_length = config.max_sequence_length + self.hidden_size = config.hidden_size + self.params_dtype = torch.half + self.num_attention_heads = config.num_attention_heads + self.vocab_size = config.vocab_size + self.num_layers = config.num_layers + self.layernorm_epsilon = config.layernorm_epsilon + self.inner_hidden_size = config.inner_hidden_size + self.hidden_size_per_attention_head = self.hidden_size // self.num_attention_heads + self.position_encoding_2d = config.position_encoding_2d + self.pre_seq_len = config.pre_seq_len + self.prefix_projection = config.prefix_projection + + self.word_embeddings = init_method( + torch.nn.Embedding, + num_embeddings=self.vocab_size, embedding_dim=self.hidden_size, + dtype=self.params_dtype + ) + self.gradient_checkpointing = False + + def get_layer(layer_id): + return GLMBlock( + self.hidden_size, + self.num_attention_heads, + self.layernorm_epsilon, + layer_id, + inner_hidden_size=self.inner_hidden_size, + hidden_size_per_attention_head=self.hidden_size_per_attention_head, + layernorm=LayerNorm, + use_bias=True, + params_dtype=self.params_dtype, + position_encoding_2d=self.position_encoding_2d, + empty_init=empty_init, + max_sequence_length=self.max_sequence_length + ) + + self.layers = torch.nn.ModuleList( + [get_layer(layer_id) for layer_id in range(self.num_layers)] + ) + + # Final layer norm before output. + self.final_layernorm = LayerNorm(self.hidden_size, eps=self.layernorm_epsilon) + + if self.pre_seq_len is not None: + for param in self.parameters(): + param.requires_grad = False + self.prefix_tokens = torch.arange(self.pre_seq_len).long() + self.prefix_encoder = PrefixEncoder(config) + self.dropout = torch.nn.Dropout(0.1) + + # total_params = sum(p.numel() for p in self.parameters()) + # trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) + # print("Using p-tuning v2: # trainable_params = {} / {}".format(trainable_params, total_params)) + + def get_input_embeddings(self): + return self.word_embeddings + + def set_input_embeddings(self, new_embeddings: torch.Tensor): + self.word_embeddings = new_embeddings + + def get_prompt(self, batch_size, device, dtype=torch.half): + prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device) + past_key_values = self.prefix_encoder(prefix_tokens).type(dtype) + past_key_values = past_key_values.view( + batch_size, + self.pre_seq_len, + self.num_layers * 2, + self.num_attention_heads, + self.hidden_size // self.num_attention_heads + ) + # seq_len, b, nh, hidden_size + past_key_values = self.dropout(past_key_values) + past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2) + # past_key_values = [(v[0], v[1]) for v in past_key_values] + return past_key_values + + @add_start_docstrings_to_model_forward(CHATGLM_6B_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPastAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPast]: + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length = inputs_embeds.shape[:2] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + if past_key_values is None: + if self.pre_seq_len is not None: + past_key_values = self.get_prompt(batch_size=input_ids.shape[0], device=input_ids.device, + dtype=inputs_embeds.dtype) + else: + past_key_values = tuple([None] * len(self.layers)) + + if attention_mask is None: + attention_mask = self.get_masks( + input_ids, + device=input_ids.device + ) + + + if position_ids is None: + MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id + seqs = input_ids.tolist() + + mask_positions, use_gmasks = [], [] + for seq in seqs: + mask_token = gMASK if gMASK in seq else MASK + use_gmask = mask_token == gMASK + mask_positions.append(seq.index(mask_token)) + use_gmasks.append(use_gmask) + + position_ids = self.get_position_ids( + input_ids, + mask_positions=mask_positions, + device=input_ids.device, + use_gmasks=use_gmasks + ) + + if self.pre_seq_len is not None and attention_mask is not None: + prefix_attention_mask = torch.ones(batch_size, 1, input_ids.size(-1), self.pre_seq_len).to( + attention_mask.device) + prefix_attention_mask = (prefix_attention_mask < 0.5).bool() + attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3) + + # [seq_len, batch, hidden_size] + # hidden_states = inputs_embeds.transpose(0, 1) + # xxx + hidden_states = inputs_embeds + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + if attention_mask is None: + attention_mask = torch.zeros(1, 1, device=input_ids.device).bool() + else: + attention_mask = attention_mask.to(hidden_states.device) + + for i, layer in enumerate(self.layers): + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + layer_past = past_key_values[i] + + if self.gradient_checkpointing and self.training: + layer_ret = torch.utils.checkpoint.checkpoint( + layer, + hidden_states, + position_ids, + attention_mask, + torch.tensor(i), + layer_past, + use_cache, + output_attentions + ) + else: + layer_ret = layer( + hidden_states, + position_ids=position_ids, + attention_mask=attention_mask, + layer_id=torch.tensor(i), + layer_past=layer_past, + use_cache=use_cache, + output_attentions=output_attentions + ) + + hidden_states = layer_ret[0] + + if use_cache: + presents = presents + (layer_ret[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_ret[2 if use_cache else 1],) + + # Final layer norm. + hidden_states = self.final_layernorm(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): + def __init__(self, config: ChatGLMConfig, empty_init=True): + super().__init__(config) + if empty_init: + init_method = skip_init + else: + init_method = default_init + + # self.hidden_size = config.hidden_size + # self.params_dtype = torch.half + # self.vocab_size = config.vocab_size + self.max_sequence_length = config.max_sequence_length + + self.position_encoding_2d = config.position_encoding_2d + + self.transformer = ChatGLMModel(config, empty_init=empty_init) + + self.lm_head = init_method( + nn.Linear, + config.hidden_size, + config.vocab_size, + bias=False, + dtype=torch.half + ) + + self.config = config + + self.quantized = False + + if self.config.quantization_bit: + self.quantize(self.config.quantization_bit, empty_init=True) + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def _update_model_kwargs_for_generation( + self, + outputs: ModelOutput, + model_kwargs: Dict[str, Any], + is_encoder_decoder: bool = False, + standardize_cache_format: bool = False, + ) -> Dict[str, Any]: + # update past_key_values + model_kwargs["past_key_values"] = self._extract_past_from_model_output( + outputs, standardize_cache_format=standardize_cache_format + ) + + # update attention mask + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + if attention_mask is not None: # and attention_mask.dtype == torch.bool: + right = torch.empty((*attention_mask.shape[:3], 1), dtype=torch.float32) + right[:] = -10000.0 + attention_mask = torch.cat( + [attention_mask, right], dim=3) + new_attention_mask = attention_mask[:, :, -1:].clone() + new_attention_mask[..., -1] = 0 + model_kwargs["attention_mask"] = torch.cat( + [attention_mask, new_attention_mask], dim=2 + ) + + # update position ids + if "position_ids" in model_kwargs: + position_ids = model_kwargs["position_ids"] + new_position_id = position_ids[..., -1:].clone() + new_position_id[:, 1, :] += 1 + model_kwargs["position_ids"] = torch.cat( + [position_ids, new_position_id], dim=-1 + ) + + return model_kwargs + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor, + past: Optional[torch.Tensor] = None, + past_key_values: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + **kwargs + ) -> dict: + batch_size, seq_length = input_ids.shape + MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id + seqs = input_ids.tolist() + mask_positions, use_gmasks = [], [] + for seq in seqs: + mask_token = gMASK if gMASK in seq else MASK + use_gmask = mask_token == gMASK + mask_positions.append(seq.index(mask_token)) + use_gmasks.append(use_gmask) + + # only last token for input_ids if past is not None + if past is not None or past_key_values is not None: + last_token = input_ids[:, -1].unsqueeze(-1) + if attention_mask is not None: # and attention_mask.dtype == torch.bool: + attention_mask = attention_mask[:, :, -1:] + # else: + # attention_mask = None + if position_ids is not None: + position_ids = position_ids[..., -1:] + # else: + # context_lengths = [seq.index(self.config.bos_token_id) for seq in seqs] + # if self.position_encoding_2d: + # position_ids = torch.tensor( + # [[mask_position, seq_length - context_length] for mask_position, context_length in + # zip(mask_positions, context_lengths)], dtype=torch.long, device=input_ids.device).unsqueeze(-1) + # else: + # position_ids = torch.tensor([mask_position for mask_position in mask_positions], dtype=torch.long, + # device=input_ids.device).unsqueeze(-1) + + if past is None: + past = past_key_values + return { + "input_ids": last_token, + "past_key_values": past, + "position_ids": position_ids, + "attention_mask": attention_mask + } + else: + # if attention_mask is not None and attention_mask.dtype != torch.bool: + # logger.warning_once(f"The dtype of attention mask ({attention_mask.dtype}) is not bool") + # attention_mask = None + # if attention_mask is None: + # attention_mask = self.get_masks( + # input_ids, + # device=input_ids.device + # ) + # if position_ids is None: + # position_ids = self.get_position_ids( + # input_ids, + # device=input_ids.device, + # mask_positions=mask_positions, + # use_gmasks=use_gmasks + # ) + + return { + "input_ids": input_ids, + "past_key_values": past, + "position_ids": position_ids, + "attention_mask": attention_mask + } + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + + lm_logits = self.lm_head(hidden_states).contiguous() + + loss = None + if labels is not None: + lm_logits = lm_logits.to(torch.float32) + + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss(ignore_index=-100) + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + lm_logits = lm_logits.to(hidden_states.dtype) + loss = loss.to(hidden_states.dtype) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + @staticmethod + def _reorder_cache( + past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor + ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + + Output shares the same memory storage as `past`. + """ + return tuple( + ( + layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)), + layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)), + ) + for layer_past in past + ) + + def process_response(self, response): + response = response.strip() + response = response.replace("[[训练时间]]", "2023年") + punkts = [ + [",", ","], + ["!", "!"], + [":", ":"], + [";", ";"], + ["\?", "?"], + ] + for item in punkts: + response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response) + response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response) + return response + + @torch.no_grad() + def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1, + do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs): + if history is None: + history = [] + if logits_processor is None: + logits_processor = LogitsProcessorList() + logits_processor.append(InvalidScoreLogitsProcessor()) + gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p, + "temperature": temperature, "logits_processor": logits_processor, **kwargs} + if not history: + prompt = query + else: + prompt = "" + for i, (old_query, response) in enumerate(history): + prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response) + prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) + inputs = tokenizer([prompt], return_tensors="pt") + inputs = inputs.to(self.device) + outputs = self.generate(**inputs, **gen_kwargs) + outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] + response = tokenizer.decode(outputs) + response = self.process_response(response) + history = history + [(query, response)] + return response, history + + @torch.no_grad() + def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, + do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs): + if history is None: + history = [] + if logits_processor is None: + logits_processor = LogitsProcessorList() + logits_processor.append(InvalidScoreLogitsProcessor()) + gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p, + "temperature": temperature, "logits_processor": logits_processor, **kwargs} + if not history: + prompt = query + else: + prompt = "" + for i, (old_query, response) in enumerate(history): + prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response) + prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) + inputs = tokenizer([prompt], return_tensors="pt") + inputs = inputs.to(self.device) + for outputs in self.stream_generate(**inputs, **gen_kwargs): + outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] + response = tokenizer.decode(outputs) + response = self.process_response(response) + new_history = history + [(query, response)] + yield response, new_history + + @torch.no_grad() + def stream_generate( + self, + input_ids, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, + **kwargs, + ): + batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] + + if generation_config is None: + generation_config = self.generation_config + generation_config = copy.deepcopy(generation_config) + model_kwargs = generation_config.update(**kwargs) + bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id + + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + + has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None + if has_default_max_length and generation_config.max_new_tokens is None: + warnings.warn( + f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " + "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" + " recommend using `max_new_tokens` to control the maximum length of the generation.", + UserWarning, + ) + elif generation_config.max_new_tokens is not None: + generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length + if not has_default_max_length: + logger.warn( + f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" + f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " + "Please refer to the documentation for more information. " + "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)", + UserWarning, + ) + + if input_ids_seq_length >= generation_config.max_length: + input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" + logger.warning( + f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" + f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" + " increasing `max_new_tokens`." + ) + + # 2. Set generation parameters if not already defined + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + + logits_processor = self._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids_seq_length, + encoder_input_ids=input_ids, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + logits_processor=logits_processor, + ) + + stopping_criteria = self._get_stopping_criteria( + generation_config=generation_config, stopping_criteria=stopping_criteria + ) + logits_warper = self._get_logits_warper(generation_config) + + unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + scores = None + while True: + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + # forward pass to get next token + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=False, + output_hidden_states=False, + ) + + next_token_logits = outputs.logits[:, -1, :] + + # pre-process distribution + next_token_scores = logits_processor(input_ids, next_token_logits) + next_token_scores = logits_warper(input_ids, next_token_scores) + + # sample + probs = nn.functional.softmax(next_token_scores, dim=-1) + if generation_config.do_sample: + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + else: + next_tokens = torch.argmax(probs, dim=-1) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long()) + + # stop when each sentence is finished, or if we exceed the maximum length + if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): + break + yield input_ids + + def quantize(self, bits: int, empty_init=False, **kwargs): + if bits == 0: + return + + from .quantization import quantize + + if self.quantized: + logger.info("Already quantized.") + return self + + self.quantized = True + + self.config.quantization_bit = bits + + self.transformer = quantize(self.transformer, bits, empty_init=empty_init, **kwargs) + return self diff --git a/tests/script/models/chatglm-6b/pytorch_model.bin.index.json b/tests/script/models/chatglm-6b/pytorch_model.bin.index.json new file mode 100644 index 0000000..b8ada2b --- /dev/null +++ b/tests/script/models/chatglm-6b/pytorch_model.bin.index.json @@ -0,0 +1,375 @@ +{ + "metadata": { + "total_size": 13744473856 + }, + "weight_map": { + "lm_head.weight": "pytorch_model-00008-of-00008.bin", + "transformer.final_layernorm.bias": "pytorch_model-00007-of-00008.bin", + "transformer.final_layernorm.weight": "pytorch_model-00007-of-00008.bin", + "transformer.layers.0.attention.dense.bias": "pytorch_model-00001-of-00008.bin", + "transformer.layers.0.attention.dense.weight": "pytorch_model-00001-of-00008.bin", + "transformer.layers.0.attention.query_key_value.bias": "pytorch_model-00001-of-00008.bin", + "transformer.layers.0.attention.query_key_value.weight": "pytorch_model-00001-of-00008.bin", + "transformer.layers.0.attention.rotary_emb.inv_freq": "pytorch_model-00001-of-00008.bin", + "transformer.layers.0.input_layernorm.bias": "pytorch_model-00001-of-00008.bin", + "transformer.layers.0.input_layernorm.weight": "pytorch_model-00001-of-00008.bin", + "transformer.layers.0.mlp.dense_4h_to_h.bias": "pytorch_model-00001-of-00008.bin", + "transformer.layers.0.mlp.dense_4h_to_h.weight": "pytorch_model-00001-of-00008.bin", + "transformer.layers.0.mlp.dense_h_to_4h.bias": "pytorch_model-00001-of-00008.bin", + "transformer.layers.0.mlp.dense_h_to_4h.weight": "pytorch_model-00001-of-00008.bin", + "transformer.layers.0.post_attention_layernorm.bias": "pytorch_model-00001-of-00008.bin", + "transformer.layers.0.post_attention_layernorm.weight": "pytorch_model-00001-of-00008.bin", + "transformer.layers.1.attention.dense.bias": "pytorch_model-00001-of-00008.bin", + "transformer.layers.1.attention.dense.weight": "pytorch_model-00001-of-00008.bin", + "transformer.layers.1.attention.query_key_value.bias": "pytorch_model-00001-of-00008.bin", + "transformer.layers.1.attention.query_key_value.weight": "pytorch_model-00001-of-00008.bin", + "transformer.layers.1.attention.rotary_emb.inv_freq": "pytorch_model-00001-of-00008.bin", + "transformer.layers.1.input_layernorm.bias": "pytorch_model-00001-of-00008.bin", + "transformer.layers.1.input_layernorm.weight": "pytorch_model-00001-of-00008.bin", + "transformer.layers.1.mlp.dense_4h_to_h.bias": "pytorch_model-00002-of-00008.bin", + "transformer.layers.1.mlp.dense_4h_to_h.weight": "pytorch_model-00002-of-00008.bin", + "transformer.layers.1.mlp.dense_h_to_4h.bias": "pytorch_model-00001-of-00008.bin", + "transformer.layers.1.mlp.dense_h_to_4h.weight": "pytorch_model-00001-of-00008.bin", + "transformer.layers.1.post_attention_layernorm.bias": "pytorch_model-00001-of-00008.bin", + "transformer.layers.1.post_attention_layernorm.weight": "pytorch_model-00001-of-00008.bin", + "transformer.layers.10.attention.dense.bias": "pytorch_model-00003-of-00008.bin", + "transformer.layers.10.attention.dense.weight": "pytorch_model-00003-of-00008.bin", + "transformer.layers.10.attention.query_key_value.bias": "pytorch_model-00003-of-00008.bin", + "transformer.layers.10.attention.query_key_value.weight": "pytorch_model-00003-of-00008.bin", + "transformer.layers.10.attention.rotary_emb.inv_freq": "pytorch_model-00003-of-00008.bin", + "transformer.layers.10.input_layernorm.bias": "pytorch_model-00003-of-00008.bin", + "transformer.layers.10.input_layernorm.weight": "pytorch_model-00003-of-00008.bin", + "transformer.layers.10.mlp.dense_4h_to_h.bias": "pytorch_model-00003-of-00008.bin", + "transformer.layers.10.mlp.dense_4h_to_h.weight": "pytorch_model-00003-of-00008.bin", + "transformer.layers.10.mlp.dense_h_to_4h.bias": "pytorch_model-00003-of-00008.bin", + "transformer.layers.10.mlp.dense_h_to_4h.weight": "pytorch_model-00003-of-00008.bin", + "transformer.layers.10.post_attention_layernorm.bias": "pytorch_model-00003-of-00008.bin", + "transformer.layers.10.post_attention_layernorm.weight": "pytorch_model-00003-of-00008.bin", + "transformer.layers.11.attention.dense.bias": "pytorch_model-00004-of-00008.bin", + "transformer.layers.11.attention.dense.weight": "pytorch_model-00004-of-00008.bin", + "transformer.layers.11.attention.query_key_value.bias": "pytorch_model-00003-of-00008.bin", + "transformer.layers.11.attention.query_key_value.weight": "pytorch_model-00003-of-00008.bin", + "transformer.layers.11.attention.rotary_emb.inv_freq": "pytorch_model-00003-of-00008.bin", + "transformer.layers.11.input_layernorm.bias": "pytorch_model-00003-of-00008.bin", + "transformer.layers.11.input_layernorm.weight": "pytorch_model-00003-of-00008.bin", + "transformer.layers.11.mlp.dense_4h_to_h.bias": "pytorch_model-00004-of-00008.bin", + "transformer.layers.11.mlp.dense_4h_to_h.weight": "pytorch_model-00004-of-00008.bin", + "transformer.layers.11.mlp.dense_h_to_4h.bias": "pytorch_model-00004-of-00008.bin", + "transformer.layers.11.mlp.dense_h_to_4h.weight": "pytorch_model-00004-of-00008.bin", + "transformer.layers.11.post_attention_layernorm.bias": "pytorch_model-00004-of-00008.bin", + "transformer.layers.11.post_attention_layernorm.weight": "pytorch_model-00004-of-00008.bin", + "transformer.layers.12.attention.dense.bias": "pytorch_model-00004-of-00008.bin", + "transformer.layers.12.attention.dense.weight": "pytorch_model-00004-of-00008.bin", + "transformer.layers.12.attention.query_key_value.bias": "pytorch_model-00004-of-00008.bin", + "transformer.layers.12.attention.query_key_value.weight": "pytorch_model-00004-of-00008.bin", + "transformer.layers.12.attention.rotary_emb.inv_freq": "pytorch_model-00004-of-00008.bin", + "transformer.layers.12.input_layernorm.bias": "pytorch_model-00004-of-00008.bin", + "transformer.layers.12.input_layernorm.weight": "pytorch_model-00004-of-00008.bin", + "transformer.layers.12.mlp.dense_4h_to_h.bias": "pytorch_model-00004-of-00008.bin", + "transformer.layers.12.mlp.dense_4h_to_h.weight": "pytorch_model-00004-of-00008.bin", + "transformer.layers.12.mlp.dense_h_to_4h.bias": "pytorch_model-00004-of-00008.bin", + "transformer.layers.12.mlp.dense_h_to_4h.weight": "pytorch_model-00004-of-00008.bin", + "transformer.layers.12.post_attention_layernorm.bias": "pytorch_model-00004-of-00008.bin", + "transformer.layers.12.post_attention_layernorm.weight": "pytorch_model-00004-of-00008.bin", + "transformer.layers.13.attention.dense.bias": "pytorch_model-00004-of-00008.bin", + "transformer.layers.13.attention.dense.weight": "pytorch_model-00004-of-00008.bin", + "transformer.layers.13.attention.query_key_value.bias": "pytorch_model-00004-of-00008.bin", + "transformer.layers.13.attention.query_key_value.weight": "pytorch_model-00004-of-00008.bin", + "transformer.layers.13.attention.rotary_emb.inv_freq": "pytorch_model-00004-of-00008.bin", + "transformer.layers.13.input_layernorm.bias": "pytorch_model-00004-of-00008.bin", + "transformer.layers.13.input_layernorm.weight": "pytorch_model-00004-of-00008.bin", + "transformer.layers.13.mlp.dense_4h_to_h.bias": "pytorch_model-00004-of-00008.bin", + "transformer.layers.13.mlp.dense_4h_to_h.weight": "pytorch_model-00004-of-00008.bin", + "transformer.layers.13.mlp.dense_h_to_4h.bias": "pytorch_model-00004-of-00008.bin", + "transformer.layers.13.mlp.dense_h_to_4h.weight": "pytorch_model-00004-of-00008.bin", + "transformer.layers.13.post_attention_layernorm.bias": "pytorch_model-00004-of-00008.bin", + "transformer.layers.13.post_attention_layernorm.weight": "pytorch_model-00004-of-00008.bin", + "transformer.layers.14.attention.dense.bias": "pytorch_model-00004-of-00008.bin", + "transformer.layers.14.attention.dense.weight": "pytorch_model-00004-of-00008.bin", + "transformer.layers.14.attention.query_key_value.bias": "pytorch_model-00004-of-00008.bin", + "transformer.layers.14.attention.query_key_value.weight": "pytorch_model-00004-of-00008.bin", + "transformer.layers.14.attention.rotary_emb.inv_freq": "pytorch_model-00004-of-00008.bin", + "transformer.layers.14.input_layernorm.bias": "pytorch_model-00004-of-00008.bin", + "transformer.layers.14.input_layernorm.weight": "pytorch_model-00004-of-00008.bin", + "transformer.layers.14.mlp.dense_4h_to_h.bias": "pytorch_model-00004-of-00008.bin", + "transformer.layers.14.mlp.dense_4h_to_h.weight": "pytorch_model-00004-of-00008.bin", + "transformer.layers.14.mlp.dense_h_to_4h.bias": "pytorch_model-00004-of-00008.bin", + "transformer.layers.14.mlp.dense_h_to_4h.weight": "pytorch_model-00004-of-00008.bin", + "transformer.layers.14.post_attention_layernorm.bias": "pytorch_model-00004-of-00008.bin", + "transformer.layers.14.post_attention_layernorm.weight": "pytorch_model-00004-of-00008.bin", + "transformer.layers.15.attention.dense.bias": "pytorch_model-00004-of-00008.bin", + "transformer.layers.15.attention.dense.weight": "pytorch_model-00004-of-00008.bin", + "transformer.layers.15.attention.query_key_value.bias": "pytorch_model-00004-of-00008.bin", + "transformer.layers.15.attention.query_key_value.weight": "pytorch_model-00004-of-00008.bin", + "transformer.layers.15.attention.rotary_emb.inv_freq": "pytorch_model-00004-of-00008.bin", + "transformer.layers.15.input_layernorm.bias": "pytorch_model-00004-of-00008.bin", + "transformer.layers.15.input_layernorm.weight": "pytorch_model-00004-of-00008.bin", + "transformer.layers.15.mlp.dense_4h_to_h.bias": "pytorch_model-00004-of-00008.bin", + "transformer.layers.15.mlp.dense_4h_to_h.weight": "pytorch_model-00004-of-00008.bin", + "transformer.layers.15.mlp.dense_h_to_4h.bias": "pytorch_model-00004-of-00008.bin", + "transformer.layers.15.mlp.dense_h_to_4h.weight": "pytorch_model-00004-of-00008.bin", + "transformer.layers.15.post_attention_layernorm.bias": "pytorch_model-00004-of-00008.bin", + "transformer.layers.15.post_attention_layernorm.weight": "pytorch_model-00004-of-00008.bin", + "transformer.layers.16.attention.dense.bias": "pytorch_model-00005-of-00008.bin", + "transformer.layers.16.attention.dense.weight": "pytorch_model-00005-of-00008.bin", + "transformer.layers.16.attention.query_key_value.bias": "pytorch_model-00005-of-00008.bin", + "transformer.layers.16.attention.query_key_value.weight": "pytorch_model-00005-of-00008.bin", + "transformer.layers.16.attention.rotary_emb.inv_freq": "pytorch_model-00004-of-00008.bin", + "transformer.layers.16.input_layernorm.bias": "pytorch_model-00004-of-00008.bin", + "transformer.layers.16.input_layernorm.weight": "pytorch_model-00004-of-00008.bin", + "transformer.layers.16.mlp.dense_4h_to_h.bias": "pytorch_model-00005-of-00008.bin", + "transformer.layers.16.mlp.dense_4h_to_h.weight": "pytorch_model-00005-of-00008.bin", + "transformer.layers.16.mlp.dense_h_to_4h.bias": "pytorch_model-00005-of-00008.bin", + "transformer.layers.16.mlp.dense_h_to_4h.weight": "pytorch_model-00005-of-00008.bin", + "transformer.layers.16.post_attention_layernorm.bias": "pytorch_model-00005-of-00008.bin", + "transformer.layers.16.post_attention_layernorm.weight": "pytorch_model-00005-of-00008.bin", + "transformer.layers.17.attention.dense.bias": "pytorch_model-00005-of-00008.bin", + "transformer.layers.17.attention.dense.weight": "pytorch_model-00005-of-00008.bin", + "transformer.layers.17.attention.query_key_value.bias": "pytorch_model-00005-of-00008.bin", + "transformer.layers.17.attention.query_key_value.weight": "pytorch_model-00005-of-00008.bin", + "transformer.layers.17.attention.rotary_emb.inv_freq": "pytorch_model-00005-of-00008.bin", + "transformer.layers.17.input_layernorm.bias": "pytorch_model-00005-of-00008.bin", + "transformer.layers.17.input_layernorm.weight": "pytorch_model-00005-of-00008.bin", + "transformer.layers.17.mlp.dense_4h_to_h.bias": "pytorch_model-00005-of-00008.bin", + "transformer.layers.17.mlp.dense_4h_to_h.weight": "pytorch_model-00005-of-00008.bin", + "transformer.layers.17.mlp.dense_h_to_4h.bias": "pytorch_model-00005-of-00008.bin", + "transformer.layers.17.mlp.dense_h_to_4h.weight": "pytorch_model-00005-of-00008.bin", + "transformer.layers.17.post_attention_layernorm.bias": "pytorch_model-00005-of-00008.bin", + "transformer.layers.17.post_attention_layernorm.weight": "pytorch_model-00005-of-00008.bin", + "transformer.layers.18.attention.dense.bias": "pytorch_model-00005-of-00008.bin", + "transformer.layers.18.attention.dense.weight": "pytorch_model-00005-of-00008.bin", + "transformer.layers.18.attention.query_key_value.bias": "pytorch_model-00005-of-00008.bin", + "transformer.layers.18.attention.query_key_value.weight": "pytorch_model-00005-of-00008.bin", + "transformer.layers.18.attention.rotary_emb.inv_freq": "pytorch_model-00005-of-00008.bin", + "transformer.layers.18.input_layernorm.bias": "pytorch_model-00005-of-00008.bin", + "transformer.layers.18.input_layernorm.weight": "pytorch_model-00005-of-00008.bin", + "transformer.layers.18.mlp.dense_4h_to_h.bias": "pytorch_model-00005-of-00008.bin", + "transformer.layers.18.mlp.dense_4h_to_h.weight": "pytorch_model-00005-of-00008.bin", + "transformer.layers.18.mlp.dense_h_to_4h.bias": "pytorch_model-00005-of-00008.bin", + "transformer.layers.18.mlp.dense_h_to_4h.weight": "pytorch_model-00005-of-00008.bin", + "transformer.layers.18.post_attention_layernorm.bias": "pytorch_model-00005-of-00008.bin", + "transformer.layers.18.post_attention_layernorm.weight": "pytorch_model-00005-of-00008.bin", + "transformer.layers.19.attention.dense.bias": "pytorch_model-00005-of-00008.bin", + "transformer.layers.19.attention.dense.weight": "pytorch_model-00005-of-00008.bin", + "transformer.layers.19.attention.query_key_value.bias": "pytorch_model-00005-of-00008.bin", + "transformer.layers.19.attention.query_key_value.weight": "pytorch_model-00005-of-00008.bin", + "transformer.layers.19.attention.rotary_emb.inv_freq": "pytorch_model-00005-of-00008.bin", + "transformer.layers.19.input_layernorm.bias": "pytorch_model-00005-of-00008.bin", + "transformer.layers.19.input_layernorm.weight": "pytorch_model-00005-of-00008.bin", + "transformer.layers.19.mlp.dense_4h_to_h.bias": "pytorch_model-00005-of-00008.bin", + "transformer.layers.19.mlp.dense_4h_to_h.weight": "pytorch_model-00005-of-00008.bin", + "transformer.layers.19.mlp.dense_h_to_4h.bias": "pytorch_model-00005-of-00008.bin", + "transformer.layers.19.mlp.dense_h_to_4h.weight": "pytorch_model-00005-of-00008.bin", + "transformer.layers.19.post_attention_layernorm.bias": "pytorch_model-00005-of-00008.bin", + "transformer.layers.19.post_attention_layernorm.weight": "pytorch_model-00005-of-00008.bin", + "transformer.layers.2.attention.dense.bias": "pytorch_model-00002-of-00008.bin", + "transformer.layers.2.attention.dense.weight": "pytorch_model-00002-of-00008.bin", + "transformer.layers.2.attention.query_key_value.bias": "pytorch_model-00002-of-00008.bin", + "transformer.layers.2.attention.query_key_value.weight": "pytorch_model-00002-of-00008.bin", + "transformer.layers.2.attention.rotary_emb.inv_freq": "pytorch_model-00002-of-00008.bin", + "transformer.layers.2.input_layernorm.bias": "pytorch_model-00002-of-00008.bin", + "transformer.layers.2.input_layernorm.weight": "pytorch_model-00002-of-00008.bin", + "transformer.layers.2.mlp.dense_4h_to_h.bias": "pytorch_model-00002-of-00008.bin", + "transformer.layers.2.mlp.dense_4h_to_h.weight": "pytorch_model-00002-of-00008.bin", + "transformer.layers.2.mlp.dense_h_to_4h.bias": "pytorch_model-00002-of-00008.bin", + "transformer.layers.2.mlp.dense_h_to_4h.weight": "pytorch_model-00002-of-00008.bin", + "transformer.layers.2.post_attention_layernorm.bias": "pytorch_model-00002-of-00008.bin", + "transformer.layers.2.post_attention_layernorm.weight": "pytorch_model-00002-of-00008.bin", + "transformer.layers.20.attention.dense.bias": "pytorch_model-00005-of-00008.bin", + "transformer.layers.20.attention.dense.weight": "pytorch_model-00005-of-00008.bin", + "transformer.layers.20.attention.query_key_value.bias": "pytorch_model-00005-of-00008.bin", + "transformer.layers.20.attention.query_key_value.weight": "pytorch_model-00005-of-00008.bin", + "transformer.layers.20.attention.rotary_emb.inv_freq": "pytorch_model-00005-of-00008.bin", + "transformer.layers.20.input_layernorm.bias": "pytorch_model-00005-of-00008.bin", + "transformer.layers.20.input_layernorm.weight": "pytorch_model-00005-of-00008.bin", + "transformer.layers.20.mlp.dense_4h_to_h.bias": "pytorch_model-00006-of-00008.bin", + "transformer.layers.20.mlp.dense_4h_to_h.weight": "pytorch_model-00006-of-00008.bin", + "transformer.layers.20.mlp.dense_h_to_4h.bias": "pytorch_model-00005-of-00008.bin", + "transformer.layers.20.mlp.dense_h_to_4h.weight": "pytorch_model-00005-of-00008.bin", + "transformer.layers.20.post_attention_layernorm.bias": "pytorch_model-00005-of-00008.bin", + "transformer.layers.20.post_attention_layernorm.weight": "pytorch_model-00005-of-00008.bin", + "transformer.layers.21.attention.dense.bias": "pytorch_model-00006-of-00008.bin", + "transformer.layers.21.attention.dense.weight": "pytorch_model-00006-of-00008.bin", + "transformer.layers.21.attention.query_key_value.bias": "pytorch_model-00006-of-00008.bin", + "transformer.layers.21.attention.query_key_value.weight": "pytorch_model-00006-of-00008.bin", + "transformer.layers.21.attention.rotary_emb.inv_freq": "pytorch_model-00006-of-00008.bin", + "transformer.layers.21.input_layernorm.bias": "pytorch_model-00006-of-00008.bin", + "transformer.layers.21.input_layernorm.weight": "pytorch_model-00006-of-00008.bin", + "transformer.layers.21.mlp.dense_4h_to_h.bias": "pytorch_model-00006-of-00008.bin", + "transformer.layers.21.mlp.dense_4h_to_h.weight": "pytorch_model-00006-of-00008.bin", + "transformer.layers.21.mlp.dense_h_to_4h.bias": "pytorch_model-00006-of-00008.bin", + "transformer.layers.21.mlp.dense_h_to_4h.weight": "pytorch_model-00006-of-00008.bin", + "transformer.layers.21.post_attention_layernorm.bias": "pytorch_model-00006-of-00008.bin", + "transformer.layers.21.post_attention_layernorm.weight": "pytorch_model-00006-of-00008.bin", + "transformer.layers.22.attention.dense.bias": "pytorch_model-00006-of-00008.bin", + "transformer.layers.22.attention.dense.weight": "pytorch_model-00006-of-00008.bin", + "transformer.layers.22.attention.query_key_value.bias": "pytorch_model-00006-of-00008.bin", + "transformer.layers.22.attention.query_key_value.weight": "pytorch_model-00006-of-00008.bin", + "transformer.layers.22.attention.rotary_emb.inv_freq": "pytorch_model-00006-of-00008.bin", + "transformer.layers.22.input_layernorm.bias": "pytorch_model-00006-of-00008.bin", + "transformer.layers.22.input_layernorm.weight": "pytorch_model-00006-of-00008.bin", + "transformer.layers.22.mlp.dense_4h_to_h.bias": "pytorch_model-00006-of-00008.bin", + "transformer.layers.22.mlp.dense_4h_to_h.weight": "pytorch_model-00006-of-00008.bin", + "transformer.layers.22.mlp.dense_h_to_4h.bias": "pytorch_model-00006-of-00008.bin", + "transformer.layers.22.mlp.dense_h_to_4h.weight": "pytorch_model-00006-of-00008.bin", + "transformer.layers.22.post_attention_layernorm.bias": "pytorch_model-00006-of-00008.bin", + "transformer.layers.22.post_attention_layernorm.weight": "pytorch_model-00006-of-00008.bin", + "transformer.layers.23.attention.dense.bias": "pytorch_model-00006-of-00008.bin", + "transformer.layers.23.attention.dense.weight": "pytorch_model-00006-of-00008.bin", + "transformer.layers.23.attention.query_key_value.bias": "pytorch_model-00006-of-00008.bin", + "transformer.layers.23.attention.query_key_value.weight": "pytorch_model-00006-of-00008.bin", + "transformer.layers.23.attention.rotary_emb.inv_freq": "pytorch_model-00006-of-00008.bin", + "transformer.layers.23.input_layernorm.bias": "pytorch_model-00006-of-00008.bin", + "transformer.layers.23.input_layernorm.weight": "pytorch_model-00006-of-00008.bin", + "transformer.layers.23.mlp.dense_4h_to_h.bias": "pytorch_model-00006-of-00008.bin", + "transformer.layers.23.mlp.dense_4h_to_h.weight": "pytorch_model-00006-of-00008.bin", + "transformer.layers.23.mlp.dense_h_to_4h.bias": "pytorch_model-00006-of-00008.bin", + "transformer.layers.23.mlp.dense_h_to_4h.weight": "pytorch_model-00006-of-00008.bin", + "transformer.layers.23.post_attention_layernorm.bias": "pytorch_model-00006-of-00008.bin", + "transformer.layers.23.post_attention_layernorm.weight": "pytorch_model-00006-of-00008.bin", + "transformer.layers.24.attention.dense.bias": "pytorch_model-00006-of-00008.bin", + "transformer.layers.24.attention.dense.weight": "pytorch_model-00006-of-00008.bin", + "transformer.layers.24.attention.query_key_value.bias": "pytorch_model-00006-of-00008.bin", + "transformer.layers.24.attention.query_key_value.weight": "pytorch_model-00006-of-00008.bin", + "transformer.layers.24.attention.rotary_emb.inv_freq": "pytorch_model-00006-of-00008.bin", + "transformer.layers.24.input_layernorm.bias": "pytorch_model-00006-of-00008.bin", + "transformer.layers.24.input_layernorm.weight": "pytorch_model-00006-of-00008.bin", + "transformer.layers.24.mlp.dense_4h_to_h.bias": "pytorch_model-00006-of-00008.bin", + "transformer.layers.24.mlp.dense_4h_to_h.weight": "pytorch_model-00006-of-00008.bin", + "transformer.layers.24.mlp.dense_h_to_4h.bias": "pytorch_model-00006-of-00008.bin", + "transformer.layers.24.mlp.dense_h_to_4h.weight": "pytorch_model-00006-of-00008.bin", + "transformer.layers.24.post_attention_layernorm.bias": "pytorch_model-00006-of-00008.bin", + "transformer.layers.24.post_attention_layernorm.weight": "pytorch_model-00006-of-00008.bin", + "transformer.layers.25.attention.dense.bias": "pytorch_model-00006-of-00008.bin", + "transformer.layers.25.attention.dense.weight": "pytorch_model-00006-of-00008.bin", + "transformer.layers.25.attention.query_key_value.bias": "pytorch_model-00006-of-00008.bin", + "transformer.layers.25.attention.query_key_value.weight": "pytorch_model-00006-of-00008.bin", + "transformer.layers.25.attention.rotary_emb.inv_freq": "pytorch_model-00006-of-00008.bin", + "transformer.layers.25.input_layernorm.bias": "pytorch_model-00006-of-00008.bin", + "transformer.layers.25.input_layernorm.weight": "pytorch_model-00006-of-00008.bin", + "transformer.layers.25.mlp.dense_4h_to_h.bias": "pytorch_model-00007-of-00008.bin", + "transformer.layers.25.mlp.dense_4h_to_h.weight": "pytorch_model-00007-of-00008.bin", + "transformer.layers.25.mlp.dense_h_to_4h.bias": "pytorch_model-00007-of-00008.bin", + "transformer.layers.25.mlp.dense_h_to_4h.weight": "pytorch_model-00007-of-00008.bin", + "transformer.layers.25.post_attention_layernorm.bias": "pytorch_model-00006-of-00008.bin", + "transformer.layers.25.post_attention_layernorm.weight": "pytorch_model-00006-of-00008.bin", + "transformer.layers.26.attention.dense.bias": "pytorch_model-00007-of-00008.bin", + "transformer.layers.26.attention.dense.weight": "pytorch_model-00007-of-00008.bin", + "transformer.layers.26.attention.query_key_value.bias": "pytorch_model-00007-of-00008.bin", + "transformer.layers.26.attention.query_key_value.weight": "pytorch_model-00007-of-00008.bin", + "transformer.layers.26.attention.rotary_emb.inv_freq": "pytorch_model-00007-of-00008.bin", + "transformer.layers.26.input_layernorm.bias": "pytorch_model-00007-of-00008.bin", + "transformer.layers.26.input_layernorm.weight": "pytorch_model-00007-of-00008.bin", + "transformer.layers.26.mlp.dense_4h_to_h.bias": "pytorch_model-00007-of-00008.bin", + "transformer.layers.26.mlp.dense_4h_to_h.weight": "pytorch_model-00007-of-00008.bin", + "transformer.layers.26.mlp.dense_h_to_4h.bias": "pytorch_model-00007-of-00008.bin", + "transformer.layers.26.mlp.dense_h_to_4h.weight": "pytorch_model-00007-of-00008.bin", + "transformer.layers.26.post_attention_layernorm.bias": "pytorch_model-00007-of-00008.bin", + "transformer.layers.26.post_attention_layernorm.weight": "pytorch_model-00007-of-00008.bin", + "transformer.layers.27.attention.dense.bias": "pytorch_model-00007-of-00008.bin", + "transformer.layers.27.attention.dense.weight": "pytorch_model-00007-of-00008.bin", + "transformer.layers.27.attention.query_key_value.bias": "pytorch_model-00007-of-00008.bin", + "transformer.layers.27.attention.query_key_value.weight": "pytorch_model-00007-of-00008.bin", + "transformer.layers.27.attention.rotary_emb.inv_freq": "pytorch_model-00007-of-00008.bin", + "transformer.layers.27.input_layernorm.bias": "pytorch_model-00007-of-00008.bin", + "transformer.layers.27.input_layernorm.weight": "pytorch_model-00007-of-00008.bin", + "transformer.layers.27.mlp.dense_4h_to_h.bias": "pytorch_model-00007-of-00008.bin", + "transformer.layers.27.mlp.dense_4h_to_h.weight": "pytorch_model-00007-of-00008.bin", + "transformer.layers.27.mlp.dense_h_to_4h.bias": "pytorch_model-00007-of-00008.bin", + "transformer.layers.27.mlp.dense_h_to_4h.weight": "pytorch_model-00007-of-00008.bin", + "transformer.layers.27.post_attention_layernorm.bias": "pytorch_model-00007-of-00008.bin", + "transformer.layers.27.post_attention_layernorm.weight": "pytorch_model-00007-of-00008.bin", + "transformer.layers.3.attention.dense.bias": "pytorch_model-00002-of-00008.bin", + "transformer.layers.3.attention.dense.weight": "pytorch_model-00002-of-00008.bin", + "transformer.layers.3.attention.query_key_value.bias": "pytorch_model-00002-of-00008.bin", + "transformer.layers.3.attention.query_key_value.weight": "pytorch_model-00002-of-00008.bin", + "transformer.layers.3.attention.rotary_emb.inv_freq": "pytorch_model-00002-of-00008.bin", + "transformer.layers.3.input_layernorm.bias": "pytorch_model-00002-of-00008.bin", + "transformer.layers.3.input_layernorm.weight": "pytorch_model-00002-of-00008.bin", + "transformer.layers.3.mlp.dense_4h_to_h.bias": "pytorch_model-00002-of-00008.bin", + "transformer.layers.3.mlp.dense_4h_to_h.weight": "pytorch_model-00002-of-00008.bin", + "transformer.layers.3.mlp.dense_h_to_4h.bias": "pytorch_model-00002-of-00008.bin", + "transformer.layers.3.mlp.dense_h_to_4h.weight": "pytorch_model-00002-of-00008.bin", + "transformer.layers.3.post_attention_layernorm.bias": "pytorch_model-00002-of-00008.bin", + "transformer.layers.3.post_attention_layernorm.weight": "pytorch_model-00002-of-00008.bin", + "transformer.layers.4.attention.dense.bias": "pytorch_model-00002-of-00008.bin", + "transformer.layers.4.attention.dense.weight": "pytorch_model-00002-of-00008.bin", + "transformer.layers.4.attention.query_key_value.bias": "pytorch_model-00002-of-00008.bin", + "transformer.layers.4.attention.query_key_value.weight": "pytorch_model-00002-of-00008.bin", + "transformer.layers.4.attention.rotary_emb.inv_freq": "pytorch_model-00002-of-00008.bin", + "transformer.layers.4.input_layernorm.bias": "pytorch_model-00002-of-00008.bin", + "transformer.layers.4.input_layernorm.weight": "pytorch_model-00002-of-00008.bin", + "transformer.layers.4.mlp.dense_4h_to_h.bias": "pytorch_model-00002-of-00008.bin", + "transformer.layers.4.mlp.dense_4h_to_h.weight": "pytorch_model-00002-of-00008.bin", + "transformer.layers.4.mlp.dense_h_to_4h.bias": "pytorch_model-00002-of-00008.bin", + "transformer.layers.4.mlp.dense_h_to_4h.weight": "pytorch_model-00002-of-00008.bin", + "transformer.layers.4.post_attention_layernorm.bias": "pytorch_model-00002-of-00008.bin", + "transformer.layers.4.post_attention_layernorm.weight": "pytorch_model-00002-of-00008.bin", + "transformer.layers.5.attention.dense.bias": "pytorch_model-00002-of-00008.bin", + "transformer.layers.5.attention.dense.weight": "pytorch_model-00002-of-00008.bin", + "transformer.layers.5.attention.query_key_value.bias": "pytorch_model-00002-of-00008.bin", + "transformer.layers.5.attention.query_key_value.weight": "pytorch_model-00002-of-00008.bin", + "transformer.layers.5.attention.rotary_emb.inv_freq": "pytorch_model-00002-of-00008.bin", + "transformer.layers.5.input_layernorm.bias": "pytorch_model-00002-of-00008.bin", + "transformer.layers.5.input_layernorm.weight": "pytorch_model-00002-of-00008.bin", + "transformer.layers.5.mlp.dense_4h_to_h.bias": "pytorch_model-00002-of-00008.bin", + "transformer.layers.5.mlp.dense_4h_to_h.weight": "pytorch_model-00002-of-00008.bin", + "transformer.layers.5.mlp.dense_h_to_4h.bias": "pytorch_model-00002-of-00008.bin", + "transformer.layers.5.mlp.dense_h_to_4h.weight": "pytorch_model-00002-of-00008.bin", + "transformer.layers.5.post_attention_layernorm.bias": "pytorch_model-00002-of-00008.bin", + "transformer.layers.5.post_attention_layernorm.weight": "pytorch_model-00002-of-00008.bin", + "transformer.layers.6.attention.dense.bias": "pytorch_model-00002-of-00008.bin", + "transformer.layers.6.attention.dense.weight": "pytorch_model-00002-of-00008.bin", + "transformer.layers.6.attention.query_key_value.bias": "pytorch_model-00002-of-00008.bin", + "transformer.layers.6.attention.query_key_value.weight": "pytorch_model-00002-of-00008.bin", + "transformer.layers.6.attention.rotary_emb.inv_freq": "pytorch_model-00002-of-00008.bin", + "transformer.layers.6.input_layernorm.bias": "pytorch_model-00002-of-00008.bin", + "transformer.layers.6.input_layernorm.weight": "pytorch_model-00002-of-00008.bin", + "transformer.layers.6.mlp.dense_4h_to_h.bias": "pytorch_model-00003-of-00008.bin", + "transformer.layers.6.mlp.dense_4h_to_h.weight": "pytorch_model-00003-of-00008.bin", + "transformer.layers.6.mlp.dense_h_to_4h.bias": "pytorch_model-00003-of-00008.bin", + "transformer.layers.6.mlp.dense_h_to_4h.weight": "pytorch_model-00003-of-00008.bin", + "transformer.layers.6.post_attention_layernorm.bias": "pytorch_model-00002-of-00008.bin", + "transformer.layers.6.post_attention_layernorm.weight": "pytorch_model-00002-of-00008.bin", + "transformer.layers.7.attention.dense.bias": "pytorch_model-00003-of-00008.bin", + "transformer.layers.7.attention.dense.weight": "pytorch_model-00003-of-00008.bin", + "transformer.layers.7.attention.query_key_value.bias": "pytorch_model-00003-of-00008.bin", + "transformer.layers.7.attention.query_key_value.weight": "pytorch_model-00003-of-00008.bin", + "transformer.layers.7.attention.rotary_emb.inv_freq": "pytorch_model-00003-of-00008.bin", + "transformer.layers.7.input_layernorm.bias": "pytorch_model-00003-of-00008.bin", + "transformer.layers.7.input_layernorm.weight": "pytorch_model-00003-of-00008.bin", + "transformer.layers.7.mlp.dense_4h_to_h.bias": "pytorch_model-00003-of-00008.bin", + "transformer.layers.7.mlp.dense_4h_to_h.weight": "pytorch_model-00003-of-00008.bin", + "transformer.layers.7.mlp.dense_h_to_4h.bias": "pytorch_model-00003-of-00008.bin", + "transformer.layers.7.mlp.dense_h_to_4h.weight": "pytorch_model-00003-of-00008.bin", + "transformer.layers.7.post_attention_layernorm.bias": "pytorch_model-00003-of-00008.bin", + "transformer.layers.7.post_attention_layernorm.weight": "pytorch_model-00003-of-00008.bin", + "transformer.layers.8.attention.dense.bias": "pytorch_model-00003-of-00008.bin", + "transformer.layers.8.attention.dense.weight": "pytorch_model-00003-of-00008.bin", + "transformer.layers.8.attention.query_key_value.bias": "pytorch_model-00003-of-00008.bin", + "transformer.layers.8.attention.query_key_value.weight": "pytorch_model-00003-of-00008.bin", + "transformer.layers.8.attention.rotary_emb.inv_freq": "pytorch_model-00003-of-00008.bin", + "transformer.layers.8.input_layernorm.bias": "pytorch_model-00003-of-00008.bin", + "transformer.layers.8.input_layernorm.weight": "pytorch_model-00003-of-00008.bin", + "transformer.layers.8.mlp.dense_4h_to_h.bias": "pytorch_model-00003-of-00008.bin", + "transformer.layers.8.mlp.dense_4h_to_h.weight": "pytorch_model-00003-of-00008.bin", + "transformer.layers.8.mlp.dense_h_to_4h.bias": "pytorch_model-00003-of-00008.bin", + "transformer.layers.8.mlp.dense_h_to_4h.weight": "pytorch_model-00003-of-00008.bin", + "transformer.layers.8.post_attention_layernorm.bias": "pytorch_model-00003-of-00008.bin", + "transformer.layers.8.post_attention_layernorm.weight": "pytorch_model-00003-of-00008.bin", + "transformer.layers.9.attention.dense.bias": "pytorch_model-00003-of-00008.bin", + "transformer.layers.9.attention.dense.weight": "pytorch_model-00003-of-00008.bin", + "transformer.layers.9.attention.query_key_value.bias": "pytorch_model-00003-of-00008.bin", + "transformer.layers.9.attention.query_key_value.weight": "pytorch_model-00003-of-00008.bin", + "transformer.layers.9.attention.rotary_emb.inv_freq": "pytorch_model-00003-of-00008.bin", + "transformer.layers.9.input_layernorm.bias": "pytorch_model-00003-of-00008.bin", + "transformer.layers.9.input_layernorm.weight": "pytorch_model-00003-of-00008.bin", + "transformer.layers.9.mlp.dense_4h_to_h.bias": "pytorch_model-00003-of-00008.bin", + "transformer.layers.9.mlp.dense_4h_to_h.weight": "pytorch_model-00003-of-00008.bin", + "transformer.layers.9.mlp.dense_h_to_4h.bias": "pytorch_model-00003-of-00008.bin", + "transformer.layers.9.mlp.dense_h_to_4h.weight": "pytorch_model-00003-of-00008.bin", + "transformer.layers.9.post_attention_layernorm.bias": "pytorch_model-00003-of-00008.bin", + "transformer.layers.9.post_attention_layernorm.weight": "pytorch_model-00003-of-00008.bin", + "transformer.word_embeddings.weight": "pytorch_model-00001-of-00008.bin" + } +} diff --git a/tests/script/models/chatglm-6b/quantization.py b/tests/script/models/chatglm-6b/quantization.py new file mode 100644 index 0000000..6f469f6 --- /dev/null +++ b/tests/script/models/chatglm-6b/quantization.py @@ -0,0 +1,201 @@ +from torch.nn import Linear +from torch.nn.parameter import Parameter + +import bz2 +import torch +import base64 +import ctypes +from transformers.utils import logging + +from typing import List +from functools import partial + +logger = logging.get_logger(__name__) + +try: + from cpm_kernels.kernels.base import LazyKernelCModule, KernelFunction, round_up + + class Kernel: + def __init__(self, code: bytes, function_names: List[str]): + self.code = code + self._function_names = function_names + self._cmodule = LazyKernelCModule(self.code) + + for name in self._function_names: + setattr(self, name, KernelFunction(self._cmodule, name)) + + quantization_code = "$QlpoOTFBWSZTWU9yuJUAQHN//////////f/n/8/n///n//bt4dTidcVx8X3V9FV/92/v4B7/AD5FBQFAAAChSgKpFCFAFVSigUAAAEKhSgUUqgFBKigqVREQAABQBQIANDTTIGI00BkZBkNGE0A0BkBkGQGRkaNAaAGQNBoGgDIAAYIGTI0DQAQAaGmmQMRpoDIyDIaMJoBoDIDIMgMjI0aA0AMgaDQNAGQAAwQMmRoGgAgA0NNMgYjTQGRkGQ0YTQDQGQGQZAZGRo0BoAZA0GgaAMgABggZMjQNABABoaaZAxGmgMjIMhowmgGgMgMgyAyMjRoDQAyBoNA0AZAADBAyZGgaAAmqU1NEgJqnptU/Sn4jRR6J6epk2pqb1Q/SgAPUGgyNNGjQ2SBpoAZAAGg0NB6mgDIAAAAA2oaApSREBNAARhGiYEaEwU8pvImlP0k2aam1GaGqbFNM1MHpTwmkepmyU9R6nqPKekHqNNPUxNGhp6n6p6QaZ6o9TG1GMqcoV9ly6nRanHlq6zPNbnGZNi6HSug+2nPiZ13XcnFYZW+45W11CumhzYhchOJ2GLLV1OBjBjGf4TptOddTSOcVxhqYZMYwZXZZY00zI1paX5X9J+b+f4e+x43RXSxXPOdquiGpduatGyXneN696M9t4HU2eR5XX/kPhP261NTx3JO1Ow7LyuDmeo9a7d351T1ZxnvnrvYnrXv/hXxPCeuYx2XsNmO003eg9J3Z6U7b23meJ4ri01OdzTk9BNO96brz+qT5nuvvH3ds/G+m/JcG/F2XYuhXlvO+jP7U3XgrzPN/lr8Sf1n6j4j7jZs+s/T0tNaNNYzTs12rxjwztHlnire3Nzc3N1wuBwOBwXBvZfoHpD7rFmR99V5vj3aXza3xdBbXMalubTg/jIv5dfAi54Pdc75j4z412n3Npj3Ld/ENm7a3b/Cod6h/ret1/5vn/C+l+gdslMvgPSLJ8d8q+U66fevYn/tW1chleEtNTGlcHCbLRlq0tHzF5tsbbZZfHjjLgZu42XCuC3NrdjTasZGNzgxPIrGqp7r3p7L2p5XjnpPSmTd5XtzqnB6U87zzg1Ol0zd0zsLszxR6lkxp35u6/teL0L0W922cR7Lu1lpL9CsHirzuM2T+BgsyViT6LHcm0/Vr6U/7LGGyJeqTEjt0PHWhF5mCT7R9mtlDwriYv0Tyr/OxYt6qp5r0mPVT0608TqnqMZaarU2nFwrTzzlrs1ed7z1ux60wyr4ydCaTi3enW8x68x0zU7tXSlcmPSW1mGpWJMg4zmPC2lK96tp0OE80y4MfEvnZj8zGluR6b22ki1Ou9V2nCd9xovcPvcYMZYy0lvN60ScZ45vN6yeCeeXFb1lVjnnCar5fwXwE2bzJ4HI1XVPXfXZMm44GUsMpYsmLB65TuVdm0cl0b+i/wGNN66XjeV7zuPpHcnK/juhhjdfId5jMdE5nN0dGmmm2zZs2cexD5n9p/dY352XsvXHaZNWWsmmS1atjR452nYudzvqv2HMRyvNNnlMcDl3R2+yx2uVrBubTW9icHDVtbNXlZm7jma1rM4VurZZd2y6nUau7ZXZ7bVU+mnoOVxZGMrVmvX60605JwmzGZhhhjTWtaaaMaaGTGmNMZasY0iX8VMUl8eepaIrzGSpemWOQyZORk2bNpjUybMmxqYmknCGCFynutfksaZpjTNMaaatM0xsxcGR0sociNqxNSmhhR1ZJPbsn8qyF0t2qH6iYBclclalbtTTcHTDsPaX6rlnElph2Jyumumtynv2Kk8GI7rsvXbIcJgHJOSaSXnnGaI3m87RtVXJOZ/YtgdTE6Wpha6ZlE8ayXkef1fh602r2WwvfMXtMdLlkfnLFdYYwYso+bWqm7yJqHXZGw2nrS5ZanSYnWlxBxMF1V940K2wdrI7R6OYf7DGGamMmTSbRhlS45xmVOumF1EyPCmHrrN8wwZOOrdNtLeMtzFzDlWnfTBxMk2NaXIZHBYxYLD4w8yju0ao65Vz1OIXoS9dLanwCe1PWrYuWMqf1if1z2k2yYfKJ741PDgno1ZQ8DRqvUny3mNoWTzGO6m1DkrJI8JiR5cSd+vZdGOO8nrMoc5+NDUFsMSXaZJeNlMmGLtJsovOsUp7I9S5VojKxF6bTVEelXqlfJobQr3LozSh2Jk7VcrVMfhXqszGWMzNqGhqZY0OadxkyyMssKugZR0KNFXBHlqwmJgTE/BNVMk6ItJXZMR0H47GpXv/DMOvNkmVuaV1PRfEdxuqc7Hcd+ZV/zTLaRxWk0nl9CdCeM6mn5rstHIBcpiuwmUZXeq81DacHI2rmrZ5SuE5mOZd6LQrZg9mx32TprA8BMo5jKN6yLTCi3WzQaZSuhzTtM1fUTGVpG8Tw+KXI0tjEpiWxtLYynOlktSbVlaI5kxP8TDH8kx50xoxi5KcA4pcja8KWLRlO/Ks6q06ergnvm1ca3Tq8Uw7LTUsmWyctXPWmpitl/uvGcWTGXGuAXDfhqazGmjkxcJW5hMMMMpYsXl2TZYtVOddG3XCarUt6Ptq9CZXSNzyuRzqRZOjsxdBbFVz6OA5HI43r1jityVlVpVkxmOsyaYWE1NTGq1sOVh36mHMcxtSvcy70edG0ZGR3I1Go1GRlV7mWWo1G0ZGRqlvH40l7o4m5xMWLLLYyNjnqc8556mdPqLJ31n/1nWOncxzG1tizrHs/Z+d2vP/B/l8wdJ6rHUn2nbbDq4p6htFtYzMMMTaZis1K5GKzGNmxhmUx2DDlZ/qNnIx41xnaMfCZWYaZWtNLTNW8ND4Fw1MyZOCdM428suKG1ehW8TesOydg7J+YYcD4cYR+8dFK6M4E3HM9ZfRNNL+Sn6rsl4DsrDl2HpPCnfxjGXtbZtYys1ttlyJ4T+BvexjGWRjMszK4Jpc77D3GyuVD7q0+G8m9G+2+rGm7cOR2y7FdtY2XUYx/oNlfRYxhMYyYZkyyg55enna9Kt/FFi6GMMwYwdwxWgxGMLKYmUyGExTKMZkMFhkymKuh0NOBNnBu+23LdwDoZYYzGGMxtORaTU1pjTGWTTGGtMrNWUsyyTTLLG1qy2ZjbK2DBllWqxMtBMaYZQmcE7zvvRcTkclUwdkxTaSdyySt/7fpL+T1v516Ji97fwr5JbLu305zMn5+GMTTZ9F+y7ExwmGVfG44yxn3dLv6l5i+Wth1jCrDq21nW9LqvvDzz3Vf3LLH/O/32TJ/erx3bXftO4eF+G956D952K/An4NfvOpjFjExjevP/UmE0fIoZXx6/w6lX/no3D0bLt+ixjieBM6ksRd0yB4Lt2SwYNE+gd1detlZWUnpiZfGfFaK+4PyCa/v18V8X75pe9fLXzp7l3VjF76vWZmHwGz1IZNWT7b8yddJ4q5kyrVdfru6atWc7bVYztL9Jf4GXvT+Y8m9/YsXP6H018a8D4XVOqvfzqeR+6yZOD8dPv0+U7/q5Pl+2dNb0MjzGVH5p6MNQ7cOWvw62U9aHE8DprDek+McLyvDz+te+9Zhq5+YTruufMcWMabqysTmZVWjKPfnK0wyVcrsuhjZRdLkHNvD72b9abriOSGIxiLixMOoalNPXzy+wT/tf+U6HHONfsz+xe8ufHBdQWWGWLA9if0rsnmrxK5LvRZQeWsTCsrmOYy8VteVfuRfcVTtDLItLIsMYxZLdU/DbtSemxF6Z6Zo5WBXE4tFdCyVMMXMTEMZXVlS6Xec2T4e0tHsRcEuWshcJ2YsNF5rUx1E8ifCq6Z+ZP7qdCeu/aTwFd53l16/o0NOw6O3dLavP4Hbi4RdmuDk6DoYaninC0+o4uZjbJ7Rxeu0/FbuFg+q7DVS6fQe0rZ6NDGUNNU6DEqOaLTicKnYZMnBWruljQxoaS3dZhocDge0bSTyOvdAbG5hxe2xji7E/L55xX13wWNDi6HCekcFxfCPGxY0MXC+s7afWaMdDyjyr+o8Rudm/NabOZvdl274zH4f5XK9z6On1Pe/K5TdPAslg77BjuO6Y3eO7GqvOPG/stknp1leyvLL0Z7bl9I4noMvLkzytLhWYzrOZzLXCORe028rORzOg4N/L0HlMOQ3Pgmnbb6KczlabORpu980q37TBqRu0/p3PO6234Bl03Ynuz+9W7gnsEcmvYaYY3aMYY0wx3pYd+ujsXauWdaY5Xkbtl23fPzFHiDB/QMo0yFjBllYxTQYYyxkrwn7JufwJ/PfgJ+C83X69ni6zvXcnyXabv0ncbLwsceS+RNlyN2mnneJtX0ngYO0+e+0+UnA+Wch3ji8hj5an4h+i6XBySU4n+R0roVcbw5yvHrmr4Yw8Y7x6c+9POPYHI5HI5HI5HI5HGXGww4nE4nrVyOR8XeqPEO7PLOiukYa3Novk5hV4cdtYZLI93e+uxff2jRo0aNGjRo0aNG1bVtW1dy3m83m8+tQ5ZzHw3nObwOu8La9Rc1dtkdS8A3eTk823tnktXWlxN6Oixe06zrN70Isd9jiOgZFq9yfkPqP/SLhN2Myl8jDM43bl1nbcb4cO57jlh8Jow6pzXZdL4dyODTuuhu77FyO27DdwdRxmvO+O+3N2+BdqyTwLHVczDVY4UPE4O66/ZO2cx1LFzVdSXtF7G4HMbrauOHRw6c8FdZ5m9fHZHYZXfTlZquyynSyTTKke6vcffSD9pzPA/G7n7jxPmuhc1DHMynPMrGL6AdewYmwu5ko+UUyTwrMv27rPH1v1nGqd87+p6N6LU8k3NEng53xXyHS97+44OSg/sy/hn+Se6yfYNjW0/uTgP+PvWYzLMmjhcLB/gGpri6H83/84eUXWT6T9Hsv7785z/7z4icpW+zfXypuR7rx/gMdZb1/wC678pcs8/2a3mDitGHxl9mfPlll5MafWWqxk/eYuTDgcNMzDGWLWvsuglNxs53GtN6uWpktlW1tZZYcuinMMWmnNnJydze3b2Y1McBxrBkXw799izLMZZYyy0TkbsGM4p03S2uVu5s/XXUdSdec6smVxZYYGpVmT8A+8ajuEyV5FatkvVru2x6uxGXXbH4A+jvgP4GMYy3iPLXzq/6z65+E005ey+cwMZD3fZcqc6xpjTFjQ0P3U+e++cPYmTIwj0nrK5NPTfl3WvpfLtXDcb2HQMudYOxFXQBor4L4T6vrOauFctYXJQ++NUWmJe5bmx1jDiZS1dTqWxo4GR8jm3fttpmPHppk9PEyv4/y8/sO07XacOmcqc0x2Vi9BvNJvN5oW8x4mOsydpidRxMYJPx06m1bqPzq9KtK8sxXNXFodD/+MYYaJTLwOhc9brCsV18oOR1i4tXChyTkq4lf4y1Ke+9axjDHqs1mfBbMXuP4Hzi+X7t8vzv7bHerrUPgPCxhjre4fXdfLNtNM+Jd+Zdh8xd8wP87uNPoPgv4W7/5P2BuxfsMabNnMnza+54Pdi5U671GPZY8CehX8Voeoo7FHpkeEc6715FwHZrIrUrHaviPUbPZHND+IhczrP6FcYvhOZ0Di/ETt0OI+YwNWR9r7tpf6WDeZKZDB1+z2IthOl1mPyb5FluvEx9h9d0NnM0Y1XPFkWIsk1WotJ0PBMmkvjvQTd0e71tfeV+8r8lQ/tpzpsmxJ+InrI/dj2UajUajVTUajatRqNRtGo1Go1Go4wjeMpZFMVV9CHbofPraLsJ3JpWV2XOoanCuFky4y3PPNxucK2uKC1Lbdb1eo+m5XomN6HfeZsabHLHRX/K+offtNGGmHWctcVcG44MdSqsOLY9VzX+Zxfxn2HPdWTpzWvkrtJ8M5zorrKcquRytJ5N5DZmcaW02l76nWO+BqPXm1A2Ry/0q71dH/mqrqeFjkYxjEXtsX8qubTk67rGycyqsdm4tZx5D6D5hhi0waaWmiaMP81Yjii5qxPlPuU/GfTL1Y5E6Jyfiq63qTa39A4J0sOGDgO9WF9bOXl0XfPRbsY2bPNKPy1YrFYrFYmRhhlTIyMjJWJYZHXuCXI8OoXsvfljGLFicNifpp2XunoPiG1wtx3p1Tah+/DD66OnVtVXP9rKbVxOnL0tR/rHtqB5UDErUVcl11D4qqvjpOcxX7armUNJB3LpW6bxVvD08e8h3odKKvyCFZBdSh2FVcST9xV3n3T8t1j7Kr9qgrqXg+13Pt5U7JCvFXVIV1YG5lRhkVYZJYYDDD4KOIMoHCp26WS8GB7uBh2zIdgq/PKyInjV2STShuoapUdCpX1yTwqq/z1VvET7Kh5nVPkO8YyxjLt2MaaMmWTLQvx3qnzltnXW0p2jxgbEtSny/Osv8Y9pLMXYoHVPAhkVdWVeODhR6q9/Sxe2liwwZWMVvFXfRkeIDxAePUPIrdJ4ey6yquzH+PD/bUOWAu05qVHtFd8rrKHSoeNIOUqrYr3FXyToqfYJgwmJdKpXXOwYYegNNGMzfZPp/t3t/DVs4zjNTN61rRqaWaa4NYbRjTa0tWwy2Y2tGN8ZO8ofNKq4j9SL7I+cSm4/6ovLV5HNXLI0jJidwrtk6ynCaP6Z++GjRlWS3tLeW129Mi9evxU9mtz6s5J3Z7M2ngTgnKvmpomxpaLCzPfmx0JWE+m3NLDDGOX47RctdYYNK5jakdqLkRlI39n590T5zctGSwwZZDJj6kW8XSi6ot2MmWWJ0DUT3nuvebBudScjZ79g8cWJ8av0k+/bE5WKd5MdbFpbDVMxu1DVMmtNZGJvq1mtRbn6M+g/kP0FwDwr7quZs7xosNGpbscyxhhd9TyJyFwbLcxlTasg75vW7TsV5K7ji44XPMMrdoj+Y3rT0Hie62nlYV/pwczzOmdLqLhYkzGMzCZWGMQzGMSsZYY6Di1t4nlJ+Em63mJxrVLxPbYxNEdgc1dU2iOKyoYYWjNrEeHTYybVk0atSa7ehuwsWMWTqn1TrnS6hYsi71d1+s+k+ic70e20fzE/VaTdxT9ZtU4GIXdeNx3X77guYYfpHeTQjaMX6brOu4OY4K7Y2d9mbHarI5ox3p4GpJ2Vd/Tst60f7j999pppjR+Q/Qf8J/VaORs3cji7FfFuN61+ui9s8hix1OCh5KGVV23BPXvZfz3CLyHpix+exi8z/KnCnosY2eunor+cxyPO/xJ0vKey9OvE9VjqaYu0x3Z3jd6o2b1T12D+F8l232lwaaacD5LE8LBxu7WTlbWraWpew8Xexjel3E+wWD4APITdNqR8F3R3T0lunCQ4GaE9R37DxeCYfcHi4xci5ovKfxVs55y2hf+65E/Xdp6jR5nrebTmi5incpkyOjs50JvrZwstbbW6kfuuQw+2mykf/EXNFzxfKTrxew929TR6bWnGL//F3JFOFCQT3K4lQ" + + kernels = Kernel( + bz2.decompress(base64.b64decode(quantization_code)), + [ + "int4WeightCompression", + "int4WeightExtractionFloat", + "int4WeightExtractionHalf", + "int8WeightExtractionFloat", + "int8WeightExtractionHalf", + ], + ) +except Exception as exception: + kernels = None + logger.warning("Failed to load cpm_kernels:" + str(exception)) + + +class W8A16Linear(torch.autograd.Function): + @staticmethod + def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width): + ctx.inp_shape = inp.size() + ctx.weight_bit_width = weight_bit_width + out_features = quant_w.size(0) + inp = inp.contiguous().view(-1, inp.size(-1)) + weight = extract_weight_to_half(quant_w, scale_w, weight_bit_width) + ctx.weight_shape = weight.size() + output = inp.mm(weight.t()) + ctx.save_for_backward(inp, quant_w, scale_w) + return output.view(*(ctx.inp_shape[:-1] + (out_features,))) + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + inp, quant_w, scale_w = ctx.saved_tensors + weight = extract_weight_to_half(quant_w, scale_w, ctx.weight_bit_width) + grad_output = grad_output.contiguous().view(-1, weight.size(0)) + grad_input = grad_output.mm(weight) + grad_weight = grad_output.t().mm(inp) + return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None, None + + +def compress_int4_weight(weight: torch.Tensor): # (n, m) + with torch.cuda.device(weight.device): + n, m = weight.size(0), weight.size(1) + assert m % 2 == 0 + m = m // 2 + out = torch.empty(n, m, dtype=torch.int8, device="cuda") + stream = torch.cuda.current_stream() + + gridDim = (n, 1, 1) + blockDim = (min(round_up(m, 32), 1024), 1, 1) + + kernels.int4WeightCompression( + gridDim, + blockDim, + 0, + stream, + [ctypes.c_void_p(weight.data_ptr()), ctypes.c_void_p(out.data_ptr()), ctypes.c_int32(n), ctypes.c_int32(m)], + ) + return out + + +def extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tensor, source_bit_width: int): + if source_bit_width == 8: + func = kernels.int8WeightExtractionHalf + elif source_bit_width == 4: + func = kernels.int4WeightExtractionHalf + else: + assert False, "Unsupported bit-width" + + with torch.cuda.device(weight.device): + n, m = weight.size(0), weight.size(1) + out = torch.empty(n, m * (8 // source_bit_width), dtype=torch.half, device="cuda") + stream = torch.cuda.current_stream() + + gridDim = (n, 1, 1) + blockDim = (min(round_up(m, 32), 1024), 1, 1) + + func( + gridDim, + blockDim, + 0, + stream, + [ + ctypes.c_void_p(weight.data_ptr()), + ctypes.c_void_p(scale_list.data_ptr()), + ctypes.c_void_p(out.data_ptr()), + ctypes.c_int32(n), + ctypes.c_int32(m), + ], + ) + return out + + +class QuantizedLinear(Linear): + def __init__(self, weight_bit_width: int, weight_tensor=None, bias_tensor=None, empty_init=False, *args, **kwargs): + super(QuantizedLinear, self).__init__(*args, **kwargs) + self.weight_bit_width = weight_bit_width + + shape = self.weight.shape + del self.weight + + if weight_tensor is None or empty_init: + self.weight = torch.empty( + shape[0], shape[1] * weight_bit_width // 8, dtype=torch.int8, device=kwargs["device"] + ) + self.weight_scale = torch.empty(shape[0], dtype=kwargs["dtype"], device=kwargs["device"]) + else: + self.weight_scale = (weight_tensor.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).half() + self.weight = torch.round(weight_tensor / self.weight_scale[:, None]).to(torch.int8) + if weight_bit_width == 4: + self.weight = compress_int4_weight(self.weight) + + self.weight = Parameter(self.weight.to(kwargs["device"]), requires_grad=False) + self.weight_scale = Parameter(self.weight_scale.to(kwargs["device"]), requires_grad=False) + if bias_tensor is not None: + self.bias = Parameter(bias_tensor.to(kwargs["device"]), requires_grad=False) + else: + self.bias = None + + def forward(self, input): + output = W8A16Linear.apply(input, self.weight, self.weight_scale, self.weight_bit_width) + if self.bias is not None: + output = output + self.bias + return output + + +def quantize(model, weight_bit_width, empty_init=False, **kwargs): + """Replace fp16 linear with quantized linear""" + + for layer in model.layers: + layer.attention.query_key_value = QuantizedLinear( + weight_bit_width=weight_bit_width, + weight_tensor=layer.attention.query_key_value.weight.to(torch.cuda.current_device()), + bias_tensor=layer.attention.query_key_value.bias, + in_features=layer.attention.query_key_value.in_features, + out_features=layer.attention.query_key_value.out_features, + bias=True, + dtype=torch.half, + device=layer.attention.query_key_value.weight.device, + empty_init=empty_init + ) + layer.attention.dense = QuantizedLinear( + weight_bit_width=weight_bit_width, + weight_tensor=layer.attention.dense.weight.to(torch.cuda.current_device()), + bias_tensor=layer.attention.dense.bias, + in_features=layer.attention.dense.in_features, + out_features=layer.attention.dense.out_features, + bias=True, + dtype=torch.half, + device=layer.attention.dense.weight.device, + empty_init=empty_init + ) + layer.mlp.dense_h_to_4h = QuantizedLinear( + weight_bit_width=weight_bit_width, + weight_tensor=layer.mlp.dense_h_to_4h.weight.to(torch.cuda.current_device()), + bias_tensor=layer.mlp.dense_h_to_4h.bias, + in_features=layer.mlp.dense_h_to_4h.in_features, + out_features=layer.mlp.dense_h_to_4h.out_features, + bias=True, + dtype=torch.half, + device=layer.mlp.dense_h_to_4h.weight.device, + empty_init=empty_init + ) + layer.mlp.dense_4h_to_h = QuantizedLinear( + weight_bit_width=weight_bit_width, + weight_tensor=layer.mlp.dense_4h_to_h.weight.to(torch.cuda.current_device()), + bias_tensor=layer.mlp.dense_4h_to_h.bias, + in_features=layer.mlp.dense_4h_to_h.in_features, + out_features=layer.mlp.dense_4h_to_h.out_features, + bias=True, + dtype=torch.half, + device=layer.mlp.dense_4h_to_h.weight.device, + empty_init=empty_init + ) + return model diff --git a/tests/script/models/chatglm-6b/test_modeling_chatglm.py b/tests/script/models/chatglm-6b/test_modeling_chatglm.py new file mode 100644 index 0000000..814c8bd --- /dev/null +++ b/tests/script/models/chatglm-6b/test_modeling_chatglm.py @@ -0,0 +1,245 @@ +import datetime +import math +import unittest +import torch +import random +import time + +from transformers import AutoTokenizer, AutoModel +from transformers.testing_utils import require_torch, slow, torch_device + +from torch.profiler import profile, record_function, ProfilerActivity + +def set_random_seed(seed): + import random + + random.seed(seed) + + # pytorch RNGs + import torch + + torch.manual_seed(seed) + torch.backends.cudnn.deterministic = True + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + # numpy RNG + import numpy as np + + np.random.seed(seed) + + + +def ids_tensor(shape, vocab_size): + # Creates a random int32 tensor of the shape within the vocab size + total_dims = 1 + for dim in shape: + total_dims *= dim + + values = [] + for _ in range(total_dims): + values.append(random.randint(0, vocab_size - 1)) + + return torch.tensor(data=values, dtype=torch.long, device=torch_device).view(shape).contiguous() + + +def get_model_and_tokenizer(): + model = AutoModel.from_pretrained("/home/luocheng/openvino/src/plugins/intel_cpu/thirdparty/llmdnn/tests/script/models/chatglm-6b", trust_remote_code=True) + model.to(torch_device, dtype=torch.bfloat16) + print(f"torch_device={torch_device}") + model.eval() + tokenizer = AutoTokenizer.from_pretrained("/home/luocheng/openvino/src/plugins/intel_cpu/thirdparty/llmdnn/tests/script/models/chatglm-6b", trust_remote_code=True) + return model, tokenizer + + +@require_torch +class ChatGLMGenerationTest(unittest.TestCase): + def get_generation_kwargs(self): + pass + + def ntest_chat(self): + print("======================test_chat") + model, tokenizer = get_model_and_tokenizer() + prompts = ["你好", "介绍一下清华大学", "它创建于哪一年"] + history = [] + set_random_seed(42) + expected_responses = [ + '你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。', + '清华大学是中国著名的综合性研究型大学,位于中国北京市海淀区,创建于 1911 年,前身是清华学堂。作为我国顶尖高等教育机构之一,清华大学在科学研究、工程技术、信息技术、经济管理等领域处于领先地位,也是世界上最著名的工程学府之一。\n\n清华大学拥有世界一流的教学设施和科学研究平台,设有多个学院和研究中心,包括工程学院、自然科学学院、社会科学学院、人文学院、法学院、经济管理学院等。学校拥有众多知名教授和研究团队,其中包括多位院士、国家杰出青年科学基金获得者、长江学者等。\n\n清华大学的本科生招生范围为全国中学毕业生,本科生入学要求严格,考试成绩优秀。同时,清华大学也提供研究生和博士生招生,包括硕士研究生和博士研究生。', + '清华大学创建于 1911 年。' + ] + for (prompt, expected_response) in zip(prompts, expected_responses): + response, history = model.chat(tokenizer, prompt, history=history) + print(repr(response)) + self.assertEqual(expected_response, response) + + def ntest_stream_chat(self): + print("======================test_stream_chat") + model, tokenizer = get_model_and_tokenizer() + prompts = ["你好", "介绍一下清华大学", "它创建于哪一年"] + history = [] + expected_responses = [ + '你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。', + '清华大学是中国著名的综合性研究型大学,位于中国北京市海淀区,创建于 1911 年,前身是清华学堂。作为我国顶尖高等教育机构之一,清华大学在科学研究、工程技术、信息技术、经济管理等领域处于领先地位,也是世界上最著名的工程学府之一。\n\n清华大学拥有世界一流的教学设施和科学研究平台,设有多个学院和研究中心,包括工程学院、自然科学学院、社会科学学院、人文学院、法学院、经济管理学院等。学校拥有众多知名教授和研究团队,其中包括多位院士、国家杰出青年科学基金获得者、长江学者等。\n\n清华大学的本科生招生范围为全国中学毕业生,本科生入学要求严格,考试成绩优秀。同时,清华大学也提供研究生和博士生招生,包括硕士研究生和博士研究生。', + '清华大学创建于 1911 年。' + ] + set_random_seed(42) + for prompt, expected_response in zip(prompts, expected_responses): + response = "" + for idx, (response, history) in enumerate(model.stream_chat(tokenizer, prompt, history=history)): + pass + print(repr(response)) + self.assertEqual(expected_response, response) + + def ntest_generation(self): + print("======================test_generation") + model, tokenizer = get_model_and_tokenizer() + sentence = "晚上睡不着怎么办" + parameters = [(False, 2048, 1), + (False, 64, 1), + (True, 2048, 1), + (True, 64, 1), + (True, 2048, 4)] + expected_out_sentences = [ + '晚上睡不着怎么办 以下是一些可能有助于在晚上入睡的方法:\n\n1. 保持规律的睡眠时间表:尽量在同一时间上床,并尝试在早上醒来时自然起床。\n\n2. 创建舒适的睡眠环境:保持房间安静、凉爽、黑暗、舒适,并使用舒适的床垫和枕头。\n\n3. 避免刺激性物质:避免饮用含咖啡因的饮料,如咖啡、茶和可乐,并尽可能减少饮酒。\n\n4. 放松身心:尝试进行放松的活动,如冥想、深呼吸、瑜伽或听轻柔的音乐。\n\n5. 避免在床上做其他事情:例如看电视、使用电脑或智能手机等。\n\n6. 练习放松技巧:例如渐进性肌肉松弛法、冥想或深呼吸练习。\n\n7. 寻求帮助:如果长时间都无法正常入睡,可以考虑咨询医生或专业心理医生,寻求更进一步的帮助。\n\n希望这些方法能有助于入睡。', + '晚上睡不着怎么办 以下是一些可能有助于在晚上入睡的方法:\n\n1. 保持规律的睡眠时间表:尽量在同一时间上床,并尝试在早上醒来时自然起床。\n\n2. 创建舒适的睡眠环境:保持房间安静、凉爽、黑暗、舒适,并使用舒适的床垫和枕头。', + '晚上睡不着怎么办 以下是一些有助于在晚上更好地入睡的方法:\n\n1. 维持规律的睡眠时间:每晚尽可能在同一时间上床,保持规律的睡眠时间表,帮助身体调整并更容易入睡。\n\n2. 避免在床上使用电子设备:手机、平板电脑、电脑等电子设备会发出蓝光,这会干扰身体释放褪黑素,进而导致难以入睡。建议你在睡前一小时停止使用这些设备。\n\n3. 创建舒适的睡眠环境:确保卧室安静、黑暗、凉爽,舒适的床垫和枕头,保持卧室温度适宜,这有助于让你更容易入睡。\n\n4. 放松身心:尝试进行一些放松的活动,如冥想、深呼吸、瑜伽或轻松的散步,减轻压力和焦虑,让你更容易入睡。\n\n5. 避免咖啡因和酒精:咖啡因和酒精会让大脑更加兴奋,进而干扰身体入睡过程。建议在睡前几小时避免饮用这些物质。\n\n6. 做一些安静的活动:阅读一本书、听轻柔的音乐、绣或者绘画等安静的活动,有助于自己放松身心,进而更容易入睡。\n\n如果采取以上这些方法仍然无法入睡,建议咨询医生或专业的睡眠专家,获取更好的建议和帮助。', + '晚上睡不着怎么办 以下是一些有助于在晚上更好地入睡的方法:\n\n1. 维持规律的睡眠时间:每晚尽可能在同一时间上床,保持规律的睡眠时间表,帮助身体调整并更容易入睡。\n\n2. 避免在床上使用电子设备:手机、平板电脑、电脑等电子设备会发出蓝光,这会干扰身体', + '晚上睡不着怎么办 以下是一些可能有助于在晚上入睡的方法:\n\n1. 建立规律的睡眠时间表:尽量在同一时间入睡和起床,即使在周末和假期也要尽量保持一致。\n\n2. 创造舒适的睡眠环境:保持房间安静、凉爽、黑暗、舒适,使用舒适的床垫和枕头等。\n\n3. 放松身心:尝试进行一些放松的活动,如冥想、深呼吸、瑜伽、听轻柔的音乐等,缓解压力和紧张情绪。\n\n4. 避免刺激性物质:避免饮用咖啡、茶、可乐等含咖啡因的饮料,避免吸烟和饮酒等刺激性物质。\n\n5. 避免躺在床上翻来覆去:如果躺在床上超过20分钟还不能入睡,就不要躺在床上翻来覆去,而是起床去做一些放松的活动,直到感到困倦为止。\n\n6. 练习放松技巧:如果感到焦虑或紧张,可以尝试进行一些放松技巧,如渐进性肌肉松弛、冥想等。\n\n7. 改善睡眠障碍:如果已经尝试了上述方法仍然无法入睡,可以考虑咨询医生,了解是否存在其他睡眠障碍问题,并接受相应的治疗。'] + for (do_sample, max_length, num_beams), expected_output_sentence in zip(parameters, expected_out_sentences): + set_random_seed(42) + inputs = tokenizer([sentence,], return_tensors="pt", padding=True) + inputs = inputs.to(torch_device) + print(inputs) + outputs = model.generate( + **inputs, + do_sample=do_sample, + max_length=max_length, + num_beams=num_beams + ) + print(outputs) + outputs = outputs.tolist()[0] + out_sentence = tokenizer.decode(outputs, skip_special_tokens=True) + print(out_sentence) + self.assertEqual(expected_output_sentence, out_sentence) + + def test_batch_generation(self): + print("======================test_batch_generation") + model, tokenizer = get_model_and_tokenizer() + sentences = [ + "你好", + "介绍一下清华大学" + ] + parameters = [(False, 2048, 1), + (False, 64, 1), + (True, 2048, 1), + (True, 64, 1), + (True, 2048, 4)] + expected_out_sentences = [ + ['你好 你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。', + '介绍一下清华大学 清华大学是中国著名的综合性大学,位于北京市海淀区双清路30号,其历史可以追溯到1911年创建的清华学堂,1925年更名为清华学校,1937年抗日战争全面爆发后南迁长沙,1946年迁回清华园。新中国成立后,清华学校更名为清华大学。\n\n清华大学是中国最顶尖的大学之一,在工程、科学、技术、经济、管理等领域都有很高的学术声誉和影响力。学校拥有世界一流的教学设施和科学研究平台,有多个学院和研究中心,包括工程学院、自然科学学院、人文学院、社会科学学院、经济管理学院、法学院、美术学院、医学院、器学院等。\n\n清华大学的本科生招生始于2000年,实行全面二孩政策后,本科生招生规模不断扩大。截至2022年,清华大学共有本科生近3万人,研究生近2万人,其中国际学生占比约为10%。清华大学的本科生教育注重通识教育和个性化培养,强调实践、创新、国际化和综合素质。'], + [ + '你好 你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。', + '介绍一下清华大学 清华大学是中国著名的综合性大学,位于北京市海淀区双清路30号,其历史可以追溯到1911年创建的清华学堂,1925年更名为清华学校,1937年抗日战争全面爆发后南迁长沙,1946年迁回' + ], + [ + '你好 你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。', + '介绍一下清华大学 清华大学是中国著名的综合性研究型大学,位于北京市海淀区双清路 30 号,其溯源于 1911 年创建的清华学堂, 1925 年更名为清华学校, 1937 年秋抗日战争全面爆发后闭校。1949 年 10 月开学复校,成为我国第一个社会主义大学生活了的高校。截至 2023 年,清华学校共管辖 2 个学院、13 个系,有本科专业 60 个,研究生专业 190 个。' + ], + [ + '你好 你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。', + '介绍一下清华大学 清华大学是中国著名的综合性研究型大学,位于北京市海淀区双清路 30 号,其溯源于 1911 年创建的清华学堂, 1925 年更名为清华学校, 1937 年秋抗日战争全面爆发后' + ], + [ + '你好 你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。', + '介绍一下清华大学 清华大学是中国著名的综合性研究型大学,位于北京市海淀区双清路30号,其历史可以追溯到1911年创建的清华学堂,1925年更名为清华学校,1937年抗日战争全面爆发后南迁长沙,与北京大学、南开大学组建国立长沙临时大学,1938年迁至 昆明改名为国立西南联合大学,1946年迁回北京。新中国成立后,清华学校更名为清华大学。' + ] + ] + for (do_sample, max_length, num_beams), expected_output_sentence in zip(parameters, expected_out_sentences): + set_random_seed(42) + inputs = tokenizer(sentences, return_tensors="pt", padding=True) + inputs = inputs.to(torch_device) + print(inputs) + outputs = model.generate( + **inputs, + do_sample=do_sample, + max_length=max_length, + num_beams=num_beams + ) + print(outputs) + batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True) + print(batch_out_sentence) + self.assertListEqual(expected_output_sentence, batch_out_sentence) + +def mytest(use_jit = False): + model, tokenizer = get_model_and_tokenizer() + + # import intel_extension_for_pytorch as ipex + # model = ipex.optimize(model, dtype=torch.bfloat16) + + + sentence = "世界羽毛球史上最伟大的球员都有谁?世界羽毛球史上最伟大的球员都有谁?世界羽毛球史上最伟大的球员都有谁?世界羽毛球史上最伟大的球员都有谁?" + parameters = [(False, 2048, 1), + #(True, 2048, 1), + #(True, 2048, 4) + ] + expected_out_sentences = [ + '晚上睡不着怎么办 以下是一些可能有助于在晚上入睡的方法:\n\n1. 保持规律的睡眠时间表:尽量在同一时间上床,并尝试在早上醒来时自然起床。\n\n2. 创建舒适的睡眠环境:保持房间安静、凉爽、黑暗、舒适,并使用舒适的床垫和枕头。\n\n3. 避免刺激性物质:避免饮用含咖啡因的饮料,如咖啡、茶和可乐,并尽可能减少饮酒。\n\n4. 放松身心:尝试进行放松的活动,如冥想、深呼吸、瑜伽或听轻柔的音乐。\n\n5. 避免在床上做其他事情:例如看电视、使用电脑或智能手机等。\n\n6. 练习放松技巧:例如渐进性肌肉松弛法、冥想或深呼吸练习。\n\n7. 寻求帮助:如果长时间都无法正常入睡,可以考虑咨询医生或专业心理医生,寻求更进一步的帮助。\n\n希望这些方法能有助于入睡。', + '晚上睡不着怎么办 以下是一些可能有助于在晚上入睡的方法:\n\n1. 保持规律的睡眠时间表:尽量在同一时间上床,并尝试在早上醒来时自然起床。\n\n2. 创建舒适的睡眠环境:保持房间安静、凉爽、黑暗、舒适,并使用舒适的床垫和枕头。', + '晚上睡不着怎么办 以下是一些有助于在晚上更好地入睡的方法:\n\n1. 维持规律的睡眠时间:每晚尽可能在同一时间上床,保持规律的睡眠时间表,帮助身体调整并更容易入睡。\n\n2. 避免在床上使用电子设备:手机、平板电脑、电脑等电子设备会发出蓝光,这会干扰身体释放褪黑素,进而导致难以入睡。建议你在睡前一小时停止使用这些设备。\n\n3. 创建舒适的睡眠环境:确保卧室安静、黑暗、凉爽,舒适的床垫和枕头,保持卧室温度适宜,这有助于让你更容易入睡。\n\n4. 放松身心:尝试进行一些放松的活动,如冥想、深呼吸、瑜伽或轻松的散步,减轻压力和焦虑,让你更容易入睡。\n\n5. 避免咖啡因和酒精:咖啡因和酒精会让大脑更加兴奋,进而干扰身体入睡过程。建议在睡前几小时避免饮用这些物质。\n\n6. 做一些安静的活动:阅读一本书、听轻柔的音乐、绣或者绘画等安静的活动,有助于自己放松身心,进而更容易入睡。\n\n如果采取以上这些方法仍然无法入睡,建议咨询医生或专业的睡眠专家,获取更好的建议和帮助。', + '晚上睡不着怎么办 以下是一些有助于在晚上更好地入睡的方法:\n\n1. 维持规律的睡眠时间:每晚尽可能在同一时间上床,保持规律的睡眠时间表,帮助身体调整并更容易入睡。\n\n2. 避免在床上使用电子设备:手机、平板电脑、电脑等电子设备会发出蓝光,这会干扰身体', + '晚上睡不着怎么办 以下是一些可能有助于在晚上入睡的方法:\n\n1. 建立规律的睡眠时间表:尽量在同一时间入睡和起床,即使在周末和假期也要尽量保持一致。\n\n2. 创造舒适的睡眠环境:保持房间安静、凉爽、黑暗、舒适,使用舒适的床垫和枕头等。\n\n3. 放松身心:尝试进行一些放松的活动,如冥想、深呼吸、瑜伽、听轻柔的音乐等,缓解压力和紧张情绪。\n\n4. 避免刺激性物质:避免饮用咖啡、茶、可乐等含咖啡因的饮料,避免吸烟和饮酒等刺激性物质。\n\n5. 避免躺在床上翻来覆去:如果躺在床上超过20分钟还不能入睡,就不要躺在床上翻来覆去,而是起床去做一些放松的活动,直到感到困倦为止。\n\n6. 练习放松技巧:如果感到焦虑或紧张,可以尝试进行一些放松技巧,如渐进性肌肉松弛、冥想等。\n\n7. 改善睡眠障碍:如果已经尝试了上述方法仍然无法入睡,可以考虑咨询医生,了解是否存在其他睡眠障碍问题,并接受相应的治疗。'] + + jit_model_generated = False + f = open('result.torch.txt', 'w') + for (do_sample, max_length, num_beams), expected_output_sentence in zip(parameters, expected_out_sentences): + set_random_seed(42) + inputs = tokenizer([sentence,], return_tensors="pt", padding=True) + inputs = inputs.to(torch_device) + #print(inputs) + inputs.data['position_ids'] = inputs.data['position_ids'].to(torch.int32) + attn_mask = torch.zeros_like(inputs.data['attention_mask'], dtype=torch.float32) + inputs.data['attention_mask'] = attn_mask.masked_fill_(inputs.data['attention_mask'], -10000.0) + + if not jit_model_generated and use_jit: + print("generating jit model...") + with torch.no_grad(), torch.cpu.amp.autocast(): + model = torch.jit.trace(model, inputs) + model = torch.jit.freeze(model) + jit_model_generated = True + print("done") + + for repeat in range(1): + t0 = time.time() + + # with profile(activities=[ProfilerActivity.CPU], record_shapes=False) as prof: + # with record_function("model_inference"): + with torch.no_grad(), torch.cpu.amp.autocast(): + outputs = model.generate( + **inputs, + do_sample=do_sample, + max_length=max_length, + num_beams=num_beams + ) + t1 = time.time() + + #prof.export_chrome_trace("trace.json") + #print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10)) + + #print(outputs) + outputs = outputs.tolist()[0] + + if repeat == 0: + out_sentence = tokenizer.decode(outputs, skip_special_tokens=True) + print(f"{out_sentence}") + print(f" #tokens={len(outputs)},do_sample={do_sample},max_length={max_length},num_beams={num_beams}") + f.write(out_sentence) + + print(f" [{repeat}] ::: {(t1-t0)*1e3/len(outputs)} ms/token") + + f.close() + +# numactl -C 0-46 python ./test_modeling_chatglm.py +if __name__ == '__main__': + mytest() + #unittest.main() \ No newline at end of file diff --git a/tests/script/models/chatglm-6b/tokenization_chatglm.py b/tests/script/models/chatglm-6b/tokenization_chatglm.py new file mode 100644 index 0000000..69ee85c --- /dev/null +++ b/tests/script/models/chatglm-6b/tokenization_chatglm.py @@ -0,0 +1,443 @@ +"""Tokenization classes for ChatGLM.""" +from typing import List, Optional, Union +import os + +from transformers.tokenization_utils import PreTrainedTokenizer +from transformers.utils import logging, PaddingStrategy +from transformers.tokenization_utils_base import EncodedInput, BatchEncoding +from typing import Dict +import sentencepiece as spm +import numpy as np + +logger = logging.get_logger(__name__) + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + "THUDM/chatglm-6b": 2048, +} + + +class TextTokenizer: + def __init__(self, model_path): + self.sp = spm.SentencePieceProcessor() + self.sp.Load(model_path) + self.num_tokens = self.sp.vocab_size() + + def encode(self, text): + return self.sp.EncodeAsIds(text) + + def decode(self, ids: List[int]): + return self.sp.DecodeIds(ids) + + def tokenize(self, text): + return self.sp.EncodeAsPieces(text) + + def convert_tokens_to_string(self, tokens): + return self.sp.DecodePieces(tokens) + + def convert_tokens_to_ids(self, tokens): + return [self.sp.PieceToId(token) for token in tokens] + + def convert_token_to_id(self, token): + return self.sp.PieceToId(token) + + def convert_id_to_token(self, idx): + return self.sp.IdToPiece(idx) + + def __len__(self): + return self.num_tokens + + +class SPTokenizer: + def __init__( + self, + vocab_file, + num_image_tokens=20000, + max_blank_length=80, + byte_fallback=True, + ): + assert vocab_file is not None + self.vocab_file = vocab_file + self.num_image_tokens = num_image_tokens + self.special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "", "", "", "", ""] + self.max_blank_length = max_blank_length + self.byte_fallback = byte_fallback + self.text_tokenizer = TextTokenizer(vocab_file) + + def _get_text_tokenizer(self): + return self.text_tokenizer + + @staticmethod + def get_blank_token(length: int): + assert length >= 2 + return f"<|blank_{length}|>" + + @staticmethod + def get_tab_token(): + return f"<|tab|>" + + @property + def num_text_tokens(self): + return self.text_tokenizer.num_tokens + + @property + def num_tokens(self): + return self.num_image_tokens + self.num_text_tokens + + @staticmethod + def _encode_whitespaces(text: str, max_len: int = 80): + text = text.replace("\t", SPTokenizer.get_tab_token()) + for i in range(max_len, 1, -1): + text = text.replace(" " * i, SPTokenizer.get_blank_token(i)) + return text + + def _preprocess(self, text: str, linebreak=True, whitespaces=True): + if linebreak: + text = text.replace("\n", "") + if whitespaces: + text = self._encode_whitespaces(text, max_len=self.max_blank_length) + return text + + def encode( + self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True + ) -> List[int]: + """ + @param text: Text to encode. + @param linebreak: Whether to encode newline (\n) in text. + @param whitespaces: Whether to encode multiple whitespaces or tab in text, useful for source code encoding. + @param special_tokens: Whether to encode special token ([MASK], [gMASK], etc.) in text. + @param add_dummy_prefix: Whether to add dummy blank space in the beginning. + """ + text = self._preprocess(text, linebreak, whitespaces) + if not add_dummy_prefix: + text = "" + text + tmp = self._get_text_tokenizer().encode(text) + tokens = [x + self.num_image_tokens for x in tmp] + return tokens if add_dummy_prefix else tokens[2:] + + def postprocess(self, text): + text = text.replace("", "\n") + text = text.replace(SPTokenizer.get_tab_token(), "\t") + for i in range(2, self.max_blank_length + 1): + text = text.replace(self.get_blank_token(i), " " * i) + return text + + def decode(self, text_ids: List[int]) -> str: + ids = [int(_id) - self.num_image_tokens for _id in text_ids] + ids = [_id for _id in ids if _id >= 0] + text = self._get_text_tokenizer().decode(ids) + text = self.postprocess(text) + return text + + def decode_tokens(self, tokens: List[str]) -> str: + text = self._get_text_tokenizer().convert_tokens_to_string(tokens) + text = self.postprocess(text) + return text + + def tokenize( + self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True + ) -> List[str]: + """ + @param text: Text to encode. + @param linebreak: Whether to encode newline (\n) in text. + @param whitespaces: Whether to encode multiple whitespaces or tab in text, useful for source code encoding. + @param special_tokens: Whether to encode special token ([MASK], [gMASK], etc.) in text. + @param add_dummy_prefix: Whether to add dummy blank space in the beginning. + """ + text = self._preprocess(text, linebreak, whitespaces) + if not add_dummy_prefix: + text = "" + text + tokens = self._get_text_tokenizer().tokenize(text) + return tokens if add_dummy_prefix else tokens[2:] + + def __getitem__(self, x: Union[int, str]): + if isinstance(x, int): + if x < self.num_image_tokens: + return "".format(x) + else: + return self.text_tokenizer.convert_id_to_token(x - self.num_image_tokens) + elif isinstance(x, str): + if x.startswith("") and x[7:-1].isdigit(): + return int(x[7:-1]) + else: + return self.text_tokenizer.convert_token_to_id(x) + self.num_image_tokens + else: + raise ValueError("The key should be str or int.") + + +class ChatGLMTokenizer(PreTrainedTokenizer): + """ + Construct a ChatGLM tokenizer. Based on byte-level Byte-Pair-Encoding. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + """ + + vocab_files_names = {"vocab_file": "ice_text.model"} + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + model_input_names = ["input_ids", "attention_mask", "position_ids"] + + def __init__( + self, + vocab_file, + do_lower_case=False, + remove_space=False, + bos_token='', + eos_token='', + end_token='', + mask_token='[MASK]', + gmask_token='[gMASK]', + padding_side="left", + pad_token="", + unk_token="", + num_image_tokens=20000, + **kwargs + ) -> None: + super().__init__( + do_lower_case=do_lower_case, + remove_space=remove_space, + padding_side=padding_side, + bos_token=bos_token, + eos_token=eos_token, + end_token=end_token, + mask_token=mask_token, + gmask_token=gmask_token, + pad_token=pad_token, + unk_token=unk_token, + num_image_tokens=num_image_tokens, + **kwargs + ) + + self.do_lower_case = do_lower_case + self.remove_space = remove_space + self.vocab_file = vocab_file + + self.bos_token = bos_token + self.eos_token = eos_token + self.end_token = end_token + self.mask_token = mask_token + self.gmask_token = gmask_token + + self.sp_tokenizer = SPTokenizer(vocab_file, num_image_tokens=num_image_tokens) + + """ Initialisation """ + + @property + def gmask_token_id(self) -> Optional[int]: + if self.gmask_token is None: + return None + return self.convert_tokens_to_ids(self.gmask_token) + + @property + def end_token_id(self) -> Optional[int]: + """ + `Optional[int]`: Id of the end of context token in the vocabulary. Returns `None` if the token has not been + set. + """ + if self.end_token is None: + return None + return self.convert_tokens_to_ids(self.end_token) + + @property + def vocab_size(self): + """ Returns vocab size """ + return self.sp_tokenizer.num_tokens + + def get_vocab(self): + """ Returns vocab as a dict """ + vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + def preprocess_text(self, inputs): + if self.remove_space: + outputs = " ".join(inputs.strip().split()) + else: + outputs = inputs + + if self.do_lower_case: + outputs = outputs.lower() + + return outputs + + def _tokenize(self, text, **kwargs): + """ Returns a tokenized string. """ + text = self.preprocess_text(text) + + seq = self.sp_tokenizer.tokenize(text) + + return seq + + def convert_tokens_to_string(self, tokens: List[str]) -> str: + return self.sp_tokenizer.decode_tokens(tokens) + + def _decode( + self, + token_ids: Union[int, List[int]], + **kwargs + ) -> str: + if isinstance(token_ids, int): + token_ids = [token_ids] + if len(token_ids) == 0: + return "" + if self.pad_token_id in token_ids: # remove pad + token_ids = list(filter((self.pad_token_id).__ne__, token_ids)) + return super()._decode(token_ids, **kwargs) + + def _convert_token_to_id(self, token): + """ Converts a token (str) in an id using the vocab. """ + return self.sp_tokenizer[token] + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.sp_tokenizer[index] + + def save_vocabulary(self, save_directory, filename_prefix=None): + """ + Save the vocabulary and special tokens file to a directory. + + Args: + save_directory (`str`): + The directory in which to save the vocabulary. + filename_prefix (`str`, *optional*): + An optional prefix to add to the named of the saved files. + + Returns: + `Tuple(str)`: Paths to the files saved. + """ + if os.path.isdir(save_directory): + vocab_file = os.path.join( + save_directory, self.vocab_files_names["vocab_file"] + ) + else: + vocab_file = save_directory + + with open(self.vocab_file, 'rb') as fin: + proto_str = fin.read() + + with open(vocab_file, "wb") as writer: + writer.write(proto_str) + + return (vocab_file,) + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A BERT sequence has the following format: + + - single sequence: `[CLS] X [SEP]` + - pair of sequences: `[CLS] A [SEP] B [SEP]` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + gmask_id = self.sp_tokenizer[self.gmask_token] + eos_id = self.sp_tokenizer[self.eos_token] + token_ids_0 = token_ids_0 + [gmask_id, self.sp_tokenizer[self.bos_token]] + if token_ids_1 is not None: + token_ids_0 = token_ids_0 + token_ids_1 + [eos_id] + return token_ids_0 + + def _pad( + self, + encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding], + max_length: Optional[int] = None, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + pad_to_multiple_of: Optional[int] = None, + return_attention_mask: Optional[bool] = None, + ) -> dict: + """ + Pad encoded inputs (on left/right and up to predefined length or max length in the batch) + + Args: + encoded_inputs: + Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`). + max_length: maximum length of the returned list and optionally padding length (see below). + Will truncate by taking into account the special tokens. + padding_strategy: PaddingStrategy to use for padding. + + - PaddingStrategy.LONGEST Pad to the longest sequence in the batch + - PaddingStrategy.MAX_LENGTH: Pad to the max length (default) + - PaddingStrategy.DO_NOT_PAD: Do not pad + The tokenizer padding sides are defined in self.padding_side: + + - 'left': pads on the left of the sequences + - 'right': pads on the right of the sequences + pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability + `>= 7.5` (Volta). + return_attention_mask: + (optional) Set to False to avoid returning attention mask (default: set to model specifics) + """ + # Load from model defaults + bos_token_id = self.sp_tokenizer[self.bos_token] + mask_token_id = self.sp_tokenizer[self.mask_token] + gmask_token_id = self.sp_tokenizer[self.gmask_token] + assert self.padding_side == "left" + + required_input = encoded_inputs[self.model_input_names[0]] + seq_length = len(required_input) + + if padding_strategy == PaddingStrategy.LONGEST: + max_length = len(required_input) + + if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): + max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + + needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length + + # Initialize attention mask if not present. + if max_length is not None: + if "attention_mask" not in encoded_inputs: + if bos_token_id in required_input: + context_length = required_input.index(bos_token_id) + else: + context_length = seq_length + attention_mask = np.ones((1, seq_length, seq_length)) + attention_mask = np.tril(attention_mask) + attention_mask[:, :, :context_length] = 1 + attention_mask = np.bool_(attention_mask < 0.5) + encoded_inputs["attention_mask"] = attention_mask + + if "position_ids" not in encoded_inputs: + if bos_token_id in required_input: + context_length = required_input.index(bos_token_id) + else: + context_length = seq_length + position_ids = np.arange(seq_length, dtype=np.int64) + mask_token = mask_token_id if mask_token_id in required_input else gmask_token_id + if mask_token in required_input: + mask_position = required_input.index(mask_token) + position_ids[context_length:] = mask_position + block_position_ids = np.concatenate( + [np.zeros(context_length, dtype=np.int64), + np.arange(1, seq_length - context_length + 1, dtype=np.int64)]) + encoded_inputs["position_ids"] = np.stack([position_ids, block_position_ids], axis=0) + + if needs_to_be_padded: + difference = max_length - len(required_input) + + if "attention_mask" in encoded_inputs: + encoded_inputs["attention_mask"] = np.pad(encoded_inputs["attention_mask"], + pad_width=[(0, 0), (difference, 0), (difference, 0)], + mode='constant', constant_values=True) + if "token_type_ids" in encoded_inputs: + encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[ + "token_type_ids" + ] + if "special_tokens_mask" in encoded_inputs: + encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"] + if "position_ids" in encoded_inputs: + encoded_inputs["position_ids"] = np.pad(encoded_inputs["position_ids"], + pad_width=[(0, 0), (difference, 0)]) + encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input + + return encoded_inputs diff --git a/tests/script/models/chatglm-6b/tokenizer_config.json b/tests/script/models/chatglm-6b/tokenizer_config.json new file mode 100644 index 0000000..f8221e0 --- /dev/null +++ b/tests/script/models/chatglm-6b/tokenizer_config.json @@ -0,0 +1,20 @@ +{ + "name_or_path": "THUDM/chatglm-6b", + "bos_token": "", + "eos_token": "", + "end_token": "", + "gmask_token": "[gMASK]", + "mask_token": "[MASK]", + "pad_token": "", + "unk_token": "", + "remove_space": false, + "do_lower_case": false, + "tokenizer_class": "ChatGLMTokenizer", + "num_image_tokens": 0, + "auto_map": { + "AutoTokenizer": [ + "tokenization_chatglm.ChatGLMTokenizer", + null + ] + } +} diff --git a/tests/script/pytest.ini b/tests/script/pytest.ini new file mode 100644 index 0000000..e68d012 --- /dev/null +++ b/tests/script/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +addopts = --ignore=models \ No newline at end of file diff --git a/tests/script/requirements.txt b/tests/script/requirements.txt new file mode 100644 index 0000000..b526b7c --- /dev/null +++ b/tests/script/requirements.txt @@ -0,0 +1,5 @@ +-f https://download.pytorch.org/whl/torch_stable.html +numpy==1.24.2 +torch==2.0.1+cpu +pytest +ninja \ No newline at end of file diff --git a/tests/script/test_attn_chatglm.py b/tests/script/test_attn_chatglm.py new file mode 100644 index 0000000..2caf02a --- /dev/null +++ b/tests/script/test_attn_chatglm.py @@ -0,0 +1,452 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +import math +import sys +import torch +import numpy as np +import llmdnn as ld +from torch import nn +import torch.nn.functional as F +from torch.nn.utils import skip_init +from typing import Optional, Tuple, Union, List, Callable, Dict, Any + +# copy from chatglm-6b/modeling_chatglm.py +class RotaryEmbedding(torch.nn.Module): + def __init__(self, dim, max_position_embeddings, base=10000, precision=torch.half, learnable=False): + super().__init__() + inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) + #inv_freq = inv_freq.half() + self.learnable = learnable + if learnable: + self.inv_freq = torch.nn.Parameter(inv_freq) + self.max_seq_len_cached = None + else: + self.register_buffer('inv_freq', inv_freq) + self.max_seq_len_cached = None + self.cos_cached = None + self.sin_cached = None + # use f32 to pass accuracy test + # Build here to make `torch.jit.trace` work. + self.max_seq_len_cached = max_position_embeddings + t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.cos_cached = emb.cos()[:, None, :] + self.sin_cached = emb.sin()[:, None, :] + + self.precision = precision + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, + error_msgs): + pass + + def forward(self, x, seq_dim=1, seq_len=None): + if seq_len is None: + seq_len = x.shape[seq_dim] + if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached): + self.max_seq_len_cached = None if self.learnable else seq_len + t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype) + freqs = torch.einsum('i,j->ij', t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + if self.precision == torch.bfloat16: + emb = emb.float() + + # [sx, 1 (b * np), hn] + cos_cached = emb.cos()[:, None, :] + sin_cached = emb.sin()[:, None, :] + if self.precision == torch.bfloat16: + cos_cached = cos_cached.bfloat16() + sin_cached = sin_cached.bfloat16() + if self.learnable: + return cos_cached, sin_cached + self.cos_cached, self.sin_cached = cos_cached, sin_cached + return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...] + + def _apply(self, fn): + if self.cos_cached is not None: + self.cos_cached = fn(self.cos_cached) + if self.sin_cached is not None: + self.sin_cached = fn(self.sin_cached) + return super()._apply(fn) + + +def rotate_half(x): + x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions + + +@torch.jit.script +def apply_rotary_pos_emb_index(q, k, cos, sin, position_id): + # position_id: [b, sq], q, k: [b, sq, np, hn], cos: [sq, 1, hn] -> [b, sq, 1, hn] + cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), \ + F.embedding(position_id, sin.squeeze(1)).unsqueeze(2) + q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) + return q, k + +class SelfAttention(torch.nn.Module): + def __init__(self, max_position_embeddings, hidden_size, num_attention_heads, + layer_id, hidden_size_per_attention_head=None, bias=True, + params_dtype=torch.float, position_encoding_2d=True, empty_init=True): + super(SelfAttention, self).__init__() + + self.layer_id = layer_id + self.hidden_size = hidden_size + self.hidden_size_per_partition = hidden_size + self.num_attention_heads = num_attention_heads + self.head_size = self.hidden_size // self.num_attention_heads + self.norm_factor = torch.sqrt(torch.tensor(self.head_size, dtype=torch.float32)).to(torch.get_default_dtype()) + self.num_attention_heads_per_partition = num_attention_heads + self.position_encoding_2d = position_encoding_2d + self.rotary_emb = RotaryEmbedding( + self.hidden_size // (self.num_attention_heads * 2) + if position_encoding_2d + else self.hidden_size // self.num_attention_heads, + max_position_embeddings, + base=10000, + precision=torch.half, + learnable=False, + ) + + self.scale_mask_softmax = None + + if hidden_size_per_attention_head is None: + self.hidden_size_per_attention_head = hidden_size // num_attention_heads + else: + self.hidden_size_per_attention_head = hidden_size_per_attention_head + + self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head + + @staticmethod + def attention_mask_func(attention_scores, attention_mask): + attention_scores.masked_fill_(attention_mask, -10000.0) + return attention_scores + + def split_tensor_along_last_dim(self, tensor, num_partitions, + contiguous_split_chunks=False): + """Split a tensor along its last dimension. + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + """ + # Get the size and dimension. + last_dim = tensor.dim() - 1 + last_dim_size = tensor.size()[last_dim] // num_partitions + # Split. + tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) + # Note: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tensor_list + + def attention_fn( + self, + query_layer, + key_layer, + value_layer, + layer_past=None, + attention_mask=None + ): + # batch, seqlen, num_attention_heads, hidden_size_per_attention_head + b, seq_len, nh, hidden_size = key_layer.shape + + # =================================== + # Raw attention scores. [b, np, s, s] + # =================================== + + # [b, np, sq, sk] + output_size = (query_layer.size(0), query_layer.size(2), query_layer.size(1), key_layer.size(1)) + + # # [sq, b, np, hn] -> [sq, b * np, hn] + # query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) + # # [sk, b, np, hn] -> [sk, b * np, hn] + # key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) + # [b, sq, np, hn] -> [b, np, sq, hn] + query_layer = query_layer.transpose(1, 2) + # [b, sk, np, hn] -> [b, np, sk, hn] + key_layer = key_layer.transpose(1, 2) + # [b, np, sq, hn] -> [b * np, sq, hn] + #query_layer = query_layer.view(output_size[0] * output_size[1], output_size[2], -1) + # [b, np, sk, hn] -> [b * np, sk, hn] + #key_layer = key_layer.view(output_size[0] * output_size[1], output_size[3], -1) + + # value_layer -> context layer. + # [b, sk, np, hn] --> [b, np, sq, hn] + + # context layer shape: [b, np, sq, hn] + output_size = (value_layer.size(0), value_layer.size(2), query_layer.size(1), value_layer.size(3)) + + # # change view [sk, b * np, hn] + # value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) + # [b, sk, np, hn] -> [b, np, sk, hn] + value_layer = value_layer.transpose(1, 2) + + if layer_past is not None: + past_key, past_value = layer_past[0], layer_past[1] + key_layer = torch.cat((past_key, key_layer), dim=2) + value_layer = torch.cat((past_value, value_layer), dim=2) + + # [b, np, sk, hn] -> [b * np, sk, hn] + #value_layer = value_layer.view(output_size[0] * output_size[1], value_layer.size(2), output_size[3]) + # return query_layer, key_layer, value_layer + past = (key_layer, value_layer) + + # from test_mha_chatglm.py/forward + query_layer = query_layer / self.norm_factor + + # =================================== + # Raw attention scores. [b, np, s, s] + # =================================== + + # [b, np, sq, sk] + output_size = (query_layer.size(0), query_layer.size(1), query_layer.size(2), key_layer.size(2)) + + # [b, np, sq, hn] -> [b * np, sq, hn] + query_layer = query_layer.reshape(output_size[0] * output_size[1], output_size[2], -1) + # [b, np, sk, hn] -> [b * np, sk, hn] + key_layer = key_layer.reshape(output_size[0] * output_size[1], output_size[3], -1) + + matmul_result = torch.zeros( + 1, 1, 1, + dtype=query_layer.dtype, + device=query_layer.device, + ) + + matmul_result = torch.baddbmm( + matmul_result, + query_layer, # [b * np, sq, hn] + key_layer.transpose(1, 2), # [b * np, hn, sk] + beta=0.0, + alpha=1.0, + ) + + # change view to [b, np, sq, sk] + attention_scores = matmul_result.view(*output_size) + + # if not (attention_mask == 0).all(): + # # if auto-regressive, skip + # attention_scores.masked_fill_(attention_mask, -10000.0) + attention_scores = attention_scores + attention_mask + dtype = attention_scores.dtype + attention_scores = attention_scores.float() + + attention_probs = F.softmax(attention_scores, dim=-1) + + attention_probs = attention_probs.type(dtype) + + # ========================= + # Context layer. [sq, b, hp] + # ========================= + + # value_layer -> context layer. + # [b, sk, np, hn] --> [b, np, sq, hn] + + # context layer shape: [b, np, sq, hn] + output_size = (value_layer.size(0), value_layer.size(1), query_layer.size(1), value_layer.size(3)) + + # [b, np, sk, hn] -> [b * np, sk, hn] + value_layer = value_layer.view(output_size[0] * output_size[1], value_layer.size(2), output_size[3]) + + # change view [b * np, sq, sk] + attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) + + # matmul: [b * np, sq, hn] + context_layer = torch.bmm(attention_probs, value_layer) + + # change view [b, np, sq, hn] + context_layer = context_layer.view(*output_size) + + # [b, np, sq, hn] --> [b, sq, np, hn] + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + + # [b, sq, np, hn] --> [b, sq, hp] + new_context_layer_shape = (output_size[0], output_size[2], output_size[1] * output_size[3]) + context_layer = context_layer.view(*new_context_layer_shape) + + return context_layer, past + + + def forward( + self, + qkv: torch.Tensor, # [batch, seq_len, 3 * hidden_size] + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]], + attention_mask: torch.Tensor, # [batch, 1, query_seq_len, key_seq_len] + position_ids # [batch, 2, query_seq_len] + ): + """ + hidden_states: [batch, seq_len, hidden_size] + attention_mask: [(1, 1), seq_len, seq_len] + """ + + # [batch, seq_len, 3 * hidden_size] + mixed_raw_layer = qkv # self.query_key_value(hidden_states) + + # [batch, seq_len, 3 * hidden_size] --> [batch, seq_len, num_attention_heads, 3 * hidden_size_per_attention_head] + new_tensor_shape = mixed_raw_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head, + ) + mixed_raw_layer = mixed_raw_layer.view(*new_tensor_shape) + + # [batch, seq_len, num_attention_heads, hidden_size_per_attention_head] + (query_layer, key_layer, value_layer) = self.split_tensor_along_last_dim(mixed_raw_layer, 3) + + if self.position_encoding_2d: + q1, q2 = query_layer.chunk(2, dim=(query_layer.ndim - 1)) + k1, k2 = key_layer.chunk(2, dim=(key_layer.ndim - 1)) + cos, sin = self.rotary_emb(q1, seq_len=position_ids.max() + 1) + # xxx + position_ids, block_position_ids = position_ids[:, 0, :].contiguous(), \ + position_ids[:, 1, :].contiguous() + q1, k1 = apply_rotary_pos_emb_index(q1, k1, cos, sin, position_ids) + q2, k2 = apply_rotary_pos_emb_index(q2, k2, cos, sin, block_position_ids) + query_layer = torch.concat([q1, q2], dim=(q1.ndim - 1)) + key_layer = torch.concat([k1, k2], dim=(k1.ndim - 1)) + else: + position_ids = position_ids.transpose(0, 1) + cos, sin = self.rotary_emb(value_layer, seq_len=position_ids.max() + 1) + # [batch, seq_len, num_attention_heads, hidden_size_per_attention_head] + query_layer, key_layer = apply_rotary_pos_emb_index(query_layer, key_layer, cos, sin, position_ids) + + query_layer = query_layer.to(dtype=torch.bfloat16) + key_layer = key_layer.to(dtype=torch.bfloat16) + + # [batch, seq_len, hidden_size] + attn, past = self.attention_fn( + query_layer=query_layer, + key_layer=key_layer, + value_layer=value_layer, + layer_past=layer_past, + attention_mask=attention_mask + ) + + return attn, past + + +class GPTAttentionExt: + def __init__(self, num_attention_heads, hidden_size, head_size_aligned, max_position_embeddings, rotary_emb_base, rotary_pct, is_int8=False): + self.attn = ld.attn_gpt() + num_heads = num_attention_heads + head_size = hidden_size // num_attention_heads + max_seq_len = max_position_embeddings + normal_factor = 1.0 / math.sqrt(head_size) + + qkv_precision_name = 's8' if is_int8 else 'bf16' + dst_precision_name = 's8' if is_int8 else 'bf16' + rotary_ndims = int(head_size * rotary_pct) + self.attn.create(num_heads, head_size, head_size_aligned, normal_factor, qkv_precision_name, + dst_precision_name, max_seq_len, rotary_ndims, True) + + inv_freq = 1. / (rotary_emb_base ** (torch.arange(0, rotary_ndims, 2).float() / rotary_ndims)) + #inv_freq = inv_freq.half() + self.max_seq_len_cached = None + self.cos_cached = None + self.sin_cached = None + # use f32 to pass accuracy test + # Build here to make `torch.jit.trace` work. + self.max_seq_len_cached = max_position_embeddings + t = torch.arange(self.max_seq_len_cached, device=inv_freq.device, dtype=inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.cos_cached = emb.cos()[:, None, :] + self.sin_cached = emb.sin()[:, None, :] + + # qkv: [batch, seq_len, (num_heads * 3 * head_size)] + # layer_past_padded: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] + # past_seq_len: past_seq_len==layer_past.shape[-2] + # attn_mask: [batch, 1, query_seq_len, key_seq_len] + # position_ids: [batch, 2, query_seq_len] + # return: + # 0: qkv [batch, seq_len, (num_heads * 3 * head_size)] + # 1: k: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] + # 2: v: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] + def forward(self, qkv, layer_past_key_padded, layer_past_value_padded, past_seq_len, attn_mask, position_ids): + return self.attn.exec_position(qkv, layer_past_key_padded, layer_past_value_padded, past_seq_len, attn_mask, position_ids, self.cos_cached, self.sin_cached) + + +HEAD_NUM = 32 +SIZE_PER_HEAD = 80 +SIZE_PER_HEAD_ALIGN = 96 +HIDDEN_SIZE = HEAD_NUM * SIZE_PER_HEAD +MAX_POSITION_EMBEDDINGS = 1024 #2048 +ROTARY_EMB_BASE = 10000 +ROTARY_PCT = 0.5 +MAX_SEQ_LEN = 1024 +def get_ref_model(): + ref_net = SelfAttention(MAX_POSITION_EMBEDDINGS, HIDDEN_SIZE, HEAD_NUM, 0, None) + ref_net = ref_net.to(dtype=torch.bfloat16) + return ref_net + +def test_attn(): + inputs = [ + # qkv: [batch, seq_len, (num_heads * 3 * head_size)] + # layer_past: [batch, num_attention_heads, past_seq_len, head_size] + # attn: [batch, 1, query_seq_len, key_seq_len] + (np.random.random(size=[2, 200, 3 * HEAD_NUM * SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 0, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 0, SIZE_PER_HEAD]).astype(np.float32), + np.zeros([2, 1, 200, 200], dtype=np.float32)), + (np.random.random(size=[2, 1, 3 * HEAD_NUM * SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), + np.zeros([2, 1, 1, 201], dtype=np.float32)), + ] + ref_net = get_ref_model() + net = GPTAttentionExt(HEAD_NUM, HIDDEN_SIZE, SIZE_PER_HEAD_ALIGN, MAX_POSITION_EMBEDDINGS, ROTARY_EMB_BASE, ROTARY_PCT) + with torch.cpu.amp.autocast(): + for (i, input) in enumerate(inputs): + qkv, layer_past_key, layer_past_value, attn_mask = input + qkv = torch.from_numpy(qkv).to(torch.bfloat16) + layer_past_key = torch.from_numpy(layer_past_key).to(torch.bfloat16) + layer_past_value = torch.from_numpy(layer_past_value).to(torch.bfloat16) + attn_mask = torch.from_numpy(attn_mask) + attn_mask[:,:,:,-2:] = torch.finfo(torch.float32).min + past_seq_len = layer_past_key.shape[-2] + shape = list(layer_past_key.shape) + + if qkv.size(1) > 1: + seq_batch1 = torch.arange(end=qkv.size(1) - 1, dtype=torch.int32) + seq_batch1 = torch.concat((seq_batch1, seq_batch1[-1:])) + block_batch1 = torch.concat((torch.zeros(qkv.size(1) -1, dtype=torch.int32), torch.ones(1, dtype=torch.int32))) + else: + seq_batch1 = torch.tensor([3], dtype=torch.int32) + block_batch1 = torch.tensor([5], dtype=torch.int32) + seq_ids = torch.empty((qkv.size(0), 2, qkv.size(1)), dtype=torch.int32) + seq_ids[:, 0, :] = seq_batch1 + seq_ids[:, 1, :] = block_batch1 + output_ref, (key_ref, value_ref) = ref_net.forward(qkv, (layer_past_key, layer_past_value), attn_mask, seq_ids) + + shape[-2] = MAX_SEQ_LEN + shape[-1] = SIZE_PER_HEAD_ALIGN + layer_past_key_padded = torch.zeros(shape, dtype=torch.bfloat16) + layer_past_value_padded = torch.zeros(shape, dtype=torch.bfloat16) + layer_past_key_padded[:,:,:layer_past_key.shape[-2],:layer_past_key.shape[-1]] = layer_past_key + layer_past_value_padded[:,:,:layer_past_key.shape[-2],:layer_past_key.shape[-1]] = layer_past_value + key_ref = key_ref.to(dtype=torch.bfloat16) + output = net.forward(qkv, layer_past_key_padded, layer_past_value_padded, past_seq_len, attn_mask, seq_ids) + key, value = layer_past_key_padded, layer_past_value_padded + # check output + if not torch.allclose(output_ref, output, rtol=0.001, atol=0.01): + print(f"error at index {i} ref:\n{output_ref} \ncur:\n {output} ") + assert(False) + # check key + if not torch.allclose(key_ref, key[:,:,:key_ref.shape[-2],:key_ref.shape[-1]], rtol=0.001, atol=0.01): + print(f"error at key index {i} ref:\n{key_ref} \ncur:\n {key} ") + assert(False) + # check value + if not torch.allclose(value_ref, value[:,:,:value_ref.shape[-2],:value_ref.shape[-1]], rtol=0.001, atol=0.01): + print(f"error at value index {i} ref:\n{value_ref} \ncur:\n {value} ") + assert(False) + + print('done.') + return + +if __name__ == "__main__": + test_attn() \ No newline at end of file diff --git a/tests/script/test_mha_chatglm.py b/tests/script/test_mha_chatglm.py new file mode 100644 index 0000000..59179e6 --- /dev/null +++ b/tests/script/test_mha_chatglm.py @@ -0,0 +1,259 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +import math +import sys +import torch +import numpy as np +import llmdnn as ld +from torch import nn +import torch.nn.functional as F + +# copy from chatglm-6b/modeling_chatglm.py +class GPTNeoXAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.num_attention_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.head_size = self.hidden_size // self.num_attention_heads + self.norm_factor = torch.sqrt(torch.tensor(self.head_size, dtype=torch.float32)).to(torch.get_default_dtype()) + + def forward(self, + query_layer, # [b, np, sq, hn] + key_layer, # [b, np, sk, hn] + value_layer, # [b, np, sk, hn] + attention_mask # [b, np, s, s] + ): + return self.attention_fn(query_layer, key_layer, value_layer, attention_mask) + + def attention_fn( + self, + query_layer, + key_layer, + value_layer, + attention_mask + ): + query_layer = query_layer / self.norm_factor + + # =================================== + # Raw attention scores. [b, np, s, s] + # =================================== + + # [b, np, sq, sk] + output_size = (query_layer.size(0), query_layer.size(1), query_layer.size(2), key_layer.size(2)) + + # [b, np, sq, hn] -> [b * np, sq, hn] + query_layer = query_layer.view(output_size[0] * output_size[1], output_size[2], -1) + # [b, np, sk, hn] -> [b * np, sk, hn] + key_layer = key_layer.view(output_size[0] * output_size[1], output_size[3], -1) + + matmul_result = torch.zeros( + 1, 1, 1, + dtype=query_layer.dtype, + device=query_layer.device, + ) + + matmul_result = torch.baddbmm( + matmul_result, + query_layer, # [b * np, sq, hn] + key_layer.transpose(1, 2), # [b * np, hn, sk] + beta=0.0, + alpha=1.0, + ) + + # change view to [b, np, sq, sk] + attention_scores = matmul_result.view(*output_size) + + # if not (attention_mask == 0).all(): + # # if auto-regressive, skip + # attention_scores.masked_fill_(attention_mask, -10000.0) + attention_scores = attention_scores + attention_mask + dtype = attention_scores.dtype + attention_scores = attention_scores.float() + + attention_probs = F.softmax(attention_scores, dim=-1) + + attention_probs = attention_probs.type(dtype) + + # ========================= + # Context layer. [sq, b, hp] + # ========================= + + # value_layer -> context layer. + # [b, sk, np, hn] --> [b, np, sq, hn] + + # context layer shape: [b, np, sq, hn] + output_size = (value_layer.size(0), value_layer.size(1), query_layer.size(1), value_layer.size(3)) + + # [b, np, sk, hn] -> [b * np, sk, hn] + value_layer = value_layer.view(output_size[0] * output_size[1], value_layer.size(2), output_size[3]) + + # change view [b * np, sq, sk] + attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) + + # matmul: [b * np, sq, hn] + context_layer = torch.bmm(attention_probs, value_layer) + + # change view [b, np, sq, hn] + context_layer = context_layer.view(*output_size) + + # [b, np, sq, hn] --> [b, sq, np, hn] + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + + # [b, sq, np, hn] --> [b, sq, hp] + new_context_layer_shape = (output_size[0], output_size[2], output_size[1] * output_size[3]) + context_layer = context_layer.view(*new_context_layer_shape) + + return context_layer + +class GPTNeoXAttentionExt: + def __init__(self, num_attention_heads, hidden_size, head_size_aligned, max_position_embeddings, is_int8=False): + self.mha = ld.mha_gpt() + num_heads = num_attention_heads + head_size = hidden_size // num_attention_heads + max_seq_len = max_position_embeddings + + head_size_aligned = head_size_aligned + normal_factor = 1.0 / math.sqrt(head_size) + qkv_precision_name = 's8' if is_int8 else 'bf16' + dst_precision_name = 's8' if is_int8 else 'bf16' + self.mha.create(num_heads, head_size, head_size_aligned, normal_factor, qkv_precision_name, + dst_precision_name, max_seq_len) + + def forward(self, query, key, value, attention_mask, head_size, key_seq_len): + return self.mha.exec(query, key, value, attention_mask, head_size, key_seq_len) + + def forward_quant(self, query, key, value, attention_mask, q_quant, k_quant, qk_quant, v_quant, requant): + # q_dequant, k_dequant, v_dequant, qk_quant, std::vector& qkv_quant + return self.mha.exec_quant(query, key, value, attention_mask, 1.0 / q_quant, 1.0 / k_quant, 1.0 / v_quant, qk_quant, requant) + +HEAD_NUM = 32 +SIZE_PER_HEAD = 80 +SIZE_PER_HEAD_ALIGN = 96 +HIDDEN_SIZE = HEAD_NUM * SIZE_PER_HEAD +MAX_POSITION_EMBEDDINGS = 1024 #2048 +def get_ref_model(): + class FakeConfig: + def __init__(self): + self.num_attention_heads = HEAD_NUM + self.hidden_size = HIDDEN_SIZE + self.max_position_embeddings = MAX_POSITION_EMBEDDINGS + config = FakeConfig() + ref_net = GPTNeoXAttention(config) + ref_net = ref_net.to(dtype=torch.bfloat16) + return ref_net + +def test_chatglm_neox(): + inputs = [ + # q, k, v, attn_mask + # q: [batch, num_heads, query_seq_len, head_size] + # k: [batch, num_heads, key_seq_len, head_size] + # v: [batch, num_heads, value_seq_len, head_size] + # attn: [batch, 1, query_seq_len, key_seq_len] + (np.random.random(size=[2, HEAD_NUM, 2, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 32, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 32, SIZE_PER_HEAD]).astype(np.float32), + np.zeros([2, 1, 2, 32], dtype=np.float32)), + (np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), + np.zeros([2, 1, 200, 200], dtype=np.float32)), + (np.random.random(size=[2, HEAD_NUM, 1, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), + np.zeros([2, 1, 1, 200], dtype=np.float32)), + ] + ref_net = get_ref_model() + net = GPTNeoXAttentionExt(HEAD_NUM, HIDDEN_SIZE, SIZE_PER_HEAD_ALIGN, MAX_POSITION_EMBEDDINGS) + with torch.cpu.amp.autocast(): + for (i, input) in enumerate(inputs): + q, k, v, attn_mask = input + q = torch.from_numpy(q).to(torch.bfloat16) + k = torch.from_numpy(k).to(torch.bfloat16) + v = torch.from_numpy(v).to(torch.bfloat16) + attn_mask = torch.from_numpy(attn_mask) + attn_mask[:,:,:,-2:] = torch.finfo(torch.float32).min + ref_output = ref_net.forward(q, k, v, attn_mask) + + shape = list(k.shape) + shape[-2] = MAX_POSITION_EMBEDDINGS + shape[-1] = SIZE_PER_HEAD_ALIGN + key_padded = torch.zeros(shape, dtype=torch.bfloat16) + value_padded = torch.zeros(shape, dtype=torch.bfloat16) + query_shape = list(q.shape) + query_shape[-1] = SIZE_PER_HEAD_ALIGN + query_padded = torch.zeros(query_shape, dtype=torch.bfloat16) + key_padded[:,:,:k.shape[-2],:k.shape[-1]] = k + value_padded[:,:,:k.shape[-2],:k.shape[-1]] = v + query_padded[:,:,:,:q.shape[-1]] = q + + output = net.forward(query_padded, key_padded, value_padded, attn_mask, k.size(3), k.size(2)) + if not torch.allclose(ref_output, output, rtol=0.001, atol=0.01): + print(f"error at index {i} ref:\n{ref_output} \ncur:\n {output} ") + assert(False) + + print('done.') + return + +def todotest_gpt_neox_int8(): + low = -4 + high = 4 + range_ = high - low + q_quant = 127.0 / high + qs = [ + np.random.random(size=[2, HEAD_NUM, 900, SIZE_PER_HEAD]).astype(np.float32)*range_+low, + np.random.random(size=[2, HEAD_NUM , 1, SIZE_PER_HEAD]).astype(np.float32)*range_+low, + np.random.random(size=[2, HEAD_NUM , 1, SIZE_PER_HEAD]).astype(np.float32)*range_+low, + ] + low = -2 + high = 2 + range_ = high - low + k_quant = 127.0 / high + ks = [ + np.random.random(size=[2, HEAD_NUM, 900, SIZE_PER_HEAD]).astype(np.float32)*range_+low, + np.random.random(size=[2, HEAD_NUM, 901, SIZE_PER_HEAD]).astype(np.float32)*range_+low, + np.random.random(size=[2, HEAD_NUM, 902, SIZE_PER_HEAD]).astype(np.float32)*range_+low, + ] + low = -8 + high = 8 + range_ = high - low + v_quant = 127.0 / high + vs = [ + np.random.random(size=[2, HEAD_NUM, 900, SIZE_PER_HEAD]).astype(np.float32)*range_+low, + np.random.random(size=[2, HEAD_NUM, 901, SIZE_PER_HEAD]).astype(np.float32)*range_+low, + np.random.random(size=[2, HEAD_NUM, 902, SIZE_PER_HEAD]).astype(np.float32)*range_+low, + ] + # q, k, v, attn_mask + # q: [batch, num_heads, query_seq_len, head_size] + # k: [batch, num_heads, key_seq_len, head_size] + # v: [batch, num_heads, value_seq_len, head_size] + # attn: [1, MAX_POSITION_EMBEDDINGS] + inputs = [] + for i in range(len(qs)): + inputs.append((qs[i], ks[i], vs[i], np.zeros([1, ks[i].shape[-2]], dtype=np.float32))) + + ref_net = get_ref_model() + net = GPTNeoXAttentionExt(HEAD_NUM, HIDDEN_SIZE, MAX_POSITION_EMBEDDINGS, True) + qk_quant, requant = 255.0, 10.0 + with torch.cpu.amp.autocast(): + for (i, input) in enumerate(inputs): + q, k, v, attn_mask = input + q = torch.from_numpy(q) + k = torch.from_numpy(k) + v = torch.from_numpy(v) + attn_mask = torch.from_numpy(attn_mask) + q = (q * q_quant).round().clamp(-128, 127).to(torch.int8) + k = (k * k_quant).round().clamp(-128, 127).to(torch.int8) + v = (v * v_quant).round().clamp(-128, 127).to(torch.int8) + ref_output = ref_net.forward(q, k, v, attn_mask, q_quant, k_quant, qk_quant, v_quant, requant) + output = net.forward_quant(q, k, v, attn_mask, q_quant, k_quant, qk_quant, v_quant, [requant,]) + if (torch.abs(ref_output- output) > 2).any(): + print(f"error at index {i} ref:\n{ref_output} \ncur:\n {output} ") + assert(False) + + print('done.') + return + +if __name__ == "__main__": + test_chatglm_neox() \ No newline at end of file diff --git a/tests/script/test_mha_gpt.py b/tests/script/test_mha_gpt.py new file mode 100644 index 0000000..ac99f99 --- /dev/null +++ b/tests/script/test_mha_gpt.py @@ -0,0 +1,243 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +import math +import sys +import torch +import numpy as np +import llmdnn as ld +from torch import nn + +# copy from transformers/models/gpt_neox/modeling_gpt_neox.py +class GPTNeoXAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.num_attention_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.head_size = self.hidden_size // self.num_attention_heads + max_positions = config.max_position_embeddings + self.register_buffer( + "bias", + torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view( + 1, 1, max_positions, max_positions + ), + ) + self.norm_factor = torch.sqrt(torch.tensor(self.head_size, dtype=torch.float32)).to(torch.get_default_dtype()) + + def forward(self, query, key, value, attention_mask, q_quant=None, k_quant=None, qk_quant=None, v_quant=None, requant=None): + if q_quant: + # quant + query = query.to(torch.float32) + key = key.to(torch.float32) + value = value.to(torch.float32) + + # Compute attention + attn_output, attn_weights = self._attn(query, key, value, attention_mask, q_quant, k_quant, qk_quant, v_quant) + if q_quant: + attn_output = (attn_output * requant).round().clamp(-128, 127).to(torch.int8) + + # Reshape outputs + attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size) + + return attn_output + + @classmethod + def _merge_heads(cls, tensor, num_attention_heads, attn_head_size): + """ + Merges attn_head_size dim and num_attn_heads dim into hidden dim + """ + # tensor [bs, num_attention_heads, seq_len, attn_head_size] + tensor = tensor.permute(0, 2, 1, 3).contiguous() + # -> [bs, seq_len, num_attention_heads, attn_head_size] + tensor = tensor.view(tensor.size(0), tensor.size(1), num_attention_heads * attn_head_size) + # -> [bs, seq_len, hidden_size] + return tensor + + def _attn(self, query, key, value, attention_mask, q_quant, k_quant, qk_quant, v_quant): + # q, k, v: [bs, num_attention_heads, seq_len, attn_head_size] + # compute causal mask from causal mask buffer + batch_size, num_attention_heads, query_length, attn_head_size = query.size() + key_length = key.size(-2) + + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] + + query = query.view(batch_size * num_attention_heads, query_length, attn_head_size) + key = key.view(batch_size * num_attention_heads, key_length, attn_head_size) + attn_scores = torch.zeros( + batch_size * num_attention_heads, + query_length, + key_length, + dtype=query.dtype, + device=key.device, + ) + if q_quant: + norm_factor = torch.tensor(1.0, dtype=self.norm_factor.dtype, device=self.norm_factor.device) / self.norm_factor * (1 / q_quant) * (1 / k_quant) + else: + norm_factor = torch.tensor(1.0, dtype=self.norm_factor.dtype, device=self.norm_factor.device) / self.norm_factor + attn_scores = torch.baddbmm( + attn_scores, + query, + key.transpose(1, 2), + beta=1.0, + alpha=(norm_factor), + ) + attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length) + + mask_value = torch.finfo(attn_scores.dtype).min + # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + mask_value = torch.tensor(mask_value, dtype=attn_scores.dtype).to(attn_scores.device) + attn_scores = torch.where(causal_mask, attn_scores, mask_value) + + if attention_mask is not None: + # Apply the attention mask + attn_scores = attn_scores + attention_mask + + attn_weights = nn.functional.softmax(attn_scores, dim=-1) + attn_weights = attn_weights.to(value.dtype) + + if q_quant: + attn_weights = (attn_weights * qk_quant).round().clamp(0, 255).to(torch.uint8).to(torch.float32) + attn_output = torch.matmul(attn_weights, value) + if q_quant: + attn_output = attn_output * ((1 / qk_quant) * (1 / v_quant)) + return attn_output, attn_weights + +class GPTNeoXAttentionExt: + def __init__(self, num_attention_heads, hidden_size, max_position_embeddings, is_int8=False): + self.mha = ld.mha_gpt() + num_heads = num_attention_heads + head_size = hidden_size // num_attention_heads + max_seq_len = max_position_embeddings + + head_size_aligned = head_size + normal_factor = 1.0 / math.sqrt(head_size) + qkv_precision_name = 's8' if is_int8 else 'bf16' + dst_precision_name = 's8' if is_int8 else 'bf16' + self.mha.create(num_heads, head_size, head_size_aligned, normal_factor, qkv_precision_name, + dst_precision_name, max_seq_len) + + def forward(self, query, key, value, attention_mask): + return self.mha.exec(query, key, value, attention_mask) + + def forward_quant(self, query, key, value, attention_mask, q_quant, k_quant, qk_quant, v_quant, requant): + # q_dequant, k_dequant, v_dequant, qk_quant, std::vector& qkv_quant + return self.mha.exec_quant(query, key, value, attention_mask, 1.0 / q_quant, 1.0 / k_quant, 1.0 / v_quant, qk_quant, requant) + +HEAD_NUM = 32 +SIZE_PER_HEAD = 80 +HIDDEN_SIZE = HEAD_NUM * SIZE_PER_HEAD +MAX_POSITION_EMBEDDINGS = 1024 #2048 +def get_ref_model(): + class FakeConfig: + def __init__(self): + self.num_attention_heads = HEAD_NUM + self.hidden_size = HIDDEN_SIZE + self.max_position_embeddings = MAX_POSITION_EMBEDDINGS + config = FakeConfig() + ref_net = GPTNeoXAttention(config) + ref_net = ref_net.to(dtype=torch.bfloat16) + return ref_net + +def test_gpt_neox(): + inputs = [ + # q, k, v, attn_mask + # q: [batch, num_heads, query_seq_len, head_size] + # k: [batch, num_heads, key_seq_len, head_size] + # v: [batch, num_heads, value_seq_len, head_size] + # attn: [2, 1, 1, key_seq_len] + (np.random.random(size=[2, HEAD_NUM, 2, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 32, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 32, SIZE_PER_HEAD]).astype(np.float32), + np.zeros([2, 1, 1, 32], dtype=np.float32)), + (np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), + np.zeros([2, 1, 1, 200], dtype=np.float32)), + (np.random.random(size=[2, HEAD_NUM, 1, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), + np.zeros([2, 1, 1, 200], dtype=np.float32)), + ] + ref_net = get_ref_model() + net = GPTNeoXAttentionExt(HEAD_NUM, HIDDEN_SIZE, MAX_POSITION_EMBEDDINGS) + with torch.cpu.amp.autocast(): + for (i, input) in enumerate(inputs): + q, k, v, attn_mask = input + q = torch.from_numpy(q).to(torch.bfloat16) + k = torch.from_numpy(k).to(torch.bfloat16) + v = torch.from_numpy(v).to(torch.bfloat16) + attn_mask = torch.from_numpy(attn_mask) + attn_mask[:,:,:,-2:] = torch.finfo(torch.float32).min + ref_output = ref_net.forward(q, k, v, attn_mask) + output = net.forward(q, k, v, attn_mask) + if not torch.allclose(ref_output, output, rtol=0.001, atol=0.01): + print(f"error at index {i} ref:\n{ref_output} \ncur:\n {output} ") + assert(False) + + print('done.') + return + +def test_gpt_neox_int8(): + low = -4 + high = 4 + range_ = high - low + q_quant = 127.0 / high + qs = [ + np.random.random(size=[2, HEAD_NUM, 900, SIZE_PER_HEAD]).astype(np.float32)*range_+low, + np.random.random(size=[2, HEAD_NUM , 1, SIZE_PER_HEAD]).astype(np.float32)*range_+low, + np.random.random(size=[2, HEAD_NUM , 1, SIZE_PER_HEAD]).astype(np.float32)*range_+low, + ] + low = -2 + high = 2 + range_ = high - low + k_quant = 127.0 / high + ks = [ + np.random.random(size=[2, HEAD_NUM, 900, SIZE_PER_HEAD]).astype(np.float32)*range_+low, + np.random.random(size=[2, HEAD_NUM, 901, SIZE_PER_HEAD]).astype(np.float32)*range_+low, + np.random.random(size=[2, HEAD_NUM, 902, SIZE_PER_HEAD]).astype(np.float32)*range_+low, + ] + low = -8 + high = 8 + range_ = high - low + v_quant = 127.0 / high + vs = [ + np.random.random(size=[2, HEAD_NUM, 900, SIZE_PER_HEAD]).astype(np.float32)*range_+low, + np.random.random(size=[2, HEAD_NUM, 901, SIZE_PER_HEAD]).astype(np.float32)*range_+low, + np.random.random(size=[2, HEAD_NUM, 902, SIZE_PER_HEAD]).astype(np.float32)*range_+low, + ] + # q, k, v, attn_mask + # q: [batch, num_heads, query_seq_len, head_size] + # k: [batch, num_heads, key_seq_len, head_size] + # v: [batch, num_heads, value_seq_len, head_size] + # attn: [batch, 1, 1, key_seq_len] + inputs = [] + for i in range(len(qs)): + inputs.append((qs[i], ks[i], vs[i], np.zeros([ks[i].shape[0], 1, 1, ks[i].shape[-2]], dtype=np.float32))) + + ref_net = get_ref_model() + net = GPTNeoXAttentionExt(HEAD_NUM, HIDDEN_SIZE, MAX_POSITION_EMBEDDINGS, True) + qk_quant, requant = 255.0, 10.0 + with torch.cpu.amp.autocast(): + for (i, input) in enumerate(inputs): + q, k, v, attn_mask = input + q = torch.from_numpy(q) + k = torch.from_numpy(k) + v = torch.from_numpy(v) + attn_mask = torch.from_numpy(attn_mask) + q = (q * q_quant).round().clamp(-128, 127).to(torch.int8) + k = (k * k_quant).round().clamp(-128, 127).to(torch.int8) + v = (v * v_quant).round().clamp(-128, 127).to(torch.int8) + ref_output = ref_net.forward(q, k, v, attn_mask, q_quant, k_quant, qk_quant, v_quant, requant) + output = net.forward_quant(q, k, v, attn_mask, q_quant, k_quant, qk_quant, v_quant, [requant,]) + if (torch.abs(ref_output- output) > 2).any(): + print(f"error at index {i} ref:\n{ref_output} \ncur:\n {output} ") + assert(False) + + print('done.') + return + +if __name__ == "__main__": + test_gpt_neox() + test_gpt_neox_int8() \ No newline at end of file diff --git a/tests/script/test_rotary_pastkv.py b/tests/script/test_rotary_pastkv.py new file mode 100644 index 0000000..402720f --- /dev/null +++ b/tests/script/test_rotary_pastkv.py @@ -0,0 +1,225 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +import math +import sys +import torch +import numpy as np +import llmdnn as ld +from torch import nn + +# copy from transformers/models/gpt_neox/modeling_gpt_neox.py +class RotaryEmbedding(torch.nn.Module): + def __init__(self, dim, max_position_embeddings, base=10000, device=None): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) + self.register_buffer("inv_freq", inv_freq) + + # Build here to make `torch.jit.trace` work. + self.max_seq_len_cached = max_position_embeddings + t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.cos_cached = emb.cos()[None, None, :, :] + self.sin_cached = emb.sin()[None, None, :, :] + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. + if seq_len > self.max_seq_len_cached: + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + self.cos_cached = emb.cos()[None, None, :, :] + self.sin_cached = emb.sin()[None, None, :, :] + return self.cos_cached[:seq_len, ...].to(x.device), self.sin_cached[:seq_len, ...].to(x.device) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0): + cos = cos[..., offset : q.shape[-2] + offset, :] + sin = sin[..., offset : q.shape[-2] + offset, :] + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + +class GPTNeoXAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.num_attention_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.head_size = self.hidden_size // self.num_attention_heads + self.rotary_ndims = int(self.head_size * config.rotary_pct) + self.rotary_emb = RotaryEmbedding( + self.rotary_ndims, config.max_position_embeddings, base=config.rotary_emb_base + ) + + # qkv: [batch, seq_len, (num_heads * 3 * head_size)] + # layer_past: [batch, num_attention_heads, past_seq_len, head_size] + # return: (key, value), key/value: [batch, num_attention_heads, seq_len+layer_past[0].shape[-2], head_size] + def forward(self, qkv, layer_past): + has_layer_past = layer_past is not None + # [batch, seq_len, (num_heads * 3 * head_size)] + # --> [batch, seq_len, num_heads, 3 * head_size] + new_qkv_shape = qkv.size()[:-1] + (self.num_attention_heads, 3 * self.head_size) + qkv = qkv.view(*new_qkv_shape) + + # [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size] + query = qkv[..., : self.head_size].permute(0, 2, 1, 3) + key = qkv[..., self.head_size : 2 * self.head_size].permute(0, 2, 1, 3) + value = qkv[..., 2 * self.head_size :].permute(0, 2, 1, 3) + + # Compute rotary embeddings on rotary_ndims + query_rot = query[..., : self.rotary_ndims] + query_pass = query[..., self.rotary_ndims :] + key_rot = key[..., : self.rotary_ndims] + key_pass = key[..., self.rotary_ndims :] + + # Compute token offset for rotary embeddings (when decoding) + seq_len = key.shape[-2] + offset = 0 + if has_layer_past: + offset = layer_past[0].shape[-2] + seq_len += offset + cos, sin = self.rotary_emb(value, seq_len=seq_len) + query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, offset=offset) + query = torch.cat((query, query_pass), dim=-1) + key = torch.cat((key, key_pass), dim=-1) + + # Cache QKV values + if has_layer_past: + past_key = layer_past[0] + past_value = layer_past[1] + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + present = (key, value) + + return present, query, key, value + +class GPTNeoXAttentionExt: + def __init__(self, num_attention_heads, hidden_size, head_size_aligned, max_position_embeddings, rotary_emb_base, rotary_pct, is_int8=False): + self.emd = ld.emb_gpt() + num_heads = num_attention_heads + head_size = hidden_size // num_attention_heads + max_seq_len = max_position_embeddings + + qkv_precision_name = 's8' if is_int8 else 'bf16' + dst_precision_name = 's8' if is_int8 else 'bf16' + rotary_ndims = int(head_size * rotary_pct) + self.emd.create(num_heads, head_size, head_size_aligned, qkv_precision_name, + dst_precision_name, rotary_ndims) + + inv_freq = 1.0 / (rotary_emb_base ** (torch.arange(0, rotary_ndims, 2).float() / rotary_ndims)) + + # Build here to make `torch.jit.trace` work. + self.max_seq_len_cached = max_position_embeddings + t = torch.arange(self.max_seq_len_cached, device=inv_freq.device, dtype=inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.cos_cached = emb.cos()[None, None, :, :] + self.sin_cached = emb.sin()[None, None, :, :] + + # qkv: [batch, seq_len, (num_heads * 3 * head_size)] + # layer_past_padded: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] + # past_seq_len: past_seq_len==layer_past.shape[-2] + # return: + # 0: (k, v): ([batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned], [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned]) + # 1: query: [batch, num_attention_heads, seq_len, head_size_aligned] + # 2: k: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] + # 3: v: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] + def forward(self, qkv, layer_past_key_padded, layer_past_value_padded, query_padded, past_seq_len): + self.emd.exec(qkv, layer_past_key_padded, layer_past_value_padded, query_padded, past_seq_len, self.cos_cached, self.sin_cached) + + +HEAD_NUM = 32 +SIZE_PER_HEAD = 80 +SIZE_PER_HEAD_ALIGN = 96 +HIDDEN_SIZE = HEAD_NUM * SIZE_PER_HEAD +MAX_POSITION_EMBEDDINGS = 1024 #2048 +ROTARY_EMB_BASE = 10000 +ROTARY_PCT = 0.25 +MAX_SEQ_LEN = 1024 +def get_ref_model(): + class FakeConfig: + def __init__(self): + self.num_attention_heads = HEAD_NUM + self.hidden_size = HIDDEN_SIZE + self.rotary_pct = ROTARY_PCT + self.max_position_embeddings = MAX_POSITION_EMBEDDINGS + self.rotary_emb_base = ROTARY_EMB_BASE + config = FakeConfig() + ref_net = GPTNeoXAttention(config) + ref_net = ref_net.to(dtype=torch.bfloat16) + return ref_net + +def test_gpt_neox(): + inputs = [ + # qkv: [batch, seq_len, (num_heads * 3 * head_size)] + # layer_past: [batch, num_attention_heads, past_seq_len, head_size] + (np.random.random(size=[2, 200, 3 * HEAD_NUM * SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 0, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 0, SIZE_PER_HEAD]).astype(np.float32)), + (np.random.random(size=[2, 1, 3 * HEAD_NUM * SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32)), + ] + ref_net = get_ref_model() + net = GPTNeoXAttentionExt(HEAD_NUM, HIDDEN_SIZE, SIZE_PER_HEAD_ALIGN, MAX_POSITION_EMBEDDINGS, ROTARY_EMB_BASE, ROTARY_PCT) + with torch.cpu.amp.autocast(): + for (i, input) in enumerate(inputs): + qkv, layer_past_key, layer_past_value = input + qkv = torch.from_numpy(qkv).to(torch.bfloat16) + past_seq_len = layer_past_key.shape[-2] + shape = list(layer_past_key.shape) + shape[-2] = MAX_SEQ_LEN + shape[-1] = SIZE_PER_HEAD_ALIGN + layer_past_key_padded = torch.zeros(shape, dtype=torch.bfloat16) + layer_past_value_padded = torch.zeros(shape, dtype=torch.bfloat16) + query_shape = list(shape) + query_shape[2] = qkv.shape[1] + query_padded = torch.zeros(query_shape, dtype=torch.bfloat16) + layer_past_key = torch.from_numpy(layer_past_key).to(torch.bfloat16) + layer_past_value = torch.from_numpy(layer_past_value).to(torch.bfloat16) + layer_past_key_padded[:,:,:layer_past_key.shape[-2],:layer_past_key.shape[-1]] = layer_past_key + layer_past_value_padded[:,:,:layer_past_key.shape[-2],:layer_past_key.shape[-1]] = layer_past_value + ref_output, query_ref, key_ref, value_ref = ref_net.forward(qkv, (layer_past_key, layer_past_value)) + query_ref = query_ref.to(dtype=torch.bfloat16) + key_ref = key_ref.to(dtype=torch.bfloat16) + net.forward(qkv, layer_past_key_padded, layer_past_value_padded, query_padded, past_seq_len) + output, query, key, value = (layer_past_key_padded, layer_past_value_padded), query_padded, layer_past_key_padded, layer_past_value_padded + # check output + if not torch.allclose(ref_output[0].to(dtype=torch.bfloat16), output[0][:,:,:ref_output[0].shape[-2],:ref_output[0].shape[-1]], rtol=0.001, atol=0.01): + print(f"error at past key index {i} ref:\n{ref_output[0]} \ncur:\n {output[0]} ") + assert(False) + if not torch.allclose(ref_output[1], output[1][:,:,:ref_output[1].shape[-2],:ref_output[1].shape[-1]], rtol=0.001, atol=0.01): + print(f"error at past value index {i} ref:\n{ref_output[1]} \ncur:\n {output[1]} ") + assert(False) + # check query + if not torch.allclose(query_ref, query[:,:,:,:query_ref.shape[-1]], rtol=0.001, atol=0.01): + print(f"error at query index {i} ref:\n{query_ref} \ncur:\n {query} ") + assert(False) + # check key + if not torch.allclose(key_ref, key[:,:,:key_ref.shape[-2],:key_ref.shape[-1]], rtol=0.001, atol=0.01): + print(f"error at key index {i} ref:\n{key_ref} \ncur:\n {key} ") + assert(False) + # check value + if not torch.allclose(value_ref, value[:,:,:value_ref.shape[-2],:value_ref.shape[-1]], rtol=0.001, atol=0.01): + print(f"error at value index {i} ref:\n{value_ref} \ncur:\n {value} ") + assert(False) + + print('done.') + return + +if __name__ == "__main__": + test_gpt_neox() \ No newline at end of file diff --git a/tests/script/test_rotary_pastkv_chatglm.py b/tests/script/test_rotary_pastkv_chatglm.py new file mode 100644 index 0000000..612c7d4 --- /dev/null +++ b/tests/script/test_rotary_pastkv_chatglm.py @@ -0,0 +1,389 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +import math +import sys +import torch +import numpy as np +import llmdnn as ld +from torch import nn +import torch.nn.functional as F +from torch.nn.utils import skip_init +from typing import Optional, Tuple, Union, List, Callable, Dict, Any + +# copy from chatglm-6b/modeling_chatglm.py +class RotaryEmbedding(torch.nn.Module): + def __init__(self, dim, max_position_embeddings, base=10000, precision=torch.half, learnable=False): + super().__init__() + inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) + #inv_freq = inv_freq.half() + self.learnable = learnable + if learnable: + self.inv_freq = torch.nn.Parameter(inv_freq) + self.max_seq_len_cached = None + else: + self.register_buffer('inv_freq', inv_freq) + self.max_seq_len_cached = None + self.cos_cached = None + self.sin_cached = None + # use f32 to pass accuracy test + # Build here to make `torch.jit.trace` work. + self.max_seq_len_cached = max_position_embeddings + t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.cos_cached = emb.cos()[:, None, :] + self.sin_cached = emb.sin()[:, None, :] + + self.precision = precision + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, + error_msgs): + pass + + def forward(self, x, seq_dim=1, seq_len=None): + if seq_len is None: + seq_len = x.shape[seq_dim] + if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached): + self.max_seq_len_cached = None if self.learnable else seq_len + t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype) + freqs = torch.einsum('i,j->ij', t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + if self.precision == torch.bfloat16: + emb = emb.float() + + # [sx, 1 (b * np), hn] + cos_cached = emb.cos()[:, None, :] + sin_cached = emb.sin()[:, None, :] + if self.precision == torch.bfloat16: + cos_cached = cos_cached.bfloat16() + sin_cached = sin_cached.bfloat16() + if self.learnable: + return cos_cached, sin_cached + self.cos_cached, self.sin_cached = cos_cached, sin_cached + return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...] + + def _apply(self, fn): + if self.cos_cached is not None: + self.cos_cached = fn(self.cos_cached) + if self.sin_cached is not None: + self.sin_cached = fn(self.sin_cached) + return super()._apply(fn) + + +def rotate_half(x): + x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions + + +@torch.jit.script +def apply_rotary_pos_emb_index(q, k, cos, sin, position_id): + # position_id: [b, sq], q, k: [b, sq, np, hn], cos: [sq, 1, hn] -> [b, sq, 1, hn] + cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), \ + F.embedding(position_id, sin.squeeze(1)).unsqueeze(2) + q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) + return q, k + +class SelfAttention(torch.nn.Module): + def __init__(self, max_position_embeddings, hidden_size, num_attention_heads, + layer_id, hidden_size_per_attention_head=None, bias=True, + params_dtype=torch.float, position_encoding_2d=True, empty_init=True): + super(SelfAttention, self).__init__() + + self.layer_id = layer_id + self.hidden_size = hidden_size + self.hidden_size_per_partition = hidden_size + self.num_attention_heads = num_attention_heads + self.num_attention_heads_per_partition = num_attention_heads + self.position_encoding_2d = position_encoding_2d + self.rotary_emb = RotaryEmbedding( + self.hidden_size // (self.num_attention_heads * 2) + if position_encoding_2d + else self.hidden_size // self.num_attention_heads, + max_position_embeddings, + base=10000, + precision=torch.half, + learnable=False, + ) + + self.scale_mask_softmax = None + + if hidden_size_per_attention_head is None: + self.hidden_size_per_attention_head = hidden_size // num_attention_heads + else: + self.hidden_size_per_attention_head = hidden_size_per_attention_head + + self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head + + @staticmethod + def attention_mask_func(attention_scores, attention_mask): + attention_scores.masked_fill_(attention_mask, -10000.0) + return attention_scores + + def split_tensor_along_last_dim(self, tensor, num_partitions, + contiguous_split_chunks=False): + """Split a tensor along its last dimension. + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + """ + # Get the size and dimension. + last_dim = tensor.dim() - 1 + last_dim_size = tensor.size()[last_dim] // num_partitions + # Split. + tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) + # Note: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tensor_list + + def attention_fn( + self, + query_layer, + key_layer, + value_layer, + layer_past=None + ): + # batch, seqlen, num_attention_heads, hidden_size_per_attention_head + b, seq_len, nh, hidden_size = key_layer.shape + + # =================================== + # Raw attention scores. [b, np, s, s] + # =================================== + + # [b, np, sq, sk] + output_size = (query_layer.size(0), query_layer.size(2), query_layer.size(1), key_layer.size(1)) + + # # [sq, b, np, hn] -> [sq, b * np, hn] + # query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) + # # [sk, b, np, hn] -> [sk, b * np, hn] + # key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) + # [b, sq, np, hn] -> [b, np, sq, hn] + query_layer = query_layer.transpose(1, 2) + # [b, sk, np, hn] -> [b, np, sk, hn] + key_layer = key_layer.transpose(1, 2) + # [b, np, sq, hn] -> [b * np, sq, hn] + #query_layer = query_layer.view(output_size[0] * output_size[1], output_size[2], -1) + # [b, np, sk, hn] -> [b * np, sk, hn] + #key_layer = key_layer.view(output_size[0] * output_size[1], output_size[3], -1) + + # value_layer -> context layer. + # [b, sk, np, hn] --> [b, np, sq, hn] + + # context layer shape: [b, np, sq, hn] + output_size = (value_layer.size(0), value_layer.size(2), query_layer.size(1), value_layer.size(3)) + + # # change view [sk, b * np, hn] + # value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) + # [b, sk, np, hn] -> [b, np, sk, hn] + value_layer = value_layer.transpose(1, 2) + + if layer_past is not None: + past_key, past_value = layer_past[0], layer_past[1] + key_layer = torch.cat((past_key, key_layer), dim=2) + value_layer = torch.cat((past_value, value_layer), dim=2) + + # [b, np, sk, hn] -> [b * np, sk, hn] + #value_layer = value_layer.view(output_size[0] * output_size[1], value_layer.size(2), output_size[3]) + return query_layer, key_layer, value_layer + + def forward( + self, + qkv: torch.Tensor, # [batch, seq_len, 3 * hidden_size] + position_ids, # [batch, 2, query_seq_len] + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None + ): + """ + hidden_states: [batch, seq_len, hidden_size] + attention_mask: [(1, 1), seq_len, seq_len] + """ + + # [batch, seq_len, 3 * hidden_size] + mixed_raw_layer = qkv # self.query_key_value(hidden_states) + + # [batch, seq_len, 3 * hidden_size] --> [batch, seq_len, num_attention_heads, 3 * hidden_size_per_attention_head] + new_tensor_shape = mixed_raw_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head, + ) + mixed_raw_layer = mixed_raw_layer.view(*new_tensor_shape) + + # [batch, seq_len, num_attention_heads, hidden_size_per_attention_head] + (query_layer, key_layer, value_layer) = self.split_tensor_along_last_dim(mixed_raw_layer, 3) + + if self.position_encoding_2d: + q1, q2 = query_layer.chunk(2, dim=(query_layer.ndim - 1)) + k1, k2 = key_layer.chunk(2, dim=(key_layer.ndim - 1)) + cos, sin = self.rotary_emb(q1, seq_len=position_ids.max() + 1) + # xxx + position_ids, block_position_ids = position_ids[:, 0, :].contiguous(), \ + position_ids[:, 1, :].contiguous() + q1, k1 = apply_rotary_pos_emb_index(q1, k1, cos, sin, position_ids) + q2, k2 = apply_rotary_pos_emb_index(q2, k2, cos, sin, block_position_ids) + query_layer = torch.concat([q1, q2], dim=(q1.ndim - 1)) + key_layer = torch.concat([k1, k2], dim=(k1.ndim - 1)) + else: + position_ids = position_ids.transpose(0, 1) + cos, sin = self.rotary_emb(value_layer, seq_len=position_ids.max() + 1) + # [batch, seq_len, num_attention_heads, hidden_size_per_attention_head] + query_layer, key_layer = apply_rotary_pos_emb_index(query_layer, key_layer, cos, sin, position_ids) + + query_layer = query_layer.to(dtype=torch.bfloat16) + key_layer = key_layer.to(dtype=torch.bfloat16) + + # [batch, seq_len, hidden_size] + query_layer, key_layer, value_layer = self.attention_fn( + query_layer=query_layer, + key_layer=key_layer, + value_layer=value_layer, + layer_past=layer_past, + ) + + return query_layer, key_layer, value_layer + + +class GPTNeoXAttentionExt: + def __init__(self, num_attention_heads, hidden_size, head_size_aligned, max_position_embeddings, rotary_emb_base, rotary_pct, is_int8=False): + self.emd = ld.emb_gpt() + num_heads = num_attention_heads + head_size = hidden_size // num_attention_heads + max_seq_len = max_position_embeddings + + qkv_precision_name = 's8' if is_int8 else 'bf16' + dst_precision_name = 's8' if is_int8 else 'bf16' + rotary_ndims = int(head_size * rotary_pct) + self.emd.create(num_heads, head_size, head_size_aligned, qkv_precision_name, + dst_precision_name, rotary_ndims, True) + + inv_freq = 1. / (rotary_emb_base ** (torch.arange(0, rotary_ndims, 2).float() / rotary_ndims)) + #inv_freq = inv_freq.half() + self.max_seq_len_cached = None + self.cos_cached = None + self.sin_cached = None + # use f32 to pass accuracy test + # Build here to make `torch.jit.trace` work. + self.max_seq_len_cached = max_position_embeddings + t = torch.arange(self.max_seq_len_cached, device=inv_freq.device, dtype=inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.cos_cached = emb.cos()[:, None, :] + self.sin_cached = emb.sin()[:, None, :] + + # qkv: [batch, seq_len, (num_heads * 3 * head_size)] + # layer_past_padded: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] + # past_seq_len: past_seq_len==layer_past.shape[-2] + # return: + # 0: (k, v): ([batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned], [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned]) + # 1: query: [batch, num_attention_heads, seq_len, head_size_aligned] + # 2: k: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] + # 3: v: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] + def forward(self, qkv, layer_past_key_src, layer_past_value_src, layer_past_key_dst, layer_past_value_dst, query_padded, past_seq_len, position_ids): + self.emd.exec_position(qkv, layer_past_key_src, layer_past_value_src, layer_past_key_dst, layer_past_value_dst, query_padded, past_seq_len, position_ids, self.cos_cached, self.sin_cached) + + +HEAD_NUM = 32 +SIZE_PER_HEAD = 80 +SIZE_PER_HEAD_ALIGN = 96 +HIDDEN_SIZE = HEAD_NUM * SIZE_PER_HEAD +MAX_POSITION_EMBEDDINGS = 1024 #2048 +ROTARY_EMB_BASE = 10000 +ROTARY_PCT = 0.5 +MAX_SEQ_LEN = 1024 +def get_ref_model(): + ref_net = SelfAttention(MAX_POSITION_EMBEDDINGS, HIDDEN_SIZE, HEAD_NUM, 0, None) + ref_net = ref_net.to(dtype=torch.bfloat16) + return ref_net + +def test_chatglm(): + inputs = [ + # qkv: [batch, seq_len, (num_heads * 3 * head_size)] + # layer_past: [batch, num_attention_heads, past_seq_len, head_size] + (np.random.random(size=[2, 200, 3 * HEAD_NUM * SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 0, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 0, SIZE_PER_HEAD]).astype(np.float32)), + (np.random.random(size=[2, 1, 3 * HEAD_NUM * SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32)), + ] + ref_net = get_ref_model() + net = GPTNeoXAttentionExt(HEAD_NUM, HIDDEN_SIZE, SIZE_PER_HEAD_ALIGN, MAX_POSITION_EMBEDDINGS, ROTARY_EMB_BASE, ROTARY_PCT) + net_seq = GPTNeoXAttentionExt(HEAD_NUM, HIDDEN_SIZE, SIZE_PER_HEAD, MAX_POSITION_EMBEDDINGS, ROTARY_EMB_BASE, ROTARY_PCT) + with torch.cpu.amp.autocast(): + for (i, input) in enumerate(inputs): + qkv, layer_past_key, layer_past_value = input + qkv = torch.from_numpy(qkv).to(torch.bfloat16) + layer_past_key = torch.from_numpy(layer_past_key).to(torch.bfloat16) + layer_past_value = torch.from_numpy(layer_past_value).to(torch.bfloat16) + past_seq_len = layer_past_key.shape[-2] + shape = list(layer_past_key.shape) + + if qkv.size(1) > 1: + seq_batch1 = torch.arange(end=qkv.size(1) - 1, dtype=torch.int32) + seq_batch1 = torch.concat((seq_batch1, seq_batch1[-1:])) + block_batch1 = torch.concat((torch.zeros(qkv.size(1) -1, dtype=torch.int32), torch.ones(1, dtype=torch.int32))) + else: + seq_batch1 = torch.tensor([3], dtype=torch.int32) + block_batch1 = torch.tensor([5], dtype=torch.int32) + seq_ids = torch.empty((qkv.size(0), 2, qkv.size(1)), dtype=torch.int32) + seq_ids[:, 0, :] = seq_batch1 + seq_ids[:, 1, :] = block_batch1 + query_ref, key_ref, value_ref = ref_net.forward(qkv, seq_ids, (layer_past_key, layer_past_value)) + query_ref = query_ref.to(dtype=torch.bfloat16) + key_ref = key_ref.to(dtype=torch.bfloat16) + + # no prealloc past kv + layer_past_key_seq = torch.zeros(key_ref.shape, dtype=torch.bfloat16) + layer_past_value_seq = torch.zeros(value_ref.shape, dtype=torch.bfloat16) + query_seq = torch.zeros(query_ref.shape, dtype=torch.bfloat16) + net_seq.forward(qkv, layer_past_key, layer_past_value, layer_past_key_seq, layer_past_value_seq, query_seq, past_seq_len, seq_ids) + query, key, value = query_seq, layer_past_key_seq, layer_past_value_seq + # check query + if not torch.allclose(query_ref, query, rtol=0.001, atol=0.01): + print(f"error at sequence query index {i} ref:\n{query_ref} \ncur:\n {query} ") + #assert(False) + # check key + if not torch.allclose(key_ref, key, rtol=0.001, atol=0.01): + print(f"error at sequence key index {i} ref:\n{key_ref} \ncur:\n {key} ") + assert(False) + # check value + if not torch.allclose(value_ref, value, rtol=0.001, atol=0.01): + print(f"error at sequence value index {i} ref:\n{value_ref} \ncur:\n {value} ") + assert(False) + + # prealloc past kv + shape[-2] = MAX_SEQ_LEN + shape[-1] = SIZE_PER_HEAD_ALIGN + layer_past_key_padded = torch.zeros(shape, dtype=torch.bfloat16) + layer_past_value_padded = torch.zeros(shape, dtype=torch.bfloat16) + query_shape = list(shape) + query_shape[2] = qkv.shape[1] + query_padded = torch.zeros(query_shape, dtype=torch.bfloat16) + layer_past_key_padded[:,:,:layer_past_key.shape[-2],:layer_past_key.shape[-1]] = layer_past_key + layer_past_value_padded[:,:,:layer_past_key.shape[-2],:layer_past_key.shape[-1]] = layer_past_value + net.forward(qkv, layer_past_key_padded, layer_past_value_padded, layer_past_key_padded, layer_past_value_padded, query_padded, past_seq_len, seq_ids) + query, key, value = query_padded, layer_past_key_padded, layer_past_value_padded + # check query + if not torch.allclose(query_ref, query[:,:,:,:query_ref.shape[-1]], rtol=0.001, atol=0.01): + print(f"error at query index {i} ref:\n{query_ref} \ncur:\n {query} ") + assert(False) + # check key + if not torch.allclose(key_ref, key[:,:,:key_ref.shape[-2],:key_ref.shape[-1]], rtol=0.001, atol=0.01): + print(f"error at key index {i} ref:\n{key_ref} \ncur:\n {key} ") + assert(False) + # check value + if not torch.allclose(value_ref, value[:,:,:value_ref.shape[-2],:value_ref.shape[-1]], rtol=0.001, atol=0.01): + print(f"error at value index {i} ref:\n{value_ref} \ncur:\n {value} ") + assert(False) + + print('done.') + return + +if __name__ == "__main__": + test_chatglm() \ No newline at end of file diff --git a/tests/src/test_common.cpp b/tests/src/test_common.cpp new file mode 100644 index 0000000..a96a268 --- /dev/null +++ b/tests/src/test_common.cpp @@ -0,0 +1,80 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include +#include +#include +#include +#include +#include "common/simple_parallel.hpp" +#include "test_common.hpp" + +#ifndef _GNU_SOURCE +#define _GNU_SOURCE /* See feature_test_macros(7) */ +#endif +#include +#include /* For SYS_xxx definitions */ + +#define XFEATURE_XTILECFG 17 +#define XFEATURE_XTILEDATA 18 +#define XFEATURE_MASK_XTILECFG (1 << XFEATURE_XTILECFG) +#define XFEATURE_MASK_XTILEDATA (1 << XFEATURE_XTILEDATA) +#define XFEATURE_MASK_XTILE (XFEATURE_MASK_XTILECFG | XFEATURE_MASK_XTILEDATA) +#define ARCH_GET_XCOMP_PERM 0x1022 +#define ARCH_REQ_XCOMP_PERM 0x1023 + +using namespace std; +using namespace llmdnn; + +std::string dtype_to_str(data_type_t type) { + switch (type) { + case dnnl_data_type_undef: return "undef"; + case dnnl_f16: return "f16"; + case dnnl_bf16: return "bf16"; + case dnnl_f32: return "f32"; + case dnnl_s32: return "s32"; + case dnnl_s8: return "s8"; + case dnnl_u8: return "u8"; + case dnnl_f64: return "f64"; + default: return "unkown"; + } +} + +bool initXTILE() { + unsigned long bitmask = 0; + long status = syscall(SYS_arch_prctl, ARCH_GET_XCOMP_PERM, &bitmask); + if (0 != status) return false; + if (bitmask & XFEATURE_MASK_XTILEDATA) return true; + + status = syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA); + if (0 != status) + return false; // XFEATURE_XTILEDATA setup is failed, TMUL usage is not allowed + status = syscall(SYS_arch_prctl, ARCH_GET_XCOMP_PERM, &bitmask); + + // XFEATURE_XTILEDATA setup is failed, can't use TMUL + if (0 != status || !(bitmask & XFEATURE_MASK_XTILEDATA)) return false; + + // XFEATURE_XTILEDATA set successfully, TMUL usage is allowed + return true; +} + +namespace ov { +namespace cpu { + +size_t getTotalThreads() { + return omp_get_max_threads(); +} + +void TrySimpleParallelFor(const std::ptrdiff_t total, const std::function& fn) { + #pragma omp parallel for + for(std::ptrdiff_t i = 0; i < total; i++) { + fn(i); + } +} + +} +} \ No newline at end of file diff --git a/tests/src/test_common.hpp b/tests/src/test_common.hpp new file mode 100644 index 0000000..f9acf1e --- /dev/null +++ b/tests/src/test_common.hpp @@ -0,0 +1,220 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include +#include +#include +#include "llm_types.hpp" +#include "llm_fc.hpp" +#include "common/tensor2d.hpp" +#include "common/bf16.hpp" +#ifdef _WIN32 +#include +#else +#include +#endif +#include + +#define rndup(x, n) (((x + n - 1)/n)*n) + +std::string dtype_to_str(llmdnn::data_type_t type); + +using func_act = std::function; + +bool initXTILE(); + +template +void matmul(tensor2D & A, + tensor2D & B, + tensor2D & C, + float * dq = nullptr, + float * bias = nullptr, + func_act act = func_act(), + float * q = nullptr) { + int M = C.dims[0]; + int N = C.dims[1]; + int K = A.dims[1]; + assert(B.dims[0] == K); + assert(B.dims[1] == N); + for(int m = 0; m < M; m++) { + for(int n = 0; n < N; n++) { + float sum = C(m,n); + int k; + for (k = 0; (k + 32) <= K; k += 32) { + float psum0 = 0; + float psum1 = 0; + for(int p = 0; p < 32; p+=2) { + psum0 += static_cast(A(m,k+p)) * static_cast(B(k+p,n)); + psum1 += static_cast(A(m,k+p+1)) * static_cast(B(k+p+1,n)); + } + sum += (psum0 + psum1); + } + for(; k < K; k++) { + sum += static_cast(A(m,k)) * static_cast(B(k,n)); + } + if (bias) { + sum += bias[n]; + } + if (act) { + sum = act(sum); + } + //std::cout << m << "," << n << std::endl; + C(m,n) = sum; + } + } +} + +inline void matmul(tensor2D & A, + tensor2D & B, + tensor2D & C, + float * dq = nullptr, + float * bias = nullptr, + func_act act = func_act(), + float * q = nullptr) { + int M = C.dims[0]; + int N = C.dims[1]; + int K = A.dims[1]; + assert(B.dims[0] == K); + assert(B.dims[1] == N); + for(int m = 0; m < M; m++) { + for(int n = 0; n < N; n++) { + float sum = C(m,n); + for(int k = 0; k < K; k++) { + sum += static_cast(A(m,k)) * static_cast(B(k,n)); + } + if (bias) { + sum += bias[n]; + } + if (act) { + sum = act(sum); + } + C(m,n) = sum; + } + } +} + +template +void matmul(tensor2D & A, + tensor2D & B, + tensor2D & C, + float * dq = nullptr, + float * bias = nullptr, + func_act act = func_act(), + float * q = nullptr) { + int M = C.dims[0]; + int N = C.dims[1]; + int K = A.dims[1]; + assert(B.dims[0] == K); + assert(B.dims[1] == N); + for(int m = 0; m < M; m++) { + for(int n = 0; n < N; n++) { + float sum = C(m,n); + for(int k = 0; k < K; k++) { + sum += static_cast(A(m,k)) * static_cast(B(k,n)); + } + if (dq) { + sum *= dq[n]; + } + if (bias) { + sum += bias[n]; + } + if (act) { + sum = act(sum); + } + if (q) { + sum *= q[n]; + sum = std::min(static_cast(std::numeric_limits::max()), sum); + sum = std::max(static_cast(std::numeric_limits::min()), sum); + } + C(m,n) = sum; + } + } +} + +struct ANSIcolor { + const char * code; + ANSIcolor(const char * code = "0") : code(code){ + } + friend std::ostream& operator<<(std::ostream& out, const ANSIcolor& obj) { + out << "\033[" << obj.code << "m"; + return out; + } +}; + +struct pretty_size { + double sz; + std::string txt; + pretty_size(double sz, const char * unit = "") : sz(sz) { + std::stringstream ss; + ss << std::setprecision(5) << std::setw(5); + if (sz < 1024) + ss << sz; + else if (sz < 1024 * 1024) + ss << (sz / 1024) << " K"; + else if (sz < 1024 * 1024 * 1024) + ss << (sz / 1024/1024) << " M"; + else + ss << (sz / 1024 / 1024/1024) << " G"; + ss << unit; + txt = ss.str(); + } + friend std::ostream& operator<<(std::ostream& os, const pretty_size& ps) { + os << ps.txt; + return os; + } +}; + +inline int readenv(const char * name) { + int v = 0; + auto * p = std::getenv(name); + if (p) + v = std::atoi(p); + std::cout << ANSIcolor("32") << "ENV: " << name << " = " << v << std::endl << ANSIcolor(); + return v; +} + +template +struct TypeName { + static const char* get(){return typeid(T).name();} +}; + +// a specialization for each type of those you want to support +// and don't like the string returned by typeid +template <> +struct TypeName{ + static const char* get(){return "int32_t";} +}; +template <> +struct TypeName{ + static const char* get(){return "foat";} +}; +template <> +struct TypeName{ + static const char* get(){return "bfloat16";} +}; +template <> +struct TypeName{ + static const char* get(){return "int8_t";} +}; + +inline std::ostream & operator<<(std::ostream & os, const llmdnn::postops_types & steps) { + if (steps == llmdnn::NONE) + os << "NONE"; + if (steps & llmdnn::DEQUANT) + os << "_DEQUANT"; + if (steps & llmdnn::BIAS) + os << "_BIAS"; + if (steps & llmdnn::GELU) + os << "_GELU"; + if (steps & llmdnn::GELU_TANH) + os << "_GELU_TANH"; + if (steps & llmdnn::QUANT) + os << "_QUANT"; + return os; +} diff --git a/tests/src/test_fc_kernel_amx.cpp b/tests/src/test_fc_kernel_amx.cpp new file mode 100644 index 0000000..0046541 --- /dev/null +++ b/tests/src/test_fc_kernel_amx.cpp @@ -0,0 +1,232 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include +#include +#include +#include +#include "gtest/gtest.h" +#include "llm_fc.hpp" +#include "common/tensor2d.hpp" +#include "common/tensor2d_helper.hpp" +#include "test_common.hpp" + +using namespace std; +using namespace llmdnn; +using ::testing::TestWithParam; +using ::testing::Values; +using ::testing::ValuesIn; + +using FCKernelTestShape = std::tuple; +using FCKernelTestDTPost = std::tuple; +using FCKernelTestParamSet = std::tuple< + FCKernelTestDTPost, // a, b, c data type, postops + bool, // b needs transpose + FCKernelTestShape // M, N, K + >; + +class FCKernelTest : public TestWithParam { +public: + static std::string getTestCaseName(const testing::TestParamInfo& obj) { + FCKernelTestDTPost types; + bool is_transpose; + postops_types postops_type; + data_type_t dt_a, dt_b, dt_c; + FCKernelTestShape shape; + int M, N, K; + std::tie(types, is_transpose, shape) = obj.param; + std::tie(M, N, K) = shape; + std::tie(dt_a, dt_b, dt_c, postops_type) = types; + + std::ostringstream result; + result << "A_" << dtype_to_str(dt_a) << "_B_" << dtype_to_str(dt_b) + << "_C_" << dtype_to_str(dt_c) << (is_transpose ? "_transpose" : "") + << "_postops_" << postops_type << "_M_" << M << "_N_" << N << "_K_" << K; + return result.str(); + } + +protected: + virtual void SetUp() override { + initXTILE(); + + FCKernelTestShape shape; + FCKernelTestDTPost types; + std::tie(types, _is_transpose, shape) = GetParam(); + std::tie(_M, _N, _K) = shape; + std::tie(_dt_a, _dt_b, _dt_c, _postops_type) = types; + }; + + template + void do_test() { + fc_kernel* fc; + fc_create_param param = { + _dt_a, _dt_b, _dt_c, + _is_transpose, _postops_type + }; + ASSERT_TRUE(fc_kernel_create(&fc, ¶m)); + auto gemm = std::shared_ptr(fc, [](fc_kernel* p) { fc_kernel_destroy(p); }); + + tensor2D A(_M, _K, true); + tensor2D B(_K, _N, true); + tensor2D C(_M, _N, true); + tensor2D C_Ref(_M, _N, true); + tensor2D dq(1, _N); + tensor2D q(1, _N); + tensor2D bias(1, _N); + + fill_rnd(A); + fill_rnd(B); + dq = 2; + q = 2; + fill_rnd(bias); + bias = 1; + + tensor2D BT = B.Tr(); + TB* ptr_B; + size_t ldb; + if (_is_transpose) { + ptr_B = BT.data; + ldb = BT.stride; + } else { + ptr_B = B.data; + ldb = B.stride; + } + fc_kernel_execute(gemm.get(), A.data, ptr_B, C.data, A.stride, ldb, + C.stride, _M, _N, _K, 0, _N, dq.data, q.data, bias.data); + C_Ref = 0; + float* ptr_dq = nullptr; + float* ptr_q = nullptr; + float* ptr_bias = nullptr; + func_act act = func_act(); + if (_postops_type & DEQUANT) { + ptr_dq = dq.data; + } + if (_postops_type & QUANT) { + ptr_q = q.data; + } + if (_postops_type & BIAS) { + ptr_bias = bias.data; + } + if (_postops_type & GELU) { + act = [] (float x) { + return x * 0.5 * (1 + std::erf(x / std::sqrt(2))); + }; + } + if (_postops_type & GELU_TANH) { + act = [] (float x) { + return 0.5f * x * (1.0f + std::tanh(std::sqrt(2.0f / 3.1415926f) * x * (1 + 0.044715f * x * x))); + }; + } + + matmul(A, B, C_Ref, ptr_dq, ptr_bias, act, ptr_q); + float thresh = 0.0001f; + if (std::is_same::value || std::is_same::value) + thresh = 1.1f; + if (std::is_same::value) + thresh = 0.01f; + ASSERT_TRUE(compare(C, C_Ref, thresh)); + } + + int _M, _N, _K; + bool _is_transpose; + postops_types _postops_type; + data_type_t _dt_a, _dt_b, _dt_c; +}; + +TEST_P(FCKernelTest, Func) { + if (_dt_a == dnnl_s8 && _dt_b == dnnl_s8 && _dt_c == dnnl_s8) { + do_test(); + } else if (_dt_a == dnnl_s8 && _dt_b == dnnl_s8 && _dt_c == dnnl_bf16) { + do_test(); + } else if (_dt_a == dnnl_s8 && _dt_b == dnnl_s8 && _dt_c == dnnl_f32) { + do_test(); + } else if (_dt_a == dnnl_bf16 && _dt_b == dnnl_bf16 && _dt_c == dnnl_bf16) { + do_test(); + } else if (_dt_a == dnnl_bf16 && _dt_b == dnnl_bf16 && _dt_c == dnnl_f32) { + do_test(); + } else if (_dt_a == dnnl_bf16 && _dt_b == dnnl_s8 && _dt_c == dnnl_f32) { + do_test(); + } else if (_dt_a == dnnl_bf16 && _dt_b == dnnl_s8 && _dt_c == dnnl_bf16) { + do_test(); + } else { + ASSERT_TRUE(false); + } +} + +// supported: +// (s8,s8,s8),dq,[bias],[gelu],q +// (s8,s8,bf16),dq,[bias],[gelu] +// (s8,s8,f32),dq,[bias],[gelu] +// (bf16,bf16,bf16),[bias],[gelu] +// (bf16,bf16,f32),[bias],[gelu] +// (bf16,s8,f32),dq,[bias],[gelu] +// (bf16,s8,bf16),dq,[bias],[gelu] +const std::vector types = { + { dnnl_s8, dnnl_s8, dnnl_s8, DEQUANT_QUANT }, + { dnnl_s8, dnnl_s8, dnnl_s8, DEQUANT_BIAS_QUANT }, + { dnnl_s8, dnnl_s8, dnnl_s8, DEQUANT_GELU_QUANT }, + { dnnl_s8, dnnl_s8, dnnl_s8, DEQUANT_BIAS_GELU_QUANT }, + { dnnl_s8, dnnl_s8, dnnl_s8, DEQUANT_GELU_TANH_QUANT }, + { dnnl_s8, dnnl_s8, dnnl_s8, DEQUANT_BIAS_GELU_TANH_QUANT }, + { dnnl_s8, dnnl_s8, dnnl_bf16, DEQUANT }, + { dnnl_s8, dnnl_s8, dnnl_bf16, DEQUANT_BIAS }, + { dnnl_s8, dnnl_s8, dnnl_bf16, DEQUANT_GELU }, + { dnnl_s8, dnnl_s8, dnnl_bf16, DEQUANT_BIAS_GELU }, + { dnnl_s8, dnnl_s8, dnnl_bf16, DEQUANT_GELU_TANH }, + { dnnl_s8, dnnl_s8, dnnl_bf16, DEQUANT_BIAS_GELU_TANH }, + { dnnl_s8, dnnl_s8, dnnl_f32, DEQUANT }, + { dnnl_s8, dnnl_s8, dnnl_f32, DEQUANT_BIAS }, + { dnnl_s8, dnnl_s8, dnnl_f32, DEQUANT_GELU }, + { dnnl_s8, dnnl_s8, dnnl_f32, DEQUANT_BIAS_GELU }, + { dnnl_s8, dnnl_s8, dnnl_f32, DEQUANT_GELU_TANH }, + { dnnl_s8, dnnl_s8, dnnl_f32, DEQUANT_BIAS_GELU_TANH }, + { dnnl_bf16, dnnl_bf16, dnnl_bf16, NONE }, + { dnnl_bf16, dnnl_bf16, dnnl_bf16, BIAS }, + { dnnl_bf16, dnnl_bf16, dnnl_bf16, GELU }, + { dnnl_bf16, dnnl_bf16, dnnl_bf16, BIAS_GELU }, + { dnnl_bf16, dnnl_bf16, dnnl_bf16, GELU_TANH }, + { dnnl_bf16, dnnl_bf16, dnnl_bf16, BIAS_GELU_TANH }, + { dnnl_bf16, dnnl_bf16, dnnl_f32, NONE }, + { dnnl_bf16, dnnl_bf16, dnnl_f32, BIAS }, + { dnnl_bf16, dnnl_bf16, dnnl_f32, GELU }, + { dnnl_bf16, dnnl_bf16, dnnl_f32, BIAS_GELU }, + { dnnl_bf16, dnnl_bf16, dnnl_f32, GELU_TANH }, + { dnnl_bf16, dnnl_bf16, dnnl_f32, BIAS_GELU_TANH }, + // TODO: support weight compression + // { dnnl_bf16, dnnl_s8, dnnl_f32, DEQUANT }, + // { dnnl_bf16, dnnl_s8, dnnl_f32, DEQUANT_BIAS }, + // { dnnl_bf16, dnnl_s8, dnnl_f32, DEQUANT_GELU }, + // { dnnl_bf16, dnnl_s8, dnnl_f32, DEQUANT_BIAS_GELU }, + // { dnnl_bf16, dnnl_s8, dnnl_bf16, DEQUANT }, + // { dnnl_bf16, dnnl_s8, dnnl_bf16, DEQUANT_BIAS }, + // { dnnl_bf16, dnnl_s8, dnnl_bf16, DEQUANT_GELU }, + // { dnnl_bf16, dnnl_s8, dnnl_bf16, DEQUANT_BIAS_GELU }, +}; + +// M, N, K +const std::vector shapes = { + // normal + {256, 48, 448}, + // k tail + {256, 48, 449}, + // M tail == unroll 8 + {256 + 8, 48, 449}, + // M tail == unroll 8 + 2 + {256 + 10, 48, 449}, + // N tail + {256, 40, 448}, + // all tail + {256 + 9, 47, 449}, + // gemv, K <= 64(32)*6 + {256, 1, 80}, +}; + +INSTANTIATE_TEST_SUITE_P(smoke_FCKernel, FCKernelTest, + ::testing::Combine(ValuesIn(types), + Values(true, false), + ValuesIn(shapes)), + FCKernelTest::getTestCaseName); diff --git a/tests/src/test_gelu_kernel_avx512.cpp b/tests/src/test_gelu_kernel_avx512.cpp new file mode 100644 index 0000000..70a9a27 --- /dev/null +++ b/tests/src/test_gelu_kernel_avx512.cpp @@ -0,0 +1,97 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include +#include +#include +#include +#include +#include "gtest/gtest.h" +#include "llm_mm.hpp" +#include "common/tensor2d.hpp" +#include "common/tensor2d_helper.hpp" +#include "gelu_kernel_avx512.hpp" +#include "test_common.hpp" + +using namespace std; +using namespace llmdnn; +using ::testing::TestWithParam; +using ::testing::Values; +using ::testing::ValuesIn; + +using GeluTestParamSet = std::tuple< + std::string // data type + >; + +class GeluTest : public TestWithParam { +public: + static std::string getTestCaseName(const testing::TestParamInfo& obj) { + std::string types; + std::tie(types) = obj.param; + + std::ostringstream result; + result << types; + return result.str(); + } + +protected: + virtual void SetUp() override { + std::tie(_types) = GetParam(); + }; + + void gelu_ref(float* s, float* d, int n) { + if (_types == "tanh") { + for (int i = 0; i < n; i++) { + auto x = s[i]; + d[i] = 0.5f * x * (1.0f + std::tanh(std::sqrt(2.0f / 3.1415926f) * x * (1 + 0.044715f * x * x))); + } + } else { + for (int i = 0; i < n; i++) { + auto x = s[i]; + d[i] = 0.5f * x * (1.0f + std::erf(x / std::sqrt(2.0f))); + } + } + } + + void test(float thresh) { + tensor2D src(1, 16, true); + tensor2D dst(1, 16, true); + tensor2D ref(1, 16, true); + for (int i = 0; i < 16; i++) { + src[i] = std::sqrt(i) - 2; + } + __m512 s = _mm512_loadu_ps(src.data); + __m512 d; + if (_types == "tanh") + d = gelu_tanh_avx512(s); + else + d = gelu_erf_minmax_approx_avx512(s); + _mm512_storeu_ps(dst.data, d); + gelu_ref(src.data, ref.data, 16); + for (int i = 0; i < src.dims[1]; i++) { + float r = ref[i]; + float c = dst[i]; + if (std::abs(r - c) > thresh) { + FAIL() << " cur is not equal, pos: " << i << " opt: " << c << " ref: " << r; + } + } + } + + std::string _types; +}; + +TEST_P(GeluTest, Gelu) { + test(0.01f); +} + +const std::vector types = { + "tanh", "erf" +}; + +INSTANTIATE_TEST_SUITE_P(smoke_Gelu, GeluTest, + ::testing::Combine(ValuesIn(types)), + GeluTest::getTestCaseName); diff --git a/tests/src/test_mm_kernel_amx.cpp b/tests/src/test_mm_kernel_amx.cpp new file mode 100644 index 0000000..ce0fcb2 --- /dev/null +++ b/tests/src/test_mm_kernel_amx.cpp @@ -0,0 +1,139 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include +#include +#include +#include +#include "gtest/gtest.h" +#include "llm_mm.hpp" +#include "common/tensor2d.hpp" +#include "common/tensor2d_helper.hpp" +#include "test_common.hpp" + +using namespace std; +using namespace llmdnn; +using ::testing::TestWithParam; +using ::testing::Values; +using ::testing::ValuesIn; + +using MMKernelTestShape = std::tuple; +using MMKernelTestParamSet = std::tuple< + std::pair, // a, b data type + bool, // b needs transpose + MMKernelTestShape // M, N, K + >; + +class GemmKernelTest : public TestWithParam { +public: + static std::string getTestCaseName(const testing::TestParamInfo& obj) { + std::pair types; + bool is_transpose; + MMKernelTestShape shape; + int M, N, K; + std::tie(types, is_transpose, shape) = obj.param; + std::tie(M, N, K) = shape; + + std::ostringstream result; + result << "A_" << dtype_to_str(types.first) << "_B_" << dtype_to_str(types.second) << "_" + << (is_transpose ? "transpose_" : "") + << "M_" << M << "_N_" << N << "_K_" << K; + return result.str(); + } + +protected: + virtual void SetUp() override { + initXTILE(); + + MMKernelTestShape shape; + std::tie(_types, _is_transpose, shape) = GetParam(); + std::tie(_M, _N, _K) = shape; + }; + + template + void test() { + if (_N == 1 && (_is_transpose || _types.first == dnnl_u8)) { + GTEST_SKIP() << "gemv does not support transpose or u8s8."; + } + mm_kernel* mm; + mm_create_param param = { + _types.first, _types.second, + _N == 1, _is_transpose + }; + ASSERT_TRUE(mm_kernel_create(&mm, ¶m)); + auto gemm = std::shared_ptr(mm, [](mm_kernel* p) { mm_kernel_destroy(p); }); + + tensor2D A(_M, _K, true); + tensor2D B(_K, _N, true); + tensor2D C(_M, _N, true); + tensor2D C_Ref(_M, _N, true); + + fill_rnd(A); + fill_rnd(B); + tensor2D BT = B.Tr(); + TB* ptr_B; + size_t ldb; + if (_is_transpose) { + ptr_B = BT.data; + ldb = BT.stride; + } else { + ptr_B = B.data; + ldb = B.stride; + } + mm_kernel_execute(gemm.get(), A.data, ptr_B, C.data, A.stride, ldb, + C.stride, _M, _N, _K); + C_Ref = 0; + matmul(A, B, C_Ref); + ASSERT_TRUE(C_Ref == C); + } + + int _M, _N, _K; + std::pair _types; + bool _is_transpose; +}; + +TEST_P(GemmKernelTest, Func) { + if (_types.first == dnnl_u8 && _types.second == dnnl_s8) { + test(); + } else if (_types.first == dnnl_s8 && _types.second == dnnl_s8) { + test(); + } else { + test(); + } +} + +const std::vector> types = { + { dnnl_u8, dnnl_s8 }, + { dnnl_s8, dnnl_s8 }, + { dnnl_bf16, dnnl_bf16 }, +}; + +// M, N, K +const std::vector shapes = { + // normal + {256, 48, 448}, + // k < 32 + {256, 48, 15}, + // k tail + {256, 48, 449}, + // M tail == unroll 8 + {256 + 8, 48, 449}, + // M tail == unroll 8 + 2 + {256 + 10, 48, 449}, + // N tail + {256, 40, 448}, + // all tail + {256 + 9, 47, 449}, + // gemv, K <= 64(32)*6 + {256, 1, 160}, +}; + +INSTANTIATE_TEST_SUITE_P(smoke_GemmKernel, GemmKernelTest, + ::testing::Combine(ValuesIn(types), + Values(true, false), + ValuesIn(shapes)), + GemmKernelTest::getTestCaseName); diff --git a/tests/src/test_rotary_kernel_avx512.cpp b/tests/src/test_rotary_kernel_avx512.cpp new file mode 100644 index 0000000..6be5829 --- /dev/null +++ b/tests/src/test_rotary_kernel_avx512.cpp @@ -0,0 +1,109 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include +#include +#include +#include +#include "gtest/gtest.h" +#include "llm_mm.hpp" +#include "common/tensor2d.hpp" +#include "common/tensor2d_helper.hpp" +#include "rotary_kernel_avx512.hpp" +#include "test_common.hpp" + +using namespace std; +using namespace llmdnn; +using ::testing::TestWithParam; +using ::testing::Values; +using ::testing::ValuesIn; + +using RotaryTestParamSet = std::tuple< + data_type_t // data type + >; + +class RotaryTest : public TestWithParam { +public: + static std::string getTestCaseName(const testing::TestParamInfo& obj) { + data_type_t types; + std::tie(types) = obj.param; + + std::ostringstream result; + result << dtype_to_str(types); + return result.str(); + } + +protected: + virtual void SetUp() override { + std::tie(_types) = GetParam(); + }; + + template + static void rotary_emb(size_t rotaryNdims, float* cos, float* sin, T* q_src, T* k_src, T* q_dst, T* k_dst) { + auto halfRotaryNdims = rotaryNdims / 2; + for (size_t i = 0; i < halfRotaryNdims; i++) { + q_dst[i] = q_src[i] * cos[i] - q_src[i + halfRotaryNdims] * sin[i]; + k_dst[i] = k_src[i] * cos[i] - k_src[i + halfRotaryNdims] * sin[i]; + } + for (size_t i = halfRotaryNdims; i < rotaryNdims; i++) { + q_dst[i] = q_src[i] * cos[i] + q_src[i - halfRotaryNdims] * sin[i]; + k_dst[i] = k_src[i] * cos[i] + k_src[i - halfRotaryNdims] * sin[i]; + } + } + + template + void test(float thresh) { + for (int n = 6; n < 129; n += 2) { + tensor2D q_src(1, n, true); + tensor2D k_src(1, n, true); + tensor2D q_dst(1, n, true); + tensor2D k_dst(1, n, true); + tensor2D q_dst_ref(1, n, true); + tensor2D k_dst_ref(1, n, true); + tensor2D cos(1, n, true); + tensor2D sin(1, n, true); + for (int i = 0; i < n; i++) { + q_src[i] = i % 19 - 10; + k_src[i] = i % 19 - 9; + cos[i] = i % 19 - 8; + sin[i] = i % 19 - 7; + } + rotary_emb(n, cos.data, sin.data, q_src.data, k_src.data, q_dst_ref.data, k_dst_ref.data); + rotary_avx512(n, cos.data, sin.data, q_src.data, k_src.data, q_dst.data, k_dst.data); + for (int i = 0; i < n; i++) { + float q = q_dst[i]; + float q_ref = q_dst_ref[i]; + float k = k_dst[i]; + float k_ref = k_dst_ref[i]; + if (std::abs(q - q_ref) > thresh) { + FAIL() << " q is not equal, N: " << n << " pos: " << i << " opt: " << q << " ref: " << q_ref; + } + if (std::abs(k - k_ref) > thresh) { + FAIL() << " k is not equal, N: " << n << " pos: " << i << " opt: " << k << " ref: " << k_ref; + } + } + } + } + + data_type_t _types; +}; + +TEST_P(RotaryTest, rotary) { + if (_types == dnnl_s8) { + ASSERT_TRUE(false); + } else { + test(0.01f); + } +} + +const std::vector types = { + dnnl_bf16 +}; + +INSTANTIATE_TEST_SUITE_P(smoke_Rotary, RotaryTest, + ::testing::Combine(ValuesIn(types)), + RotaryTest::getTestCaseName); diff --git a/tests/src/test_softmax_kernel_avx512.cpp b/tests/src/test_softmax_kernel_avx512.cpp new file mode 100644 index 0000000..a78fa1f --- /dev/null +++ b/tests/src/test_softmax_kernel_avx512.cpp @@ -0,0 +1,133 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include +#include +#include +#include +#include "gtest/gtest.h" +#include "llm_mm.hpp" +#include "common/tensor2d.hpp" +#include "common/tensor2d_helper.hpp" +#include "softmax_kernel_avx512.hpp" +#include "test_common.hpp" + +using namespace std; +using namespace llmdnn; +using ::testing::TestWithParam; +using ::testing::Values; +using ::testing::ValuesIn; + +using SoftmaxTestParamSet = std::tuple< + data_type_t // data type + >; + +class SoftmaxTest : public TestWithParam { +public: + static std::string getTestCaseName(const testing::TestParamInfo& obj) { + data_type_t types; + std::tie(types) = obj.param; + + std::ostringstream result; + result << dtype_to_str(types); + return result.str(); + } + +protected: + virtual void SetUp() override { + std::tie(_types) = GetParam(); + }; + + template + static void gen_ref(tensor2D& x, tensor2D& out, tensor2D& quant) { + tensor2D y = x.clone(); + float x_max = std::numeric_limits::lowest(); + for(int i = 0; i < x.dims[1]; i++) { + x_max = std::max(x_max, x[i]); + } + float sum = 0; + for(int i = 0; i < x.dims[1]; i++) { + y[i] = expf(x[i] - x_max); + sum += y[i]; + } + for(int i = 0; i < x.dims[1]; i++) { + y[i] = y[i] / sum; + } + out.resize(x.dims[0], x.dims[1], true); + if (std::is_same::value) { + memcpy(static_cast(out.data), y.data, x.dims[0] * x.dims[1] * sizeof(float)); + } + if (std::is_same::value) { + for(int i = 0; i < x.dims[1]; i++) { + out[i] = y[i]; + } + } +#define CLIP(x, low, high) \ + (x < low ? low : (x > high ? high : x)) + if (std::is_same::value) { + for(int i = 0; i < x.dims[1]; i++) { + auto tmp = y[i] * quant[i]; + out[i] = static_cast(CLIP(tmp, -128, 127)); + } + } + if (std::is_same::value) { + for(int i = 0; i < x.dims[1]; i++) { + auto tmp = y[i] * quant[i]; + out[i] = static_cast(CLIP(tmp, 0, 255)); + } + } +#undef CLIP + } + + template + void test(float thresh) { + for (int n = 1; n < 129; n++) { + tensor2D A(1, n, true), A_scalar(1, n, true); + tensor2D quant(1, n, true); + tensor2D out(1, n, true), out_ref, out_scalar(1, n, true); + for (int i = 0; i < n; i++) { + A[i] = static_cast(i) - n / 2; + A_scalar[i] = A[i]; + } + quant = 128.f; + gen_ref(A, out_ref, quant); + llmdnn::softmax_avx512(out.data, A.data, n, quant.data); + for (int i = 0; i < n; i++) { + float a = out[i]; + float b = out_ref[i]; + if (std::abs(a - b) > thresh) { + FAIL() << " N: " << n << " pos: " << i << " opt: " << a << " ref: " << b; + } + } + // input is scalar + llmdnn::softmax_avx512(out_scalar.data, A_scalar.data, n, quant.data[0]); + ASSERT_TRUE(out == out_scalar); + } + } + + data_type_t _types; +}; + +TEST_P(SoftmaxTest, Func) { + if (_types == dnnl_s8) { + test(1.1f); + } else if (_types == dnnl_u8) { + test(1.1f); + } else if (_types == dnnl_f32) { + test(0.00001f); + } else { + test(0.01f); + } +} + +const std::vector types = { + dnnl_s8, dnnl_bf16, dnnl_u8, dnnl_f32 +}; + +INSTANTIATE_TEST_SUITE_P(smoke_Softmax, SoftmaxTest, + ::testing::Combine(ValuesIn(types)), + SoftmaxTest::getTestCaseName); diff --git a/tests/src/test_transpose_kernel_avx512.cpp b/tests/src/test_transpose_kernel_avx512.cpp new file mode 100644 index 0000000..46035b5 --- /dev/null +++ b/tests/src/test_transpose_kernel_avx512.cpp @@ -0,0 +1,131 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include +#include +#include +#include +#include "gtest/gtest.h" +#include "llm_mm.hpp" +#include "common/tensor2d.hpp" +#include "common/tensor2d_helper.hpp" +#include "transpose_kernel_avx512.hpp" +#include "test_common.hpp" + +using namespace std; +using namespace llmdnn; +using ::testing::TestWithParam; +using ::testing::Values; +using ::testing::ValuesIn; + +using TransposeTestParamSet = std::tuple< + data_type_t // data type + >; + +class TransposeTest : public TestWithParam { +public: + static std::string getTestCaseName(const testing::TestParamInfo& obj) { + data_type_t types; + std::tie(types) = obj.param; + + std::ostringstream result; + result << dtype_to_str(types); + return result.str(); + } + +protected: + virtual void SetUp() override { + std::tie(_types) = GetParam(); + }; + + template + static void gen_ref(T* dst, float* src, size_t height, size_t width, size_t src_stride, size_t dst_stride, float* quant) { + for (size_t j = 0; j < height; j++) { + if (std::is_same::value) { + memcpy(static_cast(dst), src, width * sizeof(float)); + } + if (std::is_same::value) { + for(size_t i = 0; i < width; i++) { + dst[i] = src[i]; + } + } + #define CLIP(x, low, high) \ + (x < low ? low : (x > high ? high : x)) + if (std::is_same::value) { + for(size_t i = 0; i < width; i++) { + auto tmp = src[i] * quant[i]; + dst[i] = static_cast(CLIP(tmp, -128, 127)); + } + } + if (std::is_same::value) { + for(size_t i = 0; i < width; i++) { + auto tmp = src[i] * quant[i]; + dst[i] = static_cast(CLIP(tmp, 0, 255)); + } + } + #undef CLIP + src = reinterpret_cast(reinterpret_cast(src) + src_stride); + dst = reinterpret_cast(reinterpret_cast(dst) + dst_stride); + } + } + + template + void test(float thresh) { + // [num_heads, query_seq_len, head_size] => [query_seq_len, num_heads * head_size] + int num_heads = 2, query_seq_len = 10; + for (int head_size = 1; head_size < 129; head_size++) { + tensor2D src(num_heads, head_size * query_seq_len, true); + tensor2D quant(1, head_size, true); + tensor2D dst(query_seq_len, num_heads * head_size, true); + tensor2D dst_ref(query_seq_len, num_heads * head_size, true); + for (int i = 0; i < num_heads * head_size * query_seq_len; i++) { + src[i] = i % 253 - 127; + } + quant = 1.28f; + auto* dst_p = dst.data; + auto* dst_p_ref = dst_ref.data; + for (int i = 0; i < num_heads; i++) { + auto* src_p = &src(i, 0); + llmdnn::memcpy2d_stride_avx512(dst_p, src_p, query_seq_len, head_size, + head_size * sizeof(float), num_heads * head_size * sizeof(T), quant.data); + gen_ref(dst_p_ref, src_p, query_seq_len, head_size, + head_size * sizeof(float), num_heads * head_size * sizeof(T), quant.data); + dst_p += head_size; + dst_p_ref += head_size; + } + for (int i = 0; i < num_heads * head_size * query_seq_len; i++) { + float a = dst[i]; + float b = dst_ref[i]; + if (std::abs(a - b) > thresh) { + FAIL() << " N: " << head_size << " pos: " << i << " opt: " << a << " ref: " << b; + } + } + } + } + + data_type_t _types; +}; + +TEST_P(TransposeTest, memcpy2d) { + if (_types == dnnl_s8) { + test(1.1f); + } else if (_types == dnnl_u8) { + test(1.1f); + } else if (_types == dnnl_f32) { + test(0.00001f); + } else { + test(0.01f); + } +} + +const std::vector types = { + dnnl_s8, dnnl_bf16, dnnl_u8, dnnl_f32 +}; + +INSTANTIATE_TEST_SUITE_P(smoke_Transpose, TransposeTest, + ::testing::Combine(ValuesIn(types)), + TransposeTest::getTestCaseName); diff --git a/tests/src/test_utility_kernel_avx512.cpp b/tests/src/test_utility_kernel_avx512.cpp new file mode 100644 index 0000000..615a5b0 --- /dev/null +++ b/tests/src/test_utility_kernel_avx512.cpp @@ -0,0 +1,37 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include +#include +#include +#include "gtest/gtest.h" +#include "llm_mm.hpp" +#include "common/tensor2d.hpp" +#include "common/tensor2d_helper.hpp" +#include "utility_kernel_avx512.hpp" +#include "test_common.hpp" + +using namespace std; +using namespace llmdnn; +using ::testing::Values; +using ::testing::ValuesIn; + +TEST(smoke_Utility, muladd) { + float normal_factor = 1.2f; + for (size_t len = 1; len < 129; len++) { + std::vector x(len), x_out(len), bias(len), ref(len); + for (size_t i = 0; i < x.size(); i++) { + x[i] = -10.0f + i; + bias[i] = -100.0f + i; + ref[i] = x[i] * normal_factor + bias[i]; + } + mul_add_f32_avx512(x_out.data(), x.data(), normal_factor, bias.data(), len); + for (size_t i = 0; i < x.size(); i++) { + ASSERT_TRUE(std::abs(x_out[i] - ref[i]) < 0.0001f) << " length: " << len << " pos: " << i << " cur: " << x[i] << " ref: " << ref[i]; + } + } +} \ No newline at end of file diff --git a/tests/src/test_utility_kernel_repack1x2_avx512.cpp b/tests/src/test_utility_kernel_repack1x2_avx512.cpp new file mode 100644 index 0000000..f140ba4 --- /dev/null +++ b/tests/src/test_utility_kernel_repack1x2_avx512.cpp @@ -0,0 +1,147 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include +#include +#include +#include +#include "gtest/gtest.h" +#include "llm_mm.hpp" +#include "common/tensor2d.hpp" +#include "common/tensor2d_helper.hpp" +#include "mm_kernel_common_amx.hpp" +#include "test_common.hpp" + +using namespace std; +using namespace llmdnn; +using ::testing::TestWithParam; +using ::testing::Values; +using ::testing::ValuesIn; + +using RepackTestParamSet = std::tuple< + data_type_t // data type + >; + +class RepackTest : public TestWithParam { +public: + static std::string getTestCaseName(const testing::TestParamInfo& obj) { + data_type_t types; + std::tie(types) = obj.param; + + std::ostringstream result; + result << dtype_to_str(types); + return result.str(); + } + +protected: + virtual void SetUp() override { + std::tie(_types) = GetParam(); + }; + + template + static void gen_ref(tensor2D& in_t, tensor2D& out) { + int K = in_t.dims[1]; + int N = in_t.dims[0]; + int kStep = 64 / sizeof(T); + int K_padded = (K + kStep - 1) / kStep * kStep; + int Ktails = K % kStep; + + // N_padded : round up to multiple of (2*16) + int N_unit = 2 * 16; + int N_padded = (N + N_unit - 1) / N_unit * N_unit; + + // Bo(ni, 0) is a vector flattened from a slice of shape [K_padded x N_unit] + out.resize(N_padded / N_unit, K_padded * N_unit, true, true); + + tensor2D in_padded; + in_padded.resize(N_padded, K_padded, true, true); + for (int i = 0; i < N; i++) { + // full range + memcpy(&in_padded(i, 0), &in_t(i, 0), (K - Ktails) * sizeof(T)); + // k tail needs to right aligned + memcpy(&in_padded(i, K_padded - Ktails), &in_t(i, K - Ktails), Ktails * sizeof(T)); + } + for (int n = 0; n < N_padded; n += N_unit) { + for (int k = 0; k < K_padded; k += kStep) { + // bf16 as example: + // [N, K], 2*[16(N), 32(K)] => + // [K, N], 2*[16(K), 16(N)*2(K)] + for (int m = 0; m < 2; m++) { + auto* src = reinterpret_cast(&in_padded(n + m * 16, k)); + auto* dst = reinterpret_cast(&out(n / N_unit, k * N_unit)); + dst += m * (1024 / sizeof(int)); + for (int i = 0; i < 16; i++) { + for (int j = 0; j < 16; j++) { + dst[i * 16 + j] = src[j * in_padded.stride / sizeof(int) + i]; + } + } + } + } + } + } + + template + void test() { + auto testone = [] (int k, int n, std::string prefix) { + tensor2D A(k, n, true); + + fill_rnd(A); + tensor2D AT = A.Tr(true); + tensor2D A_out, AT_out, A_ref; + amx_kernel::repackB_1x2(A, false, A_out, false); + amx_kernel::repackB_1x2(AT, true, AT_out, false); + gen_ref(AT, A_ref); + ASSERT_TRUE(A_out == A_ref) << " " << prefix << " without transform K: " << k << " N: " << n; + ASSERT_TRUE(AT_out == A_ref) << " " << prefix << " with transform K: " << k << " N: " << n; + }; + // n tail: transpose case needs from 1 to 31, without transpose needs one + int k = 32; + int n; + for (n = 1; n < 32; n++) { + testone(k, n, "ntail"); + } + for (n = 32 + 1; n < 32 + 32; n++) { + testone(k, n, "ntail"); + } + // k tail: transpose case needs 1, without transpose needs from 1 to 31 + n = 32; + for (k = 1; k < 32; k++) { + testone(k, n, "ktail"); + } + for (k = 32 + 1; k < 32 + 32; k++) { + testone(k, n, "ktail"); + } + // k, n normal + testone(32, 32, "normal"); + testone(64, 128, "normal"); + // k, n tail + testone(64, 128 + 5, "ntail"); + testone(64 + 3, 128, "ktail"); + testone(64 + 3, 128 + 5, "alltail"); + testone(64, 128 + 16 + 5, "ntail"); + testone(64 + 16 + 3, 128, "ktail"); + testone(64 + 16 + 3, 128 + 16 + 5, "alltail"); + } + + data_type_t _types; +}; + +TEST_P(RepackTest, Func) { + if (_types == dnnl_s8) { + test(); + } else { + test(); + } +} + +const std::vector types = { + dnnl_s8, dnnl_bf16 +}; + +INSTANTIATE_TEST_SUITE_P(smoke_Repack, RepackTest, + ::testing::Combine(ValuesIn(types)), + RepackTest::getTestCaseName);