Skip to content

Commit 016e2f1

Browse files
authored
Merge pull request #5499 from martin-frbg/issue5497
Add test for SHGEMM
2 parents c35b11a + 098a8d5 commit 016e2f1

File tree

3 files changed

+283
-3
lines changed

3 files changed

+283
-3
lines changed

test/CMakeLists.txt

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,28 @@ foreach(test_bin ${OpenBLAS_Tests})
3636
target_link_libraries(${test_bin} ${OpenBLAS_LIBNAME})
3737
endforeach()
3838

39+
if (BUILD_BFLOAT16)
40+
add_executable(test_bgemm compare_sgemm_bgemm.c)
41+
target_compile_definitions(test_bgemm PUBLIC -DIBFLOAT16 -DOBFLOAT16)
42+
target_link_libraries(test_bgemm ${OpenBLAS_LIBNAME})
43+
add_executable(test_bgemv compare_sgemv_bgemv.c)
44+
target_compile_definitions(test_bgemv PUBLIC -DIBFLOAT16 -DOBFLOAT16)
45+
target_link_libraries(test_bgemv ${OpenBLAS_LIBNAME})
46+
add_executable(test_sbgemm compare_sgemm_sbgemm.c)
47+
target_compile_definitions(test_sbgemm PUBLIC -DIBFLOAT16)
48+
target_link_libraries(test_sbgemm ${OpenBLAS_LIBNAME})
49+
add_executable(test_sbgemv compare_sgemv_sbgemv.c)
50+
target_compile_definitions(test_sbgemv PUBLIC -DIBFLOAT16)
51+
target_link_libraries(test_sbgemv ${OpenBLAS_LIBNAME})
52+
endif()
53+
54+
if (BUILD_HFLOAT16)
55+
add_executable(test_shgemm compare_sgemm_shgemm.c)
56+
target_link_libraries(test_shgemm ${OpenBLAS_LIBNAME})
57+
add_executable(test_shgemv compare_sgemv_shgemv.c)
58+
target_link_libraries(test_shgemv ${OpenBLAS_LIBNAME})
59+
endif()
60+
3961
# $1 exec, $2 input, $3 output_result
4062
if(WIN32)
4163
FILE(WRITE ${CMAKE_CURRENT_BINARY_DIR}/test_helper.ps1
@@ -94,3 +116,21 @@ add_test(NAME "${float_type}blas3_3m"
94116
endif()
95117
endif()
96118
endforeach()
119+
120+
if (BUILD_BFLOAT16)
121+
add_test(NAME "bgemm"
122+
COMMAND $<TARGET_FILE:test_bgemm>)
123+
add_test(NAME "bgemv"
124+
COMMAND $<TARGET_FILE:test_bgemv>)
125+
add_test(NAME "sbgemm"
126+
COMMAND $<TARGET_FILE:test_sbgemm>)
127+
add_test(NAME "sbgemv"
128+
COMMAND $<TARGET_FILE:test_sbgemv>)
129+
endif()
130+
131+
if (BUILD_HFLOAT16)
132+
add_test(NAME "shgemm"
133+
COMMAND $<TARGET_FILE:test_shgemm>)
134+
add_test(NAME "shgemv"
135+
COMMAND $<TARGET_FILE:test_shgemv>)
136+
endif()

test/Makefile

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,9 @@ ifeq ($(BUILD_BFLOAT16),1)
234234
BF3= test_bgemm
235235
B3 = test_sbgemm
236236
endif
237+
ifeq ($(BUILD_HFLOAT16),1)
238+
H3 = test_shgemm
239+
endif
237240
ifeq ($(BUILD_SINGLE),1)
238241
S3=sblat3
239242
endif
@@ -257,9 +260,9 @@ endif
257260

258261

259262
ifeq ($(SUPPORT_GEMM3M),1)
260-
level3: $(BF3) $(B3) $(S3) $(D3) $(C3) $(Z3) level3_3m
263+
level3: $(BF3) $(B3) $(H3) $(S3) $(D3) $(C3) $(Z3) level3_3m
261264
else
262-
level3: $(BF3) $(B3) $(S3) $(D3) $(C3) $(Z3)
265+
level3: $(BF3) $(B3) $(H3) $(S3) $(D3) $(C3) $(Z3)
263266
endif
264267

265268
ifneq ($(CROSS), 1)
@@ -454,6 +457,9 @@ test_sbgemv : compare_sgemv_sbgemv.c ../$(LIBNAME)
454457
endif
455458

456459
ifeq ($(BUILD_HFLOAT16),1)
460+
test_shgemm : compare_sgemm_shgemm.c test_helpers.h ../$(LIBNAME)
461+
$(CC) $(CLDFLAGS) -DIHFLOAT16 -o test_shgemm compare_sgemm_shgemm.c ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB)
462+
457463
test_shgemv : compare_sgemv_shgemv.c ../$(LIBNAME)
458464
$(CC) $(CLDFLAGS) -o test_shgemv compare_sgemv_shgemv.c ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB)
459465
endif
@@ -475,7 +481,7 @@ clean:
475481
@rm -f *.$(SUFFIX) *.$(PSUFFIX) gmon.$(SUFFIX)ut *.SUMM *.cxml *.exe *.pdb *.dwf \
476482
sblat1 dblat1 cblat1 zblat1 \
477483
sblat2 dblat2 cblat2 zblat2 \
478-
test_bgemm test_bgemv test_sbgemm test_sbgemv test_shgemv sblat3 dblat3 cblat3 zblat3 \
484+
test_bgemm test_bgemv test_sbgemm test_sbgemv test_shgemm test_shgemv sblat3 dblat3 cblat3 zblat3 \
479485
sblat1p dblat1p cblat1p zblat1p \
480486
sblat2p dblat2p cblat2p zblat2p \
481487
sblat3p dblat3p cblat3p zblat3p \

