Skip to content

Commit 720c948

Browse files
[Bugfix] fix illegal memory access (#42)
* fix illegal memory access Signed-off-by: LucasWilkinson <[email protected]> * fix off by one Signed-off-by: LucasWilkinson <[email protected]> * typo Signed-off-by: LucasWilkinson <[email protected]> --------- Signed-off-by: LucasWilkinson <[email protected]>
1 parent d4e0903 commit 720c948

File tree

2 files changed

+32
-14
lines changed

2 files changed

+32
-14
lines changed

csrc/flash_attn/src/flash_fwd_kernel.h

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -630,10 +630,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
630630
Tensor tVsV = make_tensor(tVsV_.data(), reshape_thread_tile(tVsV_.layout()));
631631

632632
if (block_table != nullptr) {
633-
tKgK.data() = gK.data() + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block_max, params.page_block_size,
634-
block_table, params.k_batch_stride, params.k_row_stride);
635-
tVgV.data() = gV.data() + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block_max, params.page_block_size,
636-
block_table, params.v_batch_stride, params.v_row_stride);
633+
auto final_block_size = binfo.actual_seqlen_k - (n_block_max - 1) * kBlockN;
634+
tKgK.data() = gK.data() + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block_max - 1, params.page_block_size,
635+
block_table, params.k_batch_stride, params.k_row_stride, final_block_size);
636+
tVgV.data() = gV.data() + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block_max - 1, params.page_block_size,
637+
block_table, params.v_batch_stride, params.v_row_stride, final_block_size);
637638
}
638639

639640
typename Kernel_traits::TiledMma tiled_mma;
@@ -790,9 +791,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
790791
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
791792
} else {
792793
if (n_block > n_block_copy_min) {
793-
tVgV.data() = gV.data() + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size,
794+
tVgV.data() = gV.data() + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block - 1, params.page_block_size,
794795
block_table, params.v_batch_stride, params.v_row_stride);
795-
tKgK.data() = gK.data() + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size,
796+
tKgK.data() = gK.data() + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block - 1, params.page_block_size,
796797
block_table, params.k_batch_stride, params.k_row_stride);
797798
}
798799
}
@@ -886,7 +887,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
886887
if (block_table == nullptr) {
887888
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
888889
} else {
889-
tVgV.data() = gV.data() + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block + 1, params.page_block_size,
890+
tVgV.data() = gV.data() + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size,
890891
block_table, params.v_batch_stride, params.v_row_stride);
891892
}
892893
FLASH_NAMESPACE::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, tKVpKV);
@@ -922,7 +923,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
922923
if (block_table == nullptr) {
923924
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
924925
} else {
925-
tKgK.data() = gK.data() + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size,
926+
tKgK.data() = gK.data() + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block - 1, params.page_block_size,
926927
block_table, params.k_batch_stride, params.k_row_stride);
927928
}
928929
FLASH_NAMESPACE::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, tKVpKV);
@@ -962,7 +963,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
962963
if (block_table == nullptr) {
963964
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
964965
} else {
965-
tVgV.data() = gV.data() + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block + 1, params.page_block_size,
966+
tVgV.data() = gV.data() + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size,
966967
block_table, params.v_batch_stride, params.v_row_stride);
967968
}
968969

@@ -984,7 +985,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
984985
if (block_table == nullptr) {
985986
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
986987
} else {
987-
tKgK.data() = gK.data() + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size,
988+
tKgK.data() = gK.data() + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block - 1, params.page_block_size,
988989
block_table, params.k_batch_stride, params.k_row_stride);
989990
}
990991
FLASH_NAMESPACE::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, tKVpKV);

csrc/flash_attn/src/utils.h

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -297,16 +297,33 @@ void cp_async_wait() {
297297
// assumes that the tensor has already been positioned at the correct head.
298298
template <typename Kernel_traits>
299299
__forceinline__ __device__
300-
int64_t resolve_thread_kv_page_slice_offset(const int tidx, const int n_block_max, const int page_block_size,
301-
const int* block_table, const int page_stride, const int row_stride) {
300+
int64_t resolve_thread_kv_page_slice_offset(
301+
const int tidx, const int n_block, const int page_block_size,
302+
const int* block_table, const int page_stride, const int row_stride,
303+
std::optional<int> partial_block_size = std::nullopt
304+
) {
302305
constexpr int kGmemThreadsPerRow = Kernel_traits::kGmemThreadsPerRow;
303306
constexpr int kGmemRowsPerThread = Kernel_traits::kGmemRowsPerThread;
304307
constexpr int kGmemElemsPerLoad = Kernel_traits::kGmemElemsPerLoad;
305308
constexpr int kBlockN = Kernel_traits::kBlockN;
306309

307310
const int64_t col_offset = tidx % kGmemThreadsPerRow * kGmemElemsPerLoad;
308-
const int64_t block_row_offset = tidx / kGmemThreadsPerRow * kGmemRowsPerThread;
309-
const int64_t global_row_offset = block_row_offset + (n_block_max - 1) * kBlockN;
311+
int64_t block_row_offset = tidx / kGmemThreadsPerRow * kGmemRowsPerThread;
312+
313+
if (partial_block_size) {
314+
// if we have a partial block, we need to adjust the row offset to avoid
315+
// reading of the end end of the block_table
316+
// get the offset of the last row in the kBlockN we care about
317+
auto final_row_offset = std::max(*partial_block_size - 1, 0);
318+
// adjust the row offset to account for each thread loading multiple
319+
// rows
320+
auto final_thread_row_offset =
321+
ceil_div(final_row_offset, kGmemRowsPerThread) * kGmemRowsPerThread;
322+
block_row_offset = std::min(
323+
block_row_offset, int64_t(final_thread_row_offset));
324+
}
325+
326+
const int64_t global_row_offset = block_row_offset + n_block * kBlockN;
310327
const int64_t page_offset = global_row_offset % page_block_size;
311328
const int64_t virtual_page_idx = global_row_offset / page_block_size;
312329

0 commit comments

Comments
 (0)