Skip to content

Commit 4d806db

Browse files
authored
Fix bug of _prepare_4d_attention_mask (#27847)
* use _prepare_4d_attention_mask * fix comment
1 parent 75336c1 commit 4d806db

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

src/transformers/models/llama/modeling_llama.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,11 @@
2929
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
3030

3131
from ...activations import ACT2FN
32-
from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask
32+
from ...modeling_attn_mask_utils import (
33+
AttentionMaskConverter,
34+
_prepare_4d_attention_mask,
35+
_prepare_4d_causal_attention_mask,
36+
)
3337
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
3438
from ...modeling_utils import PreTrainedModel
3539
from ...pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
@@ -78,9 +82,9 @@ def _get_unpad_data(attention_mask):
7882

7983
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
8084
warnings.warn(
81-
"Calling `transformers.models.llama.modeling_llama._prepare_4d_attention_mask` is deprecated and will be removed in v4.37. Use `transformers.modeling_attn_mask_utils.AttentionMaskConverter._prepare_4d_attention_mask"
85+
"Calling `transformers.models.llama.modeling_llama._prepare_4d_attention_mask` is deprecated and will be removed in v4.37. Use `transformers.modeling_attn_mask_utils._prepare_4d_attention_mask"
8286
)
83-
return AttentionMaskConverter._prepare_4d_attention_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
87+
return _prepare_4d_attention_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
8488

8589

8690
def _make_causal_mask(

0 commit comments

Comments
 (0)