diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index 5312b0dd9cd0..1b64c657333b 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -545,7 +545,7 @@ def _flash_attention_forward( max_length_q: Optional[int] = None, max_length_k: Optional[int] = None, target_dtype: Optional[torch.dtype] = None, - implementation: Optional[str] = None, + attn_implementation: Optional[str] = None, **kwargs, ): """ @@ -564,11 +564,11 @@ def _flash_attention_forward( attention_mask (`torch.Tensor`, *optional*): The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the position of padding tokens and 1 for the position of non-padding tokens. - implementation (`str`, *optional*): + attn_implementation (`str`, *optional*): The attention implementation to use. If None, will default to the one based on the environment. """ (flash_fn, flash_varlen_fn, pad_fn, unpad_fn), process_flash_kwargs_fn = lazy_import_flash_attention( - implementation + attn_implementation ) # PEFT possibly silently casts tensors to fp32, this potentially reconverts to correct dtype or is a no op