Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions cpp/kernels/xqa/mha_sm90.cu
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,8 @@ CUBIN_EXPORT __global__
#ifdef NDEBUG
#if !OPTIMIZE_FOR_LATENCY
__launch_bounds__(128 * 3, headElems* ctaNbQHeads <= 128 * 16 ? 3 : 2)
#else
__launch_bounds__(128 * 3)
#endif
#else
__launch_bounds__(128 * 3, 1)
Expand Down Expand Up @@ -1088,6 +1090,23 @@ CUBIN_EXPORT __global__
}
}
smem.gemm1WarpGrpBar.arrive_and_wait();
#else
if (blockIdx.y == 1 && threadIdx.x == 0)
{
printf("rowMax:\n");
for (int i = 0; i < ctaNbQHeads; i++)
{
printf("%f, ", smem.xRowMax[idxXBuf][i]);
}
printf("\n");
printf("rowSum:\n");
for (int i = 0; i < ctaNbQHeads; i++)
{
printf("%f, ", smem.xRowSum[idxXBuf][i]);
}
printf("\n");
}
smem.gemm1WarpGrpBar.arrive_and_wait();
#endif
#endif

Expand Down
8 changes: 7 additions & 1 deletion cpp/kernels/xqa/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,13 @@
#include <cuda_fp8.h>

inline constexpr float log2e = 1.4426950408889634; // std::log2(M_E)
inline constexpr float safeInitRowMax = -1e+30F;
// we used an optimization where exp(x-rowMax) is computed as:
/* bias = rowMax * log2e // shared for the whole row
exp(x-rowMax) = exp2f(x * log2e - bias)
*/
// But this optimization is not numerically stable when (x * log2e - bias) is computed with FMA and x is too large. For
// this reason, don't set safeInitRowMax with a huge absolute value.
inline constexpr float safeInitRowMax = -1e+5F;
inline constexpr int32_t kBAD_PAGE_INDEX = -1;
__constant__ constexpr float kE4M3_MAX = 448.F;

Expand Down