diff --git a/csrc/rocm/custom_kernels.cu b/csrc/rocm/custom_kernels.cu index 3533108b3316..2d4a68fe3e7b 100644 --- a/csrc/rocm/custom_kernels.cu +++ b/csrc/rocm/custom_kernels.cu @@ -1715,7 +1715,7 @@ void wvSpltKQ_(void* in_a, void* in_b, void* out_c, void* scale_a, dim3 block(64, _WvPrGrp); \ if ((K_in * N_in <= 32 * 1024) && (M_in % _YTILEs == 0)) { \ int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \ - wvSpltKQ_hf_sml_<64, _YTILEs, _WvPrGrp, 16, _UNRLs, _N> \ + wvSpltKQ_hf_sml_<64, _YTILEs, _WvPrGrp, 8, _UNRLs, _N> \ <<>>(K_in, Kp_in, M_in, af4, bf4, c, s_a, \ s_b, __wvPrGrp, Otp_in, CuCount); \ } else { \