Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions common_level3.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,19 @@ void sgemm_direct_alpha_beta(BLASLONG M, BLASLONG N, BLASLONG K,
float beta,
float * R, BLASLONG strideR);

void ssymm_direct_alpha_betaLU(BLASLONG M, BLASLONG N,
float alpha,
float * A, BLASLONG strideA,
float * B, BLASLONG strideB,
float beta,
float * R, BLASLONG strideR);
void ssymm_direct_alpha_betaLL(BLASLONG M, BLASLONG N,
float alpha,
float * A, BLASLONG strideA,
float * B, BLASLONG strideB,
float beta,
float * R, BLASLONG strideR);

int sgemm_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K);

int shgemm_beta(BLASLONG, BLASLONG, BLASLONG, float,
Expand Down
2 changes: 2 additions & 0 deletions common_param.h
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,8 @@ int (*shgemm_otcopy )(BLASLONG, BLASLONG, hfloat16 *, BLASLONG, hfloat16 *);
#ifdef ARCH_ARM64
void (*sgemm_direct) (BLASLONG, BLASLONG, BLASLONG, float *, BLASLONG , float *, BLASLONG , float * , BLASLONG);
void (*sgemm_direct_alpha_beta) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float, float * , BLASLONG);
void (*ssymm_direct_alpha_betaLU) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float, float * , BLASLONG);
void (*ssymm_direct_alpha_betaLL) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float, float * , BLASLONG);
#endif


Expand Down
4 changes: 4 additions & 0 deletions common_s.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
#define SGEMM_DIRECT_PERFORMANT sgemm_direct_performant
#define SGEMM_DIRECT sgemm_direct
#define SGEMM_DIRECT_ALPHA_BETA sgemm_direct_alpha_beta
#define SSYMM_DIRECT_ALPHA_BETA_LU ssymm_direct_alpha_betaLU
#define SSYMM_DIRECT_ALPHA_BETA_LL ssymm_direct_alpha_betaLL

#define SGEMM_ONCOPY sgemm_oncopy
#define SGEMM_OTCOPY sgemm_otcopy
Expand Down Expand Up @@ -220,6 +222,8 @@
#define SGEMM_DIRECT_PERFORMANT sgemm_direct_performant
#define SGEMM_DIRECT gotoblas -> sgemm_direct
#define SGEMM_DIRECT_ALPHA_BETA gotoblas -> sgemm_direct_alpha_beta
#define SSYMM_DIRECT_ALPHA_BETA_LU gotoblas -> ssymm_direct_alpha_betaLU
#define SSYMM_DIRECT_ALPHA_BETA_LL gotoblas -> ssymm_direct_alpha_betaLL
#endif

#define SGEMM_ONCOPY gotoblas -> sgemm_oncopy
Expand Down
18 changes: 18 additions & 0 deletions interface/symm.c
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,24 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_SIDE Side, enum CBLAS_UPLO Uplo,
return;
}

#if !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16) && !defined(HFLOAT16)
#if defined(ARCH_ARM64) && (defined(USE_SSYMM_KERNEL_DIRECT)||defined(DYNAMIC_ARCH))
#if defined(DYNAMIC_ARCH)
if (support_sme1())
#endif
if (args.m == 0 || args.n == 0) return;
if (order == CblasRowMajor && m == lda && n == ldb && n == ldc)
{
if (Side == CblasLeft && Uplo == CblasUpper) {
SSYMM_DIRECT_ALPHA_BETA_LU(m, n, alpha, a, lda, b, ldb, beta, c, ldc); return;
}
else if (Side == CblasLeft && Uplo == CblasLower) {
SSYMM_DIRECT_ALPHA_BETA_LL(m, n, alpha, a, lda, b, ldb, beta, c, ldc); return;
}
}
#endif
#endif

#endif

if (args.m == 0 || args.n == 0) return;
Expand Down
12 changes: 12 additions & 0 deletions kernel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,10 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS)
if (X86_64 OR ARM64)
set(USE_DIRECT_SGEMM true)
endif()
set(USE_DIRECT_SSYMM false)
if (ARM64)
set(USE_DIRECT_SSYMM true)
endif()
if (UC_TARGET_CORE MATCHES ARMV9SME)
set (HAVE_SME true)
endif ()
Expand All @@ -267,6 +271,14 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS)
endif ()
endif()

