Skip to content

Commit 3bf3ae4

Browse files
committed
modify the kwargs inside _flash_attention_forward
1 parent ab4b7c1 commit 3bf3ae4

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

src/transformers/integrations/flash_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def flash_attention_forward(
7676
softcap=softcap,
7777
use_top_left_mask=_use_top_left_mask,
7878
target_dtype=target_dtype,
79-
implementation=module.config._attn_implementation,
79+
attn_implementation=module.config._attn_implementation,
8080
layer_idx=module.layer_idx if hasattr(module, "layer_idx") else None,
8181
**kwargs,
8282
)

src/transformers/modeling_flash_attention_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -545,7 +545,7 @@ def _flash_attention_forward(
545545
max_length_q: Optional[int] = None,
546546
max_length_k: Optional[int] = None,
547547
target_dtype: Optional[torch.dtype] = None,
548-
implementation: Optional[str] = None,
548+
attn_implementation: Optional[str] = None,
549549
**kwargs,
550550
):
551551
"""
@@ -568,7 +568,7 @@ def _flash_attention_forward(
568568
The attention implementation to use. If None, will default to the one based on the environment.
569569
"""
570570
(flash_fn, flash_varlen_fn, pad_fn, unpad_fn), process_flash_kwargs_fn = lazy_import_flash_attention(
571-
implementation
571+
attn_implementation
572572
)
573573

574574
# PEFT possibly silently casts tensors to fp32, this potentially reconverts to correct dtype or is a no op

0 commit comments

Comments
 (0)