Skip to content

Commit bd4bfcf

Browse files
committed
update vllm interface
Signed-off-by: Ming Yang <[email protected]>
1 parent dc64529 commit bd4bfcf

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

hopper/flash_api_torch_lib.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq
5454
int const sm_margin,
5555
std::optional<const at::Tensor> &s_aux_,
5656
int const cp_world_size,
57-
int const cp_rank,
57+
int const cp_rank
5858
);
5959

6060
// Only applicable to the case where seqused_k (i.e. cache_seqlens) is available
@@ -124,7 +124,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
124124
" int sm_margin,"
125125
" Tensor? s_aux,"
126126
" int cp_world_size,"
127-
" int cp_rank") -> Tensor[]");
127+
" int cp_rank) -> Tensor[]");
128128
ops.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd));
129129

130130
ops.def("get_scheduler_metadata("

vllm_flash_attn/flash_attn_interface.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,8 @@ def flash_attn_varlen_func(
146146
# Version selector
147147
fa_version: int = DEFAULT_FA_VERSION,
148148
s_aux=None,
149+
cp_world_size=1,
150+
cp_rank=0,
149151
):
150152
"""dropout_p should be set to 0.0 during evaluation
151153
Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
@@ -279,7 +281,9 @@ def flash_attn_varlen_func(
279281
num_splits,
280282
None, # pack_gqa
281283
0, # sm_margin
282-
s_aux # s_aux
284+
s_aux, # s_aux
285+
cp_world_size,
286+
cp_rank,
283287
)
284288
else:
285289
raise ValueError(f"Unsupported FA version: {fa_version}")
@@ -316,6 +320,8 @@ def flash_attn_with_kvcache(
316320
# Version selector
317321
fa_version: int = DEFAULT_FA_VERSION,
318322
s_aux=None,
323+
cp_world_size=1,
324+
cp_rank=0,
319325
):
320326
"""
321327
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from

0 commit comments

Comments
 (0)