@@ -1252,8 +1252,8 @@ int main(int argc, char** argv)
12521252
12531253 // Allocate the reference on the host.
12541254 float * o_ref_h = (float *) malloc (o_size * sizeof (float ));
1255- float * softmax_sum_ref_h = (float *) malloc (2 * b * s * h * sizeof (float ));
1256- float * softmax_sum_h = (float *) malloc (2 * b * s * h * sizeof (float ));
1255+ float * softmax_stats_ref_h = (float *) malloc (2 * b * s * h * sizeof (float ));
1256+ float * softmax_stats_h = (float *) malloc (2 * b * s * h * sizeof (float ));
12571257
12581258 // The P matrix is stored as one big matrix of size S x B x H x S.
12591259 const size_t p_size = s * b * h * s;
@@ -1947,7 +1947,7 @@ int main(int argc, char** argv)
19471947
19481948 // Read the results.
19491949 FMHA_CHECK_CUDA (cuda_memcpy_d2h (o_ref_h, o_d, o_size, data_type));
1950- FMHA_CHECK_CUDA (cuda_memcpy_d2h (softmax_sum_ref_h , softmax_stats_d, 2 * b * s * h, DATA_TYPE_FP32));
1950+ FMHA_CHECK_CUDA (cuda_memcpy_d2h (softmax_stats_ref_h , softmax_stats_d, 2 * b * s * h, DATA_TYPE_FP32));
19511951 }
19521952
19531953 // Fill-in p/s/o with garbage data.
@@ -2033,7 +2033,7 @@ int main(int argc, char** argv)
20332033 std::vector<float > o_ref_trans_h (o_size);
20342034
20352035 FMHA_CHECK_CUDA (cuda_memcpy_d2h (o_h, o_d_view, o_view_size, output_dtype));
2036- FMHA_CHECK_CUDA (cuda_memcpy_d2h (softmax_sum_h , softmax_stats_d, 2 * b * s * h, DATA_TYPE_FP32));
2036+ FMHA_CHECK_CUDA (cuda_memcpy_d2h (softmax_stats_h , softmax_stats_d, 2 * b * s * h, DATA_TYPE_FP32));
20372037
20382038 if (interleaved)
20392039 {
@@ -2053,7 +2053,7 @@ int main(int argc, char** argv)
20532053 dv, epsilon, verbose, true );
20542054 if (save_softmax)
20552055 {
2056- auto errors = check_softmax_results (softmax_sum_h, softmax_sum_ref_h , b, s, h, seqlens, cu_seqlens);
2056+ auto errors = check_softmax_results (softmax_stats_h, softmax_stats_ref_h , b, s, h, seqlens, cu_seqlens);
20572057 status = status | ((errors.first + errors.second ) > 0 );
20582058 }
20592059 }
@@ -2149,8 +2149,8 @@ int main(int argc, char** argv)
21492149 free (s_h);
21502150 free (o_h);
21512151 free (o_ref_h);
2152- free (softmax_sum_h );
2153- free (softmax_sum_ref_h );
2152+ free (softmax_stats_h );
2153+ free (softmax_stats_ref_h );
21542154 free (contiguous_kv_h);
21552155 free (kv_cache_ptrs_h);
21562156 free (kv_cache_block_offsets_h);
0 commit comments