-
Notifications
You must be signed in to change notification settings - Fork 31.2k
[Whisper] Use Attention Cache #28931
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,6 +24,7 @@ | |
| from torch.nn import CrossEntropyLoss | ||
|
|
||
| from ...activations import ACT2FN | ||
| from ...cache_utils import Cache, DynamicCache | ||
| from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa | ||
| from ...modeling_outputs import ( | ||
| BaseModelOutput, | ||
|
|
@@ -249,6 +250,7 @@ def __init__( | |
| is_decoder: bool = False, | ||
| bias: bool = True, | ||
| is_causal: bool = False, | ||
| layer_idx: Optional[int] = None, | ||
| config: Optional[WhisperConfig] = None, | ||
| ): | ||
| super().__init__() | ||
|
|
@@ -267,6 +269,14 @@ def __init__( | |
| self.is_decoder = is_decoder | ||
| self.is_causal = is_causal | ||
|
|
||
| if layer_idx is None and is_decoder: | ||
| logger.warning_once( | ||
| f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " | ||
| "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " | ||
| "when creating this class." | ||
| ) | ||
| self.layer_idx = layer_idx | ||
|
|
||
| self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False) | ||
| self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) | ||
| self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) | ||
|
|
@@ -281,7 +291,7 @@ def forward( | |
| self, | ||
| hidden_states: torch.Tensor, | ||
| key_value_states: Optional[torch.Tensor] = None, | ||
| past_key_value: Optional[Tuple[torch.Tensor]] = None, | ||
| past_key_value: Optional[Cache] = None, | ||
| attention_mask: Optional[torch.Tensor] = None, | ||
| layer_head_mask: Optional[torch.Tensor] = None, | ||
| output_attentions: bool = False, | ||
|
|
@@ -302,37 +312,27 @@ def forward( | |
| # the provided `key_value_states` to support prefix tuning | ||
| if ( | ||
| is_cross_attention | ||
| and past_key_value is not None | ||
| and past_key_value[0].shape[2] == key_value_states.shape[1] | ||
| and len(past_key_value.key_cache) >= self.layer_idx + 1 | ||
| and past_key_value.key_cache[self.layer_idx].shape[2] == key_value_states.shape[1] | ||
| ): | ||
| # reuse k,v, cross_attentions | ||
| key_states = past_key_value[0] | ||
| value_states = past_key_value[1] | ||
| key_states = past_key_value.key_cache[self.layer_idx] | ||
| value_states = past_key_value.value_cache[self.layer_idx] | ||
| elif is_cross_attention: | ||
| # cross_attentions | ||
| key_states = self._shape(self.k_proj(key_value_states), -1, bsz) | ||
| value_states = self._shape(self.v_proj(key_value_states), -1, bsz) | ||
| key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx) | ||
| elif past_key_value is not None: | ||
| # reuse k, v, self_attention | ||
| key_states = self._shape(self.k_proj(hidden_states), -1, bsz) | ||
| value_states = self._shape(self.v_proj(hidden_states), -1, bsz) | ||
| key_states = torch.cat([past_key_value[0], key_states], dim=2) | ||
| value_states = torch.cat([past_key_value[1], value_states], dim=2) | ||
| key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx) | ||
| else: | ||
| # self_attention | ||
| key_states = self._shape(self.k_proj(hidden_states), -1, bsz) | ||
| value_states = self._shape(self.v_proj(hidden_states), -1, bsz) | ||
|
|
||
| if self.is_decoder: | ||
| # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. | ||
| # Further calls to cross_attention layer can then reuse all cross-attention | ||
| # key/value_states (first "if" case) | ||
| # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of | ||
| # all previous decoder key/value_states. Further calls to uni-directional self-attention | ||
| # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) | ||
| # if encoder bi-directional self-attention `past_key_value` is always `None` | ||
| past_key_value = (key_states, value_states) | ||
|
|
||
| proj_shape = (bsz * self.num_heads, -1, self.head_dim) | ||
| query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) | ||
| key_states = key_states.reshape(*proj_shape) | ||
|
|
@@ -415,14 +415,11 @@ def __init__(self, *args, **kwargs): | |
| # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). | ||
| self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() | ||
|
|
||
| def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): | ||
| return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) | ||
|
|
||
| def forward( | ||
| self, | ||
| hidden_states: torch.Tensor, | ||
| key_value_states: Optional[torch.Tensor] = None, | ||
| past_key_value: Optional[Tuple[torch.Tensor]] = None, | ||
| past_key_value: Optional[Cache] = None, | ||
| attention_mask: Optional[torch.Tensor] = None, | ||
| layer_head_mask: Optional[torch.Tensor] = None, | ||
| output_attentions: bool = False, | ||
|
|
@@ -445,40 +442,30 @@ def forward( | |
| # the provided `key_value_states` to support prefix tuning | ||
| if ( | ||
| is_cross_attention | ||
| and past_key_value is not None | ||
| and past_key_value[0].shape[2] == key_value_states.shape[1] | ||
| and len(past_key_value.key_cache) >= self.layer_idx + 1 | ||
| and past_key_value.key_cache[self.layer_idx].shape[2] == key_value_states.shape[1] | ||
| ): | ||
| # reuse k,v, cross_attentions | ||
| key_states = past_key_value[0].transpose(1, 2) | ||
| value_states = past_key_value[1].transpose(1, 2) | ||
| key_states = past_key_value.key_cache[self.layer_idx] | ||
| value_states = past_key_value.value_cache[self.layer_idx] | ||
| elif is_cross_attention: | ||
| # cross_attentions | ||
| key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) | ||
| value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) | ||
| key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx) | ||
| elif past_key_value is not None: | ||
| # reuse k, v, self_attention | ||
| key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) | ||
| value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) | ||
| key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) | ||
| value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) | ||
| key_states = self._shape(self.k_proj(hidden_states), -1, bsz) | ||
| value_states = self._shape(self.v_proj(hidden_states), -1, bsz) | ||
| key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx) | ||
|
|
||
| else: | ||
| # self_attention | ||
| key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) | ||
| value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) | ||
|
|
||
| if self.is_decoder: | ||
| # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. | ||
| # Further calls to cross_attention layer can then reuse all cross-attention | ||
| # key/value_states (first "if" case) | ||
| # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of | ||
| # all previous decoder key/value_states. Further calls to uni-directional self-attention | ||
| # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) | ||
| # if encoder bi-directional self-attention `past_key_value` is always `None` | ||
| past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) | ||
|
|
||
| kv_seq_len = key_states.shape[-2] | ||
| if past_key_value is not None: | ||
| kv_seq_len += past_key_value[0].shape[-2] | ||
| key_states = self._shape(self.k_proj(hidden_states), -1, bsz) | ||
| value_states = self._shape(self.v_proj(hidden_states), -1, bsz) | ||
|
|
||
| key_states = key_states.transpose(1, 2) | ||
| value_states = value_states.transpose(1, 2) | ||
|
|
||
| # In PEFT, usually we cast the layer norms in float32 for training stability reasons | ||
| # therefore the input hidden states gets silently casted in float32. Hence, we need | ||
|
|
@@ -624,7 +611,7 @@ def forward( | |
| self, | ||
| hidden_states: torch.Tensor, | ||
| key_value_states: Optional[torch.Tensor] = None, | ||
| past_key_value: Optional[Tuple[torch.Tensor]] = None, | ||
| past_key_value: Optional[Cache] = None, | ||
| attention_mask: Optional[torch.Tensor] = None, | ||
| layer_head_mask: Optional[torch.Tensor] = None, | ||
| output_attentions: bool = False, | ||
|
|
@@ -659,37 +646,27 @@ def forward( | |
| # the provided `key_value_states` to support prefix tuning | ||
| if ( | ||
| is_cross_attention | ||
| and past_key_value is not None | ||
| and past_key_value[0].shape[2] == key_value_states.shape[1] | ||
| and len(past_key_value.key_cache) >= self.layer_idx + 1 | ||
| and past_key_value.key_cache[self.layer_idx].shape[2] == key_value_states.shape[1] | ||
| ): | ||
| # reuse k,v, cross_attentions | ||
| key_states = past_key_value[0] | ||
| value_states = past_key_value[1] | ||
| key_states = past_key_value.key_cache[self.layer_idx] | ||
| value_states = past_key_value.value_cache[self.layer_idx] | ||
| elif is_cross_attention: | ||
| # cross_attentions | ||
| key_states = self._shape(self.k_proj(key_value_states), -1, bsz) | ||
| value_states = self._shape(self.v_proj(key_value_states), -1, bsz) | ||
| key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx) | ||
| elif past_key_value is not None: | ||
| # reuse k, v, self_attention | ||
| key_states = self._shape(self.k_proj(hidden_states), -1, bsz) | ||
| value_states = self._shape(self.v_proj(hidden_states), -1, bsz) | ||
| key_states = torch.cat([past_key_value[0], key_states], dim=2) | ||
| value_states = torch.cat([past_key_value[1], value_states], dim=2) | ||
| key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx) | ||
| else: | ||
| # self_attention | ||
| key_states = self._shape(self.k_proj(hidden_states), -1, bsz) | ||
| value_states = self._shape(self.v_proj(hidden_states), -1, bsz) | ||
|
|
||
| if self.is_decoder: | ||
| # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. | ||
| # Further calls to cross_attention layer can then reuse all cross-attention | ||
| # key/value_states (first "if" case) | ||
| # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of | ||
| # all previous decoder key/value_states. Further calls to uni-directional self-attention | ||
| # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) | ||
| # if encoder bi-directional self-attention `past_key_value` is always `None` | ||
| past_key_value = (key_states, value_states) | ||
|
|
||
| query_states = self._shape(query_states, tgt_len, bsz) | ||
|
|
||
| # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, | ||
|
|
@@ -801,7 +778,7 @@ def forward( | |
|
|
||
| # Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Whisper, MBART->WHISPER | ||
| class WhisperDecoderLayer(nn.Module): | ||
| def __init__(self, config: WhisperConfig): | ||
| def __init__(self, config: WhisperConfig, layer_idx: int = None): | ||
| super().__init__() | ||
| self.embed_dim = config.d_model | ||
|
|
||
|
|
@@ -811,6 +788,7 @@ def __init__(self, config: WhisperConfig): | |
| dropout=config.attention_dropout, | ||
| is_decoder=True, | ||
| is_causal=True, | ||
| layer_idx=layer_idx, | ||
| config=config, | ||
| ) | ||
| self.dropout = config.dropout | ||
|
|
@@ -823,6 +801,7 @@ def __init__(self, config: WhisperConfig): | |
| config.decoder_attention_heads, | ||
| dropout=config.attention_dropout, | ||
| is_decoder=True, | ||
| layer_idx=layer_idx, | ||
| config=config, | ||
| ) | ||
| self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) | ||
|
|
@@ -864,9 +843,9 @@ def forward( | |
| hidden_states = self.self_attn_layer_norm(hidden_states) | ||
|
|
||
| # Self Attention | ||
| # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 | ||
| self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None | ||
| # add present self-attn cache to positions 1,2 of present_key_value tuple | ||
| # decoder uni-directional self-attention cached key/values states are at position 0 | ||
| self_attn_past_key_value = past_key_value[0] if past_key_value is not None else None | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The difficulty comes here from the fact that we're dealing with two sets of past key-values per decoder layer: one from the self-attention, and one from the cross-attention. The current solution uses a separate cache for each. |
||
| # add present self-attn cache to positions 0 of present_key_value tuple | ||
| hidden_states, self_attn_weights, present_key_value = self.self_attn( | ||
| hidden_states=hidden_states, | ||
| past_key_value=self_attn_past_key_value, | ||
|
|
@@ -884,8 +863,8 @@ def forward( | |
| residual = hidden_states | ||
| hidden_states = self.encoder_attn_layer_norm(hidden_states) | ||
|
|
||
| # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple | ||
| cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None | ||
| # cross_attn cached key/values tuple is at position 1 of present_key_value tuple | ||
| cross_attn_past_key_value = past_key_value[1] if past_key_value is not None else None | ||
| hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( | ||
| hidden_states=hidden_states, | ||
| key_value_states=encoder_hidden_states, | ||
|
|
@@ -897,8 +876,8 @@ def forward( | |
| hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) | ||
| hidden_states = residual + hidden_states | ||
|
|
||
| # add cross-attn to positions 3,4 of present_key_value tuple | ||
| present_key_value = present_key_value + cross_attn_present_key_value | ||
| # add cross-attn to positions 1 of present_key_value tuple | ||
| present_key_value = (present_key_value, cross_attn_present_key_value) | ||
|
|
||
| # Fully Connected | ||
| residual = hidden_states | ||
|
|
@@ -928,6 +907,7 @@ class WhisperPreTrainedModel(PreTrainedModel): | |
| _no_split_modules = ["WhisperEncoderLayer", "WhisperDecoderLayer"] | ||
| _supports_flash_attn_2 = True | ||
| _supports_sdpa = True | ||
| _supports_cache_class = True | ||
|
|
||
| def _init_weights(self, module): | ||
| std = self.config.init_std | ||
|
|
@@ -1257,7 +1237,9 @@ def __init__(self, config: WhisperConfig): | |
| self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) | ||
| self.embed_positions = WhisperPositionalEmbedding(self.max_target_positions, config.d_model) | ||
|
|
||
| self.layers = nn.ModuleList([WhisperDecoderLayer(config) for _ in range(config.decoder_layers)]) | ||
| self.layers = nn.ModuleList( | ||
| [WhisperDecoderLayer(config, layer_idx) for layer_idx in range(config.decoder_layers)] | ||
| ) | ||
| self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" | ||
| self._use_sdpa = config._attn_implementation == "sdpa" | ||
|
|
||
|
|
@@ -1365,7 +1347,20 @@ def forward( | |
| raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") | ||
|
|
||
| # past_key_values_length | ||
| past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 | ||
| past_key_values_length = 0 | ||
| if use_cache: | ||
| use_legacy_cache = not (past_key_values is not None and isinstance(past_key_values[0], Cache)) | ||
| if use_legacy_cache: | ||
| if past_key_values is None: | ||
| self_attn = cross_attn = None | ||
| else: | ||
| self_attn = [key_values[:2] for key_values in past_key_values] | ||
| cross_attn = [key_values[2:] for key_values in past_key_values] | ||
| past_key_values = ( | ||
| DynamicCache.from_legacy_cache(self_attn), | ||
| DynamicCache.from_legacy_cache(cross_attn), | ||
| ) | ||
| past_key_values_length = past_key_values[0].get_seq_length() | ||
|
|
||
| if inputs_embeds is None: | ||
| inputs_embeds = self.embed_tokens(input_ids) | ||
|
|
@@ -1407,7 +1402,7 @@ def forward( | |
| all_hidden_states = () if output_hidden_states else None | ||
| all_self_attns = () if output_attentions else None | ||
| all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None | ||
| next_decoder_cache = () if use_cache else None | ||
| next_decoder_cache = None | ||
|
|
||
| # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired | ||
| for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): | ||
|
|
@@ -1425,8 +1420,6 @@ def forward( | |
| if dropout_probability < self.layerdrop: | ||
| continue | ||
|
|
||
| past_key_value = past_key_values[idx] if past_key_values is not None else None | ||
|
|
||
| if self.gradient_checkpointing and self.training: | ||
| layer_outputs = self._gradient_checkpointing_func( | ||
| decoder_layer.__call__, | ||
|
|
@@ -1449,14 +1442,14 @@ def forward( | |
| cross_attn_layer_head_mask=( | ||
| cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None | ||
| ), | ||
| past_key_value=past_key_value, | ||
| past_key_value=past_key_values, | ||
| output_attentions=output_attentions, | ||
| use_cache=use_cache, | ||
| ) | ||
| hidden_states = layer_outputs[0] | ||
|
|
||
| if use_cache: | ||
| next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) | ||
| next_decoder_cache = layer_outputs[3 if output_attentions else 1] | ||
|
|
||
| if output_attentions: | ||
| all_self_attns += (layer_outputs[1],) | ||
|
|
@@ -1469,7 +1462,13 @@ def forward( | |
| if output_hidden_states: | ||
| all_hidden_states += (hidden_states,) | ||
|
|
||
| next_cache = next_decoder_cache if use_cache else None | ||
| next_cache = None | ||
| if use_cache and use_legacy_cache: | ||
| next_cache = () | ||
| for self_attn, cross_attn in zip( | ||
| next_decoder_cache[0].to_legacy_cache(), next_decoder_cache[1].to_legacy_cache() | ||
| ): | ||
| next_cache += (self_attn + cross_attn,) | ||
| if not return_dict: | ||
| return tuple( | ||
| v | ||
|
|
@@ -1806,9 +1805,11 @@ def prepare_inputs_for_generation( | |
| decoder_position_ids = None | ||
| if decoder_attention_mask is not None: | ||
| decoder_position_ids = (decoder_attention_mask.cumsum(-1) - 1).clamp(min=0) | ||
|
|
||
| if past_key_values is not None: | ||
| past_length = past_key_values[0][0].shape[2] | ||
| if isinstance(past_key_values[0], Cache): | ||
| past_length = past_key_values[0].get_seq_length | ||
| else: | ||
| past_length = past_key_values[0][0].shape[2] | ||
|
|
||
| # Some generation methods already pass only the last input ID | ||
| if decoder_input_ids.shape[1] > past_length: | ||
|
|
||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll propagate changes to all other MBart derived modules when we're happy with the design