@@ -630,10 +630,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
630630 Tensor tVsV = make_tensor (tVsV_.data (), reshape_thread_tile (tVsV_.layout ()));
631631
632632 if (block_table != nullptr ) {
633- tKgK.data () = gK .data () + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block_max, params.page_block_size ,
634- block_table, params.k_batch_stride , params.k_row_stride );
635- tVgV.data () = gV .data () + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block_max, params.page_block_size ,
636- block_table, params.v_batch_stride , params.v_row_stride );
633+ auto final_block_size = binfo.actual_seqlen_k - (n_block_max - 1 ) * kBlockN ;
634+ tKgK.data () = gK .data () + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block_max - 1 , params.page_block_size ,
635+ block_table, params.k_batch_stride , params.k_row_stride , final_block_size);
636+ tVgV.data () = gV .data () + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block_max - 1 , params.page_block_size ,
637+ block_table, params.v_batch_stride , params.v_row_stride , final_block_size);
637638 }
638639
639640 typename Kernel_traits::TiledMma tiled_mma;
@@ -790,9 +791,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
790791 tKgK.data () = tKgK.data () + (-int (kBlockN * params.k_row_stride ));
791792 } else {
792793 if (n_block > n_block_copy_min) {
793- tVgV.data () = gV .data () + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size ,
794+ tVgV.data () = gV .data () + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block - 1 , params.page_block_size ,
794795 block_table, params.v_batch_stride , params.v_row_stride );
795- tKgK.data () = gK .data () + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size ,
796+ tKgK.data () = gK .data () + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block - 1 , params.page_block_size ,
796797 block_table, params.k_batch_stride , params.k_row_stride );
797798 }
798799 }
@@ -886,7 +887,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
886887 if (block_table == nullptr ) {
887888 tVgV.data () = tVgV.data () + (-int (kBlockN * params.v_row_stride ));
888889 } else {
889- tVgV.data () = gV .data () + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block + 1 , params.page_block_size ,
890+ tVgV.data () = gV .data () + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size ,
890891 block_table, params.v_batch_stride , params.v_row_stride );
891892 }
892893 FLASH_NAMESPACE::copy</* Is_even_MN=*/ true , Is_even_K>(gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, tKVpKV);
@@ -922,7 +923,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
922923 if (block_table == nullptr ) {
923924 tKgK.data () = tKgK.data () + (-int (kBlockN * params.k_row_stride ));
924925 } else {
925- tKgK.data () = gK .data () + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size ,
926+ tKgK.data () = gK .data () + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block - 1 , params.page_block_size ,
926927 block_table, params.k_batch_stride , params.k_row_stride );
927928 }
928929 FLASH_NAMESPACE::copy</* Is_even_MN=*/ true , Is_even_K>(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, tKVpKV);
@@ -962,7 +963,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
962963 if (block_table == nullptr ) {
963964 tVgV.data () = tVgV.data () + (-int (kBlockN * params.v_row_stride ));
964965 } else {
965- tVgV.data () = gV .data () + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block + 1 , params.page_block_size ,
966+ tVgV.data () = gV .data () + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size ,
966967 block_table, params.v_batch_stride , params.v_row_stride );
967968 }
968969
@@ -984,7 +985,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
984985 if (block_table == nullptr ) {
985986 tKgK.data () = tKgK.data () + (-int (kBlockN * params.k_row_stride ));
986987 } else {
987- tKgK.data () = gK .data () + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size ,
988+ tKgK.data () = gK .data () + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block - 1 , params.page_block_size ,
988989 block_table, params.k_batch_stride , params.k_row_stride );
989990 }
990991 FLASH_NAMESPACE::copy</* Is_even_MN=*/ true , Is_even_K>(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, tKVpKV);
0 commit comments