if (USE_DIRECT_SSYMM)
if (ARM64)
set (SSYMMDIRECTKERNEL_ALPHA_BETA ssymm_direct_alpha_beta_arm64_sme1.c)
GenerateNamedObjects("${KERNELDIR}/${SSYMMDIRECTKERNEL_ALPHA_BETA}" "" "symm_direct_alpha_betaLU" false "" "" false SINGLE)
GenerateNamedObjects("${KERNELDIR}/${SSYMMDIRECTKERNEL_ALPHA_BETA}" "" "symm_direct_alpha_betaLL" false "" "" false SINGLE)
endif ()
endif()

foreach (float_type SINGLE DOUBLE)
string(SUBSTRING ${float_type} 0 1 float_char)
GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMMKERNEL}" "" "gemm_kernel" false "" "" false ${float_type})
Expand Down
29 changes: 29 additions & 0 deletions kernel/Makefile.L3
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ endif
ifeq ($(ARCH), arm64)
USE_TRMM = 1
USE_DIRECT_SGEMM = 1
USE_DIRECT_SSYMM = 1
endif

ifeq ($(ARCH), riscv64)
Expand Down Expand Up @@ -137,6 +138,17 @@ endif
endif
endif

ifdef USE_DIRECT_SSYMM
ifndef SSYMMDIRECTKERNEL_ALPHA_BETA
ifeq ($(ARCH), arm64)
ifeq ($(TARGET_CORE), ARMV9SME)
HAVE_SME = 1
endif
SSYMMDIRECTKERNEL_ALPHA_BETA = ssymm_direct_alpha_beta_arm64_sme1.c
endif
endif
endif

ifeq ($(BUILD_BFLOAT16), 1)
ifndef BGEMMKERNEL
BGEMM_BETA = ../generic/gemm_beta.c
Expand Down Expand Up @@ -220,6 +232,14 @@ endif
endif
endif

ifdef USE_DIRECT_SSYMM
ifeq ($(ARCH), arm64)
SKERNELOBJS += \
ssymm_direct_alpha_betaLU$(TSUFFIX).$(SUFFIX) \
ssymm_direct_alpha_betaLL$(TSUFFIX).$(SUFFIX)
endif
endif

ifneq "$(or $(BUILD_DOUBLE),$(BUILD_COMPLEX16))" ""
DKERNELOBJS += \
dgemm_beta$(TSUFFIX).$(SUFFIX) \
Expand Down Expand Up @@ -982,6 +1002,15 @@ endif
endif
endif

ifdef USE_DIRECT_SSYMM
ifeq ($(ARCH), arm64)
$(KDIR)ssymm_direct_alpha_betaLU$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SSYMMDIRECTKERNEL_ALPHA_BETA)
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -DLEFT -DUPPER $< -o $@
$(KDIR)ssymm_direct_alpha_betaLL$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SSYMMDIRECTKERNEL_ALPHA_BETA)
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -DLEFT -DLOWER $< -o $@
endif
endif

ifeq ($(BUILD_BFLOAT16), 1)
$(KDIR)bgemm_kernel$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(BGEMMKERNEL)
$(CC) $(CFLAGS) -c -DBFLOAT16 -DBGEMM -UDOUBLE -UCOMPLEX $< -o $@
Expand Down
225 changes: 225 additions & 0 deletions kernel/arm64/ssymm_direct_alpha_beta_arm64_sme1.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
/*
Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
SPDX-License-Identifier: BSD-3-Clause-Clear
*/

#include "common.h"
#include <stdlib.h>
#include <inttypes.h>
#include <math.h>
// #include "sme_abi.h"
#if defined(HAVE_SME)

#if defined(__ARM_FEATURE_SME) && defined(__clang__) && __clang_major__ >= 16
#include <arm_sme.h>
#endif

/* Function prototypes */
extern void sgemm_direct_sme1_preprocess(uint64_t nbr, uint64_t nbc,\
const float * restrict a, float * a_mod) __asm__("sgemm_direct_sme1_preprocess");

