diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..5e92026 --- /dev/null +++ b/.gitignore @@ -0,0 +1,56 @@ +# Compiled Object files +**/.DS_Store +*.slo +*.lo +*.o +*.obj + +# Precompiled Headers +*.gch +*.pch + +# Compiled Dynamic libraries +*.so +*.dylib +*.dll + +# Fortran module files +*.mod +*.smod + +# Compiled Static libraries +*.lai +*.la +*.a +*.lib + +# Executables +*.exe +*.out +*.app + +**/cmake-build-debug +**/CMakeCache.txt +**/cmake_install.cmake +**/install_manifest.txt +**/CMakeFiles/ +**/CTestTestfile.cmake +**/Makefile +**/*.cbp +**/CMakeScripts +**/compile_commands.json + + +## Local + +build/**/* +**/build/**/* +out/* +lib/* +bin/* +test/test_runner +.vs +.cache +__pycache__ +dist +*.egg-info \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..068837a --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,59 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +cmake_minimum_required(VERSION 3.13) + +project(root) + +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) +option(CPU_EXTENSIONS_BUILD_TESTS "Build with tests" ON) +option(CPU_EXTENSIONS_ENABLE_LOG "Enable log" ON) + +message(INFO "--------------------------------") +message(STATUS "Build with tests: ${CPU_EXTENSIONS_BUILD_TESTS}") +message(INFO "--------------------------------") + +if(MSVC) + # TODO: validate + if(MSVC_VERSION VERSION_LESS 1928) + message(FATAL_ERROR "Insufficient msvc compiler version, current ${MSVC_VERSION}, minimum 1928.") + endif() +elseif(CMAKE_COMPILER_IS_GNUCC OR CMAKE_COMPILER_IS_GNUCXX) + if(CMAKE_CXX_COMPILER_VERSION VERSION_LESS "11.2") + message(FATAL_ERROR "Insufficient gcc compiler version, current ${CMAKE_CXX_COMPILER_VERSION}, minimum 11.2.") + endif() + set(EXTRA_CXX_FLAGS -march=sapphirerapids -flax-vector-conversions) +elseif(OV_COMPILER_IS_CLANG) + if(CMAKE_CXX_COMPILER_VERSION VERSION_LESS "12") + message(FATAL_ERROR "Insufficient clang compiler version, current ${CMAKE_CXX_COMPILER_VERSION}, minimum 12.") + endif() + set(EXTRA_CXX_FLAGS -march=sapphirerapids -flax-vector-conversions) +elseif(CMAKE_CXX_COMPILER_ID STREQUAL "IntelLLVM") + if(CMAKE_CXX_COMPILER_VERSION VERSION_LESS "2023.0") + message(FATAL_ERROR "Insufficient intel compiler version, current ${CMAKE_CXX_COMPILER_VERSION}, minimum 2023.0.") + endif() + set(EXTRA_CXX_FLAGS -march=sapphirerapids) +endif() + +if (NOT CMAKE_RUNTIME_OUTPUT_DIRECTORY) + set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) + set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) + set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) +endif() +add_subdirectory(src) +if (CPU_EXTENSIONS_BUILD_TESTS) + add_subdirectory(tests) +endif() + +# Get the latest commit hash +execute_process( + COMMAND git rev-parse HEAD + WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR} + OUTPUT_VARIABLE GIT_HASH + OUTPUT_STRIP_TRAILING_WHITESPACE + ) +file(WRITE ${CMAKE_BINARY_DIR}/git-state.txt ${GIT_HASH}) +install(FILES + ${CMAKE_BINARY_DIR}/git-state.txt + DESTINATION ${CMAKE_INSTALL_PREFIX}) diff --git a/README.md b/README.md new file mode 100644 index 0000000..0129645 --- /dev/null +++ b/README.md @@ -0,0 +1,9 @@ +# About CPU_Extensions +CPU_Extensions is a compute library containing processor optimized kernels code. + +# Unit tests for CPU_Extensions +## Tests for kernels +Tests for kernels are written in gtest under tests\src, use ./cpu_extensions_tests to run it. + +## Tests for complex features +Some features have many steps and the reference could not be easily written using gtest. For these features can use python to generate the reference. The directory tests\script contains these test, please refer [test in python](./tests/script/README.md). \ No newline at end of file diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 0000000..eb482d9 --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,12 @@ +# Security Policy + +## Report a Vulnerability + +Please report security issues or vulnerabilities to the [Intel® Security Center]. + +For more information on how Intel® works to resolve security issues, see +[Vulnerability Handling Guidelines]. + +[Intel® Security Center]:https://www.intel.com/security + +[Vulnerability Handling Guidelines]:https://www.intel.com/content/www/us/en/security-center/vulnerability-handling-guidelines.html diff --git a/include/llm_emb_gpt.hpp b/include/llm_emb_gpt.hpp new file mode 100644 index 0000000..f74e3e6 --- /dev/null +++ b/include/llm_emb_gpt.hpp @@ -0,0 +1,31 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include "llm_types.hpp" +#include "llm_tensor.hpp" + +namespace llmdnn { + +status_t emb_gpt(const tensor& q_src, // q shape: [batch, query_seq_len, head_num, head_size] or + // [batch, query_seq_len, num_kv_heads, head_num/num_kv_heads, head_size] + const tensor& k_src, // k shape: [batch, query_seq_len, head_num, head_size] or + // [batch, query_seq_len, num_kv_heads, 1, head_size] + const tensor& v_src, // v shape: [batch, query_seq_len, head_num, head_size] or + // [batch, query_seq_len, num_kv_heads, 1, head_size] + const tensor& k_past, // k_past shape: [batch, num_heads, past_seq_len, head_size] + const tensor& v_past, // v_past shape: [batch, num_heads, past_seq_len, head_size] + const tensor& q_dst, // q_dst, shape: [batch, num_heads, query_seq_len, head_size] + const tensor& k_dst, // k_past shape: [batch, num_heads, query_seq_len+past_seq_len, head_size] + // if k_past!=k_past_dst, will copy k_past to k_past_dst + const tensor& v_dst, // v_past shape: [batch, num_heads, query_seq_len+past_seq_len, head_size] + const tensor& cos, // cos lookup table, shape: [1, 1, max_seq_len, rotary_dims] + const tensor& sin, // sin lookup table, shape: [1, 1, max_seq_len, rotary_dims] + const tensor& position2d_ids); // shape: [batch, 2, query_seq_len] + +} // namespace llmdnn diff --git a/include/llm_fc.hpp b/include/llm_fc.hpp new file mode 100644 index 0000000..c29371c --- /dev/null +++ b/include/llm_fc.hpp @@ -0,0 +1,98 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "llm_types.hpp" +#include "llm_tensor.hpp" + +namespace llmdnn { + +typedef enum { + NONE = 0, + DEQUANT = 1 << 0, + BIAS = 1 << 1, + GELU_ERF = 1 << 2, + GELU_TANH = 1 << 3, + QUANT = 1 << 4, + GELU = GELU_ERF, // default is ERF + + BIAS_GELU = BIAS | GELU, + DEQUANT_BIAS_GELU = DEQUANT | BIAS_GELU, + DEQUANT_BIAS_GELU_QUANT = DEQUANT_BIAS_GELU | QUANT, + DEQUANT_BIAS_QUANT = DEQUANT | BIAS | QUANT, + DEQUANT_GELU_QUANT = DEQUANT | GELU | QUANT, + DEQUANT_QUANT = DEQUANT | QUANT, + + DEQUANT_GELU = DEQUANT | GELU, + DEQUANT_BIAS = DEQUANT | BIAS, + + BIAS_GELU_TANH = BIAS | GELU_TANH, + DEQUANT_BIAS_GELU_TANH = DEQUANT | BIAS_GELU_TANH, + DEQUANT_BIAS_GELU_TANH_QUANT = DEQUANT_BIAS_GELU_TANH | QUANT, + DEQUANT_GELU_TANH_QUANT = DEQUANT | GELU_TANH | QUANT, + + DEQUANT_GELU_TANH = DEQUANT | GELU_TANH, +} postops_types; + +struct fc_create_param { + data_type_t dt_a; + data_type_t dt_b; + data_type_t dt_c; + bool b_is_trans; + postops_types postops_type; + // for weight compression + float* scale; + float* zp; + int scale_zp_size; +}; + +struct fc_kernel; + +/// Generates a mm kernel based on param +/// +/// @param mm Output kernel +/// @param param kernel parameters, supported: +/// fc: (s8,s8,s8),dq,[bias],[gelu],q +/// fc: (s8,s8,bf16),dq,[bias],[gelu] +/// fc: (s8,s8,f32),dq,[bias],[gelu] +/// fc: (bf16,bf16,bf16),[bias],[gelu] +/// fc: (bf16,bf16,f32),[bias],[gelu] +/// fc: (bf16,u8,f32),dq,[bias],[gelu] +/// fc: (bf16,u8,bf16),dq,[bias],[gelu] +/// +status_t fc_kernel_create(fc_kernel** mm, const fc_create_param* param); +void fc_kernel_destroy(fc_kernel* mm); +// when fc_create_param.dt_b==bf16, dt_b is in [bf16, f32] +// when fc_create_param.dt_b==u8, dt_b is in [bf16, f32] +void fc_kernel_pack_weight(fc_kernel* mm, void* ptr_b, data_type_t dt_b, size_t N, size_t K, size_t stride_b, size_t n_start, size_t n_end); +void fc_kernel_pack_weight_to_dst(fc_kernel* mm, void* src_b, void* dst_b, data_type_t dt_b, size_t N, size_t K, size_t stride_b, size_t n_start, size_t n_end); +// ptr_b may be null if using fc_kernel_pack_weight to pack into internal buffer +// if ptr_b is not null, its layout is [N/32, 32*rndup(K,32|64)] +void fc_kernel_execute(fc_kernel* mm, + void* ptr_a, void* ptr_b, void* ptr_c, size_t stride_a, size_t stride_c, + size_t M, size_t N, size_t K, size_t n_start, size_t n_end, + float* dq=nullptr, float* q=nullptr, float* bias=nullptr); + +/// Generates a fc based on param +class fc { +public: + fc(); + ~fc(); + + bool init(const fc_create_param& param); + void pack_weight(const tensor& w); + status_t exec(const tensor& input, const tensor& output, const tensor& dq, const tensor& q, const tensor& bias); + + struct impl { + virtual ~impl() {} + virtual bool init(const fc_create_param& param) = 0; + virtual void pack_weight(const tensor& w) = 0; + virtual status_t exec(const tensor& input, const tensor& output, const tensor& dq, const tensor& q, const tensor& bias) = 0; + }; +protected: + impl* _impl; +}; + +} // namespace llmdnn diff --git a/include/llm_mha_gpt.hpp b/include/llm_mha_gpt.hpp new file mode 100644 index 0000000..68942bb --- /dev/null +++ b/include/llm_mha_gpt.hpp @@ -0,0 +1,53 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include "llm_types.hpp" +#include "llm_tensor.hpp" + +namespace llmdnn { + +class mha_gpt { +public: + mha_gpt(); + ~mha_gpt(); + + status_t exec(const tensor& q, // q shape: [batch, num_heads, query_seq_len, head_size] + const tensor& k, // k shape: [batch, num_heads, key_seq_len, head_size] + const tensor& v, // v shape: [batch, num_heads, value_seq_len, head_size] + const tensor& output, // output, compact, shape: [batch, query_seq_len, num_heads * head_size] + const tensor& attn_mask, // attention mask[opt], shape: + // [batch, 1, 1, key_seq_len], + // [batch, 1, query_seq_len, key_seq_len] + const tensor& alibi, // alibi[opt] shape: [batch, num_heads, 1, key_seq_len] + const tensor& causal_mask, // [opt] use_causal_mask must be false, u8, shape: + // [1, 1, query_seq_len, key_seq_len] + // [batch, 1, query_seq_len, key_seq_len] + bool select_nfltmax_at_0, // used when causal_mask is not null. true means causal_mask[i]==0 use -FLT_MAX + // false means causal_mask[i]==1 use -FLT_MAX + float normal_factor, + bool use_causal_mask = false);// add causal mask + + struct impl { + virtual ~impl() {} + virtual status_t exec(const tensor& q, + const tensor& k, + const tensor& v, + const tensor& output, + const tensor& attn_mask, + const tensor& alibi, + const tensor& causal_mask, + bool select_nfltmax_at_0, + float normal_factor, + bool use_causal_mask = false) = 0; + }; +protected: + impl* _impl; +}; + +} // namespace llmdnn diff --git a/include/llm_mm.hpp b/include/llm_mm.hpp new file mode 100644 index 0000000..52dc420 --- /dev/null +++ b/include/llm_mm.hpp @@ -0,0 +1,35 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "llm_types.hpp" + +namespace llmdnn { + +struct mm_create_param { + data_type_t dt_a; + data_type_t dt_b; + bool b_is_gemv; // true if matrix b is vector. Shape: a[M,K], b[K,1], c[M,1] + bool b_is_trans; +}; + +struct mm_kernel; + +/// Generates a mm kernel based on param +/// +/// @param mm Output kernel +/// @param param kernel parameters, supported: +/// matmul: (u8/s8,s8,f32) +/// gemv: (s8,s8,f32) +/// matmul: (bf16,bf16,f32) +/// gemv: (bf16,bf16,f32) +/// +status_t mm_kernel_create(mm_kernel** mm, const mm_create_param* param); +void mm_kernel_destroy(const mm_kernel* mm); + +status_t mm_kernel_execute(const mm_kernel* mm, void* ptr_a, void* ptr_b, void* ptr_c, size_t lda, size_t ldb, size_t ldc, + size_t M, size_t N, size_t K); + +} // namespace llmdnn diff --git a/include/llm_tensor.hpp b/include/llm_tensor.hpp new file mode 100644 index 0000000..615d32f --- /dev/null +++ b/include/llm_tensor.hpp @@ -0,0 +1,181 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "llm_types.hpp" + +// forward declaration +namespace ov { +class bfloat16; +}; + +namespace llmdnn { + +template +struct precision_of { + static constexpr data_type_t value = llmdnn_data_type_undef; +}; + +template <> +struct precision_of { + static constexpr data_type_t value = llmdnn_f32; +}; + +template <> +struct precision_of { + static constexpr data_type_t value = llmdnn_s32; +}; + +template <> +struct precision_of { + static constexpr data_type_t value = llmdnn_bf16; +}; + +template <> +struct precision_of { + static constexpr data_type_t value = llmdnn_u8; +}; + +template <> +struct precision_of { + static constexpr data_type_t value = llmdnn_s8; +}; + + +#define TENSOR_RANK_MAX 8 +struct tensor { + size_t m_strides[TENSOR_RANK_MAX]; + size_t m_dims[TENSOR_RANK_MAX]; + size_t m_rank = 0; + + void* m_ptr = nullptr; + size_t m_capacity = 0; // 0 means not own m_ptr + size_t m_element_size = 0; + data_type_t m_dtype = llmdnn_data_type_undef; + + tensor(); + ~tensor(); + tensor(const tensor&) = delete; + tensor& operator = (const tensor&) = delete; + tensor(tensor&& t) { + memcpy(reinterpret_cast(this), &t, sizeof(*this)); + t.m_capacity = 0; + t.m_ptr = nullptr; + } + tensor& operator = (tensor&& t) { + if (m_capacity && m_ptr) + free(m_ptr); + memcpy(reinterpret_cast(this), &t, sizeof(*this)); + t.m_capacity = 0; + t.m_ptr = nullptr; + return *this; + } + operator bool() const { + return m_ptr != nullptr; + } + + size_t size(int i) const { + assert(static_cast(i) < m_rank); + return m_dims[i]; + } + size_t stride(int i) const { + assert(static_cast(i) < m_rank); + return m_strides[i]; + } + + struct tensor_index { + int start; + int end; + int step; + int count; + // select all + tensor_index() { + start = 0; + end = INT_MAX; + step = 1; + } + bool slice_with_squeeze() { + return end == INT_MIN; + } + // tensor_index(start) : select 1 element (with squeeze) + // tensor_index(start, end, step) : select a range w/o squeeze + tensor_index(int start, int end = INT_MIN, int step = 1) : start(start), end(end), step(step) {} + + void regularize(int size) { + if (start < 0) + start += size; + assert(start >= 0 && start < size); + if (end != INT_MIN) { + if (end < 0) + end += size; + if (end > size) + end = size; + assert(end >= 0 && end <= size); + count = (end - start + step - 1) / step; + } else { + count = 1; + } + } + }; + + tensor index(const std::initializer_list& indices) const; + + // slice: return a sub-view (w/o ownership/refcount to original data) + tensor slice(int axis, int start, int end) const; + + bool is_dense() const; + + tensor reshape(const std::initializer_list& target_shape) const; + + tensor permute(const std::initializer_list& order) const; + + template + void resize(const size_t* new_dims, size_t dim_num, DT* data = nullptr) { + resize(new_dims, dim_num, data, sizeof(DT), precision_of
::value); + } + + template + void resize(const std::vector& new_dims, DT* data = nullptr) { + resize(new_dims.data(), new_dims.size(), data); + } + + void resize(const size_t* new_dims, size_t dim_num, void* data, size_t element_size, data_type_t dtype); + void resize(const std::vector& new_dims, void* data, size_t element_size, data_type_t dtype) { + resize(new_dims.data(), new_dims.size(), data, element_size, dtype); + } + + template + DT* data() const { + return reinterpret_cast(m_ptr); + } + + template + DT& at(const std::initializer_list& index) const { + size_t off = 0; + auto it = index.begin(); + for (auto& stride : m_strides) { + auto coordinate = (it != index.end()) ? (*it++) : 0; + off += stride * coordinate; + } + return *reinterpret_cast(reinterpret_cast(m_ptr) + off); + } + + template + DT& operator()(const std::initializer_list& index) const { + return at
(index); + } + + void assert_dims(const std::initializer_list& expect_dims) const; +}; + +} // namespace llmdnn diff --git a/include/llm_types.hpp b/include/llm_types.hpp new file mode 100644 index 0000000..2a48041 --- /dev/null +++ b/include/llm_types.hpp @@ -0,0 +1,42 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +namespace llmdnn { + +/// Data type specification +typedef enum { + /// Undefined data type, used for empty memory descriptors. + llmdnn_data_type_undef = 0, + /// 16-bit/half-precision floating point. + llmdnn_f16 = 1, + /// non-standard 16-bit (bfloat16 w/ 7 bit mantissa) floating point. + llmdnn_bf16 = 2, + /// 32-bit/single-precision floating point. + llmdnn_f32 = 3, + /// 32-bit signed integer. + llmdnn_s32 = 4, + /// 8-bit signed integer. + llmdnn_s8 = 5, + /// 8-bit unsigned integer. + llmdnn_u8 = 6, + /// 64-bit/double-precision floating point. + llmdnn_f64 = 7, + + /// Parameter to allow internal only data_types without undefined behavior. + /// This parameter is chosen to be valid for so long as sizeof(int) >= 2. + llmdnn_data_type_max = 0x7fff, +} data_type_t; + +typedef enum { + status_ok, + status_invalid_arguments, + status_unimplemented, + status_fail = 10 +} status_t; + +} // namespace llmdnn diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt new file mode 100644 index 0000000..29f4982 --- /dev/null +++ b/src/CMakeLists.txt @@ -0,0 +1,50 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +cmake_minimum_required(VERSION 3.13) +project(cpu_extensions) + +file(GLOB_RECURSE ${PROJECT_NAME}_SOURCE_FILES CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp) + +add_library(${PROJECT_NAME} STATIC ${${PROJECT_NAME}_SOURCE_FILES}) +set_target_properties(${PROJECT_NAME} PROPERTIES + POSITION_INDEPENDENT_CODE ON) +target_include_directories(${PROJECT_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} + PUBLIC $ + $/${CMAKE_INSTALL_INCLUDEDIR}>) +target_compile_options(${PROJECT_NAME} PRIVATE ${EXTRA_CXX_FLAGS}) +target_compile_features(${PROJECT_NAME} PRIVATE cxx_std_17) +if(CPU_EXTENSIONS_ENABLE_LOG) + target_compile_definitions(${PROJECT_NAME} PRIVATE ENABLE_LOG) +endif() +target_link_libraries(${PROJECT_NAME} PUBLIC dl) + +set(CMAKE_DST lib/cmake/${PROJECT_NAME}) +# header files +include(GNUInstallDirs) +install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../include/ + DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}) + +# library file +install(TARGETS ${PROJECT_NAME} + EXPORT ${PROJECT_NAME}Targets + LIBRARY DESTINATION lib + ARCHIVE DESTINATION lib + RUNTIME DESTINATION bin + INCLUDES DESTINATION include) + +# config file +install(EXPORT ${PROJECT_NAME}Targets + NAMESPACE ${PROJECT_NAME}:: + FILE ${PROJECT_NAME}Config.cmake + DESTINATION ${CMAKE_DST}) + +# version file +include(CMakePackageConfigHelpers) +write_basic_package_version_file(${CMAKE_CURRENT_BINARY_DIR}/${PROJECT_NAME}ConfigVersion.cmake + VERSION 1.0.0 + COMPATIBILITY AnyNewerVersion) +install(FILES + ${CMAKE_CURRENT_BINARY_DIR}/${PROJECT_NAME}ConfigVersion.cmake + DESTINATION ${CMAKE_DST}) diff --git a/src/common/bf16.hpp b/src/common/bf16.hpp new file mode 100644 index 0000000..35f42cc --- /dev/null +++ b/src/common/bf16.hpp @@ -0,0 +1,249 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include +#include +#include + +#define ROUND_MODE_TO_NEAREST_EVEN + +#define OPENVINO_API + +namespace ov { +class OPENVINO_API bfloat16 { +public: + constexpr bfloat16() : m_value{0} {} + bfloat16(float value) : m_value { +#if defined ROUND_MODE_TO_NEAREST + round_to_nearest(value) +#elif defined ROUND_MODE_TO_NEAREST_EVEN + round_to_nearest_even(value) +#elif defined ROUND_MODE_TRUNCATE + truncate(value) +#else +# error "ROUNDING_MODE must be one of ROUND_MODE_TO_NEAREST, ROUND_MODE_TO_NEAREST_EVEN, or ROUND_MODE_TRUNCATE" +#endif + } + {} + + template + explicit bfloat16(I value) : m_value{bfloat16{static_cast(value)}.m_value} {} + + std::string to_string() const; + size_t size() const; + template + bool operator==(const T& other) const; + template + bool operator!=(const T& other) const { + return !(*this == other); + } + template + bool operator<(const T& other) const; + template + bool operator<=(const T& other) const; + template + bool operator>(const T& other) const; + template + bool operator>=(const T& other) const; + template + bfloat16 operator+(const T& other) const; + template + bfloat16 operator+=(const T& other); + template + bfloat16 operator-(const T& other) const; + template + bfloat16 operator-=(const T& other); + template + bfloat16 operator*(const T& other) const; + template + bfloat16 operator*=(const T& other); + template + bfloat16 operator/(const T& other) const; + template + bfloat16 operator/=(const T& other); + operator float() const { + uint32_t tmp = 0; + uint32_t* ptmp = &tmp; + *ptmp = (static_cast(m_value) << 16); + const float* f = reinterpret_cast(ptmp); + return *f; + } + + static std::vector to_float_vector(const std::vector&); + static std::vector from_float_vector(const std::vector&); + static constexpr bfloat16 from_bits(uint16_t bits) { + return bfloat16(bits, true); + } + uint16_t to_bits() const; + friend std::ostream& operator<<(std::ostream& out, const bfloat16& obj) { + out << static_cast(obj); + return out; + } + +#define cu32(x) (F32(x).i) + + static uint16_t round_to_nearest_even(float x) { + return static_cast((cu32(x) + ((cu32(x) & 0x00010000) >> 1)) >> 16); + } + + static uint16_t round_to_nearest(float x) { + return static_cast((cu32(x) + 0x8000) >> 16); + } + + static uint16_t truncate(float x) { + return static_cast((cu32(x)) >> 16); + } + +private: + constexpr bfloat16(uint16_t x, bool) : m_value{x} {} + union F32 { + F32(float val) : f{val} {} + F32(uint32_t val) : i{val} {} + float f; + uint32_t i; + }; + + uint16_t m_value; +}; + +#if defined(_MSC_VER) +# pragma warning(push) +# pragma warning(disable : 4756) +#endif +template +bool bfloat16::operator==(const T& other) const { +#if defined(__GNUC__) +# pragma GCC diagnostic push +# pragma GCC diagnostic ignored "-Wfloat-equal" +#endif + return (static_cast(*this) == static_cast(other)); +#if defined(__GNUC__) +# pragma GCC diagnostic pop +#endif +} + +template +bool bfloat16::operator<(const T& other) const { + return (static_cast(*this) < static_cast(other)); +} + +template +bool bfloat16::operator<=(const T& other) const { + return (static_cast(*this) <= static_cast(other)); +} + +template +bool bfloat16::operator>(const T& other) const { + return (static_cast(*this) > static_cast(other)); +} + +template +bool bfloat16::operator>=(const T& other) const { + return (static_cast(*this) >= static_cast(other)); +} + +template +bfloat16 bfloat16::operator+(const T& other) const { + return {static_cast(*this) + static_cast(other)}; +} + +template +bfloat16 bfloat16::operator+=(const T& other) { + return *this = *this + other; +} + +template +bfloat16 bfloat16::operator-(const T& other) const { + return {static_cast(*this) - static_cast(other)}; +} + +template +bfloat16 bfloat16::operator-=(const T& other) { + return *this = *this - other; +} + +template +bfloat16 bfloat16::operator*(const T& other) const { + return {static_cast(*this) * static_cast(other)}; +} + +template +bfloat16 bfloat16::operator*=(const T& other) { + return *this = *this * other; +} + +template +bfloat16 bfloat16::operator/(const T& other) const { + return {static_cast(*this) / static_cast(other)}; +} + +template +bfloat16 bfloat16::operator/=(const T& other) { + return *this = *this / other; +} +#if defined(_MSC_VER) +# pragma warning(pop) +#endif +} // namespace ov + +namespace std { +template <> +class numeric_limits { +public: + static constexpr bool is_specialized = true; + static constexpr ov::bfloat16 min() noexcept { + return ov::bfloat16::from_bits(0x007F); + } + static constexpr ov::bfloat16 max() noexcept { + return ov::bfloat16::from_bits(0x7F7F); + } + static constexpr ov::bfloat16 lowest() noexcept { + return ov::bfloat16::from_bits(0xFF7F); + } + static constexpr int digits = 7; + static constexpr int digits10 = 2; + static constexpr bool is_signed = true; + static constexpr bool is_integer = false; + static constexpr bool is_exact = false; + static constexpr int radix = 2; + static constexpr ov::bfloat16 epsilon() noexcept { + return ov::bfloat16::from_bits(0x3C00); + } + static constexpr ov::bfloat16 round_error() noexcept { + return ov::bfloat16::from_bits(0x3F00); + } + static constexpr int min_exponent = -125; + static constexpr int min_exponent10 = -37; + static constexpr int max_exponent = 128; + static constexpr int max_exponent10 = 38; + static constexpr bool has_infinity = true; + static constexpr bool has_quiet_NaN = true; + static constexpr bool has_signaling_NaN = true; + static constexpr float_denorm_style has_denorm = denorm_absent; + static constexpr bool has_denorm_loss = false; + static constexpr ov::bfloat16 infinity() noexcept { + return ov::bfloat16::from_bits(0x7F80); + } + static constexpr ov::bfloat16 quiet_NaN() noexcept { + return ov::bfloat16::from_bits(0x7FC0); + } + static constexpr ov::bfloat16 signaling_NaN() noexcept { + return ov::bfloat16::from_bits(0x7FC0); + } + static constexpr ov::bfloat16 denorm_min() noexcept { + return ov::bfloat16::from_bits(0); + } + static constexpr bool is_iec559 = false; + static constexpr bool is_bounded = false; + static constexpr bool is_modulo = false; + static constexpr bool traps = false; + static constexpr bool tinyness_before = false; + static constexpr float_round_style round_style = round_to_nearest; +}; +} // namespace std diff --git a/src/common/compatible.hpp b/src/common/compatible.hpp new file mode 100644 index 0000000..f232763 --- /dev/null +++ b/src/common/compatible.hpp @@ -0,0 +1,43 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include + +// gcc 9 does not recognize 'std::__throw_bad_array_new_length()' which is imported by +// gcc 11. The symbol exists in std::allocator::allocate, use custom to wa. +template +class custom_allocator { +public: + using value_type = T; + custom_allocator() noexcept = default; + template + custom_allocator (const custom_allocator&) noexcept {} + inline T* allocate(std::allocator::size_type cnt, typename std::allocator::const_pointer = 0) { + return static_cast(::operator new(cnt * sizeof(T))); + } + void deallocate (T* p, std::size_t n) { + ::operator delete(p); + } +}; + +template +bool operator==(custom_allocator const&, custom_allocator const&) noexcept { + return true; +} + +template +bool operator!=(custom_allocator const& x, custom_allocator const& y) noexcept { + return !(x == y); +} + +template +using llm_vector = std::vector>; + +template > +using llm_map = std::map>>; \ No newline at end of file diff --git a/src/common/log.hpp b/src/common/log.hpp new file mode 100644 index 0000000..03d969e --- /dev/null +++ b/src/common/log.hpp @@ -0,0 +1,11 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#ifdef ENABLE_LOG + #define DEBUG_LOG std::cout +#else + #define DEBUG_LOG if (0) std::cout +#endif diff --git a/src/common/memory_alloc.cpp b/src/common/memory_alloc.cpp new file mode 100644 index 0000000..89327fa --- /dev/null +++ b/src/common/memory_alloc.cpp @@ -0,0 +1,160 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// +#include +#include +#include +#include +#include +#include + +#include "memory_alloc.hpp" +#include "common/simple_parallel.hpp" + +struct numa_funcs { + numa_funcs() { + _numa_handle = dlopen(libnuma_path, RTLD_NOW); + if (_numa_handle) { + _numa_available = reinterpret_cast(dlsym(_numa_handle, "numa_available")); + _numa_node_of_cpu = reinterpret_cast(dlsym(_numa_handle, "numa_node_of_cpu")); + _numa_alloc_onnode = reinterpret_cast(dlsym(_numa_handle, "numa_alloc_onnode")); + _numa_free = reinterpret_cast(dlsym(_numa_handle, "numa_free")); + } + } + + ~numa_funcs() { + if (_numa_handle) { + dlclose(_numa_handle); + } + } + + static numa_funcs& get() { + static numa_funcs funcs; + return funcs; + } + + int numa_available() { + if (_numa_available) { + return _numa_available(); + } else { + return -1; + } + } + + int numa_node_of_cpu(int cpu) { + if (_numa_node_of_cpu) { + return _numa_node_of_cpu(cpu); + } else { + return 0; + } + } + + void *numa_alloc_onnode(size_t size, int node) { + if (_numa_alloc_onnode) { + return _numa_alloc_onnode(size, node); + } else { + return aligned_alloc(64, size); + } + } + + void numa_free(void *mem, size_t size) { + if (_numa_free) { + _numa_free(mem, size); + } else { + ::free(mem); + } + } + +private: + constexpr static const char* libnuma_path = "libnuma.so.1"; + void* _numa_handle = nullptr; + int (*_numa_available)(void) = nullptr; + int (*_numa_node_of_cpu)(int cpu) = nullptr; + void *(*_numa_alloc_onnode)(size_t size, int node) = nullptr; + void (*_numa_free)(void *mem, size_t size) = nullptr; +}; + +static bool llmdnn_use_numa() { + struct init_numa_flag { + init_numa_flag() { + auto p = std::getenv("LLMDNN_USE_NUMA"); + if (p) { + use_numa = p[0] != '0'; + } + if (use_numa) { + use_numa = numa_funcs::get().numa_available() != -1; + } + } + + bool use_numa = true; + }; + + static init_numa_flag flag; + + return flag.use_numa; +} + +void* llmdnn_alloc(size_t aligned_size, size_t size, bool hint_numa) { + if (hint_numa && llmdnn_use_numa()) { + int cur_cpu = sched_getcpu(); + auto cur_numa_node = numa_funcs::get().numa_node_of_cpu(cur_cpu); + return numa_funcs::get().numa_alloc_onnode(size, cur_numa_node); + } else { + return aligned_alloc(aligned_size, size); + } +} + +void llmdnn_free(void* p, size_t size, bool hint_numa) { + if (hint_numa && llmdnn_use_numa()) { + numa_funcs::get().numa_free(p, size); + } else { + ::free(p); + } +} + +int llmdnn_get_numa_id_for_cur_task() { + if (llmdnn_use_numa()) { + int cur_cpu = sched_getcpu(); + return numa_funcs::get().numa_node_of_cpu(cur_cpu); + } else { + return 0; + } +} + +llm_vector llmdnn_get_numa_nodes() { + llm_vector numa_nodes; + if (llmdnn_use_numa()) { + auto thread_nums = llmdnn::get_total_threads(); + llm_vector numa_nodes_list; + numa_nodes_list.resize(thread_nums); + llmdnn::parallel_for(thread_nums, [&] (size_t id) { + int cur_cpu = sched_getcpu(); + numa_nodes_list[id] = numa_funcs::get().numa_node_of_cpu(cur_cpu); + }); + for (auto numa_node : numa_nodes_list) { + if (std::find(numa_nodes.begin(), numa_nodes.end(), numa_node) == numa_nodes.end()) { + numa_nodes.push_back(numa_node); + } + } + std::sort(numa_nodes.begin(), numa_nodes.end()); + } else { + numa_nodes.push_back(0); + } + return numa_nodes; +} + +void* llmdnn_alloc_on(size_t aligned_size, size_t size, int numa_id) { + if (llmdnn_use_numa()) { + return numa_funcs::get().numa_alloc_onnode(size, static_cast(numa_id)); + } else { + return aligned_alloc(aligned_size, size); + } +} + +void llmdnn_free_on(void* p, size_t size) { + if (llmdnn_use_numa()) { + numa_funcs::get().numa_free(p, size); + } else { + ::free(p); + } +} diff --git a/src/common/memory_alloc.hpp b/src/common/memory_alloc.hpp new file mode 100644 index 0000000..a704add --- /dev/null +++ b/src/common/memory_alloc.hpp @@ -0,0 +1,16 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include "compatible.hpp" + +void* llmdnn_alloc(size_t aligned_size, size_t size, bool hint_numa = true); +void llmdnn_free(void* p, size_t size, bool hint_numa = true); + +llm_vector llmdnn_get_numa_nodes(); +void* llmdnn_alloc_on(size_t aligned_size, size_t size, int numa_id); +void llmdnn_free_on(void* p, size_t size); +int llmdnn_get_numa_id_for_cur_task(); \ No newline at end of file diff --git a/src/common/simple_parallel.hpp b/src/common/simple_parallel.hpp new file mode 100644 index 0000000..1f7de95 --- /dev/null +++ b/src/common/simple_parallel.hpp @@ -0,0 +1,181 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include + +namespace llmdnn { + +size_t get_total_threads(); +void simple_parallel_for(const size_t total, const std::function& fn); + +// copy from openvino/core/parallel.hpp +template +inline T parallel_it_init(T start) { + return start; +} +template +inline T parallel_it_init(T start, Q& x, const R& X, Args&&... tuple) { + start = parallel_it_init(start, static_cast(tuple)...); + x = start % X; + return start / X; +} + +inline bool parallel_it_step() { + return true; +} +template +inline bool parallel_it_step(Q& x, const R& X, Args&&... tuple) { + if (parallel_it_step(static_cast(tuple)...)) { + if (++x - X == 0) { + x = 0; + return true; + } + } + return false; +} + +template +inline void splitter(const T& n, const Q& team, const Q& tid, T& n_start, T& n_end) { + if (team <= 1 || n == 0) { + n_start = 0; + n_end = n; + } else { + T n1 = (n + (T)team - 1) / (T)team; + T n2 = n1 - 1; + T T1 = n - n2 * (T)team; + n_end = (T)tid < T1 ? n1 : n2; + n_start = (T)tid <= T1 ? tid * n1 : T1 * n1 + ((T)tid - T1) * n2; + } + + n_end += n_start; +} + +namespace helpers { +template +struct NumOfLambdaArgs : public NumOfLambdaArgs {}; + +template +struct NumOfLambdaArgs { + constexpr static int value = sizeof...(Args); +}; + +template ::value> +typename std::enable_if::type call_with_args(const ACT& body, + size_t g_id, + size_t iwork, + T... arg) { + body(g_id, iwork, arg...); +} + +template ::value> +typename std::enable_if::type call_with_args(const ACT& body, + size_t g_id, + size_t iwork, + T... arg) { + body(g_id, arg...); +} + +template ::value> +typename std::enable_if::type call_with_args(const ACT& body, + size_t g_id, + size_t iwork, + T... arg) { + body(arg...); +} +} // namespace helpers + +template +void for_1d(const int& ithr, const int& nthr, const T0& D0, const F& func) { + T0 d0{0}, end{0}; + splitter(D0, nthr, ithr, d0, end); + for (; d0 < end; ++d0) + helpers::call_with_args(func, ithr, d0, d0); +} + +template +void parallel_for(const T0& D0, const F& func) { + auto work_amount = static_cast(D0); + int nthr = static_cast(get_total_threads()); + if (static_cast(nthr) > work_amount) + nthr = static_cast(work_amount); + if (nthr == 1) { + for_1d(0, 1, D0, func); + } else { + simple_parallel_for(static_cast(nthr), [&](size_t ithr) { + for_1d(static_cast(ithr), nthr, D0, func); + }); + } +} + +template +void for_2d(const int& ithr, const int& nthr, const T0& D0, const T1& D1, const F& func) { + const size_t work_amount = (size_t)D0 * D1; + if (work_amount == 0) + return; + size_t start{0}, end{0}; + splitter(work_amount, nthr, ithr, start, end); + + T0 d0{0}; + T1 d1{0}; + parallel_it_init(start, d0, D0, d1, D1); + for (size_t iwork = start; iwork < end; ++iwork) { + helpers::call_with_args(func, ithr, iwork, d0, d1); + parallel_it_step(d0, D0, d1, D1); + } +} + +template +void parallel_for2d(const T0& D0, const T1& D1, const F& func) { + auto work_amount = static_cast(D0 * D1); + int nthr = static_cast(get_total_threads()); + if (static_cast(nthr) > work_amount) + nthr = static_cast(work_amount); + if (nthr == 1) { + for_2d(0, 1, D0, D1, func); + } else { + simple_parallel_for(static_cast(nthr), [&](size_t ithr) { + for_2d(static_cast(ithr), nthr, D0, D1, func); + }); + } +} + +template +void for_3d(const int& ithr, const int& nthr, const T0& D0, const T1& D1, const T2& D2, const F& func) { + const size_t work_amount = (size_t)D0 * D1 * D2; + if (work_amount == 0) + return; + size_t start{0}, end{0}; + splitter(work_amount, nthr, ithr, start, end); + + T0 d0{0}; + T1 d1{0}; + T2 d2{0}; + parallel_it_init(start, d0, D0, d1, D1, d2, D2); + for (size_t iwork = start; iwork < end; ++iwork) { + helpers::call_with_args(func, ithr, iwork, d0, d1, d2); + parallel_it_step(d0, D0, d1, D1, d2, D2); + } +} + +template +void parallel_for3d(const T0& D0, const T1& D1, const T2& D2, const F& func) { + auto work_amount = static_cast(D0 * D1 * D2); + int nthr = static_cast(get_total_threads()); + if (static_cast(nthr) > work_amount) + nthr = static_cast(work_amount); + if (nthr == 1) { + for_3d(0, 1, D0, D1, D2, func); + } else { + simple_parallel_for(static_cast(nthr), [&](size_t ithr) { + for_3d(static_cast(ithr), nthr, D0, D1, D2, func); + }); + } +} + +} // namespace llmdnn \ No newline at end of file diff --git a/src/common/tensor.cpp b/src/common/tensor.cpp new file mode 100644 index 0000000..0117972 --- /dev/null +++ b/src/common/tensor.cpp @@ -0,0 +1,164 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include + +#include "common/log.hpp" +#include "bf16.hpp" +#include "llm_tensor.hpp" + +namespace llmdnn { + +tensor::tensor() { +} + +tensor::~tensor() { + if (m_capacity && m_ptr) + free(m_ptr); +} + +tensor tensor::index(const std::initializer_list& indices) const { + tensor sub_tensor; + assert(indices.size() <= m_rank); + int i_src = 0; + int i_dst = 0; + sub_tensor.m_capacity = 0; + size_t off = 0; + for (auto idx : indices) { + auto src_dim = m_dims[i_src]; + auto src_stride = m_strides[i_src]; + idx.regularize(src_dim); + off += idx.start * src_stride; + if (idx.slice_with_squeeze()) { + // no output dimension + i_src++; + continue; + } + sub_tensor.m_dims[i_dst] = idx.count; + sub_tensor.m_strides[i_dst] = src_stride; + i_dst++; + i_src++; + } + sub_tensor.m_rank = i_dst; // index may imply squeeze + sub_tensor.m_ptr = reinterpret_cast(m_ptr) + off; + return sub_tensor; +} + +// slice: return a sub-view (w/o ownership/refcount to original data) +tensor tensor::slice(int axis, int start, int end) const { + tensor sub_tensor; + assert(static_cast(axis) < m_rank); + + sub_tensor.m_capacity = 0; + sub_tensor.m_rank = m_rank; // slice dosen't change rank & strides + for (size_t i = 0; i < m_rank; i++) { + sub_tensor.m_strides[i] = m_strides[i]; + sub_tensor.m_dims[i] = m_dims[i]; + } + sub_tensor.m_dims[axis] = end - start; + + auto off = start * m_strides[axis]; + auto* data = reinterpret_cast(m_ptr) + off; + sub_tensor.m_ptr = reinterpret_cast(data); + + return sub_tensor; +} + +bool tensor::is_dense() const { + // check if it's dense tensor + size_t stride = m_element_size; + for (int i = m_rank - 1; i >= 0; i--) { + if (m_strides[i] != stride) + return false; + stride *= m_dims[i]; + } + return true; +} + +/* + suppose current shape is [a0,a1,...,am] + and target shape is [b0,b1,...,bn] + reshape is only valid when (a0*a1*...*am) == (b0*b1*...*bn) <======= (A) + + uniform a tensor's shape into groups from last to first, the dimension is merged + into current group if the subtensor in the group is still dense after merge. + otherwise a new group is formed. + + then reshape is performed on group basis, the check (A) is performed on group bases. + which means any reshape inside the group is OK, but not across the group boundary. + + this can be done in one-loop, while group is forming, and checks are performed. + + simplified form is when whole tensor is dense +*/ +tensor tensor::reshape(const std::initializer_list& target_shape) const { + // only valid for dense memory + tensor new_tensor_view; + assert(is_dense()); + new_tensor_view.resize(target_shape.begin(), target_shape.size(), m_ptr, m_element_size, m_dtype); + return new_tensor_view; +} + +tensor tensor::permute(const std::initializer_list& order) const { + tensor new_tensor_view; + assert(order.size() == m_rank); + new_tensor_view.m_capacity = 0; + new_tensor_view.m_ptr = m_ptr; + new_tensor_view.m_rank = m_rank; + auto it_order = order.begin(); + // also should check order has no repeat element + for (size_t i = 0; i < m_rank; i++) { + auto j = *it_order++; + assert(j >= 0 && j < m_rank); + new_tensor_view.m_dims[i] = m_dims[j]; + new_tensor_view.m_strides[i] = m_strides[j]; + } + return new_tensor_view; +} + +void tensor::resize(const size_t* new_dims, size_t dim_num, void* data, size_t element_size, data_type_t dtype) { + // initialize strides for compact/dense tensor + m_element_size = element_size; + m_dtype = dtype; + m_rank = dim_num; + assert(m_rank <= TENSOR_RANK_MAX); + size_t stride = element_size; + for (int i = m_rank - 1; i >= 0; i--) { + m_dims[i] = new_dims[i]; + m_strides[i] = stride; + stride *= new_dims[i]; + } + + if (!data) { + auto capacity_new = m_strides[0] * m_dims[0]; + if (capacity_new > m_capacity) { + m_ptr = aligned_alloc(64, capacity_new); + m_capacity = capacity_new; + } + } else { + // m_capacity is zero to indicate that we don't own the memory + m_capacity = 0; + m_ptr = reinterpret_cast(data); + } +} + +void tensor::assert_dims(const std::initializer_list& expect_dims) const { + if (m_rank != expect_dims.size()) { + DEBUG_LOG << "dims not same\n"; + } + if (!std::equal(expect_dims.begin(), expect_dims.end(), m_dims)) { + DEBUG_LOG << " m_dims=["; + for (size_t i = 0; i < m_rank; i++) + DEBUG_LOG << m_dims[i] << ","; + DEBUG_LOG << "] expect_dims=["; + for (auto& i : expect_dims) + DEBUG_LOG << i << ","; + DEBUG_LOG << "]"; + } +} + +} // namespace llmdnn diff --git a/src/common/tensor2d.hpp b/src/common/tensor2d.hpp new file mode 100644 index 0000000..8e1756d --- /dev/null +++ b/src/common/tensor2d.hpp @@ -0,0 +1,191 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include +#include +#include "memory_alloc.hpp" +#include "log.hpp" +#include "bf16.hpp" + +#define rndup(x, n) ((((x) + (n) - 1) / (n)) * (n)) + +template +struct tensor2D { + int dims[2] = {0}; + T* data = nullptr; + int64_t capacity = 0; + int stride = 0; + bool force_compact = false; + bool own = false; + bool use_numa_alloc = false; + int padded_dim1 = 0; + + tensor2D() = default; + tensor2D(const tensor2D&) = delete; + ~tensor2D() { + if (own && data) llmdnn_free(data, capacity, use_numa_alloc); + } + + operator bool() { + return dims[0] * dims[1] > 0; + } + + tensor2D(int d0, int d1, bool _force_compact = false) { + capacity = 0; + resize(d0, d1, _force_compact); + } + + tensor2D(int d0, int d1, T * ext, int _stride) { + capacity = 1; + data = ext; + own = false; + dims[0] = d0; + dims[1] = d1; + stride = _stride; + padded_dim1 = stride / sizeof(T); + } + + tensor2D Tr(bool _force_compact = false) { + tensor2D ret(dims[1], dims[0], _force_compact); + for(int c0=0; c0 < dims[0]; ++c0) { + for(int c1=0; c1 < dims[1]; ++c1) { + ret(c1, c0) = (*this)(c0, c1); + } + } + return ret; + } + tensor2D clone() { + tensor2D ret; + ret.resize(dims[0], dims[1], force_compact); + if (ret.stride == stride) { + memcpy(ret.data, data, dims[0] * stride); + }else{ + for(int i=0;i= dims[0] && dim1 >= dims[1]); + for(int i = 0; i < dims[0]; i++) { + memcpy(&ret(i, 0), &(*this)(i, 0), dims[1] * sizeof(T)); + memset(reinterpret_cast(&ret(i, 0) + dims[1]), 0, ret.stride - dims[1] * sizeof(T)); + } + if (dims[1] == dim1) { + memset(reinterpret_cast(ret.data + dims[0] * ret.padded_dim1), 0, (dim0 - dims[0]) * ret.stride); + } + + return ret; + } + void copyto_with_padzero(tensor2D& dst, int dim0, int dim1) { + dst.resize(dim0, dim1, force_compact); + assert(dim0 >= dims[0] && dim1 >= dims[1]); + for(int i = 0; i < dims[0]; i++) { + memcpy(&dst(i, 0), &(*this)(i, 0), dims[1] * sizeof(T)); + memset(reinterpret_cast(&dst(i, 0) + dims[1]), 0, dst.stride - dims[1] * sizeof(T)); + } + if (dims[1] == dim1) { + memset(reinterpret_cast(dst.data + dims[0] * dst.padded_dim1), 0, (dim0 - dims[0]) * dst.stride); + } + } + void resize(int d0, int d1, bool _force_compact = false, bool is_const = false) { + force_compact = _force_compact; + dims[0] = d0; + dims[1] = d1; + stride = d1 * sizeof(T); + if ((stride % 64) && (!force_compact)) { + auto stride_fix = rndup(stride, 64); + stride = stride_fix; + } + padded_dim1 = stride / sizeof(T); + + // resize method never shrink capacity, and extra T is added to put nan as test + auto need_capacity = dims[0] * stride + 4096; + if (capacity < need_capacity) { + own = true; + if (!is_const) + need_capacity *= 2; + // align begin address to cache line is vital, so tile load can + // use all bandwidth (L1D/L2 only deliver data in unit of 64-byte aligned cache-line) + if (data) llmdnn_free(data, capacity, use_numa_alloc); + use_numa_alloc = is_const; + data = reinterpret_cast(llmdnn_alloc(64, need_capacity, use_numa_alloc)); + capacity = need_capacity; + if (is_const) + memset(static_cast(data), 0, need_capacity); + if (reinterpret_cast(data) % 64) + DEBUG_LOG << "WARNING: resize(), data is not cache-line aligned!" << std::endl; + } + // put a NaN at the end to test over-read + // https://en.wikipedia.org/wiki/Bfloat16_floating-point_format + // #define INF 0xff80 + // #define NAN1 (INF + 1) + // if (sizeof(T) == 2) { + // *reinterpret_cast(data.get() + dims[0] * padded_dim1) = NAN1; + // } + } + + T & operator[](int i) { + return data[i]; + } + + const T & operator[](int i) const { + return data[i]; + } + + //https://stackoverflow.com/questions/1936399/c-array-operator-with-multiple-arguments + T & operator()(int i0, int i1) { + return (*this)[i0 * padded_dim1 + i1]; + } + + const T & operator()(int i0, int i1) const { + return (*this)[i0 * padded_dim1 + i1]; + } + + + void operator=(const T & v) { + for(int k = 0; k < dims[0] * padded_dim1; k++) + (*this)[k] = v; + } + + tensor2D& operator=(const tensor2D& t2) = delete; + + // move semantics + tensor2D(tensor2D && t2) { + dims[0] = t2.dims[0]; + dims[1] = t2.dims[1]; + if (own && data) llmdnn_free(data, capacity, use_numa_alloc); + data = t2.data; + own = t2.own; + capacity = t2.capacity; + stride = t2.stride; + padded_dim1 = t2.padded_dim1; + force_compact = t2.force_compact; + t2.capacity = 0; + t2.data = nullptr; + } + + tensor2D& operator=(tensor2D && t2) { + dims[0] = t2.dims[0]; + dims[1] = t2.dims[1]; + if (own && data) llmdnn_free(data, capacity, use_numa_alloc); + own = t2.own; + data = t2.data; + capacity = t2.capacity; + stride = t2.stride; + padded_dim1 = t2.padded_dim1; + force_compact = t2.force_compact; + t2.capacity = 0; + t2.data = nullptr; + return *this; + } +}; diff --git a/src/common/tensor2d_helper.hpp b/src/common/tensor2d_helper.hpp new file mode 100644 index 0000000..5221280 --- /dev/null +++ b/src/common/tensor2d_helper.hpp @@ -0,0 +1,199 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include +#include +#include +#include "llm_types.hpp" +#include "tensor2d.hpp" +#include "bf16.hpp" +#ifdef _WIN32 +#include +#else +#include +#endif +#include + +// https://stackoverflow.com/questions/570669/checking-if-a-double-or-float-is-nan-in-c/57770634#57770634 +static inline uint32_t load_ieee754_rep(float a) { + uint32_t r; + static_assert(sizeof r == sizeof a, "Unexpected sizes."); + std::memcpy(&r, &a, sizeof a); // Generates movd instruction. + return r; +} +constexpr uint32_t inf_float_shl1 = UINT32_C(0xff000000); +// The shift left removes the sign bit. The exponent moves into the topmost bits, +// so that plain unsigned comparison is enough. +static inline bool isnan2(float a) { return load_ieee754_rep(a) << 1 > inf_float_shl1; } +static inline bool isinf2(float a) { return load_ieee754_rep(a) << 1 == inf_float_shl1; } +static inline bool isfinite2(float a) { return load_ieee754_rep(a) << 1 < inf_float_shl1; } + +template +void fill_rnd(tensor2D& t) { + auto * p = t.data; + int i = 0; + int total = t.dims[0] * t.padded_dim1; + // +1 -1 for integer types + // 0.5 -0.5 for float point + float scale = std::is_integral::value ? 2:1; + for(i = 0; i + 8 <= total; i+=8) { + // lower mantissa can help to avoid small errors in accuracy comparison + auto num = rand() & 0xFF; + p[i] = scale*((num & 1) - 0.5f); num>>=1; + p[i+1] = scale*((num & 1) - 0.5f); num>>=1; + p[i+2] = scale*((num & 1) - 0.5f); num>>=1; + p[i+3] = scale*((num & 1) - 0.5f); num>>=1; + p[i+4] = scale*((num & 1) - 0.5f); num>>=1; + p[i+5] = scale*((num & 1) - 0.5f); num>>=1; + p[i+6] = scale*((num & 1) - 0.5f); num>>=1; + p[i+7] = scale*((num & 1) - 0.5f); num>>=1; + } + for(; i +bool operator==(const tensor2D& lhs, const tensor2D& rhs) { + if (lhs.dims[0] != rhs.dims[0] || lhs.dims[1] != rhs.dims[1]) + return false; + for(int i0 = 0; i0 < lhs.dims[0]; i0++) + for(int i1 = 0; i1 < lhs.dims[1]; i1++) { + // with -ffast-math, std::isnan, std::isinf, x != x always return false + // so we need special logic to test nan here + if (std::is_same::value || + std::is_same::value) { + float f0 = lhs(i0,i1); + float f1 = rhs(i0,i1); + if (isnan2(f1) || isnan2(f0)) { + DEBUG_LOG << " nan is found: f0=" << f0 << ", f1=" << f1 << std::endl; + return false; + } + if (std::abs(f0 - f1) <= 0.01) + continue; + } + + if (lhs(i0,i1) == rhs(i0,i1)) + continue; + DEBUG_LOG << " operator== failed at (" << i0 << ", " << i1 << ") value " + << lhs(i0,i1) << "!=" << rhs(i0,i1) << std::endl; + return false; + } + return true; +} + +template +bool is_normal(const tensor2D& t) { + for (int i0 = 0; i0 < t.dims[0]; i0++) + for (int i1 = 0; i1 < t.dims[1]; i1++) { + float f0 = t(i0,i1); + if (isnan2(f0)) { + DEBUG_LOG << " found nan at (" << i0 << "," << i1 << ")" << std::endl; + return false; + } + if (isinf2(f0)) { + DEBUG_LOG << " found inf at (" << i0 << "," << i1 << ")" << std::endl; + return false; + } + } + return true; +} + +template +bool compare(const tensor2D& lhs, const tensor2D& rhs, float tolerance) { + float max_abs_diff = 0; + float max_rel_diff = 0; + if (lhs.dims[0] != rhs.dims[0] || lhs.dims[1] != rhs.dims[1]) + return false; + for (int i0 = 0; i0 < lhs.dims[0]; i0++) + for (int i1 = 0; i1 < lhs.dims[1]; i1++) { + float f0 = lhs(i0, i1); + float f1 = rhs(i0, i1); + auto diff = std::fabs(f0 - f1); + auto rel_diff = diff / std::fabs(f0); + max_abs_diff = std::max(max_abs_diff, diff); + if (std::fabs(lhs(i0,i1) > 0) && diff > 0) + max_rel_diff = std::max(max_rel_diff, rel_diff); + } + DEBUG_LOG << "max_abs_diff=" << max_abs_diff << " max_rel_diff=" << max_rel_diff << "\n"; + return tolerance > max_abs_diff; +} + +template +std::ostream& operator<<(std::ostream& out, const tensor2D& obj) { + int i0; + auto showline = [&](int i) { + out << "[" << i << "," << 0 << "]: "; + int i1; + for(i1=0; i1 +inline void show(const T * data, int rows, int cols) { + DEBUG_LOG << "==============\n"; + for(int i0=0; i0 < rows; i0++) { + DEBUG_LOG << "[" << i0 << "," << 0 << "]: "; + for(int i1=0; i1 +inline void vshow(__m512i v) { + T values[512/8/sizeof(T)]; + _mm512_storeu_si512(values, v); + show(values, 1, 512/8/sizeof(T)); +} + +template +inline void vshow(__m512 v) { + T values[512/8/sizeof(T)]; + _mm512_storeu_ps(values, v); + show(values, 1, 512/8/sizeof(T)); +} + +template +inline void tshow() { + if (std::is_same::value) { + ov::bfloat16 data[16*32]; + _tile_stored(tile, data, 64); + show(data, 16, 32); + } + if (std::is_same::value) { + float data[16*16]; + _tile_stored(tile, data, 64); + show(data, 16, 16); + } + if (std::is_same::value) { + int8_t data[16*64]; + _tile_stored(tile, data, 64); + show(data, 16, 64); + } + if (std::is_same::value) { + uint8_t data[16*64]; + _tile_stored(tile, data, 64); + show(data, 16, 64); + } +} diff --git a/src/common/utility.hpp b/src/common/utility.hpp new file mode 100644 index 0000000..7012629 --- /dev/null +++ b/src/common/utility.hpp @@ -0,0 +1,66 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include +#include +#include +#include "llm_types.hpp" + +#ifndef OV_DECL_ALIGNED +# ifdef __GNUC__ +# define OV_DECL_ALIGNED(x) __attribute__ ((aligned (x))) +# elif defined _MSC_VER +# define OV_DECL_ALIGNED(x) __declspec(align(x)) +# else +# define OV_DECL_ALIGNED(x) +# endif +#endif // OV_DECL_ALIGNED + +namespace llmdnn { + +inline size_t get_precision_size(data_type_t type) { + switch(type) { + case llmdnn_f16: + case llmdnn_bf16: + return 2; + case llmdnn_f32: + case llmdnn_s32: + return 4; + case llmdnn_s8: + case llmdnn_u8: + return 1; + case llmdnn_f64: + return 8; + default: + assert(false && "unknown data type"); + return 0; + } +} + +inline data_type_t get_dt_from_str(const std::string& name) { + static std::pair name2type[] = { + { "f16", llmdnn_f16 }, + { "bf16", llmdnn_bf16 }, + { "f32", llmdnn_f32 }, + { "s32", llmdnn_s32 }, + { "i32", llmdnn_s32 }, + { "s8", llmdnn_s8 }, + { "i8", llmdnn_s8 }, + { "u8", llmdnn_u8 }, + { "f64", llmdnn_f64 }, + }; + for (size_t i = 0; i < sizeof(name2type) / sizeof(name2type[0]); i++) { + if (name == name2type[i].first) + return name2type[i].second; + } + + return llmdnn_data_type_undef; +} + +} // namespace llmdnn diff --git a/src/emb_gpt_api.cpp b/src/emb_gpt_api.cpp new file mode 100644 index 0000000..792ff8c --- /dev/null +++ b/src/emb_gpt_api.cpp @@ -0,0 +1,26 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include + +#include "emb_gpt_avx512.hpp" + +namespace llmdnn { + +status_t emb_gpt(const tensor& q_src, + const tensor& k_src, + const tensor& v_src, + const tensor& k_past, + const tensor& v_past, + const tensor& q_dst, + const tensor& k_dst, + const tensor& v_dst, + const tensor& cos, + const tensor& sin, + const tensor& position2d_ids) { + return emb_gpt_avx512(q_src, k_src, v_src, k_past, v_past, q_dst, k_dst, v_dst, cos, sin, position2d_ids); +} + +} // namespace llmdnn diff --git a/src/emb_gpt_avx512.cpp b/src/emb_gpt_avx512.cpp new file mode 100644 index 0000000..1338494 --- /dev/null +++ b/src/emb_gpt_avx512.cpp @@ -0,0 +1,191 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include + +#include "common/log.hpp" +#include "common/bf16.hpp" +#include "common/simple_parallel.hpp" +#include "common/utility.hpp" +#include "llm_types.hpp" +#include "utility_kernel_avx512.hpp" +#include "transpose_kernel_avx512.hpp" +#include "llm_emb_gpt.hpp" +#include "emb_gpt_avx512.hpp" +#include "rotary_kernel_avx512.hpp" + +namespace llmdnn { + +static void memcpy_past_kv(const tensor& k_past, const tensor& v_past, const tensor& k_dst, const tensor& v_dst) { + auto batch = k_past.m_dims[0]; + auto head_num = k_past.m_dims[1]; + auto past_seq_len = k_past.m_dims[2]; + parallel_for3d(batch, head_num, past_seq_len, [&](size_t b, size_t h, size_t s) { + memcpy(&k_dst.at({b, h, s}), &k_past.at({b, h, s}), k_past.m_strides[2]); + memcpy(&v_dst.at({b, h, s}), &v_past.at({b, h, s}), v_past.m_strides[2]); + }); +} + +// q_src shape: [batch, q_seq_len, head_hum, head_size] +// q_dst shape: [batch, head_hum, q_seq_len, head_size] +// kv_src shape: [batch, q_seq_len, head_hum, head_size] +// kv_past shape: [batch, head_hum, past_seq_len, head_size] +// kv_dst shape: [batch, head_hum, q_seq_len+past_seq_len, head_size] +// position2d_ids: [batch, 2, q_seq_len] +// cos/sin: [max_seq_len, rotary_dims] +static void rotary_emb_position2d(const tensor& q_src, + const tensor& k_src, + const tensor& v_src, + const tensor& k_past, + const tensor& v_past, + const tensor& q_dst, + const tensor& k_dst, + const tensor& v_dst, + const tensor& cos, + const tensor& sin, + const tensor& position2d_ids) { + auto batch = k_past.m_dims[0]; + auto head_num = k_past.m_dims[1]; + auto past_seq_len = k_past.m_dims[2]; + auto head_size = k_past.m_dims[3]; + auto query_seq_len = q_src.m_dims[1]; + auto rotary_ndim = cos.m_dims[3]; + + parallel_for3d(batch, head_num, query_seq_len, [&](size_t b, size_t h, size_t s) { + auto kv_dst_s = s + past_seq_len; + // q, k rotary encoding + if (position2d_ids) { + auto pos = position2d_ids.at({b, 0, s}); + rotary_avx512(rotary_ndim, &cos.at({0, 0, pos}), &sin.at({0, 0, pos}), + &q_src.at({b, s, h}), + &k_src.at({b, s, h}), + &q_dst.at({b, h, s}), + &k_dst.at({b, h, kv_dst_s})); + pos = position2d_ids.at({b, 1, s}); + rotary_avx512(rotary_ndim, &cos.at({0, 0, pos}), &sin.at({0, 0, pos}), + &q_src.at({b, s, h, rotary_ndim}), + &k_src.at({b, s, h, rotary_ndim}), + &q_dst.at({b, h, s, rotary_ndim}), + &k_dst.at({b, h, kv_dst_s, rotary_ndim})); + } else { + rotary_avx512(rotary_ndim, &cos.at({0, 0, s + past_seq_len}), &sin.at({0, 0, s + past_seq_len}), + &q_src.at({b, s, h}), + &k_src.at({b, s, h}), + &q_dst.at({b, h, s}), + &k_dst.at({b, h, kv_dst_s})); + memcpy(&q_dst.at({b, h, s, rotary_ndim}), &q_src.at({b, s, h, rotary_ndim}), (head_size - rotary_ndim) * sizeof(ov::bfloat16)); + memcpy(&k_dst.at({b, h, kv_dst_s, rotary_ndim}), &k_src.at({b, s, h, rotary_ndim}), (head_size - rotary_ndim) * sizeof(ov::bfloat16)); + } + + // v concat + memcpy(&v_dst.at({b, h, kv_dst_s}), &v_src.at({b, s, h}), head_size * sizeof(ov::bfloat16)); + }); +} + + +// q_src shape: [batch, q_seq_len, num_kv_heads, head_num/num_kv_heads, head_size] +// q_dst shape: [batch, head_hum, q_seq_len, head_size] +// kv_src shape: [batch, q_seq_len, num_kv_heads, 1, head_size] +// kv_past shape: [batch, head_hum, past_seq_len, head_size] +// kv_dst shape: [batch, head_hum, q_seq_len+past_seq_len, head_size] +// position2d_ids: [batch, 2, q_seq_len] +// cos/sin: [max_seq_len, rotary_dims] +static void rotary_emb_falcon(const tensor& q_src, + const tensor& k_src, + const tensor& v_src, + const tensor& k_past, + const tensor& v_past, + const tensor& q_dst, + const tensor& k_dst, + const tensor& v_dst, + const tensor& cos, + const tensor& sin) { + auto batch = k_past.m_dims[0]; + auto head_num = k_past.m_dims[1]; + auto past_seq_len = k_past.m_dims[2]; + auto head_size = k_past.m_dims[3]; + auto query_seq_len = q_src.m_dims[1]; + auto rotary_ndim = cos.m_dims[3]; + auto num_kv_heads_in_group = q_src.m_dims[3]; + + parallel_for3d(batch, head_num, query_seq_len, [&](size_t b, size_t h, size_t s) { + auto kv_dst_s = s + past_seq_len; + auto cur_num_kv_heads = h / num_kv_heads_in_group; + auto cur_sub_head_num = h % num_kv_heads_in_group; + + // q, k rotary encoding + rotary_avx512(rotary_ndim, &cos.at({0, 0, s + past_seq_len}), &sin.at({0, 0, s + past_seq_len}), + &q_src.at({b, s, cur_num_kv_heads, cur_sub_head_num}), + &k_src.at({b, s, cur_num_kv_heads, 0}), + &q_dst.at({b, h, s}), + &k_dst.at({b, h, kv_dst_s})); + if (head_size > rotary_ndim) { + memcpy(&q_dst.at({b, h, s, rotary_ndim}), &q_src.at({b, s, h, rotary_ndim}), (head_size - rotary_ndim) * sizeof(ov::bfloat16)); + memcpy(&k_dst.at({b, h, kv_dst_s, rotary_ndim}), &k_src.at({b, s, h, rotary_ndim}), (head_size - rotary_ndim) * sizeof(ov::bfloat16)); + } + + // v concat + memcpy(&v_dst.at({b, h, kv_dst_s}), &v_src.at({b, s, cur_num_kv_heads, 0}), head_size * sizeof(ov::bfloat16)); + }); +} + +status_t emb_gpt_avx512(const tensor& q_src, + const tensor& k_src, + const tensor& v_src, + const tensor& k_past, + const tensor& v_past, + const tensor& q_dst, + const tensor& k_dst, + const tensor& v_dst, + const tensor& cos, + const tensor& sin, + const tensor& position2d_ids) { + if ((q_src.m_rank != 4 && q_src.m_rank != 5) || (k_src.m_rank != 4 && k_src.m_rank != 5) || (v_src.m_rank != 4 && v_src.m_rank != 5) || + k_past.m_rank != 4 || v_past.m_rank != 4 || q_dst.m_rank != 4 || + k_dst.m_rank != 4 || v_dst.m_rank != 4 || cos.m_rank != 4 || sin.m_rank != 4) { + DEBUG_LOG << "emb_gpt_avx512: rank is not correct: should be 4/5\n"; + return status_t::status_invalid_arguments; + } + if (position2d_ids) { + if (position2d_ids.m_rank != 3) { + DEBUG_LOG << "emb_gpt_avx512: position2d_ids rank should be 3\n"; + return status_t::status_invalid_arguments; + } + if (position2d_ids.m_dims[0] != q_src.m_dims[0] || position2d_ids.m_dims[1] != 2 || position2d_ids.m_dims[2] != q_src.m_dims[1]) { + DEBUG_LOG << "emb_gpt_avx512: position2d_ids dims should be [batch, 2, seq_len]\n"; + return status_t::status_invalid_arguments; + } + } + + // [batch, seq_len, (num_heads * 3 * head_size)] + // --> [batch, seq_len, num_heads, 3 * head_size] + auto past_seq_len = k_past.m_dims[2]; + + // past kv src != dst, copy src to dst first + if (k_past.m_ptr != k_dst.m_ptr && past_seq_len) + memcpy_past_kv(k_past, v_past, k_dst, v_dst); + + // transpose + rotary embbeding: + // transpose: [batch, seq_len, head_hum, 3 * head_size] --> + // 3 [batch, head_hum, seq_len, head_size] + // rotary embbeding: part of key will write to past_key, part of query will write to tempory buffer + if (q_src.m_dtype == llmdnn_s8) { + assert(false); + } else { + // query pass part(temp buffer): query = torch.cat((query, query_pass), dim=-1) + // key pass part(past_key): key = torch.cat((key, key_pass), dim=-1) + // value(pastKeys): value = torch.cat((past_value, value), dim=-2) + if (q_src.m_rank == 4) + rotary_emb_position2d(q_src, k_src, v_src, k_past, v_past, q_dst, k_dst, v_dst, cos, sin, position2d_ids); + else + rotary_emb_falcon(q_src, k_src, v_src, k_past, v_past, q_dst, k_dst, v_dst, cos, sin); + } + + return status_t::status_ok; +} + +} // namespace llmdnn diff --git a/src/emb_gpt_avx512.hpp b/src/emb_gpt_avx512.hpp new file mode 100644 index 0000000..c5bd134 --- /dev/null +++ b/src/emb_gpt_avx512.hpp @@ -0,0 +1,26 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include "llm_types.hpp" +#include "llm_emb_gpt.hpp" + +namespace llmdnn { + +status_t emb_gpt_avx512(const tensor& q_src, + const tensor& k_src, + const tensor& v_src, + const tensor& k_past, + const tensor& v_past, + const tensor& q_dst, + const tensor& k_dst, + const tensor& v_dst, + const tensor& cos, + const tensor& sin, + const tensor& position2d_ids); +} // namespace llmdnn diff --git a/src/fc_amx.cpp b/src/fc_amx.cpp new file mode 100644 index 0000000..8255851 --- /dev/null +++ b/src/fc_amx.cpp @@ -0,0 +1,333 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include + +#include "common/log.hpp" +#include "common/simple_parallel.hpp" +#include "common/tensor2d.hpp" +#include "common/utility.hpp" +#include "common/compatible.hpp" +#include "common/memory_alloc.hpp" +#include "llm_types.hpp" +#include "utility_kernel_avx512.hpp" +#include "mm_kernel_common_amx.hpp" +#include "softmax_kernel_avx512.hpp" +#include "transpose_kernel_avx512.hpp" +#include "llm_fc.hpp" +#include "fc_amx.hpp" + +namespace llmdnn { + +struct fc_impl_amx : public fc::impl { + fc_impl_amx() = default; + ~fc_impl_amx(); + + bool init(const fc_create_param& param) override; + void pack_weight(const tensor& w) override; + status_t exec(const tensor& input, const tensor& output, const tensor& dq, const tensor& q, const tensor& bias) override; + void associate_thread_numa(const llm_vector& numa_nodes); + void init_m_block(); + void init_weight_compress_param(); + + fc_create_param _create_param; + llm_vector _kernel; // one kernel for each numa node + llm_vector _weights; // one weight for each numa node + llm_vector _weight_sizes; // one weight size for each numa node + llm_vector _numa_nodes; // numa nodes + size_t _thread_nums; // thread numbers + size_t _N_in_one_numa; // N on each numa node + size_t _m_block_num_idea; // idea m block number for best thread balance + size_t _n_block_num_idea; // idea n block number for best thread balance + llm_vector _thread_nums_in_one_numa; // thread numbers in one numa node + size_t _K_align; + struct work_info { + int numa_id = 0; // numa node id, use to index in _weights + size_t thread_no_in_one_numa = 0; // sequence no in one numa node + }; + llm_vector _thread_infos; // map thread id to numa node id and thread no in one numa node + tensor2D _descale; + tensor2D _zp; +}; + +fc_impl_amx::~fc_impl_amx() { + for (size_t i = 0; i < _kernel.size(); i++) { + if (_kernel[i]) + fc_kernel_destroy(_kernel[i]); + } + for (size_t i = 0; i < _weight_sizes.size(); i++) { + llmdnn_free_on(_weights[i], _weight_sizes[i]); + } +} + +void fc_impl_amx::init_weight_compress_param() { + fc_create_param& param = _create_param; + if (param.scale) { + auto size = rndup(param.scale_zp_size, 64 / sizeof(float)); + _descale.resize(1, size, false, false); + memcpy(_descale.data, param.scale, param.scale_zp_size * sizeof(float)); + memset(_descale.data + param.scale_zp_size, 0, (size - param.scale_zp_size) * sizeof(float)); + auto zp_size = rndup(param.scale_zp_size * 2, 64 / sizeof(float)); + _zp.resize(1, zp_size, false, false); + if (param.zp) { + for (int i = 0; i < param.scale_zp_size; i++) { + _zp(0, 2 * i) = param.zp[i]; + _zp(0, 2 * i + 1) = param.zp[i]; + } + memset(_zp.data + param.scale_zp_size * 2, 0, (zp_size - param.scale_zp_size * 2) * sizeof(float)); + } else { + memset(_zp.data, 0, zp_size * sizeof(float)); + } + param.scale = _descale.data; + param.zp = _zp.data; + } +} + +bool fc_impl_amx::init(const fc_create_param& param) { + _create_param = param; + _thread_nums = get_total_threads(); + _kernel.resize(_thread_nums, nullptr); + init_weight_compress_param(); + bool ret = true; + for (size_t i = 0; i < _thread_nums; i++) { + if (fc_kernel_create(&_kernel[i], &_create_param) != llmdnn::status_ok) { + ret = false; + break; + } + } + if (ret) { + _numa_nodes = llmdnn_get_numa_nodes(); + associate_thread_numa(_numa_nodes); + } + + return ret; +} + +void fc_impl_amx::init_m_block() { + size_t div = 2; + auto work_amount = _N_in_one_numa / 32; + size_t work = work_amount; + // worse case: M block number is _thread_num + size_t threads = std::max(1ul, _thread_nums / _numa_nodes.size()); + while (div <= work_amount) { + // if work and threads can be divided by div, M block number can be dived by div + if (work % div == 0 && threads % div == 0) { + threads /= div; + work /= div; + } else { + div++; + } + if (work < div) + break; + } + _m_block_num_idea = threads; + _n_block_num_idea = _thread_nums / _numa_nodes.size() / threads; +} + +void fc_impl_amx::associate_thread_numa(const llm_vector& numa_nodes) { + _thread_infos.resize(_thread_nums); + struct int_atomic { + std::atomic_int v{0}; + }; + llm_vector thread_id_in_one_numa(numa_nodes.size()); + // the real numa id may not be continuous, but we need a number to index _numa_nodes + parallel_for(_thread_nums, [&] (size_t id) { + auto cur_numa_id = llmdnn_get_numa_id_for_cur_task(); + for (int i = 0; i < static_cast(numa_nodes.size()); i++) { + if (numa_nodes[i] == cur_numa_id) { + _thread_infos[id].numa_id = i; + _thread_infos[id].thread_no_in_one_numa = thread_id_in_one_numa[i].v.fetch_add(1); + break; + } + } + }); + + // check: the index is stable in another loop + std::mutex m; + parallel_for(_thread_nums, [&] (size_t id) { + auto cur_numa_id = llmdnn_get_numa_id_for_cur_task(); + for (int i = 0; i < static_cast(numa_nodes.size()); i++) { + if (numa_nodes[i] == cur_numa_id) { + if (_thread_infos[id].numa_id != i) { + std::lock_guard l(m); + DEBUG_LOG << "index test warning: cur numa index of thread no " << id << " is " << i << ", prev index " << _thread_infos[id].numa_id << "\n"; + } + break; + } + } + }); + + // check: each numa should have same thread numbers + _thread_nums_in_one_numa.resize(numa_nodes.size()); + int actual_threads = thread_id_in_one_numa[0].v; + _thread_nums_in_one_numa[0] = thread_id_in_one_numa[0].v; + bool zero_threads_in_one_numa = _thread_nums_in_one_numa[0] == 0; + for (size_t i = 1; i < thread_id_in_one_numa.size(); i++) { + if (thread_id_in_one_numa[0].v != thread_id_in_one_numa[i].v) { + DEBUG_LOG << "numa test warning: thread number of numa " << i << " is " << thread_id_in_one_numa[i].v << ", not equal to numa 0 thread numbers: " << thread_id_in_one_numa[0].v << "\n"; + } + actual_threads += thread_id_in_one_numa[i].v; + _thread_nums_in_one_numa[i] = thread_id_in_one_numa[i].v; + zero_threads_in_one_numa |= _thread_nums_in_one_numa[i] == 0; + } + if (zero_threads_in_one_numa) { + // no threads in one numa, the result will be wrong + DEBUG_LOG << "zero threads warning: there is no threads in some numa. Will assign threads statically.\n"; + } + + // check: actual threads number should equal to _thread_nums + if (static_cast(_thread_nums) != actual_threads) { + DEBUG_LOG << "thread number test warning: actual threads number: " << actual_threads << ", not equal to _thread_nums " << _thread_nums << "\n"; + } + + // fix thread numbers in one numa to get correct result regardless of performance + if (zero_threads_in_one_numa || static_cast(_thread_nums) != actual_threads) { + auto thread_num_in_one_numa = (_thread_nums + numa_nodes.size() - 1) / numa_nodes.size(); + for (size_t i = 0; i < numa_nodes.size(); i++) { + _thread_nums_in_one_numa[i] = std::min(thread_num_in_one_numa, _thread_nums - i * thread_num_in_one_numa); + } + for (int i = 0; i < static_cast(_thread_infos.size()); i++) { + _thread_infos[i].numa_id = i / thread_num_in_one_numa; + _thread_infos[i].thread_no_in_one_numa = i % thread_num_in_one_numa; + } + } +} + +void fc_impl_amx::pack_weight(const tensor& w) { + auto N = w.m_dims[_create_param.b_is_trans ? 0 : 1]; + auto K = w.m_dims[_create_param.b_is_trans ? 1 : 0]; + // will allocate memory on different numa nodes: + // 1, get numa nodes number, allocate memory on each numa node + // 2, get cores number, compute each cores area and pack each area simultaneously + auto numa_nodes_nums = _numa_nodes.size(); + auto N_blocks = rndup(N, 32) / 32; + // NOTE: assuming memory/thread is evenly distributed across mutiple numas. Need to support unbalanced numa? + _N_in_one_numa = (N_blocks + numa_nodes_nums - 1) / numa_nodes_nums * 32; + if (_create_param.dt_a == data_type_t::llmdnn_bf16) { + _K_align = rndup(K, 32); + } else { + _K_align = rndup(K, 64); + } + _weights.resize(numa_nodes_nums); + _weight_sizes.resize(numa_nodes_nums); + // allocate memory + for (size_t i = 0; i < numa_nodes_nums; i++) { + auto size = _K_align * _N_in_one_numa * get_precision_size(_create_param.dt_b); + _weights[i] = reinterpret_cast(llmdnn_alloc_on(64, size + 4096, _numa_nodes[i])); + _weight_sizes[i] = size + 4096; + memset(_weights[i] + size, 0, 4096); + } + auto work_amount_in_one_numa = _N_in_one_numa / 32; + parallel_for(_thread_nums, [&] (size_t id) { + auto numa_id = _thread_infos[id].numa_id; + auto thread_no_in_one_numa = _thread_infos[id].thread_no_in_one_numa; + size_t start, end; + splitter(work_amount_in_one_numa, static_cast(_thread_nums_in_one_numa[numa_id]), thread_no_in_one_numa, start, end); + size_t n0_in_one_numa = start * 32; + size_t n1_in_one_numa = std::min(end * 32, _N_in_one_numa); + if (n0_in_one_numa >= _N_in_one_numa) return; + auto n0 = n0_in_one_numa + _N_in_one_numa * numa_id; + auto n1 = n1_in_one_numa + _N_in_one_numa * numa_id; + n1 = std::min(n1, N); + if (n0 >= n1) return; + + auto dst = _weights[numa_id] + n0_in_one_numa * _K_align * get_precision_size(_create_param.dt_b); + fc_kernel_pack_weight_to_dst(_kernel[id], w.data(), dst, w.m_dtype, N, K, w.stride(0), n0, n1); + }); + init_m_block(); +} + +status_t fc_impl_amx::exec(const tensor& input, const tensor& output, const tensor& dq, const tensor& q, const tensor& bias) { + if (input.m_rank != 2 || output.m_rank != 2 || bias.m_rank != 2) { + DEBUG_LOG << "input,output,bias rank should be 2.\n"; + return status_t::status_invalid_arguments; + } + + auto M = input.size(0); + auto N = output.size(1); + auto K = input.size(1); + auto work_amount_n_in_one_numa = _N_in_one_numa / 32; + if (M < 32) { + parallel_for(_thread_nums, [&](size_t id) { + auto numa_id = _thread_infos[id].numa_id; + auto thread_no_in_one_numa = _thread_infos[id].thread_no_in_one_numa; + size_t start, end; + splitter(work_amount_n_in_one_numa, static_cast(_thread_nums_in_one_numa[numa_id]), thread_no_in_one_numa, start, end); + size_t n0_in_one_numa = start * 32; + size_t n1_in_one_numa = std::min(end * 32, _N_in_one_numa); + if (n0_in_one_numa >= _N_in_one_numa) return; + auto n0 = n0_in_one_numa + _N_in_one_numa * numa_id; + auto n1 = n1_in_one_numa + _N_in_one_numa * numa_id; + n1 = std::min(n1, N); + if (n0 >= n1) return; + + auto weight = _weights[numa_id] + n0_in_one_numa * _K_align * get_precision_size(_create_param.dt_b); + fc_kernel_execute(_kernel[id], input.data(), weight, output.data(), input.stride(0), + output.stride(0), M, N, K, n0, n1, dq.data(), q.data(), bias.data()); + }); + } else { + // row number of each block + auto m_row = rndup(M, _m_block_num_idea) / _m_block_num_idea; + auto work_amount_n_block = _n_block_num_idea; + // at least 32 rows + if (m_row < 32) { + m_row = 32; + work_amount_n_block = work_amount_n_in_one_numa; + } + auto work_amount_m = rndup(M, m_row) / m_row; + auto work_amount = work_amount_n_block * work_amount_m; + auto n_block_in_one_numa = work_amount_n_in_one_numa / work_amount_n_block; + parallel_for(_thread_nums, [&](size_t id) { + auto numa_id = _thread_infos[id].numa_id; + auto thread_no_in_one_numa = _thread_infos[id].thread_no_in_one_numa; + size_t start, end; + splitter(work_amount, static_cast(_thread_nums_in_one_numa[numa_id]), thread_no_in_one_numa, start, end); + size_t m_start{0}, n_start{0}; + if (M > N) + parallel_it_init(start, m_start, work_amount_m, n_start, work_amount_n_block); + else + parallel_it_init(start, n_start, work_amount_n_block, m_start, work_amount_m); + for (auto work = start; work < end; work++) { + size_t n0_in_one_numa = n_start * n_block_in_one_numa * 32; + size_t n1_in_one_numa = n0_in_one_numa + n_block_in_one_numa * 32; //std::min(n_end * 32, _N_in_one_numa); + //if (n0_in_one_numa >= _N_in_one_numa) return; + auto n0 = n0_in_one_numa + _N_in_one_numa * numa_id; + auto n1 = n1_in_one_numa + _N_in_one_numa * numa_id; + n1 = std::min(n1, N); + if (n0 >= n1) continue; + + size_t m0 = m_start * m_row; + size_t m1 = std::min(m0 + m_row, M); + size_t m = m1 - m0; + + auto weight = _weights[numa_id] + n0_in_one_numa * _K_align * get_precision_size(_create_param.dt_b); + fc_kernel_execute(_kernel[id], + input.data() + m0 * input.stride(0), + weight, + output.data() + m0 * output.stride(0), + input.stride(0), + output.stride(0), + m, N, K, n0, n1, + dq.data(), + q.data(), + bias.data()); + if (M > N) + parallel_it_step(m_start, work_amount_m, n_start, work_amount_n_block); + else + parallel_it_step(n_start, work_amount_n_block, m_start, work_amount_m); + } + }); + } + + return status_t::status_ok; +} + +fc::impl* new_fc_impl_amx() { + return new fc_impl_amx(); +} + +} // namespace llmdnn diff --git a/src/fc_amx.hpp b/src/fc_amx.hpp new file mode 100644 index 0000000..53b23ad --- /dev/null +++ b/src/fc_amx.hpp @@ -0,0 +1,17 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include "llm_types.hpp" +#include "llm_fc.hpp" + +namespace llmdnn { + +fc::impl* new_fc_impl_amx(); + +} // namespace llmdnn diff --git a/src/fc_api.cpp b/src/fc_api.cpp new file mode 100644 index 0000000..3abdf96 --- /dev/null +++ b/src/fc_api.cpp @@ -0,0 +1,32 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include + +#include "fc_amx.hpp" + +namespace llmdnn { + +// interface +fc::fc(): _impl(new_fc_impl_amx()) { +} + +fc::~fc() { + delete _impl; +} + +bool fc::init(const fc_create_param& param) { + return _impl->init(param); +} + +void fc::pack_weight(const tensor& w) { + return _impl->pack_weight(w); +} + +status_t fc::exec(const tensor& input, const tensor& output, const tensor& dq, const tensor& q, const tensor& bias) { + return _impl->exec(input, output, dq, q, bias); +} + +} // namespace llmdnn diff --git a/src/fc_kernel_amx.cpp b/src/fc_kernel_amx.cpp new file mode 100644 index 0000000..0292a4e --- /dev/null +++ b/src/fc_kernel_amx.cpp @@ -0,0 +1,445 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "llm_fc.hpp" +#include "llm_types.hpp" +#include "mm_kernel_common_amx.hpp" +#include "utility_kernel_avx512.hpp" +#include "fc_kernel_amx.hpp" +#include "common/compatible.hpp" + +namespace llmdnn { + +using ov::bfloat16; +struct fc_kernel { + std::unique_ptr> bf16xbf16; + std::unique_ptr> bf16xi8; + std::unique_ptr> i8xi8; + std::unique_ptr> u8xi8; + + data_type_t dt_a; + data_type_t dt_b; + data_type_t dt_c; + size_t stride_b; + postops_types postops_type; + bool b_is_transpose; +}; + +using supported_key = std::tuple; +using supported_value = std::pair; +static bool check_valid_postops(size_t value, data_type_t dt_a, data_type_t dt_b, data_type_t dt_c) { + llm_map supported_postops = { + { { llmdnn_s8, llmdnn_s8, llmdnn_s8 }, { DEQUANT | QUANT, BIAS | GELU | GELU_TANH } }, + { { llmdnn_s8, llmdnn_s8, llmdnn_bf16 }, { DEQUANT, BIAS | GELU | GELU_TANH } }, + { { llmdnn_s8, llmdnn_s8, llmdnn_f32 }, { DEQUANT, BIAS | GELU | GELU_TANH } }, + { { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16 }, { 0, BIAS | GELU | GELU_TANH } }, + { { llmdnn_bf16, llmdnn_bf16, llmdnn_f32 }, { 0, BIAS | GELU | GELU_TANH } }, + { { llmdnn_bf16, llmdnn_u8, llmdnn_f32 }, { DEQUANT, BIAS | GELU | GELU_TANH } }, + { { llmdnn_bf16, llmdnn_u8, llmdnn_bf16 }, { DEQUANT, BIAS | GELU | GELU_TANH } }, + }; + + auto it = supported_postops.find(std::make_tuple(dt_a, dt_b, dt_c)); + if (it == supported_postops.end()) { + return false; + } + + size_t must_have; + size_t opt_have; + must_have = (*it).second.first; + opt_have = (*it).second.second; + + if ((value & must_have) != must_have) + return false; + // value must in must_have and opt_have + if ((value & ~(must_have | opt_have)) != 0) + return false; + + return true; +} + +// interface +status_t fc_kernel_create_amx(fc_kernel** mm, const fc_create_param* param) { + fc_kernel* m = nullptr; + if (param == nullptr || mm == nullptr) { + DEBUG_LOG << "fc_kernel_create: invalid input parameter.\n"; + goto ERR; + } + + if (!check_valid_postops(static_cast(param->postops_type), param->dt_a, param->dt_b, param->dt_c)) { + DEBUG_LOG << "fc_kernel_create: unsupported data type, a: " << param->dt_a <<", b: " << param->dt_b << ", c: " << param->dt_c << + ", postops type: " << param->postops_type << ".\n"; + goto ERR; + } + + m = new fc_kernel; + if (param->dt_a == llmdnn_s8 && param->dt_b == llmdnn_s8) { + m->i8xi8 = std::make_unique>(true, param->b_is_trans); + } else if (param->dt_a == llmdnn_u8 && param->dt_b == llmdnn_s8) { + m->u8xi8 = std::make_unique>(true, param->b_is_trans); + } else if (param->dt_a == llmdnn_bf16 && (param->dt_b == llmdnn_bf16 || param->dt_b == llmdnn_f32)) { + m->bf16xbf16 = std::make_unique>(true, param->b_is_trans); + } else if (param->dt_a == llmdnn_bf16 && param->dt_b == llmdnn_u8) { + m->bf16xi8 = std::make_unique>(true, param->b_is_trans); + m->bf16xi8->dequant_scale_B = param->scale; + m->bf16xi8->zp = param->zp; + } else { + DEBUG_LOG << "fc_kernel_create: unsupport input type, a: " << param->dt_a << ", b: " << param->dt_b << ".\n"; + goto ERR; + } + + m->dt_a = param->dt_a; + m->dt_b = param->dt_b; + m->dt_c = param->dt_c; + m->b_is_transpose = param->b_is_trans; + m->postops_type = param->postops_type; + + *mm = m; + return status_t::status_ok; +ERR: + delete m; + return status_t::status_invalid_arguments; +} + +void fc_kernel_destroy_amx(fc_kernel* mm) { + if (mm) { + delete mm; + } +} + +void fc_kernel_pack_weight_amx(fc_kernel* mm, void* ptr_b, data_type_t dt_b, size_t N, size_t K, size_t stride_b, size_t n_start, size_t n_end) { + mm->stride_b = stride_b; + size_t b_d0 = K, b_d1 = N; + if (mm->b_is_transpose) { + b_d0 = N; + b_d1 = K; + } + if (mm->i8xi8) { + tensor2D b(b_d0, b_d1, static_cast(ptr_b), mm->stride_b); + auto matB = amx_kernel::getSubMatB(b, n_start, n_end, mm->b_is_transpose); + amx_kernel::repackB_1x2(matB, mm->b_is_transpose, mm->i8xi8->internalB, true); + } else if (mm->u8xi8) { + tensor2D b(b_d0, b_d1, static_cast(ptr_b), mm->stride_b); + auto matB = amx_kernel::getSubMatB(b, n_start, n_end, mm->b_is_transpose); + amx_kernel::repackB_1x2(matB, mm->b_is_transpose, mm->u8xi8->internalB, true); + } else if (mm->bf16xbf16) { + if (dt_b == llmdnn_bf16) { + tensor2D b(b_d0, b_d1, static_cast(ptr_b), mm->stride_b); + auto matB = amx_kernel::getSubMatB(b, n_start, n_end, mm->b_is_transpose); + amx_kernel::repackB_1x2(matB, mm->b_is_transpose, mm->bf16xbf16->internalB, true); + } else { + tensor2D b(b_d0, b_d1, static_cast(ptr_b), mm->stride_b); + auto matB = amx_kernel::getSubMatB(b, n_start, n_end, mm->b_is_transpose); + amx_kernel::repackB_1x2(matB, mm->b_is_transpose, mm->bf16xbf16->internalB, true); + } + } else { + assert(dt_b == llmdnn_u8); + tensor2D b(b_d0, b_d1, static_cast(ptr_b), mm->stride_b); + auto matB = amx_kernel::getSubMatB(b, n_start, n_end, mm->b_is_transpose); + amx_kernel::repackB_1x2_compressed(matB, mm->b_is_transpose, mm->bf16xi8->internalBI8, true); + } +} + +void fc_kernel_pack_weight_to_dst_amx(fc_kernel* mm, void* src_b, void* dst_b, data_type_t dt_b, size_t N, size_t K, size_t stride_b, size_t n_start, size_t n_end) { + mm->stride_b = stride_b; + size_t b_d0 = K, b_d1 = N; + if (mm->b_is_transpose) { + b_d0 = N; + b_d1 = K; + } + if (mm->i8xi8) { + tensor2D b(b_d0, b_d1, static_cast(src_b), mm->stride_b); + // do not care about the real dimension, only ensure .capacity big enough + mm->i8xi8->internalB = tensor2D(1, 1, static_cast(dst_b), 1); + mm->i8xi8->internalB.capacity = INT_MAX; + auto matB = amx_kernel::getSubMatB(b, n_start, n_end, mm->b_is_transpose); + amx_kernel::repackB_1x2(matB, mm->b_is_transpose, mm->i8xi8->internalB, true); + } else if (mm->u8xi8) { + tensor2D b(b_d0, b_d1, static_cast(src_b), mm->stride_b); + mm->u8xi8->internalB = tensor2D(1, 1, static_cast(dst_b), 1); + mm->u8xi8->internalB.capacity = INT_MAX; + auto matB = amx_kernel::getSubMatB(b, n_start, n_end, mm->b_is_transpose); + amx_kernel::repackB_1x2(matB, mm->b_is_transpose, mm->u8xi8->internalB, true); + } else if (mm->bf16xbf16) { + mm->bf16xbf16->internalB = tensor2D(1, 1, static_cast(dst_b), 1); + mm->bf16xbf16->internalB.capacity = INT_MAX; + if (dt_b == llmdnn_bf16) { + tensor2D b(b_d0, b_d1, static_cast(src_b), mm->stride_b); + auto matB = amx_kernel::getSubMatB(b, n_start, n_end, mm->b_is_transpose); + amx_kernel::repackB_1x2(matB, mm->b_is_transpose, mm->bf16xbf16->internalB, true); + } else { + tensor2D b(b_d0, b_d1, static_cast(src_b), mm->stride_b); + auto matB = amx_kernel::getSubMatB(b, n_start, n_end, mm->b_is_transpose); + amx_kernel::repackB_1x2(matB, mm->b_is_transpose, mm->bf16xbf16->internalB, true); + } + } else { + assert(dt_b == llmdnn_u8); + mm->bf16xi8->internalBI8 = tensor2D(1, 1, static_cast(dst_b), 1); + mm->bf16xi8->internalBI8.capacity = INT_MAX; + tensor2D b(b_d0, b_d1, static_cast(src_b), mm->stride_b); + auto matB = amx_kernel::getSubMatB(b, n_start, n_end, mm->b_is_transpose); + amx_kernel::repackB_1x2_compressed(matB, mm->b_is_transpose, mm->bf16xi8->internalBI8, true); + } +} + +void fc_kernel_execute_amx(fc_kernel* mm, void* ptr_a, void* ptr_b, void* ptr_c, size_t stride_a, size_t stride_c, + size_t M, size_t N, size_t K, size_t n_start, size_t n_end, float* dq, float* q, float* bias) { + size_t b_d0 = K, b_d1 = N; + if (mm->b_is_transpose) { + b_d0 = N; + b_d1 = K; + } + if (mm->i8xi8) { + tensor2D a(M, K, reinterpret_cast(ptr_a), stride_a); + tensor2D b(b_d0, b_d1, nullptr, mm->stride_b); + if (ptr_b) { + auto K_padded = rndup(K, 64); + mm->i8xi8->internalB = tensor2D(N / 32, 32 * K_padded, static_cast(ptr_b), 32 * K_padded); + } + + if (mm->dt_c == llmdnn_s8) { + tensor2D c(M, N, reinterpret_cast(ptr_c), stride_c); + if (!(mm->postops_type & BIAS)) { + if (mm->postops_type & GELU) { + amx_kernel::PP::BiasGeluStore ppkernel(c); + ppkernel.set_deq_scale(dq); + ppkernel.set_q_scale(q); + (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); + } else if (mm->postops_type & GELU_TANH) { + amx_kernel::PP::BiasGeluStore ppkernel(c); + ppkernel.set_deq_scale(dq); + ppkernel.set_q_scale(q); + (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); + } else { + amx_kernel::PP::BiasGeluStore ppkernel(c); + ppkernel.set_deq_scale(dq); + ppkernel.set_q_scale(q); + (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); + } + } else { + if (mm->postops_type & GELU) { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + ppkernel.set_deq_scale(dq); + ppkernel.set_q_scale(q); + (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); + } else if (mm->postops_type & GELU_TANH) { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + ppkernel.set_deq_scale(dq); + ppkernel.set_q_scale(q); + (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); + } else { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + ppkernel.set_deq_scale(dq); + ppkernel.set_q_scale(q); + (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); + } + } + } else if (mm->dt_c == llmdnn_bf16) { + tensor2D c(M, N, reinterpret_cast(ptr_c), stride_c); + if (!bias) { + if (mm->postops_type & GELU) { + amx_kernel::PP::BiasGeluStore ppkernel(c); + ppkernel.set_deq_scale(dq); + (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); + } else if (mm->postops_type & GELU_TANH) { + amx_kernel::PP::BiasGeluStore ppkernel(c); + ppkernel.set_deq_scale(dq); + (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); + } else { + amx_kernel::PP::BiasGeluStore ppkernel(c); + ppkernel.set_deq_scale(dq); + (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); + } + } else { + if (mm->postops_type & GELU) { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + ppkernel.set_deq_scale(dq); + (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); + } else if (mm->postops_type & GELU_TANH) { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + ppkernel.set_deq_scale(dq); + (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); + } else { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + ppkernel.set_deq_scale(dq); + (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); + } + } + } else if (mm->dt_c == llmdnn_f32) { + tensor2D c(M, N, reinterpret_cast(ptr_c), stride_c); + if (!bias) { + if (mm->postops_type & GELU) { + amx_kernel::PP::BiasGeluStore ppkernel(c); + ppkernel.set_deq_scale(dq); + (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); + } else if (mm->postops_type & GELU_TANH) { + amx_kernel::PP::BiasGeluStore ppkernel(c); + ppkernel.set_deq_scale(dq); + (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); + } else { + amx_kernel::PP::BiasGeluStore ppkernel(c); + ppkernel.set_deq_scale(dq); + (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); + } + } else { + if (mm->postops_type & GELU) { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + ppkernel.set_deq_scale(dq); + (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); + } else if (mm->postops_type & GELU_TANH) { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + ppkernel.set_deq_scale(dq); + (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); + } else { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + ppkernel.set_deq_scale(dq); + (*mm->i8xi8)(a, b, n_start, n_end, ppkernel); + } + } + } + } else if (mm->u8xi8) { + tensor2D a(M, K, reinterpret_cast(ptr_a), stride_a); + tensor2D b(b_d0, b_d1, nullptr, mm->stride_b); + tensor2D c(M, N, reinterpret_cast(ptr_c), stride_c); + amx_kernel::PP::BiasGeluStore pp(c); + (*mm->u8xi8)(a, b, n_start, n_end, pp); + } else if (mm->bf16xbf16) { + tensor2D a(M, K, reinterpret_cast(ptr_a), stride_a); + tensor2D b(b_d0, b_d1, nullptr, mm->stride_b); + if (ptr_b) { + auto K_padded = rndup(K, 32); + mm->bf16xbf16->internalB = tensor2D(N / 32, 32 * K_padded, static_cast(ptr_b), 32 * K_padded * sizeof(bfloat16)); + } + if (mm->dt_c == llmdnn_bf16) { + tensor2D c(M, N, reinterpret_cast(ptr_c), stride_c); + if (!(mm->postops_type & BIAS)) { + if (mm->postops_type & GELU) { + amx_kernel::PP::BiasGeluStore ppkernel(c); + (*mm->bf16xbf16)(a, b, n_start, n_end, ppkernel); + } else if (mm->postops_type & GELU_TANH) { + amx_kernel::PP::BiasGeluStore ppkernel(c); + (*mm->bf16xbf16)(a, b, n_start, n_end, ppkernel); + } else { + amx_kernel::PP::BiasGeluStore ppkernel(c); + (*mm->bf16xbf16)(a, b, n_start, n_end, ppkernel); + } + } else { + if (mm->postops_type & GELU) { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + (*mm->bf16xbf16)(a, b, n_start, n_end, ppkernel); + } else if (mm->postops_type & GELU_TANH) { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + (*mm->bf16xbf16)(a, b, n_start, n_end, ppkernel); + } else { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + (*mm->bf16xbf16)(a, b, n_start, n_end, ppkernel); + } + } + } else if (mm->dt_c == llmdnn_f32) { + tensor2D c(M, N, reinterpret_cast(ptr_c), stride_c); + if (!(mm->postops_type & BIAS)) { + if (mm->postops_type & GELU) { + amx_kernel::PP::BiasGeluStore ppkernel(c); + (*mm->bf16xbf16)(a, b, n_start, n_end, ppkernel); + } else if (mm->postops_type & GELU_TANH) { + amx_kernel::PP::BiasGeluStore ppkernel(c); + (*mm->bf16xbf16)(a, b, n_start, n_end, ppkernel); + } else { + amx_kernel::PP::BiasGeluStore ppkernel(c); + (*mm->bf16xbf16)(a, b, n_start, n_end, ppkernel); + } + } else { + if (mm->postops_type & GELU) { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + (*mm->bf16xbf16)(a, b, n_start, n_end, ppkernel); + } else if (mm->postops_type & GELU_TANH) { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + (*mm->bf16xbf16)(a, b, n_start, n_end, ppkernel); + } else { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + (*mm->bf16xbf16)(a, b, n_start, n_end, ppkernel); + } + } + } + } else { + tensor2D a(M, K, reinterpret_cast(ptr_a), stride_a); + tensor2D b(b_d0, b_d1, nullptr, mm->stride_b); + + if (ptr_b) { + auto K_padded = rndup(K, 32); + mm->bf16xi8->internalBI8 = tensor2D(N / 32, 32 * K_padded, static_cast(ptr_b), 32 * K_padded); + } + + if (mm->dt_c == llmdnn_bf16) { + tensor2D c(M, N, reinterpret_cast(ptr_c), stride_c); + if (!(mm->postops_type & BIAS)) { + if (mm->postops_type & GELU) { + amx_kernel::PP::BiasGeluStore ppkernel(c); + (*mm->bf16xi8)(a, b, n_start, n_end, ppkernel); + } else if (mm->postops_type & GELU_TANH) { + amx_kernel::PP::BiasGeluStore ppkernel(c); + (*mm->bf16xi8)(a, b, n_start, n_end, ppkernel); + } else { + amx_kernel::PP::BiasGeluStore ppkernel(c); + (*mm->bf16xi8)(a, b, n_start, n_end, ppkernel); + } + } else { + if (mm->postops_type & GELU) { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + (*mm->bf16xi8)(a, b, n_start, n_end, ppkernel); + } else if (mm->postops_type & GELU_TANH) { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + (*mm->bf16xi8)(a, b, n_start, n_end, ppkernel); + } else { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + (*mm->bf16xi8)(a, b, n_start, n_end, ppkernel); + } + } + } else if (mm->dt_c == llmdnn_f32) { + tensor2D c(M, N, reinterpret_cast(ptr_c), stride_c); + if (!(mm->postops_type & BIAS)) { + if (mm->postops_type & GELU) { + amx_kernel::PP::BiasGeluStore ppkernel(c); + (*mm->bf16xi8)(a, b, n_start, n_end, ppkernel); + } else if (mm->postops_type & GELU_TANH) { + amx_kernel::PP::BiasGeluStore ppkernel(c); + (*mm->bf16xi8)(a, b, n_start, n_end, ppkernel); + } else { + amx_kernel::PP::BiasGeluStore ppkernel(c); + (*mm->bf16xi8)(a, b, n_start, n_end, ppkernel); + } + } else { + if (mm->postops_type & GELU) { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + (*mm->bf16xi8)(a, b, n_start, n_end, ppkernel); + } else if (mm->postops_type & GELU_TANH) { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + (*mm->bf16xi8)(a, b, n_start, n_end, ppkernel); + } else { + amx_kernel::PP::BiasGeluStore ppkernel(c, bias); + (*mm->bf16xi8)(a, b, n_start, n_end, ppkernel); + } + } + } + } +} + +void fc_kernel_bf16w8_get_q_dq_amx(size_t K, size_t N, size_t stride, void* ptr, float* q, float* dq) { + float min, max; + tensor2D B(K, N, reinterpret_cast(ptr), stride); + amx_kernel::functional::get_min_max(B, min, max); + max = std::max(std::abs(max), std::abs(min)); + *q = 127 / max; + *dq = max / 127; +} + +} // namespace llmdnn diff --git a/src/fc_kernel_amx.hpp b/src/fc_kernel_amx.hpp new file mode 100644 index 0000000..f293f0c --- /dev/null +++ b/src/fc_kernel_amx.hpp @@ -0,0 +1,22 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "llm_fc.hpp" + +namespace llmdnn { + +status_t fc_kernel_create_amx(fc_kernel** mm, const fc_create_param* param); + +void fc_kernel_destroy_amx(fc_kernel* mm); + +void fc_kernel_pack_weight_amx(fc_kernel* mm, void* ptr_b, data_type_t dt_b, size_t N, size_t K, size_t stride_b, size_t n_start, size_t n_end); + +void fc_kernel_pack_weight_to_dst_amx(fc_kernel* mm, void* src_b, void* dst_b, data_type_t dt_b, size_t N, size_t K, size_t stride_b, size_t n_start, size_t n_end); + +void fc_kernel_execute_amx(fc_kernel* mm, void* ptr_a, void* ptr_b, void* ptr_c, size_t stride_a, size_t stride_c, + size_t M, size_t N, size_t K, size_t n_start, size_t n_end, float* dq, float* q, float* bias); + +void fc_kernel_bf16w8_get_q_dq_amx(size_t K, size_t N, size_t stride, void* ptr, float* q, float* dq); + +} // namespace llmdnn diff --git a/src/fc_kernel_api.cpp b/src/fc_kernel_api.cpp new file mode 100644 index 0000000..956b1a8 --- /dev/null +++ b/src/fc_kernel_api.cpp @@ -0,0 +1,52 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "llm_fc.hpp" +#include "fc_kernel_amx.hpp" +#include "mm_kernel_common_amx.hpp" +#include "utility_kernel_avx512.hpp" + +namespace llmdnn { + +static decltype(&fc_kernel_create) fc_kernel_create_ptr = fc_kernel_create_amx; +static decltype(&fc_kernel_destroy) fc_kernel_destroy_ptr = fc_kernel_destroy_amx; +static decltype(&fc_kernel_pack_weight) fc_kernel_pack_weight_ptr = fc_kernel_pack_weight_amx; +static decltype(&fc_kernel_pack_weight_to_dst) fc_kernel_pack_weight_to_dst_ptr = fc_kernel_pack_weight_to_dst_amx; +static decltype(&fc_kernel_execute) fc_kernel_execute_ptr = fc_kernel_execute_amx; + +// interface +status_t fc_kernel_create(fc_kernel** mm, const fc_create_param* param) { + return fc_kernel_create_ptr(mm, param); +} + +void fc_kernel_destroy(fc_kernel* mm) { + fc_kernel_destroy_ptr(mm); +} + +void fc_kernel_pack_weight(fc_kernel* mm, void* ptr_b, data_type_t dt_b, size_t N, size_t K, size_t stride_b, size_t n_start, size_t n_end) { + fc_kernel_pack_weight_ptr(mm, ptr_b, dt_b, N, K, stride_b, n_start, n_end); +} + +void fc_kernel_pack_weight_to_dst(fc_kernel* mm, void* src_b, void* dst_b, data_type_t dt_b, size_t N, size_t K, size_t stride_b, size_t n_start, size_t n_end) { + fc_kernel_pack_weight_to_dst_ptr(mm, src_b, dst_b, dt_b, N, K, stride_b, n_start, n_end); +} + +void fc_kernel_execute(fc_kernel* mm, void* ptr_a, void* ptr_b, void* ptr_c, size_t stride_a, size_t stride_c, + size_t M, size_t N, size_t K, size_t n_start, size_t n_end, float* dq, float* q, float* bias) { + fc_kernel_execute_ptr(mm, ptr_a, ptr_b, ptr_c, stride_a, stride_c, M, N, K, n_start, n_end, dq, q, bias); +} + +} // namespace llmdnn diff --git a/src/gelu_kernel_avx512.hpp b/src/gelu_kernel_avx512.hpp new file mode 100644 index 0000000..de635b1 --- /dev/null +++ b/src/gelu_kernel_avx512.hpp @@ -0,0 +1,386 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#ifdef _WIN32 +#include +#else +#include +#include +#endif +#include "common/bf16.hpp" +#include "llm_types.hpp" +#include "common/utility.hpp" +#include "utility_kernel_avx512.hpp" + +namespace llmdnn { + + // gelu_erf_minimax_approx_compute_vector_fwd in oneDNN + // x*0.5*(1+erf(x/sqrt(2))) = x*0.5*(1 + x*Polynomial(x^2)) + inline __m512 gelu_erf_minmax_approx_avx512(__m512 & x) { + auto x2 = _mm512_mul_ps(x, x); // x^2 + + auto x_positive = _mm512_castsi512_ps(_mm512_and_epi32(_mm512_castps_si512(x), _mm512_set1_epi32(0x7FFFFFFF))); // clear sign mask + auto x_half = _mm512_mul_ps(x, _mm512_set1_ps(0.5f)); + + auto poly = _mm512_castsi512_ps(_mm512_set1_epi32(0x1f1c83fd)); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0xa3198977))); // poly * x^2 + xxx + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0x268a7927))); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0xa998c963))); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0x2c67ddb2))); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0xaf013b2c))); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0x315d4a4f))); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0xb3969b11))); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0x35a776e9))); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0xb79b0914))); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0x3970b255))); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0xbb1b7399))); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0x3ca3621f))); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0xbe082bc7))); + poly = _mm512_fmadd_ps(poly, x2, _mm512_castsi512_ps(_mm512_set1_epi32(0x3f4c4228))); + + // 1.0f + erf(x * inv_sqrt2) = 1.0f + x * P(x^2) + poly = _mm512_fmadd_ps(poly, x, _mm512_set1_ps(1.0f)); + // x*0.5*(1 + x*Polynomial(x^2)) + poly = _mm512_mul_ps(poly, x_half); + + // combine: + // zone_id + // 1 -inf; -saturation_lbound : 0.0f + // 2 -saturation_lbound; -linear_ubound : x*0.5*(1 + x*Polynomial(x^2)) + // 3 -linear_ubound, linear_ubound : x*0.5 + // 4 linear_ubound : saturation_lbound : x*0.5*(1 + x*Polynomial(x^2)) + // 5 saturation_lbound: +inf : x + constexpr int neg_saturation_lbound = 0xc0a00000; + constexpr int linear_ubound = 0x33800000; + constexpr int saturation_lbound = 0x40a00000; + + auto mask_x_not_zone1 = _mm512_cmpnlt_ps_mask(x, _mm512_castsi512_ps(_mm512_set1_epi32(neg_saturation_lbound))); + x = _mm512_maskz_mov_ps(mask_x_not_zone1, x); + + auto mask_x_in_zone5 = _mm512_cmpnle_ps_mask(x_positive, _mm512_castsi512_ps(_mm512_set1_epi32(saturation_lbound))); + poly = _mm512_mask_mov_ps(poly, mask_x_in_zone5, x); + + auto mask_x_in_zone3 = _mm512_cmple_ps_mask(x_positive, _mm512_castsi512_ps(_mm512_set1_epi32(linear_ubound))); + poly = _mm512_mask_mov_ps(poly, mask_x_in_zone3, x_half); + return poly; + } + + // gelu_tanh_compute_vector_fwd in oneDNN + inline __m512 gelu_tanh_avx512(__m512& x) { + // compute G(x) = sqrt_root_two_over_pi * x * (1 + fitting_const * x * x) + auto x2 = _mm512_mul_ps(x, x); + auto y = _mm512_fmadd_ps(x2, (__m512)_mm512_set1_epi32(0x3d372713), _mm512_set1_ps(1.0f)); + y = _mm512_mul_ps(y, x); + y = _mm512_mul_ps(y, (__m512)_mm512_set1_epi32(0x3f4c422a)); + + // compute tanh(G(x)) + // We split the positive domain in 33 intervals: + // a) [0; linear_ubound]: in this interval tanh(x) = x + // b) [linear_ubound; 0x1.8p-12]: This interval spans part of a + // half binade + // c) [0x1.8p-12; 0x1.0p-11], ..., [0x1.8p2; 0x1.0p3]: + // one interval for each half binade, there are 29 of those + // d) [0x1.0p3; saturation_ubound]: + // This interval spans part of a half binade + // e) [0x1.205966p3; saturation_ubound]: in this interval, tanh(x) = 1 + // For b-d, we need 31 polynomials and will do a table lookup for those. + // To simplify the logic, we will also put a) in the table. + + // The polynomials are of degree 6, so we need to gather 7 coefficients. + // - sse4.1: we do it the naive way using vextract/vinsert. + // Here we will extract the indices in gpr only once and + // reuse them as there are only 4 of them. + // - avx: we do the same as for sse4.1 but use half of the 64-bits + // registers to store the idx of second half of YMM and half for + // responding XMM. Halfway through the copy we exchange Xmm and + // higher half of Ymm and we get the expected result. + // - avx2: we use vpermps and blend for each coefficient. + // This needs an extra vmm to store the mask + // - avx512: because the table fits in 2 registers, we can use vpermi2d. + + // because tanh(x) = -tanh(-x), we extract sign to make x postive + // and reapply sign at the end + auto y_positive = _mm512_and_ps(y, (__m512)(_mm512_set1_epi32(0x7fffffff))); + + // We compute the indices for the table lookup + auto indices = _mm512_sub_epi32((__m512i)y_positive, _mm512_set1_epi32(0x39800000)); + + indices = _mm512_and_epi32(indices, _mm512_set1_epi32(0xffc00000)); + indices = _mm512_srli_epi32(indices, 22); + // we do the argument reduction + auto y_shift = _mm512_and_ps(y_positive, (__m512)_mm512_set1_epi32(0xffc00000)); + + y_shift = _mm512_sub_ps(y_positive, y_shift); + + static uint32_t OV_DECL_ALIGNED(64) tanh_pol_table[] = { + // coefficients of degree 0 + 0x00000000, + 0x39bfffff, + 0x39ffffff, + 0x3a3ffffe, + 0x3a7ffffb, + 0x3abffff7, + 0x3affffeb, + 0x3b3fffdc, + 0x3b7fffab, + 0x3bbfff70, + 0x3bfffeab, + 0x3c3ffdc0, + 0x3c7ffaab, + 0x3cbff701, + 0x3cffeaad, + 0x3d3fdc08, + 0x3d7faacd, + 0x3dbf7081, + 0x3dfeacc9, + 0x3e3dc7fd, + 0x3e7acbf5, + 0x3eb77a9f, + 0x3eec9a9f, + 0x3f22991f, + 0x3f42f7d6, + 0x3f67b7cc, + 0x3f76ca83, + 0x3f7ebbe9, + 0x3f7fd40c, + 0x3f7fff32, + 0x3f7ffffc, + 0x3f800000, + // coefficients of degree 1 + 0x3f800000, + 0x3f800018, + 0x3f7fffe8, + 0x3f7fffda, + 0x3f7fffdc, + 0x3f7fffdc, + 0x3f7fffac, + 0x3f7fff70, + 0x3f7ffeec, + 0x3f7ffdc0, + 0x3f7ffbed, + 0x3f7ff704, + 0x3f7feff5, + 0x3f7fdbca, + 0x3f7fbfff, + 0x3f7f7041, + 0x3f7f009b, + 0x3f7dc36c, + 0x3f7c0aa8, + 0x3f7734b8, + 0x3f70a4de, + 0x3f5f1fd8, + 0x3f495493, + 0x3f18b9ec, + 0x3ed706cb, + 0x3e390b06, + 0x3d90b11f, + 0x3c21a053, + 0x3aaf7fdb, + 0x37ccc1a3, + 0x355c6733, + 0x00000000, + // coefficients of degree 2 + 0x00000000, + 0xbe4e0ff1, + 0x3d25b1b1, + 0x3d6b6dab, + 0x3c9fb1d5, + 0xbabff06f, + 0x3c07b3f6, + 0xbb3fc1bc, + 0x3a9f5921, + 0xbbbf06f2, + 0xbbb0f402, + 0xbc47db9e, + 0xbc73d5e7, + 0xbca25bda, + 0xbcfca780, + 0xbd40e07c, + 0xbd7dab03, + 0xbdbe4a0f, + 0xbdfb14a5, + 0xbe36cc8d, + 0xbe6bd102, + 0xbe9fe7c5, + 0xbeba0f10, + 0xbec206a8, + 0xbea3c388, + 0xbe277d62, + 0xbd8b7960, + 0xbc209f49, + 0xbaad44ca, + 0xb7c6eeac, + 0xb663aa41, + 0x00000000, + // coefficients of degree 3 + 0x00000000, + 0x45b3ae96, + 0xc414eb20, + 0xc450e02e, + 0xc3152b4e, + 0xbead2f56, + 0xc2162e02, + 0xbeb4bd5a, + 0xc11a59a4, + 0xbed2f507, + 0xc020d32c, + 0x3dd0f506, + 0xbf2a75e2, + 0xbff950e3, + 0xbed47334, + 0xbe809b8c, + 0xbeb64532, + 0xbe961a5b, + 0xbe9b63ac, + 0xbea0d4b2, + 0xbe828a77, + 0xbe378612, + 0xbdc20908, + 0x3d2d3957, + 0x3dd46e89, + 0x3db3f629, + 0x3d2c5e7b, + 0x3bd20403, + 0x3a59dfae, + 0x3770af45, + 0x372cc014, + 0x00000000, + // coefficients of degree 4 + 0x00000000, + 0xcc981a1b, + 0x4a7edd3d, + 0x4ab1007c, + 0x48fedd9c, + 0x41a557b5, + 0x477ee32a, + 0x422557f5, + 0x45ff3ce4, + 0x42a55641, + 0x446e0867, + 0xc33dc19a, + 0x42915214, + 0x43af4fad, + 0x4110fe88, + 0xc1099b75, + 0x3fc8a8dc, + 0xbfbeaef5, + 0xbe365aad, + 0x3f4d9652, + 0x3ddfa08f, + 0x3e34e9b8, + 0x3e2d07a6, + 0x3dc63567, + 0x3cdaeb78, + 0xbcd17537, + 0xbc92829c, + 0xbb43ab99, + 0xb9b471dd, + 0xb6baad5a, + 0xb78bafc7, + 0x00000000, + // coefficients of degree 5 + 0x00000000, + 0x52f688d5, + 0xd0505c72, + 0xd08f98e3, + 0xce505cc9, + 0xc7162b8a, + 0xcc5061d6, + 0xc7162bdf, + 0xca50b37f, + 0xc7162a3a, + 0xc8422086, + 0x471a714e, + 0xc5ece1f1, + 0xc70e3d90, + 0xc3eba94a, + 0x43e0c424, + 0xc21f4552, + 0x42217cc8, + 0x405e7dc4, + 0xc10dd401, + 0x3e96b602, + 0xbd1a6d2f, + 0xbd393883, + 0xbd674682, + 0xbd310016, + 0xb961e269, + 0x3ba32495, + 0x3a7680d5, + 0x38b3173c, + 0x35a9deea, + 0x375c3f2a, + 0x00000000, + // coefficients of degree 6 + 0x00000000, + 0xd8995ed1, + 0x558285ea, + 0x55b2cd69, + 0x53028625, + 0x4bc9991f, + 0x5082898a, + 0x4b4999b3, + 0x4e02c07c, + 0x4ac99764, + 0x4b72c822, + 0xca40c0e1, + 0x489413e4, + 0x49b12224, + 0x46134c4e, + 0xc60c2d57, + 0x43c83910, + 0xc3c872d1, + 0xc186bc9e, + 0x42325bc3, + 0xbf2ffa4a, + 0x3d9a203c, + 0xbc545a43, + 0xbae08fee, + 0x3c80225d, + 0x3b1fd1df, + 0xba36b9d1, + 0xb91de544, + 0xb71f100f, + 0xb408e2ed, + 0xb685fec8, + 0x00000000, + }; + auto pol = _mm512_permutex2var_ps(_mm512_load_ps(tanh_pol_table + 32 * 6), indices, _mm512_load_ps(tanh_pol_table + 32 * 6 + 16)); + + pol = _mm512_fmadd_ps(pol, y_shift, _mm512_permutex2var_ps(_mm512_load_ps(tanh_pol_table + 32 * 5), indices, _mm512_load_ps(tanh_pol_table + 32 * 5 + 16))); + pol = _mm512_fmadd_ps(pol, y_shift, _mm512_permutex2var_ps(_mm512_load_ps(tanh_pol_table + 32 * 4), indices, _mm512_load_ps(tanh_pol_table + 32 * 4 + 16))); + pol = _mm512_fmadd_ps(pol, y_shift, _mm512_permutex2var_ps(_mm512_load_ps(tanh_pol_table + 32 * 3), indices, _mm512_load_ps(tanh_pol_table + 32 * 3 + 16))); + pol = _mm512_fmadd_ps(pol, y_shift, _mm512_permutex2var_ps(_mm512_load_ps(tanh_pol_table + 32 * 2), indices, _mm512_load_ps(tanh_pol_table + 32 * 2 + 16))); + pol = _mm512_fmadd_ps(pol, y_shift, _mm512_permutex2var_ps(_mm512_load_ps(tanh_pol_table + 32 * 1), indices, _mm512_load_ps(tanh_pol_table + 32 * 1 + 16))); + pol = _mm512_fmadd_ps(pol, y_shift, _mm512_permutex2var_ps(_mm512_load_ps(tanh_pol_table + 32 * 0), indices, _mm512_load_ps(tanh_pol_table + 32 * 0 + 16))); + + // we restore src with cleared sign, and keep sign + auto sign = _mm512_and_ps(y, (__m512)_mm512_set1_epi32(0x80000000)); + + // Now we blend the results + // [saturation_ubound; +inf[ : we return +/- 1 + auto dst = (__m512)_mm512_set1_epi32(0x3f800000); + // [linear_ubound; saturation_lbound] : we return +/- P(x) + auto mask = (__m512)_mm512_set1_epi32(0x41102cb3); + auto mask16 = _mm512_cmp_ps_mask(mask, y_positive, _CMP_GT_OS); + dst = _mm512_mask_blend_ps(mask16, dst, pol); + // [0; linear_ubound] : we return x + mask = (__m512)_mm512_set1_epi32(0x39ddb3d7); + mask16 = _mm512_cmp_ps_mask(mask, y_positive, _CMP_GT_OS); + dst = _mm512_mask_blend_ps(mask16, dst, y_positive); + + // We reapply the sign and return + dst = _mm512_xor_ps(dst, sign); + + // compute 0.5 * x * (1 + tanh(G(x))) + dst = _mm512_add_ps(dst, (__m512)_mm512_set1_epi32(0x3f800000)); + dst = _mm512_mul_ps(dst, (__m512)_mm512_set1_epi32(0x3f000000)); + dst = _mm512_mul_ps(dst, x); + return dst; + } +} // namespace llmdnn diff --git a/src/mha_gpt_amx.cpp b/src/mha_gpt_amx.cpp new file mode 100644 index 0000000..e3f27fb --- /dev/null +++ b/src/mha_gpt_amx.cpp @@ -0,0 +1,311 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include + +#include "common/log.hpp" +#include "common/simple_parallel.hpp" +#include "common/tensor2d.hpp" +#include "common/utility.hpp" +#include "common/compatible.hpp" +#include "llm_types.hpp" +#include "utility_kernel_avx512.hpp" +#include "mm_kernel_common_amx.hpp" +#include "softmax_kernel_avx512.hpp" +#include "transpose_kernel_avx512.hpp" +#include "llm_mha_gpt.hpp" +#include "mha_gpt_amx.hpp" + +namespace llmdnn { + +struct mha_gpt_impl_amx : public mha_gpt::impl { + mha_gpt_impl_amx() = default; + ~mha_gpt_impl_amx(); + void create(data_type_t in_type, size_t seq_len, size_t head_size, bool is_bloom); + status_t exec(const tensor& q, const tensor& k, const tensor& v, const tensor& output, const tensor& attn_mask, + const tensor& alibi, const tensor& causal_mask, bool select_nfltmax_at_0, float normal_factor, bool use_causal_mask) override; + + void mha_bf16(const tensor& q, const tensor& k, const tensor& v, const tensor& output, const tensor& attn_mask, + const tensor& alibi, const tensor& causal_mask, bool select_nfltmax_at_0, float normal_factor, bool use_causal_mask); + + size_t _head_size_aligned = 0; + size_t _buffer_mat0_out_size = 0; + size_t _buffer_mat1_out_size = 0; + size_t _num_threads = 0; + + uint8_t* _buffer_mat0_out = nullptr; + uint8_t* _buffer_mat1_out = nullptr; + + bool _is_bloom; + + llm_vector*> gemAvB_BF16xBF16; + llm_vector*> qKtrGemm_BF16xBF16; + llm_vector*> qKVGemm_BF16xBF16; +}; + +mha_gpt_impl_amx::~mha_gpt_impl_amx() { + for (size_t i = 0; i < gemAvB_BF16xBF16.size(); i++) { + delete gemAvB_BF16xBF16[i]; + } + for (size_t i = 0; i < qKtrGemm_BF16xBF16.size(); i++) { + delete qKtrGemm_BF16xBF16[i]; + } + for (size_t i = 0; i < qKVGemm_BF16xBF16.size(); i++) { + delete qKVGemm_BF16xBF16[i]; + } + + if (_buffer_mat0_out) + free(_buffer_mat0_out); + if (_buffer_mat1_out) + free(_buffer_mat1_out); +} + +void mha_gpt_impl_amx::create(data_type_t in_type, size_t seq_len, size_t head_size, bool is_bloom) { + // q: [batch, head_num, query_seq_len, head_size] + // k: [batch, head_num, maxSeqLen(valid: key_seq_len), head_size] + // v: [batch, head_num, maxSeqLen(valid: value_seq_len), head_size] + // attention_mask: [batch, 1, 1, maxSeqLen(valid: key_seq_len)] + // matmul1: [batch, head_num, query_seq_len, head_size] + // attn_output: [batch, query_seq_len, head_num * head_size] + if (_num_threads == 0) { + _num_threads = get_total_threads(); + gemAvB_BF16xBF16.resize(_num_threads); + for (size_t i = 0; i < _num_threads; i++) { + gemAvB_BF16xBF16[i] = new amx_kernel::MatmulVector(); + } + qKtrGemm_BF16xBF16.resize(_num_threads); + for (size_t i = 0; i < _num_threads; i++) { + qKtrGemm_BF16xBF16[i] = new amx_kernel::Matmul(false, !is_bloom); + } + _is_bloom = is_bloom; + qKVGemm_BF16xBF16.resize(_num_threads); + for (size_t i = 0; i < _num_threads; i++) { + qKVGemm_BF16xBF16[i] = new amx_kernel::Matmul(false, false); + } + } + + // correct transposeB + if (_is_bloom != is_bloom) { + for (auto& mm : qKtrGemm_BF16xBF16) { + mm->transposeB = !is_bloom; + } + _is_bloom = is_bloom; + } + + auto buffer_mat0_out_size = seq_len * rndup(seq_len * sizeof(float), 64); + if (buffer_mat0_out_size > _buffer_mat0_out_size) { + _head_size_aligned = rndup(head_size, 32); + _buffer_mat0_out_size = seq_len * rndup(seq_len * sizeof(float), 64) * 3 / 2; + _buffer_mat1_out_size = seq_len * _head_size_aligned * sizeof(float) * 3 / 2; + if (_buffer_mat0_out) + free(_buffer_mat0_out); + if (_buffer_mat1_out) + free(_buffer_mat1_out); + _buffer_mat0_out = reinterpret_cast(aligned_alloc(64, _num_threads * _buffer_mat0_out_size)); + _buffer_mat1_out = reinterpret_cast(aligned_alloc(64, _num_threads * _buffer_mat1_out_size)); + } +} + +void mha_gpt_impl_amx::mha_bf16(const tensor& q, const tensor& k, const tensor& v, const tensor& output, const tensor& attn_mask, + const tensor& alibi, const tensor& causal_mask, bool select_nfltmax_at_0, float normal_factor, bool use_causal_mask) { + auto batch = q.m_dims[0]; + auto head_num = q.m_dims[1]; + auto query_seq_len = q.m_dims[2]; + auto head_size = q.m_dims[3]; + auto key_seq_len = k.m_dims[2]; + bool is_bloom = k.m_strides[3] > k.m_strides[2]; + auto h_group_num = k.m_dims[1]; + size_t h_each_group_len = head_num / h_group_num; + + uint8_t* out = output.data(); + + auto& gemAvB_ops = gemAvB_BF16xBF16; + auto& qKtrGemm_ops = qKtrGemm_BF16xBF16; + auto& qKVGemm_ops = qKVGemm_BF16xBF16; + bool use_gemv = query_seq_len == 1 && head_size >= 32 && head_size <= 32 * 6 && !is_bloom && !alibi && attn_mask && !causal_mask; + size_t head_stride_in_attn = head_size; + size_t batch_stride_in_attn = head_size * head_num * query_seq_len; + size_t causal_mask_offset_start = use_causal_mask ? key_seq_len - query_seq_len : key_seq_len; + + if (use_gemv) { + parallel_for2d(batch, head_num, [&](size_t thread_id, size_t i0, size_t i1) { + auto q_sub = &q.at({i0, i1}); + auto k_sub = &k.at({i0, i1 / h_each_group_len}); + auto v_sub = &v.at({i0, i1 / h_each_group_len}); + + auto mat0_out = reinterpret_cast(_buffer_mat0_out + thread_id * _buffer_mat0_out_size); + auto mat1_out = reinterpret_cast(_buffer_mat1_out + thread_id * _buffer_mat1_out_size); + + tensor2D matK(key_seq_len, head_size, reinterpret_cast(k_sub), k.m_strides[2]); + // N: key_seq_len, K: head_size + // q[1, K] * transpose(k[N, K]) ==> + // k[N, K] * transpose(q[1, K]) ==> + // k[N, K] * q[K, 1] + (*gemAvB_ops[thread_id])(matK, reinterpret_cast(q_sub), reinterpret_cast(mat0_out)); + + float* pMatMul0Out = reinterpret_cast(mat0_out); + mul_add2_select_f32_avx512(pMatMul0Out, pMatMul0Out, normal_factor, nullptr, &attn_mask.at({i0}), nullptr, false, key_seq_len); + softmax_avx512(reinterpret_cast(pMatMul0Out), pMatMul0Out, key_seq_len, nullptr); + auto out_sub = out + (i0 * batch_stride_in_attn + i1 * head_stride_in_attn) * sizeof(ov::bfloat16); + tensor2D matQK(query_seq_len, key_seq_len, reinterpret_cast(mat0_out), rndup(key_seq_len * sizeof(ov::bfloat16), 64)); + tensor2D matV(key_seq_len, head_size, reinterpret_cast(v_sub), v.m_strides[2]); + tensor2D matQKV(query_seq_len, head_size, reinterpret_cast(mat1_out), _head_size_aligned * sizeof(float)); + amx_kernel::PP::BiasGeluStore pp(matQKV); + (*qKVGemm_ops[thread_id])(matQK, matV, 0, head_size, pp); + memcpy2d_stride_avx512(reinterpret_cast(out_sub), reinterpret_cast(mat1_out), query_seq_len, + head_size, _head_size_aligned * sizeof(float), head_num * head_size * sizeof(ov::bfloat16), nullptr); + }); + } else { + size_t seq_count_all = rndup(query_seq_len, 32) / 32; + auto work_amount = batch * head_num * seq_count_all; + parallel_for(_num_threads, [&](size_t thread_id) { + size_t i0; + size_t i1; + size_t seq; + size_t start {0}, end {0}; + splitter(work_amount, _num_threads, thread_id, start, end); + if (start >= work_amount) return; + + parallel_it_init(start, i0, batch, i1, head_num, seq, seq_count_all); + ov::bfloat16* prev_k = nullptr; + ov::bfloat16* prev_v = nullptr; + for (size_t iwork = start; iwork < end; ++iwork) { + auto seq_start = seq * 32; + auto seq_end = std::min(seq_start + 32, query_seq_len); + auto seq_count = seq_end - seq_start; + // q: [batch, head_num, query_seq_len, head_size] + // k: [batch, head_num, key_seq_len, head_size] + // v: [batch, head_num, value_seq_len, head_size] + auto q_sub = &q.at({i0, i1, seq_start}); + auto k_sub = &k.at({i0, i1 / h_each_group_len}); + auto v_sub = &v.at({i0, i1 / h_each_group_len}); + + auto mat0_out = reinterpret_cast(_buffer_mat0_out + thread_id * _buffer_mat0_out_size); + auto mat1_out = reinterpret_cast(_buffer_mat1_out + thread_id * _buffer_mat1_out_size); + + tensor2D matQ(seq_count, head_size, q_sub, q.m_strides[2]); + tensor2D matQK(seq_count, key_seq_len, mat0_out, rndup(key_seq_len * sizeof(float), 64)); + amx_kernel::PP::BiasGeluStore pp(matQK); + if (!is_bloom) { + tensor2D matK(key_seq_len, head_size, k_sub, k.m_strides[2]); + (*qKtrGemm_ops[thread_id])(matQ, matK, 0, key_seq_len, pp, k_sub == prev_k); + } else { + tensor2D matK(head_size, key_seq_len, k_sub, k.m_strides[3]); + (*qKtrGemm_ops[thread_id])(matQ, matK, 0, key_seq_len, pp, k_sub == prev_k); + } + prev_k = k_sub; + tensor2D softmax_dst(seq_count, key_seq_len, reinterpret_cast(mat0_out), rndup(key_seq_len * sizeof(ov::bfloat16), 64)); + size_t valid_softmax_items = std::min(causal_mask_offset_start + seq_start + 1, key_seq_len); + // attn: [batch, 1, 1, key_seq_len] or [batch, 1, query_seq_len, key_seq_len] + // alibi: [batch, num_heads, 1, key_seq_len] + // causal: [batch/1, 1, query_seq_len, key_seq_len] + for (uint32_t m = 0; m < seq_count; m++) { + auto attn_sub = attn_mask ? &attn_mask.at({i0, 0, attn_mask.m_dims[2] == 1 ? 0 : m + seq_start}) : nullptr; + auto alibi_sub = alibi ? &alibi.at({i0, i1}) : nullptr; + auto causal_mask_sub = causal_mask ? &causal_mask.at({causal_mask.m_dims[0] == 1 ? 0 : i0, 0, m + seq_start}) : nullptr; + float* src = &matQK(m, 0); + ov::bfloat16* dst = &softmax_dst(m, 0); + mul_add2_select_f32_avx512(src, src, normal_factor, alibi_sub, attn_sub, causal_mask_sub, select_nfltmax_at_0, valid_softmax_items); + softmax_avx512(dst, src, valid_softmax_items, nullptr); + if (key_seq_len > valid_softmax_items) { + auto *invalidPtr = dst + valid_softmax_items; + memset(static_cast(invalidPtr), 0, (key_seq_len - valid_softmax_items) * sizeof(ov::bfloat16)); + valid_softmax_items = std::min(valid_softmax_items + 1, key_seq_len); + } + } + + auto out_sub = out + (i0 * batch_stride_in_attn + i1 * head_stride_in_attn + + seq_start * head_stride_in_attn * head_num) * sizeof(ov::bfloat16); + tensor2D matQKBF16(seq_count, key_seq_len, softmax_dst.data, softmax_dst.stride); + tensor2D matV(key_seq_len, head_size, v_sub, v.m_strides[2]); + tensor2D matQKV(seq_count, head_size, mat1_out, _head_size_aligned * sizeof(float)); + amx_kernel::PP::BiasGeluStore pp2(matQKV); + (*qKVGemm_ops[thread_id])(matQKBF16, matV, 0, head_size, pp2, prev_v == v_sub); + prev_v = v_sub; + memcpy2d_stride_avx512(reinterpret_cast(out_sub), mat1_out, seq_count, + head_size, _head_size_aligned * sizeof(float), head_num * head_size * sizeof(ov::bfloat16), nullptr); + parallel_it_step(i0, batch, i1, head_num, seq, seq_count_all); + } + }); + } +} + +status_t mha_gpt_impl_amx::exec(const tensor& q, const tensor& k, const tensor& v, const tensor& output, const tensor& attn_mask, const tensor& alibi, const tensor& causal_mask, bool select_nfltmax_at_0, float normal_factor, bool use_causal_mask) { + if (q.m_rank != 4 || k.m_rank != 4 || v.m_rank != 4) { + DEBUG_LOG << "q,k,v rank does not equal 4.\n"; + return status_t::status_invalid_arguments; + } + if (output.m_rank != 3) { + DEBUG_LOG << "output rank should be 3.\n"; + return status_t::status_invalid_arguments; + } + if (attn_mask) { + if (attn_mask.m_rank != 4) { + DEBUG_LOG << "attn_mask rank should be 4.\n"; + return status_t::status_invalid_arguments; + } + if (attn_mask.m_dims[1] != 1) { + DEBUG_LOG << "attn_mask dim 1 should be 1.\n"; + return status_t::status_invalid_arguments; + } + } + if (alibi) { + if (alibi.m_rank != 4) { + DEBUG_LOG << "alibi rank should be 4.\n"; + return status_t::status_invalid_arguments; + } + if (alibi.m_dims[1] != k.m_dims[1]) { + DEBUG_LOG << "alibi dim 1 should be equal to k dim 1.\n"; + return status_t::status_invalid_arguments; + } + if (alibi.m_dims[2] != 1) { + DEBUG_LOG << "alibi dim 2 should be 1.\n"; + return status_t::status_invalid_arguments; + } + } + if (causal_mask) { + if (causal_mask.m_rank != 4) { + DEBUG_LOG << "causal_mask rank should be 4.\n"; + return status_t::status_invalid_arguments; + } + if (use_causal_mask) { + DEBUG_LOG << "use_causal_mask must be false to disable builtin causal mask.\n"; + return status_t::status_invalid_arguments; + } + } + auto batch = q.m_dims[0]; + auto head_num = q.m_dims[1]; + auto head_size = q.m_dims[3]; + auto key_seq_len = k.m_dims[2]; + + if (!(batch == k.m_dims[0] && batch == v.m_dims[0] && + k.m_dims[1] == v.m_dims[1] && + key_seq_len == v.m_dims[2] && + head_size == k.m_dims[3] && head_size == v.m_dims[3])) { + DEBUG_LOG << "dim of q,k,v is error.\n"; + return status_t::status_invalid_arguments; + } + + bool is_bloom = k.m_strides[3] > k.m_strides[2]; + + auto in_dtype = q.m_dtype; + auto out_dtype = output.m_dtype; + + if (in_dtype == llmdnn_bf16 && out_dtype == llmdnn_bf16) { + create(in_dtype, key_seq_len, head_size, is_bloom); + mha_bf16(q, k, v, output, attn_mask, alibi, causal_mask, select_nfltmax_at_0, normal_factor, use_causal_mask); + } else { + DEBUG_LOG << "doesn't support provided input precisions.\n"; + return status_t::status_invalid_arguments; + } + + return status_t::status_ok; +} + +mha_gpt::impl* new_impl_amx() { + return new mha_gpt_impl_amx(); +} + +} // namespace llmdnn diff --git a/src/mha_gpt_amx.hpp b/src/mha_gpt_amx.hpp new file mode 100644 index 0000000..f75df98 --- /dev/null +++ b/src/mha_gpt_amx.hpp @@ -0,0 +1,17 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include "llm_types.hpp" +#include "llm_mha_gpt.hpp" + +namespace llmdnn { + +mha_gpt::impl* new_impl_amx(); + +} // namespace llmdnn diff --git a/src/mha_gpt_api.cpp b/src/mha_gpt_api.cpp new file mode 100644 index 0000000..0e35085 --- /dev/null +++ b/src/mha_gpt_api.cpp @@ -0,0 +1,24 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include + +#include "mha_gpt_amx.hpp" + +namespace llmdnn { + +// interface +mha_gpt::mha_gpt(): _impl(new_impl_amx()) { +} + +mha_gpt::~mha_gpt() { + delete _impl; +} + +status_t mha_gpt::exec(const tensor& q, const tensor& k, const tensor& v, const tensor& output, const tensor& attn_mask, const tensor& alibi, const tensor& causal_mask, bool select_nfltmax_at_0, float normal_factor, bool use_causal_mask) { + return _impl->exec(q, k, v, output, attn_mask, alibi, causal_mask, select_nfltmax_at_0, normal_factor, use_causal_mask); +} + +} // namespace llmdnn diff --git a/src/mm_kernel_amx.cpp b/src/mm_kernel_amx.cpp new file mode 100644 index 0000000..8fbc5ed --- /dev/null +++ b/src/mm_kernel_amx.cpp @@ -0,0 +1,121 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// +#include +#include +#include +#include +#include +#include +#include + +#include "llm_mm.hpp" +#include "llm_types.hpp" +#include "mm_kernel_common_amx.hpp" +#include "utility_kernel_avx512.hpp" +#include "mm_kernel_amx.hpp" + +namespace llmdnn { + +using ov::bfloat16; +struct mm_kernel { + std::unique_ptr> bf16xbf16; + std::unique_ptr> i8xi8; + std::unique_ptr> u8xi8; + + std::unique_ptr> i8xi8_gemv; + std::unique_ptr> bf16xbf16_gemv; + + data_type_t dt_a; + data_type_t dt_b; + bool b_is_transpose; +}; + +// interface +status_t mm_kernel_create_amx(mm_kernel** mm, const mm_create_param* param) { + mm_kernel* m = nullptr; + if (param == nullptr || mm == nullptr) { + DEBUG_LOG << "mm_kernel_create: invalid input parameter.\n"; + goto ERR; + } + + m = new mm_kernel; + if (param->b_is_gemv) { + if (param->dt_a == llmdnn_s8 && param->dt_b == llmdnn_s8) { + m->i8xi8_gemv = std::make_unique>(); + } else if (param->dt_a == llmdnn_bf16 && param->dt_b == llmdnn_bf16) { + m->bf16xbf16_gemv = std::make_unique>(); + } else { + DEBUG_LOG << "mm_kernel_create: unsupport gemv input type, a: " << param->dt_a << ", b: " << param->dt_b << ".\n"; + goto ERR; + } + } else { + if (param->dt_a == llmdnn_s8 && param->dt_b == llmdnn_s8) { + m->i8xi8 = std::make_unique>(false, param->b_is_trans); + } else if (param->dt_a == llmdnn_u8 && param->dt_b == llmdnn_s8) { + m->u8xi8 = std::make_unique>(false, param->b_is_trans); + } else if (param->dt_a == llmdnn_bf16 && param->dt_b == llmdnn_bf16) { + m->bf16xbf16 = std::make_unique>(false, param->b_is_trans); + } else { + DEBUG_LOG << "mm_kernel_create: unsupport input type, a: " << param->dt_a << ", b: " << param->dt_b << ".\n"; + goto ERR; + } + } + m->dt_a = param->dt_a; + m->dt_b = param->dt_b; + m->b_is_transpose = param->b_is_trans; + + *mm = m; + return status_t::status_ok; +ERR: + delete m; + return status_t::status_invalid_arguments; +} + +void mm_kernel_destroy_amx(const mm_kernel* mm) { + if (mm) { + delete mm; + } +} + +status_t mm_kernel_execute_amx(const mm_kernel* mm, void* ptr_a, void* ptr_b, void* ptr_c, size_t lda, size_t ldb, size_t ldc, + size_t M, size_t N, size_t K) { + size_t b_d0 = K, b_d1 = N; + if (mm->b_is_transpose) { + b_d0 = N; + b_d1 = K; + } + if (mm->i8xi8_gemv) { + tensor2D a(M, K, reinterpret_cast(ptr_a), lda); + (*mm->i8xi8_gemv)(a, reinterpret_cast(ptr_b), reinterpret_cast(ptr_c)); + cvt_i32_f32_avx512(reinterpret_cast(ptr_c), reinterpret_cast(ptr_c), M); + } else if (mm->i8xi8) { + tensor2D a(M, K, reinterpret_cast(ptr_a), lda); + tensor2D b(b_d0, b_d1, reinterpret_cast(ptr_b), ldb); + tensor2D c(M, N, reinterpret_cast(ptr_c), ldc); + amx_kernel::PP::BiasGeluStore pp(c); + (*mm->i8xi8)(a, b, 0, N, pp); + } else if (mm->u8xi8) { + tensor2D a(M, K, reinterpret_cast(ptr_a), lda); + tensor2D b(b_d0, b_d1, reinterpret_cast(ptr_b), ldb); + tensor2D c(M, N, reinterpret_cast(ptr_c), ldc); + amx_kernel::PP::BiasGeluStore pp(c); + (*mm->u8xi8)(a, b, 0, N, pp); + } else if (mm->bf16xbf16_gemv) { + tensor2D a(M, K, reinterpret_cast(ptr_a), lda); + (*mm->bf16xbf16_gemv)(a, reinterpret_cast(ptr_b), reinterpret_cast(ptr_c)); + } else if (mm->bf16xbf16) { + tensor2D a(M, K, reinterpret_cast(ptr_a), lda); + tensor2D b(b_d0, b_d1, reinterpret_cast(ptr_b), ldb); + tensor2D c(M, N, reinterpret_cast(ptr_c), ldc); + amx_kernel::PP::BiasGeluStore pp(c); + (*mm->bf16xbf16)(a, b, 0, N, pp); + } else { + DEBUG_LOG << "mm_kernel_execute: no valid kernel created, call create first.\n"; + return status_t::status_invalid_arguments; + } + + return status_t::status_ok; +} + +} // namespace llmdnn diff --git a/src/mm_kernel_amx.hpp b/src/mm_kernel_amx.hpp new file mode 100644 index 0000000..1950de2 --- /dev/null +++ b/src/mm_kernel_amx.hpp @@ -0,0 +1,29 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "llm_mm.hpp" +#include "llm_types.hpp" + +namespace llmdnn { + +status_t mm_kernel_create_amx(mm_kernel** mm, const mm_create_param* param); + +void mm_kernel_destroy_amx(const mm_kernel* mm); + +status_t mm_kernel_execute_amx(const mm_kernel* mm, void* ptr_a, void* ptr_b, void* ptr_c, size_t lda, size_t ldb, size_t ldc, + size_t M, size_t N, size_t K); + +} // namespace llmdnn diff --git a/src/mm_kernel_api.cpp b/src/mm_kernel_api.cpp new file mode 100644 index 0000000..ca52fe6 --- /dev/null +++ b/src/mm_kernel_api.cpp @@ -0,0 +1,29 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "llm_mm.hpp" +#include "llm_types.hpp" +#include "mm_kernel_amx.hpp" + +namespace llmdnn { + +static decltype(&mm_kernel_create) mm_kernel_create_ptr = mm_kernel_create_amx; +static decltype(&mm_kernel_destroy) mm_kernel_destroy_ptr = mm_kernel_destroy_amx; +static decltype(&mm_kernel_execute) mm_kernel_execute_ptr = mm_kernel_execute_amx; + +// interface +status_t mm_kernel_create(mm_kernel** mm, const mm_create_param* param) { + return mm_kernel_create_ptr(mm, param); +} + +void mm_kernel_destroy(const mm_kernel* mm) { + mm_kernel_destroy_ptr(mm); +} + +status_t mm_kernel_execute(const mm_kernel* mm, void* ptr_a, void* ptr_b, void* ptr_c, size_t lda, size_t ldb, size_t ldc, + size_t M, size_t N, size_t K) { + return mm_kernel_execute_ptr(mm, ptr_a, ptr_b, ptr_c, lda, ldb, ldc, M, N, K); +} + +} // namespace llmdnn diff --git a/src/mm_kernel_common_amx.hpp b/src/mm_kernel_common_amx.hpp new file mode 100644 index 0000000..4d97802 --- /dev/null +++ b/src/mm_kernel_common_amx.hpp @@ -0,0 +1,2818 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "common/bf16.hpp" +#include "common/tensor2d.hpp" +#include "utility_kernel_amx.hpp" +#include "gelu_kernel_avx512.hpp" +#include "llm_fc.hpp" + +#ifdef _WIN32 +#include +#else +#include +#endif + +using namespace llmdnn; + +namespace amx_kernel { + +namespace functional { + + inline void transpose_m512i_16x16(__m512i &r0, __m512i &r1, __m512i &r2, __m512i &r3, + __m512i &r4, __m512i &r5, __m512i &r6, __m512i &r7, + __m512i &r8, __m512i &r9, __m512i &ra, __m512i &rb, + __m512i &rc, __m512i &rd, __m512i &re, __m512i &rf) { + __m512i t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, ta, tb, tc, td, te, tf; + + t0 = _mm512_unpacklo_epi32(r0,r1); // 0 16 1 17 4 20 5 21 8 24 9 25 12 28 13 29 + t1 = _mm512_unpackhi_epi32(r0,r1); // 2 18 3 19 6 22 7 23 10 26 11 27 14 30 15 31 + t2 = _mm512_unpacklo_epi32(r2,r3); // 32 48 33 49 ... + t3 = _mm512_unpackhi_epi32(r2,r3); // 34 50 35 51 ... + t4 = _mm512_unpacklo_epi32(r4,r5); // 64 80 65 81 ... + t5 = _mm512_unpackhi_epi32(r4,r5); // 66 82 67 83 ... + t6 = _mm512_unpacklo_epi32(r6,r7); // 96 112 97 113 ... + t7 = _mm512_unpackhi_epi32(r6,r7); // 98 114 99 115 ... + t8 = _mm512_unpacklo_epi32(r8,r9); // 128 ... + t9 = _mm512_unpackhi_epi32(r8,r9); // 130 ... + ta = _mm512_unpacklo_epi32(ra,rb); // 160 ... + tb = _mm512_unpackhi_epi32(ra,rb); // 162 ... + tc = _mm512_unpacklo_epi32(rc,rd); // 196 ... + td = _mm512_unpackhi_epi32(rc,rd); // 198 ... + te = _mm512_unpacklo_epi32(re,rf); // 228 ... + tf = _mm512_unpackhi_epi32(re,rf); // 230 ... + + r0 = _mm512_unpacklo_epi64(t0,t2); // 0 16 32 48 ... + r1 = _mm512_unpackhi_epi64(t0,t2); // 1 17 33 49 ... + r2 = _mm512_unpacklo_epi64(t1,t3); // 2 18 34 49 ... + r3 = _mm512_unpackhi_epi64(t1,t3); // 3 19 35 51 ... + r4 = _mm512_unpacklo_epi64(t4,t6); // 64 80 96 112 ... + r5 = _mm512_unpackhi_epi64(t4,t6); // 65 81 97 114 ... + r6 = _mm512_unpacklo_epi64(t5,t7); // 66 82 98 113 ... + r7 = _mm512_unpackhi_epi64(t5,t7); // 67 83 99 115 ... + r8 = _mm512_unpacklo_epi64(t8,ta); // 128 144 160 176 ... + r9 = _mm512_unpackhi_epi64(t8,ta); // 129 145 161 178 ... + ra = _mm512_unpacklo_epi64(t9,tb); // 130 146 162 177 ... + rb = _mm512_unpackhi_epi64(t9,tb); // 131 147 163 179 ... + rc = _mm512_unpacklo_epi64(tc,te); // 192 208 228 240 ... + rd = _mm512_unpackhi_epi64(tc,te); // 193 209 229 241 ... + re = _mm512_unpacklo_epi64(td,tf); // 194 210 230 242 ... + rf = _mm512_unpackhi_epi64(td,tf); // 195 211 231 243 ... + + t0 = _mm512_shuffle_i32x4(r0, r4, 0x88); // 0 16 32 48 8 24 40 56 64 80 96 112 ... + t1 = _mm512_shuffle_i32x4(r1, r5, 0x88); // 1 17 33 49 ... + t2 = _mm512_shuffle_i32x4(r2, r6, 0x88); // 2 18 34 50 ... + t3 = _mm512_shuffle_i32x4(r3, r7, 0x88); // 3 19 35 51 ... + t4 = _mm512_shuffle_i32x4(r0, r4, 0xdd); // 4 20 36 52 ... + t5 = _mm512_shuffle_i32x4(r1, r5, 0xdd); // 5 21 37 53 ... + t6 = _mm512_shuffle_i32x4(r2, r6, 0xdd); // 6 22 38 54 ... + t7 = _mm512_shuffle_i32x4(r3, r7, 0xdd); // 7 23 39 55 ... + t8 = _mm512_shuffle_i32x4(r8, rc, 0x88); // 128 144 160 176 ... + t9 = _mm512_shuffle_i32x4(r9, rd, 0x88); // 129 145 161 177 ... + ta = _mm512_shuffle_i32x4(ra, re, 0x88); // 130 146 162 178 ... + tb = _mm512_shuffle_i32x4(rb, rf, 0x88); // 131 147 163 179 ... + tc = _mm512_shuffle_i32x4(r8, rc, 0xdd); // 132 148 164 180 ... + td = _mm512_shuffle_i32x4(r9, rd, 0xdd); // 133 149 165 181 ... + te = _mm512_shuffle_i32x4(ra, re, 0xdd); // 134 150 166 182 ... + tf = _mm512_shuffle_i32x4(rb, rf, 0xdd); // 135 151 167 183 ... + + r0 = _mm512_shuffle_i32x4(t0, t8, 0x88); // 0 16 32 48 64 80 96 112 ... 240 + r1 = _mm512_shuffle_i32x4(t1, t9, 0x88); // 1 17 33 49 66 81 97 113 ... 241 + r2 = _mm512_shuffle_i32x4(t2, ta, 0x88); // 2 18 34 50 67 82 98 114 ... 242 + r3 = _mm512_shuffle_i32x4(t3, tb, 0x88); // 3 19 35 51 68 83 99 115 ... 243 + r4 = _mm512_shuffle_i32x4(t4, tc, 0x88); // 4 ... + r5 = _mm512_shuffle_i32x4(t5, td, 0x88); // 5 ... + r6 = _mm512_shuffle_i32x4(t6, te, 0x88); // 6 ... + r7 = _mm512_shuffle_i32x4(t7, tf, 0x88); // 7 ... + r8 = _mm512_shuffle_i32x4(t0, t8, 0xdd); // 8 ... + r9 = _mm512_shuffle_i32x4(t1, t9, 0xdd); // 9 ... + ra = _mm512_shuffle_i32x4(t2, ta, 0xdd); // 10 ... + rb = _mm512_shuffle_i32x4(t3, tb, 0xdd); // 11 ... + rc = _mm512_shuffle_i32x4(t4, tc, 0xdd); // 12 ... + rd = _mm512_shuffle_i32x4(t5, td, 0xdd); // 13 ... + re = _mm512_shuffle_i32x4(t6, te, 0xdd); // 14 ... + rf = _mm512_shuffle_i32x4(t7, tf, 0xdd); // 15 31 47 63 79 96 111 127 ... 255 + } + + inline void transpose_epi32_16x16(void * _dst, const void * src, int stride) { + auto * dst = reinterpret_cast(_dst); + __m512i r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf; + auto * pA = reinterpret_cast(src); + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_loadu_epi32(pA + 2*stride); + r3 = _mm512_loadu_epi32(pA + 3*stride); + r4 = _mm512_loadu_epi32(pA + 4*stride); + r5 = _mm512_loadu_epi32(pA + 5*stride); + r6 = _mm512_loadu_epi32(pA + 6*stride); + r7 = _mm512_loadu_epi32(pA + 7*stride); + r8 = _mm512_loadu_epi32(pA + 8*stride); + r9 = _mm512_loadu_epi32(pA + 9*stride); + ra = _mm512_loadu_epi32(pA + 10*stride); + rb = _mm512_loadu_epi32(pA + 11*stride); + rc = _mm512_loadu_epi32(pA + 12*stride); + rd = _mm512_loadu_epi32(pA + 13*stride); + re = _mm512_loadu_epi32(pA + 14*stride); + rf = _mm512_loadu_epi32(pA + 15*stride); + + transpose_m512i_16x16(r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf); + + _mm512_storeu_epi32(dst, r0); + _mm512_storeu_epi32(dst + 16, r1); + _mm512_storeu_epi32(dst + 2*16, r2); + _mm512_storeu_epi32(dst + 3*16, r3); + _mm512_storeu_epi32(dst + 4*16, r4); + _mm512_storeu_epi32(dst + 5*16, r5); + _mm512_storeu_epi32(dst + 6*16, r6); + _mm512_storeu_epi32(dst + 7*16, r7); + _mm512_storeu_epi32(dst + 8*16, r8); + _mm512_storeu_epi32(dst + 9*16, r9); + _mm512_storeu_epi32(dst + 10*16, ra); + _mm512_storeu_epi32(dst + 11*16, rb); + _mm512_storeu_epi32(dst + 12*16, rc); + _mm512_storeu_epi32(dst + 13*16, rd); + _mm512_storeu_epi32(dst + 14*16, re); + _mm512_storeu_epi32(dst + 15*16, rf); + } + + // 16xN, N<=16, non-valid part is filled with zeros + inline void transpose_epi32_16xN(void * _dst, const void * src, int stride, int valid_bytes) { + auto * dst = reinterpret_cast(_dst); + __m512i r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf; + auto * pA = reinterpret_cast(src); + uint64_t mask_value = 0xFFFFFFFFFFFFFFFFull >> (64 - valid_bytes); + __mmask64 mask = _cvtu64_mask64(mask_value); + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); + r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); + r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); + r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); + r6 = _mm512_maskz_loadu_epi8 (mask, pA + 6*stride); + r7 = _mm512_maskz_loadu_epi8 (mask, pA + 7*stride); + r8 = _mm512_maskz_loadu_epi8 (mask, pA + 8*stride); + r9 = _mm512_maskz_loadu_epi8 (mask, pA + 9*stride); + ra = _mm512_maskz_loadu_epi8 (mask, pA + 10*stride); + rb = _mm512_maskz_loadu_epi8 (mask, pA + 11*stride); + rc = _mm512_maskz_loadu_epi8 (mask, pA + 12*stride); + rd = _mm512_maskz_loadu_epi8 (mask, pA + 13*stride); + re = _mm512_maskz_loadu_epi8 (mask, pA + 14*stride); + rf = _mm512_maskz_loadu_epi8 (mask, pA + 15*stride); + transpose_m512i_16x16(r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf); + _mm512_storeu_epi32(dst, r0); + _mm512_storeu_epi32(dst + 16, r1); + _mm512_storeu_epi32(dst + 2*16, r2); + _mm512_storeu_epi32(dst + 3*16, r3); + _mm512_storeu_epi32(dst + 4*16, r4); + _mm512_storeu_epi32(dst + 5*16, r5); + _mm512_storeu_epi32(dst + 6*16, r6); + _mm512_storeu_epi32(dst + 7*16, r7); + _mm512_storeu_epi32(dst + 8*16, r8); + _mm512_storeu_epi32(dst + 9*16, r9); + _mm512_storeu_epi32(dst + 10*16, ra); + _mm512_storeu_epi32(dst + 11*16, rb); + _mm512_storeu_epi32(dst + 12*16, rc); + _mm512_storeu_epi32(dst + 13*16, rd); + _mm512_storeu_epi32(dst + 14*16, re); + _mm512_storeu_epi32(dst + 15*16, rf); + } + + inline void transpose_epi32_Mx16(void * _dst, const void * src, int stride, int valid_m) { + auto * dst = reinterpret_cast(_dst); + __m512i r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf; + auto * pA = reinterpret_cast(src); + switch (valid_m) { + case 15: + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_loadu_epi32(pA + 2*stride); + r3 = _mm512_loadu_epi32(pA + 3*stride); + r4 = _mm512_loadu_epi32(pA + 4*stride); + r5 = _mm512_loadu_epi32(pA + 5*stride); + r6 = _mm512_loadu_epi32(pA + 6*stride); + r7 = _mm512_loadu_epi32(pA + 7*stride); + r8 = _mm512_loadu_epi32(pA + 8*stride); + r9 = _mm512_loadu_epi32(pA + 9*stride); + ra = _mm512_loadu_epi32(pA + 10*stride); + rb = _mm512_loadu_epi32(pA + 11*stride); + rc = _mm512_loadu_epi32(pA + 12*stride); + rd = _mm512_loadu_epi32(pA + 13*stride); + re = _mm512_loadu_epi32(pA + 14*stride); + rf = _mm512_setzero_epi32(); + break; + case 14: + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_loadu_epi32(pA + 2*stride); + r3 = _mm512_loadu_epi32(pA + 3*stride); + r4 = _mm512_loadu_epi32(pA + 4*stride); + r5 = _mm512_loadu_epi32(pA + 5*stride); + r6 = _mm512_loadu_epi32(pA + 6*stride); + r7 = _mm512_loadu_epi32(pA + 7*stride); + r8 = _mm512_loadu_epi32(pA + 8*stride); + r9 = _mm512_loadu_epi32(pA + 9*stride); + ra = _mm512_loadu_epi32(pA + 10*stride); + rb = _mm512_loadu_epi32(pA + 11*stride); + rc = _mm512_loadu_epi32(pA + 12*stride); + rd = _mm512_loadu_epi32(pA + 13*stride); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 13: + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_loadu_epi32(pA + 2*stride); + r3 = _mm512_loadu_epi32(pA + 3*stride); + r4 = _mm512_loadu_epi32(pA + 4*stride); + r5 = _mm512_loadu_epi32(pA + 5*stride); + r6 = _mm512_loadu_epi32(pA + 6*stride); + r7 = _mm512_loadu_epi32(pA + 7*stride); + r8 = _mm512_loadu_epi32(pA + 8*stride); + r9 = _mm512_loadu_epi32(pA + 9*stride); + ra = _mm512_loadu_epi32(pA + 10*stride); + rb = _mm512_loadu_epi32(pA + 11*stride); + rc = _mm512_loadu_epi32(pA + 12*stride); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 12: + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_loadu_epi32(pA + 2*stride); + r3 = _mm512_loadu_epi32(pA + 3*stride); + r4 = _mm512_loadu_epi32(pA + 4*stride); + r5 = _mm512_loadu_epi32(pA + 5*stride); + r6 = _mm512_loadu_epi32(pA + 6*stride); + r7 = _mm512_loadu_epi32(pA + 7*stride); + r8 = _mm512_loadu_epi32(pA + 8*stride); + r9 = _mm512_loadu_epi32(pA + 9*stride); + ra = _mm512_loadu_epi32(pA + 10*stride); + rb = _mm512_loadu_epi32(pA + 11*stride); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 11: + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_loadu_epi32(pA + 2*stride); + r3 = _mm512_loadu_epi32(pA + 3*stride); + r4 = _mm512_loadu_epi32(pA + 4*stride); + r5 = _mm512_loadu_epi32(pA + 5*stride); + r6 = _mm512_loadu_epi32(pA + 6*stride); + r7 = _mm512_loadu_epi32(pA + 7*stride); + r8 = _mm512_loadu_epi32(pA + 8*stride); + r9 = _mm512_loadu_epi32(pA + 9*stride); + ra = _mm512_loadu_epi32(pA + 10*stride); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 10: + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_loadu_epi32(pA + 2*stride); + r3 = _mm512_loadu_epi32(pA + 3*stride); + r4 = _mm512_loadu_epi32(pA + 4*stride); + r5 = _mm512_loadu_epi32(pA + 5*stride); + r6 = _mm512_loadu_epi32(pA + 6*stride); + r7 = _mm512_loadu_epi32(pA + 7*stride); + r8 = _mm512_loadu_epi32(pA + 8*stride); + r9 = _mm512_loadu_epi32(pA + 9*stride); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 9: + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_loadu_epi32(pA + 2*stride); + r3 = _mm512_loadu_epi32(pA + 3*stride); + r4 = _mm512_loadu_epi32(pA + 4*stride); + r5 = _mm512_loadu_epi32(pA + 5*stride); + r6 = _mm512_loadu_epi32(pA + 6*stride); + r7 = _mm512_loadu_epi32(pA + 7*stride); + r8 = _mm512_loadu_epi32(pA + 8*stride); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 8: + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_loadu_epi32(pA + 2*stride); + r3 = _mm512_loadu_epi32(pA + 3*stride); + r4 = _mm512_loadu_epi32(pA + 4*stride); + r5 = _mm512_loadu_epi32(pA + 5*stride); + r6 = _mm512_loadu_epi32(pA + 6*stride); + r7 = _mm512_loadu_epi32(pA + 7*stride); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 7: + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_loadu_epi32(pA + 2*stride); + r3 = _mm512_loadu_epi32(pA + 3*stride); + r4 = _mm512_loadu_epi32(pA + 4*stride); + r5 = _mm512_loadu_epi32(pA + 5*stride); + r6 = _mm512_loadu_epi32(pA + 6*stride); + r7 = _mm512_setzero_epi32(); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 6: + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_loadu_epi32(pA + 2*stride); + r3 = _mm512_loadu_epi32(pA + 3*stride); + r4 = _mm512_loadu_epi32(pA + 4*stride); + r5 = _mm512_loadu_epi32(pA + 5*stride); + r6 = _mm512_setzero_epi32(); + r7 = _mm512_setzero_epi32(); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 5: + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_loadu_epi32(pA + 2*stride); + r3 = _mm512_loadu_epi32(pA + 3*stride); + r4 = _mm512_loadu_epi32(pA + 4*stride); + r5 = _mm512_setzero_epi32(); + r6 = _mm512_setzero_epi32(); + r7 = _mm512_setzero_epi32(); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 4: + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_loadu_epi32(pA + 2*stride); + r3 = _mm512_loadu_epi32(pA + 3*stride); + r4 = _mm512_setzero_epi32(); + r5 = _mm512_setzero_epi32(); + r6 = _mm512_setzero_epi32(); + r7 = _mm512_setzero_epi32(); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 3: + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_loadu_epi32(pA + 2*stride); + r3 = _mm512_setzero_epi32(); + r4 = _mm512_setzero_epi32(); + r5 = _mm512_setzero_epi32(); + r6 = _mm512_setzero_epi32(); + r7 = _mm512_setzero_epi32(); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 2: + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_setzero_epi32(); + r3 = _mm512_setzero_epi32(); + r4 = _mm512_setzero_epi32(); + r5 = _mm512_setzero_epi32(); + r6 = _mm512_setzero_epi32(); + r7 = _mm512_setzero_epi32(); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 1: + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_setzero_epi32(); + r2 = _mm512_setzero_epi32(); + r3 = _mm512_setzero_epi32(); + r4 = _mm512_setzero_epi32(); + r5 = _mm512_setzero_epi32(); + r6 = _mm512_setzero_epi32(); + r7 = _mm512_setzero_epi32(); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + default: + assert(false); + return; + } + + transpose_m512i_16x16(r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf); + + _mm512_storeu_epi32(dst, r0); + _mm512_storeu_epi32(dst + 16, r1); + _mm512_storeu_epi32(dst + 2*16, r2); + _mm512_storeu_epi32(dst + 3*16, r3); + _mm512_storeu_epi32(dst + 4*16, r4); + _mm512_storeu_epi32(dst + 5*16, r5); + _mm512_storeu_epi32(dst + 6*16, r6); + _mm512_storeu_epi32(dst + 7*16, r7); + _mm512_storeu_epi32(dst + 8*16, r8); + _mm512_storeu_epi32(dst + 9*16, r9); + _mm512_storeu_epi32(dst + 10*16, ra); + _mm512_storeu_epi32(dst + 11*16, rb); + _mm512_storeu_epi32(dst + 12*16, rc); + _mm512_storeu_epi32(dst + 13*16, rd); + _mm512_storeu_epi32(dst + 14*16, re); + _mm512_storeu_epi32(dst + 15*16, rf); + } + + // 16xN, N<=16, non-valid part is on the left, filled with zeros + inline void transpose_epi32_16xN_right_align(void * _dst, const void * src, int stride, int valid_bytes) { + auto * dst = reinterpret_cast(_dst); + __m512i r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf; + int invalid_bytes = 64 - valid_bytes; + auto * pA = reinterpret_cast(src) - invalid_bytes; + uint64_t mask_value = 0xFFFFFFFFFFFFFFFFull << invalid_bytes; + __mmask64 mask = _cvtu64_mask64(mask_value); + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); + r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); + r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); + r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); + r6 = _mm512_maskz_loadu_epi8 (mask, pA + 6*stride); + r7 = _mm512_maskz_loadu_epi8 (mask, pA + 7*stride); + r8 = _mm512_maskz_loadu_epi8 (mask, pA + 8*stride); + r9 = _mm512_maskz_loadu_epi8 (mask, pA + 9*stride); + ra = _mm512_maskz_loadu_epi8 (mask, pA + 10*stride); + rb = _mm512_maskz_loadu_epi8 (mask, pA + 11*stride); + rc = _mm512_maskz_loadu_epi8 (mask, pA + 12*stride); + rd = _mm512_maskz_loadu_epi8 (mask, pA + 13*stride); + re = _mm512_maskz_loadu_epi8 (mask, pA + 14*stride); + rf = _mm512_maskz_loadu_epi8 (mask, pA + 15*stride); + transpose_m512i_16x16(r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf); + _mm512_storeu_epi32(dst, r0); + _mm512_storeu_epi32(dst + 16, r1); + _mm512_storeu_epi32(dst + 2*16, r2); + _mm512_storeu_epi32(dst + 3*16, r3); + _mm512_storeu_epi32(dst + 4*16, r4); + _mm512_storeu_epi32(dst + 5*16, r5); + _mm512_storeu_epi32(dst + 6*16, r6); + _mm512_storeu_epi32(dst + 7*16, r7); + _mm512_storeu_epi32(dst + 8*16, r8); + _mm512_storeu_epi32(dst + 9*16, r9); + _mm512_storeu_epi32(dst + 10*16, ra); + _mm512_storeu_epi32(dst + 11*16, rb); + _mm512_storeu_epi32(dst + 12*16, rc); + _mm512_storeu_epi32(dst + 13*16, rd); + _mm512_storeu_epi32(dst + 14*16, re); + _mm512_storeu_epi32(dst + 15*16, rf); + } + + inline void transpose_epi32_MxN_right_align(void * _dst, const void * src, int stride, int valid_bytes, int valid_m) { + auto * dst = reinterpret_cast(_dst); + __m512i r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf; + int invalid_bytes = 64 - valid_bytes; + auto * pA = reinterpret_cast(src) - invalid_bytes; + uint64_t mask_value = 0xFFFFFFFFFFFFFFFFull << invalid_bytes; + __mmask64 mask = _cvtu64_mask64(mask_value); + switch (valid_m) { + case 15: + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); + r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); + r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); + r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); + r6 = _mm512_maskz_loadu_epi8 (mask, pA + 6*stride); + r7 = _mm512_maskz_loadu_epi8 (mask, pA + 7*stride); + r8 = _mm512_maskz_loadu_epi8 (mask, pA + 8*stride); + r9 = _mm512_maskz_loadu_epi8 (mask, pA + 9*stride); + ra = _mm512_maskz_loadu_epi8 (mask, pA + 10*stride); + rb = _mm512_maskz_loadu_epi8 (mask, pA + 11*stride); + rc = _mm512_maskz_loadu_epi8 (mask, pA + 12*stride); + rd = _mm512_maskz_loadu_epi8 (mask, pA + 13*stride); + re = _mm512_maskz_loadu_epi8 (mask, pA + 14*stride); + rf = _mm512_setzero_epi32(); + break; + case 14: + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); + r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); + r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); + r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); + r6 = _mm512_maskz_loadu_epi8 (mask, pA + 6*stride); + r7 = _mm512_maskz_loadu_epi8 (mask, pA + 7*stride); + r8 = _mm512_maskz_loadu_epi8 (mask, pA + 8*stride); + r9 = _mm512_maskz_loadu_epi8 (mask, pA + 9*stride); + ra = _mm512_maskz_loadu_epi8 (mask, pA + 10*stride); + rb = _mm512_maskz_loadu_epi8 (mask, pA + 11*stride); + rc = _mm512_maskz_loadu_epi8 (mask, pA + 12*stride); + rd = _mm512_maskz_loadu_epi8 (mask, pA + 13*stride); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 13: + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); + r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); + r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); + r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); + r6 = _mm512_maskz_loadu_epi8 (mask, pA + 6*stride); + r7 = _mm512_maskz_loadu_epi8 (mask, pA + 7*stride); + r8 = _mm512_maskz_loadu_epi8 (mask, pA + 8*stride); + r9 = _mm512_maskz_loadu_epi8 (mask, pA + 9*stride); + ra = _mm512_maskz_loadu_epi8 (mask, pA + 10*stride); + rb = _mm512_maskz_loadu_epi8 (mask, pA + 11*stride); + rc = _mm512_maskz_loadu_epi8 (mask, pA + 12*stride); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 12: + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); + r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); + r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); + r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); + r6 = _mm512_maskz_loadu_epi8 (mask, pA + 6*stride); + r7 = _mm512_maskz_loadu_epi8 (mask, pA + 7*stride); + r8 = _mm512_maskz_loadu_epi8 (mask, pA + 8*stride); + r9 = _mm512_maskz_loadu_epi8 (mask, pA + 9*stride); + ra = _mm512_maskz_loadu_epi8 (mask, pA + 10*stride); + rb = _mm512_maskz_loadu_epi8 (mask, pA + 11*stride); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 11: + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); + r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); + r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); + r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); + r6 = _mm512_maskz_loadu_epi8 (mask, pA + 6*stride); + r7 = _mm512_maskz_loadu_epi8 (mask, pA + 7*stride); + r8 = _mm512_maskz_loadu_epi8 (mask, pA + 8*stride); + r9 = _mm512_maskz_loadu_epi8 (mask, pA + 9*stride); + ra = _mm512_maskz_loadu_epi8 (mask, pA + 10*stride); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 10: + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); + r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); + r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); + r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); + r6 = _mm512_maskz_loadu_epi8 (mask, pA + 6*stride); + r7 = _mm512_maskz_loadu_epi8 (mask, pA + 7*stride); + r8 = _mm512_maskz_loadu_epi8 (mask, pA + 8*stride); + r9 = _mm512_maskz_loadu_epi8 (mask, pA + 9*stride); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 9: + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); + r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); + r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); + r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); + r6 = _mm512_maskz_loadu_epi8 (mask, pA + 6*stride); + r7 = _mm512_maskz_loadu_epi8 (mask, pA + 7*stride); + r8 = _mm512_maskz_loadu_epi8 (mask, pA + 8*stride); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 8: + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); + r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); + r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); + r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); + r6 = _mm512_maskz_loadu_epi8 (mask, pA + 6*stride); + r7 = _mm512_maskz_loadu_epi8 (mask, pA + 7*stride); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 7: + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); + r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); + r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); + r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); + r6 = _mm512_maskz_loadu_epi8 (mask, pA + 6*stride); + r7 = _mm512_setzero_epi32(); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 6: + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); + r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); + r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); + r5 = _mm512_maskz_loadu_epi8 (mask, pA + 5*stride); + r6 = _mm512_setzero_epi32(); + r7 = _mm512_setzero_epi32(); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 5: + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); + r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); + r4 = _mm512_maskz_loadu_epi8 (mask, pA + 4*stride); + r5 = _mm512_setzero_epi32(); + r6 = _mm512_setzero_epi32(); + r7 = _mm512_setzero_epi32(); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 4: + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); + r3 = _mm512_maskz_loadu_epi8 (mask, pA + 3*stride); + r4 = _mm512_setzero_epi32(); + r5 = _mm512_setzero_epi32(); + r6 = _mm512_setzero_epi32(); + r7 = _mm512_setzero_epi32(); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 3: + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_maskz_loadu_epi8 (mask, pA + 2*stride); + r3 = _mm512_setzero_epi32(); + r4 = _mm512_setzero_epi32(); + r5 = _mm512_setzero_epi32(); + r6 = _mm512_setzero_epi32(); + r7 = _mm512_setzero_epi32(); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 2: + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_maskz_loadu_epi8 (mask, pA + stride); + r2 = _mm512_setzero_epi32(); + r3 = _mm512_setzero_epi32(); + r4 = _mm512_setzero_epi32(); + r5 = _mm512_setzero_epi32(); + r6 = _mm512_setzero_epi32(); + r7 = _mm512_setzero_epi32(); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + case 1: + r0 = _mm512_maskz_loadu_epi8 (mask, pA); + r1 = _mm512_setzero_epi32(); + r2 = _mm512_setzero_epi32(); + r3 = _mm512_setzero_epi32(); + r4 = _mm512_setzero_epi32(); + r5 = _mm512_setzero_epi32(); + r6 = _mm512_setzero_epi32(); + r7 = _mm512_setzero_epi32(); + r8 = _mm512_setzero_epi32(); + r9 = _mm512_setzero_epi32(); + ra = _mm512_setzero_epi32(); + rb = _mm512_setzero_epi32(); + rc = _mm512_setzero_epi32(); + rd = _mm512_setzero_epi32(); + re = _mm512_setzero_epi32(); + rf = _mm512_setzero_epi32(); + break; + default: + assert(false); + return; + } + transpose_m512i_16x16(r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf); + _mm512_storeu_epi32(dst, r0); + _mm512_storeu_epi32(dst + 16, r1); + _mm512_storeu_epi32(dst + 2*16, r2); + _mm512_storeu_epi32(dst + 3*16, r3); + _mm512_storeu_epi32(dst + 4*16, r4); + _mm512_storeu_epi32(dst + 5*16, r5); + _mm512_storeu_epi32(dst + 6*16, r6); + _mm512_storeu_epi32(dst + 7*16, r7); + _mm512_storeu_epi32(dst + 8*16, r8); + _mm512_storeu_epi32(dst + 9*16, r9); + _mm512_storeu_epi32(dst + 10*16, ra); + _mm512_storeu_epi32(dst + 11*16, rb); + _mm512_storeu_epi32(dst + 12*16, rc); + _mm512_storeu_epi32(dst + 13*16, rd); + _mm512_storeu_epi32(dst + 14*16, re); + _mm512_storeu_epi32(dst + 15*16, rf); + } + + inline void kpack_tile_B0B1(void * _dst0, void * _dst1, const int8_t * _src, int stride, int src_rows) { + #define FROM_B(i) ((1<<4)|(i)) + static const uint32_t idx[16] = { 0,4,FROM_B(0),FROM_B(4), + 1,5,FROM_B(1),FROM_B(5), + 2,6,FROM_B(2),FROM_B(6), + 3,7,FROM_B(3),FROM_B(7)}; + auto midx = _mm512_loadu_epi64(idx); + __mmask16 mask = _cvtu32_mask16(0xFFFFu); + const auto * src = reinterpret_cast(_src); + auto * dst0 = reinterpret_cast(_dst0); + auto * dst1 = reinterpret_cast(_dst1); + if (src_rows == 64) { + for (int row = 0; row < 16; row++) { + // each element (a? or b?) is 32-bits, two lanes in each ymm register + auto a256 = _mm256_loadu_epi8(src); src += stride; // [a0 a1 a2 a3 | a4 a5 a6 a7] 256-bits ymm0 B0: a0-a3 B1: a4:a7 + auto b256 = _mm256_loadu_epi8(src); src += stride; // [b0 b1 b2 b3 | b4 b5 b6 b7] 256-bits ymm1 B0: b0-b3 B1: b4:b7 + auto c256 = _mm256_loadu_epi8(src); src += stride; // [c0 c1 c2 c3 | c4 c5 c6 c7] 256-bits ymm2 B0: c0-c3 B1: c4:c7 + auto d256 = _mm256_loadu_epi8(src); src += stride; // [d0 d1 d2 d3 | d4 d5 d6 d7] 256-bits ymm3 B0: d0-d3 B1: d4:d7 + auto a = _mm512_castsi256_si512(a256); + auto b = _mm512_castsi256_si512(b256); + auto c = _mm512_castsi256_si512(c256); + auto d = _mm512_castsi256_si512(d256); + auto ac = _mm512_mask_permutex2var_epi32(a, mask, midx, c); // [a0 a4 c0 c4 | a1 a5 c1 c5 | a2 a6 c2 c6 | a3 a7 c3 c7] + auto bd = _mm512_mask_permutex2var_epi32(b, mask, midx, d); // [b0 b4 d0 d4 | b1 b5 d1 d5 | b2 b6 d2 d6 | b3 b7 d3 d7] + auto aib = _mm512_unpacklo_epi8(ac, bd); // [a0&b0 a4&b4 | a1&b1 a5&b5 | a2&b2 a6&b6 | a3&b3 a7&b7] + auto cid = _mm512_unpackhi_epi8(ac, bd); // [c0&d0 c4&d4 | c1&d1 c5&d5 | c2&d2 c6&d6 | c3&d3 c7&d7] + auto rowB0 = _mm512_unpacklo_epi16(aib, cid); // [a0&b0&c0&d0 | a1&b1&c1&d1 | a2&b2&c2&d2 | a3&b3&c3&d3] 512-bit (64bytes) line in B0 + auto rowB1 = _mm512_unpackhi_epi16(aib, cid); // [a4&b4&c4&d4 | a5&b5&c5&d5 | a6&b6&c6&d6 | a7&b7&c7&d7] 512-bit (64bytes) line in B1 + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + dst0 += 64; + dst1 += 64; + } + } else { + // less than 64 source lines, + int allzero_dst_rows = (64-src_rows)/4; + int allnonzero_dst_rows = src_rows/4; + // padding zeros at the top + auto rowB0 = _mm512_setzero_si512(); + auto rowB1 = _mm512_setzero_si512(); + for(int i = 0; i < allzero_dst_rows ; i++) { + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + dst0 += 64; + dst1 += 64; + } + // mixed row + int tails_nz = (src_rows & 3); + if (tails_nz) { + auto a256 = _mm256_setzero_si256(); // must be zero + auto b256 = _mm256_setzero_si256(); // when tails_nz > 2 + auto c256 = _mm256_setzero_si256(); // when tails_nz > 1 + auto d256 = _mm256_setzero_si256(); // when tails_nz > 0(always load) + if (tails_nz > 2) { + b256 = _mm256_loadu_epi8 (src); src += stride; + } + if (tails_nz > 1) { + c256 = _mm256_loadu_epi8 (src); src += stride; + } + d256 = _mm256_loadu_epi8 (src); src += stride; + auto a = _mm512_castsi256_si512(a256); + auto b = _mm512_castsi256_si512(b256); + auto c = _mm512_castsi256_si512(c256); + auto d = _mm512_castsi256_si512(d256); + auto ac = _mm512_mask_permutex2var_epi32(a, mask, midx, c); // [a0 a4 c0 c4 | a1 a5 c1 c5 | a2 a6 c2 c6 | a3 a7 c3 c7] + auto bd = _mm512_mask_permutex2var_epi32(b, mask, midx, d); // [b0 b4 d0 d4 | b1 b5 d1 d5 | b2 b6 d2 d6 | b3 b7 d3 d7] + auto aib = _mm512_unpacklo_epi8(ac, bd); // [a0&b0 a4&b4 | a1&b1 a5&b5 | a2&b2 a6&b6 | a3&b3 a7&b7] + auto cid = _mm512_unpackhi_epi8(ac, bd); // [c0&d0 c4&d4 | c1&d1 c5&d5 | c2&d2 c6&d6 | c3&d3 c7&d7] + auto rowB0 = _mm512_unpacklo_epi16(aib, cid); // [a0&b0&c0&d0 | a1&b1&c1&d1 | a2&b2&c2&d2 | a3&b3&c3&d3] 512-bit (64bytes) line in B0 + auto rowB1 = _mm512_unpackhi_epi16(aib, cid); // [a4&b4&c4&d4 | a5&b5&c5&d5 | a6&b6&c6&d6 | a7&b7&c7&d7] 512-bit (64bytes) line in B1 + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + dst0 += 64; + dst1 += 64; + } + // all non zeros + for (int i = 0; i < allnonzero_dst_rows; i++) { + auto a256 = _mm256_loadu_epi8 (src); src += stride; + auto b256 = _mm256_loadu_epi8 (src); src += stride; + auto c256 = _mm256_loadu_epi8 (src); src += stride; + auto d256 = _mm256_loadu_epi8 (src); src += stride; + auto a = _mm512_castsi256_si512(a256); + auto b = _mm512_castsi256_si512(b256); + auto c = _mm512_castsi256_si512(c256); + auto d = _mm512_castsi256_si512(d256); + auto ac = _mm512_mask_permutex2var_epi32(a, mask, midx, c); // [a0 a4 c0 c4 | a1 a5 c1 c5 | a2 a6 c2 c6 | a3 a7 c3 c7] + auto bd = _mm512_mask_permutex2var_epi32(b, mask, midx, d); // [b0 b4 d0 d4 | b1 b5 d1 d5 | b2 b6 d2 d6 | b3 b7 d3 d7] + auto aib = _mm512_unpacklo_epi8(ac, bd); // [a0&b0 a4&b4 | a1&b1 a5&b5 | a2&b2 a6&b6 | a3&b3 a7&b7] + auto cid = _mm512_unpackhi_epi8(ac, bd); // [c0&d0 c4&d4 | c1&d1 c5&d5 | c2&d2 c6&d6 | c3&d3 c7&d7] + auto rowB0 = _mm512_unpacklo_epi16(aib, cid); // [a0&b0&c0&d0 | a1&b1&c1&d1 | a2&b2&c2&d2 | a3&b3&c3&d3] 512-bit (64bytes) line in B0 + auto rowB1 = _mm512_unpackhi_epi16(aib, cid); // [a4&b4&c4&d4 | a5&b5&c5&d5 | a6&b6&c6&d6 | a7&b7&c7&d7] 512-bit (64bytes) line in B1 + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + dst0 += 64; + dst1 += 64; + } + } + } + + inline void kpack_tile_B0B1(void * _dst0, void * _dst1, const ov::bfloat16 * _src, int stride, int src_rows) { + static const uint64_t idx[8] = {0,4,1,5,2,6,3,7}; + auto midx = _mm512_loadu_epi64(idx); + const auto * src = reinterpret_cast(_src); + auto * dst0 = reinterpret_cast(_dst0); + auto * dst1 = reinterpret_cast(_dst1); + __m512i a,b,rowB0, rowB1; + if (src_rows == 32) { + for (int row = 0; row < 16; row++) { + a = _mm512_loadu_epi16(src); // [a1 a2 a3 a4 | a5 a6 a7 a8] total 512-bits in 8 64bits unit + b = _mm512_loadu_epi16(src + stride); // [b1 b2 b3 b4 | b5 b6 b7 b8] total 512-bits + a = _mm512_permutexvar_epi64(midx, a); // [a1 a5 | a2 a6 | a3 a7 | a4 a8] + b = _mm512_permutexvar_epi64(midx, b); // [b1 b5 | b2 b6 | b3 b7 | b4 b8] + rowB0 = _mm512_unpacklo_epi16(a, b); // [ a1&b1 a2&b2 a3&b3 a4&b4] for each 128-bits lane, interleave word in low 64 bits + rowB1 = _mm512_unpackhi_epi16(a, b); // [ a5&b5 a6&b6 a7&b7 a8&b8] for each 128-bits lane, interleave word in high 64 bits + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + src += 2*stride; + dst0 += 64; + dst1 += 64; + } + } else { + int allzero_dst_rows = (32-src_rows)/2; + int allnonzero_dst_rows = src_rows/2; + + rowB0 = _mm512_setzero_si512(); + rowB1 = _mm512_setzero_si512(); + for(int i = 0; i < allzero_dst_rows ; i++) { + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + dst0 += 64; + dst1 += 64; + } + // mixed row + if (src_rows & 1) { + a = _mm512_setzero_si512(); + b = _mm512_loadu_epi16(src); src += stride; + a = _mm512_permutexvar_epi64(midx, a); + b = _mm512_permutexvar_epi64(midx, b); + auto rowB0 = _mm512_unpacklo_epi16(a, b); + auto rowB1 = _mm512_unpackhi_epi16(a, b); + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + dst0 += 64; + dst1 += 64; + } + // all non-zero rows + for (int i = 0; i < allnonzero_dst_rows; i++) { + a = _mm512_loadu_epi16(src); + b = _mm512_loadu_epi16(src + stride); + a = _mm512_permutexvar_epi64(midx, a); + b = _mm512_permutexvar_epi64(midx, b); + rowB0 = _mm512_unpacklo_epi16(a, b); + rowB1 = _mm512_unpackhi_epi16(a, b); + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + src += 2*stride; + dst0 += 64; + dst1 += 64; + } + } + } + + inline void kpack_tile_B0B1_ntail(void * _dst0, void * _dst1, const int8_t * _src, int stride, int src_rows, int valid_n) { + #define FROM_B(i) ((1<<4)|(i)) + static const uint32_t idx[16] = { 0,4,FROM_B(0),FROM_B(4), + 1,5,FROM_B(1),FROM_B(5), + 2,6,FROM_B(2),FROM_B(6), + 3,7,FROM_B(3),FROM_B(7)}; + auto midx = _mm512_loadu_epi64(idx); + __mmask16 mask = _cvtu32_mask16(0xFFFFu); + const auto * src = reinterpret_cast(_src); + auto * dst0 = reinterpret_cast(_dst0); + auto * dst1 = reinterpret_cast(_dst1); + __mmask32 mask_n = _cvtu32_mask32(0xFFFFFFFF >> (32 - valid_n)); + if (src_rows == 64) { + for (int row = 0; row < 16; row++) { + // each element (a? or b?) is 32-bits, two lanes in each ymm register + auto a256 = _mm256_maskz_loadu_epi8(mask_n, src); src += stride; // [a0 a1 a2 a3 | a4 a5 a6 a7] 256-bits ymm0 B0: a0-a3 B1: a4:a7 + auto b256 = _mm256_maskz_loadu_epi8(mask_n, src); src += stride; // [b0 b1 b2 b3 | b4 b5 b6 b7] 256-bits ymm1 B0: b0-b3 B1: b4:b7 + auto c256 = _mm256_maskz_loadu_epi8(mask_n, src); src += stride; // [c0 c1 c2 c3 | c4 c5 c6 c7] 256-bits ymm2 B0: c0-c3 B1: c4:c7 + auto d256 = _mm256_maskz_loadu_epi8(mask_n, src); src += stride; // [d0 d1 d2 d3 | d4 d5 d6 d7] 256-bits ymm3 B0: d0-d3 B1: d4:d7 + auto a = _mm512_castsi256_si512(a256); + auto b = _mm512_castsi256_si512(b256); + auto c = _mm512_castsi256_si512(c256); + auto d = _mm512_castsi256_si512(d256); + auto ac = _mm512_mask_permutex2var_epi32(a, mask, midx, c); // [a0 a4 c0 c4 | a1 a5 c1 c5 | a2 a6 c2 c6 | a3 a7 c3 c7] + auto bd = _mm512_mask_permutex2var_epi32(b, mask, midx, d); // [b0 b4 d0 d4 | b1 b5 d1 d5 | b2 b6 d2 d6 | b3 b7 d3 d7] + auto aib = _mm512_unpacklo_epi8(ac, bd); // [a0&b0 a4&b4 | a1&b1 a5&b5 | a2&b2 a6&b6 | a3&b3 a7&b7] + auto cid = _mm512_unpackhi_epi8(ac, bd); // [c0&d0 c4&d4 | c1&d1 c5&d5 | c2&d2 c6&d6 | c3&d3 c7&d7] + auto rowB0 = _mm512_unpacklo_epi16(aib, cid); // [a0&b0&c0&d0 | a1&b1&c1&d1 | a2&b2&c2&d2 | a3&b3&c3&d3] 512-bit (64bytes) line in B0 + auto rowB1 = _mm512_unpackhi_epi16(aib, cid); // [a4&b4&c4&d4 | a5&b5&c5&d5 | a6&b6&c6&d6 | a7&b7&c7&d7] 512-bit (64bytes) line in B1 + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + dst0 += 64; + dst1 += 64; + } + } else { + // less than 64 source lines, + int allzero_dst_rows = (64-src_rows)/4; + int allnonzero_dst_rows = src_rows/4; + // padding zeros at the top + auto rowB0 = _mm512_setzero_si512(); + auto rowB1 = _mm512_setzero_si512(); + for(int i = 0; i < allzero_dst_rows ; i++) { + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + dst0 += 64; + dst1 += 64; + } + // mixed row + int tails_nz = (src_rows & 3); + if (tails_nz) { + auto a256 = _mm256_setzero_si256(); // must be zero + auto b256 = _mm256_setzero_si256(); // when tails_nz > 2 + auto c256 = _mm256_setzero_si256(); // when tails_nz > 1 + auto d256 = _mm256_setzero_si256(); // when tails_nz > 0(always load) + if (tails_nz > 2) { + b256 = _mm256_maskz_loadu_epi8(mask_n, src); src += stride; + } + if (tails_nz > 1) { + c256 = _mm256_maskz_loadu_epi8(mask_n, src); src += stride; + } + d256 = _mm256_maskz_loadu_epi8(mask_n, src); src += stride; + auto a = _mm512_castsi256_si512(a256); + auto b = _mm512_castsi256_si512(b256); + auto c = _mm512_castsi256_si512(c256); + auto d = _mm512_castsi256_si512(d256); + auto ac = _mm512_mask_permutex2var_epi32(a, mask, midx, c); // [a0 a4 c0 c4 | a1 a5 c1 c5 | a2 a6 c2 c6 | a3 a7 c3 c7] + auto bd = _mm512_mask_permutex2var_epi32(b, mask, midx, d); // [b0 b4 d0 d4 | b1 b5 d1 d5 | b2 b6 d2 d6 | b3 b7 d3 d7] + auto aib = _mm512_unpacklo_epi8(ac, bd); // [a0&b0 a4&b4 | a1&b1 a5&b5 | a2&b2 a6&b6 | a3&b3 a7&b7] + auto cid = _mm512_unpackhi_epi8(ac, bd); // [c0&d0 c4&d4 | c1&d1 c5&d5 | c2&d2 c6&d6 | c3&d3 c7&d7] + auto rowB0 = _mm512_unpacklo_epi16(aib, cid); // [a0&b0&c0&d0 | a1&b1&c1&d1 | a2&b2&c2&d2 | a3&b3&c3&d3] 512-bit (64bytes) line in B0 + auto rowB1 = _mm512_unpackhi_epi16(aib, cid); // [a4&b4&c4&d4 | a5&b5&c5&d5 | a6&b6&c6&d6 | a7&b7&c7&d7] 512-bit (64bytes) line in B1 + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + dst0 += 64; + dst1 += 64; + } + // all non zeros + for (int i = 0; i < allnonzero_dst_rows; i++) { + auto a256 = _mm256_maskz_loadu_epi8(mask_n, src); src += stride; + auto b256 = _mm256_maskz_loadu_epi8(mask_n, src); src += stride; + auto c256 = _mm256_maskz_loadu_epi8(mask_n, src); src += stride; + auto d256 = _mm256_maskz_loadu_epi8(mask_n, src); src += stride; + auto a = _mm512_castsi256_si512(a256); + auto b = _mm512_castsi256_si512(b256); + auto c = _mm512_castsi256_si512(c256); + auto d = _mm512_castsi256_si512(d256); + auto ac = _mm512_mask_permutex2var_epi32(a, mask, midx, c); // [a0 a4 c0 c4 | a1 a5 c1 c5 | a2 a6 c2 c6 | a3 a7 c3 c7] + auto bd = _mm512_mask_permutex2var_epi32(b, mask, midx, d); // [b0 b4 d0 d4 | b1 b5 d1 d5 | b2 b6 d2 d6 | b3 b7 d3 d7] + auto aib = _mm512_unpacklo_epi8(ac, bd); // [a0&b0 a4&b4 | a1&b1 a5&b5 | a2&b2 a6&b6 | a3&b3 a7&b7] + auto cid = _mm512_unpackhi_epi8(ac, bd); // [c0&d0 c4&d4 | c1&d1 c5&d5 | c2&d2 c6&d6 | c3&d3 c7&d7] + auto rowB0 = _mm512_unpacklo_epi16(aib, cid); // [a0&b0&c0&d0 | a1&b1&c1&d1 | a2&b2&c2&d2 | a3&b3&c3&d3] 512-bit (64bytes) line in B0 + auto rowB1 = _mm512_unpackhi_epi16(aib, cid); // [a4&b4&c4&d4 | a5&b5&c5&d5 | a6&b6&c6&d6 | a7&b7&c7&d7] 512-bit (64bytes) line in B1 + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + dst0 += 64; + dst1 += 64; + } + } + } + + inline void kpack_tile_B0B1_ntail(void * _dst0, void * _dst1, const ov::bfloat16 * _src, int stride, int src_rows, int valid_n) { + static const uint64_t idx[8] = {0,4,1,5,2,6,3,7}; + auto midx = _mm512_loadu_epi64(idx); + const auto * src = reinterpret_cast(_src); + auto * dst0 = reinterpret_cast(_dst0); + auto * dst1 = reinterpret_cast(_dst1); + __mmask32 mask = _cvtu32_mask32(0xFFFFFFFF >> (32 - valid_n)); + __m512i a,b,rowB0, rowB1; + if (src_rows == 32) { + for (int row = 0; row < 16; row++) { + a = _mm512_maskz_loadu_epi16(mask, src); // [a1 a2 a3 a4 | a5 a6 a7 a8] total 512-bits in 8 64bits unit + b = _mm512_maskz_loadu_epi16(mask, src + stride); // [b1 b2 b3 b4 | b5 b6 b7 b8] total 512-bits + a = _mm512_permutexvar_epi64(midx, a); // [a1 a5 | a2 a6 | a3 a7 | a4 a8] + b = _mm512_permutexvar_epi64(midx, b); // [b1 b5 | b2 b6 | b3 b7 | b4 b8] + rowB0 = _mm512_unpacklo_epi16(a, b); // [ a1&b1 a2&b2 a3&b3 a4&b4] for each 128-bits lane, interleave word in low 64 bits + rowB1 = _mm512_unpackhi_epi16(a, b); // [ a5&b5 a6&b6 a7&b7 a8&b8] for each 128-bits lane, interleave word in high 64 bits + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + src += 2*stride; + dst0 += 64; + dst1 += 64; + } + } else { + int allzero_dst_rows = (32-src_rows)/2; + int allnonzero_dst_rows = src_rows/2; + + rowB0 = _mm512_setzero_si512(); + rowB1 = _mm512_setzero_si512(); + for(int i = 0; i < allzero_dst_rows ; i++) { + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + dst0 += 64; + dst1 += 64; + } + // mixed row + if (src_rows & 1) { + a = _mm512_setzero_si512(); + b = _mm512_maskz_loadu_epi16(mask, src); src += stride; + a = _mm512_permutexvar_epi64(midx, a); + b = _mm512_permutexvar_epi64(midx, b); + auto rowB0 = _mm512_unpacklo_epi16(a, b); + auto rowB1 = _mm512_unpackhi_epi16(a, b); + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + dst0 += 64; + dst1 += 64; + } + // all non-zero rows + for (int i = 0; i < allnonzero_dst_rows; i++) { + a = _mm512_maskz_loadu_epi16(mask, src); + b = _mm512_maskz_loadu_epi16(mask, src + stride); + a = _mm512_permutexvar_epi64(midx, a); + b = _mm512_permutexvar_epi64(midx, b); + rowB0 = _mm512_unpacklo_epi16(a, b); + rowB1 = _mm512_unpackhi_epi16(a, b); + _mm512_storeu_epi16(dst0, rowB0); + _mm512_storeu_epi16(dst1, rowB1); + src += 2*stride; + dst0 += 64; + dst1 += 64; + } + } + } + + // prepare B matrix for C matrix 2x2 blocking (B matrix + // will be accessed in 1x2) + // given 2x2 blocking scheme, Kx32 blocks are always + // accessed sequentially: + // transpose/repack each 32xK ov::bfloat16 submatrix + // into Kx32 slices (each number is a 16x32 bf16-block): + // 0 2 4 6 ... ... + // 1 3 5 7 ... ... + + inline void get_min_max(tensor2D & matB, float& min, float& max) { + int K = matB.dims[0]; + int N = matB.dims[1]; + auto m_max = _mm512_set1_ps(-__FLT_MAX__); + auto m_min = _mm512_set1_ps(__FLT_MAX__); + for (int k = 0; k < K; k++) { + int n = 0; + for (; n < N / 16 * 16; n += 16) { + auto a = _mm512_cvtepi16_epi32(_mm256_loadu_epi16(&matB(k, n))); + a = _mm512_slli_epi32(a, 16); + m_max = _mm512_max_ps((__m512)a, m_max); + m_min = _mm512_min_ps((__m512)a, m_min); + } + if (n != N) { + __mmask16 msk = _cvtu32_mask16(0xFFFFu >> (16 - (N - n))); + auto a = _mm512_cvtepi16_epi32(_mm256_maskz_loadu_epi16(msk, &matB(k, n))); + a = _mm512_slli_epi32(a, 16); + m_max = _mm512_mask_max_ps(m_max, msk, (__m512)a, m_max); + m_min = _mm512_mask_min_ps(m_min, msk, (__m512)a, m_min); + } + } + max = _mm512_reduce_max_ps(m_max); + min = _mm512_reduce_min_ps(m_min); + } + + template + void i8_to_bf16_Kx32(int8_t *&src, ov::bfloat16 *dst) + { + for (int k = 0; k < K; k++) + { + auto a = _mm_load_si128((__m128i *)src); // 16 int8 + auto b = _mm_load_si128((__m128i *)(src + 16)); // 16 int8 + auto a_512 = _mm512_cvtepi8_epi32(a); // 16 int32 + auto b_512 = _mm512_cvtepi8_epi32(b); // 16 int32 + auto a_f = _mm512_cvtepi32_ps(a_512); // 16 ps + auto b_f = _mm512_cvtepi32_ps(b_512); // 16 ps + auto reg_out = _mm512_cvtne2ps_pbh(b_f, a_f); // 32 packed bf16 + _mm512_store_epi32(dst, (__m512i)reg_out); // + src += 32; // 32 int8_t dequantized into 32 bf16 + dst += 32; + } + } + + template + void i8_to_bf16_Kx32(int8_t *&src, ov::bfloat16 *dst, float* zp) + { + auto zp0 = _mm512_loadu_ps(zp); + auto zp1 = _mm512_loadu_ps(zp + 16); + for (int k = 0; k < K; k++) + { + auto a = _mm_load_si128((__m128i *)src); // 16 int8 + auto b = _mm_load_si128((__m128i *)(src + 16)); // 16 int8 + auto a_512 = _mm512_cvtepu8_epi32(a); // 16 int32 + auto b_512 = _mm512_cvtepu8_epi32(b); // 16 int32 + auto a_f = _mm512_cvtepi32_ps(a_512); // 16 ps + auto b_f = _mm512_cvtepi32_ps(b_512); // 16 ps + a_f = _mm512_sub_ps(a_f, zp0); + b_f = _mm512_sub_ps(b_f, zp1); + auto reg_out = _mm512_cvtne2ps_pbh(b_f, a_f); // 32 packed bf16 + _mm512_store_epi32(dst, (__m512i)reg_out); // + src += 32; // 32 int8_t dequantized into 32 bf16 + dst += 32; + } + } + + // K tail, because right align need to fill zero for padding, real valid k is from invalid_k_num to the end + template + void i8_to_bf16_Kx32_tail(int8_t *&src, ov::bfloat16 *dst, float* zp, int k_start, int invalid_k_num) + { + auto zp0 = _mm512_loadu_ps(zp); + auto zp1 = _mm512_loadu_ps(zp + 16); + auto zero = _mm512_setzero_epi32(); + for (int k = 0; k < K; k++) { + auto k_cur = k + k_start; + if (k_cur < invalid_k_num) { + _mm512_store_epi32(dst, zero); + } else { + auto a = _mm_load_si128((__m128i *)src); // 16 int8 + auto b = _mm_load_si128((__m128i *)(src + 16)); // 16 int8 + auto a_512 = _mm512_cvtepu8_epi32(a); // 16 int32 + auto b_512 = _mm512_cvtepu8_epi32(b); // 16 int32 + auto a_f = _mm512_cvtepi32_ps(a_512); // 16 ps + auto b_f = _mm512_cvtepi32_ps(b_512); // 16 ps + a_f = _mm512_sub_ps(a_f, zp0); + b_f = _mm512_sub_ps(b_f, zp1); + auto reg_out = _mm512_cvtne2ps_pbh(b_f, a_f); // 32 packed bf16 + _mm512_store_epi32(dst, (__m512i)reg_out); + } + src += 32; + dst += 32; + } + } + + inline void bf16_to_i8_tensor(tensor2D& dst, tensor2D& src, float quant_scale) { + dst.resize(src.dims[0], src.dims[1]); + auto scale = _mm512_set1_ps(quant_scale); + for (int k = 0; k < src.dims[0]; k++) { + auto p_src = &src(k, 0); + auto p_dst = &dst(k, 0); + for (int n = 0; n < src.dims[1]; n += 16, p_src += 16, p_dst += 16) { + auto a = _mm512_cvtepi16_epi32(_mm256_loadu_epi16(p_src)); // load packed 16 x bf16 + a = _mm512_slli_epi32(a, 16); // bf16 zero-extend to f32 + auto a_f = _mm512_mul_ps((__m512)a, scale); // scale + a = _mm512_cvtps_epi32(a_f); // ps to dw + auto a_128 = _mm512_cvtsepi32_epi8(a); // saturate convert into int8 + _mm_store_si128((__m128i*)(p_dst), a_128); + } + } + } + + inline void f32_to_bf16_tensor(tensor2D& dst, const tensor2D& src) { + dst.resize(src.dims[0], src.dims[1]); + auto tail = src.dims[1] % 16; + __mmask16 x_mask = _cvtu32_mask16(0xFFFFu >> (16 - tail)); + for (int k = 0; k < src.dims[0]; k++) { + auto p_src = &src(k, 0); + auto p_dst = &dst(k, 0); + int i; + for(i = 0; i < src.dims[1] / 32 * 32; i += 32) { + auto x0 = _mm512_loadu_ps(p_src + i); + auto x1 = _mm512_loadu_ps(p_src + i + 16); + auto out = _mm512_cvtne2ps_pbh(x1, x0); + _mm512_storeu_epi32(reinterpret_cast(p_dst) + i, (__m512i)out); + } + if (i < src.dims[1] - tail) { + auto x = _mm512_loadu_ps(p_src + i); + auto out = _mm512_cvtne2ps_pbh(x, x); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(reinterpret_cast(p_dst) + i), + _mm512_extracti64x4_epi64(out, 0)); + i += 16; + } + // handle tails + if (tail) { + auto x = _mm512_maskz_loadu_ps(x_mask, p_src + i); + auto out = _mm512_cvtne2ps_pbh(x, x); + _mm256_mask_storeu_epi16(reinterpret_cast<__m256i*>(reinterpret_cast(p_dst) + i), + x_mask, _mm512_extracti64x4_epi64(out, 0)); + } + } + } + + inline void u8_to_u16_tensor(tensor2D& dst, const tensor2D& src) { + dst.resize(src.dims[0], src.dims[1]); + auto tail = src.dims[1] % 32; + __mmask32 x_mask = _cvtu32_mask32(0xFFFFFFFF >> (32 - tail)); + for (int k = 0; k < src.dims[0]; k++) { + auto p_src = &src(k, 0); + auto p_dst = &dst(k, 0); + int i; + for(i = 0; i < src.dims[1] / 32 * 32; i += 32) { + auto x = _mm256_loadu_epi8(p_src + i); + auto y = _mm512_cvtepu8_epi16(x); + _mm512_storeu_epi16(p_dst + i, y); + } + // handle tails + if (tail) { + auto x = _mm256_maskz_loadu_epi8(x_mask, p_src + i); + auto y = _mm512_cvtepu8_epi16(x); + _mm512_mask_storeu_epi16(p_dst + i, x_mask, y); + } + } + } + + inline void u16_to_u8_tensor(tensor2D&& dst, const tensor2D& src) { + auto tail = src.dims[1] % 32; + __mmask32 x_mask = _cvtu32_mask32(0xFFFFFFFF >> (32 - tail)); + for (int k = 0; k < src.dims[0]; k++) { + auto p_src = &src(k, 0); + auto p_dst = &dst(k, 0); + int i; + for(i = 0; i < src.dims[1] / 32 * 32; i += 32) { + auto x = _mm512_loadu_epi16(p_src + i); + auto y = _mm512_cvtusepi16_epi8(x); + _mm256_storeu_epi8(p_dst + i, y); + } + // handle tails + if (tail) { + auto x = _mm512_maskz_loadu_epi16(x_mask, p_src + i); + auto y = _mm512_cvtusepi16_epi8(x); + _mm256_mask_storeu_epi8(p_dst + i, x_mask, y); + } + } + } +}; + +// 2x2 tiles post process kernels + +// 4 tiles located at C matrix (m,n) of size (valid_m, valid_n) +// tC00/tC01 +// tC10/tC11 + +namespace PP { + template + struct is_f32i32 : std::false_type {}; + template<> + struct is_f32i32 : std::true_type {}; + template<> + struct is_f32i32 : std::true_type {}; + + using Steps = postops_types; + + template + struct BiasGeluStore { + static_assert(std::is_same::value || std::is_same::value || std::is_same::value, + "BiasGeluStore only support output data types ov::bfloat16/int8_t/float"); + + BiasGeluStore(tensor2D & C, float * bias = nullptr) : C(C), bias(bias) {} + + tensor2D & C; + float * bias; + void set_bias(float * _bias) { + assert (steps & BIAS); + bias = _bias; + } + + float deq_scale_common = 1.0f; + float * deq_scale_per_oc = nullptr; + void set_deq_scale(float scale = 1.0f) { + assert (steps & DEQUANT); + deq_scale_common = scale; + deq_scale_per_oc = nullptr; + } + void set_deq_scale(float * scale_per_oc) { + assert (steps & DEQUANT); + deq_scale_common = 0; + deq_scale_per_oc = scale_per_oc; + } + + float q_scale_common = 0.0f; + float * q_scale_per_oc = nullptr; + void set_q_scale(float scale) { + assert (steps & QUANT); + q_scale_common = scale; + q_scale_per_oc = nullptr; + } + void set_q_scale(float * scale_per_oc) { + assert (steps & QUANT); + q_scale_common = 0; + q_scale_per_oc = scale_per_oc; + } + + // source buffC can be i32 or f32, buffC size is [32, 32], valid_m/valid_n is in [1, 32] + template::value, bool>::type = true> + void operator()(tensor2D & buffC, int m, int n, int valid_m, int valid_n) { + auto * psrc = &buffC(0,0); + int8_t * pdst = reinterpret_cast(&(C(m, n))); + int stride = C.stride; + + __m512 bias0, bias1; + if (steps & BIAS) { + bias0 = _mm512_loadu_ps(bias + n); + bias1 = _mm512_loadu_ps(bias + n + 16); + } + + __m512 m512_q_scale0; + __m512 m512_q_scale1; + __m512 m512_deq_scale0; + __m512 m512_deq_scale1; + if (steps & DEQUANT) { + if (deq_scale_per_oc) { + m512_deq_scale0 = _mm512_loadu_ps(deq_scale_per_oc + n); + m512_deq_scale1 = _mm512_loadu_ps(deq_scale_per_oc + n + 16); + } else { + m512_deq_scale0 = _mm512_set1_ps(deq_scale_common); + m512_deq_scale1 = _mm512_set1_ps(deq_scale_common); + } + } + if (steps & QUANT) { + if (q_scale_per_oc) { + m512_q_scale0 = _mm512_loadu_ps(q_scale_per_oc + n); + m512_q_scale1 = _mm512_loadu_ps(q_scale_per_oc + n + 16); + } else { + m512_q_scale0 = _mm512_set1_ps(q_scale_common); + m512_q_scale1 = _mm512_set1_ps(q_scale_common); + } + } + + __mmask32 kall; + __mmask16 k0, k1; + if (std::is_same::value) { + if (valid_n >= 16) { + k0 = _cvtu32_mask16(0xFFFF); + k1 = _cvtu32_mask16(0xFFFF >> (32-valid_n)); + } else { + k0 = _cvtu32_mask16(0xFFFF >> (16-valid_n)); + k1 = _cvtu32_mask16(0); + } + } else { + kall = _cvtu32_mask32(0xFFFFFFFF >> (32-valid_n)); + } + + for(int i = 0; i < valid_m; i ++) { + auto r0 = _mm512_loadu_ps(psrc); + auto r1 = _mm512_loadu_ps(psrc + 16); + if (std::is_same::value) { + r0 = _mm512_cvtepi32_ps(_mm512_castps_si512(r0)); // cvt i32=>f32 + r1 = _mm512_cvtepi32_ps(_mm512_castps_si512(r1)); // cvt i32=>f32 + } + if (steps & DEQUANT) { + r0 = _mm512_mul_ps(r0, m512_deq_scale0); // dequantize + r1 = _mm512_mul_ps(r1, m512_deq_scale1); // dequantize + } + if (steps & BIAS) { + r0 = _mm512_add_ps(r0, bias0); + r1 = _mm512_add_ps(r1, bias1); + } + if (steps & GELU) { + r0 = gelu_erf_minmax_approx_avx512(r0); + r1 = gelu_erf_minmax_approx_avx512(r1); + } + if (steps & GELU_TANH) { + r0 = gelu_tanh_avx512(r0); + r1 = gelu_tanh_avx512(r1); + } + + // quantize & store + if (steps & QUANT) { + r0 = _mm512_mul_ps(r0, m512_q_scale0); + r1 = _mm512_mul_ps(r1, m512_q_scale1); + } + if (std::is_same::value) { + auto c = _mm512_cvtne2ps_pbh(r1, r0); // convert to bf16 + _mm512_mask_storeu_epi16(pdst, kall, (__m512i)c); // store bf16 + } + if (std::is_same::value) { + auto d0 = _mm512_cvtps_epi32(r0); // convert to dword(i32) + auto d1 = _mm512_cvtps_epi32(r1); // convert to dword(i32) + auto b0 = _mm512_cvtsepi32_epi8 (d0); // dword => int8 with Saturate8 + auto b1 = _mm512_cvtsepi32_epi8 (d1); // dword => int8 with Saturate8 + auto b0b1 = _mm256_inserti32x4(_mm256_castsi128_si256(b0), b1, 1); // combine two int8 xmm into a ymm + _mm256_mask_storeu_epi8(pdst, kall, b0b1); // masked store + } + if (std::is_same::value) { + _mm512_mask_storeu_ps(pdst, k0, r0); // store float + _mm512_mask_storeu_ps(pdst + 64, k1, r1); // store float + } + pdst += stride; + psrc += 32; + } + } + }; +} + +// WA: clang could not find _mm_hint definition +#define prefetch_bytes(bytes, sel, advance, src) { \ + int8_t *p = reinterpret_cast(src); \ + for (int i = 0; i < bytes; i+=64) \ + _mm_prefetch(p + i + advance, sel); \ +} + +// matmul (FC) +// +// constB constrols whether it's FC or not +// store precision for weight compression, only for BF16 AMX + +template +tensor2D getSubMatB(tensor2D & _matB, int n0, int n1, bool transposeB) { + int Bd0 = transposeB ? (n1-n0) : _matB.dims[0]; + int Bd1 = transposeB ? _matB.dims[1] : (n1-n0); + T * pbase = transposeB ? (&_matB(n0, 0)):(&_matB(0, n0)); + return tensor2D(Bd0, Bd1, pbase, _matB.stride); +} + +template +void loop2D_no_bM(int M, int N, F f) { + for(int n=0; n +void loop2D(int M, int N, int mc, F f) { + for(int m0=0; m0= bM) +template +void loop2D_opt_Mtail(int M, int N, int mc, F f) { + assert(M > bM); + for(int m0=0; m0 +void repackB_1x2(const tensor2D &Bi, bool transpose, tensor2D& Bo, bool is_const) { + int K = Bi.dims[transpose ? 1 : 0]; + int N = Bi.dims[transpose ? 0 : 1]; + + // K_padded : round up to multiple of 32/64 + int kStep = 64 / sizeof(T); + int K_padded = (K + kStep - 1) / kStep * kStep; + int Ktails = K % kStep; + int Kbody = K - Ktails; + + // N_padded : round up to multiple of (2*16) + int N_unit = 2 * 16; + int N_padded = (N + N_unit - 1) / N_unit * N_unit; + + // Bo(ni, 0) is a vector flattened from a slice of shape [K_padded x N_unit] + Bo.resize(N_padded / N_unit, K_padded * N_unit, false, is_const); + + int n = 0; + int n_tail = N % N_unit; + if (transpose) { + for(; n < N - n_tail; n += N_unit) { + // a K_padded x N_unit submatrix layouted in B0/B1... and put sequentially + auto* dst = reinterpret_cast(&Bo(n / N_unit, 0)); + int k; + for(k = 0; k < Kbody; k += kStep) { + // B0 (16x32) => transpose+repack as 32x16(16x16x2) or 64x16(16x16x4) + functional::transpose_epi32_16x16(dst, &Bi(n, k), Bi.stride); + dst += 1024; + functional::transpose_epi32_16x16(dst, &Bi(n + 16, k), Bi.stride); + dst += 1024; + } + if (Ktails) { + // Ktails part is loaded into A tile right-aligned, so B tile must also load + // Ktails part to bottom-aligned, and fill upper padding with zero + functional::transpose_epi32_16xN_right_align(dst, &Bi(n, k), Bi.stride, (K - k) * sizeof(T)); + dst += 1024; + functional::transpose_epi32_16xN_right_align(dst, &Bi(n + 16, k), Bi.stride, (K - k) * sizeof(T)); + dst += 1024; + } + } + // n_tail: [16, 32) + if (N - n >= 16) { + auto* dst = reinterpret_cast(&Bo(n / N_unit, 0)); + int k; + for(k = 0; k < Kbody; k += kStep) { + // B0 (16x32) => transpose+repack as 32x16(16x16x2) or 64x16(16x16x4) + functional::transpose_epi32_16x16(dst, &Bi(n, k), Bi.stride); + dst += 1024 * 2; + } + if (Ktails) { + // Ktails part is loaded into A tile right-aligned, so B tile must also load + // Ktails part to bottom-aligned, and fill upper padding with zero + functional::transpose_epi32_16xN_right_align(dst, &Bi(n, k), Bi.stride, (K - k) * sizeof(T)); + } + n += 16; + } + // n_tail: (0, 16) + if (N - n > 0) { + auto* dst = reinterpret_cast(&Bo(n / N_unit, 0)) + (n_tail > 16 ? 1024 : 0); + int k; + for(k = 0; k < Kbody; k += kStep) { + // B0 (16x32) => transpose+repack as 32x16(16x16x2) or 64x16(16x16x4) + functional::transpose_epi32_Mx16(dst, &Bi(n, k), Bi.stride, N - n); + dst += 1024 * 2; + } + if (Ktails) { + // Ktails part is loaded into A tile right-aligned, so B tile must also load + // Ktails part to bottom-aligned, and fill upper padding with zero + functional::transpose_epi32_MxN_right_align(dst, &Bi(n, k), Bi.stride, (K - k) * sizeof(T), N - n); + } + n = N; + } + // second B tile is untouched, need to set to zero + if (n_tail > 0 && n_tail <= 16) { + auto* dst = reinterpret_cast(&Bo(n / N_unit, 0)); + for (int k = 0; k < K_padded; k += kStep) { + memset(dst + 1024, 0, 1024); + dst += 1024 * 2; + } + } + } else { + // pack & layout sequentially + int n = 0; + int n_tail = N % N_unit; + for(; n < N - n_tail; n += N_unit) { + auto * dst = reinterpret_cast(&Bo(n / N_unit, 0)); + for(int k = 0; k < K; k += kStep) { + // bf16: B0 B1 32x(16+16) => repack as two 16x16x2 + // int8: B0 B1 64x(16+16) => repack as two 16x16x4 + int src_rows = std::min(K - k, kStep); + functional::kpack_tile_B0B1(dst, dst + 1024, &Bi(k, n), Bi.stride, src_rows); + dst += 2048; + } + } + // n_tail: (0, 32) + if (N - n > 0) { + auto * dst = reinterpret_cast(&Bo(n / N_unit, 0)); + for(int k = 0; k < K; k += kStep) { + // bf16: B0 B1 32x(16+16) => repack as two 16x16x2 + // int8: B0 B1 64x(16+16) => repack as two 16x16x4 + int src_rows = std::min(K - k, kStep); + functional::kpack_tile_B0B1_ntail(dst, dst + 1024, &Bi(k, n), Bi.stride, src_rows, N - n); + dst += 2048; + } + n += 16; + } + } +} + +inline void repackB_1x2(tensor2D &Bi, bool transpose, tensor2D& Bo, bool is_const) { + int K = Bi.dims[transpose ? 1 : 0]; + int N = Bi.dims[transpose ? 0 : 1]; + + // K_padded : round up to multiple of 32/64 + int kStep = 64 / sizeof(ov::bfloat16); + int K_padded = (K + kStep - 1) / kStep * kStep; + int Ktails = K % kStep; + int Kbody = K - Ktails; + + // N_padded : round up to multiple of (2*16) + int N_unit = 2 * 16; + int N_padded = (N + N_unit - 1) / N_unit * N_unit; + + // Bo(ni, 0) is a vector flattened from a slice of shape [K_padded x N_unit] + Bo.resize(N_padded / N_unit, K_padded * N_unit, false, is_const); + + int n = 0; + int n_tail = N % N_unit; + if (transpose) { + tensor2D Btmp(16, 32); + for(; n < N - n_tail; n += N_unit) { + // a K_padded x N_unit submatrix layouted in B0/B1... and put sequentially + auto* dst = reinterpret_cast(&Bo(n / N_unit, 0)); + int k; + for(k = 0; k < Kbody; k += kStep) { + // B0 (16x32) => transpose+repack as 32x16(16x16x2) or 64x16(16x16x4) + functional::f32_to_bf16_tensor(Btmp, tensor2D(16, 32, &Bi(n, k), Bi.stride)); + functional::transpose_epi32_16x16(dst, &Btmp(0, 0), Btmp.stride); + dst += 1024; + functional::f32_to_bf16_tensor(Btmp, tensor2D(16, 32, &Bi(n + 16, k), Bi.stride)); + functional::transpose_epi32_16x16(dst, &Btmp(0, 0), Btmp.stride); + dst += 1024; + } + if (Ktails) { + // Ktails part is loaded into A tile right-aligned, so B tile must also load + // Ktails part to bottom-aligned, and fill upper padding with zero + functional::f32_to_bf16_tensor(Btmp, tensor2D(16, K - k, &Bi(n, k), Bi.stride)); + functional::transpose_epi32_16xN_right_align(dst, &Btmp(0, 0), Btmp.stride, (K - k) * sizeof(ov::bfloat16)); + dst += 1024; + functional::f32_to_bf16_tensor(Btmp, tensor2D(16, K - k, &Bi(n + 16, k), Bi.stride)); + functional::transpose_epi32_16xN_right_align(dst, &Btmp(0, 0), Btmp.stride, (K - k) * sizeof(ov::bfloat16)); + dst += 1024; + } + } + // n_tail: [16, 32) + if (N - n >= 16) { + auto* dst = reinterpret_cast(&Bo(n / N_unit, 0)); + int k; + for(k = 0; k < Kbody; k += kStep) { + // B0 (16x32) => transpose+repack as 32x16(16x16x2) or 64x16(16x16x4) + functional::f32_to_bf16_tensor(Btmp, tensor2D(16, 32, &Bi(n, k), Bi.stride)); + functional::transpose_epi32_16x16(dst, &Btmp(0, 0), Btmp.stride); + dst += 1024 * 2; + } + if (Ktails) { + // Ktails part is loaded into A tile right-aligned, so B tile must also load + // Ktails part to bottom-aligned, and fill upper padding with zero + functional::f32_to_bf16_tensor(Btmp, tensor2D(16, K - k, &Bi(n, k), Bi.stride)); + functional::transpose_epi32_16xN_right_align(dst, &Btmp(0, 0), Btmp.stride, (K - k) * sizeof(ov::bfloat16)); + } + n += 16; + } + // n_tail: (0, 16) + if (N - n > 0) { + auto* dst = reinterpret_cast(&Bo(n / N_unit, 0)) + (n_tail > 16 ? 1024 : 0); + int k; + for(k = 0; k < Kbody; k += kStep) { + // B0 (16x32) => transpose+repack as 32x16(16x16x2) or 64x16(16x16x4) + functional::f32_to_bf16_tensor(Btmp, tensor2D(N - n, 32, &Bi(n, k), Bi.stride)); + functional::transpose_epi32_Mx16(dst, &Btmp(0, 0), Btmp.stride, N - n); + dst += 1024 * 2; + } + if (Ktails) { + // Ktails part is loaded into A tile right-aligned, so B tile must also load + // Ktails part to bottom-aligned, and fill upper padding with zero + functional::f32_to_bf16_tensor(Btmp, tensor2D(N - n, K - k, &Bi(n, k), Bi.stride)); + functional::transpose_epi32_MxN_right_align(dst, &Btmp(0, 0), Btmp.stride, (K - k) * sizeof(ov::bfloat16), N - n); + } + n = N; + } + // second B tile is untouched, need to set to zero + if (n_tail > 0 && n_tail <= 16) { + auto* dst = reinterpret_cast(&Bo(n / N_unit, 0)); + for (int k = 0; k < K_padded; k += kStep) { + memset(dst + 1024, 0, 1024); + dst += 1024 * 2; + } + } + } else { + // pack & layout sequentially + int n = 0; + int n_tail = N % N_unit; + tensor2D Btmp(32, 32); + for(; n < N - n_tail; n += N_unit) { + auto * dst = reinterpret_cast(&Bo(n / N_unit, 0)); + for(int k = 0; k < K; k += kStep) { + // bf16: B0 B1 32x(16+16) => repack as two 16x16x2 + int src_rows = std::min(K - k, kStep); + functional::f32_to_bf16_tensor(Btmp, tensor2D(src_rows, 32, &Bi(k, n), Bi.stride)); + functional::kpack_tile_B0B1(dst, dst + 1024, &Btmp(0, 0), Btmp.stride, src_rows); + dst += 2048; + } + } + // n_tail: (0, 32) + if (N - n > 0) { + auto * dst = reinterpret_cast(&Bo(n / N_unit, 0)); + for(int k = 0; k < K; k += kStep) { + // bf16: B0 B1 32x(16+16) => repack as two 16x16x2 + int src_rows = std::min(K - k, kStep); + functional::f32_to_bf16_tensor(Btmp, tensor2D(src_rows, N - n, &Bi(k, n), Bi.stride)); + functional::kpack_tile_B0B1_ntail(dst, dst + 1024, &Btmp(0, 0), Btmp.stride, src_rows, N - n); + dst += 2048; + } + n += 16; + } + } +} + +inline void repackB_1x2_compressed(tensor2D &Bi, bool transpose, tensor2D& Bo, bool is_const) { + int K = Bi.dims[transpose ? 1 : 0]; + int N = Bi.dims[transpose ? 0 : 1]; + + // K_padded : round up to multiple of 32/64 + int kStep = 32; + int K_padded = (K + kStep - 1) / kStep * kStep; + int Ktails = K % kStep; + int Kbody = K - Ktails; + + // N_padded : round up to multiple of (2*16) + int N_unit = 2 * 16; + int N_padded = (N + N_unit - 1) / N_unit * N_unit; + + // Bo(ni, 0) is a vector flattened from a slice of shape [K_padded x N_unit] + Bo.resize(N_padded / N_unit, K_padded * N_unit, false, is_const); + + int n = 0; + int n_tail = N % N_unit; + if (transpose) { + tensor2D Btmp(16, 32), transTmp(16, 32); + for(; n < N - n_tail; n += N_unit) { + // a K_padded x N_unit submatrix layouted in B0/B1... and put sequentially + auto* dst = reinterpret_cast(&Bo(n / N_unit, 0)); + int k; + for(k = 0; k < Kbody; k += kStep) { + // B0 (16x32) => transpose+repack as 32x16(16x16x2) or 64x16(16x16x4) + functional::u8_to_u16_tensor(Btmp, tensor2D(16, 32, &Bi(n, k), Bi.stride)); + functional::transpose_epi32_16x16(&transTmp(0, 0), &Btmp(0, 0), Btmp.stride); + functional::u16_to_u8_tensor(tensor2D(16, 32, dst, 32), transTmp); + dst += 512; + functional::u8_to_u16_tensor(Btmp, tensor2D(16, 32, &Bi(n + 16, k), Bi.stride)); + functional::transpose_epi32_16x16(&transTmp(0, 0), &Btmp(0, 0), Btmp.stride); + functional::u16_to_u8_tensor(tensor2D(16, 32, dst, 32), transTmp); + dst += 512; + } + if (Ktails) { + // Ktails part is loaded into A tile right-aligned, so B tile must also load + // Ktails part to bottom-aligned, and fill upper padding with zero + functional::u8_to_u16_tensor(Btmp, tensor2D(16, K - k, &Bi(n, k), Bi.stride)); + functional::transpose_epi32_16xN_right_align(&transTmp(0, 0), &Btmp(0, 0), Btmp.stride, (K - k) * sizeof(ov::bfloat16)); + functional::u16_to_u8_tensor(tensor2D(16, 32, dst, 32), transTmp); + dst += 512; + functional::u8_to_u16_tensor(Btmp, tensor2D(16, K - k, &Bi(n + 16, k), Bi.stride)); + functional::transpose_epi32_16xN_right_align(&transTmp(0, 0), &Btmp(0, 0), Btmp.stride, (K - k) * sizeof(ov::bfloat16)); + functional::u16_to_u8_tensor(tensor2D(16, 32, dst, 32), transTmp); + dst += 512; + } + } + // n_tail: [16, 32) + if (N - n >= 16) { + auto* dst = reinterpret_cast(&Bo(n / N_unit, 0)); + int k; + for(k = 0; k < Kbody; k += kStep) { + // B0 (16x32) => transpose+repack as 32x16(16x16x2) or 64x16(16x16x4) + functional::u8_to_u16_tensor(Btmp, tensor2D(16, 32, &Bi(n, k), Bi.stride)); + functional::transpose_epi32_16x16(&transTmp(0, 0), &Btmp(0, 0), Btmp.stride); + functional::u16_to_u8_tensor(tensor2D(16, 32, dst, 32), transTmp); + dst += 1024 * 1; + } + if (Ktails) { + // Ktails part is loaded into A tile right-aligned, so B tile must also load + // Ktails part to bottom-aligned, and fill upper padding with zero + functional::u8_to_u16_tensor(Btmp, tensor2D(16, K - k, &Bi(n, k), Bi.stride)); + functional::transpose_epi32_16xN_right_align(&transTmp(0, 0), &Btmp(0, 0), Btmp.stride, (K - k) * sizeof(ov::bfloat16)); + functional::u16_to_u8_tensor(tensor2D(16, 32, dst, 32), transTmp); + } + n += 16; + } + // n_tail: (0, 16) + if (N - n > 0) { + auto* dst = reinterpret_cast(&Bo(n / N_unit, 0)) + (n_tail > 16 ? 512 : 0); + int k; + for(k = 0; k < Kbody; k += kStep) { + // B0 (16x32) => transpose+repack as 32x16(16x16x2) or 64x16(16x16x4) + functional::u8_to_u16_tensor(Btmp, tensor2D(N - n, 32, &Bi(n, k), Bi.stride)); + functional::transpose_epi32_Mx16(&transTmp(0, 0), &Btmp(0, 0), Btmp.stride, N - n); + functional::u16_to_u8_tensor(tensor2D(16, 32, dst, 32), transTmp); + dst += 1024 * 1; + } + if (Ktails) { + // Ktails part is loaded into A tile right-aligned, so B tile must also load + // Ktails part to bottom-aligned, and fill upper padding with zero + functional::u8_to_u16_tensor(Btmp, tensor2D(N - n, K - k, &Bi(n, k), Bi.stride)); + functional::transpose_epi32_MxN_right_align(&transTmp(0, 0), &Btmp(0, 0), Btmp.stride, (K - k) * sizeof(ov::bfloat16), N - n); + functional::u16_to_u8_tensor(tensor2D(16, 32, dst, 32), transTmp); + } + n = N; + } + // second B tile is untouched, need to set to zero + if (n_tail > 0 && n_tail <= 16) { + auto* dst = reinterpret_cast(&Bo(n / N_unit, 0)); + for (int k = 0; k < K_padded; k += kStep) { + memset(dst + 512, 0, 512); + dst += 1024 * 1; + } + } + } else { + // pack & layout sequentially + int n = 0; + int n_tail = N % N_unit; + tensor2D Btmp(32, 32), transTmp(32, 32); + for(; n < N - n_tail; n += N_unit) { + auto * dst = reinterpret_cast(&Bo(n / N_unit, 0)); + for(int k = 0; k < K; k += kStep) { + // bf16: B0 B1 32x(16+16) => repack as two 16x16x2 + int src_rows = std::min(K - k, kStep); + functional::u8_to_u16_tensor(Btmp, tensor2D(src_rows, 32, &Bi(k, n), Bi.stride)); + functional::kpack_tile_B0B1(&transTmp(0, 0), &transTmp(16, 0), reinterpret_cast(&Btmp(0, 0)), Btmp.stride, src_rows); + functional::u16_to_u8_tensor(tensor2D(32, 32, dst, 32), transTmp); + dst += 1024; + } + } + // n_tail: (0, 32) + if (N - n > 0) { + auto * dst = reinterpret_cast(&Bo(n / N_unit, 0)); + for(int k = 0; k < K; k += kStep) { + // bf16: B0 B1 32x(16+16) => repack as two 16x16x2 + int src_rows = std::min(K - k, kStep); + functional::u8_to_u16_tensor(Btmp, tensor2D(src_rows, N - n, &Bi(k, n), Bi.stride)); + functional::kpack_tile_B0B1_ntail(&transTmp(0, 0), &transTmp(16, 0), reinterpret_cast(&Btmp(0, 0)), Btmp.stride, src_rows, N - n); + functional::u16_to_u8_tensor(tensor2D(32, 32, dst, 32), transTmp); + dst += 1024; + } + n += 16; + } + } +} + +template +struct acc_type {}; +template<> +struct acc_type { typedef float type; }; +template<> +struct acc_type { typedef int32_t type; }; +template<> +struct acc_type { typedef int32_t type; }; + +template +using acc_type_t = typename acc_type::type; + +// matrix multiply with vector +// C_Nx1 = A_MxK * b_Kx1 + +template> +struct MatmulVector { + MatmulVector() {} + constexpr static bool is_bf16s8 = std::is_same::value && std::is_same::value; + constexpr static bool is_bf16bf16 = std::is_same::value && std::is_same::value; + constexpr static bool is_s8s8 = std::is_same::value && std::is_same::value; + constexpr static bool is_s8u8 = std::is_same::value && std::is_same::value; + constexpr static bool is_u8s8 = std::is_same::value && std::is_same::value; + constexpr static bool is_u8u8 = std::is_same::value && std::is_same::value; + constexpr static bool is_i8_mode = is_s8s8 || is_s8u8 || is_u8s8 || is_u8u8; + constexpr static int kStep = is_i8_mode ? 64 : 32; + +#define TILE_DP(dst, a, b) \ + if (is_bf16bf16) _tile_dpbf16ps(dst, a, b); \ + if (is_s8s8) _tile_dpbssd(dst, a, b); \ + if (is_s8u8) _tile_dpbsud(dst, a, b); \ + if (is_u8s8) _tile_dpbusd(dst, a, b); \ + if (is_u8u8) _tile_dpbuud(dst, a, b); + + alignas(64) int8_t KtailBuff[64]; + + template + void kernel(int M, int K, const void * pA, int strideA, const void * vB, void * vC) { + static_assert(tmmN >= 1 && tmmN <= 6, "tmmN must be within [1-6] range"); + const auto * pA0 = reinterpret_cast(pA); + int KLastOffBytes = (K - kStep) * sizeof(TA); + const auto * pB0 = reinterpret_cast(vB); + auto * pC0 = reinterpret_cast(vC); + + const auto * pBLast = pB0 + 64*(tmmN - 1); + int Ktail = K & (kStep - 1); + if (Ktail) { + if (bFallbackKtails) { + // if bContainMtails, the last submatrix needs to use special to prevent A matrix read overflow + // K tails is handled by: + // - zero-padding the last tile of vector B, at the top + // - right-align last tile load from matA + __mmask64 kmask = _cvtu64_mask64(0xFFFFFFFFFFFFFFFFull << (kStep - Ktail)*sizeof(TB)); + auto r = _mm512_maskz_loadu_epi8(kmask, pB0 + KLastOffBytes); + _mm512_storeu_epi8(KtailBuff, r); + } else { + // each row of A can be read overflow w/o worrying NaN numbers + // zero-padding the last tile of vector B as bottom is enough + __mmask64 kmask = _cvtu64_mask64(0xFFFFFFFFFFFFFFFFull >> (kStep - Ktail)*sizeof(TB)); + KLastOffBytes = (K - Ktail)*sizeof(TA); + auto r = _mm512_maskz_loadu_epi8(kmask, pB0 + KLastOffBytes); + _mm512_storeu_epi8(KtailBuff, r); + } + pBLast = KtailBuff; + } + + // load B tiles outside of loop + if (tmmN == 1) { + _tile_loadd(2, pB0, 4); + } + if (tmmN == 2) { + _tile_loadd(2, pB0, 4); + _tile_loadd(3, pBLast, 4); + } + if (tmmN == 3) { + _tile_loadd(2, pB0, 4); + _tile_loadd(3, pB0 + 64, 4); + _tile_loadd(4, pBLast, 4); + } + if (tmmN == 4) { + _tile_loadd(2, pB0, 4); + _tile_loadd(3, pB0 + 64, 4); + _tile_loadd(4, pB0 + 64*2, 4); + _tile_loadd(5, pBLast, 4); + } + if (tmmN == 5) { + _tile_loadd(2, pB0, 4); + _tile_loadd(3, pB0 + 64, 4); + _tile_loadd(4, pB0 + 64*2, 4); + _tile_loadd(5, pB0 + 64*3, 4); + _tile_loadd(6, pBLast, 4); + } + if (tmmN == 6) { + _tile_loadd(2, pB0, 4); + _tile_loadd(3, pB0 + 64, 4); + _tile_loadd(4, pB0 + 64*2, 4); + _tile_loadd(5, pB0 + 64*3, 4); + _tile_loadd(6, pB0 + 64*4, 4); + _tile_loadd(7, pBLast, 4); + } + //asm("int3"); + for(int m = 0; m < M; m+=16) { + _tile_zero(0); + if (tmmN == 1) { + _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); + } + if (tmmN == 2) { + _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); + _tile_loadd(1, pA0 + KLastOffBytes, strideA); TILE_DP(0, 1, 3); + } + if (tmmN == 3) { + _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); + _tile_loadd(1, pA0 + 64, strideA); TILE_DP(0, 1, 3); + _tile_loadd(1, pA0 + KLastOffBytes, strideA); TILE_DP(0, 1, 4); + } + if (tmmN == 4) { + _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); + _tile_loadd(1, pA0 + 64, strideA); TILE_DP(0, 1, 3); + _tile_loadd(1, pA0 + 128, strideA); TILE_DP(0, 1, 4); + _tile_loadd(1, pA0 + KLastOffBytes, strideA); TILE_DP(0, 1, 5); + } + if (tmmN == 5) { + _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); + _tile_loadd(1, pA0 + 64, strideA); TILE_DP(0, 1, 3); + _tile_loadd(1, pA0 + 128, strideA); TILE_DP(0, 1, 4); + _tile_loadd(1, pA0 + 192, strideA); TILE_DP(0, 1, 5); + _tile_loadd(1, pA0 + KLastOffBytes, strideA); TILE_DP(0, 1, 6); + } + if (tmmN == 6) { + _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); + _tile_loadd(1, pA0 + 64, strideA); TILE_DP(0, 1, 3); + _tile_loadd(1, pA0 + 128, strideA); TILE_DP(0, 1, 4); + _tile_loadd(1, pA0 + 192, strideA); TILE_DP(0, 1, 5); + _tile_loadd(1, pA0 + 256, strideA); TILE_DP(0, 1, 6); + _tile_loadd(1, pA0 + KLastOffBytes, strideA); TILE_DP(0, 1, 7); + } + _tile_stored(0, pC0, 4); pC0 += 16*4; // C is single column, always take 4 bytes + pA0 += 16 * strideA; + } + } + + void operator()(tensor2D & matA, const TB * vB, TC * vC) { + int M = matA.dims[0]; + int K = matA.dims[1]; + TA * pA = &matA[0]; + int strideA = matA.stride; + + // M tails is handled + assert(K >= kStep && K <= 6*kStep); + + int Ktail = K & (kStep - 1); + int Mtail = M & (16 - 1); + int Mbody = M - Mtail; + int numBtiles = (K + kStep - 1)/kStep; + + // if we have Ktails, then it will always be handled in Mtail, so we split + // Mtail out even if it's zero + if (Ktail) { + if (Mtail == 0) { + Mtail = 16; + Mbody -= 16; + } + } + + if (Mbody) { + tileconfig_t tfg(1, 0, { + {16, 4}, // C:0 M x 1 (4b) + {16, 64}, // A:1 M x 32/64 (64b) + {16, 4}, // B:2 32/64 x 1 (4b) + {16, 4}, // B:3 + {16, 4}, // B:4 + {16, 4}, // B:5 + {16, 4}, // B:6 + {16, 4}, // B:7 + }); + // Ktail fallback will always be done at Mtails loop + switch(numBtiles) { + case 1: kernel<1, false>(Mbody, K, pA, strideA, vB, vC); break; + case 2: kernel<2, false>(Mbody, K, pA, strideA, vB, vC); break; + case 3: kernel<3, false>(Mbody, K, pA, strideA, vB, vC); break; + case 4: kernel<4, false>(Mbody, K, pA, strideA, vB, vC); break; + case 5: kernel<5, false>(Mbody, K, pA, strideA, vB, vC); break; + case 6: kernel<6, false>(Mbody, K, pA, strideA, vB, vC); break; + default: + assert(false); // impossible since (K <= 6*kStep) + } + } + + if (Mtail) { + pA = &matA(Mbody, 0); + tileconfig_t tfg(1, 0, { + {Mtail, 4}, // C:0 M x 1 (4b) + {Mtail, 64}, // A:1 M x 32/64 (64b) + {16, 4}, // B:2 32/64 x 1 (4b) + {16, 4}, // B:3 + {16, 4}, // B:4 + {16, 4}, // B:5 + {16, 4}, // B:6 + {16, 4}, // B:7 + }); + if (Ktail) { + switch(numBtiles) { + case 1: kernel<1, true>(Mtail, K, pA, strideA, vB, vC + Mbody); break; + case 2: kernel<2, true>(Mtail, K, pA, strideA, vB, vC + Mbody); break; + case 3: kernel<3, true>(Mtail, K, pA, strideA, vB, vC + Mbody); break; + case 4: kernel<4, true>(Mtail, K, pA, strideA, vB, vC + Mbody); break; + case 5: kernel<5, true>(Mtail, K, pA, strideA, vB, vC + Mbody); break; + case 6: kernel<6, true>(Mtail, K, pA, strideA, vB, vC + Mbody); break; + default: + assert(false); // impossible since (K <= 6*kStep) + } + } else { + switch(numBtiles) { + case 1: kernel<1, false>(Mtail, K, pA, strideA, vB, vC + Mbody); break; + case 2: kernel<2, false>(Mtail, K, pA, strideA, vB, vC + Mbody); break; + case 3: kernel<3, false>(Mtail, K, pA, strideA, vB, vC + Mbody); break; + case 4: kernel<4, false>(Mtail, K, pA, strideA, vB, vC + Mbody); break; + case 5: kernel<5, false>(Mtail, K, pA, strideA, vB, vC + Mbody); break; + case 6: kernel<6, false>(Mtail, K, pA, strideA, vB, vC + Mbody); break; + default: + assert(false); // impossible since (K <= 6*kStep) + } + } + } + } +}; + +template::type> +struct Matmul { + // B matrix is orgnized as tensor2D of shape axb where a=round_up_div(N, 32), b=round_up(K,32/64)*32 + // so b is size of submatrix of Kx32 composed of two columns of B0/B1 tiles. + tensor2D internalB; + tensor2D A_PaddedK; // pad to 32(bf16)/64(int8) buffer + tensor2D B_PaddedK; // pad to 32(bf16)/64(int8) buffer + + bool constB; + bool transposeB; + + constexpr static bool is_bf16s8 = std::is_same::value && std::is_same::value; + constexpr static bool is_bf16bf16 = std::is_same::value && std::is_same::value; + constexpr static bool is_s8s8 = std::is_same::value && std::is_same::value; + constexpr static bool is_s8u8 = std::is_same::value && std::is_same::value; + constexpr static bool is_u8s8 = std::is_same::value && std::is_same::value; + constexpr static bool is_u8u8 = std::is_same::value && std::is_same::value; + constexpr static bool is_i8_mode = is_s8s8 || is_s8u8 || is_u8s8 || is_u8u8; + + // AMX bf16 & int8 has same M(=16) in A,C tile and same N(=16) in B tile + // but only different K(32 vs 64) in A,C & B tiles + constexpr static int kStep = is_i8_mode ? 64 : 32; + + + Matmul(bool constB = false, bool transposeB = false) : + constB(constB), transposeB(transposeB) {} + + // ppkernel is a callable which captures the runtime args + // by itself, so no need to pass in any post-process related + // runtime args through this API + // + // n0/n1 allows us for calculating only partial results, so it + // can be used to run on multi-cores in parallel + // + // ppkernel will be invoked with true (m,n) with n0-offset added + // so ppkernel don't need to know which sub-matrix it's working on. + // + // for most ppkernels w/o runtime state, a single ppkernel can be + // shared among all threads. + // + // but for ppkernels doing reductions, it needs separate instance + // for each thread, also a post-merging process to combine the results. + // + // ppkernels are simple to write, further wrapping or structurelize only + // makes the design more complex, so we stop doing that. + + // I cannot find a way to call TDP intrinsic polymophically using overload or template. + // have to use old-macro-tricks, hopefully these compile-time checks can be optimized + // by compiler. +#define TILE_DP(dst, a, b) \ + if (is_bf16bf16) _tile_dpbf16ps(dst, a, b); \ + if (is_s8s8) _tile_dpbssd(dst, a, b); \ + if (is_s8u8) _tile_dpbsud(dst, a, b); \ + if (is_u8s8) _tile_dpbusd(dst, a, b); \ + if (is_u8u8) _tile_dpbuud(dst, a, b); + + template + void kernel_slimB(int M, int N, int K, int n0, + tensor2D & A, + void * B, + tensor2D & buffC, + PP ppkernel) { + auto * pB0 = reinterpret_cast(B); + auto * pC0 = &buffC[0]; + int8_t * pA0 = reinterpret_cast(&A[0]); + int strideA = A.stride; + int KlastOffBytes = (K - kStep)* sizeof(TA); + // load B tiles outside of loop + if (tmmN > 0) { + _tile_loadd(2, pB0, 64); pB0 += 1024*2; + } + if (tmmN > 1) { + _tile_loadd(3, pB0, 64); pB0 += 1024*2; + } + if (tmmN > 2) { + _tile_loadd(4, pB0, 64); pB0 += 1024*2; + } + if (tmmN > 3) { + _tile_loadd(5, pB0, 64); pB0 += 1024*2; + } + if (tmmN > 4) { + _tile_loadd(6, pB0, 64); pB0 += 1024*2; + } + if (tmmN > 5) { + _tile_loadd(7, pB0, 64); pB0 += 1024*2; + } + //asm("int3"); + for(int m0 = 0; m0 < M; m0+=16) { + int m = m0; + if (M - m0 < 16) { + // shift up to prevent M-tails + pA0 -= (16 - (M - m0))*A.stride; + m = M - 16; + } + _tile_zero(0); + if (tmmN == 1) { + _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); + } + if (tmmN == 2) { + _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); + _tile_loadd(1, pA0 + KlastOffBytes, strideA); TILE_DP(0, 1, 3); + } + if (tmmN == 3) { + _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); + _tile_loadd(1, pA0 + 64, strideA); TILE_DP(0, 1, 3); + _tile_loadd(1, pA0 + KlastOffBytes, strideA); TILE_DP(0, 1, 4); + } + if (tmmN == 4) { + _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); + _tile_loadd(1, pA0 + 64, strideA); TILE_DP(0, 1, 3); + _tile_loadd(1, pA0 + 128, strideA); TILE_DP(0, 1, 4); + _tile_loadd(1, pA0 + KlastOffBytes, strideA); TILE_DP(0, 1, 5); + } + if (tmmN == 5) { + _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); + _tile_loadd(1, pA0 + 64, strideA); TILE_DP(0, 1, 3); + _tile_loadd(1, pA0 + 128, strideA); TILE_DP(0, 1, 4); + _tile_loadd(1, pA0 + 192, strideA); TILE_DP(0, 1, 5); + _tile_loadd(1, pA0 + KlastOffBytes, strideA); TILE_DP(0, 1, 6); + } + if (tmmN == 6) { + _tile_loadd(1, pA0, strideA); TILE_DP(0, 1, 2); + _tile_loadd(1, pA0 + 64, strideA); TILE_DP(0, 1, 3); + _tile_loadd(1, pA0 + 128, strideA); TILE_DP(0, 1, 4); + _tile_loadd(1, pA0 + 192, strideA); TILE_DP(0, 1, 5); + _tile_loadd(1, pA0 + 256, strideA); TILE_DP(0, 1, 6); + _tile_loadd(1, pA0 + KlastOffBytes, strideA); TILE_DP(0, 1, 7); + } + _tile_stored(0, pC0, buffC.stride); + (ppkernel)(buffC, m, n0, 16, N); + pA0 += 16*A.stride; + } + } + + template + void operator()(tensor2D & matA, + tensor2D & _matB, + int n0, int n1, + PP ppkernel, + bool skip_repack = false) { + int M = matA.dims[0]; + int K = matA.dims[1]; + // 2x2 C tiles buffer + // most usecase requires post-processing with AVX, thus buffC + // is used to transfer data to AVX register + alignas(64) TC buff[32 * 32]; + tensor2D buffC(32, 32, buff, 32 * sizeof(TC)); + + tensor2D* pA = &matA; + tensor2D* pB = &_matB; + if (K < kStep) { + int B0, B1; + if (transposeB) { + B0 = _matB.dims[0]; + B1 = kStep; + } else { + B0 = kStep; + B1 = _matB.dims[1]; + } + matA.copyto_with_padzero(A_PaddedK, M, kStep); + pA = &A_PaddedK; + _matB.copyto_with_padzero(B_PaddedK, B0, B1); + pB = &B_PaddedK; + K = kStep; + } + auto matB = getSubMatB(*pB, n0, n1, transposeB); + int N = matB.dims[transposeB ? 0 : 1]; + assert(K == matB.dims[transposeB ? 1 : 0]); + // Due to the fact that we load a full tile at tails of K dimension + // we may access memory address beyond the limit of A matrix + // to avoid read in nan values, we backoff to the left to ensure A tile + // contain valid numbers and no overflow access happens, but it requires K>=kStep; + assert(K >= kStep); + int Ktails = K % kStep; + int Kbody = K - Ktails; + int KbackoffBytes = (kStep - Ktails)*sizeof(TA); + + // for non-constB, internalB is updated every time + // for constB, internalB is updated once + if ((!constB && !skip_repack) || (internalB.capacity == 0)) { + repackB_1x2(matB, transposeB, internalB, constB); + } + + // special case when whole B matrix can fit in 6 tiles + // we can load B only once + if (M >= 16 && N <= 16 && K <= 6*kStep) { + // B is zero-padded + // C:0 + // A:1 + // B:2,3,4,5,6,7 + auto * pB0 = reinterpret_cast(&internalB[0]); + tileconfig_t tfg(1, 0, 8, 16, 64); + switch((K + kStep - 1)/kStep) { + case 1: kernel_slimB<1>(M, N, K, n0, *pA, pB0, buffC, ppkernel); break; + case 2: kernel_slimB<2>(M, N, K, n0, *pA, pB0, buffC, ppkernel); break; + case 3: kernel_slimB<3>(M, N, K, n0, *pA, pB0, buffC, ppkernel); break; + case 4: kernel_slimB<4>(M, N, K, n0, *pA, pB0, buffC, ppkernel); break; + case 5: kernel_slimB<5>(M, N, K, n0, *pA, pB0, buffC, ppkernel); break; + case 6: kernel_slimB<6>(M, N, K, n0, *pA, pB0, buffC, ppkernel); break; + default: + assert(false); // impossible since (K <= 6*kStep) + } + return; + } + + if (M <= 16) { + // register/cache blocking scheme is simplified when M <= 16 + // C_MxN: 0,1 + // A_MxK: 2, + // B_KxN: 3, 4 + tileconfig_t tfg(1, 0, {M,M,M,16,16}, 64); + auto * pB0 = reinterpret_cast(&internalB[0]); + auto * const pC0 = &buffC[0]; + int k; + const auto strideA = (*pA).stride; + loop2D_no_bM<32>(M, N, [&](int m, int n, int valid_m, int valid_n) { + _tile_zero(0); + _tile_zero(1); + int8_t * pA0 = reinterpret_cast(&(*pA)[0]); + for(k=0; k(&(*pA)(m, 0)); + auto * pA1 = reinterpret_cast(&(*pA)(m + 16, 0)); + auto strideA = (*pA).stride; + auto * pB = reinterpret_cast(&internalB(n>>5, 0)); + _tile_zero(0); + _tile_zero(1); + _tile_zero(2); + _tile_zero(3); + // 2x2 + for (int k = 0; k < Kbody; k += kStep) { + _tile_loadd(4, pA0, strideA); pA0 += 64; + _tile_loadd(6, pB, 64); pB += 1024; + // prefetch_bytes(1024, _MM_HINT_T0, 4096, pB); + TILE_DP(0, 4, 6); + + _tile_loadd(5, pA1, strideA); pA1 += 64; + TILE_DP(2, 5, 6); + _tile_loadd(7, pB, 64); pB += 1024; + // prefetch_bytes(1024, _MM_HINT_T0, 4096, pB); + TILE_DP(1, 4, 7); + + TILE_DP(3, 5, 7); + } + if (Ktails) { + _tile_loadd(4, pA0 - KbackoffBytes, strideA); + _tile_loadd(6, pB, 64); pB += 1024; + // prefetch_bytes(1024, _MM_HINT_T0, 4096, pB); + TILE_DP(0, 4, 6); + + _tile_loadd(5, pA1 - KbackoffBytes, strideA); + TILE_DP(2, 5, 6); + _tile_loadd(7, pB, 64); pB += 1024; + // prefetch_bytes(1024, _MM_HINT_T0, 4096, pB); + TILE_DP(1, 4, 7); + + TILE_DP(3, 5, 7); + } + _tile_stored(0, &buffC(0,0), buffC.stride); + _tile_stored(1, &buffC(0,16), buffC.stride); + _tile_stored(2, &buffC(16,0), buffC.stride); + _tile_stored(3, &buffC(16,16), buffC.stride); + (ppkernel)(buffC, m, n + n0, valid_m, valid_n); + }; + + if (M <= 32 && M > 16) { + // 2x2 tile, C:0/1/2/3 A:4/5 B:6/7 no blocking along M dimension + tileconfig_t tfg(1, 0, {16,16,M-16,M-16,16,M-16,16,16}, 64); + loop2D_no_bM<32>(M, N, kernel_2x2); + return; + } + + // generic input shapes with M > 32 + // determine cache blocking scheme + int elesz = sizeof(TA); + int L2 = 2048*1024; // 2MB + int slice_size = 32*rndup(K, 32)*elesz; + int mc = std::max(1, L2/slice_size - 1); + + // M > bM + tileconfig_t tfg(1, 0, 8, 16, 64); + loop2D_opt_Mtail<32, 32>(M, N, mc, kernel_2x2); + } +}; + +// specialization: +// TA is ov::bfloat16 and TB is int8_t, decompressed on the fly into ov::bfloat16 by simply convert +template<> +struct Matmul { + tensor2D internalBI8; + + bool constB; + bool transposeB; + + constexpr static int kStep = 32; + + Matmul(bool constB = false, bool transposeB = false) : + constB(constB), transposeB(transposeB) {} + + float* dequant_scale_B; + float* zp; + + template + void operator()(tensor2D & matA, + tensor2D & _matB, + int n0, int n1, + PP ppkernel) { + alignas(64) float buff[32 * 32]; + // wei_buff is ping-pong buffer containing ov::bfloat16 weights decompressed on the fly. + alignas(64) ov::bfloat16 weiBuff[32 * 2 * 32]; + // 2x2 C tiles buffer + // most usecase requires post-processing with AVX, thus buffC + // is used to transfer data to AVX register + tensor2D buffC(32, 32, buff, 32 * sizeof(float)); + + auto matB = getSubMatB(_matB, n0, n1, transposeB); + int M = matA.dims[0]; + int K = matA.dims[1]; + int N = matB.dims[transposeB ? 0 : 1]; + assert(K == matB.dims[transposeB ? 1 : 0]); + // Due to the fact that we load a full tile at tails of K dimension + // we may access memory address beyond the limit of A matrix + // to avoid read in nan values, we backoff to the left to ensure A tile + // contain valid numbers and no overflow access happens, but it requires K>=kStep; + assert(K >= kStep); + int Ktails = K % kStep; + int Kbody = K - Ktails; + int Kbackoff = (kStep - Ktails); + + ppkernel.set_deq_scale(dequant_scale_B); + auto zp_start = zp + n0 * 2; + + if (M <= 16) { + // C:0/1 A:2 B:3/4 + // dequantize scale is moved into ppkernel + //constexpr int prefetch_ahead = 64*1024; + tileconfig_t tfg(1, 0, {M,M,M,16,16}, 64); + auto * pBint = reinterpret_cast(&internalBI8[0]); + auto * const pB = weiBuff; + auto * pBsrc = pB + (32*32) * 0; + auto * pBdst = pB + (32*32) * 1; + auto * const pC0 = &buffC[0]; + const auto strideA = matA.stride; + if (Ktails) { + // with tails, will decompress current block's weights: if ahead we need to know the (tails - 1) which is in the + // tight loop, then special handle the decompress process - skip subtract zeropoint step + loop2D_no_bM<32>(M, N, [&](int m, int n, int valid_m, int valid_n) { + // C:Mx32 = A:Mx32 x B:32x32 + _tile_zero(0); + _tile_zero(1); + auto * pA0 = &matA[0]; + auto cur_zp = zp_start + n * 2; + for(int k=0; k(pBint, pBsrc, cur_zp); + _tile_loadd(3, pBsrc, 64); + _tile_dpbf16ps(0, 2, 3); // C0 += A*B0 + functional::i8_to_bf16_Kx32<16>(pBint, pBsrc + 16 * 32, cur_zp + 32); + + // prefetch_bytes(512, _MM_HINT_T1, prefetch_ahead, pBint); + _tile_loadd(4, pBsrc + 16*32, 64); + _tile_dpbf16ps(1, 2, 4); // C1 += A*B1 + std::swap(pBsrc, pBdst); + } + // tails + { + _tile_loadd(2, pA0 - Kbackoff, strideA); // backoff to prevent access beyond the end of A + //prefetch_bytes(512, _MM_HINT_T1, prefetch_ahead, pBint); + + functional::i8_to_bf16_Kx32_tail<16>(pBint, pBsrc, cur_zp, 0, Kbackoff / 2); + _tile_loadd(3, pBsrc, 64); + _tile_dpbf16ps(0, 2, 3); // C0 += A*B0 + + //prefetch_bytes(512, _MM_HINT_T1, prefetch_ahead, pBint); + functional::i8_to_bf16_Kx32_tail<16>(pBint, pBsrc + 16*32, cur_zp + 32, 0, Kbackoff / 2); + _tile_loadd(4, pBsrc + 16*32, 64); + _tile_dpbf16ps(1, 2, 4); // C1 += A*B1 + std::swap(pBsrc, pBdst); + } + //prefetch_bytes(2048, _MM_HINT_T1, prefetch_ahead, pBint); + _tile_stored(0, pC0, buffC.stride); + _tile_stored(1, pC0 + 16, buffC.stride); + //prefetch_bytes(2048, _MM_HINT_T1, prefetch_ahead, pBint + 2048); + //int valid_n = std::min(N - n, 32); + (ppkernel)(buffC, 0, n + n0, M, valid_n); + }); + } else { + // no tails, will decompress next block's weights ahead + functional::i8_to_bf16_Kx32<16>(pBint, pBsrc, zp_start); + functional::i8_to_bf16_Kx32<16>(pBint, pBsrc + 16 * 32, zp_start + 32); + + loop2D_no_bM<32>(M, N, [&](int m, int n, int valid_m, int valid_n) { + // C:Mx32 = A:Mx32 x B:32x32 + _tile_zero(0); + _tile_zero(1); + auto * pA0 = &matA[0]; + auto cur_zp = zp_start + n * 2; + for(int k=0; k(pBint, pBdst, cur_zp); + _tile_loadd(3, pBsrc, 64); + functional::i8_to_bf16_Kx32<8>(pBint, pBdst + 8*32, cur_zp); + _tile_dpbf16ps(0, 2, 3); // C0 += A*B0 + + //prefetch_bytes(512, _MM_HINT_T1, prefetch_ahead, pBint); + functional::i8_to_bf16_Kx32<8>(pBint, pBdst + 16*32, cur_zp + 32); + _tile_loadd(4, pBsrc + 16*32, 64); + functional::i8_to_bf16_Kx32<8>(pBint, pBdst + 24*32, cur_zp + 32); + _tile_dpbf16ps(1, 2, 4); // C1 += A*B1 + std::swap(pBsrc, pBdst); + } + //prefetch_bytes(2048, _MM_HINT_T1, prefetch_ahead, pBint); + _tile_stored(0, pC0, buffC.stride); + _tile_stored(1, pC0 + 16, buffC.stride); + //prefetch_bytes(2048, _MM_HINT_T1, prefetch_ahead, pBint + 2048); + //int valid_n = std::min(N - n, 32); + (ppkernel)(buffC, 0, n + n0, M, valid_n); + }); + } + return; + } + + // 4 tiles buffC is reused as decompressed bf16 weights + ov::bfloat16 * pBa = reinterpret_cast(&buffC(0,0)); + ov::bfloat16 * pBb = pBa + (16*32)*2; + auto kernel_2x2 = [&](int m, int n, int valid_m, int valid_n) { + auto strideA = matA.stride; + auto * pA0 = &matA(m, 0); + auto * pA1 = &matA(m + 16, 0); + auto * pBint = reinterpret_cast(&internalBI8(n>>5, 0)); + _tile_zero(0); + _tile_zero(1); + _tile_zero(2); + _tile_zero(3); + auto cur_zp = zp_start + n * 2; + if (Ktails) { + int k; + for (k = 0; k < Kbody; k += kStep) { + functional::i8_to_bf16_Kx32<16>(pBint, pBa, cur_zp); + + _tile_loadd(4, pA0 + k, strideA); + _tile_loadd(6, pBa, 64); + _tile_dpbf16ps(0, 4, 6); + + _tile_loadd(5, pA1 + k, strideA); + _tile_dpbf16ps(2, 5, 6); + + functional::i8_to_bf16_Kx32<16>(pBint, pBa + 16*32, cur_zp + 32); + + _tile_loadd(7, pBa + 16*32, 64); + _tile_dpbf16ps(1, 4, 7); + _tile_dpbf16ps(3, 5, 7); + + std::swap(pBa, pBb); + } + // tails + { + functional::i8_to_bf16_Kx32_tail<16>(pBint, pBa, cur_zp, 0, Kbackoff / 2); + + _tile_loadd(4, pA0 + k - Kbackoff, strideA); + _tile_loadd(6, pBa, 64); + _tile_dpbf16ps(0, 4, 6); + + _tile_loadd(5, pA1 + k - Kbackoff, strideA); + _tile_dpbf16ps(2, 5, 6); + + functional::i8_to_bf16_Kx32_tail<16>(pBint, pBa + 16*32, cur_zp + 32, 0, Kbackoff / 2); + + _tile_loadd(7, pBa + 16*32, 64); + _tile_dpbf16ps(1, 4, 7); + _tile_dpbf16ps(3, 5, 7); + + std::swap(pBa, pBb); + } + } else { + functional::i8_to_bf16_Kx32<16>(pBint, pBb, cur_zp); + functional::i8_to_bf16_Kx32<16>(pBint, pBb + 16 * 32, cur_zp + 32); + + for (int k = 0; k < Kbody; k += kStep) { + // weights are ahead of A, if reach the last K block, need to change to next N block + cur_zp += (k == K - kStep) * 64; + functional::i8_to_bf16_Kx32<16>(pBint, pBa, cur_zp); + + _tile_loadd(4, pA0 + k, strideA); + _tile_loadd(6, pBb, 64); + _tile_dpbf16ps(0, 4, 6); + + _tile_loadd(5, pA1 + k, strideA); + _tile_dpbf16ps(2, 5, 6); + + functional::i8_to_bf16_Kx32<16>(pBint, pBa + 16*32, cur_zp + 32); + + _tile_loadd(7, pBb + 16*32, 64); + _tile_dpbf16ps(1, 4, 7); + _tile_dpbf16ps(3, 5, 7); + + std::swap(pBa, pBb); + } + } + _tile_stored(0, &buffC(0,0), buffC.stride); + _tile_stored(1, &buffC(0,16), buffC.stride); + _tile_stored(2, &buffC(16,0), buffC.stride); + _tile_stored(3, &buffC(16,16), buffC.stride); + (ppkernel)(buffC, m, n + n0, valid_m, valid_n); + }; + + if (M <= 32 && M > 16) { + // 2x2 C:0/1/2/3 A:4/5 B:6/7 + tileconfig_t tfg(1, 0, {16, 16, M-16, M-16, 16, M-16, 16, 16}, 64); + loop2D_no_bM<32>(M, N, kernel_2x2); + return; + } + + // determine blocking scheme + int elesz = sizeof(uint16_t); + int L2 = 2048*1024; // 2MB + int slice_size = 32*rndup(K, 32)*elesz; + int mc = std::max(1, L2/slice_size - 1); // if 1 32xK slice cannot fit L2, use 1 slice at least + + // main loop + tileconfig_t tfg(1, 0, 8, 16, 64); + loop2D_opt_Mtail<32, 32>(M, N, mc, kernel_2x2); + } +}; + +//https://stackoverflow.com/questions/29519222/how-to-transpose-a-16x16-matrix-using-simd-instructions +// vector multiply with matrix: +// mAvB: A(M, K) * B(K, 1) => C(M, 1) +// vAmB: A(1, K) * B(K, N) => C(1, N) +// +// in mAvB form, block of A (16x32) is transposed in register +// in unit of 2 packed bf16, and then vdpbf16ps was used +// to multiply with broadcasted B (2x1) and accumulate into C (16x1) +// +// B is pre-broadcasted in unit of 2 +// +struct GemAvB { + tensor2D Bpadded; + GemAvB() { + } + + void operator()(tensor2D & matA, + ov::bfloat16 * vecB, + float * vecC) { + int M = matA.dims[0]; + int K = matA.dims[1]; + + assert(K >= 32); + + if (K % 32) { + if (K > Bpadded.dims[1]) + Bpadded.resize(1, rndup(K, 32)); + auto newB = &Bpadded(0, 0); + memset(static_cast(newB), 0, Bpadded.stride); + memcpy(newB, vecB, K * sizeof(ov::bfloat16)); + vecB = newB; + } + + for(int m = 0; m < M; m += 16) { + auto * pA = reinterpret_cast(&matA(m, 0)); + auto * pBi32 = reinterpret_cast(vecB); + __m512 regC0 = _mm512_setzero(); + __m512 regC1 = _mm512_setzero(); + __mmask16 kmask = _cvtu32_mask16(0xFFFF); + if (M-m < 16) { + kmask = _cvtu32_mask16(0xFFFF >> (16-(M-m))); + } + for(int k = 0; k < K; k += 32, pA += 64, pBi32 += 16) { + // handle Ab: 16x32 + // transposed in register as 16x16x2 + // r0: (a0,a1)(b0,b1).... + // r1: (a2,a3)(b2,b3).... + // ... + // rf: (a30,a31),(b30,b31).... + // + __m512i r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf; + auto stride = matA.stride; + r0 = _mm512_loadu_epi32(pA); + r1 = _mm512_loadu_epi32(pA + stride); + r2 = _mm512_loadu_epi32(pA + 2*stride); + r3 = _mm512_loadu_epi32(pA + 3*stride); + r4 = _mm512_loadu_epi32(pA + 4*stride); + r5 = _mm512_loadu_epi32(pA + 5*stride); + r6 = _mm512_loadu_epi32(pA + 6*stride); + r7 = _mm512_loadu_epi32(pA + 7*stride); + r8 = _mm512_loadu_epi32(pA + 8*stride); + r9 = _mm512_loadu_epi32(pA + 9*stride); + ra = _mm512_loadu_epi32(pA + 10*stride); + rb = _mm512_loadu_epi32(pA + 11*stride); + rc = _mm512_loadu_epi32(pA + 12*stride); + rd = _mm512_loadu_epi32(pA + 13*stride); + re = _mm512_loadu_epi32(pA + 14*stride); + rf = _mm512_loadu_epi32(pA + 15*stride); + + functional::transpose_m512i_16x16(r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf); + + // vdpbf16ps + regC0 = _mm512_dpbf16_ps(regC0, (__m512bh)(r0), (__m512bh)(_mm512_set1_epi32(pBi32[0]))); + regC1 = _mm512_dpbf16_ps(regC1, (__m512bh)(r1), (__m512bh)(_mm512_set1_epi32(pBi32[1]))); + regC0 = _mm512_dpbf16_ps(regC0, (__m512bh)(r2), (__m512bh)(_mm512_set1_epi32(pBi32[2]))); + regC1 = _mm512_dpbf16_ps(regC1, (__m512bh)(r3), (__m512bh)(_mm512_set1_epi32(pBi32[3]))); + regC0 = _mm512_dpbf16_ps(regC0, (__m512bh)(r4), (__m512bh)(_mm512_set1_epi32(pBi32[4]))); + regC1 = _mm512_dpbf16_ps(regC1, (__m512bh)(r5), (__m512bh)(_mm512_set1_epi32(pBi32[5]))); + regC0 = _mm512_dpbf16_ps(regC0, (__m512bh)(r6), (__m512bh)(_mm512_set1_epi32(pBi32[6]))); + regC1 = _mm512_dpbf16_ps(regC1, (__m512bh)(r7), (__m512bh)(_mm512_set1_epi32(pBi32[7]))); + regC0 = _mm512_dpbf16_ps(regC0, (__m512bh)(r8), (__m512bh)(_mm512_set1_epi32(pBi32[8]))); + regC1 = _mm512_dpbf16_ps(regC1, (__m512bh)(r9), (__m512bh)(_mm512_set1_epi32(pBi32[9]))); + regC0 = _mm512_dpbf16_ps(regC0, (__m512bh)(ra), (__m512bh)(_mm512_set1_epi32(pBi32[10]))); + regC1 = _mm512_dpbf16_ps(regC1, (__m512bh)(rb), (__m512bh)(_mm512_set1_epi32(pBi32[11]))); + regC0 = _mm512_dpbf16_ps(regC0, (__m512bh)(rc), (__m512bh)(_mm512_set1_epi32(pBi32[12]))); + regC1 = _mm512_dpbf16_ps(regC1, (__m512bh)(rd), (__m512bh)(_mm512_set1_epi32(pBi32[13]))); + regC0 = _mm512_dpbf16_ps(regC0, (__m512bh)(re), (__m512bh)(_mm512_set1_epi32(pBi32[14]))); + regC1 = _mm512_dpbf16_ps(regC1, (__m512bh)(rf), (__m512bh)(_mm512_set1_epi32(pBi32[15]))); + } + regC0 = _mm512_add_ps(regC0, regC1); + _mm512_mask_storeu_ps (vecC + m, kmask, regC0); + //auto regOut = _mm512_cvtne2ps_pbh(regC0, regC0); // only 16 ov::bfloat16 results in lower 256bits + //_mm256_storeu_si256(reinterpret_cast<__m256i_u *>(vecC + m), _mm512_extracti64x4_epi64(regOut, 0)); + } + } +}; + +} // namespace amx diff --git a/src/rotary_kernel_avx2.hpp b/src/rotary_kernel_avx2.hpp new file mode 100644 index 0000000..5bebd50 --- /dev/null +++ b/src/rotary_kernel_avx2.hpp @@ -0,0 +1,108 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#ifdef _WIN32 +#include +#else +#include +#include +#endif +#include "common/bf16.hpp" +#include "llm_types.hpp" +#include "utility_kernel_avx2.hpp" + +namespace llmdnn { + inline void rotary_avx2(size_t N, float* cos, float* sin, float* q_src, float* k_src, float* q_dst, float* k_dst) { + auto half = N / 2; + // for (size_t i = 0; i < half; i++) { + // q_dst[i] = q_src[i] * cos[i] - q_src[i + half] * sin[i]; + // k_dst[i] = k_src[i] * cos[i] - k_src[i + half] * sin[i]; + // } + // for (size_t i = half; i < N; i++) { + // q_dst[i] = q_src[i] * cos[i] + q_src[i - half] * sin[i]; + // k_dst[i] = k_src[i] * cos[i] + k_src[i - half] * sin[i]; + // } + size_t tail = half % 8; + auto x_mask = get_mask(tail); + size_t i; + for (i = 0; i < half - tail; i += 8) { + auto q_f = _mm256_loadu_ps(q_src + i + half); + auto k_f = _mm256_loadu_ps(k_src + i + half); + auto cos_f = _mm256_loadu_ps(cos + i); + auto sin_f = _mm256_loadu_ps(sin + i); + auto q_dst_f = _mm256_mul_ps(q_f, sin_f); + auto k_dst_f = _mm256_mul_ps(k_f, sin_f); + + q_f = _mm256_loadu_ps(q_src + i); + k_f = _mm256_loadu_ps(k_src + i); + + q_dst_f = _mm256_fmsub_ps(q_f, cos_f, q_dst_f); + k_dst_f = _mm256_fmsub_ps(k_f, cos_f, k_dst_f); + + _mm256_storeu_ps(q_dst + i, q_dst_f); + _mm256_storeu_ps(k_dst + i, k_dst_f); + } + if (tail) { + auto q_f = _mm256_maskload_ps(q_src + i + half, x_mask); + auto k_f = _mm256_maskload_ps(k_src + i + half, x_mask); + auto cos_f = _mm256_maskload_ps(cos + i, x_mask); + auto sin_f = _mm256_maskload_ps(sin + i, x_mask); + auto q_dst_f = _mm256_mul_ps(q_f, sin_f); + auto k_dst_f = _mm256_mul_ps(k_f, sin_f); + + q_f = _mm256_maskload_ps(q_src + i, x_mask); + k_f = _mm256_maskload_ps(k_src + i, x_mask); + + q_dst_f = _mm256_fmsub_ps(q_f, cos_f, q_dst_f); + k_dst_f = _mm256_fmsub_ps(k_f, cos_f, k_dst_f); + + _mm256_maskstore_ps(q_dst + i, x_mask, q_dst_f); + _mm256_maskstore_ps(k_dst + i, x_mask, k_dst_f); + } + // second half + q_src += half; + k_src += half; + cos += half; + sin += half; + q_dst += half; + k_dst += half; + for (i = 0; i < half - tail; i += 8) { + auto q_f = _mm256_loadu_ps(q_src + i - half); + auto k_f = _mm256_loadu_ps(k_src + i - half); + auto cos_f = _mm256_loadu_ps(cos + i); + auto sin_f = _mm256_loadu_ps(sin + i); + auto q_dst_f = _mm256_mul_ps(q_f, sin_f); + auto k_dst_f = _mm256_mul_ps(k_f, sin_f); + + q_f = _mm256_loadu_ps(q_src + i); + k_f = _mm256_loadu_ps(k_src + i); + + q_dst_f = _mm256_fmadd_ps(q_f, cos_f, q_dst_f); + k_dst_f = _mm256_fmadd_ps(k_f, cos_f, k_dst_f); + + _mm256_storeu_ps(q_dst + i, q_dst_f); + _mm256_storeu_ps(k_dst + i, k_dst_f); + } + if (tail) { + auto q_f = _mm256_maskload_ps(q_src + i - half, x_mask); + auto k_f = _mm256_maskload_ps(k_src + i - half, x_mask); + auto cos_f = _mm256_maskload_ps(cos + i, x_mask); + auto sin_f = _mm256_maskload_ps(sin + i, x_mask); + auto q_dst_f = _mm256_mul_ps(q_f, sin_f); + auto k_dst_f = _mm256_mul_ps(k_f, sin_f); + + q_f = _mm256_maskload_ps(q_src + i, x_mask); + k_f = _mm256_maskload_ps(k_src + i, x_mask); + + q_dst_f = _mm256_fmadd_ps(q_f, cos_f, q_dst_f); + k_dst_f = _mm256_fmadd_ps(k_f, cos_f, k_dst_f); + + _mm256_maskstore_ps(q_dst + i, x_mask, q_dst_f); + _mm256_maskstore_ps(k_dst + i, x_mask, k_dst_f); + } + } +} // namespace llmdnn diff --git a/src/rotary_kernel_avx512.hpp b/src/rotary_kernel_avx512.hpp new file mode 100644 index 0000000..0a4cb48 --- /dev/null +++ b/src/rotary_kernel_avx512.hpp @@ -0,0 +1,135 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#ifdef _WIN32 +#include +#else +#include +#include +#endif +#include "common/bf16.hpp" +#include "llm_types.hpp" +#include "utility_kernel_avx512.hpp" + +namespace llmdnn { + template + void rotary_avx512(size_t N, float* cos, float* sin, T* q_src, T* k_src, T* q_dst, T* k_dst) { + static_assert(std::is_same::value, + "rotary_avx512 only support output data types ov::bfloat16/int8_t"); + auto half = N / 2; + // for (size_t i = 0; i < half; i++) { + // q_dst[i] = q_src[i] * cos[i] - q_src[i + half] * sin[i]; + // k_dst[i] = k_src[i] * cos[i] - k_src[i + half] * sin[i]; + // } + // for (size_t i = half; i < N; i++) { + // q_dst[i] = q_src[i] * cos[i] + q_src[i - half] * sin[i]; + // k_dst[i] = k_src[i] * cos[i] + k_src[i - half] * sin[i]; + // } + size_t tail = half % 16; + __mmask16 x_mask = _cvtu32_mask16(0xFFFFu >> (16 - tail)); + size_t i; + for (i = 0; i < half - tail; i += 16) { + auto q = _mm256_loadu_epi16(q_src + i + half); + auto q_f = _mm512_cvtpbh_ps((__m256bh)q); + auto k = _mm256_loadu_epi16(k_src + i + half); + auto k_f = _mm512_cvtpbh_ps((__m256bh)k); + auto cos_f = _mm512_loadu_ps(cos + i); + auto sin_f = _mm512_loadu_ps(sin + i); + auto q_dst_f = _mm512_mul_ps(q_f, sin_f); + auto k_dst_f = _mm512_mul_ps(k_f, sin_f); + + q = _mm256_loadu_epi16(q_src + i); + q_f = _mm512_cvtpbh_ps((__m256bh)q); + k = _mm256_loadu_epi16(k_src + i); + k_f = _mm512_cvtpbh_ps((__m256bh)k); + + q_dst_f = _mm512_fmsub_ps(q_f, cos_f, q_dst_f); + k_dst_f = _mm512_fmsub_ps(k_f, cos_f, k_dst_f); + + auto out = _mm512_cvtne2ps_pbh(q_dst_f, q_dst_f); + _mm256_storeu_si256(reinterpret_cast<__m256i *>(q_dst + i), _mm512_extracti64x4_epi64((__m512i)out, 0)); + out = _mm512_cvtne2ps_pbh(k_dst_f, k_dst_f); + _mm256_storeu_si256(reinterpret_cast<__m256i *>(k_dst + i), _mm512_extracti64x4_epi64((__m512i)out, 0)); + } + if (tail) { + auto q = _mm256_maskz_loadu_epi16(x_mask, q_src + i + half); + auto q_f = _mm512_cvtpbh_ps((__m256bh)q); + auto k = _mm256_maskz_loadu_epi16(x_mask, k_src + i + half); + auto k_f = _mm512_cvtpbh_ps((__m256bh)k); + auto cos_f = _mm512_maskz_loadu_ps(x_mask, cos + i); + auto sin_f = _mm512_maskz_loadu_ps(x_mask, sin + i); + auto q_dst_f = _mm512_mul_ps(q_f, sin_f); + auto k_dst_f = _mm512_mul_ps(k_f, sin_f); + + q = _mm256_maskz_loadu_epi16(x_mask, q_src + i); + q_f = _mm512_cvtpbh_ps((__m256bh)q); + k = _mm256_maskz_loadu_epi16(x_mask, k_src + i); + k_f = _mm512_cvtpbh_ps((__m256bh)k); + + q_dst_f = _mm512_fmsub_ps(q_f, cos_f, q_dst_f); + k_dst_f = _mm512_fmsub_ps(k_f, cos_f, k_dst_f); + + auto out = _mm512_cvtne2ps_pbh(q_dst_f, q_dst_f); + _mm256_mask_storeu_epi16(q_dst + i, x_mask, _mm512_extracti64x4_epi64((__m512i)out, 0)); + out = _mm512_cvtne2ps_pbh(k_dst_f, k_dst_f); + _mm256_mask_storeu_epi16(k_dst + i, x_mask, _mm512_extracti64x4_epi64((__m512i)out, 0)); + } + // second half + q_src += half; + k_src += half; + cos += half; + sin += half; + q_dst += half; + k_dst += half; + for (i = 0; i < half - tail; i += 16) { + auto q = _mm256_loadu_epi16(q_src + i - half); + auto q_f = _mm512_cvtpbh_ps((__m256bh)q); + auto k = _mm256_loadu_epi16(k_src + i - half); + auto k_f = _mm512_cvtpbh_ps((__m256bh)k); + auto cos_f = _mm512_loadu_ps(cos + i); + auto sin_f = _mm512_loadu_ps(sin + i); + auto q_dst_f = _mm512_mul_ps(q_f, sin_f); + auto k_dst_f = _mm512_mul_ps(k_f, sin_f); + + q = _mm256_loadu_epi16(q_src + i); + q_f = _mm512_cvtpbh_ps((__m256bh)q); + k = _mm256_loadu_epi16(k_src + i); + k_f = _mm512_cvtpbh_ps((__m256bh)k); + + q_dst_f = _mm512_fmadd_ps(q_f, cos_f, q_dst_f); + k_dst_f = _mm512_fmadd_ps(k_f, cos_f, k_dst_f); + + auto out = _mm512_cvtne2ps_pbh(q_dst_f, q_dst_f); + _mm256_storeu_si256(reinterpret_cast<__m256i *>(q_dst + i), _mm512_extracti64x4_epi64((__m512i)out, 0)); + out = _mm512_cvtne2ps_pbh(k_dst_f, k_dst_f); + _mm256_storeu_si256(reinterpret_cast<__m256i *>(k_dst + i), _mm512_extracti64x4_epi64((__m512i)out, 0)); + } + if (tail) { + auto q = _mm256_maskz_loadu_epi16(x_mask, q_src + i - half); + auto q_f = _mm512_cvtpbh_ps((__m256bh)q); + auto k = _mm256_maskz_loadu_epi16(x_mask, k_src + i - half); + auto k_f = _mm512_cvtpbh_ps((__m256bh)k); + auto cos_f = _mm512_maskz_loadu_ps(x_mask, cos + i); + auto sin_f = _mm512_maskz_loadu_ps(x_mask, sin + i); + auto q_dst_f = _mm512_mul_ps(q_f, sin_f); + auto k_dst_f = _mm512_mul_ps(k_f, sin_f); + + q = _mm256_maskz_loadu_epi16(x_mask, q_src + i); + q_f = _mm512_cvtpbh_ps((__m256bh)q); + k = _mm256_maskz_loadu_epi16(x_mask, k_src + i); + k_f = _mm512_cvtpbh_ps((__m256bh)k); + + q_dst_f = _mm512_fmadd_ps(q_f, cos_f, q_dst_f); + k_dst_f = _mm512_fmadd_ps(k_f, cos_f, k_dst_f); + + auto out = _mm512_cvtne2ps_pbh(q_dst_f, q_dst_f); + _mm256_mask_storeu_epi16(q_dst + i, x_mask, _mm512_extracti64x4_epi64((__m512i)out, 0)); + out = _mm512_cvtne2ps_pbh(k_dst_f, k_dst_f); + _mm256_mask_storeu_epi16(k_dst + i, x_mask, _mm512_extracti64x4_epi64((__m512i)out, 0)); + } + } +} // namespace llmdnn diff --git a/src/softmax_kernel_avx512.hpp b/src/softmax_kernel_avx512.hpp new file mode 100644 index 0000000..38b4719 --- /dev/null +++ b/src/softmax_kernel_avx512.hpp @@ -0,0 +1,226 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#ifdef _WIN32 +#include +#else +#include +#include +#endif +#include "common/bf16.hpp" +#include "llm_types.hpp" +#include "utility_kernel_avx512.hpp" + +namespace llmdnn { + inline void exp_ps_avx512(__m512 & src) { + static __m512 exp_ln_flt_min_f = _mm512_castsi512_ps(_mm512_set1_epi32(0xc2aeac50)); // log(FLT_MIN) + static __m512 exp_ln_flt_max_f = _mm512_castsi512_ps(_mm512_set1_epi32(0x42b17218)); // log(FLT_MAX) + static __m512 exp_log2ef = _mm512_castsi512_ps(_mm512_set1_epi32(0x3fb8aa3b)); // log2(e) + static __m512 half = _mm512_castsi512_ps(_mm512_set1_epi32(0x3f000000)); // 0.5f + static __m512 ln2f = _mm512_castsi512_ps(_mm512_set1_epi32(0x3f317218)); // ln(2) + static __m512 one = _mm512_castsi512_ps(_mm512_set1_epi32(0x3f800000)); // 1.0f + static __m512i exponent_bias = _mm512_set1_epi32(0x0000007f); // 127 + static constexpr int n_mantissa_bits = 23; + static __m512 exp_pol1 = _mm512_castsi512_ps(_mm512_set1_epi32(0x3f7ffffb)); // p1 = 0.999999701f + static __m512 exp_pol2 = _mm512_castsi512_ps(_mm512_set1_epi32(0x3efffee3)); // p2 = 0.499991506f + static __m512 exp_pol3 = _mm512_castsi512_ps(_mm512_set1_epi32(0x3e2aad40)); // p3 = 0.166676521f + static __m512 exp_pol4 = _mm512_castsi512_ps(_mm512_set1_epi32(0x3d2b9d0d)); // p4 = 0.0418978221f + static __m512 exp_pol5 = _mm512_castsi512_ps(_mm512_set1_epi32(0x3c07cfce)); // p5 = 0.00828929059f + static __m512 two = _mm512_castsi512_ps(_mm512_set1_epi32(0x40000000)); // 2 + // exp(x) = + // = exp(n * ln(2) + r) // divide x by ln(2) and get quot and rem + // = 2^n * exp(r) // simplify the exp(n*ln(2)) expression + + // get mask of values lower than log(FLT_MIN) to zero them in the output + auto zero_mask = _mm512_cmp_ps_mask(src, exp_ln_flt_min_f, _CMP_LT_OS); + + // clip src + src = _mm512_min_ps(src, exp_ln_flt_max_f); + src = _mm512_max_ps(src, exp_ln_flt_min_f); + + // aux1 : r + auto aux1 = src; + + // calculate exp(x) + // fx = x * log2(e) + 0.5 + src = _mm512_mul_ps(src, exp_log2ef); + src = _mm512_add_ps(src, half); + + // tmp = floorf(fx) + src = _mm512_floor_ps(src); + + // aux1 = x - fx * ln2 + aux1 = _mm512_fnmadd_ps(src, ln2f, aux1); + + // We do not count 2^n here, because n can reach 128 and 2^128 is not + // representable by fp32, so to get around this problem, instead of computing + // 2^n * exp(r) will be counted 2*2^(n-1)*exp(r), because 2^127 + // and 2 are numbers representable in fp32. + + // compute 2^(n-1) + src = _mm512_sub_ps(src, one); + auto aux2_i = _mm512_cvtps_epi32(src); + aux2_i = _mm512_add_epi32(aux2_i, exponent_bias); + aux2_i = _mm512_slli_epi32 (aux2_i, n_mantissa_bits); + + // set zeroes at those points which were < log(FLT_MIN) + auto zero = _mm512_setzero_ps(); + auto aux2 = _mm512_mask_blend_ps(zero_mask, _mm512_castsi512_ps(aux2_i), zero); + + // compute polynomial + src = exp_pol5; + src = _mm512_fmadd_ps(src, aux1, exp_pol4); + src = _mm512_fmadd_ps(src, aux1, exp_pol3); + src = _mm512_fmadd_ps(src, aux1, exp_pol2); + src = _mm512_fmadd_ps(src, aux1, exp_pol1); + src = _mm512_fmadd_ps(src, aux1, one); + + // y = y * 2^n + src = _mm512_mul_ps(src, aux2); + src = _mm512_mul_ps(src, two); + } + + template + void softmax_avx512(D* dst, float* src, int N, QD quant) { + static_assert(std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value, + "softmax_avx512 only support output data types ov::bfloat16/uint8_t/int8_t/float"); + + static __m512 one = _mm512_castsi512_ps(_mm512_set1_epi32(0x3f800000)); // 1.0f + auto tail = N % 16; + __mmask16 x_mask = _cvtu32_mask16(0xFFFFu >> (16 - tail)); + + // get max + auto x_max = _mm512_set1_ps(std::numeric_limits::lowest()); + int i; + for (i = 0; i < N - tail; i += 16) { + auto x = _mm512_loadu_ps(src + i); + x_max = _mm512_max_ps(x_max, x); + } + // tails + if (tail) { + auto x = _mm512_maskz_loadu_ps(x_mask, src + i); + x_max = _mm512_mask_max_ps(x_max, x_mask, x_max, x); + } + auto max = _mm512_reduce_max_ps(x_max); + x_max = _mm512_set1_ps(max); + + // softmax + auto sum_exp = _mm512_setzero_ps(); + for(i = 0; i < N - tail; i += 16) { + auto x = _mm512_loadu_ps(src + i); + x = _mm512_sub_ps(x, x_max); + exp_ps_avx512(x); // exp(x-x_max) + sum_exp = _mm512_add_ps(sum_exp, x); // sum(exp(x-x_max)) + _mm512_storeu_ps(src + i, x); // save exp(x-x_max) + } + + // handle tails + if (tail) { + auto x = _mm512_maskz_loadu_ps(x_mask, src + i); + x = _mm512_sub_ps(x, x_max); + exp_ps_avx512(x); + x = _mm512_mask_blend_ps(x_mask, _mm512_setzero_ps(), x); + sum_exp = _mm512_add_ps(sum_exp, x); + _mm512_mask_storeu_ps(src + i, x_mask, x); + } + + auto sum = _mm512_reduce_add_ps(sum_exp); + sum_exp = _mm512_set1_ps(sum); + auto reciprocal_sum_exp = _mm512_div_ps(one, sum_exp); // 1/sum_exp + + // divide + if (std::is_same::value) { + for(i = 0; i < N - tail; i += 16) { + auto x = _mm512_loadu_ps(src + i); + x = _mm512_mul_ps(x, reciprocal_sum_exp); + _mm512_storeu_ps(dst + i, x); + } + // handle tails + if (tail) { + auto x = _mm512_maskz_loadu_ps(x_mask, src + i); + x = _mm512_mul_ps(x, reciprocal_sum_exp); + _mm512_mask_storeu_ps(dst + i, x_mask, x); + } + } + if (std::is_same::value) { + for(i = 0; i < N / 32 * 32; i += 32) { + auto x0 = _mm512_loadu_ps(src + i); + auto x1 = _mm512_loadu_ps(src + i + 16); + x0 = _mm512_mul_ps(x0, reciprocal_sum_exp); + x1 = _mm512_mul_ps(x1, reciprocal_sum_exp); + auto out = _mm512_cvtne2ps_pbh(x1, x0); + _mm512_storeu_epi32(dst + i, (__m512i)out); + } + if (i < N - tail) { + auto x = _mm512_loadu_ps(src + i); + x = _mm512_mul_ps(x, reciprocal_sum_exp); + auto out = _mm512_cvtne2ps_pbh(x, x); + _mm256_storeu_si256(reinterpret_cast<__m256i *>(dst + i), _mm512_extracti64x4_epi64((__m512i)out, 0)); + i += 16; + } + // handle tails + if (tail) { + auto x = _mm512_maskz_loadu_ps(x_mask, src + i); + x = _mm512_mul_ps(x, reciprocal_sum_exp); + auto out = _mm512_cvtne2ps_pbh(x, x); + _mm256_mask_storeu_epi16(dst + i, x_mask, _mm512_extracti64x4_epi64((__m512i)out, 0)); + } + } + if (std::is_same::value) { + __m512 q; + if constexpr (std::is_same::value) + q = _mm512_set1_ps(quant); + for(i = 0; i < N - tail; i += 16) { + if constexpr (std::is_same::value) + q = _mm512_loadu_ps(quant + i); + auto x = _mm512_loadu_ps(src + i); + x = _mm512_mul_ps(x, reciprocal_sum_exp); + x = _mm512_mul_ps(x, q); + auto x_i = _mm512_cvtps_epi32(x); + _mm512_mask_cvtsepi32_storeu_epi8(dst + i, 0xFFFF, x_i); + } + // handle tails + if (tail) { + auto x = _mm512_maskz_loadu_ps(x_mask, src + i); + if constexpr (std::is_same::value) + q = _mm512_maskz_loadu_ps(x_mask, quant + i); + x = _mm512_mul_ps(x, reciprocal_sum_exp); + x = _mm512_mul_ps(x, q); + auto x_i = _mm512_cvtps_epi32(x); + _mm512_mask_cvtsepi32_storeu_epi8(dst + i, x_mask, x_i); + } + } + if (std::is_same::value) { + auto zero = _mm512_setzero_epi32(); + __m512 q; + if constexpr (std::is_same::value) + q = _mm512_set1_ps(quant); + for(i = 0; i < N - tail; i += 16) { + if constexpr (std::is_same::value) + q = _mm512_loadu_ps(quant + i); + auto x = _mm512_loadu_ps(src + i); + x = _mm512_mul_ps(x, reciprocal_sum_exp); + x = _mm512_mul_ps(x, q); + auto x_i = _mm512_cvtps_epi32(x); + x_i = _mm512_max_epi32(x_i, zero); + _mm512_mask_cvtusepi32_storeu_epi8(dst + i, 0xFFFF, x_i); + } + // handle tails + if (tail) { + auto x = _mm512_maskz_loadu_ps(x_mask, src + i); + if constexpr (std::is_same::value) + q = _mm512_maskz_loadu_ps(x_mask, quant + i); + x = _mm512_mul_ps(x, reciprocal_sum_exp); + x = _mm512_mul_ps(x, q); + auto x_i = _mm512_cvtps_epi32(x); + x_i = _mm512_max_epi32(x_i, zero); + _mm512_mask_cvtusepi32_storeu_epi8(dst + i, x_mask, x_i); + } + } + } +} // namespace llmdnn diff --git a/src/transpose_kernel_avx512.hpp b/src/transpose_kernel_avx512.hpp new file mode 100644 index 0000000..76af1c6 --- /dev/null +++ b/src/transpose_kernel_avx512.hpp @@ -0,0 +1,111 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#ifdef _WIN32 +#include +#else +#include +#include +#endif +#include "common/bf16.hpp" +#include "llm_types.hpp" +#include "utility_kernel_avx512.hpp" + +namespace llmdnn { + template + void memcpy2d_stride_avx512(D* dst, S* src, size_t height, size_t width, size_t src_stride, size_t dst_stride, float* quant=nullptr); + + template + void memcpy2d_stride_avx512(D* dst, float* src, size_t height, size_t width, size_t src_stride, size_t dst_stride, float* quant=nullptr) { + static_assert(std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value, + "memcpy2d_stride_avx512 only support output data types ov::bfloat16/uint8_t/int8_t/float"); + + auto tail = width % 16; + __mmask16 x_mask = _cvtu32_mask16(0xFFFFu >> (16 - tail)); + + for (size_t j = 0; j < height; j++) { + size_t i; + if (std::is_same::value) { + for(i = 0; i < width - tail; i += 16) { + auto x = _mm512_loadu_ps(src + i); + _mm512_storeu_ps(reinterpret_cast(dst) + i, x); + } + // handle tails + if (tail) { + auto x = _mm512_maskz_loadu_ps(x_mask, src + i); + _mm512_mask_storeu_ps(reinterpret_cast(dst) + i, x_mask, x); + } + } + + if (std::is_same::value) { + for(i = 0; i < width / 32 * 32; i += 32) { + auto x0 = _mm512_loadu_ps(src + i); + auto x1 = _mm512_loadu_ps(src + i + 16); + auto out = _mm512_cvtne2ps_pbh(x1, x0); + _mm512_storeu_epi32(reinterpret_cast(dst) + i, (__m512i)out); + } + if (i < width - tail) { + auto x = _mm512_loadu_ps(src + i); + auto out = _mm512_cvtne2ps_pbh(x, x); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(reinterpret_cast(dst) + i), + _mm512_extracti64x4_epi64(out, 0)); + i += 16; + } + // handle tails + if (tail) { + auto x = _mm512_maskz_loadu_ps(x_mask, src + i); + auto out = _mm512_cvtne2ps_pbh(x, x); + _mm256_mask_storeu_epi16(reinterpret_cast<__m256i*>(reinterpret_cast(dst) + i), + x_mask, _mm512_extracti64x4_epi64(out, 0)); + } + } + + if (std::is_same::value) { + for(i = 0; i < width - tail; i += 16) { + auto x = _mm512_loadu_ps(src + i); + auto q = _mm512_loadu_ps(quant + i); + x = _mm512_mul_ps(x, q); + auto x_i = _mm512_cvtps_epi32(x); + _mm512_mask_cvtsepi32_storeu_epi8(reinterpret_cast(dst) + i, 0xFFFF, x_i); + } + // handle tails + if (tail) { + auto x = _mm512_maskz_loadu_ps(x_mask, src + i); + auto q = _mm512_maskz_loadu_ps(x_mask, quant + i); + x = _mm512_mul_ps(x, q); + auto x_i = _mm512_cvtps_epi32(x); + _mm512_mask_cvtsepi32_storeu_epi8(reinterpret_cast(dst) + i, x_mask, x_i); + } + } + + if (std::is_same::value) { + auto zero = _mm512_setzero_epi32(); + for(i = 0; i < width - tail; i += 16) { + auto x = _mm512_loadu_ps(src + i); + auto q = _mm512_loadu_ps(quant + i); + x = _mm512_mul_ps(x, q); + auto x_i = _mm512_cvtps_epi32(x); + x_i = _mm512_max_epi32(x_i, zero); + _mm512_mask_cvtusepi32_storeu_epi8(reinterpret_cast(dst) + i, 0xFFFF, x_i); + } + // handle tails + if (tail) { + auto x = _mm512_maskz_loadu_ps(x_mask, src + i); + auto q = _mm512_maskz_loadu_ps(x_mask, quant + i); + x = _mm512_mul_ps(x, q); + auto x_i = _mm512_cvtps_epi32(x); + x_i = _mm512_max_epi32(x_i, zero); + _mm512_mask_cvtusepi32_storeu_epi8(reinterpret_cast(dst) + i, x_mask, x_i); + } + } + + src = reinterpret_cast(reinterpret_cast(src) + src_stride); + dst = reinterpret_cast(reinterpret_cast(dst) + dst_stride); + } + } +} // namespace llmdnn diff --git a/src/utility_kernel_amx.hpp b/src/utility_kernel_amx.hpp new file mode 100644 index 0000000..3ba2443 --- /dev/null +++ b/src/utility_kernel_amx.hpp @@ -0,0 +1,95 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include + +#ifdef _WIN32 +#include +#else +#include +#endif + +struct tileconfig_t { + uint8_t palette_id; + uint8_t startRow; + uint8_t reserved[14]; + uint16_t cols[16]; + uint8_t rows[16]; + tileconfig_t() = default; + + tileconfig_t(int palette, int _startRow, const std::initializer_list> &_rows_columnsBytes) { + palette_id = palette; + startRow = _startRow; + int i; + for(i = 0; i < 14; i++) { + reserved[i] = 0; + } + i = 0; + for (const auto& ele : _rows_columnsBytes) { + rows[i] = ele.first; + cols[i] = ele.second; + i++; + } + for(; i < 16; i++) { + cols[i] = 0; + rows[i] = 0; + } + load(); + } + + tileconfig_t(int palette, int _startRow, const std::initializer_list &_rows, int columnsBytes) { + palette_id = palette; + startRow = _startRow; + int i; + for(i = 0; i < 14; i++) { + reserved[i] = 0; + } + i = 0; + for (const auto ele : _rows) { + rows[i] = ele; + cols[i] = columnsBytes; + i++; + } + for(; i < 16; i++) { + cols[i] = 0; + rows[i] = 0; + } + load(); + } + + tileconfig_t(int palette, int _startRow, int numTiles, int _rows, int columnsBytes) { + palette_id = palette; + startRow = _startRow; + int i; + for(i = 0; i < 14; i++) { + reserved[i] = 0; + } + for(i = 0; i < numTiles; i++) { + rows[i] = _rows; + cols[i] = columnsBytes; + } + for(; i < 16; i++) { + cols[i] = 0; + rows[i] = 0; + } + load(); + } + + ~tileconfig_t() { + _tile_release(); + } + + void __attribute__((noinline)) load() { + _tile_loadconfig(this); + } + + void store() { + _tile_storeconfig(this); + } +} __attribute__ ((__packed__)); diff --git a/src/utility_kernel_avx2.hpp b/src/utility_kernel_avx2.hpp new file mode 100644 index 0000000..458d59e --- /dev/null +++ b/src/utility_kernel_avx2.hpp @@ -0,0 +1,45 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include "common/bf16.hpp" +#ifdef _WIN32 +#include +#else +#include +#include +#endif + +namespace llmdnn { + +inline __m256i get_mask(int N7) { + static __m256i mask[] = { + _mm256_set_epi32( 0, 0, 0, 0, 0, 0, 0, 0), + _mm256_set_epi32( 0, 0, 0, 0, 0, 0, 0,-1), + _mm256_set_epi32( 0, 0, 0, 0, 0, 0,-1,-1), + _mm256_set_epi32( 0, 0, 0, 0, 0,-1,-1,-1), + _mm256_set_epi32( 0, 0, 0, 0,-1,-1,-1,-1), + _mm256_set_epi32( 0, 0, 0,-1,-1,-1,-1,-1), + _mm256_set_epi32( 0, 0,-1,-1,-1,-1,-1,-1), + _mm256_set_epi32( 0,-1,-1,-1,-1,-1,-1,-1), + _mm256_set_epi32(-1,-1,-1,-1,-1,-1,-1,-1), + }; + return _mm256_loadu_si256(&mask[N7]); +} + +// https://stackoverflow.com/questions/23189488/horizontal-sum-of-32-bit-floats-in-256-bit-avx-vector +static inline float _mm256_reduce_add_ps(__m256 x) { + /* ( x3+x7, x2+x6, x1+x5, x0+x4 ) */ + const __m128 x128 = _mm_add_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x)); + /* ( -, -, x1+x3+x5+x7, x0+x2+x4+x6 ) */ + const __m128 x64 = _mm_add_ps(x128, _mm_movehl_ps(x128, x128)); + /* ( -, -, -, x0+x1+x2+x3+x4+x5+x6+x7 ) */ + const __m128 x32 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55)); + /* Conversion to float is a no-op on x86-64 */ + return _mm_cvtss_f32(x32); +} + +} // namespace llmdnn diff --git a/src/utility_kernel_avx512.hpp b/src/utility_kernel_avx512.hpp new file mode 100644 index 0000000..1f44ef9 --- /dev/null +++ b/src/utility_kernel_avx512.hpp @@ -0,0 +1,205 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include "common/bf16.hpp" +#ifdef _WIN32 +#include +#else +#include +#include +#endif + +namespace llmdnn { + +/// Convert Packed BF16 Data to Packed float Data. +/// +/// \headerfile +/// +/// \param __A +/// A 256-bit vector of [16 x bfloat]. +/// \returns A 512-bit vector of [16 x float] come from convertion of __A +static __inline__ __m512 _mm512_cvtpbh_ps(__m256bh __A) { + return _mm512_castsi512_ps((__m512i)_mm512_slli_epi32( + (__m512i)_mm512_cvtepi16_epi32((__m256i)__A), 16)); +} + +// Store masks. The highest bit in each byte indicates the byte to store. +alignas(16) const unsigned char masks[16][16] = +{ + { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }, + { 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }, + { 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }, + { 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }, + { 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }, + { 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }, + { 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }, + { 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }, + { 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }, + { 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }, + { 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }, + { 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00 }, + { 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00 }, + { 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00 }, + { 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00 }, + { 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00 } +}; + +inline void store_n(__m128i mm, unsigned int n, void* storage) +{ + _mm_maskmoveu_si128(mm, reinterpret_cast< const __m128i& >(masks[n]), static_cast< char* >(storage)); +} + +inline void quant_i8_avx512(void* dst, void* src, size_t ele_num, float scale) { + size_t i = 0; + auto* a = reinterpret_cast(src); + int8_t* d = reinterpret_cast(dst); + auto s = _mm512_set1_ps(scale); + for (; i < ele_num / 16 * 16; i += 16) { + auto a0 = _mm256_loadu_epi16(a); + auto a0_f = _mm512_cvtpbh_ps((__m256bh)a0); + auto d_f = _mm512_mul_ps(a0_f, s); + auto d_i = _mm512_cvtps_epi32(d_f); + auto d_i8 = _mm512_cvtsepi32_epi8(d_i); + _mm_storeu_si128(reinterpret_cast<__m128i *>(d), d_i8); + a += 16; + d += 16; + } + if (i != ele_num) { + // https://stackoverflow.com/questions/40391708/convert-16-bit-mask-mmask16-to-m128i-control-byte-mask-on-knl-xeon-phi-72 + __mmask16 msk = _cvtu32_mask16(0xFFFFu >> (16 - (ele_num % 16))); + auto a0 = _mm256_maskz_loadu_epi16(msk, a); + auto a0_f = _mm512_cvtpbh_ps((__m256bh)a0); + auto d_f = _mm512_mul_ps(a0_f, s); + auto d_i = _mm512_cvtps_epi32(d_f); + auto d_i8 = _mm512_cvtsepi32_epi8(d_i); + store_n(d_i8, ele_num % 16, d); + } +} + +// NOTE: did not handle tail because there should be enough room +inline void cvt_i32_f32_avx512(float* dst, int32_t* src, size_t ele_num) { + for (size_t i = 0; i < (ele_num + 15) / 16 * 16; i += 16) { + auto a0 = _mm512_load_epi32(src); + auto a_f = _mm512_cvtepi32_ps(a0); + _mm512_storeu_ps(dst, a_f); + src += 16; + dst += 16; + } +} + +enum mul_add2_select_flag { + mul_add2_select_flag_none, + mul_add2_select_flag_add1 = 1, + mul_add2_select_flag_add2 = 2, + mul_add2_select_flag_select = 4 +}; +template +inline void _mul_add2_select_f32_avx512(float* dst, float* src, float mul, float* add1, float* add2, uint8_t* select, int ele_num) { + auto mul_f = _mm512_set1_ps(mul); + int i; + auto tail = ele_num % 16; + __mmask16 msk = _cvtu32_mask16(0xFFFFu >> (16 - tail)); + auto zero_i32 = _mm512_setzero_si512(); + auto nfltmax = _mm512_set1_ps(-__FLT_MAX__); + for (i = 0; i < ele_num - tail; i += 16) { + auto a_f = _mm512_loadu_ps(src + i); + __m512 result; + if constexpr ((flag & (mul_add2_select_flag_add1 | mul_add2_select_flag_add2)) == mul_add2_select_flag_none) + result = _mm512_mul_ps(a_f, mul_f); + else if constexpr ((flag & (mul_add2_select_flag_add1 | mul_add2_select_flag_add2)) == mul_add2_select_flag_add2) + result = _mm512_fmadd_ps(a_f, mul_f, _mm512_loadu_ps(add2 + i)); + else { + result = _mm512_fmadd_ps(a_f, mul_f, _mm512_loadu_ps(add1 + i)); + if constexpr (flag & mul_add2_select_flag_add2) + result = _mm512_add_ps(result, _mm512_loadu_ps(add2 + i)); + } + if constexpr (flag & mul_add2_select_flag_select) { + auto r_maski8 = _mm_loadu_si128(reinterpret_cast<__m128i*>(select + i)); + auto r_maski32 = _mm512_cvtepi8_epi32(r_maski8); + r_maski32 = _mm512_sub_epi32(zero_i32, r_maski32); + auto r_maskps = _mm512_movepi32_mask(r_maski32); // -FLT_MAX if mask == 0 + if constexpr (select_nfltmax_at_0) + result = _mm512_mask_blend_ps(r_maskps, nfltmax, result); + else + result = _mm512_mask_blend_ps(r_maskps, result, nfltmax); + } + + _mm512_storeu_ps(dst + i, result); + } + if (tail) { + auto a_f = _mm512_maskz_loadu_ps(msk, src + i); + __m512 result; + if constexpr ((flag & (mul_add2_select_flag_add1 | mul_add2_select_flag_add2)) == mul_add2_select_flag_none) + result = _mm512_mul_ps(a_f, mul_f); + else if constexpr ((flag & (mul_add2_select_flag_add1 | mul_add2_select_flag_add2)) == mul_add2_select_flag_add2) + result = _mm512_fmadd_ps(a_f, mul_f, _mm512_maskz_loadu_ps(msk, add2 + i)); + else { + result = _mm512_fmadd_ps(a_f, mul_f, _mm512_maskz_loadu_ps(msk, add1 + i)); + if constexpr (flag & mul_add2_select_flag_add2) + result = _mm512_add_ps(result, _mm512_maskz_loadu_ps(msk, add2 + i)); + } + if constexpr (flag & mul_add2_select_flag_select) { + auto r_maski8 = _mm512_castsi512_si128(_mm512_maskz_loadu_epi8(msk, select + i)); + auto r_maski32 = _mm512_cvtepi8_epi32(r_maski8); + r_maski32 = _mm512_sub_epi32(zero_i32, r_maski32); + auto r_maskps = _mm512_movepi32_mask(r_maski32); // -FLT_MAX if mask == 0 + if constexpr (select_nfltmax_at_0) + result = _mm512_mask_blend_ps(r_maskps, nfltmax, result); + else + result = _mm512_mask_blend_ps(r_maskps, result, nfltmax); + } + + _mm512_mask_storeu_ps(dst + i, msk, result); + } +} + +inline void mul_add2_select_f32_avx512(float* dst, float* src, float mul, float* add1, float* add2, uint8_t* select, bool select_nfltmax_at_0, int ele_num) { + if (add1) { + if (add2) { + if (select) { + if (select_nfltmax_at_0) + _mul_add2_select_f32_avx512(mul_add2_select_flag_add1 | mul_add2_select_flag_add2 | mul_add2_select_flag_select), true>(dst, src, mul, add1, add2, select, ele_num); + else + _mul_add2_select_f32_avx512(mul_add2_select_flag_add1 | mul_add2_select_flag_add2 | mul_add2_select_flag_select)>(dst, src, mul, add1, add2, select, ele_num);; + } else { + _mul_add2_select_f32_avx512(mul_add2_select_flag_add1 | mul_add2_select_flag_add2)>(dst, src, mul, add1, add2, select, ele_num); + } + } else { + if (select) { + if (select_nfltmax_at_0) + _mul_add2_select_f32_avx512(mul_add2_select_flag_add1 | mul_add2_select_flag_select), true>(dst, src, mul, add1, add2, select, ele_num); + else + _mul_add2_select_f32_avx512(mul_add2_select_flag_add1 | mul_add2_select_flag_select)>(dst, src, mul, add1, add2, select, ele_num); + } else { + _mul_add2_select_f32_avx512(mul_add2_select_flag_add1)>(dst, src, mul, add1, add2, select, ele_num); + } + } + } else { + if (add2) { + if (select) { + if (select_nfltmax_at_0) + _mul_add2_select_f32_avx512(mul_add2_select_flag_add2 | mul_add2_select_flag_select), true>(dst, src, mul, add1, add2, select, ele_num); + else + _mul_add2_select_f32_avx512(mul_add2_select_flag_add2 | mul_add2_select_flag_select)>(dst, src, mul, add1, add2, select, ele_num);; + } else { + _mul_add2_select_f32_avx512(mul_add2_select_flag_add2)>(dst, src, mul, add1, add2, select, ele_num); + } + } else { + if (select) { + if (select_nfltmax_at_0) + _mul_add2_select_f32_avx512(mul_add2_select_flag_select), true>(dst, src, mul, add1, add2, select, ele_num); + else + _mul_add2_select_f32_avx512(mul_add2_select_flag_select)>(dst, src, mul, add1, add2, select, ele_num); + } else { + _mul_add2_select_f32_avx512(mul_add2_select_flag_none)>(dst, src, mul, add1, add2, select, ele_num); + } + } + } +} + +} // namespace llmdnn diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt new file mode 100644 index 0000000..e43a428 --- /dev/null +++ b/tests/CMakeLists.txt @@ -0,0 +1,34 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +cmake_minimum_required(VERSION 3.13) +project(cpu_extensions_tests) + +enable_testing() + +if (NOT TARGET gtest_main) + include(FetchContent) + FetchContent_Declare( + googletest + GIT_REPOSITORY https://github.com/google/googletest.git + GIT_TAG release-1.11.0 + GIT_SHALLOW TRUE + GIT_PROGRESS TRUE) + FetchContent_GetProperties(googletest) + if(NOT googletest_POPULATED) + FetchContent_Populate(googletest) + add_subdirectory(${googletest_SOURCE_DIR} ${googletest_BINARY_DIR} EXCLUDE_FROM_ALL) + endif() +endif() + +set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) + +file(GLOB_RECURSE TEST_SOURCE_FILES CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/src/*.cpp) + +find_package(OpenMP REQUIRED) + +add_executable(cpu_extensions_tests ${TEST_SOURCE_FILES}) +target_include_directories(cpu_extensions_tests PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../src) +target_link_libraries(cpu_extensions_tests cpu_extensions gtest_main stdc++ OpenMP::OpenMP_CXX) +target_compile_options(cpu_extensions_tests PRIVATE ${EXTRA_CXX_FLAGS}) diff --git a/tests/script/README.md b/tests/script/README.md new file mode 100644 index 0000000..6ef1941 --- /dev/null +++ b/tests/script/README.md @@ -0,0 +1,24 @@ +# Torch extension to help test + +## usage +prepare python enviroment +``` +python3 -m venv .env +source .env/bin/activate +sudo apt install libnuma-dev +pip3 install -r requirements.txt +``` + +compile extension +``` +./build.sh +``` +debug version extension(llmdnn needs to config debug version also) +``` +DEBUG_EXT=1 ./build.sh +``` + +run test +``` +pytest +``` \ No newline at end of file diff --git a/tests/script/build.sh b/tests/script/build.sh new file mode 100755 index 0000000..b6bc751 --- /dev/null +++ b/tests/script/build.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +pip uninstall -y llmdnn +cd ../../build/ || exit +make -j 20 +cd - || exit +cd ext || exit +python setup.py clean --all install diff --git a/tests/script/ext/CMakeLists.txt b/tests/script/ext/CMakeLists.txt new file mode 100644 index 0000000..41a0951 --- /dev/null +++ b/tests/script/ext/CMakeLists.txt @@ -0,0 +1,29 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +cmake_minimum_required (VERSION 3.13) + +project(cpu_extensions_ext LANGUAGES CXX) + +# ref:https://stackoverflow.com/questions/68401650/how-can-i-make-a-pytorch-extension-with-cmake +execute_process(COMMAND python3 -c "import torch;print(torch.utils.cmake_prefix_path+'/Torch/',end='')" OUTPUT_VARIABLE TORCH_CMAKE_PATH) +set(CMAKE_PREFIX_PATH ${TORCH_CMAKE_PATH}) + +find_package(Python REQUIRED COMPONENTS Development) +find_package(Torch REQUIRED) +find_package(OpenMP REQUIRED) + +add_library(cpu_extensions_ext SHARED + mha_gpt.cpp + module.cpp + ../../src/test_common.cpp +) +set_target_properties(cpu_extensions_ext PROPERTIES + OUTPUT_NAME "cpu_extensions_ext" + POSITION_INDEPENDENT_CODE ON + LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../) +install(TARGETS cpu_extensions_ext DESTINATION ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}) +target_compile_features(cpu_extensions_ext PRIVATE cxx_std_14) +target_include_directories(cpu_extensions_ext PRIVATE ../../src) +target_link_libraries(cpu_extensions_ext PRIVATE ${TORCH_LIBRARIES} Python::Python cpu_extensions stdc++ OpenMP::OpenMP_CXX) diff --git a/tests/script/ext/emb_gpt.cpp b/tests/script/ext/emb_gpt.cpp new file mode 100644 index 0000000..43acafc --- /dev/null +++ b/tests/script/ext/emb_gpt.cpp @@ -0,0 +1,136 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include "alloca.h" +#include "llm_tensor.hpp" +#include "module.hpp" +#include "common/utility.hpp" +#include "utility_kernel_amx.hpp" +#include "llm_emb_gpt.hpp" +#include "test_common.hpp" + +using namespace torch::indexing; + +void regclass_emb_gpt(pybind11::module m) { + m.def("emb_gpt", [] ( + const torch::Tensor& qkv, + const torch::Tensor& k_past, + const torch::Tensor& v_past, + const torch::Tensor& cos, + const torch::Tensor& sin, + const torch::Tensor& position2d_ids) { + // qkv: [batch, seq_len, (num_heads * 3 * head_size)] + // k_past: [batch, head_num, past_seq_len, head_size] + // q_dst: [batch, head_num, query_seq_len, head_size] + // k_dst: [batch, head_num, query_seq_len+past_seq_len, head_size] + // cos: [max_seq_len, rotary_dims] + // position2d_ids: [batch, 2, query_seq_len] + AT_ASSERT(qkv.dim() == 3 && k_past.dim() == 4 && v_past.dim() == 4); + auto batch = qkv.size(0); + auto query_seq_len = qkv.size(1); + auto head_num = k_past.size(1); + auto head_size = k_past.size(3); + auto past_seq_len = k_past.size(2); + auto kv_seq_len = query_seq_len + past_seq_len; + + torch::Tensor q_dst = qkv.new_empty({batch, head_num, query_seq_len, head_size}); + torch::Tensor k_dst = qkv.new_empty({batch, head_num, kv_seq_len, head_size}); + torch::Tensor v_dst = qkv.new_empty({batch, head_num, kv_seq_len, head_size}); + llmdnn::tensor q_, k_, v_, k_past_, v_past_, q_dst_, k_dst_, v_dst_, cos_, sin_, position2d_ids_; + q_.resize({batch, query_seq_len, head_num, head_size * 3}, reinterpret_cast(qkv.data_ptr()) + head_size * 0); + q_.m_dims[3] = head_size; + k_.resize({batch, query_seq_len, head_num, head_size * 3}, reinterpret_cast(qkv.data_ptr()) + head_size * 1); + k_.m_dims[3] = head_size; + v_.resize({batch, query_seq_len, head_num, head_size * 3}, reinterpret_cast(qkv.data_ptr()) + head_size * 2); + v_.m_dims[3] = head_size; + k_past_.resize({batch, head_num, past_seq_len, head_size}, reinterpret_cast(k_past.data_ptr())); + v_past_.resize({batch, head_num, past_seq_len, head_size}, reinterpret_cast(v_past.data_ptr())); + q_dst_.resize({batch, head_num, query_seq_len, head_size}, reinterpret_cast(q_dst.data_ptr())); + k_dst_.resize({batch, head_num, kv_seq_len, head_size}, reinterpret_cast(k_dst.data_ptr())); + v_dst_.resize({batch, head_num, kv_seq_len, head_size}, reinterpret_cast(v_dst.data_ptr())); + cos_.resize({cos.size(0), cos.size(1), cos.size(2), cos.size(3)}, cos.data_ptr()); + sin_.resize({sin.size(0), sin.size(1), sin.size(2), sin.size(3)}, sin.data_ptr()); + if (position2d_ids.numel()) + position2d_ids_.resize({batch, 2, query_seq_len}, position2d_ids.data_ptr()); + + llmdnn::emb_gpt(q_, k_, v_, k_past_, v_past_, q_dst_, k_dst_, v_dst_, cos_, sin_, position2d_ids_); + + return std::make_tuple(q_dst, k_dst, v_dst); + // auto options = torch::TensorOptions().dtype(torch::kBFloat16); + // auto query = torch::from_blob(param.query, {batch, num_heads, query_seq_len, head_size}, options); + }, + py::arg("qkv"), + py::arg("k_past"), + py::arg("v_past"), + py::arg("cos"), + py::arg("sin"), + py::arg("position2d_ids"), + R"( + exec emb + + :param num_heads: heads number. + :type num_heads: int + )"); + m.def("emb_gpt", [] ( + const torch::Tensor& qkv, + int num_kv_heads, + const torch::Tensor& k_past, + const torch::Tensor& v_past, + const torch::Tensor& cos, + const torch::Tensor& sin) { + // qkv: [batch, seq_len, (head_num + num_kv_heads * 2) * head_size] + // k_past: [batch, head_num, past_seq_len, head_size] + // q_dst: [batch, head_num, query_seq_len, head_size] + // k_dst: [batch, head_num, query_seq_len+past_seq_len, head_size] + // cos: [max_seq_len, rotary_dims] + AT_ASSERT(qkv.dim() == 3 && k_past.dim() == 4 && v_past.dim() == 4); + auto batch = qkv.size(0); + auto query_seq_len = qkv.size(1); + auto head_num = k_past.size(1); + auto head_size = k_past.size(3); + auto past_seq_len = k_past.size(2); + auto kv_seq_len = query_seq_len + past_seq_len; + AT_ASSERT(qkv.size(2) / head_size - 2 * num_kv_heads == head_num); + + torch::Tensor q_dst = qkv.new_empty({batch, head_num, query_seq_len, head_size}); + torch::Tensor k_dst = qkv.new_empty({batch, head_num, kv_seq_len, head_size}); + torch::Tensor v_dst = qkv.new_empty({batch, head_num, kv_seq_len, head_size}); + // q, k, v will be [batch, seq_len, num_kv_heads, head_num/num_kv_heads|1, head_size] + llmdnn::tensor q_, k_, v_, k_past_, v_past_, q_dst_, k_dst_, v_dst_, cos_, sin_, position2d_ids_; + q_.resize({batch, query_seq_len, num_kv_heads, qkv.size(2) / head_size / num_kv_heads, head_size}, reinterpret_cast(qkv.data_ptr()) + head_size * 0); + q_.m_dims[3] = head_num / num_kv_heads; + k_.resize({batch, query_seq_len, num_kv_heads, qkv.size(2) / head_size / num_kv_heads, head_size}, reinterpret_cast(qkv.data_ptr()) + head_size * q_.m_dims[3]); + k_.m_dims[3] = 1; + v_.resize({batch, query_seq_len, num_kv_heads, qkv.size(2) / head_size / num_kv_heads, head_size}, reinterpret_cast(qkv.data_ptr()) + head_size * (q_.m_dims[3] + 1)); + v_.m_dims[3] = 1; + k_past_.resize({batch, head_num, past_seq_len, head_size}, reinterpret_cast(k_past.data_ptr())); + v_past_.resize({batch, head_num, past_seq_len, head_size}, reinterpret_cast(v_past.data_ptr())); + q_dst_.resize({batch, head_num, query_seq_len, head_size}, reinterpret_cast(q_dst.data_ptr())); + k_dst_.resize({batch, head_num, kv_seq_len, head_size}, reinterpret_cast(k_dst.data_ptr())); + v_dst_.resize({batch, head_num, kv_seq_len, head_size}, reinterpret_cast(v_dst.data_ptr())); + cos_.resize({cos.size(0), cos.size(1), cos.size(2), cos.size(3)}, cos.data_ptr()); + sin_.resize({sin.size(0), sin.size(1), sin.size(2), sin.size(3)}, sin.data_ptr()); + + llmdnn::emb_gpt(q_, k_, v_, k_past_, v_past_, q_dst_, k_dst_, v_dst_, cos_, sin_, position2d_ids_); + + return std::make_tuple(q_dst, k_dst, v_dst); + // auto options = torch::TensorOptions().dtype(torch::kBFloat16); + // auto query = torch::from_blob(param.query, {batch, num_heads, query_seq_len, head_size}, options); + }, + py::arg("qkv"), + py::arg("num_kv_heads"), + py::arg("k_past"), + py::arg("v_past"), + py::arg("cos"), + py::arg("sin"), + R"( + exec emb + + :param num_heads: heads number. + :type num_heads: int + )"); +} \ No newline at end of file diff --git a/tests/script/ext/mha_gpt.cpp b/tests/script/ext/mha_gpt.cpp new file mode 100644 index 0000000..813f891 --- /dev/null +++ b/tests/script/ext/mha_gpt.cpp @@ -0,0 +1,74 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include "alloca.h" +#include "common/bf16.hpp" +#include "llm_tensor.hpp" +#include "module.hpp" +#include "common/utility.hpp" +#include "utility_kernel_amx.hpp" +#include "llm_mha_gpt.hpp" +#include "test_common.hpp" + +void regclass_mha_gpt(pybind11::module m) { + py::class_> cls(m, "mha_gpt"); + cls.def(py::init<>()); + cls.def("exec", [] (llmdnn::mha_gpt& self, const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v, const torch::Tensor& alibi, + const torch::Tensor& attn_mask, const torch::Tensor& causal_mask, bool select_nfltmax_at_0, float normal_factor, bool use_causal) { + // q: [batch, num_heads, query_seq_len, head_size] + // k: [batch, num_heads, key_seq_len, head_size] + // v: [batch, num_heads, key_seq_len, head_size] + // attn_mask: [batch, 1, 1/query_seq_len, key_seq_len] + // out: [batch, query_seq_len, num_heads * head_size] + // alibi: [batch, num_heads, 1, key_seq_len] + // causal_mask: [batch, 1, query_seq_len, key_seq_len] + AT_ASSERT(q.dim() == 4 && k.dim() == 4 && v.dim() == 4 && attn_mask.dim() == 4); + auto batch = q.size(0); + auto num_heads = q.size(1); + auto query_seq_len = q.size(2); + auto head_size = q.size(3); + AT_ASSERT(batch == k.size(0) && batch == v.size(0) && batch == attn_mask.size(0) && + num_heads == k.size(1) && num_heads == v.size(1) && + head_size == v.size(3)); + + llmdnn::tensor q_, k_, v_, out_, attn_mask_, alibi_, causal_mask_; + q_.resize({q.size(0), q.size(1), q.size(2), q.size(3)}, reinterpret_cast(q.data_ptr())); + k_.resize({k.size(0), k.size(1), k.size(2), k.size(3)}, reinterpret_cast(k.data_ptr())); + if (k.size(2) != v.size(2)) { + // bloom k shape: [batch, num_heads, head_size, key_seq_len] + std::swap(k_.m_dims[2], k_.m_dims[3]); + std::swap(k_.m_strides[2], k_.m_strides[3]); + } + v_.resize({v.size(0), v.size(1), v.size(2), v.size(3)}, reinterpret_cast(v.data_ptr())); + auto out = q.new_empty({batch, query_seq_len, num_heads * head_size}); + out_.resize({batch, query_seq_len, num_heads * head_size}, reinterpret_cast(out.data_ptr())); + if (attn_mask.numel()) + attn_mask_.resize({attn_mask.size(0), attn_mask.size(1), attn_mask.size(2), attn_mask.size(3)}, attn_mask.data_ptr()); + if (alibi.numel()) + alibi_.resize({alibi.size(0), alibi.size(1), alibi.size(2), alibi.size(3)}, alibi.data_ptr()); + if (causal_mask.numel()) + causal_mask_.resize({causal_mask.size(0), causal_mask.size(1), causal_mask.size(2), causal_mask.size(3)}, causal_mask.data_ptr()); + self.exec(q_, k_, v_, out_, attn_mask_, alibi_, causal_mask_, select_nfltmax_at_0, normal_factor, use_causal); + + return out; + }, + py::arg("q"), + py::arg("k"), + py::arg("v"), + py::arg("alibi"), + py::arg("attn_mask"), + py::arg("causal_mask"), + py::arg("select_nfltmax_at_0"), + py::arg("normal_factor"), + py::arg("use_causal"), + R"( + exec mha + + :param num_heads: heads number. + :type num_heads: int + )"); +} \ No newline at end of file diff --git a/tests/script/ext/module.cpp b/tests/script/ext/module.cpp new file mode 100644 index 0000000..d016470 --- /dev/null +++ b/tests/script/ext/module.cpp @@ -0,0 +1,19 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include "module.hpp" +#include "utility_kernel_amx.hpp" +#include "llm_mha_gpt.hpp" +#include "test_common.hpp" + +PYBIND11_MODULE(llmdnn, m) { + static bool initAMX = initXTILE(); + if (!initAMX) { + std::cout << "init amx failed.\n"; + } + regclass_mha_gpt(m); + regclass_emb_gpt(m); +} \ No newline at end of file diff --git a/tests/script/ext/module.hpp b/tests/script/ext/module.hpp new file mode 100644 index 0000000..f7a704a --- /dev/null +++ b/tests/script/ext/module.hpp @@ -0,0 +1,10 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +void regclass_mha_gpt(pybind11::module m); +void regclass_emb_gpt(pybind11::module m); diff --git a/tests/script/ext/setup.py b/tests/script/ext/setup.py new file mode 100644 index 0000000..133ac56 --- /dev/null +++ b/tests/script/ext/setup.py @@ -0,0 +1,49 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +from setuptools import setup, Extension +from torch.utils import cpp_extension +import sys +import os + +''' +using intel compiler: +source ~/intel/oneapi/setvars.sh +export CXX=icx +export CC=icx +''' +debug = False +if 'DEBUG_EXT' in os.environ: + debug = True if os.environ['DEBUG_EXT'] == '1' else False +extra_args = ['-fopenmp', '-Wno-narrowing', '-Wno-attributes', + '-march=native'] +cpu_extensions_lib_dir = f'{os.getcwd()}/../../../build/lib' +if debug: + cpu_extensions_lib_dir = f'{os.getcwd()}/../../../build/lib' + extra_args += ['-g', '-O0'] + print('install debug version') +else: + print('install release version') + +setup(name='llmdnn', + ext_modules=[ + cpp_extension.CppExtension( + 'llmdnn', + ['module.cpp', f'../../src/test_common.cpp', + 'mha_gpt.cpp', + 'emb_gpt.cpp', + #'attn_gpt.cpp', + ], + extra_compile_args=extra_args, + include_dirs=[f'{os.getcwd()}/../../src', + f'{os.getcwd()}/../../../include', + f'{os.getcwd()}/../../../src'], + library_dirs=[f'{sys.prefix}/lib', + cpu_extensions_lib_dir], + libraries=['cpu_extensions', + 'numa', + 'stdc++']), + ], + cmdclass={'build_ext': cpp_extension.BuildExtension} + ) \ No newline at end of file diff --git a/tests/script/pytest.ini b/tests/script/pytest.ini new file mode 100644 index 0000000..e68d012 --- /dev/null +++ b/tests/script/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +addopts = --ignore=models \ No newline at end of file diff --git a/tests/script/requirements.txt b/tests/script/requirements.txt new file mode 100644 index 0000000..8ec05ac --- /dev/null +++ b/tests/script/requirements.txt @@ -0,0 +1,4 @@ +-f https://download.pytorch.org/whl/torch_stable.html +numpy==1.24.2 +torch==2.0.1+cpu +pytest diff --git a/tests/script/test_mha_bloom.py b/tests/script/test_mha_bloom.py new file mode 100644 index 0000000..3c5dfee --- /dev/null +++ b/tests/script/test_mha_bloom.py @@ -0,0 +1,227 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +import math +import sys +import torch +import numpy as np +import llmdnn as ld +from torch import nn +import torch.nn.functional as F + +# copy from transformers/models/bloom/modeling_bloom.py +class BloomAttention(nn.Module): + def __init__(self, head_dim:int, num_heads:int): + super().__init__() + + # self.pretraining_tp = config.pretraining_tp + # self.slow_but_exact = config.slow_but_exact + + # self.hidden_size = config.hidden_size + self.num_heads = num_heads + self.head_dim = head_dim + # self.split_size = self.hidden_size + # self.hidden_dropout = config.hidden_dropout + + # if self.head_dim * self.num_heads != self.hidden_size: + # raise ValueError( + # f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:" + # f" {self.num_heads})." + # ) + + # Layer-wise attention scaling + self.inv_norm_factor = 1.0 / math.sqrt(head_dim) + self.beta = 1.0 + + # self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True) + # self.dense = nn.Linear(self.hidden_size, self.hidden_size) + # self.attention_dropout = nn.Dropout(config.attention_dropout) + + # def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # """ + # Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory + # storage as `fused_qkv` + + # Args: + # fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim] + + # Returns: + # query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim] + # value: [batch_size, seq_length, num_heads, head_dim] + # """ + # batch_size, seq_length, three_times_hidden_size = fused_qkv.shape + # fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim) + # return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :] + + def _merge_heads(self, x: torch.Tensor) -> torch.Tensor: + """ + Merge heads together over the last dimenstion + + Args: + x: (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim] + + Returns: + torch.tensor: [batch_size, seq_length, num_heads * head_dim] + """ + # What we want to achieve is: + # batch_size * num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads * head_dim + batch_size_and_num_heads, seq_length, _ = x.shape + batch_size = batch_size_and_num_heads // self.num_heads + + # First view to decompose the batch size + # batch_size * num_heads, seq_length, head_dim -> batch_size, num_heads, seq_length, head_dim + x = x.view(batch_size, self.num_heads, seq_length, self.head_dim) + + # batch_size, num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads, head_dim + x = x.permute(0, 2, 1, 3) + + # batch_size, seq_length, num_heads, head_dim -> batch_size, seq_length, num_heads * head_dim + return x.reshape(batch_size, seq_length, self.num_heads * self.head_dim) + + def forward( + self, + query_layer: torch.Tensor, # [batch * head_num, q_len, head_size] + key_layer: torch.Tensor, # [batch * head_num, head_size, q_len+kv_len] + value_layer: torch.Tensor, # [batch * head_num, q_len+kv_len, head_size] + alibi: torch.Tensor, # [batch * head_num, 1, q_len+kv_len] + attention_mask: torch.Tensor, # [batch * head_num, q_len, q_len+kv_len] + ): + # fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] + + # # 3 x [batch_size, seq_length, num_heads, head_dim] + # (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) + + batch_size, _, q_length, _ = query_layer.shape + + query_layer = query_layer.reshape(-1, q_length, self.head_dim) + key_layer = key_layer.reshape(-1, key_layer.size(2), key_layer.size(3)) + value_layer = value_layer.reshape(-1, value_layer.size(2), value_layer.size(3)) + # if layer_past is not None: + # past_key, past_value = layer_past + # # concatenate along seq_length dimension: + # # - key: [batch_size * self.num_heads, head_dim, kv_length] + # # - value: [batch_size * self.num_heads, kv_length, head_dim] + # key_layer = torch.cat((past_key, key_layer), dim=2) + # value_layer = torch.cat((past_value, value_layer), dim=1) + + _, _, kv_length = key_layer.shape + + # if use_cache is True: + # present = (key_layer, value_layer) + # else: + # present = None + + # [batch_size * num_heads, q_length, kv_length] + # we use `torch.Tensor.baddbmm` instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11 + matmul_result = alibi.baddbmm( + batch1=query_layer, + batch2=key_layer, + beta=self.beta, + alpha=self.inv_norm_factor, + ) + + # change view to [batch_size, num_heads, q_length, kv_length] + attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length) + + # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length] + input_dtype = attention_scores.dtype + # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38` + if input_dtype == torch.float16: + attention_scores = attention_scores.to(torch.float) + #attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min) + attn_weights = attention_scores + attention_mask + attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype) + + # change view [batch_size x num_heads, q_length, kv_length] + attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, kv_length) + + # matmul: [batch_size * num_heads, q_length, head_dim] + context_layer = torch.bmm(attention_probs_reshaped, value_layer) + + # change view [batch_size, num_heads, q_length, head_dim] + context_layer = self._merge_heads(context_layer) + + # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232 + # if self.pretraining_tp > 1 and self.slow_but_exact: + # slices = self.hidden_size / self.pretraining_tp + # output_tensor = torch.zeros_like(context_layer) + # for i in range(self.pretraining_tp): + # output_tensor = output_tensor + F.linear( + # context_layer[:, :, int(i * slices) : int((i + 1) * slices)], + # self.dense.weight[:, int(i * slices) : int((i + 1) * slices)], + # ) + # else: + # output_tensor = self.dense(context_layer) + + # output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training) + + # outputs = (output_tensor, present) + # if output_attentions: + # outputs += (attention_probs,) + + return context_layer + +class BloomAttentionExt: + def __init__(self): + self.mha = ld.mha_gpt() + + def forward(self, query, key, value, alibi, attention_mask, normal_factor): + return self.mha.exec(query, key, value, alibi, attention_mask, torch.tensor([]), False, normal_factor, False) + +HEAD_NUM = 32 +SIZE_PER_HEAD = 80 +HIDDEN_SIZE = HEAD_NUM * SIZE_PER_HEAD +MAX_POSITION_EMBEDDINGS = 1024 #2048 +def get_ref_model(): + ref_net = BloomAttention(SIZE_PER_HEAD, HEAD_NUM) + ref_net = ref_net.to(dtype=torch.bfloat16) + return ref_net + +def test_bloom(): + inputs = [ + # q, k, v, attn_mask + # q: [batch, num_heads, query_seq_len, head_size] + # k: [batch, num_heads, head_size, key_seq_len] + # v: [batch, num_heads, value_seq_len, head_size] + # alibi: [batch, num_heads, 1, key_seq_len] + # attn: [2, 1, query_seq_len, key_seq_len] + (np.random.random(size=[2, HEAD_NUM, 2, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, SIZE_PER_HEAD, 32]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 32, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 1, 32]).astype(np.float32), + np.zeros([2, 1, 2, 32], dtype=np.float32)), + (np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, SIZE_PER_HEAD, 200]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 1, 200]).astype(np.float32), + np.zeros([2, 1, 200, 200], dtype=np.float32)), + (np.random.random(size=[2, HEAD_NUM, 1, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, SIZE_PER_HEAD, 200]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 1, 200]).astype(np.float32), + np.zeros([2, 1, 1, 200], dtype=np.float32)), + ] + ref_net = get_ref_model() + net = BloomAttentionExt() + with torch.cpu.amp.autocast(): + for (i, input) in enumerate(inputs): + q, k, v, alibi, attn_mask = input + q = torch.from_numpy(q).to(torch.bfloat16) + k = torch.from_numpy(k).to(torch.bfloat16) + v = torch.from_numpy(v).to(torch.bfloat16) + alibi = torch.from_numpy(alibi) # to(torch.bfloat16) + attn_mask = torch.from_numpy(attn_mask) + attn_mask[:,:,:,-2:] = torch.finfo(torch.float32).min + output = net.forward(q, k, v, alibi, attn_mask, normal_factor = 1.0 / math.sqrt(SIZE_PER_HEAD)) + alibi = alibi.view(-1, alibi.size(2), alibi.size(3)) + ref_output = ref_net.forward(q, k, v, alibi, attn_mask) + if not torch.allclose(ref_output, output, rtol=0.001, atol=0.01): + print(f"error at index {i} ref:\n{ref_output} \ncur:\n {output} ") + assert(False) + + print('done.') + return + +if __name__ == "__main__": + test_bloom() diff --git a/tests/script/test_mha_gpt.py b/tests/script/test_mha_gpt.py new file mode 100644 index 0000000..fc5a3aa --- /dev/null +++ b/tests/script/test_mha_gpt.py @@ -0,0 +1,222 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +import math +import sys +import torch +import numpy as np +import llmdnn as ld +from torch import nn + +# copy from transformers/models/gpt_neox/modeling_gpt_neox.py +class GPTNeoXAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.num_attention_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.head_size = self.hidden_size // self.num_attention_heads + max_positions = config.max_position_embeddings + self.register_buffer( + "bias", + torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view( + 1, 1, max_positions, max_positions + ), + ) + self.norm_factor = torch.sqrt(torch.tensor(self.head_size, dtype=torch.float32)).to(torch.get_default_dtype()) + + def forward(self, query, key, value, attention_mask, q_quant=None, k_quant=None, qk_quant=None, v_quant=None, requant=None): + if q_quant: + # quant + query = query.to(torch.float32) + key = key.to(torch.float32) + value = value.to(torch.float32) + + # Compute attention + attn_output, attn_weights = self._attn(query, key, value, attention_mask, q_quant, k_quant, qk_quant, v_quant) + if q_quant: + attn_output = (attn_output * requant).round().clamp(-128, 127).to(torch.int8) + + # Reshape outputs + attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size) + + return attn_output + + @classmethod + def _merge_heads(cls, tensor, num_attention_heads, attn_head_size): + """ + Merges attn_head_size dim and num_attn_heads dim into hidden dim + """ + # tensor [bs, num_attention_heads, seq_len, attn_head_size] + tensor = tensor.permute(0, 2, 1, 3).contiguous() + # -> [bs, seq_len, num_attention_heads, attn_head_size] + tensor = tensor.view(tensor.size(0), tensor.size(1), num_attention_heads * attn_head_size) + # -> [bs, seq_len, hidden_size] + return tensor + + def _attn(self, query, key, value, attention_mask, q_quant, k_quant, qk_quant, v_quant): + # q, k, v: [bs, num_attention_heads, seq_len, attn_head_size] + # compute causal mask from causal mask buffer + batch_size, num_attention_heads, query_length, attn_head_size = query.size() + key_length = key.size(-2) + + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] + + query = query.view(batch_size * num_attention_heads, query_length, attn_head_size) + key = key.view(batch_size * num_attention_heads, key_length, attn_head_size) + attn_scores = torch.zeros( + batch_size * num_attention_heads, + query_length, + key_length, + dtype=query.dtype, + device=key.device, + ) + if q_quant: + norm_factor = torch.tensor(1.0, dtype=self.norm_factor.dtype, device=self.norm_factor.device) / self.norm_factor * (1 / q_quant) * (1 / k_quant) + else: + norm_factor = torch.tensor(1.0, dtype=self.norm_factor.dtype, device=self.norm_factor.device) / self.norm_factor + attn_scores = torch.baddbmm( + attn_scores, + query, + key.transpose(1, 2), + beta=1.0, + alpha=(norm_factor), + ) + attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length) + + mask_value = torch.finfo(attn_scores.dtype).min + # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + mask_value = torch.tensor(mask_value, dtype=attn_scores.dtype).to(attn_scores.device) + attn_scores = torch.where(causal_mask, attn_scores, mask_value) + + if attention_mask is not None: + # Apply the attention mask + attn_scores = attn_scores + attention_mask + + attn_weights = nn.functional.softmax(attn_scores, dim=-1) + attn_weights = attn_weights.to(value.dtype) + + if q_quant: + attn_weights = (attn_weights * qk_quant).round().clamp(0, 255).to(torch.uint8).to(torch.float32) + attn_output = torch.matmul(attn_weights, value) + if q_quant: + attn_output = attn_output * ((1 / qk_quant) * (1 / v_quant)) + return attn_output, attn_weights + +class GPTNeoXAttentionExt: + def __init__(self): + self.mha = ld.mha_gpt() + + def forward(self, query, key, value, attention_mask, normal_factor, causal_mask = torch.tensor([]), select_nfltmax_at_0 = False): + return self.mha.exec(query, key, value, torch.tensor([]), attention_mask, causal_mask, select_nfltmax_at_0, normal_factor, False if causal_mask.numel() > 0 else True) + +HEAD_NUM = 32 +SIZE_PER_HEAD = 80 +HIDDEN_SIZE = HEAD_NUM * SIZE_PER_HEAD +MAX_POSITION_EMBEDDINGS = 1024 #2048 +def get_ref_model(): + class FakeConfig: + def __init__(self): + self.num_attention_heads = HEAD_NUM + self.hidden_size = HIDDEN_SIZE + self.max_position_embeddings = MAX_POSITION_EMBEDDINGS + config = FakeConfig() + ref_net = GPTNeoXAttention(config) + ref_net = ref_net.to(dtype=torch.bfloat16) + return ref_net + +def test_gpt_neox(): + inputs = [ + # q, k, v, attn_mask + # q: [batch, num_heads, query_seq_len, head_size] + # k: [batch, num_heads, key_seq_len, head_size] + # v: [batch, num_heads, value_seq_len, head_size] + # attn: [2, 1, 1, key_seq_len] + (np.random.random(size=[2, HEAD_NUM, 2, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 32, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 32, SIZE_PER_HEAD]).astype(np.float32), + np.zeros([2, 1, 1, 32], dtype=np.float32)), + (np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), + np.zeros([2, 1, 1, 200], dtype=np.float32)), + (np.random.random(size=[2, HEAD_NUM, 1, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), + np.zeros([2, 1, 1, 200], dtype=np.float32)), + ] + ref_net = get_ref_model() + net = GPTNeoXAttentionExt() + with torch.cpu.amp.autocast(): + for (i, input) in enumerate(inputs): + q, k, v, attn_mask = input + q = torch.from_numpy(q).to(torch.bfloat16) + k = torch.from_numpy(k).to(torch.bfloat16) + v = torch.from_numpy(v).to(torch.bfloat16) + attn_mask = torch.from_numpy(attn_mask) + attn_mask[:,:,:,-2:] = torch.finfo(torch.float32).min + ref_output = ref_net.forward(q, k, v, attn_mask) + output = net.forward(q, k, v, attn_mask, normal_factor = 1.0 / math.sqrt(SIZE_PER_HEAD)) + if not torch.allclose(ref_output, output, rtol=0.001, atol=0.01): + print(f"error at index {i} ref:\n{ref_output} \ncur:\n {output} ") + assert(False) + + print('done.') + return + +def test_gpt_neox_with_causal(): + inputs = [ + # q, k, v, attn_mask + # q: [batch, num_heads, query_seq_len, head_size] + # k: [batch, num_heads, key_seq_len, head_size] + # v: [batch, num_heads, value_seq_len, head_size] + # attn: [2, 1, 1, key_seq_len] + # causal: [2, 1, query_seq_len, key_seq_len] + # (np.random.random(size=[2, HEAD_NUM, 2, SIZE_PER_HEAD]).astype(np.float32), + # np.random.random(size=[2, HEAD_NUM, 32, SIZE_PER_HEAD]).astype(np.float32), + # np.random.random(size=[2, HEAD_NUM, 32, SIZE_PER_HEAD]).astype(np.float32), + # np.zeros([2, 1, 1, 32], dtype=np.float32)), + (np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), + np.zeros([2, 1, 1, 200], dtype=np.float32)), + (np.random.random(size=[2, HEAD_NUM, 1, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), + np.zeros([2, 1, 1, 200], dtype=np.float32)), + ] + ref_net = get_ref_model() + net = GPTNeoXAttentionExt() + causal_mask = torch.tril(torch.ones((MAX_POSITION_EMBEDDINGS, MAX_POSITION_EMBEDDINGS), dtype=torch.uint8)).view( + 1, 1, MAX_POSITION_EMBEDDINGS, MAX_POSITION_EMBEDDINGS + ) + with torch.cpu.amp.autocast(): + for (i, input) in enumerate(inputs): + q, k, v, attn_mask = input + q = torch.from_numpy(q).to(torch.bfloat16) + k = torch.from_numpy(k).to(torch.bfloat16) + v = torch.from_numpy(v).to(torch.bfloat16) + + batch_size, num_attention_heads, query_length, attn_head_size = q.size() + key_length = k.size(-2) + causal_mask_sub = causal_mask[:, :, key_length - query_length : key_length, :key_length].contiguous() + # 0 means -fltmax + select_nfltmax_at_0 = True + if i == 0: + # 1 means -fltmax + causal_mask_sub = 1 - causal_mask_sub + select_nfltmax_at_0 = False + attn_mask = torch.from_numpy(attn_mask) + attn_mask[:,:,:,-2:] = torch.finfo(torch.float32).min + ref_output = ref_net.forward(q, k, v, attn_mask) + output = net.forward(q, k, v, attn_mask, normal_factor = 1.0 / math.sqrt(SIZE_PER_HEAD), causal_mask = causal_mask_sub, select_nfltmax_at_0 = select_nfltmax_at_0) + if not torch.allclose(ref_output, output, rtol=0.001, atol=0.01): + print(f"error at index {i} ref:\n{ref_output} \ncur:\n {output} ") + assert(False) + + print('done.') + return + +if __name__ == "__main__": + test_gpt_neox_with_causal() diff --git a/tests/script/test_rotary_pastkv.py b/tests/script/test_rotary_pastkv.py new file mode 100644 index 0000000..edb1866 --- /dev/null +++ b/tests/script/test_rotary_pastkv.py @@ -0,0 +1,204 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +import math +import sys +import torch +import numpy as np +import llmdnn as ld +from torch import nn + +# copy from transformers/models/gpt_neox/modeling_gpt_neox.py +class RotaryEmbedding(torch.nn.Module): + def __init__(self, dim, max_position_embeddings, base=10000, device=None): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) + self.register_buffer("inv_freq", inv_freq) + + # Build here to make `torch.jit.trace` work. + self.max_seq_len_cached = max_position_embeddings + t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.cos_cached = emb.cos()[None, None, :, :] + self.sin_cached = emb.sin()[None, None, :, :] + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. + if seq_len > self.max_seq_len_cached: + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + self.cos_cached = emb.cos()[None, None, :, :] + self.sin_cached = emb.sin()[None, None, :, :] + return self.cos_cached[:seq_len, ...].to(x.device), self.sin_cached[:seq_len, ...].to(x.device) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0): + cos = cos[..., offset : q.shape[-2] + offset, :] + sin = sin[..., offset : q.shape[-2] + offset, :] + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + +class GPTNeoXAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.num_attention_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.head_size = self.hidden_size // self.num_attention_heads + self.rotary_ndims = int(self.head_size * config.rotary_pct) + self.rotary_emb = RotaryEmbedding( + self.rotary_ndims, config.max_position_embeddings, base=config.rotary_emb_base + ) + + # qkv: [batch, seq_len, (num_heads * 3 * head_size)] + # layer_past: [batch, num_attention_heads, past_seq_len, head_size] + # return: (key, value), key/value: [batch, num_attention_heads, seq_len+layer_past[0].shape[-2], head_size] + def forward(self, qkv, layer_past): + has_layer_past = layer_past is not None + # [batch, seq_len, (num_heads * 3 * head_size)] + # --> [batch, seq_len, num_heads, 3 * head_size] + new_qkv_shape = qkv.size()[:-1] + (self.num_attention_heads, 3 * self.head_size) + qkv = qkv.view(*new_qkv_shape) + + # [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size] + query = qkv[..., : self.head_size].permute(0, 2, 1, 3) + key = qkv[..., self.head_size : 2 * self.head_size].permute(0, 2, 1, 3) + value = qkv[..., 2 * self.head_size :].permute(0, 2, 1, 3) + + # Compute rotary embeddings on rotary_ndims + query_rot = query[..., : self.rotary_ndims] + query_pass = query[..., self.rotary_ndims :] + key_rot = key[..., : self.rotary_ndims] + key_pass = key[..., self.rotary_ndims :] + + # Compute token offset for rotary embeddings (when decoding) + seq_len = key.shape[-2] + offset = 0 + if has_layer_past: + offset = layer_past[0].shape[-2] + seq_len += offset + cos, sin = self.rotary_emb(value, seq_len=seq_len) + query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, offset=offset) + query = torch.cat((query, query_pass), dim=-1) + key = torch.cat((key, key_pass), dim=-1) + + # Cache QKV values + if has_layer_past: + past_key = layer_past[0] + past_value = layer_past[1] + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + present = (key, value) + + return present, query, key, value + +class GPTNeoXAttentionExt: + def __init__(self, num_attention_heads, hidden_size, head_size_aligned, max_position_embeddings, rotary_emb_base, rotary_pct, is_int8=False): + num_heads = num_attention_heads + head_size = hidden_size // num_attention_heads + max_seq_len = max_position_embeddings + + rotary_ndims = int(head_size * rotary_pct) + + inv_freq = 1.0 / (rotary_emb_base ** (torch.arange(0, rotary_ndims, 2).float() / rotary_ndims)) + + # Build here to make `torch.jit.trace` work. + self.max_seq_len_cached = max_position_embeddings + t = torch.arange(self.max_seq_len_cached, device=inv_freq.device, dtype=inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.cos_cached = emb.cos()[None, None, :, :] + self.sin_cached = emb.sin()[None, None, :, :] + + # qkv: [batch, seq_len, (num_heads * 3 * head_size)] + # layer_past_padded: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] + # past_seq_len: past_seq_len==layer_past.shape[-2] + # return: + # 0: (k, v): ([batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned], [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned]) + # 1: query: [batch, num_attention_heads, seq_len, head_size_aligned] + # 2: k: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] + # 3: v: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] + def forward(self, qkv, k_past, v_past): + return ld.emb_gpt(qkv, k_past, v_past, self.cos_cached, self.sin_cached, torch.tensor([])) + + +HEAD_NUM = 32 +SIZE_PER_HEAD = 80 +SIZE_PER_HEAD_ALIGN = 80 +HIDDEN_SIZE = HEAD_NUM * SIZE_PER_HEAD +MAX_POSITION_EMBEDDINGS = 1024 #2048 +ROTARY_EMB_BASE = 10000 +ROTARY_PCT = 0.25 +MAX_SEQ_LEN = 1024 +def get_ref_model(): + class FakeConfig: + def __init__(self): + self.num_attention_heads = HEAD_NUM + self.hidden_size = HIDDEN_SIZE + self.rotary_pct = ROTARY_PCT + self.max_position_embeddings = MAX_POSITION_EMBEDDINGS + self.rotary_emb_base = ROTARY_EMB_BASE + config = FakeConfig() + ref_net = GPTNeoXAttention(config) + ref_net = ref_net.to(dtype=torch.bfloat16) + return ref_net + +def test_gpt_neox(): + inputs = [ + # qkv: [batch, seq_len, (num_heads * 3 * head_size)] + # layer_past: [batch, num_attention_heads, past_seq_len, head_size] + (np.random.random(size=[2, 200, 3 * HEAD_NUM * SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 0, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 0, SIZE_PER_HEAD]).astype(np.float32)), + (np.random.random(size=[2, 1, 3 * HEAD_NUM * SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32)), + ] + ref_net = get_ref_model() + net = GPTNeoXAttentionExt(HEAD_NUM, HIDDEN_SIZE, SIZE_PER_HEAD_ALIGN, MAX_POSITION_EMBEDDINGS, ROTARY_EMB_BASE, ROTARY_PCT) + with torch.cpu.amp.autocast(): + for (i, input) in enumerate(inputs): + qkv, layer_past_key, layer_past_value = input + qkv = torch.from_numpy(qkv).to(torch.bfloat16) + layer_past_key = torch.from_numpy(layer_past_key).to(torch.bfloat16) + layer_past_value = torch.from_numpy(layer_past_value).to(torch.bfloat16) + + _, query_ref, key_ref, value_ref = ref_net.forward(qkv, (layer_past_key, layer_past_value)) + query_ref = query_ref.to(dtype=torch.bfloat16) + key_ref = key_ref.to(dtype=torch.bfloat16) + + # no prealloc past kv + query, key, value = net.forward(qkv, layer_past_key, layer_past_value) + # check query + if not torch.allclose(query_ref, query, rtol=0.001, atol=0.01): + print(f"error at sequence query index {i} ref:\n{query_ref} \ncur:\n {query} ") + assert(False) + # check key + if not torch.allclose(key_ref, key, rtol=0.001, atol=0.01): + print(f"error at sequence key index {i} ref:\n{key_ref} \ncur:\n {key} ") + assert(False) + # check value + if not torch.allclose(value_ref, value, rtol=0.001, atol=0.01): + print(f"error at sequence value index {i} ref:\n{value_ref} \ncur:\n {value} ") + assert(False) + + print('done.') + return + +if __name__ == "__main__": + test_gpt_neox() \ No newline at end of file diff --git a/tests/script/test_rotary_pastkv_chatglm.py b/tests/script/test_rotary_pastkv_chatglm.py new file mode 100644 index 0000000..92f25f7 --- /dev/null +++ b/tests/script/test_rotary_pastkv_chatglm.py @@ -0,0 +1,353 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +import math +import sys +import torch +import numpy as np +import llmdnn as ld +from torch import nn +import torch.nn.functional as F +from torch.nn.utils import skip_init +from typing import Optional, Tuple, Union, List, Callable, Dict, Any + +# copy from chatglm-6b/modeling_chatglm.py +class RotaryEmbedding(torch.nn.Module): + def __init__(self, dim, max_position_embeddings, base=10000, precision=torch.half, learnable=False): + super().__init__() + inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) + #inv_freq = inv_freq.half() + self.learnable = learnable + if learnable: + self.inv_freq = torch.nn.Parameter(inv_freq) + self.max_seq_len_cached = None + else: + self.register_buffer('inv_freq', inv_freq) + self.max_seq_len_cached = None + self.cos_cached = None + self.sin_cached = None + # use f32 to pass accuracy test + # Build here to make `torch.jit.trace` work. + self.max_seq_len_cached = max_position_embeddings + t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.cos_cached = emb.cos()[:, None, :] + self.sin_cached = emb.sin()[:, None, :] + + self.precision = precision + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, + error_msgs): + pass + + def forward(self, x, seq_dim=1, seq_len=None): + if seq_len is None: + seq_len = x.shape[seq_dim] + if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached): + self.max_seq_len_cached = None if self.learnable else seq_len + t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype) + freqs = torch.einsum('i,j->ij', t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + if self.precision == torch.bfloat16: + emb = emb.float() + + # [sx, 1 (b * np), hn] + cos_cached = emb.cos()[:, None, :] + sin_cached = emb.sin()[:, None, :] + if self.precision == torch.bfloat16: + cos_cached = cos_cached.bfloat16() + sin_cached = sin_cached.bfloat16() + if self.learnable: + return cos_cached, sin_cached + self.cos_cached, self.sin_cached = cos_cached, sin_cached + return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...] + + def _apply(self, fn): + if self.cos_cached is not None: + self.cos_cached = fn(self.cos_cached) + if self.sin_cached is not None: + self.sin_cached = fn(self.sin_cached) + return super()._apply(fn) + + +def rotate_half(x): + x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions + + +@torch.jit.script +def apply_rotary_pos_emb_index(q, k, cos, sin, position_id): + # position_id: [b, sq], q, k: [b, sq, np, hn], cos: [sq, 1, hn] -> [b, sq, 1, hn] + cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), \ + F.embedding(position_id, sin.squeeze(1)).unsqueeze(2) + q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) + return q, k + +class SelfAttention(torch.nn.Module): + def __init__(self, max_position_embeddings, hidden_size, num_attention_heads, + layer_id, hidden_size_per_attention_head=None, bias=True, + params_dtype=torch.float, position_encoding_2d=True, empty_init=True): + super(SelfAttention, self).__init__() + + self.layer_id = layer_id + self.hidden_size = hidden_size + self.hidden_size_per_partition = hidden_size + self.num_attention_heads = num_attention_heads + self.num_attention_heads_per_partition = num_attention_heads + self.position_encoding_2d = position_encoding_2d + self.rotary_emb = RotaryEmbedding( + self.hidden_size // (self.num_attention_heads * 2) + if position_encoding_2d + else self.hidden_size // self.num_attention_heads, + max_position_embeddings, + base=10000, + precision=torch.half, + learnable=False, + ) + + self.scale_mask_softmax = None + + if hidden_size_per_attention_head is None: + self.hidden_size_per_attention_head = hidden_size // num_attention_heads + else: + self.hidden_size_per_attention_head = hidden_size_per_attention_head + + self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head + + @staticmethod + def attention_mask_func(attention_scores, attention_mask): + attention_scores.masked_fill_(attention_mask, -10000.0) + return attention_scores + + def split_tensor_along_last_dim(self, tensor, num_partitions, + contiguous_split_chunks=False): + """Split a tensor along its last dimension. + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + """ + # Get the size and dimension. + last_dim = tensor.dim() - 1 + last_dim_size = tensor.size()[last_dim] // num_partitions + # Split. + tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) + # Note: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tensor_list + + def attention_fn( + self, + query_layer, + key_layer, + value_layer, + layer_past=None + ): + # batch, seqlen, num_attention_heads, hidden_size_per_attention_head + b, seq_len, nh, hidden_size = key_layer.shape + + # =================================== + # Raw attention scores. [b, np, s, s] + # =================================== + + # [b, np, sq, sk] + output_size = (query_layer.size(0), query_layer.size(2), query_layer.size(1), key_layer.size(1)) + + # # [sq, b, np, hn] -> [sq, b * np, hn] + # query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) + # # [sk, b, np, hn] -> [sk, b * np, hn] + # key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) + # [b, sq, np, hn] -> [b, np, sq, hn] + query_layer = query_layer.transpose(1, 2) + # [b, sk, np, hn] -> [b, np, sk, hn] + key_layer = key_layer.transpose(1, 2) + # [b, np, sq, hn] -> [b * np, sq, hn] + #query_layer = query_layer.view(output_size[0] * output_size[1], output_size[2], -1) + # [b, np, sk, hn] -> [b * np, sk, hn] + #key_layer = key_layer.view(output_size[0] * output_size[1], output_size[3], -1) + + # value_layer -> context layer. + # [b, sk, np, hn] --> [b, np, sq, hn] + + # context layer shape: [b, np, sq, hn] + output_size = (value_layer.size(0), value_layer.size(2), query_layer.size(1), value_layer.size(3)) + + # # change view [sk, b * np, hn] + # value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) + # [b, sk, np, hn] -> [b, np, sk, hn] + value_layer = value_layer.transpose(1, 2) + + if layer_past is not None: + past_key, past_value = layer_past[0], layer_past[1] + key_layer = torch.cat((past_key, key_layer), dim=2) + value_layer = torch.cat((past_value, value_layer), dim=2) + + # [b, np, sk, hn] -> [b * np, sk, hn] + #value_layer = value_layer.view(output_size[0] * output_size[1], value_layer.size(2), output_size[3]) + return query_layer, key_layer, value_layer + + def forward( + self, + qkv: torch.Tensor, # [batch, seq_len, 3 * hidden_size] + position_ids, # [batch, 2, query_seq_len] + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None + ): + """ + hidden_states: [batch, seq_len, hidden_size] + attention_mask: [(1, 1), seq_len, seq_len] + """ + + # [batch, seq_len, 3 * hidden_size] + mixed_raw_layer = qkv # self.query_key_value(hidden_states) + + # [batch, seq_len, 3 * hidden_size] --> [batch, seq_len, num_attention_heads, 3 * hidden_size_per_attention_head] + new_tensor_shape = mixed_raw_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head, + ) + mixed_raw_layer = mixed_raw_layer.view(*new_tensor_shape) + + # [batch, seq_len, num_attention_heads, hidden_size_per_attention_head] + (query_layer, key_layer, value_layer) = self.split_tensor_along_last_dim(mixed_raw_layer, 3) + + if self.position_encoding_2d: + q1, q2 = query_layer.chunk(2, dim=(query_layer.ndim - 1)) + k1, k2 = key_layer.chunk(2, dim=(key_layer.ndim - 1)) + cos, sin = self.rotary_emb(q1, seq_len=position_ids.max() + 1) + # xxx + position_ids, block_position_ids = position_ids[:, 0, :].contiguous(), \ + position_ids[:, 1, :].contiguous() + q1, k1 = apply_rotary_pos_emb_index(q1, k1, cos, sin, position_ids) + q2, k2 = apply_rotary_pos_emb_index(q2, k2, cos, sin, block_position_ids) + query_layer = torch.concat([q1, q2], dim=(q1.ndim - 1)) + key_layer = torch.concat([k1, k2], dim=(k1.ndim - 1)) + else: + position_ids = position_ids.transpose(0, 1) + cos, sin = self.rotary_emb(value_layer, seq_len=position_ids.max() + 1) + # [batch, seq_len, num_attention_heads, hidden_size_per_attention_head] + query_layer, key_layer = apply_rotary_pos_emb_index(query_layer, key_layer, cos, sin, position_ids) + + query_layer = query_layer.to(dtype=torch.bfloat16) + key_layer = key_layer.to(dtype=torch.bfloat16) + + # [batch, seq_len, hidden_size] + query_layer, key_layer, value_layer = self.attention_fn( + query_layer=query_layer, + key_layer=key_layer, + value_layer=value_layer, + layer_past=layer_past, + ) + + return query_layer, key_layer, value_layer + + +class GPTNeoXAttentionExt: + def __init__(self, num_attention_heads, hidden_size, head_size_aligned, max_position_embeddings, rotary_emb_base, rotary_pct, is_int8=False): + num_heads = num_attention_heads + head_size = hidden_size // num_attention_heads + max_seq_len = max_position_embeddings + + rotary_ndims = int(head_size * rotary_pct) + + inv_freq = 1. / (rotary_emb_base ** (torch.arange(0, rotary_ndims, 2).float() / rotary_ndims)) + #inv_freq = inv_freq.half() + self.max_seq_len_cached = None + self.cos_cached = None + self.sin_cached = None + # use f32 to pass accuracy test + # Build here to make `torch.jit.trace` work. + self.max_seq_len_cached = max_position_embeddings + t = torch.arange(self.max_seq_len_cached, device=inv_freq.device, dtype=inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.cos_cached = emb.cos()[None, None, :, :] + self.sin_cached = emb.sin()[None, None, :, :] + + # qkv: [batch, seq_len, (num_heads * 3 * head_size)] + # layer_past_padded: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] + # past_seq_len: past_seq_len==layer_past.shape[-2] + # return: + # 0: (k, v): ([batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned], [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned]) + # 1: query: [batch, num_attention_heads, seq_len, head_size_aligned] + # 2: k: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] + # 3: v: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] + def forward(self, qkv, k_past, v_past, position_ids): + return ld.emb_gpt(qkv, k_past, v_past, self.cos_cached, self.sin_cached, position_ids) + + +HEAD_NUM = 32 +SIZE_PER_HEAD = 80 +SIZE_PER_HEAD_ALIGN = 80 +HIDDEN_SIZE = HEAD_NUM * SIZE_PER_HEAD +MAX_POSITION_EMBEDDINGS = 1024 #2048 +ROTARY_EMB_BASE = 10000 +ROTARY_PCT = 0.5 +MAX_SEQ_LEN = 1024 +def get_ref_model(): + ref_net = SelfAttention(MAX_POSITION_EMBEDDINGS, HIDDEN_SIZE, HEAD_NUM, 0, None) + ref_net = ref_net.to(dtype=torch.bfloat16) + return ref_net + +def test_chatglm(): + inputs = [ + # qkv: [batch, seq_len, (num_heads * 3 * head_size)] + # layer_past: [batch, num_attention_heads, past_seq_len, head_size] + (np.random.random(size=[2, 200, 3 * HEAD_NUM * SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 0, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 0, SIZE_PER_HEAD]).astype(np.float32)), + (np.random.random(size=[2, 1, 3 * HEAD_NUM * SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32)), + ] + ref_net = get_ref_model() + net = GPTNeoXAttentionExt(HEAD_NUM, HIDDEN_SIZE, SIZE_PER_HEAD_ALIGN, MAX_POSITION_EMBEDDINGS, ROTARY_EMB_BASE, ROTARY_PCT) + net_seq = GPTNeoXAttentionExt(HEAD_NUM, HIDDEN_SIZE, SIZE_PER_HEAD, MAX_POSITION_EMBEDDINGS, ROTARY_EMB_BASE, ROTARY_PCT) + with torch.cpu.amp.autocast(): + for (i, input) in enumerate(inputs): + qkv, layer_past_key, layer_past_value = input + qkv = torch.from_numpy(qkv).to(torch.bfloat16) + layer_past_key = torch.from_numpy(layer_past_key).to(torch.bfloat16) + layer_past_value = torch.from_numpy(layer_past_value).to(torch.bfloat16) + + if qkv.size(1) > 1: + seq_batch1 = torch.arange(end=qkv.size(1) - 1, dtype=torch.int32) + seq_batch1 = torch.concat((seq_batch1, seq_batch1[-1:])) + block_batch1 = torch.concat((torch.zeros(qkv.size(1) -1, dtype=torch.int32), torch.ones(1, dtype=torch.int32))) + else: + seq_batch1 = torch.tensor([3], dtype=torch.int32) + block_batch1 = torch.tensor([5], dtype=torch.int32) + seq_ids = torch.empty((qkv.size(0), 2, qkv.size(1)), dtype=torch.int32) + seq_ids[:, 0, :] = seq_batch1 + seq_ids[:, 1, :] = block_batch1 + query_ref, key_ref, value_ref = ref_net.forward(qkv, seq_ids, (layer_past_key, layer_past_value)) + query_ref = query_ref.to(dtype=torch.bfloat16) + key_ref = key_ref.to(dtype=torch.bfloat16) + + # no prealloc past kv + query, key, value = net_seq.forward(qkv, layer_past_key, layer_past_value, seq_ids) + # check query + if not torch.allclose(query_ref, query, rtol=0.001, atol=0.01): + print(f"error at sequence query index {i} ref:\n{query_ref} \ncur:\n {query} ") + assert(False) + # check key + if not torch.allclose(key_ref, key, rtol=0.001, atol=0.01): + print(f"error at sequence key index {i} ref:\n{key_ref} \ncur:\n {key} ") + assert(False) + # check value + if not torch.allclose(value_ref, value, rtol=0.001, atol=0.01): + print(f"error at sequence value index {i} ref:\n{value_ref} \ncur:\n {value} ") + assert(False) + + print('done.') + return + +if __name__ == "__main__": + test_chatglm() \ No newline at end of file diff --git a/tests/script/test_rotary_pastkv_falcon.py b/tests/script/test_rotary_pastkv_falcon.py new file mode 100644 index 0000000..b9f951e --- /dev/null +++ b/tests/script/test_rotary_pastkv_falcon.py @@ -0,0 +1,262 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +import math +import sys +import torch +import numpy as np +import llmdnn as ld +from torch import nn +import torch.nn.functional as F +from torch.nn.utils import skip_init +from typing import Optional, Tuple, Union, List, Callable, Dict, Any + +# copy from transformers/models/falcon/modeling_falcon.py +# rotary pos emb helpers (torch.jit.script does not seem to support staticmethod...) +def rotate_half(x): + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +class FalconRotaryEmbedding(nn.Module): + """Implementation of RotaryEmbedding from GPT-NeoX. + This implementation is designed to operate on queries and keys that are compatible with `[batch_size, + n_heads_per_partition, seq_len, head_dim]` (e.g. MinGPTAttention format). + """ + + def __init__(self, head_dim: int, base=10000): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.head_dim = head_dim + self.seq_len_cached = -1 + self.cos_cached: torch.Tensor | None = None + self.sin_cached: torch.Tensor | None = None + + def cos_sin(self, seq_len: int, past_key_values_length: int, device="cpu", dtype=torch.bfloat16) -> torch.Tensor: + total_length = seq_len + past_key_values_length + if total_length > self.seq_len_cached: + self.seq_len_cached = total_length + t = torch.arange(total_length, device=device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1).to(device) + + if dtype in [torch.float16, torch.bfloat16]: + emb = emb.float() + + self.cos_cached = emb.cos()[None, :, :] + self.sin_cached = emb.sin()[None, :, :] + + # self.cos_cached = self.cos_cached.type(dtype) + # self.sin_cached = self.sin_cached.type(dtype) + + return ( + self.cos_cached[:, past_key_values_length : seq_len + past_key_values_length], + self.sin_cached[:, past_key_values_length : seq_len + past_key_values_length], + ) + + def forward(self, query, key, past_key_values_length=0): + batch, seq_len, head_dim = query.shape + cos, sin = self.cos_sin(seq_len, past_key_values_length, query.device, query.dtype) + return (query * cos) + (rotate_half(query) * sin), (key * cos) + (rotate_half(key) * sin) + +class FalconAttention(torch.nn.Module): + def __init__(self, hidden_size, num_attention_heads, num_kv_heads, new_decoder_architecture=True): + super().__init__() + + self.hidden_size = hidden_size + self.num_heads = num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.split_size = self.hidden_size + + if self.head_dim * self.num_heads != self.hidden_size: + raise ValueError( + f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:" + f" {self.num_heads})." + ) + + self.maybe_rotary = FalconRotaryEmbedding(self.head_dim) #if config.rotary else lambda q, k, t: (q, k) + + # Layer-wise attention scaling + self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim) + self.beta = self.inv_norm_factor + if new_decoder_architecture: + qkv_out_dim = (num_kv_heads * 2 + num_attention_heads) * self.head_dim + elif config.multi_query: + qkv_out_dim = self.hidden_size + 2 * self.head_dim + else: + qkv_out_dim = 3 * self.hidden_size + self.new_decoder_architecture = new_decoder_architecture + self.num_kv_heads = num_kv_heads #if (self.new_decoder_architecture or not self.multi_query) else 1 + + def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Split the last dimension into (num_heads, head_dim), results share same memory storage as `fused_qkv` + + Args: + fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim] + + Returns: + query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim] + value: [batch_size, seq_length, num_heads, head_dim] + """ + if self.new_decoder_architecture: + batch, seq_len, _ = fused_qkv.shape + qkv = fused_qkv.view(batch, seq_len, -1, self.num_heads // self.num_kv_heads + 2, self.head_dim) + query = qkv[:, :, :, :-2] + key = qkv[:, :, :, [-2]] + value = qkv[:, :, :, [-1]] + key = torch.broadcast_to(key, query.shape) + value = torch.broadcast_to(value, query.shape) + + query, key, value = [x.flatten(2, 3) for x in (query, key, value)] + return query, key, value + elif not self.multi_query: + batch_size, seq_length, three_times_hidden_size = fused_qkv.shape + fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim) + return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :] + else: + batch_size, seq_length, three_times_hidden_size = fused_qkv.shape + fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads + 2, self.head_dim) + return fused_qkv[..., :-2, :], fused_qkv[..., [-2], :], fused_qkv[..., [-1], :] + + def forward( + self, + fused_qkv: torch.Tensor, # [batch_size, seq_length, 9216] + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + ): + num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads + # 3 x [batch_size, seq_length, num_heads, head_dim] + (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) + + batch_size, query_length, _, _ = query_layer.shape + + query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, query_length, self.head_dim) + key_layer = key_layer.transpose(1, 2).reshape( + batch_size * num_kv_heads, + query_length, + self.head_dim, + ) + value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_kv_heads, query_length, self.head_dim) + + layer_past = [item.view(item.size(0) * item.size(1), item.size(2), item.size(3)) for item in layer_past] + past_kv_length = 0 if layer_past is None else layer_past[0].shape[1] + query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length) + + if layer_past is not None: + past_key, past_value = layer_past + # concatenate along seq_length dimension: + # - key: [batch_size * self.num_heads, kv_length, head_dim] + # - value: [batch_size * self.num_heads, kv_length, head_dim] + key_layer = torch.cat((past_key, key_layer), dim=1) + value_layer = torch.cat((past_value, value_layer), dim=1) + + _, kv_length, _ = key_layer.shape + if use_cache: + present = (key_layer, value_layer) + else: + present = None + + query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim) + key_layer_ = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim) + value_layer_ = value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim) + + return query_layer_, key_layer_, value_layer_ + + +class FalconAttentionExt: + def __init__(self, num_attention_heads, hidden_size, max_position_embeddings, rotary_ndims, rotary_emb_base=10000): + num_heads = num_attention_heads + head_size = hidden_size // num_attention_heads + max_seq_len = max_position_embeddings + + inv_freq = 1. / (rotary_emb_base ** (torch.arange(0, rotary_ndims, 2).float() / rotary_ndims)) + #inv_freq = inv_freq.half() + self.max_seq_len_cached = None + self.cos_cached = None + self.sin_cached = None + # use f32 to pass accuracy test + # Build here to make `torch.jit.trace` work. + self.max_seq_len_cached = max_position_embeddings + t = torch.arange(self.max_seq_len_cached, device=inv_freq.device, dtype=inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.cos_cached = emb.cos()[None, None, :, :] + self.sin_cached = emb.sin()[None, None, :, :] + + # qkv: [batch, seq_len, ((num_heads + num_kv_heads * 2) * head_size)] + # layer_past_padded: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] + # past_seq_len: past_seq_len==layer_past.shape[-2] + # return: + # 0: (k, v): ([batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned], [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned]) + # 1: query: [batch, num_attention_heads, seq_len, head_size_aligned] + # 2: k: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] + # 3: v: [batch, num_attention_heads, MAX_SEQ_LEN, head_size_aligned] + def forward(self, qkv, num_kv_heads, k_past, v_past): + return ld.emb_gpt(qkv, num_kv_heads, k_past, v_past, self.cos_cached, self.sin_cached) + + +HEAD_NUM = 128 +SIZE_PER_HEAD = 64 +SIZE_PER_HEAD_ALIGN = 64 +NUM_KV_HEADS = 8 +HIDDEN_SIZE = HEAD_NUM * SIZE_PER_HEAD +MAX_POSITION_EMBEDDINGS = 1024 #2048 +ROTARY_EMB_BASE = 10000 +ROTARY_PCT = 0.5 +MAX_SEQ_LEN = 1024 +def get_ref_model(): + ref_net = FalconAttention(hidden_size=HIDDEN_SIZE, num_attention_heads=HEAD_NUM, num_kv_heads=NUM_KV_HEADS, new_decoder_architecture=True) + ref_net.maybe_rotary.cos_sin(0, MAX_SEQ_LEN) + ref_net = ref_net.to(dtype=torch.bfloat16) + return ref_net + +def test_falcon(): + inputs = [ + # qkv: [batch, seq_len, (num_heads + 2 * num_kv_heads) * head_size)] + # layer_past: [batch, num_attention_heads, past_seq_len, head_size] + (np.random.random(size=[2, 200, (HEAD_NUM + 2 * NUM_KV_HEADS) * SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 0, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 0, SIZE_PER_HEAD]).astype(np.float32)), + (np.random.random(size=[2, 1, (HEAD_NUM + 2 * NUM_KV_HEADS) * SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32), + np.random.random(size=[2, HEAD_NUM, 200, SIZE_PER_HEAD]).astype(np.float32)), + ] + ref_net = get_ref_model() + net_seq = FalconAttentionExt(HEAD_NUM, HIDDEN_SIZE, MAX_POSITION_EMBEDDINGS, SIZE_PER_HEAD, ROTARY_EMB_BASE) + with torch.cpu.amp.autocast(): + for (i, input) in enumerate(inputs): + qkv, layer_past_key, layer_past_value = input + qkv = torch.from_numpy(qkv).to(torch.bfloat16) + layer_past_key = torch.from_numpy(layer_past_key).to(torch.bfloat16) + layer_past_value = torch.from_numpy(layer_past_value).to(torch.bfloat16) + + query_ref, key_ref, value_ref = ref_net.forward(qkv, (layer_past_key, layer_past_value)) + query_ref = query_ref.to(dtype=torch.bfloat16) + key_ref = key_ref.to(dtype=torch.bfloat16) + + # no prealloc past kv + query, key, value = net_seq.forward(qkv, NUM_KV_HEADS, layer_past_key, layer_past_value) + # check query + if not torch.allclose(query_ref, query, rtol=0.001, atol=0.01): + print(f"error at sequence query index {i} ref:\n{query_ref} \ncur:\n {query} ") + assert(False) + # check key + if not torch.allclose(key_ref, key, rtol=0.001, atol=0.01): + print(f"error at sequence key index {i} ref:\n{key_ref} \ncur:\n {key} ") + assert(False) + # check value + if not torch.allclose(value_ref, value, rtol=0.001, atol=0.01): + print(f"error at sequence value index {i} ref:\n{value_ref} \ncur:\n {value} ") + assert(False) + + print('done.') + return + +if __name__ == "__main__": + test_falcon() \ No newline at end of file diff --git a/tests/src/test_common.cpp b/tests/src/test_common.cpp new file mode 100644 index 0000000..8f14d26 --- /dev/null +++ b/tests/src/test_common.cpp @@ -0,0 +1,78 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include +#include +#include +#include +#include +#include "common/simple_parallel.hpp" +#include "test_common.hpp" + +#ifndef _GNU_SOURCE +#define _GNU_SOURCE /* See feature_test_macros(7) */ +#endif +#include +#include /* For SYS_xxx definitions */ + +#define XFEATURE_XTILECFG 17 +#define XFEATURE_XTILEDATA 18 +#define XFEATURE_MASK_XTILECFG (1 << XFEATURE_XTILECFG) +#define XFEATURE_MASK_XTILEDATA (1 << XFEATURE_XTILEDATA) +#define XFEATURE_MASK_XTILE (XFEATURE_MASK_XTILECFG | XFEATURE_MASK_XTILEDATA) +#define ARCH_GET_XCOMP_PERM 0x1022 +#define ARCH_REQ_XCOMP_PERM 0x1023 + +using namespace std; +using namespace llmdnn; + +std::string dtype_to_str(data_type_t type) { + switch (type) { + case llmdnn_data_type_undef: return "undef"; + case llmdnn_f16: return "f16"; + case llmdnn_bf16: return "bf16"; + case llmdnn_f32: return "f32"; + case llmdnn_s32: return "s32"; + case llmdnn_s8: return "s8"; + case llmdnn_u8: return "u8"; + case llmdnn_f64: return "f64"; + default: return "unkown"; + } +} + +bool initXTILE() { + unsigned long bitmask = 0; + long status = syscall(SYS_arch_prctl, ARCH_GET_XCOMP_PERM, &bitmask); + if (0 != status) return false; + if (bitmask & XFEATURE_MASK_XTILEDATA) return true; + + status = syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA); + if (0 != status) + return false; // XFEATURE_XTILEDATA setup is failed, TMUL usage is not allowed + status = syscall(SYS_arch_prctl, ARCH_GET_XCOMP_PERM, &bitmask); + + // XFEATURE_XTILEDATA setup is failed, can't use TMUL + if (0 != status || !(bitmask & XFEATURE_MASK_XTILEDATA)) return false; + + // XFEATURE_XTILEDATA set successfully, TMUL usage is allowed + return true; +} + +namespace llmdnn { + +size_t get_total_threads() { + return omp_get_max_threads(); +} + +void simple_parallel_for(const size_t total, const std::function& fn) { + #pragma omp parallel for + for(size_t i = 0; i < total; i++) { + fn(i); + } +} + +} \ No newline at end of file diff --git a/tests/src/test_common.hpp b/tests/src/test_common.hpp new file mode 100644 index 0000000..22dbd83 --- /dev/null +++ b/tests/src/test_common.hpp @@ -0,0 +1,218 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include +#include +#include +#include "llm_types.hpp" +#include "llm_fc.hpp" +#include "common/tensor2d.hpp" +#include "common/bf16.hpp" +#ifdef _WIN32 +#include +#else +#include +#endif +#include + +std::string dtype_to_str(llmdnn::data_type_t type); + +using func_act = std::function; + +bool initXTILE(); + +template +void matmul(tensor2D & A, + tensor2D & B, + tensor2D & C, + float * dq = nullptr, + float * bias = nullptr, + func_act act = func_act(), + float * q = nullptr) { + int M = C.dims[0]; + int N = C.dims[1]; + int K = A.dims[1]; + assert(B.dims[0] == K); + assert(B.dims[1] == N); + for(int m = 0; m < M; m++) { + for(int n = 0; n < N; n++) { + float sum = C(m,n); + int k; + for (k = 0; (k + 32) <= K; k += 32) { + float psum0 = 0; + float psum1 = 0; + for(int p = 0; p < 32; p+=2) { + psum0 += static_cast(A(m,k+p)) * static_cast(B(k+p,n)); + psum1 += static_cast(A(m,k+p+1)) * static_cast(B(k+p+1,n)); + } + sum += (psum0 + psum1); + } + for(; k < K; k++) { + sum += static_cast(A(m,k)) * static_cast(B(k,n)); + } + if (bias) { + sum += bias[n]; + } + if (act) { + sum = act(sum); + } + //std::cout << m << "," << n << std::endl; + C(m,n) = sum; + } + } +} + +inline void matmul(tensor2D & A, + tensor2D & B, + tensor2D & C, + float * dq = nullptr, + float * bias = nullptr, + func_act act = func_act(), + float * q = nullptr) { + int M = C.dims[0]; + int N = C.dims[1]; + int K = A.dims[1]; + assert(B.dims[0] == K); + assert(B.dims[1] == N); + for(int m = 0; m < M; m++) { + for(int n = 0; n < N; n++) { + float sum = C(m,n); + for(int k = 0; k < K; k++) { + sum += static_cast(A(m,k)) * static_cast(B(k,n)); + } + if (bias) { + sum += bias[n]; + } + if (act) { + sum = act(sum); + } + C(m,n) = sum; + } + } +} + +template +void matmul(tensor2D & A, + tensor2D & B, + tensor2D & C, + float * dq = nullptr, + float * bias = nullptr, + func_act act = func_act(), + float * q = nullptr) { + int M = C.dims[0]; + int N = C.dims[1]; + int K = A.dims[1]; + assert(B.dims[0] == K); + assert(B.dims[1] == N); + for(int m = 0; m < M; m++) { + for(int n = 0; n < N; n++) { + float sum = C(m,n); + for(int k = 0; k < K; k++) { + sum += static_cast(A(m,k)) * static_cast(B(k,n)); + } + if (dq) { + sum *= dq[n]; + } + if (bias) { + sum += bias[n]; + } + if (act) { + sum = act(sum); + } + if (q) { + sum *= q[n]; + sum = std::min(static_cast(std::numeric_limits::max()), sum); + sum = std::max(static_cast(std::numeric_limits::min()), sum); + } + C(m,n) = sum; + } + } +} + +struct ANSIcolor { + const char * code; + ANSIcolor(const char * code = "0") : code(code){ + } + friend std::ostream& operator<<(std::ostream& out, const ANSIcolor& obj) { + out << "\033[" << obj.code << "m"; + return out; + } +}; + +struct pretty_size { + double sz; + std::string txt; + pretty_size(double sz, const char * unit = "") : sz(sz) { + std::stringstream ss; + ss << std::setprecision(5) << std::setw(5); + if (sz < 1024) + ss << sz; + else if (sz < 1024 * 1024) + ss << (sz / 1024) << " K"; + else if (sz < 1024 * 1024 * 1024) + ss << (sz / 1024/1024) << " M"; + else + ss << (sz / 1024 / 1024/1024) << " G"; + ss << unit; + txt = ss.str(); + } + friend std::ostream& operator<<(std::ostream& os, const pretty_size& ps) { + os << ps.txt; + return os; + } +}; + +inline int readenv(const char * name) { + int v = 0; + auto * p = std::getenv(name); + if (p) + v = std::atoi(p); + std::cout << ANSIcolor("32") << "ENV: " << name << " = " << v << std::endl << ANSIcolor(); + return v; +} + +template +struct TypeName { + static const char* get(){return typeid(T).name();} +}; + +// a specialization for each type of those you want to support +// and don't like the string returned by typeid +template <> +struct TypeName{ + static const char* get(){return "int32_t";} +}; +template <> +struct TypeName{ + static const char* get(){return "foat";} +}; +template <> +struct TypeName{ + static const char* get(){return "bfloat16";} +}; +template <> +struct TypeName{ + static const char* get(){return "int8_t";} +}; + +inline std::ostream & operator<<(std::ostream & os, const llmdnn::postops_types & steps) { + if (steps == llmdnn::NONE) + os << "NONE"; + if (steps & llmdnn::DEQUANT) + os << "_DEQUANT"; + if (steps & llmdnn::BIAS) + os << "_BIAS"; + if (steps & llmdnn::GELU) + os << "_GELU"; + if (steps & llmdnn::GELU_TANH) + os << "_GELU_TANH"; + if (steps & llmdnn::QUANT) + os << "_QUANT"; + return os; +} diff --git a/tests/src/test_fc_amx.cpp b/tests/src/test_fc_amx.cpp new file mode 100644 index 0000000..5a4b1d3 --- /dev/null +++ b/tests/src/test_fc_amx.cpp @@ -0,0 +1,345 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include +#include +#include +#include +#include "gtest/gtest.h" +#include "llm_fc.hpp" +#include "common/tensor2d.hpp" +#include "common/tensor2d_helper.hpp" +#include "llm_tensor.hpp" +#include "llm_types.hpp" +#include "test_common.hpp" + +using namespace std; +using namespace llmdnn; +using ::testing::TestWithParam; +using ::testing::Values; +using ::testing::ValuesIn; + +using FCTestShape = std::tuple; +using FCTestDTPost = std::tuple; +using FCTestParamSet = std::tuple< + FCTestDTPost, // a, b, c data type, postops + bool, // b needs transpose + FCTestShape // M, N, K + >; + +class FCTest : public TestWithParam { +public: + static std::string getTestCaseName(const testing::TestParamInfo& obj) { + FCTestDTPost types; + bool is_transpose; + postops_types postops_type; + data_type_t dt_a, dt_b, dt_c, dt_weight; + FCTestShape shape; + int M, N, K; + std::tie(types, is_transpose, shape) = obj.param; + std::tie(M, N, K) = shape; + std::tie(dt_a, dt_b, dt_c, dt_weight, postops_type) = types; + + std::ostringstream result; + result << "A_" << dtype_to_str(dt_a) << "_B_" << dtype_to_str(dt_b) + << "_C_" << dtype_to_str(dt_c) << "_WEIGHT_" << dtype_to_str(dt_weight) + << (is_transpose ? "_transpose" : "") + << "_postops_" << postops_type << "_M_" << M << "_N_" << N << "_K_" << K; + return result.str(); + } + +protected: + virtual void SetUp() override { + initXTILE(); + + FCTestShape shape; + FCTestDTPost types; + std::tie(types, _is_transpose, shape) = GetParam(); + std::tie(_M, _N, _K) = shape; + std::tie(_dt_a, _dt_b, _dt_c, _dt_weight, _postops_type) = types; + }; + + template + void do_test() { + fc_create_param param = { + _dt_a, _dt_b, _dt_c, + _is_transpose, _postops_type + }; + llmdnn::fc fc; + ASSERT_TRUE(fc.init(param)); + + tensor2D A(_M, _K, true); + tensor2D B(_K, _N, true); + tensor2D C(_M, _N, true); + tensor2D C_Ref(_M, _N, true); + tensor2D dq(1, _N); + tensor2D q(1, _N); + tensor2D bias(1, _N); + + fill_rnd(A); + fill_rnd(B); + dq = 2; + q = 2; + fill_rnd(bias); + bias = 1; + + tensor2D BT = B.Tr(true); + TB* ptr_B; + size_t ldb; + tensor weight; + if (_is_transpose) { + ptr_B = BT.data; + ldb = BT.stride; + weight.resize({ static_cast(BT.dims[0]), static_cast(BT.dims[1]) }, static_cast(ptr_B)); + } else { + ptr_B = B.data; + ldb = B.stride; + weight.resize({ static_cast(B.dims[0]), static_cast(B.dims[1]) }, static_cast(ptr_B)); + } + fc.pack_weight(weight); + tensor input, output, bias_t, q_t, dq_t; + input.resize({ static_cast(A.dims[0]), static_cast(A.dims[1]) }, static_cast(A.data)); + output.resize({ static_cast(C.dims[0]), static_cast(C.dims[1]) }, static_cast(C.data)); + dq_t.resize({ static_cast(dq.dims[0]), static_cast(dq.dims[1]) }, dq.data); + q_t.resize({ static_cast(q.dims[0]), static_cast(q.dims[1]) }, q.data); + bias_t.resize({ static_cast(bias.dims[0]), static_cast(bias.dims[1]) }, bias.data); + ASSERT_TRUE(fc.exec(input, output, dq_t, q_t, bias_t) == llmdnn::status_ok); + C_Ref = 0; + float* ptr_dq = nullptr; + float* ptr_q = nullptr; + float* ptr_bias = nullptr; + func_act act = func_act(); + if ((_postops_type & DEQUANT) && _dt_a == llmdnn::llmdnn_s8) { + ptr_dq = dq.data; + } + if (_postops_type & QUANT) { + ptr_q = q.data; + } + if (_postops_type & BIAS) { + ptr_bias = bias.data; + } + if (_postops_type & GELU) { + act = [] (float x) { + return x * 0.5 * (1 + std::erf(x / std::sqrt(2))); + }; + } + if (_postops_type & GELU_TANH) { + act = [] (float x) { + return 0.5f * x * (1.0f + std::tanh(std::sqrt(2.0f / 3.1415926f) * x * (1 + 0.044715f * x * x))); + }; + } + + matmul(A, B, C_Ref, ptr_dq, ptr_bias, act, ptr_q); + float thresh = 0.0001f; + if (std::is_same::value || std::is_same::value) + thresh = 1.1f; + if (std::is_same::value) + thresh = 0.01f; + ASSERT_TRUE(compare(C, C_Ref, thresh)); + } + + template + void do_test_wc() { + fc_create_param param = { + _dt_a, _dt_b, _dt_c, + _is_transpose, _postops_type + }; + // bf16 needs divisible by 2 + if (_K % 2 == 1) _K += 1; + + tensor2D A(_M, _K, true); + tensor2D B(_K, _N, true); + tensor2D C(_M, _N, true); + tensor2D C_Ref(_M, _N, true); + tensor2D dq(1, _N); + tensor2D zp(1, _N); + tensor2D bias(1, _N); + + fill_rnd(A); + for (int i = 0; i < _N; i++) { + // make all weight 1: w - zp == 1 + zp.data[i] = static_cast(i % 255) - 1; + dq.data[i] = (i % 3) * 0.5; + for (int j = 0; j < _K; j++) { + B(j, i) = i % 255; + } + } + fill_rnd(bias); + param.scale = dq.data; + param.zp = zp.data; + param.scale_zp_size = _N; + llmdnn::fc fc; + ASSERT_TRUE(fc.init(param)); + + tensor2D BT = B.Tr(true); + uint8_t* ptr_B; + size_t ldb; + tensor weight; + if (_is_transpose) { + ptr_B = BT.data; + ldb = BT.stride; + weight.resize({ static_cast(BT.dims[0]), static_cast(BT.dims[1]) }, static_cast(ptr_B)); + } else { + ptr_B = B.data; + ldb = B.stride; + weight.resize({ static_cast(B.dims[0]), static_cast(B.dims[1]) }, static_cast(ptr_B)); + } + fc.pack_weight(weight); + tensor input, output, bias_t, q_t, dq_t; + input.resize({ static_cast(A.dims[0]), static_cast(A.dims[1]) }, static_cast(A.data)); + output.resize({ static_cast(C.dims[0]), static_cast(C.dims[1]) }, static_cast(C.data)); + dq_t.resize({ static_cast(dq.dims[0]), static_cast(dq.dims[1]) }, dq.data); + bias_t.resize({ static_cast(bias.dims[0]), static_cast(bias.dims[1]) }, bias.data); + ASSERT_TRUE(fc.exec(input, output, dq_t, q_t, bias_t) == llmdnn::status_ok); + C_Ref = 0; + float* ptr_dq = nullptr; + float* ptr_q = nullptr; + float* ptr_bias = nullptr; + func_act act = func_act(); + ptr_dq = dq.data; + if (_postops_type & BIAS) { + ptr_bias = bias.data; + } + if (_postops_type & GELU) { + act = [] (float x) { + return x * 0.5 * (1 + std::erf(x / std::sqrt(2))); + }; + } + if (_postops_type & GELU_TANH) { + act = [] (float x) { + return 0.5f * x * (1.0f + std::tanh(std::sqrt(2.0f / 3.1415926f) * x * (1 + 0.044715f * x * x))); + }; + } + + B = 1; + matmul(A, B, C_Ref, ptr_dq, ptr_bias, act, ptr_q); + float thresh = 0.0001f; + if (std::is_same::value || std::is_same::value) + thresh = 1.1f; + if (std::is_same::value) + thresh = 0.01f; + ASSERT_TRUE(compare(C, C_Ref, thresh)); + } + + int _M, _N, _K; + bool _is_transpose; + postops_types _postops_type; + data_type_t _dt_a, _dt_b, _dt_c, _dt_weight; +}; + +TEST_P(FCTest, Func) { + if (_dt_a == llmdnn_s8 && _dt_weight == llmdnn_s8 && _dt_c == llmdnn_s8) { + do_test(); + } else if (_dt_a == llmdnn_s8 && _dt_weight == llmdnn_s8 && _dt_c == llmdnn_bf16) { + do_test(); + } else if (_dt_a == llmdnn_s8 && _dt_weight == llmdnn_s8 && _dt_c == llmdnn_f32) { + do_test(); + } else if (_dt_a == llmdnn_bf16 && _dt_weight == llmdnn_bf16 && _dt_c == llmdnn_bf16) { + do_test(); + } else if (_dt_a == llmdnn_bf16 && _dt_weight == llmdnn_bf16 && _dt_c == llmdnn_f32) { + do_test(); + } else if (_dt_a == llmdnn_bf16 && _dt_weight == llmdnn_f32 && _dt_c == llmdnn_bf16) { + do_test(); + } else if (_dt_a == llmdnn_bf16 && _dt_weight == llmdnn_f32 && _dt_c == llmdnn_f32) { + do_test(); + } else if (_dt_a == llmdnn_bf16 && _dt_weight == llmdnn_u8 && _dt_c == llmdnn_bf16) { + do_test_wc(); + } else if (_dt_a == llmdnn_bf16 && _dt_weight == llmdnn_u8 && _dt_c == llmdnn_f32) { + do_test_wc(); + } else { + ASSERT_TRUE(false); + } +} + +// supported: +// (s8,s8,s8),dq,[bias],[gelu],q +// (s8,s8,bf16),dq,[bias],[gelu] +// (s8,s8,f32),dq,[bias],[gelu] +// (bf16,bf16,bf16),[bias],[gelu] +// (bf16,bf16,f32),[bias],[gelu] +// (bf16,u8,f32),dq,[bias],[gelu] +// (bf16,u8,bf16),dq,[bias],[gelu] +const std::vector types = { + { llmdnn_s8, llmdnn_s8, llmdnn_s8, llmdnn_s8, DEQUANT_QUANT }, + { llmdnn_s8, llmdnn_s8, llmdnn_s8, llmdnn_s8, DEQUANT_BIAS_QUANT }, + { llmdnn_s8, llmdnn_s8, llmdnn_s8, llmdnn_s8, DEQUANT_GELU_QUANT }, + { llmdnn_s8, llmdnn_s8, llmdnn_s8, llmdnn_s8, DEQUANT_BIAS_GELU_QUANT }, + { llmdnn_s8, llmdnn_s8, llmdnn_s8, llmdnn_s8, DEQUANT_GELU_TANH_QUANT }, + { llmdnn_s8, llmdnn_s8, llmdnn_s8, llmdnn_s8, DEQUANT_BIAS_GELU_TANH_QUANT }, + { llmdnn_s8, llmdnn_s8, llmdnn_bf16, llmdnn_s8, DEQUANT }, + { llmdnn_s8, llmdnn_s8, llmdnn_bf16, llmdnn_s8, DEQUANT_BIAS }, + { llmdnn_s8, llmdnn_s8, llmdnn_bf16, llmdnn_s8, DEQUANT_GELU }, + { llmdnn_s8, llmdnn_s8, llmdnn_bf16, llmdnn_s8, DEQUANT_BIAS_GELU }, + { llmdnn_s8, llmdnn_s8, llmdnn_bf16, llmdnn_s8, DEQUANT_GELU_TANH }, + { llmdnn_s8, llmdnn_s8, llmdnn_bf16, llmdnn_s8, DEQUANT_BIAS_GELU_TANH }, + { llmdnn_s8, llmdnn_s8, llmdnn_f32, llmdnn_s8, DEQUANT }, + { llmdnn_s8, llmdnn_s8, llmdnn_f32, llmdnn_s8, DEQUANT_BIAS }, + { llmdnn_s8, llmdnn_s8, llmdnn_f32, llmdnn_s8, DEQUANT_GELU }, + { llmdnn_s8, llmdnn_s8, llmdnn_f32, llmdnn_s8, DEQUANT_BIAS_GELU }, + { llmdnn_s8, llmdnn_s8, llmdnn_f32, llmdnn_s8, DEQUANT_GELU_TANH }, + { llmdnn_s8, llmdnn_s8, llmdnn_f32, llmdnn_s8, DEQUANT_BIAS_GELU_TANH }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, NONE }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, BIAS }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, GELU }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, BIAS_GELU }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, GELU_TANH }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, BIAS_GELU_TANH }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, llmdnn_bf16, NONE }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, llmdnn_bf16, BIAS }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, llmdnn_bf16, GELU }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, llmdnn_bf16, BIAS_GELU }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, llmdnn_bf16, GELU_TANH }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, llmdnn_bf16, BIAS_GELU_TANH }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, llmdnn_f32, NONE }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, llmdnn_f32, BIAS }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, llmdnn_f32, GELU }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, llmdnn_f32, BIAS_GELU }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, llmdnn_f32, GELU_TANH }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, llmdnn_f32, BIAS_GELU_TANH }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, llmdnn_f32, NONE }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, llmdnn_f32, BIAS }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, llmdnn_f32, GELU }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, llmdnn_f32, BIAS_GELU }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, llmdnn_f32, GELU_TANH }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, llmdnn_f32, BIAS_GELU_TANH }, + // weight compression + { llmdnn_bf16, llmdnn_u8, llmdnn_f32, llmdnn_u8, DEQUANT }, + { llmdnn_bf16, llmdnn_u8, llmdnn_f32, llmdnn_u8, DEQUANT_BIAS }, + { llmdnn_bf16, llmdnn_u8, llmdnn_f32, llmdnn_u8, DEQUANT_GELU }, + { llmdnn_bf16, llmdnn_u8, llmdnn_f32, llmdnn_u8, DEQUANT_BIAS_GELU }, + { llmdnn_bf16, llmdnn_u8, llmdnn_bf16, llmdnn_u8, DEQUANT }, + { llmdnn_bf16, llmdnn_u8, llmdnn_bf16, llmdnn_u8, DEQUANT_BIAS }, + { llmdnn_bf16, llmdnn_u8, llmdnn_bf16, llmdnn_u8, DEQUANT_GELU }, + { llmdnn_bf16, llmdnn_u8, llmdnn_bf16, llmdnn_u8, DEQUANT_BIAS_GELU }, +}; + +// M, N, K +const std::vector shapes = { + // normal + {256, 128, 448}, + // M < 16 + {15, 129, 447}, + {15, 129, 448}, + // M in (16, 32] + {31, 129, 447}, + {31, 129, 448}, + // all tail + {256 + 9, 129, 449}, + {256 + 9, 129, 448}, + // gemv, K <= 64(32)*6 + {256, 1, 80}, +}; + +INSTANTIATE_TEST_SUITE_P(smoke_FC, FCTest, + ::testing::Combine(ValuesIn(types), + Values(true, false), + ValuesIn(shapes)), + FCTest::getTestCaseName); diff --git a/tests/src/test_fc_kernel_amx.cpp b/tests/src/test_fc_kernel_amx.cpp new file mode 100644 index 0000000..ad23884 --- /dev/null +++ b/tests/src/test_fc_kernel_amx.cpp @@ -0,0 +1,329 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include +#include +#include +#include +#include "gtest/gtest.h" +#include "llm_fc.hpp" +#include "common/tensor2d.hpp" +#include "common/tensor2d_helper.hpp" +#include "test_common.hpp" + +using namespace std; +using namespace llmdnn; +using ::testing::TestWithParam; +using ::testing::Values; +using ::testing::ValuesIn; + +using FCKernelTestShape = std::tuple; +using FCKernelTestDTPost = std::tuple; +using FCKernelTestParamSet = std::tuple< + FCKernelTestDTPost, // a, b, c data type, postops + bool, // b needs transpose + FCKernelTestShape // M, N, K + >; + +class FCKernelTest : public TestWithParam { +public: + static std::string getTestCaseName(const testing::TestParamInfo& obj) { + FCKernelTestDTPost types; + bool is_transpose; + postops_types postops_type; + data_type_t dt_a, dt_b, dt_c, dt_weight; + FCKernelTestShape shape; + int M, N, K; + std::tie(types, is_transpose, shape) = obj.param; + std::tie(M, N, K) = shape; + std::tie(dt_a, dt_b, dt_c, dt_weight, postops_type) = types; + + std::ostringstream result; + result << "A_" << dtype_to_str(dt_a) << "_B_" << dtype_to_str(dt_b) + << "_C_" << dtype_to_str(dt_c) << "_WEIGHT_" << dtype_to_str(dt_weight) + << (is_transpose ? "_transpose" : "") + << "_postops_" << postops_type << "_M_" << M << "_N_" << N << "_K_" << K; + return result.str(); + } + +protected: + virtual void SetUp() override { + initXTILE(); + + FCKernelTestShape shape; + FCKernelTestDTPost types; + std::tie(types, _is_transpose, shape) = GetParam(); + std::tie(_M, _N, _K) = shape; + std::tie(_dt_a, _dt_b, _dt_c, _dt_weight, _postops_type) = types; + }; + + template + void do_test() { + fc_kernel* fc; + fc_create_param param = { + _dt_a, _dt_b, _dt_c, + _is_transpose, _postops_type + }; + ASSERT_TRUE(fc_kernel_create(&fc, ¶m) == llmdnn::status_ok); + auto gemm = std::shared_ptr(fc, [](fc_kernel* p) { fc_kernel_destroy(p); }); + + tensor2D A(_M, _K, true); + tensor2D B(_K, _N, true); + tensor2D C(_M, _N, true); + tensor2D C_Ref(_M, _N, true); + tensor2D dq(1, _N); + tensor2D q(1, _N); + tensor2D bias(1, _N); + + fill_rnd(A); + fill_rnd(B); + dq = 2; + q = 2; + fill_rnd(bias); + bias = 1; + + tensor2D BT = B.Tr(); + TB* ptr_B; + size_t ldb; + if (_is_transpose) { + ptr_B = BT.data; + ldb = BT.stride; + } else { + ptr_B = B.data; + ldb = B.stride; + } + fc_kernel_pack_weight(gemm.get(), ptr_B, _dt_weight, _N, _K, ldb, 0, _N); + fc_kernel_execute(gemm.get(), A.data, nullptr, C.data, A.stride, + C.stride, _M, _N, _K, 0, _N, dq.data, q.data, bias.data); + C_Ref = 0; + float* ptr_dq = nullptr; + float* ptr_q = nullptr; + float* ptr_bias = nullptr; + func_act act = func_act(); + if ((_postops_type & DEQUANT) && _dt_a == llmdnn::llmdnn_s8) { + ptr_dq = dq.data; + } + if (_postops_type & QUANT) { + ptr_q = q.data; + } + if (_postops_type & BIAS) { + ptr_bias = bias.data; + } + if (_postops_type & GELU) { + act = [] (float x) { + return x * 0.5 * (1 + std::erf(x / std::sqrt(2))); + }; + } + if (_postops_type & GELU_TANH) { + act = [] (float x) { + return 0.5f * x * (1.0f + std::tanh(std::sqrt(2.0f / 3.1415926f) * x * (1 + 0.044715f * x * x))); + }; + } + + matmul(A, B, C_Ref, ptr_dq, ptr_bias, act, ptr_q); + float thresh = 0.0001f; + if (std::is_same::value || std::is_same::value) + thresh = 1.1f; + if (std::is_same::value) + thresh = 0.01f; + ASSERT_TRUE(compare(C, C_Ref, thresh)); + } + + template + void do_weight_compress_test() { + fc_kernel* fc; + fc_create_param param = { + _dt_a, _dt_b, _dt_c, + _is_transpose, _postops_type + }; + // bf16 needs divisible by 2 + if (_K % 2 == 1) _K += 1; + + tensor2D A(_M, _K, true); + tensor2D B(_K, _N, true); + tensor2D C(_M, _N, true); + tensor2D C_Ref(_M, _N, true); + tensor2D dq(1, _N); + tensor2D zp(1, _N * 2); + tensor2D bias(1, _N); + + fill_rnd(A); + dq = 999999; + for (int i = 0; i < _N; i++) { + zp.data[i * 2 + 0] = static_cast(i % 255); + zp.data[i * 2 + 1] = static_cast(i % 255); + for (int j = 0; j < _K; j++) { + B(j, i) = i % 255; + } + } + bias = 0; + param.scale = dq.data; + param.zp = zp.data; + param.scale_zp_size = _N; + ASSERT_TRUE(fc_kernel_create(&fc, ¶m) == llmdnn::status_ok); + auto gemm = std::shared_ptr(fc, [](fc_kernel* p) { fc_kernel_destroy(p); }); + + tensor2D BT = B.Tr(true); + uint8_t* ptr_B; + size_t ldb; + if (_is_transpose) { + ptr_B = BT.data; + ldb = BT.stride; + } else { + ptr_B = B.data; + ldb = B.stride; + } + fc_kernel_pack_weight(gemm.get(), ptr_B, _dt_weight, _N, _K, ldb, 0, _N); + fc_kernel_execute(gemm.get(), A.data, nullptr, C.data, A.stride, + C.stride, _M, _N, _K, 0, _N, dq.data, nullptr, bias.data); + C_Ref = 0; + float* ptr_dq = nullptr; + float* ptr_q = nullptr; + float* ptr_bias = nullptr; + func_act act = func_act(); + ptr_dq = dq.data; + if (_postops_type & BIAS) { + ptr_bias = bias.data; + } + if (_postops_type & GELU) { + act = [] (float x) { + return x * 0.5 * (1 + std::erf(x / std::sqrt(2))); + }; + } + if (_postops_type & GELU_TANH) { + act = [] (float x) { + return 0.5f * x * (1.0f + std::tanh(std::sqrt(2.0f / 3.1415926f) * x * (1 + 0.044715f * x * x))); + }; + } + + //matmul(A, B, C_Ref, ptr_dq, ptr_bias, act, ptr_q); + float thresh = 0.0001f; + if (std::is_same::value || std::is_same::value) + thresh = 1.1f; + if (std::is_same::value) + thresh = 0.01f; + ASSERT_TRUE(compare(C, C_Ref, thresh)); + } + + int _M, _N, _K; + bool _is_transpose; + postops_types _postops_type; + data_type_t _dt_a, _dt_b, _dt_c, _dt_weight; +}; + +TEST_P(FCKernelTest, Func) { + if (_dt_a == llmdnn_s8 && _dt_weight == llmdnn_s8 && _dt_c == llmdnn_s8) { + do_test(); + } else if (_dt_a == llmdnn_s8 && _dt_weight == llmdnn_s8 && _dt_c == llmdnn_bf16) { + do_test(); + } else if (_dt_a == llmdnn_s8 && _dt_weight == llmdnn_s8 && _dt_c == llmdnn_f32) { + do_test(); + } else if (_dt_a == llmdnn_bf16 && _dt_weight == llmdnn_bf16 && _dt_c == llmdnn_bf16) { + do_test(); + } else if (_dt_a == llmdnn_bf16 && _dt_weight == llmdnn_bf16 && _dt_c == llmdnn_f32) { + do_test(); + } else if (_dt_a == llmdnn_bf16 && _dt_weight == llmdnn_f32 && _dt_c == llmdnn_bf16) { + do_test(); + } else if (_dt_a == llmdnn_bf16 && _dt_weight == llmdnn_f32 && _dt_c == llmdnn_f32) { + do_test(); + } else if (_dt_a == llmdnn_bf16 && _dt_weight == llmdnn_u8 && _dt_c == llmdnn_bf16) { + do_weight_compress_test(); + } else if (_dt_a == llmdnn_bf16 && _dt_weight == llmdnn_u8 && _dt_c == llmdnn_f32) { + do_weight_compress_test(); + } else { + ASSERT_TRUE(false); + } +} + +// supported: +// (s8,s8,s8),dq,[bias],[gelu],q +// (s8,s8,bf16),dq,[bias],[gelu] +// (s8,s8,f32),dq,[bias],[gelu] +// (bf16,bf16,bf16),[bias],[gelu] +// (bf16,bf16,f32),[bias],[gelu] +// (bf16,u8,f32),dq,[bias],[gelu] +// (bf16,u8,bf16),dq,[bias],[gelu] +const std::vector types = { + { llmdnn_s8, llmdnn_s8, llmdnn_s8, llmdnn_s8, DEQUANT_QUANT }, + { llmdnn_s8, llmdnn_s8, llmdnn_s8, llmdnn_s8, DEQUANT_BIAS_QUANT }, + { llmdnn_s8, llmdnn_s8, llmdnn_s8, llmdnn_s8, DEQUANT_GELU_QUANT }, + { llmdnn_s8, llmdnn_s8, llmdnn_s8, llmdnn_s8, DEQUANT_BIAS_GELU_QUANT }, + { llmdnn_s8, llmdnn_s8, llmdnn_s8, llmdnn_s8, DEQUANT_GELU_TANH_QUANT }, + { llmdnn_s8, llmdnn_s8, llmdnn_s8, llmdnn_s8, DEQUANT_BIAS_GELU_TANH_QUANT }, + { llmdnn_s8, llmdnn_s8, llmdnn_bf16, llmdnn_s8, DEQUANT }, + { llmdnn_s8, llmdnn_s8, llmdnn_bf16, llmdnn_s8, DEQUANT_BIAS }, + { llmdnn_s8, llmdnn_s8, llmdnn_bf16, llmdnn_s8, DEQUANT_GELU }, + { llmdnn_s8, llmdnn_s8, llmdnn_bf16, llmdnn_s8, DEQUANT_BIAS_GELU }, + { llmdnn_s8, llmdnn_s8, llmdnn_bf16, llmdnn_s8, DEQUANT_GELU_TANH }, + { llmdnn_s8, llmdnn_s8, llmdnn_bf16, llmdnn_s8, DEQUANT_BIAS_GELU_TANH }, + { llmdnn_s8, llmdnn_s8, llmdnn_f32, llmdnn_s8, DEQUANT }, + { llmdnn_s8, llmdnn_s8, llmdnn_f32, llmdnn_s8, DEQUANT_BIAS }, + { llmdnn_s8, llmdnn_s8, llmdnn_f32, llmdnn_s8, DEQUANT_GELU }, + { llmdnn_s8, llmdnn_s8, llmdnn_f32, llmdnn_s8, DEQUANT_BIAS_GELU }, + { llmdnn_s8, llmdnn_s8, llmdnn_f32, llmdnn_s8, DEQUANT_GELU_TANH }, + { llmdnn_s8, llmdnn_s8, llmdnn_f32, llmdnn_s8, DEQUANT_BIAS_GELU_TANH }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, NONE }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, BIAS }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, GELU }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, BIAS_GELU }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, GELU_TANH }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, BIAS_GELU_TANH }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, llmdnn_bf16, NONE }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, llmdnn_bf16, BIAS }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, llmdnn_bf16, GELU }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, llmdnn_bf16, BIAS_GELU }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, llmdnn_bf16, GELU_TANH }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, llmdnn_bf16, BIAS_GELU_TANH }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, llmdnn_f32, NONE }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, llmdnn_f32, BIAS }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, llmdnn_f32, GELU }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, llmdnn_f32, BIAS_GELU }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, llmdnn_f32, GELU_TANH }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_bf16, llmdnn_f32, BIAS_GELU_TANH }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, llmdnn_f32, NONE }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, llmdnn_f32, BIAS }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, llmdnn_f32, GELU }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, llmdnn_f32, BIAS_GELU }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, llmdnn_f32, GELU_TANH }, + { llmdnn_bf16, llmdnn_bf16, llmdnn_f32, llmdnn_f32, BIAS_GELU_TANH }, + // weight compression + { llmdnn_bf16, llmdnn_u8, llmdnn_f32, llmdnn_u8, DEQUANT }, + { llmdnn_bf16, llmdnn_u8, llmdnn_f32, llmdnn_u8, DEQUANT_BIAS }, + { llmdnn_bf16, llmdnn_u8, llmdnn_f32, llmdnn_u8, DEQUANT_GELU }, + { llmdnn_bf16, llmdnn_u8, llmdnn_f32, llmdnn_u8, DEQUANT_BIAS_GELU }, + { llmdnn_bf16, llmdnn_u8, llmdnn_bf16, llmdnn_u8, DEQUANT }, + { llmdnn_bf16, llmdnn_u8, llmdnn_bf16, llmdnn_u8, DEQUANT_BIAS }, + { llmdnn_bf16, llmdnn_u8, llmdnn_bf16, llmdnn_u8, DEQUANT_GELU }, + { llmdnn_bf16, llmdnn_u8, llmdnn_bf16, llmdnn_u8, DEQUANT_BIAS_GELU }, +}; + +// M, N, K +const std::vector shapes = { + // normal + {256, 128, 448}, + // M < 16 + {15, 129, 447}, + {15, 129, 448}, + // M in (16, 32] + {31, 129, 447}, + {31, 129, 448}, + // all tail + {256 + 9, 129, 449}, + {256 + 9, 129, 448}, + // gemv, K <= 64(32)*6 + {256, 1, 80}, +}; + +INSTANTIATE_TEST_SUITE_P(smoke_FCKernel, FCKernelTest, + ::testing::Combine(ValuesIn(types), + Values(true, false), + ValuesIn(shapes)), + FCKernelTest::getTestCaseName); diff --git a/tests/src/test_gelu_kernel_avx512.cpp b/tests/src/test_gelu_kernel_avx512.cpp new file mode 100644 index 0000000..70a9a27 --- /dev/null +++ b/tests/src/test_gelu_kernel_avx512.cpp @@ -0,0 +1,97 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include +#include +#include +#include +#include +#include "gtest/gtest.h" +#include "llm_mm.hpp" +#include "common/tensor2d.hpp" +#include "common/tensor2d_helper.hpp" +#include "gelu_kernel_avx512.hpp" +#include "test_common.hpp" + +using namespace std; +using namespace llmdnn; +using ::testing::TestWithParam; +using ::testing::Values; +using ::testing::ValuesIn; + +using GeluTestParamSet = std::tuple< + std::string // data type + >; + +class GeluTest : public TestWithParam { +public: + static std::string getTestCaseName(const testing::TestParamInfo& obj) { + std::string types; + std::tie(types) = obj.param; + + std::ostringstream result; + result << types; + return result.str(); + } + +protected: + virtual void SetUp() override { + std::tie(_types) = GetParam(); + }; + + void gelu_ref(float* s, float* d, int n) { + if (_types == "tanh") { + for (int i = 0; i < n; i++) { + auto x = s[i]; + d[i] = 0.5f * x * (1.0f + std::tanh(std::sqrt(2.0f / 3.1415926f) * x * (1 + 0.044715f * x * x))); + } + } else { + for (int i = 0; i < n; i++) { + auto x = s[i]; + d[i] = 0.5f * x * (1.0f + std::erf(x / std::sqrt(2.0f))); + } + } + } + + void test(float thresh) { + tensor2D src(1, 16, true); + tensor2D dst(1, 16, true); + tensor2D ref(1, 16, true); + for (int i = 0; i < 16; i++) { + src[i] = std::sqrt(i) - 2; + } + __m512 s = _mm512_loadu_ps(src.data); + __m512 d; + if (_types == "tanh") + d = gelu_tanh_avx512(s); + else + d = gelu_erf_minmax_approx_avx512(s); + _mm512_storeu_ps(dst.data, d); + gelu_ref(src.data, ref.data, 16); + for (int i = 0; i < src.dims[1]; i++) { + float r = ref[i]; + float c = dst[i]; + if (std::abs(r - c) > thresh) { + FAIL() << " cur is not equal, pos: " << i << " opt: " << c << " ref: " << r; + } + } + } + + std::string _types; +}; + +TEST_P(GeluTest, Gelu) { + test(0.01f); +} + +const std::vector types = { + "tanh", "erf" +}; + +INSTANTIATE_TEST_SUITE_P(smoke_Gelu, GeluTest, + ::testing::Combine(ValuesIn(types)), + GeluTest::getTestCaseName); diff --git a/tests/src/test_mm_kernel_amx.cpp b/tests/src/test_mm_kernel_amx.cpp new file mode 100644 index 0000000..da1ec6f --- /dev/null +++ b/tests/src/test_mm_kernel_amx.cpp @@ -0,0 +1,140 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include +#include +#include +#include +#include "gtest/gtest.h" +#include "llm_mm.hpp" +#include "common/tensor2d.hpp" +#include "common/tensor2d_helper.hpp" +#include "llm_types.hpp" +#include "test_common.hpp" + +using namespace std; +using namespace llmdnn; +using ::testing::TestWithParam; +using ::testing::Values; +using ::testing::ValuesIn; + +using MMKernelTestShape = std::tuple; +using MMKernelTestParamSet = std::tuple< + std::pair, // a, b data type + bool, // b needs transpose + MMKernelTestShape // M, N, K + >; + +class GemmKernelTest : public TestWithParam { +public: + static std::string getTestCaseName(const testing::TestParamInfo& obj) { + std::pair types; + bool is_transpose; + MMKernelTestShape shape; + int M, N, K; + std::tie(types, is_transpose, shape) = obj.param; + std::tie(M, N, K) = shape; + + std::ostringstream result; + result << "A_" << dtype_to_str(types.first) << "_B_" << dtype_to_str(types.second) << "_" + << (is_transpose ? "transpose_" : "") + << "M_" << M << "_N_" << N << "_K_" << K; + return result.str(); + } + +protected: + virtual void SetUp() override { + initXTILE(); + + MMKernelTestShape shape; + std::tie(_types, _is_transpose, shape) = GetParam(); + std::tie(_M, _N, _K) = shape; + }; + + template + void test() { + if (_N == 1 && (_is_transpose || _types.first == llmdnn_u8)) { + GTEST_SKIP() << "gemv does not support transpose or u8s8."; + } + mm_kernel* mm; + mm_create_param param = { + _types.first, _types.second, + _N == 1, _is_transpose + }; + ASSERT_TRUE(mm_kernel_create(&mm, ¶m) == llmdnn::status_ok); + auto gemm = std::shared_ptr(mm, [](mm_kernel* p) { mm_kernel_destroy(p); }); + + tensor2D A(_M, _K, true); + tensor2D B(_K, _N, true); + tensor2D C(_M, _N, true); + tensor2D C_Ref(_M, _N, true); + + fill_rnd(A); + fill_rnd(B); + tensor2D BT = B.Tr(); + TB* ptr_B; + size_t ldb; + if (_is_transpose) { + ptr_B = BT.data; + ldb = BT.stride; + } else { + ptr_B = B.data; + ldb = B.stride; + } + mm_kernel_execute(gemm.get(), A.data, ptr_B, C.data, A.stride, ldb, + C.stride, _M, _N, _K); + C_Ref = 0; + matmul(A, B, C_Ref); + ASSERT_TRUE(C_Ref == C); + } + + int _M, _N, _K; + std::pair _types; + bool _is_transpose; +}; + +TEST_P(GemmKernelTest, Func) { + if (_types.first == llmdnn_u8 && _types.second == llmdnn_s8) { + test(); + } else if (_types.first == llmdnn_s8 && _types.second == llmdnn_s8) { + test(); + } else { + test(); + } +} + +const std::vector> types = { + { llmdnn_u8, llmdnn_s8 }, + { llmdnn_s8, llmdnn_s8 }, + { llmdnn_bf16, llmdnn_bf16 }, +}; + +// M, N, K +const std::vector shapes = { + // normal + {256, 48, 448}, + // k < 32 + {256, 48, 15}, + // k tail + {256, 48, 449}, + // M tail == unroll 8 + {256 + 8, 48, 449}, + // M tail == unroll 8 + 2 + {256 + 10, 48, 449}, + // N tail + {256, 40, 448}, + // all tail + {256 + 9, 47, 449}, + // gemv, K <= 64(32)*6 + {256, 1, 160}, +}; + +INSTANTIATE_TEST_SUITE_P(smoke_GemmKernel, GemmKernelTest, + ::testing::Combine(ValuesIn(types), + Values(true, false), + ValuesIn(shapes)), + GemmKernelTest::getTestCaseName); diff --git a/tests/src/test_rotary_kernel_avx2.cpp b/tests/src/test_rotary_kernel_avx2.cpp new file mode 100644 index 0000000..4718162 --- /dev/null +++ b/tests/src/test_rotary_kernel_avx2.cpp @@ -0,0 +1,109 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include +#include +#include +#include +#include "gtest/gtest.h" +#include "llm_mm.hpp" +#include "common/tensor2d.hpp" +#include "common/tensor2d_helper.hpp" +#include "rotary_kernel_avx2.hpp" +#include "test_common.hpp" + +using namespace std; +using namespace llmdnn; +using ::testing::TestWithParam; +using ::testing::Values; +using ::testing::ValuesIn; + +using RotaryTestAVX2ParamSet = std::tuple< + data_type_t // data type + >; + +class RotaryTestAVX2 : public TestWithParam { +public: + static std::string getTestCaseName(const testing::TestParamInfo& obj) { + data_type_t types; + std::tie(types) = obj.param; + + std::ostringstream result; + result << dtype_to_str(types); + return result.str(); + } + +protected: + virtual void SetUp() override { + std::tie(_types) = GetParam(); + }; + + template + static void rotary_emb(size_t rotaryNdims, float* cos, float* sin, T* q_src, T* k_src, T* q_dst, T* k_dst) { + auto halfRotaryNdims = rotaryNdims / 2; + for (size_t i = 0; i < halfRotaryNdims; i++) { + q_dst[i] = q_src[i] * cos[i] - q_src[i + halfRotaryNdims] * sin[i]; + k_dst[i] = k_src[i] * cos[i] - k_src[i + halfRotaryNdims] * sin[i]; + } + for (size_t i = halfRotaryNdims; i < rotaryNdims; i++) { + q_dst[i] = q_src[i] * cos[i] + q_src[i - halfRotaryNdims] * sin[i]; + k_dst[i] = k_src[i] * cos[i] + k_src[i - halfRotaryNdims] * sin[i]; + } + } + + template + void test(float thresh) { + for (int n = 6; n < 129; n += 2) { + tensor2D q_src(1, n, true); + tensor2D k_src(1, n, true); + tensor2D q_dst(1, n, true); + tensor2D k_dst(1, n, true); + tensor2D q_dst_ref(1, n, true); + tensor2D k_dst_ref(1, n, true); + tensor2D cos(1, n, true); + tensor2D sin(1, n, true); + for (int i = 0; i < n; i++) { + q_src[i] = i % 19 - 10; + k_src[i] = i % 19 - 9; + cos[i] = i % 19 - 8; + sin[i] = i % 19 - 7; + } + rotary_emb(n, cos.data, sin.data, q_src.data, k_src.data, q_dst_ref.data, k_dst_ref.data); + rotary_avx2(n, cos.data, sin.data, q_src.data, k_src.data, q_dst.data, k_dst.data); + for (int i = 0; i < n; i++) { + float q = q_dst[i]; + float q_ref = q_dst_ref[i]; + float k = k_dst[i]; + float k_ref = k_dst_ref[i]; + if (std::abs(q - q_ref) > thresh) { + FAIL() << " q is not equal, N: " << n << " pos: " << i << " opt: " << q << " ref: " << q_ref; + } + if (std::abs(k - k_ref) > thresh) { + FAIL() << " k is not equal, N: " << n << " pos: " << i << " opt: " << k << " ref: " << k_ref; + } + } + } + } + + data_type_t _types; +}; + +TEST_P(RotaryTestAVX2, rotary) { + if (_types == llmdnn_s8) { + ASSERT_TRUE(false); + } else { + test(0.01f); + } +} + +const std::vector types = { + llmdnn_f32 +}; + +INSTANTIATE_TEST_SUITE_P(smoke_Rotary, RotaryTestAVX2, + ::testing::Combine(ValuesIn(types)), + RotaryTestAVX2::getTestCaseName); diff --git a/tests/src/test_rotary_kernel_avx512.cpp b/tests/src/test_rotary_kernel_avx512.cpp new file mode 100644 index 0000000..6cd196b --- /dev/null +++ b/tests/src/test_rotary_kernel_avx512.cpp @@ -0,0 +1,109 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include +#include +#include +#include +#include "gtest/gtest.h" +#include "llm_mm.hpp" +#include "common/tensor2d.hpp" +#include "common/tensor2d_helper.hpp" +#include "rotary_kernel_avx512.hpp" +#include "test_common.hpp" + +using namespace std; +using namespace llmdnn; +using ::testing::TestWithParam; +using ::testing::Values; +using ::testing::ValuesIn; + +using RotaryTestParamSet = std::tuple< + data_type_t // data type + >; + +class RotaryTest : public TestWithParam { +public: + static std::string getTestCaseName(const testing::TestParamInfo& obj) { + data_type_t types; + std::tie(types) = obj.param; + + std::ostringstream result; + result << dtype_to_str(types); + return result.str(); + } + +protected: + virtual void SetUp() override { + std::tie(_types) = GetParam(); + }; + + template + static void rotary_emb(size_t rotaryNdims, float* cos, float* sin, T* q_src, T* k_src, T* q_dst, T* k_dst) { + auto halfRotaryNdims = rotaryNdims / 2; + for (size_t i = 0; i < halfRotaryNdims; i++) { + q_dst[i] = q_src[i] * cos[i] - q_src[i + halfRotaryNdims] * sin[i]; + k_dst[i] = k_src[i] * cos[i] - k_src[i + halfRotaryNdims] * sin[i]; + } + for (size_t i = halfRotaryNdims; i < rotaryNdims; i++) { + q_dst[i] = q_src[i] * cos[i] + q_src[i - halfRotaryNdims] * sin[i]; + k_dst[i] = k_src[i] * cos[i] + k_src[i - halfRotaryNdims] * sin[i]; + } + } + + template + void test(float thresh) { + for (int n = 6; n < 129; n += 2) { + tensor2D q_src(1, n, true); + tensor2D k_src(1, n, true); + tensor2D q_dst(1, n, true); + tensor2D k_dst(1, n, true); + tensor2D q_dst_ref(1, n, true); + tensor2D k_dst_ref(1, n, true); + tensor2D cos(1, n, true); + tensor2D sin(1, n, true); + for (int i = 0; i < n; i++) { + q_src[i] = i % 19 - 10; + k_src[i] = i % 19 - 9; + cos[i] = i % 19 - 8; + sin[i] = i % 19 - 7; + } + rotary_emb(n, cos.data, sin.data, q_src.data, k_src.data, q_dst_ref.data, k_dst_ref.data); + rotary_avx512(n, cos.data, sin.data, q_src.data, k_src.data, q_dst.data, k_dst.data); + for (int i = 0; i < n; i++) { + float q = q_dst[i]; + float q_ref = q_dst_ref[i]; + float k = k_dst[i]; + float k_ref = k_dst_ref[i]; + if (std::abs(q - q_ref) > thresh) { + FAIL() << " q is not equal, N: " << n << " pos: " << i << " opt: " << q << " ref: " << q_ref; + } + if (std::abs(k - k_ref) > thresh) { + FAIL() << " k is not equal, N: " << n << " pos: " << i << " opt: " << k << " ref: " << k_ref; + } + } + } + } + + data_type_t _types; +}; + +TEST_P(RotaryTest, rotary) { + if (_types == llmdnn_s8) { + ASSERT_TRUE(false); + } else { + test(0.01f); + } +} + +const std::vector types = { + llmdnn_bf16 +}; + +INSTANTIATE_TEST_SUITE_P(smoke_Rotary, RotaryTest, + ::testing::Combine(ValuesIn(types)), + RotaryTest::getTestCaseName); diff --git a/tests/src/test_softmax_kernel_avx512.cpp b/tests/src/test_softmax_kernel_avx512.cpp new file mode 100644 index 0000000..34942ba --- /dev/null +++ b/tests/src/test_softmax_kernel_avx512.cpp @@ -0,0 +1,133 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include +#include +#include +#include +#include "gtest/gtest.h" +#include "llm_mm.hpp" +#include "common/tensor2d.hpp" +#include "common/tensor2d_helper.hpp" +#include "softmax_kernel_avx512.hpp" +#include "test_common.hpp" + +using namespace std; +using namespace llmdnn; +using ::testing::TestWithParam; +using ::testing::Values; +using ::testing::ValuesIn; + +using SoftmaxTestParamSet = std::tuple< + data_type_t // data type + >; + +class SoftmaxTest : public TestWithParam { +public: + static std::string getTestCaseName(const testing::TestParamInfo& obj) { + data_type_t types; + std::tie(types) = obj.param; + + std::ostringstream result; + result << dtype_to_str(types); + return result.str(); + } + +protected: + virtual void SetUp() override { + std::tie(_types) = GetParam(); + }; + + template + static void gen_ref(tensor2D& x, tensor2D& out, tensor2D& quant) { + tensor2D y = x.clone(); + float x_max = std::numeric_limits::lowest(); + for(int i = 0; i < x.dims[1]; i++) { + x_max = std::max(x_max, x[i]); + } + float sum = 0; + for(int i = 0; i < x.dims[1]; i++) { + y[i] = expf(x[i] - x_max); + sum += y[i]; + } + for(int i = 0; i < x.dims[1]; i++) { + y[i] = y[i] / sum; + } + out.resize(x.dims[0], x.dims[1], true); + if (std::is_same::value) { + memcpy(static_cast(out.data), y.data, x.dims[0] * x.dims[1] * sizeof(float)); + } + if (std::is_same::value) { + for(int i = 0; i < x.dims[1]; i++) { + out[i] = y[i]; + } + } +#define CLIP(x, low, high) \ + (x < low ? low : (x > high ? high : x)) + if (std::is_same::value) { + for(int i = 0; i < x.dims[1]; i++) { + auto tmp = y[i] * quant[i]; + out[i] = static_cast(CLIP(tmp, -128, 127)); + } + } + if (std::is_same::value) { + for(int i = 0; i < x.dims[1]; i++) { + auto tmp = y[i] * quant[i]; + out[i] = static_cast(CLIP(tmp, 0, 255)); + } + } +#undef CLIP + } + + template + void test(float thresh) { + for (int n = 1; n < 129; n++) { + tensor2D A(1, n, true), A_scalar(1, n, true); + tensor2D quant(1, n, true); + tensor2D out(1, n, true), out_ref, out_scalar(1, n, true); + for (int i = 0; i < n; i++) { + A[i] = static_cast(i) - n / 2; + A_scalar[i] = A[i]; + } + quant = 128.f; + gen_ref(A, out_ref, quant); + llmdnn::softmax_avx512(out.data, A.data, n, quant.data); + for (int i = 0; i < n; i++) { + float a = out[i]; + float b = out_ref[i]; + if (std::abs(a - b) > thresh) { + FAIL() << " N: " << n << " pos: " << i << " opt: " << a << " ref: " << b; + } + } + // input is scalar + llmdnn::softmax_avx512(out_scalar.data, A_scalar.data, n, quant.data[0]); + ASSERT_TRUE(out == out_scalar); + } + } + + data_type_t _types; +}; + +TEST_P(SoftmaxTest, Func) { + if (_types == llmdnn_s8) { + test(1.1f); + } else if (_types == llmdnn_u8) { + test(1.1f); + } else if (_types == llmdnn_f32) { + test(0.00001f); + } else { + test(0.01f); + } +} + +const std::vector types = { + llmdnn_s8, llmdnn_bf16, llmdnn_u8, llmdnn_f32 +}; + +INSTANTIATE_TEST_SUITE_P(smoke_Softmax, SoftmaxTest, + ::testing::Combine(ValuesIn(types)), + SoftmaxTest::getTestCaseName); diff --git a/tests/src/test_transpose_kernel_avx512.cpp b/tests/src/test_transpose_kernel_avx512.cpp new file mode 100644 index 0000000..702fdc3 --- /dev/null +++ b/tests/src/test_transpose_kernel_avx512.cpp @@ -0,0 +1,131 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include +#include +#include +#include +#include "gtest/gtest.h" +#include "llm_mm.hpp" +#include "common/tensor2d.hpp" +#include "common/tensor2d_helper.hpp" +#include "transpose_kernel_avx512.hpp" +#include "test_common.hpp" + +using namespace std; +using namespace llmdnn; +using ::testing::TestWithParam; +using ::testing::Values; +using ::testing::ValuesIn; + +using TransposeTestParamSet = std::tuple< + data_type_t // data type + >; + +class TransposeTest : public TestWithParam { +public: + static std::string getTestCaseName(const testing::TestParamInfo& obj) { + data_type_t types; + std::tie(types) = obj.param; + + std::ostringstream result; + result << dtype_to_str(types); + return result.str(); + } + +protected: + virtual void SetUp() override { + std::tie(_types) = GetParam(); + }; + + template + static void gen_ref(T* dst, float* src, size_t height, size_t width, size_t src_stride, size_t dst_stride, float* quant) { + for (size_t j = 0; j < height; j++) { + if (std::is_same::value) { + memcpy(static_cast(dst), src, width * sizeof(float)); + } + if (std::is_same::value) { + for(size_t i = 0; i < width; i++) { + dst[i] = src[i]; + } + } + #define CLIP(x, low, high) \ + (x < low ? low : (x > high ? high : x)) + if (std::is_same::value) { + for(size_t i = 0; i < width; i++) { + auto tmp = src[i] * quant[i]; + dst[i] = static_cast(CLIP(tmp, -128, 127)); + } + } + if (std::is_same::value) { + for(size_t i = 0; i < width; i++) { + auto tmp = src[i] * quant[i]; + dst[i] = static_cast(CLIP(tmp, 0, 255)); + } + } + #undef CLIP + src = reinterpret_cast(reinterpret_cast(src) + src_stride); + dst = reinterpret_cast(reinterpret_cast(dst) + dst_stride); + } + } + + template + void test(float thresh) { + // [num_heads, query_seq_len, head_size] => [query_seq_len, num_heads * head_size] + int num_heads = 2, query_seq_len = 10; + for (int head_size = 1; head_size < 129; head_size++) { + tensor2D src(num_heads, head_size * query_seq_len, true); + tensor2D quant(1, head_size, true); + tensor2D dst(query_seq_len, num_heads * head_size, true); + tensor2D dst_ref(query_seq_len, num_heads * head_size, true); + for (int i = 0; i < num_heads * head_size * query_seq_len; i++) { + src[i] = i % 253 - 127; + } + quant = 1.28f; + auto* dst_p = dst.data; + auto* dst_p_ref = dst_ref.data; + for (int i = 0; i < num_heads; i++) { + auto* src_p = &src(i, 0); + llmdnn::memcpy2d_stride_avx512(dst_p, src_p, query_seq_len, head_size, + head_size * sizeof(float), num_heads * head_size * sizeof(T), quant.data); + gen_ref(dst_p_ref, src_p, query_seq_len, head_size, + head_size * sizeof(float), num_heads * head_size * sizeof(T), quant.data); + dst_p += head_size; + dst_p_ref += head_size; + } + for (int i = 0; i < num_heads * head_size * query_seq_len; i++) { + float a = dst[i]; + float b = dst_ref[i]; + if (std::abs(a - b) > thresh) { + FAIL() << " N: " << head_size << " pos: " << i << " opt: " << a << " ref: " << b; + } + } + } + } + + data_type_t _types; +}; + +TEST_P(TransposeTest, memcpy2d) { + if (_types == llmdnn_s8) { + test(1.1f); + } else if (_types == llmdnn_u8) { + test(1.1f); + } else if (_types == llmdnn_f32) { + test(0.00001f); + } else { + test(0.01f); + } +} + +const std::vector types = { + llmdnn_s8, llmdnn_bf16, llmdnn_u8, llmdnn_f32 +}; + +INSTANTIATE_TEST_SUITE_P(smoke_Transpose, TransposeTest, + ::testing::Combine(ValuesIn(types)), + TransposeTest::getTestCaseName); diff --git a/tests/src/test_utility_kernel_avx512.cpp b/tests/src/test_utility_kernel_avx512.cpp new file mode 100644 index 0000000..c0806e1 --- /dev/null +++ b/tests/src/test_utility_kernel_avx512.cpp @@ -0,0 +1,42 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include +#include +#include +#include +#include "gtest/gtest.h" +#include "llm_mm.hpp" +#include "common/tensor2d.hpp" +#include "common/tensor2d_helper.hpp" +#include "utility_kernel_avx512.hpp" +#include "test_common.hpp" + +using namespace std; +using namespace llmdnn; +using ::testing::Values; +using ::testing::ValuesIn; + +TEST(smoke_Utility, muladd) { + float normal_factor = 1.2f; + for (size_t len = 1; len < 129; len++) { + std::vector x(len), x_out(len), bias(len), ref(len); + std::vector mask(len, 1); + mask[0] = 0; + for (size_t i = 0; i < x.size(); i++) { + x[i] = -10.0f + i; + bias[i] = -100.0f + i; + ref[i] = x[i] * normal_factor + bias[i] + bias[i]; + if (mask[i] == 0) + ref[i] = -FLT_MAX; + } + mul_add2_select_f32_avx512(x_out.data(), x.data(), normal_factor, bias.data(), bias.data(), mask.data(), true, len); + for (size_t i = 0; i < x.size(); i++) { + ASSERT_TRUE(std::abs(x_out[i] - ref[i]) < 0.0001f) << " length: " << len << " pos: " << i << " cur: " << x[i] << " ref: " << ref[i]; + } + } +} \ No newline at end of file diff --git a/tests/src/test_utility_kernel_repack1x2_avx512.cpp b/tests/src/test_utility_kernel_repack1x2_avx512.cpp new file mode 100644 index 0000000..6d48cfd --- /dev/null +++ b/tests/src/test_utility_kernel_repack1x2_avx512.cpp @@ -0,0 +1,213 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include +#include +#include +#include +#include "gtest/gtest.h" +#include "llm_mm.hpp" +#include "common/tensor2d.hpp" +#include "common/tensor2d_helper.hpp" +#include "mm_kernel_common_amx.hpp" +#include "test_common.hpp" + +using namespace std; +using namespace llmdnn; +using ::testing::TestWithParam; +using ::testing::Values; +using ::testing::ValuesIn; + +using RepackTestParamSet = std::tuple< + std::pair // data type + >; + +class RepackTest : public TestWithParam { +public: + static std::string getTestCaseName(const testing::TestParamInfo& obj) { + std::pair types; + std::tie(types) = obj.param; + + std::ostringstream result; + result << "IN_" << dtype_to_str(types.first) << "_OUT_" << dtype_to_str(types.second); + return result.str(); + } + +protected: + virtual void SetUp() override { + std::tie(_types) = GetParam(); + }; + + template + static void gen_ref(tensor2D& in_t, tensor2D& out) { + int K = in_t.dims[1]; + int N = in_t.dims[0]; + int kStep = 64 / sizeof(T); + int K_padded = (K + kStep - 1) / kStep * kStep; + int Ktails = K % kStep; + + // N_padded : round up to multiple of (2*16) + int N_unit = 2 * 16; + int N_padded = (N + N_unit - 1) / N_unit * N_unit; + + // Bo(ni, 0) is a vector flattened from a slice of shape [K_padded x N_unit] + out.resize(N_padded / N_unit, K_padded * N_unit, true, true); + + tensor2D in_padded; + in_padded.resize(N_padded, K_padded, true, true); + for (int i = 0; i < N; i++) { + // full range + memcpy(&in_padded(i, 0), &in_t(i, 0), (K - Ktails) * sizeof(T)); + // k tail needs to right aligned + memcpy(&in_padded(i, K_padded - Ktails), &in_t(i, K - Ktails), Ktails * sizeof(T)); + } + for (int n = 0; n < N_padded; n += N_unit) { + for (int k = 0; k < K_padded; k += kStep) { + // bf16 as example: + // [N, K], 2*[16(N), 32(K)] => + // [K, N], 2*[16(K), 16(N)*2(K)] + for (int m = 0; m < 2; m++) { + auto* src = reinterpret_cast(&in_padded(n + m * 16, k)); + auto* dst = reinterpret_cast(&out(n / N_unit, k * N_unit)); + dst += m * (1024 / sizeof(int)); + for (int i = 0; i < 16; i++) { + for (int j = 0; j < 16; j++) { + dst[i * 16 + j] = src[j * in_padded.stride / sizeof(int) + i]; + } + } + } + } + } + } + + template + void test() { + auto testone = [] (int k, int n, std::string prefix) { + tensor2D A(k, n, true); + + fill_rnd(A); + tensor2D AT = A.Tr(true); + tensor2D A_out, AT_out, A_ref; + amx_kernel::repackB_1x2(A, false, A_out, false); + amx_kernel::repackB_1x2(AT, true, AT_out, false); + if constexpr (std::is_same_v) { + tensor2D AT_bf16(n, k, true); + amx_kernel::functional::f32_to_bf16_tensor(AT_bf16, AT); + gen_ref(AT_bf16, A_ref); + } else { + gen_ref(AT, A_ref); + } + ASSERT_TRUE(A_out == A_ref) << " " << prefix << " without transform K: " << k << " N: " << n; + ASSERT_TRUE(AT_out == A_ref) << " " << prefix << " with transform K: " << k << " N: " << n; + }; + // n tail: transpose case needs from 1 to 31, without transpose needs one + int k = 32; + int n; + for (n = 1; n < 32; n++) { + testone(k, n, "ntail"); + } + for (n = 32 + 1; n < 32 + 32; n++) { + testone(k, n, "ntail"); + } + // k tail: transpose case needs 1, without transpose needs from 1 to 31 + n = 32; + for (k = 1; k < 32; k++) { + testone(k, n, "ktail"); + } + for (k = 32 + 1; k < 32 + 32; k++) { + testone(k, n, "ktail"); + } + // k, n normal + testone(32, 32, "normal"); + testone(64, 128, "normal"); + // k, n tail + testone(64, 128 + 5, "ntail"); + testone(64 + 3, 128, "ktail"); + testone(64 + 3, 128 + 5, "alltail"); + testone(64, 128 + 16 + 5, "ntail"); + testone(64 + 16 + 3, 128, "ktail"); + testone(64 + 16 + 3, 128 + 16 + 5, "alltail"); + } + + void test_wc() { + auto testone = [] (int k, int n, std::string prefix) { + tensor2D A(k, n, true); + + // get ref result + tensor2D A_ref; + tensor2D A_bf16(k, n, true), A_ref_bf16; + for (int i = 0; i < k * n; i++) { + A.data[i] = i % 23; + A_bf16.data[i] = ov::bfloat16(i % 23); + } + amx_kernel::repackB_1x2(A_bf16, false, A_ref_bf16, true); + A_ref.resize(A_ref_bf16.dims[0], A_ref_bf16.dims[1], true); + for (int i = 0; i < A_ref_bf16.dims[0] * A_ref_bf16.dims[1]; i++) { + A_ref.data[i] = static_cast(float(A_ref_bf16.data[i])); + } + + tensor2D AT = A.Tr(true); + tensor2D A_out, AT_out; + amx_kernel::repackB_1x2_compressed(A, false, A_out, true); + amx_kernel::repackB_1x2_compressed(AT, true, AT_out, true); + ASSERT_TRUE(A_out == A_ref) << " " << prefix << " without transform K: " << k << " N: " << n; + ASSERT_TRUE(AT_out == A_ref) << " " << prefix << " with transform K: " << k << " N: " << n; + }; + // n tail: transpose case needs from 1 to 31, without transpose needs one + int k = 32; + int n; + for (n = 1; n < 32; n++) { + testone(k, n, "ntail"); + } + for (n = 32 + 1; n < 32 + 32; n++) { + testone(k, n, "ntail"); + } + // k tail: transpose case needs 1, without transpose needs from 1 to 31 + n = 32; + for (k = 1; k < 32; k++) { + testone(k, n, "ktail"); + } + for (k = 32 + 1; k < 32 + 32; k++) { + testone(k, n, "ktail"); + } + // k, n normal + testone(32, 32, "normal"); + testone(64, 128, "normal"); + // k, n tail + testone(64, 128 + 5, "ntail"); + testone(64 + 3, 128, "ktail"); + testone(64 + 3, 128 + 5, "alltail"); + testone(64, 128 + 16 + 5, "ntail"); + testone(64 + 16 + 3, 128, "ktail"); + testone(64 + 16 + 3, 128 + 16 + 5, "alltail"); + } + + std::pair _types; +}; + +TEST_P(RepackTest, Func) { + if (_types.first == llmdnn_s8 && _types.second == llmdnn_s8) { + test(); + } else if (_types.first == llmdnn_u8 && _types.second == llmdnn_u8) { + test_wc(); + } else if (_types.first == llmdnn::llmdnn_bf16 && _types.second == llmdnn_bf16) { + test(); + } else { + test(); + } +} + +const std::vector> types = { + {llmdnn_u8, llmdnn_u8}, // compress weight + {llmdnn_s8, llmdnn_s8}, + {llmdnn_bf16, llmdnn_bf16}, + {llmdnn_f32, llmdnn_bf16}, +}; + +INSTANTIATE_TEST_SUITE_P(smoke_Repack, RepackTest, + ::testing::Combine(ValuesIn(types)), + RepackTest::getTestCaseName);