diff --git a/common_level3.h b/common_level3.h index fc4909e693..19d7d120f7 100644 --- a/common_level3.h +++ b/common_level3.h @@ -59,6 +59,19 @@ void sgemm_direct_alpha_beta(BLASLONG M, BLASLONG N, BLASLONG K, float beta, float * R, BLASLONG strideR); +void ssymm_direct_alpha_betaLU(BLASLONG M, BLASLONG N, + float alpha, + float * A, BLASLONG strideA, + float * B, BLASLONG strideB, + float beta, + float * R, BLASLONG strideR); +void ssymm_direct_alpha_betaLL(BLASLONG M, BLASLONG N, + float alpha, + float * A, BLASLONG strideA, + float * B, BLASLONG strideB, + float beta, + float * R, BLASLONG strideR); + int sgemm_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K); int shgemm_beta(BLASLONG, BLASLONG, BLASLONG, float, diff --git a/common_param.h b/common_param.h index 0145f667a1..54bb896fc5 100644 --- a/common_param.h +++ b/common_param.h @@ -257,6 +257,8 @@ int (*shgemm_otcopy )(BLASLONG, BLASLONG, hfloat16 *, BLASLONG, hfloat16 *); #ifdef ARCH_ARM64 void (*sgemm_direct) (BLASLONG, BLASLONG, BLASLONG, float *, BLASLONG , float *, BLASLONG , float * , BLASLONG); void (*sgemm_direct_alpha_beta) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float, float * , BLASLONG); + void (*ssymm_direct_alpha_betaLU) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float, float * , BLASLONG); + void (*ssymm_direct_alpha_betaLL) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float, float * , BLASLONG); #endif diff --git a/common_s.h b/common_s.h index 88b4732f51..a73b0082f3 100644 --- a/common_s.h +++ b/common_s.h @@ -50,6 +50,8 @@ #define SGEMM_DIRECT_PERFORMANT sgemm_direct_performant #define SGEMM_DIRECT sgemm_direct #define SGEMM_DIRECT_ALPHA_BETA sgemm_direct_alpha_beta +#define SSYMM_DIRECT_ALPHA_BETA_LU ssymm_direct_alpha_betaLU +#define SSYMM_DIRECT_ALPHA_BETA_LL ssymm_direct_alpha_betaLL #define SGEMM_ONCOPY sgemm_oncopy #define SGEMM_OTCOPY sgemm_otcopy @@ -220,6 +222,8 @@ #define SGEMM_DIRECT_PERFORMANT sgemm_direct_performant #define SGEMM_DIRECT gotoblas -> sgemm_direct #define SGEMM_DIRECT_ALPHA_BETA gotoblas -> sgemm_direct_alpha_beta +#define SSYMM_DIRECT_ALPHA_BETA_LU gotoblas -> ssymm_direct_alpha_betaLU +#define SSYMM_DIRECT_ALPHA_BETA_LL gotoblas -> ssymm_direct_alpha_betaLL #endif #define SGEMM_ONCOPY gotoblas -> sgemm_oncopy diff --git a/interface/symm.c b/interface/symm.c index 3e6e0fd488..04a8fab7de 100644 --- a/interface/symm.c +++ b/interface/symm.c @@ -371,6 +371,24 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_SIDE Side, enum CBLAS_UPLO Uplo, return; } +#if !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16) && !defined(HFLOAT16) +#if defined(ARCH_ARM64) && (defined(USE_SSYMM_KERNEL_DIRECT)||defined(DYNAMIC_ARCH)) +#if defined(DYNAMIC_ARCH) + if (support_sme1()) +#endif + if (args.m == 0 || args.n == 0) return; + if (order == CblasRowMajor && m == lda && n == ldb && n == ldc) + { + if (Side == CblasLeft && Uplo == CblasUpper) { + SSYMM_DIRECT_ALPHA_BETA_LU(m, n, alpha, a, lda, b, ldb, beta, c, ldc); return; + } + else if (Side == CblasLeft && Uplo == CblasLower) { + SSYMM_DIRECT_ALPHA_BETA_LL(m, n, alpha, a, lda, b, ldb, beta, c, ldc); return; + } + } +#endif +#endif + #endif if (args.m == 0 || args.n == 0) return; diff --git a/kernel/CMakeLists.txt b/kernel/CMakeLists.txt index a2e349d32d..d73dc27e29 100644 --- a/kernel/CMakeLists.txt +++ b/kernel/CMakeLists.txt @@ -241,6 +241,10 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS) if (X86_64 OR ARM64) set(USE_DIRECT_SGEMM true) endif() + set(USE_DIRECT_SSYMM false) + if (ARM64) + set(USE_DIRECT_SSYMM true) + endif() if (UC_TARGET_CORE MATCHES ARMV9SME) set (HAVE_SME true) endif () @@ -267,6 +271,14 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS) endif () endif() + if (USE_DIRECT_SSYMM) + if (ARM64) + set (SSYMMDIRECTKERNEL_ALPHA_BETA ssymm_direct_alpha_beta_arm64_sme1.c) + GenerateNamedObjects("${KERNELDIR}/${SSYMMDIRECTKERNEL_ALPHA_BETA}" "" "symm_direct_alpha_betaLU" false "" "" false SINGLE) + GenerateNamedObjects("${KERNELDIR}/${SSYMMDIRECTKERNEL_ALPHA_BETA}" "" "symm_direct_alpha_betaLL" false "" "" false SINGLE) + endif () + endif() + foreach (float_type SINGLE DOUBLE) string(SUBSTRING ${float_type} 0 1 float_char) GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMMKERNEL}" "" "gemm_kernel" false "" "" false ${float_type}) diff --git a/kernel/Makefile.L3 b/kernel/Makefile.L3 index 79c88d76c7..5d0f7213cd 100644 --- a/kernel/Makefile.L3 +++ b/kernel/Makefile.L3 @@ -52,6 +52,7 @@ endif ifeq ($(ARCH), arm64) USE_TRMM = 1 USE_DIRECT_SGEMM = 1 +USE_DIRECT_SSYMM = 1 endif ifeq ($(ARCH), riscv64) @@ -137,6 +138,17 @@ endif endif endif +ifdef USE_DIRECT_SSYMM +ifndef SSYMMDIRECTKERNEL_ALPHA_BETA +ifeq ($(ARCH), arm64) +ifeq ($(TARGET_CORE), ARMV9SME) +HAVE_SME = 1 +endif +SSYMMDIRECTKERNEL_ALPHA_BETA = ssymm_direct_alpha_beta_arm64_sme1.c +endif +endif +endif + ifeq ($(BUILD_BFLOAT16), 1) ifndef BGEMMKERNEL BGEMM_BETA = ../generic/gemm_beta.c @@ -220,6 +232,14 @@ endif endif endif +ifdef USE_DIRECT_SSYMM +ifeq ($(ARCH), arm64) +SKERNELOBJS += \ + ssymm_direct_alpha_betaLU$(TSUFFIX).$(SUFFIX) \ + ssymm_direct_alpha_betaLL$(TSUFFIX).$(SUFFIX) +endif +endif + ifneq "$(or $(BUILD_DOUBLE),$(BUILD_COMPLEX16))" "" DKERNELOBJS += \ dgemm_beta$(TSUFFIX).$(SUFFIX) \ @@ -982,6 +1002,15 @@ endif endif endif +ifdef USE_DIRECT_SSYMM +ifeq ($(ARCH), arm64) +$(KDIR)ssymm_direct_alpha_betaLU$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SSYMMDIRECTKERNEL_ALPHA_BETA) + $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -DLEFT -DUPPER $< -o $@ +$(KDIR)ssymm_direct_alpha_betaLL$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SSYMMDIRECTKERNEL_ALPHA_BETA) + $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -DLEFT -DLOWER $< -o $@ +endif +endif + ifeq ($(BUILD_BFLOAT16), 1) $(KDIR)bgemm_kernel$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(BGEMMKERNEL) $(CC) $(CFLAGS) -c -DBFLOAT16 -DBGEMM -UDOUBLE -UCOMPLEX $< -o $@ diff --git a/kernel/arm64/ssymm_direct_alpha_beta_arm64_sme1.c b/kernel/arm64/ssymm_direct_alpha_beta_arm64_sme1.c new file mode 100644 index 0000000000..8e3c0c0563 --- /dev/null +++ b/kernel/arm64/ssymm_direct_alpha_beta_arm64_sme1.c @@ -0,0 +1,225 @@ +/* + Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. + SPDX-License-Identifier: BSD-3-Clause-Clear +*/ + +#include "common.h" +#include +#include +#include +// #include "sme_abi.h" +#if defined(HAVE_SME) + +#if defined(__ARM_FEATURE_SME) && defined(__clang__) && __clang_major__ >= 16 +#include +#endif + +/* Function prototypes */ +extern void sgemm_direct_sme1_preprocess(uint64_t nbr, uint64_t nbc,\ + const float * restrict a, float * a_mod) __asm__("sgemm_direct_sme1_preprocess"); + +extern void sgemm_direct_alpha_beta_sme1_2VLx2VL(uint64_t m, uint64_t k, uint64_t n, const float* alpha,\ + const float *ba, const float *restrict bb, const float* beta,\ + float *restrict C); +/* Function Definitions */ +static uint64_t sve_cntw() { + uint64_t cnt; + asm volatile( + "rdsvl %[res], #1\n" + "lsr %[res], %[res], #2\n" + : [res] "=r" (cnt) :: + ); + return cnt; +} + +#if defined(__ARM_FEATURE_SME) && defined(__ARM_FEATURE_LOCALLY_STREAMING) && defined(__clang__) && __clang_major__ >= 16 + +__arm_new("za") __arm_locally_streaming +static void ssymm_direct_sme1_preprocessLU(uint64_t nbr, uint64_t nbc, + const float *restrict a, float *restrict a_mod) +{ + // const uint64_t num_rows = nbr; + // const uint64_t num_cols = nbc; + const uint64_t svl = svcntw(); + uint64_t row_batch = svl; + + float *restrict pSrc; + float *restrict pDst; + for (uint64_t row_idx = 0; row_idx < nbr; row_idx += row_batch) + { + row_batch = MIN(row_batch, nbr - row_idx); + + // Fill in the lower triangle and Transpose 1SVL x N panel of A + uint64_t col_batch = svl; + + for (uint64_t col_idx = 0; col_idx < nbc; col_idx += col_batch) + { + svzero_za(); + + if (col_idx == row_idx) + { + pSrc = &a[(row_idx)*nbc + col_idx]; + pDst = &a_mod[(col_idx)*svl + row_idx * nbc]; + // Load horizontal slices, filling lower elements + const svbool_t pg_row = svwhilelt_b32_u64(col_idx, nbc); + for (int64_t row = row_batch - 1; row >= 0; row--) + { + svld1_hor_za32(0, row, pg_row, &pSrc[row * nbc]); + svld1_ver_za32(0, row, pg_row, &pSrc[row * nbc]); + } + // Save vertical slices + col_batch = MIN(col_batch, nbc - col_idx); + for (uint64_t col = 0; col < col_batch; col++) + { + svst1_ver_za32(0, col, svptrue_b32(), &pDst[col * svl]); + } + } + else if (col_idx > row_idx) + { + pSrc = &a[(row_idx)*nbc + col_idx]; + pDst = &a_mod[(col_idx)*svl + row_idx * nbc]; + // Load horizontal slices + const svbool_t pg_row = svwhilelt_b32_u64(col_idx, nbc); + for (uint64_t row = 0; row < row_batch; row++) + { + svld1_hor_za32(0, row, pg_row, &pSrc[row * nbc]); + } + // Save vertical slices + col_batch = MIN(col_batch, nbc - col_idx); + for (uint64_t col = 0; col < col_batch; col++) + { + svst1_ver_za32(0, col, svptrue_b32(), &pDst[col * svl]); + } + } + else if (col_idx < row_idx) + { + pSrc = &a[row_idx + col_idx * nbc]; + pDst = &a_mod[(col_idx)*svl + row_idx * nbc]; + // Load horizontal slices + const svbool_t pg_row = svwhilelt_b32_u64(row_idx, nbc); + for (uint64_t row = 0; row < svl; row++) + { + svld1_hor_za32(0, row, pg_row, &pSrc[row * nbc]); + } + // Save vertical slices + col_batch = MIN(col_batch, nbc - col_idx); + for (uint64_t col = 0; col < svl; col++) + { + svst1_hor_za32(0, col, svptrue_b32(), &pDst[col * svl]); + } + } + } + } +} + +// +__arm_new("za") __arm_locally_streaming +static void ssymm_direct_sme1_preprocessLL(uint64_t nbr, uint64_t nbc, + const float *restrict a, float *restrict a_mod) +{ + // const uint64_t num_rows = nbr; + const uint64_t svl = svcntw(); + uint64_t row_batch = svl; + + float *restrict pSrc; + float *restrict pDst; + for (uint64_t row_idx = 0; row_idx < nbr; row_idx += row_batch) + { + row_batch = MIN(row_batch, nbr - row_idx); + + // Fill in the upper triangle and Transpose 1SVL x N panel of A + uint64_t col_batch = svl; + + for (uint64_t col_idx = 0; col_idx < nbc; col_idx += col_batch) + { + svzero_za(); + + if (col_idx == row_idx) + { + pSrc = &a[(row_idx)*nbc + col_idx]; + pDst = &a_mod[(col_idx)*svl + row_idx * nbc]; + // Load horizontal slices, filling upper elements + const svbool_t pg_row = svwhilelt_b32_u64(col_idx, nbc); + for (uint64_t row = 0; row < row_batch; row++) + { + svld1_hor_za32(0, row, pg_row, &pSrc[row * nbc]); + svld1_ver_za32(0, row, pg_row, &pSrc[row * nbc]); + } + // Save vertical slices + col_batch = MIN(col_batch, nbc - col_idx); + for (uint64_t col = 0; col < col_batch; col++) + { + svst1_ver_za32(0, col, svptrue_b32(), &pDst[col * svl]); + } + } + else if (col_idx > row_idx) + { + pSrc = &a[row_idx + col_idx * nbc]; + pDst = &a_mod[(col_idx)*svl + row_idx * nbc]; + // Load horizontal slices + const svbool_t pg_row = svptrue_b32(); + for (uint64_t row = 0; row < row_batch; row++) + { + svld1_hor_za32(0, row, pg_row, &pSrc[row * nbc]); + } + // Save vertical slices + col_batch = MIN(col_batch, nbc - col_idx); + for (uint64_t col = 0; col < col_batch; col++) + { + svst1_hor_za32(0, col, svptrue_b32(), &pDst[col * svl]); + } + } + else if (col_idx < row_idx) + { + pSrc = &a[(row_idx)*nbc + col_idx]; + pDst = &a_mod[(col_idx)*svl + row_idx * nbc]; + // Load horizontal slices + const svbool_t pg_row = svwhilelt_b32_u64(col_idx, nbc); + for (uint64_t row = 0; row < row_batch; row++) + { + svld1_hor_za32(0, row, pg_row, &pSrc[row * nbc]); + } + // Save vertical slices + col_batch = MIN(col_batch, nbc - col_idx); + for (uint64_t col = 0; col < col_batch; col++) + { + svst1_ver_za32(0, col, svptrue_b32(), &pDst[col * svl]); + } + } + } + } +} + +#endif + +// +void CNAME(BLASLONG M, BLASLONG N, float alpha, float *__restrict A, + BLASLONG strideA, float *__restrict B, BLASLONG strideB, + float beta, float *__restrict R, BLASLONG strideR) +{ + uint64_t vl_elms = sve_cntw(); // vl_elem = 16 + uint64_t m_mod = ceil((double)M / (double)vl_elms) * vl_elms; + + /* Pre-process the left matrix to make it suitable for + matrix sum of outer-product calculation + */ + float *A_mod = (float *)malloc(m_mod * M * sizeof(float)); + +#if defined(UPPER) + ssymm_direct_sme1_preprocessLU(M, M, A, A_mod); +#elif defined(LOWER) + ssymm_direct_sme1_preprocessLL(M, M, A, A_mod); +#endif + + /* Calculate C = alpha*A*B + beta*C */ + sgemm_direct_alpha_beta_sme1_2VLx2VL(M, M, N, &alpha, A_mod, B, &beta, R); + free(A_mod); +} + +#else + +void CNAME (BLASLONG M, BLASLONG N, float alpha, float * __restrict A,\ + BLASLONG strideA, float * __restrict B, BLASLONG strideB ,\ + float beta, float * __restrict R, BLASLONG strideR){} + +#endif diff --git a/kernel/setparam-ref.c b/kernel/setparam-ref.c index ccfbab8c11..df455cd5d8 100644 --- a/kernel/setparam-ref.c +++ b/kernel/setparam-ref.c @@ -216,6 +216,8 @@ gotoblas_t TABLE_NAME = { #ifdef ARCH_ARM64 sgemm_directTS, sgemm_direct_alpha_betaTS, + ssymm_direct_alpha_betaLUTS, + ssymm_direct_alpha_betaLLTS, #endif sgemm_kernelTS, sgemm_betaTS, diff --git a/param.h b/param.h index d0ee246e83..9ceb6b58da 100644 --- a/param.h +++ b/param.h @@ -3844,6 +3844,7 @@ Until then, just keep it different than DGEMM_DEFAULT_UNROLL_N to keep copy rout #if defined(ARMV9SME) /* ARMv9 SME */ #define USE_SGEMM_KERNEL_DIRECT 1 +#define USE_SSYMM_KERNEL_DIRECT 1 #endif /* ARMv9 SME */ #if defined(ARMV5)