test/compare_sgemm_shgemm.c

Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
/***************************************************************************
2+
Copyright (c) 2020,2025 The OpenBLAS Project
3+
All rights reserved.
4+
Redistribution and use in source and binary forms, with or without
5+
modification, are permitted provided that the following conditions are
6+
met:
7+
1. Redistributions of source code must retain the above copyright
8+
notice, this list of conditions and the following disclaimer.
9+
2. Redistributions in binary form must reproduce the above copyright
10+
notice, this list of conditions and the following disclaimer in
11+
the documentation and/or other materials provided with the
12+
distribution.
13+
3. Neither the name of the OpenBLAS project nor the names of
14+
its contributors may be used to endorse or promote products
15+
derived from this software without specific prior written permission.
16+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
20+
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
25+
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26+
*****************************************************************************/
27+
#include <stdio.h>
28+
#include <stdint.h>
29+
#include "../common.h"
30+
31+
#include "test_helpers.h"
32+
33+
#define SGEMM BLASFUNC(sgemm)
34+
#define SHGEMM BLASFUNC(shgemm)
35+
#define SHGEMM_LARGEST 256
36+
37+
int
38+
main (int argc, char *argv[])
39+
{
40+
blasint m, n, k;
41+
int i, j, l;
42+
blasint x, y;
43+
int ret = 0;
44+
int rret = 0;
45+
int loop = SHGEMM_LARGEST;
46+
char transA = 'N', transB = 'N';
47+
float alpha = 1.0, beta = 0.0;
48+
int xvals[6]={3,24,55,71,SHGEMM_LARGEST/2,SHGEMM_LARGEST};
49+
50+
for (x = 0; x <= loop; x++)
51+
{
52+
if ((x > 100) && (x != SHGEMM_LARGEST)) continue;
53+
m = k = n = x;
54+
float *A = (float *)malloc_safe(m * k * sizeof(FLOAT));
55+
float *B = (float *)malloc_safe(k * n * sizeof(FLOAT));
56+
float *C = (float *)malloc_safe(m * n * sizeof(FLOAT));
57+
_Float16 *AA = (_Float16 *)malloc_safe(m * k * sizeof(_Float16));
58+
_Float16 *BB = (_Float16 *)malloc_safe(k * n * sizeof(_Float16));
59+
float *DD = (float *)malloc_safe(m * n * sizeof(FLOAT));
60+
float *CC = (float *)malloc_safe(m * n * sizeof(FLOAT));
61+
if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) ||
62+
(DD == NULL) || (CC == NULL))
63+
return 1;
64+
65+
for (j = 0; j < m; j++)
66+
{
67+
for (i = 0; i < k; i++)
68+
{
69+
A[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
70+
AA[j * k + i] = (_Float16) A[j * k + i];
71+
}
72+
}
73+
for (j = 0; j < n; j++)
74+
{
75+
for (i = 0; i < k; i++)
76+
{
77+
B[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
78+
BB[j * k + i] = (_Float16) B[j * k + i];
79+
}
80+
}
81+
for (y = 0; y < 4; y++)
82+
{
83+
if ((y == 0) || (y == 2)) {
84+
transA = 'N';
85+
} else {
86+
transA = 'T';
87+
}
88+
if ((y == 0) || (y == 1)) {
89+
transB = 'N';
90+
} else {
91+
transB = 'T';
92+
}
93+
94+
memset(CC, 0, m * n * sizeof(FLOAT));
95+
memset(DD, 0, m * n * sizeof(FLOAT));
96+
memset(C, 0, m * n * sizeof(FLOAT));
97+
98+
SGEMM (&transA, &transB, &m, &n, &k, &alpha, A,
99+
&m, B, &k, &beta, C, &m);
100+
SHGEMM (&transA, &transB, &m, &n, &k, &alpha, (_Float16*) AA,
101+
&m, (_Float16*)BB, &k, &beta, CC, &m);
102+
103+
for (i = 0; i < n; i++)
104+
for (j = 0; j < m; j++)
105+
{
106+
for (l = 0; l < k; l++)
107+
if (transA == 'N' && transB == 'N')
108+
{
109+
DD[i * m + j] +=
110+
(float) AA[l * m + j] * (float)BB[l + k * i];
111+
} else if (transA == 'T' && transB == 'N')
112+
{
113+
DD[i * m + j] +=
114+
(float)AA[k * j + l] * (float)BB[l + k * i];
115+
} else if (transA == 'N' && transB == 'T')
116+
{
117+
DD[i * m + j] +=
118+
(float)AA[l * m + j] * (float)BB[i + l * n];
119+
} else if (transA == 'T' && transB == 'T')
120+
{
121+
DD[i * m + j] +=
122+
(float)AA[k * j + l] * (float)BB[i + l * n];
123+
}
124+
if (!is_close(CC[i * m + j], C[i * m + j], 0.01, 0.001)) {
125+
fprintf(stderr,"CC %f C %f \n",(float)CC[i*m+j],C[i*m+j]);
126+
ret++;
127+
}
128+
if (!is_close(CC[i * m + j], DD[i * m + j], 0.001, 0.0001)) {
129+
fprintf(stderr,"CC %f DD %f \n",(float)CC[i*m+j],(float)DD[i*m+j]);
130+
ret++;
131+
}
132+
}
133+
}
134+
free(A);
135+
free(B);
136+
free(C);
137+
free(AA);
138+
free(BB);
139+
free(DD);
140+
free(CC);
141+
}
142+
if (ret != 0) {
143+
fprintf(stderr, "SHGEMM FAILURES: %d!!!\n", ret);
144+
return 1;
145+
}
146+
147+
148+
for (loop = 0; loop<6; loop++) {
149+
x=xvals[loop];
150+
for (alpha=0.;alpha<=1.;alpha+=0.5)
151+
{
152+
for (beta = 0.0; beta <=1.; beta+=0.5) {
153+
154+
m = k = n = x;
155+
float *A = (float *)malloc_safe(m * k * sizeof(FLOAT));
156+
float *B = (float *)malloc_safe(k * n * sizeof(FLOAT));
157+
float *C = (float *)malloc_safe(m * n * sizeof(FLOAT));
158+
_Float16 *AA = (_Float16 *)malloc_safe(m * k * sizeof(_Float16));
159+
_Float16 *BB = (_Float16 *)malloc_safe(k * n * sizeof(_Float16));
160+
float *CC = (float *)malloc_safe(m * n * sizeof(FLOAT));
161+
if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) ||
162+
(CC == NULL))
163+
return 1;
164+
165+
for (j = 0; j < m; j++)
166+
{
167+
for (i = 0; i < k; i++)
168+
{
169+
A[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
170+
AA[j * k + i] = (_Float16) A[j * k + i];
171+
}
172+
}
173+
for (j = 0; j < n; j++)
174+
{
175+
for (i = 0; i < k; i++)
176+
{
177+
B[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
178+
BB[j * k + i] = (_Float16) B[j * k + i];
179+
}
180+
}
181+
182+
for (y = 0; y < 4; y++)
183+
{
184+
if ((y == 0) || (y == 2)) {
185+
transA = 'N';
186+
} else {
187+
transA = 'T';
188+
}
189+
if ((y == 0) || (y == 1)) {
190+
transB = 'N';
191+
} else {
192+
transB = 'T';
193+
}
194+
195+
memset(CC, 0, m * n * sizeof(FLOAT));
196+
memset(C, 0, m * n * sizeof(FLOAT));
197+
198+
SGEMM (&transA, &transB, &m, &n, &k, &alpha, A,
199+
&m, B, &k, &beta, C, &m);
200+
SHGEMM (&transA, &transB, &m, &n, &k, &alpha, (_Float16*) AA,
201+
&m, (_Float16*)BB, &k, &beta, CC, &m);
202+
203+
for (i = 0; i < n; i++)
204+
for (j = 0; j < m; j++)
205+
{
206+
if (!is_close(CC[i * m + j], C[i * m + j], 0.01, 0.001)) {
207+
ret++;
208+
}
209+
}
210+
}
211+
free(A);
212+
free(B);
213+
free(C);
214+
free(AA);
215+
free(BB);
216+
free(CC);
217+
218+
if (ret != 0) {
219+
/*
220+
* fprintf(stderr, "SHGEMM FAILURES FOR n=%d, alpha=%f beta=%f : %d\n", x, alpha, beta, ret);
221+
*/
222+
rret++;
223+
ret=0;
224+
/* } else {
225+
fprintf(stderr, "SHGEMM SUCCEEDED FOR n=%d, alpha=%f beta=%f : %d\n", x, alpha, beta, ret);
226+
*/
227+
}
228+
}
229+
230+
}
231+
}
232+
if (rret > 0) return(1);
233+
return(0);
234+
}

0 commit comments

Comments
 (0)