Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion hopper/block.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,13 @@ struct BlockMN {
// TODO: check off-by-1 error
if (PackGQA) { m_idx_max = qhead_per_khead_divmod.divide(m_idx_max - 1) + 1 ; }
// If local, blocking (m_idx_max - m_idx_min + window_size_right + window_size_left)
// when cp is not enabled, tot_seqlen_k is equal to seqlen_k, and cp_world_size is 1.
// cp_world_size is guaranteed to be greater than 0
n_block_max = std::min(n_block_max,
cute::ceil_div(m_idx_max + seqlen_k - seqlen_q + window_size_right, kBlockN));
cute::ceil_div(
cute::ceil_div(m_idx_max + seqlen_info.tot_seqlen_k - seqlen_q + window_size_right - seqlen_info.cp_rank,
seqlen_info.cp_world_size),
kBlockN));
}
// Now, only adjust n_block_min if split
int n_block_min = 0;
Expand Down
5 changes: 5 additions & 0 deletions hopper/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,11 @@ struct Flash_fwd_params : public Qkv_params {

// The S extra matrix, (num_heads)
void *__restrict__ s_aux_ptr;

// CP (Context Parallelism) parameters
int cp_world_size;
int cp_rank;
int *__restrict__ cp_tot_seqused_k;
};

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
21 changes: 19 additions & 2 deletions hopper/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -707,7 +707,10 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq
int num_splits,
std::optional<bool> pack_gqa_,
int const sm_margin,
std::optional<const at::Tensor> &s_aux_ // (h)
std::optional<const at::Tensor> &s_aux_, // (h)
int const cp_world_size, // context parallelism (cp) world size
int const cp_rank, // cp rank
std::optional<const at::Tensor> &cp_tot_seqused_k_ // b. total seqused_k in cp world
) {

auto dprops = at::cuda::getCurrentDeviceProperties();
Expand Down Expand Up @@ -845,6 +848,12 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq
CHECK_DEVICE(seqused_k); CHECK_CONTIGUOUS(seqused_k);
CHECK_SHAPE(seqused_k, batch_size);
}
if (cp_tot_seqused_k_.has_value()) {
auto cp_tot_seqused_k = cp_tot_seqused_k_.value();
TORCH_CHECK(cp_tot_seqused_k.dtype() == torch::kInt32, "seqused_k must have dtype int32");
CHECK_DEVICE(cp_tot_seqused_k); CHECK_CONTIGUOUS(cp_tot_seqused_k);
CHECK_SHAPE(cp_tot_seqused_k, batch_size);
}

if (leftpad_k_.has_value()) {
auto leftpad_k = leftpad_k_.value();
Expand Down Expand Up @@ -1154,6 +1163,14 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq
params.s_aux_ptr = nullptr;
}

params.cp_world_size = cp_world_size;
params.cp_rank = cp_rank;
params.cp_tot_seqused_k = cp_tot_seqused_k_.has_value() ?
static_cast<int *>(cp_tot_seqused_k_.value().data_ptr()) : nullptr;
TORCH_CHECK(cp_world_size > 0, "cp_world_size must be positive, required by downstream unified code path. Use 1 if CP is not enabled.");
TORCH_CHECK(cp_world_size != 1 || cp_rank == 0, "When context parallelism is disabled, cp_rank must be zero");
TORCH_CHECK(cp_world_size == 1 || cp_tot_seqused_k_.has_value(), "cp_tot_seqused_k_ must be provided when context parallelism is enabled.");

#ifdef FLASHATTENTION_DISABLE_LOCAL
TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention.");
#endif
Expand Down Expand Up @@ -1670,4 +1687,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("get_scheduler_metadata", &mha_fwd_get_scheduler_metadata, "Get scheduler metadata for varlen forward pass");
}

#endif
#endif
12 changes: 9 additions & 3 deletions hopper/flash_api_torch_lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq
int num_splits,
std::optional<bool> pack_gqa_,
int const sm_margin,
std::optional<const at::Tensor> &s_aux_
std::optional<const at::Tensor> &s_aux_,
int const cp_world_size,
int const cp_rank,
std::optional<const at::Tensor> &cp_tot_seqused_k
);

