Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 81 additions & 80 deletions src/transformers/models/whisper/modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from torch.nn import CrossEntropyLoss

from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
from ...modeling_outputs import (
BaseModelOutput,
Expand Down Expand Up @@ -249,6 +250,7 @@ def __init__(
is_decoder: bool = False,
bias: bool = True,
is_causal: bool = False,
layer_idx: Optional[int] = None,
config: Optional[WhisperConfig] = None,
):
super().__init__()
Expand All @@ -267,6 +269,14 @@ def __init__(
self.is_decoder = is_decoder
self.is_causal = is_causal

if layer_idx is None and is_decoder:
logger.warning_once(
f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
"will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
self.layer_idx = layer_idx

self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
Expand All @@ -281,7 +291,7 @@ def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
past_key_value: Optional[Cache] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
Expand All @@ -302,37 +312,27 @@ def forward(
# the provided `key_value_states` to support prefix tuning
if (
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
and len(past_key_value.key_cache) >= self.layer_idx + 1
and past_key_value.key_cache[self.layer_idx].shape[2] == key_value_states.shape[1]
):
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
key_states = past_key_value.key_cache[self.layer_idx]
value_states = past_key_value.value_cache[self.layer_idx]
elif is_cross_attention:
# cross_attentions
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)
else:
# self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)

if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)

proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.reshape(*proj_shape)
Expand Down Expand Up @@ -415,14 +415,11 @@ def __init__(self, *args, **kwargs):
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()

def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)

def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
past_key_value: Optional[Cache] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
Expand All @@ -445,40 +442,30 @@ def forward(
# the provided `key_value_states` to support prefix tuning
if (
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
and len(past_key_value.key_cache) >= self.layer_idx + 1
and past_key_value.key_cache[self.layer_idx].shape[2] == key_value_states.shape[1]
):
# reuse k,v, cross_attentions
key_states = past_key_value[0].transpose(1, 2)
value_states = past_key_value[1].transpose(1, 2)
key_states = past_key_value.key_cache[self.layer_idx]
value_states = past_key_value.value_cache[self.layer_idx]
elif is_cross_attention:
# cross_attentions
key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)

else:
# self_attention
key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)

if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)

key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)

# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
Expand Down Expand Up @@ -624,7 +611,7 @@ def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
past_key_value: Optional[Cache] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
Expand Down Expand Up @@ -659,37 +646,27 @@ def forward(
# the provided `key_value_states` to support prefix tuning
if (
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
and len(past_key_value.key_cache) >= self.layer_idx + 1
and past_key_value.key_cache[self.layer_idx].shape[2] == key_value_states.shape[1]
):
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
key_states = past_key_value.key_cache[self.layer_idx]
value_states = past_key_value.value_cache[self.layer_idx]
elif is_cross_attention:
# cross_attentions
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)
else:
# self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)

if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)

query_states = self._shape(query_states, tgt_len, bsz)

# NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
Expand Down Expand Up @@ -801,7 +778,7 @@ def forward(

# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Whisper, MBART->WHISPER
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll propagate changes to all other MBart derived modules when we're happy with the design

class WhisperDecoderLayer(nn.Module):
def __init__(self, config: WhisperConfig):
def __init__(self, config: WhisperConfig, layer_idx: int = None):
super().__init__()
self.embed_dim = config.d_model

Expand All @@ -811,6 +788,7 @@ def __init__(self, config: WhisperConfig):
dropout=config.attention_dropout,
is_decoder=True,
is_causal=True,
layer_idx=layer_idx,
config=config,
)
self.dropout = config.dropout
Expand All @@ -823,6 +801,7 @@ def __init__(self, config: WhisperConfig):
config.decoder_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
layer_idx=layer_idx,
config=config,
)
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
Expand Down Expand Up @@ -864,9 +843,9 @@ def forward(
hidden_states = self.self_attn_layer_norm(hidden_states)

# Self Attention
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
# add present self-attn cache to positions 1,2 of present_key_value tuple
# decoder uni-directional self-attention cached key/values states are at position 0
self_attn_past_key_value = past_key_value[0] if past_key_value is not None else None
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The difficulty comes here from the fact that we're dealing with two sets of past key-values per decoder layer: one from the self-attention, and one from the cross-attention. The current solution uses a separate cache for each.

# add present self-attn cache to positions 0 of present_key_value tuple
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
past_key_value=self_attn_past_key_value,
Expand All @@ -884,8 +863,8 @@ def forward(
residual = hidden_states
hidden_states = self.encoder_attn_layer_norm(hidden_states)

# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
# cross_attn cached key/values tuple is at position 1 of present_key_value tuple
cross_attn_past_key_value = past_key_value[1] if past_key_value is not None else None
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
hidden_states=hidden_states,
key_value_states=encoder_hidden_states,
Expand All @@ -897,8 +876,8 @@ def forward(
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states

# add cross-attn to positions 3,4 of present_key_value tuple
present_key_value = present_key_value + cross_attn_present_key_value
# add cross-attn to positions 1 of present_key_value tuple
present_key_value = (present_key_value, cross_attn_present_key_value)

# Fully Connected
residual = hidden_states
Expand Down Expand Up @@ -928,6 +907,7 @@ class WhisperPreTrainedModel(PreTrainedModel):
_no_split_modules = ["WhisperEncoderLayer", "WhisperDecoderLayer"]
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True

def _init_weights(self, module):
std = self.config.init_std
Expand Down Expand Up @@ -1257,7 +1237,9 @@ def __init__(self, config: WhisperConfig):
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
self.embed_positions = WhisperPositionalEmbedding(self.max_target_positions, config.d_model)

self.layers = nn.ModuleList([WhisperDecoderLayer(config) for _ in range(config.decoder_layers)])
self.layers = nn.ModuleList(
[WhisperDecoderLayer(config, layer_idx) for layer_idx in range(config.decoder_layers)]
)
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self._use_sdpa = config._attn_implementation == "sdpa"

Expand Down Expand Up @@ -1365,7 +1347,20 @@ def forward(
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")

# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
past_key_values_length = 0
if use_cache:
use_legacy_cache = not (past_key_values is not None and isinstance(past_key_values[0], Cache))
if use_legacy_cache:
if past_key_values is None:
self_attn = cross_attn = None
else:
self_attn = [key_values[:2] for key_values in past_key_values]
cross_attn = [key_values[2:] for key_values in past_key_values]
past_key_values = (
DynamicCache.from_legacy_cache(self_attn),
DynamicCache.from_legacy_cache(cross_attn),
)
past_key_values_length = past_key_values[0].get_seq_length()

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
Expand Down Expand Up @@ -1407,7 +1402,7 @@ def forward(
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
next_decoder_cache = () if use_cache else None
next_decoder_cache = None

# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
Expand All @@ -1425,8 +1420,6 @@ def forward(
if dropout_probability < self.layerdrop:
continue

past_key_value = past_key_values[idx] if past_key_values is not None else None

if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
Expand All @@ -1449,14 +1442,14 @@ def forward(
cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
past_key_value=past_key_value,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]

if use_cache:
next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
next_decoder_cache = layer_outputs[3 if output_attentions else 1]

if output_attentions:
all_self_attns += (layer_outputs[1],)
Expand All @@ -1469,7 +1462,13 @@ def forward(
if output_hidden_states:
all_hidden_states += (hidden_states,)

next_cache = next_decoder_cache if use_cache else None
next_cache = None
if use_cache and use_legacy_cache:
next_cache = ()
for self_attn, cross_attn in zip(
next_decoder_cache[0].to_legacy_cache(), next_decoder_cache[1].to_legacy_cache()
):
next_cache += (self_attn + cross_attn,)
if not return_dict:
return tuple(
v
Expand Down Expand Up @@ -1806,9 +1805,11 @@ def prepare_inputs_for_generation(
decoder_position_ids = None
if decoder_attention_mask is not None:
decoder_position_ids = (decoder_attention_mask.cumsum(-1) - 1).clamp(min=0)

if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
if isinstance(past_key_values[0], Cache):
past_length = past_key_values[0].get_seq_length
else:
past_length = past_key_values[0][0].shape[2]

# Some generation methods already pass only the last input ID
if decoder_input_ids.shape[1] > past_length:
Expand Down