|
29 | 29 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss |
30 | 30 |
|
31 | 31 | from ...activations import ACT2FN |
32 | | -from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask |
| 32 | +from ...modeling_attn_mask_utils import ( |
| 33 | + AttentionMaskConverter, |
| 34 | + _prepare_4d_attention_mask, |
| 35 | + _prepare_4d_causal_attention_mask, |
| 36 | +) |
33 | 37 | from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast |
34 | 38 | from ...modeling_utils import PreTrainedModel |
35 | 39 | from ...pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13 |
@@ -78,9 +82,9 @@ def _get_unpad_data(attention_mask): |
78 | 82 |
|
79 | 83 | def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): |
80 | 84 | warnings.warn( |
81 | | - "Calling `transformers.models.llama.modeling_llama._prepare_4d_attention_mask` is deprecated and will be removed in v4.37. Use `transformers.modeling_attn_mask_utils.AttentionMaskConverter._prepare_4d_attention_mask" |
| 85 | + "Calling `transformers.models.llama.modeling_llama._prepare_4d_attention_mask` is deprecated and will be removed in v4.37. Use `transformers.modeling_attn_mask_utils._prepare_4d_attention_mask" |
82 | 86 | ) |
83 | | - return AttentionMaskConverter._prepare_4d_attention_mask(mask=mask, dtype=dtype, tgt_len=tgt_len) |
| 87 | + return _prepare_4d_attention_mask(mask=mask, dtype=dtype, tgt_len=tgt_len) |
84 | 88 |
|
85 | 89 |
|
86 | 90 | def _make_causal_mask( |
|
0 commit comments