diff --git a/cpp/kernels/xqa/mha_sm90.cu b/cpp/kernels/xqa/mha_sm90.cu index da44fba60c4..5b14f37aea0 100644 --- a/cpp/kernels/xqa/mha_sm90.cu +++ b/cpp/kernels/xqa/mha_sm90.cu @@ -632,6 +632,8 @@ CUBIN_EXPORT __global__ #ifdef NDEBUG #if !OPTIMIZE_FOR_LATENCY __launch_bounds__(128 * 3, headElems* ctaNbQHeads <= 128 * 16 ? 3 : 2) +#else + __launch_bounds__(128 * 3) #endif #else __launch_bounds__(128 * 3, 1) @@ -1088,6 +1090,23 @@ CUBIN_EXPORT __global__ } } smem.gemm1WarpGrpBar.arrive_and_wait(); +#else + if (blockIdx.y == 1 && threadIdx.x == 0) + { + printf("rowMax:\n"); + for (int i = 0; i < ctaNbQHeads; i++) + { + printf("%f, ", smem.xRowMax[idxXBuf][i]); + } + printf("\n"); + printf("rowSum:\n"); + for (int i = 0; i < ctaNbQHeads; i++) + { + printf("%f, ", smem.xRowSum[idxXBuf][i]); + } + printf("\n"); + } + smem.gemm1WarpGrpBar.arrive_and_wait(); #endif #endif diff --git a/cpp/kernels/xqa/utils.cuh b/cpp/kernels/xqa/utils.cuh index 2ec5b40995a..6f74a830ed2 100644 --- a/cpp/kernels/xqa/utils.cuh +++ b/cpp/kernels/xqa/utils.cuh @@ -30,7 +30,13 @@ #include inline constexpr float log2e = 1.4426950408889634; // std::log2(M_E) -inline constexpr float safeInitRowMax = -1e+30F; +// we used an optimization where exp(x-rowMax) is computed as: +/* bias = rowMax * log2e // shared for the whole row + exp(x-rowMax) = exp2f(x * log2e - bias) +*/ +// But this optimization is not numerically stable when (x * log2e - bias) is computed with FMA and x is too large. For +// this reason, don't set safeInitRowMax with a huge absolute value. +inline constexpr float safeInitRowMax = -1e+5F; inline constexpr int32_t kBAD_PAGE_INDEX = -1; __constant__ constexpr float kE4M3_MAX = 448.F;