extern void sgemm_direct_alpha_beta_sme1_2VLx2VL(uint64_t m, uint64_t k, uint64_t n, const float* alpha,\
const float *ba, const float *restrict bb, const float* beta,\
float *restrict C);
/* Function Definitions */
static uint64_t sve_cntw() {
uint64_t cnt;
asm volatile(
"rdsvl %[res], #1\n"
"lsr %[res], %[res], #2\n"
: [res] "=r" (cnt) ::
);
return cnt;
}

#if defined(__ARM_FEATURE_SME) && defined(__ARM_FEATURE_LOCALLY_STREAMING) && defined(__clang__) && __clang_major__ >= 16

__arm_new("za") __arm_locally_streaming
static void ssymm_direct_sme1_preprocessLU(uint64_t nbr, uint64_t nbc,
const float *restrict a, float *restrict a_mod)
{
// const uint64_t num_rows = nbr;
// const uint64_t num_cols = nbc;
const uint64_t svl = svcntw();
uint64_t row_batch = svl;

float *restrict pSrc;
float *restrict pDst;
for (uint64_t row_idx = 0; row_idx < nbr; row_idx += row_batch)
{
row_batch = MIN(row_batch, nbr - row_idx);

// Fill in the lower triangle and Transpose 1SVL x N panel of A
uint64_t col_batch = svl;

for (uint64_t col_idx = 0; col_idx < nbc; col_idx += col_batch)
{
svzero_za();

if (col_idx == row_idx)
{
pSrc = &a[(row_idx)*nbc + col_idx];
pDst = &a_mod[(col_idx)*svl + row_idx * nbc];
// Load horizontal slices, filling lower elements
const svbool_t pg_row = svwhilelt_b32_u64(col_idx, nbc);
for (int64_t row = row_batch - 1; row >= 0; row--)
{
svld1_hor_za32(0, row, pg_row, &pSrc[row * nbc]);
svld1_ver_za32(0, row, pg_row, &pSrc[row * nbc]);
}
// Save vertical slices
col_batch = MIN(col_batch, nbc - col_idx);
for (uint64_t col = 0; col < col_batch; col++)
{
svst1_ver_za32(0, col, svptrue_b32(), &pDst[col * svl]);
}
}
else if (col_idx > row_idx)
{
pSrc = &a[(row_idx)*nbc + col_idx];
pDst = &a_mod[(col_idx)*svl + row_idx * nbc];
// Load horizontal slices
const svbool_t pg_row = svwhilelt_b32_u64(col_idx, nbc);
for (uint64_t row = 0; row < row_batch; row++)
{
svld1_hor_za32(0, row, pg_row, &pSrc[row * nbc]);
}
// Save vertical slices
col_batch = MIN(col_batch, nbc - col_idx);
for (uint64_t col = 0; col < col_batch; col++)
{
svst1_ver_za32(0, col, svptrue_b32(), &pDst[col * svl]);
}
}
else if (col_idx < row_idx)
{
pSrc = &a[row_idx + col_idx * nbc];
pDst = &a_mod[(col_idx)*svl + row_idx * nbc];
// Load horizontal slices
const svbool_t pg_row = svwhilelt_b32_u64(row_idx, nbc);
for (uint64_t row = 0; row < svl; row++)
{
svld1_hor_za32(0, row, pg_row, &pSrc[row * nbc]);
}
// Save vertical slices
col_batch = MIN(col_batch, nbc - col_idx);
for (uint64_t col = 0; col < svl; col++)
{
svst1_hor_za32(0, col, svptrue_b32(), &pDst[col * svl]);
}
}
}
}
}