// Only applicable to the case where seqused_k (i.e. cache_seqlens) is available
Expand Down Expand Up @@ -120,7 +123,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" int num_splits,"
" bool? pack_gqa,"
" int sm_margin,"
" Tensor? s_aux) -> Tensor[]");
" Tensor? s_aux,"
" int cp_world_size,"
" int cp_rank,"
" Tensor? cp_tot_seqused_k) -> Tensor[]");
ops.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd));

ops.def("get_scheduler_metadata("
Expand Down Expand Up @@ -151,4 +157,4 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
make_pytorch_shim(&mha_fwd_get_scheduler_metadata));
}

REGISTER_EXTENSION(TORCH_EXTENSION_NAME);
REGISTER_EXTENSION(TORCH_EXTENSION_NAME);
40 changes: 38 additions & 2 deletions hopper/flash_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@ def _flash_attn_forward(
num_splits=1,
pack_gqa=None,
sm_margin=0,
s_aux=None):
s_aux=None,
cp_world_size=1,
cp_rank=0,
cp_tot_seqused_k=None):
q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)]
v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [
Expand Down Expand Up @@ -95,7 +98,10 @@ def _flash_attn_forward(
num_splits,
pack_gqa,
sm_margin,
s_aux
s_aux,
cp_world_size,
cp_rank,
cp_tot_seqused_k,
)
return out, softmax_lse, *rest

Expand Down Expand Up @@ -260,6 +266,9 @@ def forward(
deterministic=False,
sm_margin=0,
s_aux=None,
cp_world_size=1,
cp_rank=0,
cp_tot_seqused_k=None,
):
if softmax_scale is None:
softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
Expand All @@ -285,6 +294,9 @@ def forward(
pack_gqa=pack_gqa,
sm_margin=sm_margin,
s_aux=s_aux,
cp_world_size=cp_world_size,
cp_rank=cp_rank,
cp_tot_seqused_k=cp_tot_seqused_k,
)
# ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
ctx.save_for_backward(q, k, v, out, softmax_lse)
Expand Down Expand Up @@ -351,6 +363,9 @@ def forward(
deterministic=False,
sm_margin=0,
s_aux=None,
cp_world_size=1,
cp_rank=0,
cp_tot_seqused_k=0,
):
if softmax_scale is None:
softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
Expand Down Expand Up @@ -380,6 +395,9 @@ def forward(
pack_gqa=pack_gqa,
sm_margin=sm_margin,
s_aux=s_aux,
cp_world_size=cp_world_size,
cp_rank=cp_rank,
cp_tot_seqused_k=cp_tot_seqused_k,
)
# ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
Expand Down Expand Up @@ -497,6 +515,9 @@ def flash_attn_func(
deterministic=False,
sm_margin=0,
s_aux=None,
cp_world_size=1,
cp_rank=0,
cp_tot_seqused_k=None,
):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
Expand Down Expand Up @@ -558,6 +579,9 @@ def flash_attn_func(
deterministic,
sm_margin,
s_aux,
cp_world_size,
cp_rank,
cp_tot_seqused_k,
)


Expand All @@ -582,6 +606,9 @@ def flash_attn_varlen_func(
deterministic=False,
sm_margin=0,
s_aux=None,
cp_world_size=1,
cp_rank=0,
cp_tot_seqused_k=None,
):
return FlashAttnVarlenFunc.apply(
q,
Expand All @@ -604,6 +631,9 @@ def flash_attn_varlen_func(
deterministic,
sm_margin,
s_aux,
cp_world_size,
cp_rank,
cp_tot_seqused_k,
)


Expand Down Expand Up @@ -642,6 +672,9 @@ def flash_attn_with_kvcache(
sm_margin=0, # Can be tuned if some SMs are used for communication
return_softmax_lse=False,
s_aux=None,
cp_world_size=1,
cp_rank=0,
cp_tot_seqused_k=None,
):
"""
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
Expand Down Expand Up @@ -769,6 +802,9 @@ def flash_attn_with_kvcache(
pack_gqa=pack_gqa,
sm_margin=sm_margin,
s_aux=s_aux,
cp_world_size=cp_world_size,
cp_rank=cp_rank,
cp_tot_seqused_k=cp_tot_seqused_k,
)
# return (out, softmax_lse) if return_softmax_lse else out
return (out, softmax_lse, *rest) if return_softmax_lse else out
Expand Down
9 changes: 7 additions & 2 deletions hopper/flash_fwd_kernel_sm90.h
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,10 @@ class FlashAttnFwdSm90 {
get<0>(params.mainloop.shape_K_new),
params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new,
params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k,
params.mainloop.seqlens_rotary
params.mainloop.seqlens_rotary,
params.mainloop.cp_world_size,
params.mainloop.cp_rank,
params.mainloop.cp_tot_seqused_k
};
if constexpr (AppendKV) {
bool tile_new_valid = mainloop.load_kv_new(
Expand Down Expand Up @@ -396,7 +399,9 @@ class FlashAttnFwdSm90 {
get<0>(params.mainloop.shape_K_new),
params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new,
params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k,
params.mainloop.seqlens_rotary
params.mainloop.seqlens_rotary, params.mainloop.cp_world_size,
params.mainloop.cp_rank,
params.mainloop.cp_tot_seqused_k
};
if constexpr (AppendKV) {
bool tile_new_valid = mainloop.store_kv_new(
Expand Down
5 changes: 4 additions & 1 deletion hopper/flash_fwd_launch_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_knew,
params.seqused_q, params.seqused_k,
params.leftpad_k, params.seqlens_rotary,
static_cast<ElementS const*>(params.s_aux_ptr)
static_cast<ElementS const*>(params.s_aux_ptr),
params.cp_world_size, params.cp_rank, params.cp_tot_seqused_k
};
typename CollectiveEpilogue::Arguments epilogue_args {
static_cast<ElementOut*>(params.o_ptr),
Expand All @@ -156,6 +157,8 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
params.tile_count_semaphore, params.cu_seqlens_q, params.seqused_q,
// params.num_m_blocks_ptr,
params.num_splits_dynamic_ptr,
params.cp_world_size,
params.cp_rank,
};

if (Varlen && params.num_splits_dynamic_ptr && !params.skip_scheduler_metadata_computation) {
Expand Down
3 changes: 3 additions & 0 deletions hopper/mainloop_fwd_sm80.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,9 @@ struct CollectiveMainloopFwdSm80 {
int const* const leftpad_k = nullptr;
int const* const seqlens_rotary = nullptr;
ElementSAux const* const ptr_S_aux = nullptr;
int cp_world_size;
int cp_rank;
int const* const cp_tot_seqused_k = nullptr;
};

// Device side kernel params
Expand Down
27 changes: 23 additions & 4 deletions hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,10 @@ struct CollectiveMainloopFwdSm90 {
int const* const leftpad_k = nullptr;
int const* const seqlens_rotary = nullptr;
ElementSAux const* const ptr_S_aux = nullptr;
// Context parallelism (CP) parameters
int const cp_world_size = 1;
int const cp_rank = 0;
int const* const cp_tot_seqused_k = nullptr;
};

// Device side kernel params
Expand Down Expand Up @@ -469,6 +473,9 @@ struct CollectiveMainloopFwdSm90 {
int const* const leftpad_k = nullptr;
int const* const seqlens_rotary = nullptr;
ElementSAux const* const ptr_S_aux = nullptr;
int cp_world_size = 1;
int cp_rank = 0;
int const* const cp_tot_seqused_k = nullptr;
};

static Params
Expand Down Expand Up @@ -584,7 +591,8 @@ struct CollectiveMainloopFwdSm90 {
args.kv_batch_idx,
args.cu_seqlens_q, args.cu_seqlens_k, args.cu_seqlens_k_new,
args.seqused_q, args.seqused_k, args.leftpad_k, args.seqlens_rotary,
args.ptr_S_aux};
args.ptr_S_aux,
args.cp_world_size, args.cp_rank, args.cp_tot_seqused_k};
}

/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
Expand Down Expand Up @@ -1093,7 +1101,8 @@ struct CollectiveMainloopFwdSm90 {
// But we subtract n_offset for consistency in mask calculations
flash::Mask<kBlockM, kBlockN, PackGQA, TiledMmaQK> mask(
thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, 0 - n_offset /*sink_token_length*/,
params.qhead_per_khead_divmod
params.qhead_per_khead_divmod,
params.cp_world_size, params.cp_rank, seqlen_info.tot_seqlen_k
);

float softcap_val = params.softcap_val;
Expand Down Expand Up @@ -1275,8 +1284,13 @@ struct CollectiveMainloopFwdSm90 {
auto mask_fn = [&](auto& tSrS, int n_block) { mask.template apply<false /*Seqlenk_mask*/, Is_causal, Is_local>(tSrS, m_block, n_block); };
int const m_idx_min = !PackGQA ? m_block * kBlockM : params.qhead_per_khead_divmod.divide(m_block * kBlockM);
// If local, blocking (window_size_right + window_size_left)
// when cp is not enabled, tot_seqlen_k is equal to seqlen_k, and cp_world_size is 1.
// cp_world_size is guaranteed to be greater than 0
int const n_block_min_causal_local_mask =
std::max(n_block_min, (m_idx_min + seqlen_k - seqlen_q + params.window_size_right) / kBlockN);
std::max(n_block_min,
(m_idx_min + seqlen_info.tot_seqlen_k - seqlen_q + params.window_size_right) /
seqlen_info.cp_world_size /
kBlockN);
#pragma unroll 1
for (; n_block >= n_block_min_causal_local_mask; --n_block) {
fwd_step(n_block, mask_fn, cute::true_type{} /*check_inf*/);
Expand All @@ -1285,10 +1299,15 @@ struct CollectiveMainloopFwdSm90 {

int const m_idx_max = !PackGQA ? (m_block + 1) * kBlockM : params.qhead_per_khead_divmod.divide((m_block + 1) * kBlockM - 1) + 1;
// If local, blocking (m_idx_max - m_idx_min)
// when cp is not enabled, tot_seqlen_k is equal to seqlen_k, and cp_world_size is 1.
// cp_world_size is guaranteed to be greater than 0
int const n_block_min_before_local_mask = !Is_local
? n_block_min
: std::max(n_block_min,
cute::ceil_div(m_idx_max + seqlen_k - seqlen_q - params.window_size_left, kBlockN));
cute::ceil_div(
cute::ceil_div(m_idx_max + seqlen_info.tot_seqlen_k - seqlen_q - params.window_size_left - seqlen_info.cp_rank,
seqlen_info.cp_world_size),
kBlockN));
auto no_mask_fn = [](auto& tSrS, int n_block) { };
#pragma unroll 1
for (; n_block >= n_block_min_before_local_mask; --n_block) {
Expand Down
21 changes: 19 additions & 2 deletions hopper/mask.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,23 @@ struct Mask {
int const seqlen_q, seqlen_k;
int const window_size_left, window_size_right, sink_token_length;
cutlass::FastDivmod const qhead_per_khead_divmod;
int const cp_world_size, cp_rank, tot_seqlen_k;

CUTLASS_DEVICE
Mask(const int thread_idx, const int seqlen_q, const int seqlen_k,
const int window_size_left, const int window_size_right, const int sink_token_length,
cutlass::FastDivmod const &qhead_per_khead_divmod)
cutlass::FastDivmod const &qhead_per_khead_divmod,
const int cp_world_size = 1, const int cp_rank = 0, const int tot_seqlen_k = 0)
: thread_idx(thread_idx)
, seqlen_q(seqlen_q)
, seqlen_k(seqlen_k)
, window_size_left(window_size_left)
, window_size_right(window_size_right)
, sink_token_length(sink_token_length)
, qhead_per_khead_divmod(qhead_per_khead_divmod)
, cp_world_size(cp_world_size)
, cp_rank(cp_rank)
, tot_seqlen_k(tot_seqlen_k)
{
};

Expand Down Expand Up @@ -94,7 +99,19 @@ struct Mask {
: __viaddmin_s32(row_idx, causal_row_offset, seqlenk_col_limit);
#pragma unroll
for (int n = 0; n < size<1>(tSrS_rowcol); ++n) {
if (int(get<Col>(t0ScS_rowcol(_0{}, n))) >= col_limit_right) { tSrS_rowcol(m, n) = -INFINITY; }
int col_idx = int(get<Col>(t0ScS_rowcol(_0{}, n)));
if (cp_world_size > 1) {
int local_k_idx = int(get<Col>(t0ScS_rowcol(_0{}, n))) + get<Col>(tScS_rowcol(_0{}, _0{})) + n_block * kBlockN;
int abs_k_idx = local_k_idx * cp_world_size + cp_rank;
int k_limit = row_idx + tot_seqlen_k - seqlen_q;
if (abs_k_idx > k_limit || (Seqlenk_mask && abs_k_idx >= tot_seqlen_k)) {
tSrS_rowcol(m, n) = -INFINITY;
}
} else {
if (col_idx >= col_limit_right) {
tSrS_rowcol(m, n) = -INFINITY;
}
}
}
}
} else {
Expand Down
Loading