diff --git a/test/Makefile b/test/Makefile index f29bd35471..62585b29c2 100644 --- a/test/Makefile +++ b/test/Makefile @@ -234,6 +234,9 @@ ifeq ($(BUILD_BFLOAT16),1) BF3= test_bgemm B3 = test_sbgemm endif +ifeq ($(BUILD_HFLOAT16),1) +H3 = test_shgemm +endif ifeq ($(BUILD_SINGLE),1) S3=sblat3 endif @@ -257,9 +260,9 @@ endif ifeq ($(SUPPORT_GEMM3M),1) -level3: $(BF3) $(B3) $(S3) $(D3) $(C3) $(Z3) level3_3m +level3: $(BF3) $(B3) $(H3) $(S3) $(D3) $(C3) $(Z3) level3_3m else -level3: $(BF3) $(B3) $(S3) $(D3) $(C3) $(Z3) +level3: $(BF3) $(B3) $(H3) $(S3) $(D3) $(C3) $(Z3) endif ifneq ($(CROSS), 1) @@ -454,6 +457,9 @@ test_sbgemv : compare_sgemv_sbgemv.c ../$(LIBNAME) endif ifeq ($(BUILD_HFLOAT16),1) +test_shgemm : compare_sgemm_shgemm.c test_helpers.h ../$(LIBNAME) + $(CC) $(CLDFLAGS) -DIHFLOAT16 -o test_shgemm compare_sgemm_shgemm.c ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB) + test_shgemv : compare_sgemv_shgemv.c ../$(LIBNAME) $(CC) $(CLDFLAGS) -o test_shgemv compare_sgemv_shgemv.c ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB) endif @@ -475,7 +481,7 @@ clean: @rm -f *.$(SUFFIX) *.$(PSUFFIX) gmon.$(SUFFIX)ut *.SUMM *.cxml *.exe *.pdb *.dwf \ sblat1 dblat1 cblat1 zblat1 \ sblat2 dblat2 cblat2 zblat2 \ - test_bgemm test_bgemv test_sbgemm test_sbgemv test_shgemv sblat3 dblat3 cblat3 zblat3 \ + test_bgemm test_bgemv test_sbgemm test_sbgemv test_shgemm test_shgemv sblat3 dblat3 cblat3 zblat3 \ sblat1p dblat1p cblat1p zblat1p \ sblat2p dblat2p cblat2p zblat2p \ sblat3p dblat3p cblat3p zblat3p \ diff --git a/test/compare_sgemm_shgemm.c b/test/compare_sgemm_shgemm.c new file mode 100644 index 0000000000..11b6f39f59 --- /dev/null +++ b/test/compare_sgemm_shgemm.c @@ -0,0 +1,145 @@ +/*************************************************************************** +Copyright (c) 2020,2025 The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ +#include +#include +#include "../common.h" + +#include "test_helpers.h" + +#define SGEMM BLASFUNC(sgemm) +#define SHGEMM BLASFUNC(shgemm) +#define SHGEMM_LARGEST 256 + +int +main (int argc, char *argv[]) +{ + blasint m, n, k; + int i, j, l; + blasint x, y; + int ret = 0; + int loop = SHGEMM_LARGEST; + char transA = 'N', transB = 'N'; + float alpha = 1.0, beta = 0.0; + + for (x = 0; x <= loop; x++) + { + if ((x > 100) && (x != SHGEMM_LARGEST)) continue; + m = k = n = x; + float *A = (float *)malloc_safe(m * k * sizeof(FLOAT)); + float *B = (float *)malloc_safe(k * n * sizeof(FLOAT)); + float *C = (float *)malloc_safe(m * n * sizeof(FLOAT)); + hfloat16 *AA = (hfloat16 *)malloc_safe(m * k * sizeof(hfloat16)); + hfloat16 *BB = (hfloat16 *)malloc_safe(k * n * sizeof(hfloat16)); + float *DD = (float *)malloc_safe(m * n * sizeof(FLOAT)); + float *CC = (float *)malloc_safe(m * n * sizeof(FLOAT)); + if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) || + (DD == NULL) || (CC == NULL)) + return 1; + + for (j = 0; j < m; j++) + { + for (i = 0; i < k; i++) + { + A[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; + AA[j * k + i] = (hfloat16) A[j * k + i]; + } + } + for (j = 0; j < n; j++) + { + for (i = 0; i < k; i++) + { + B[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; + BB[j * k + i] = (hfloat16) A[j * k + i]; + } + } + for (y = 0; y < 4; y++) + { + if ((y == 0) || (y == 2)) { + transA = 'N'; + } else { + transA = 'T'; + } + if ((y == 0) || (y == 1)) { + transB = 'N'; + } else { + transB = 'T'; + } + + memset(CC, 0, m * n * sizeof(FLOAT)); + memset(DD, 0, m * n * sizeof(FLOAT)); + memset(C, 0, m * n * sizeof(FLOAT)); + + SGEMM (&transA, &transB, &m, &n, &k, &alpha, A, + &m, B, &k, &beta, C, &m); + SHGEMM (&transA, &transB, &m, &n, &k, &alpha, AA, + &m, BB, &k, &beta, CC, &m); + + for (i = 0; i < n; i++) + for (j = 0; j < m; j++) + { + for (l = 0; l < k; l++) + if (transA == 'N' && transB == 'N') + { + DD[i * m + j] += + (float) AA[l * m + j] * (float)BB[l + k * i]; + } else if (transA == 'T' && transB == 'N') + { + DD[i * m + j] += + (float)AA[k * j + l] * (float)BB[l + k * i]; + } else if (transA == 'N' && transB == 'T') + { + DD[i * m + j] += + (float)AA[l * m + j] * (float)BB[i + l * n]; + } else if (transA == 'T' && transB == 'T') + { + DD[i * m + j] += + (float)AA[k * j + l] * (float)BB[i + l * n]; + } + if (!is_close(CC[i * m + j], C[i * m + j], 0.01, 0.001)) { + ret++; + } + if (!is_close(CC[i * m + j], DD[i * m + j], 0.001, 0.0001)) { + ret++; + } + } + } + free(A); + free(B); + free(C); + free(AA); + free(BB); + free(DD); + free(CC); + } + + if (ret != 0) { + fprintf(stderr, "SHGEMM FAILURES: %d\n", ret); + return 1; + } + + return ret; +}