//
__arm_new("za") __arm_locally_streaming
static void ssymm_direct_sme1_preprocessLL(uint64_t nbr, uint64_t nbc,
const float *restrict a, float *restrict a_mod)
{
// const uint64_t num_rows = nbr;
const uint64_t svl = svcntw();
uint64_t row_batch = svl;

float *restrict pSrc;
float *restrict pDst;
for (uint64_t row_idx = 0; row_idx < nbr; row_idx += row_batch)
{
row_batch = MIN(row_batch, nbr - row_idx);

// Fill in the upper triangle and Transpose 1SVL x N panel of A
uint64_t col_batch = svl;

for (uint64_t col_idx = 0; col_idx < nbc; col_idx += col_batch)
{
svzero_za();

if (col_idx == row_idx)
{
pSrc = &a[(row_idx)*nbc + col_idx];
pDst = &a_mod[(col_idx)*svl + row_idx * nbc];
// Load horizontal slices, filling upper elements
const svbool_t pg_row = svwhilelt_b32_u64(col_idx, nbc);
for (uint64_t row = 0; row < row_batch; row++)
{
svld1_hor_za32(0, row, pg_row, &pSrc[row * nbc]);
svld1_ver_za32(0, row, pg_row, &pSrc[row * nbc]);
}
// Save vertical slices
col_batch = MIN(col_batch, nbc - col_idx);
for (uint64_t col = 0; col < col_batch; col++)
{
svst1_ver_za32(0, col, svptrue_b32(), &pDst[col * svl]);
}
}
else if (col_idx > row_idx)
{
pSrc = &a[row_idx + col_idx * nbc];
pDst = &a_mod[(col_idx)*svl + row_idx * nbc];
// Load horizontal slices
const svbool_t pg_row = svptrue_b32();
for (uint64_t row = 0; row < row_batch; row++)
{
svld1_hor_za32(0, row, pg_row, &pSrc[row * nbc]);
}
// Save vertical slices
col_batch = MIN(col_batch, nbc - col_idx);
for (uint64_t col = 0; col < col_batch; col++)
{
svst1_hor_za32(0, col, svptrue_b32(), &pDst[col * svl]);
}
}
else if (col_idx < row_idx)
{
pSrc = &a[(row_idx)*nbc + col_idx];
pDst = &a_mod[(col_idx)*svl + row_idx * nbc];
// Load horizontal slices
const svbool_t pg_row = svwhilelt_b32_u64(col_idx, nbc);
for (uint64_t row = 0; row < row_batch; row++)
{
svld1_hor_za32(0, row, pg_row, &pSrc[row * nbc]);
}
// Save vertical slices
col_batch = MIN(col_batch, nbc - col_idx);
for (uint64_t col = 0; col < col_batch; col++)
{
svst1_ver_za32(0, col, svptrue_b32(), &pDst[col * svl]);
}
}
}
}
}

#endif

//
void CNAME(BLASLONG M, BLASLONG N, float alpha, float *__restrict A,
BLASLONG strideA, float *__restrict B, BLASLONG strideB,
float beta, float *__restrict R, BLASLONG strideR)
{
uint64_t vl_elms = sve_cntw(); // vl_elem = 16
uint64_t m_mod = ceil((double)M / (double)vl_elms) * vl_elms;

/* Pre-process the left matrix to make it suitable for
matrix sum of outer-product calculation
*/
float *A_mod = (float *)malloc(m_mod * M * sizeof(float));

#if defined(UPPER)
ssymm_direct_sme1_preprocessLU(M, M, A, A_mod);
#elif defined(LOWER)
ssymm_direct_sme1_preprocessLL(M, M, A, A_mod);
#endif

/* Calculate C = alpha*A*B + beta*C */
sgemm_direct_alpha_beta_sme1_2VLx2VL(M, M, N, &alpha, A_mod, B, &beta, R);
free(A_mod);
}

#else

void CNAME (BLASLONG M, BLASLONG N, float alpha, float * __restrict A,\
BLASLONG strideA, float * __restrict B, BLASLONG strideB ,\
float beta, float * __restrict R, BLASLONG strideR){}

#endif
2 changes: 2 additions & 0 deletions kernel/setparam-ref.c
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,8 @@ gotoblas_t TABLE_NAME = {
#ifdef ARCH_ARM64
sgemm_directTS,
sgemm_direct_alpha_betaTS,
ssymm_direct_alpha_betaLUTS,
ssymm_direct_alpha_betaLLTS,
#endif

sgemm_kernelTS, sgemm_betaTS,
Expand Down
1 change: 1 addition & 0 deletions param.h
Original file line number Diff line number Diff line change
Expand Up @@ -3844,6 +3844,7 @@ Until then, just keep it different than DGEMM_DEFAULT_UNROLL_N to keep copy rout

#if defined(ARMV9SME) /* ARMv9 SME */
#define USE_SGEMM_KERNEL_DIRECT 1
#define USE_SSYMM_KERNEL_DIRECT 1
#endif /* ARMv9 SME */

#if defined(ARMV5)
Expand Down
Loading