From fbc715a5c7fc125c568b4e53f11c75cca15c62f0 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 10 Oct 2023 16:55:47 +0000 Subject: [PATCH 01/25] [FA2] Bart-like models --- src/transformers/models/bart/modeling_bart.py | 243 +++++++++++++++++- .../models/llama/modeling_llama.py | 2 +- 2 files changed, 238 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 52dfa5e39229..78a0d5fb07d4 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -23,6 +23,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +import torch.nn.functional as F from ...activations import ACT2FN from ...modeling_outputs import ( BaseModelOutput, @@ -39,17 +40,34 @@ add_end_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + is_flash_attn_available, logging, replace_return_docstrings, ) from .configuration_bart import BartConfig +if is_flash_attn_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "facebook/bart-base" _CONFIG_FOR_DOC = "BartConfig" + +def _get_unpad_data(padding_mask): + seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + # Base model docstring _EXPECTED_OUTPUT_SHAPE = [1, 8, 768] @@ -138,7 +156,6 @@ def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0): return super().forward(positions + self.offset) - class BartAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -292,16 +309,215 @@ def forward( return attn_output, attn_weights_reshaped, past_key_value +class BartFlashAttention2(BartAttention): + """Multi-headed attention from 'Attention Is All You Need' paper""" + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + padding_mask: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # BartFlashAttention2 attention does not support output_attentions + output_attentions = False + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, q_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0].transpose(1, 2) + value_states = past_key_value[1].transpose(1, 2) + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) + value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + # TODO: Bart does not have dropout in the config?? + # It is recommended to use dropout with FA according to the docs + # when training. + dropout_rate = 0.0 # if not self.training else self.attn_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + input_dtype = query_states.dtype + if input_dtype == torch.float32: + logger.warning_once( + "The input hidden states seems to be silently casted in float32, this might be related to" + " the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + " float16." + ) + + query_states = query_states.to(torch.float16) + key_states = key_states.to(torch.float16) + value_states = value_states.to(torch.float16) + + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, padding_mask, q_len, dropout=dropout_rate + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, query_states, key_states, value_states, padding_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + padding_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + # Contains at least one padding token in the sequence + if padding_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, padding_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=True, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True + ) + + return attn_output + + def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + padding_mask = padding_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, padding_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) class BartEncoderLayer(nn.Module): def __init__(self, config: BartConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = BartAttention( - embed_dim=self.embed_dim, - num_heads=config.encoder_attention_heads, - dropout=config.attention_dropout, - ) + + if getattr(config, "_flash_attn_2_enabled", False): + self.self_attn = BartFlashAttention2( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + ) + else: + self.self_attn = BartAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + ) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -372,6 +588,20 @@ def __init__(self, config: BartConfig): dropout=config.attention_dropout, is_decoder=True, ) + if getattr(config, "_flash_attn_2_enabled", False): + self.self_attn = BartFlashAttention2( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) + else: + self.self_attn = BartAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] self.activation_dropout = config.activation_dropout @@ -509,6 +739,7 @@ class BartPreTrainedModel(PreTrainedModel): _keys_to_ignore_on_load_unexpected = ["encoder.version", "decoder.version"] _no_split_modules = [r"BartEncoderLayer", r"BartDecoderLayer"] _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True def _init_weights(self, module): std = self.config.init_std diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 55753d5f75d9..aef4fdab9061 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -437,7 +437,7 @@ def forward( value_states = self.v_proj(hidden_states) # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dime x hidden_dim + # batch_size x seq_length x head_dim x hidden_dim # therefore we just need to keep the original shape query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) From ce43f2a5cf92b4c23a9a970cccb2249474f25e6b Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 10 Oct 2023 17:42:12 +0000 Subject: [PATCH 02/25] make tests work --- .../jax-projects/big_bird/bigbird_flax.py | 2 +- .../jax-projects/big_bird/train.py | 2 +- .../vqgan-clip/VQGAN_CLIP.py | 2 +- src/transformers/models/bart/modeling_bart.py | 13 +++- tests/test_modeling_common.py | 73 +++++++++++++++---- 5 files changed, 72 insertions(+), 20 deletions(-) diff --git a/examples/research_projects/jax-projects/big_bird/bigbird_flax.py b/examples/research_projects/jax-projects/big_bird/bigbird_flax.py index af5e11c83a6a..c171b88800ed 100644 --- a/examples/research_projects/jax-projects/big_bird/bigbird_flax.py +++ b/examples/research_projects/jax-projects/big_bird/bigbird_flax.py @@ -9,13 +9,13 @@ import jax.numpy as jnp import joblib import optax -import wandb from flax import jax_utils, struct, traverse_util from flax.serialization import from_bytes, to_bytes from flax.training import train_state from flax.training.common_utils import shard from tqdm.auto import tqdm +import wandb from transformers import BigBirdConfig, FlaxBigBirdForQuestionAnswering from transformers.models.big_bird.modeling_flax_big_bird import FlaxBigBirdForQuestionAnsweringModule diff --git a/examples/research_projects/jax-projects/big_bird/train.py b/examples/research_projects/jax-projects/big_bird/train.py index ce37b7f975bb..3840918d16ae 100644 --- a/examples/research_projects/jax-projects/big_bird/train.py +++ b/examples/research_projects/jax-projects/big_bird/train.py @@ -2,11 +2,11 @@ from dataclasses import replace import jax -import wandb from bigbird_flax import Args, DataCollator, FlaxBigBirdForNaturalQuestions, Trainer, build_tx, train_step, val_step from datasets import load_dataset from flax import jax_utils +import wandb from transformers import BigBirdTokenizerFast diff --git a/examples/research_projects/vqgan-clip/VQGAN_CLIP.py b/examples/research_projects/vqgan-clip/VQGAN_CLIP.py index 1bfbc4cd5c36..2a39955e347f 100644 --- a/examples/research_projects/vqgan-clip/VQGAN_CLIP.py +++ b/examples/research_projects/vqgan-clip/VQGAN_CLIP.py @@ -4,12 +4,12 @@ import imageio import torch import torchvision -import wandb from img_processing import custom_to_pil, loop_post_process, preprocess, preprocess_vqgan from loaders import load_vqgan from PIL import Image from torch import nn +import wandb from transformers import CLIPModel, CLIPTokenizerFast from utils import get_device, get_timestamp, show_pil diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 78a0d5fb07d4..6678cddd62c0 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -19,11 +19,11 @@ from typing import List, Optional, Tuple, Union import torch +import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -import torch.nn.functional as F from ...activations import ACT2FN from ...modeling_outputs import ( BaseModelOutput, @@ -68,6 +68,7 @@ def _get_unpad_data(padding_mask): max_seqlen_in_batch, ) + # Base model docstring _EXPECTED_OUTPUT_SHAPE = [1, 8, 768] @@ -156,6 +157,7 @@ def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0): return super().forward(positions + self.offset) + class BartAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -309,8 +311,10 @@ def forward( return attn_output, attn_weights_reshaped, past_key_value + class BartFlashAttention2(BartAttention): """Multi-headed attention from 'Attention Is All You Need' paper""" + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) @@ -334,7 +338,7 @@ def forward( bsz, q_len, _ = hidden_states.size() # get query proj - query_states = self.q_proj(hidden_states) * self.scaling + query_states = self._shape(self.q_proj(hidden_states), -1, bsz) * self.scaling # get key, value proj # `past_key_value[0].shape[2] == key_value_states.shape[1]` # is checking that the `sequence_length` of the `past_key_value` is the same as @@ -402,8 +406,8 @@ def forward( query_states, key_states, value_states, padding_mask, q_len, dropout=dropout_rate ) - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() - attn_output = self.o_proj(attn_output) + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.out_proj(attn_output) if not output_attentions: attn_weights = None @@ -501,6 +505,7 @@ def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_l (max_seqlen_in_batch_q, max_seqlen_in_batch_k), ) + class BartEncoderLayer(nn.Module): def __init__(self, config: BartConfig): super().__init__() diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 5a239cf0fb3b..9af1355f4291 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2797,16 +2797,35 @@ def test_flash_attn_2_inference(self): dummy_input = torch.LongTensor([[1, 2, 3, 4, 5]]).to(torch_device) dummy_attention_mask = torch.LongTensor([[0, 1, 1, 1, 1]]).to(torch_device) - logits = model(dummy_input, output_hidden_states=True).hidden_states[-1] - logits_fa = model_fa(dummy_input, output_hidden_states=True).hidden_states[-1] + outputs = model(dummy_input, output_hidden_states=True) + outputs_fa = model(dummy_input, output_hidden_states=True) + + logits = ( + outputs.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs.decoder_hidden_states[-1] + ) + logits_fa = ( + outputs.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs_fa.decoder_hidden_states[-1] + ) self.assertTrue(torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)) - output_fa = model_fa(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True) - logits_fa = output_fa.hidden_states[-1] + outputs = model(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True) + outputs_fa = model(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True) - output = model(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True) - logits = output.hidden_states[-1] + logits = ( + outputs.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs.decoder_hidden_states[-1] + ) + logits_fa = ( + outputs.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs_fa.decoder_hidden_states[-1] + ) self.assertTrue(torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2)) @@ -2839,16 +2858,35 @@ def test_flash_attn_2_inference_padding_right(self): dummy_input = torch.LongTensor([[1, 2, 3, 4, 5]]).to(torch_device) dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1, 0]]).to(torch_device) - logits = model(dummy_input, output_hidden_states=True).hidden_states[-1] - logits_fa = model_fa(dummy_input, output_hidden_states=True).hidden_states[-1] + outputs = model(dummy_input, output_hidden_states=True) + outputs_fa = model(dummy_input, output_hidden_states=True) + + logits = ( + outputs.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs.decoder_hidden_states[-1] + ) + logits_fa = ( + outputs.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs_fa.decoder_hidden_states[-1] + ) self.assertTrue(torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)) - output_fa = model_fa(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True) - logits_fa = output_fa.hidden_states[-1] + outputs = model(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True) + outputs_fa = model(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True) - output = model(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True) - logits = output.hidden_states[-1] + logits = ( + outputs.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs.decoder_hidden_states[-1] + ) + logits_fa = ( + outputs.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs_fa.decoder_hidden_states[-1] + ) self.assertTrue(torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)) @@ -2938,6 +2976,11 @@ def test_flash_attn_2_generate_use_cache(self): return config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + # make sure that all models have at least 40 position ids + if hasattr(config, "max_position_embeddings"): + config.max_position_embeddings = 40 + model = model_class(config) with tempfile.TemporaryDirectory() as tmpdirname: @@ -2947,7 +2990,11 @@ def test_flash_attn_2_generate_use_cache(self): dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [0, 1, 1, 1]]).to(torch_device) model = model_class.from_pretrained( - tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True, low_cpu_mem_usage=True + # tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True, low_cpu_mem_usage=True + tmpdirname, + torch_dtype=torch.float32, + use_flash_attention_2=False, + low_cpu_mem_usage=True, ).to(torch_device) # Just test that a large cache works as expected From 2b793529400ddaaedb3bd35c4f88926c76aad806 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 10 Oct 2023 19:45:44 +0200 Subject: [PATCH 03/25] Apply suggestions from code review --- .../research_projects/jax-projects/big_bird/bigbird_flax.py | 3 +-- examples/research_projects/jax-projects/big_bird/train.py | 2 +- examples/research_projects/vqgan-clip/VQGAN_CLIP.py | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/examples/research_projects/jax-projects/big_bird/bigbird_flax.py b/examples/research_projects/jax-projects/big_bird/bigbird_flax.py index c171b88800ed..324842fd15da 100644 --- a/examples/research_projects/jax-projects/big_bird/bigbird_flax.py +++ b/examples/research_projects/jax-projects/big_bird/bigbird_flax.py @@ -8,14 +8,13 @@ import jax import jax.numpy as jnp import joblib -import optax +import wandb from flax import jax_utils, struct, traverse_util from flax.serialization import from_bytes, to_bytes from flax.training import train_state from flax.training.common_utils import shard from tqdm.auto import tqdm -import wandb from transformers import BigBirdConfig, FlaxBigBirdForQuestionAnswering from transformers.models.big_bird.modeling_flax_big_bird import FlaxBigBirdForQuestionAnsweringModule diff --git a/examples/research_projects/jax-projects/big_bird/train.py b/examples/research_projects/jax-projects/big_bird/train.py index 3840918d16ae..ce37b7f975bb 100644 --- a/examples/research_projects/jax-projects/big_bird/train.py +++ b/examples/research_projects/jax-projects/big_bird/train.py @@ -2,11 +2,11 @@ from dataclasses import replace import jax +import wandb from bigbird_flax import Args, DataCollator, FlaxBigBirdForNaturalQuestions, Trainer, build_tx, train_step, val_step from datasets import load_dataset from flax import jax_utils -import wandb from transformers import BigBirdTokenizerFast diff --git a/examples/research_projects/vqgan-clip/VQGAN_CLIP.py b/examples/research_projects/vqgan-clip/VQGAN_CLIP.py index 2a39955e347f..1bfbc4cd5c36 100644 --- a/examples/research_projects/vqgan-clip/VQGAN_CLIP.py +++ b/examples/research_projects/vqgan-clip/VQGAN_CLIP.py @@ -4,12 +4,12 @@ import imageio import torch import torchvision +import wandb from img_processing import custom_to_pil, loop_post_process, preprocess, preprocess_vqgan from loaders import load_vqgan from PIL import Image from torch import nn -import wandb from transformers import CLIPModel, CLIPTokenizerFast from utils import get_device, get_timestamp, show_pil From a70f9cff885e42b121c0303850203464110a7465 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 10 Oct 2023 19:46:24 +0200 Subject: [PATCH 04/25] Apply suggestions from code review --- tests/test_modeling_common.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 9af1355f4291..4cc8865cfd0d 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2990,10 +2990,9 @@ def test_flash_attn_2_generate_use_cache(self): dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [0, 1, 1, 1]]).to(torch_device) model = model_class.from_pretrained( - # tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True, low_cpu_mem_usage=True tmpdirname, - torch_dtype=torch.float32, - use_flash_attention_2=False, + torch_dtype=torch.float16, + use_flash_attention_2=True, low_cpu_mem_usage=True, ).to(torch_device) From be48a5c2791bd5bcf5bbf760491a34d28742bc29 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 10 Oct 2023 19:26:23 +0000 Subject: [PATCH 05/25] improve --- src/transformers/models/bart/modeling_bart.py | 32 ++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 6678cddd62c0..52ffea153c47 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -57,6 +57,7 @@ _CONFIG_FOR_DOC = "BartConfig" +# Copied from transformers.models.llama.modeling_llama._get_unpad_data def _get_unpad_data(padding_mask): seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32) indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten() @@ -199,6 +200,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + padding_mask: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -313,7 +315,12 @@ def forward( class BartFlashAttention2(BartAttention): - """Multi-headed attention from 'Attention Is All You Need' paper""" + """ + Bart flash attention module. This module inherits from `BartAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) @@ -414,6 +421,7 @@ def forward( return attn_output, attn_weights, past_key_value + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward def _flash_attention_forward( self, query_states, key_states, value_states, padding_mask, query_length, dropout=0.0, softmax_scale=None ): @@ -467,6 +475,7 @@ def _flash_attention_forward( return attn_output + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_length): indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask) batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape @@ -537,6 +546,7 @@ def forward( attention_mask: torch.FloatTensor, layer_head_mask: torch.FloatTensor, output_attentions: Optional[bool] = False, + padding_mask: Optional[torch.LongTensor] = None, ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]: """ Args: @@ -555,6 +565,7 @@ def forward( attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, + padding_mask=padding_mask, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states @@ -634,6 +645,8 @@ def forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, + padding_mask: Optional[torch.Tensor] = None, + encoder_padding_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -665,6 +678,7 @@ def forward( attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, + padding_mask=padding_mask, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states @@ -685,6 +699,7 @@ def forward( layer_head_mask=cross_attn_layer_head_mask, past_key_value=cross_attn_past_key_value, output_attentions=output_attentions, + padding_mask=encoder_padding_mask, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states @@ -1063,7 +1078,10 @@ def forward( # expand attention_mask if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + padding_mask = attention_mask if 0 in attention_mask else None attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) + else: + padding_mask = None encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -1294,6 +1312,11 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input) * self.embed_scale + if attention_mask is None: + padding_mask = None + else: + padding_mask = attention_mask if 0 in attention_mask else None + attention_mask = self._prepare_decoder_attention_mask( attention_mask, input_shape, inputs_embeds, past_key_values_length ) @@ -1301,7 +1324,10 @@ def forward( # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_padding_mask = encoder_attention_mask if 0 in encoder_attention_mask else None encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + else: + encoder_padding_mask = None # embed positions positions = self.embed_positions(input, past_key_values_length) @@ -1363,6 +1389,8 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, + padding_mask, + encoder_padding_mask, ) else: layer_outputs = decoder_layer( @@ -1377,6 +1405,8 @@ def custom_forward(*inputs): past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, + padding_mask=padding_mask, + encoder_padding_mask=encoder_padding_mask, ) hidden_states = layer_outputs[0] From 8e344ee0ed7cf83b35f249c664e7f412a7ee78ca Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 10 Oct 2023 21:27:08 +0200 Subject: [PATCH 06/25] Update examples/research_projects/jax-projects/big_bird/bigbird_flax.py --- examples/research_projects/jax-projects/big_bird/bigbird_flax.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/research_projects/jax-projects/big_bird/bigbird_flax.py b/examples/research_projects/jax-projects/big_bird/bigbird_flax.py index 324842fd15da..af5e11c83a6a 100644 --- a/examples/research_projects/jax-projects/big_bird/bigbird_flax.py +++ b/examples/research_projects/jax-projects/big_bird/bigbird_flax.py @@ -8,6 +8,7 @@ import jax import jax.numpy as jnp import joblib +import optax import wandb from flax import jax_utils, struct, traverse_util from flax.serialization import from_bytes, to_bytes From 8e35ac92cb54ef502b6aac71bae16d589c4814f5 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 10 Oct 2023 19:31:02 +0000 Subject: [PATCH 07/25] improve further --- src/transformers/models/bart/modeling_bart.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 52ffea153c47..7ea88251e611 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -545,8 +545,8 @@ def forward( hidden_states: torch.FloatTensor, attention_mask: torch.FloatTensor, layer_head_mask: torch.FloatTensor, - output_attentions: Optional[bool] = False, padding_mask: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]: """ Args: @@ -643,10 +643,10 @@ def forward( layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = True, padding_mask: Optional[torch.Tensor] = None, encoder_padding_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = True, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -1120,12 +1120,14 @@ def custom_forward(*inputs): hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), + padding_mask ) else: layer_outputs = encoder_layer( hidden_states, attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), + padding_mask=padding_mask, output_attentions=output_attentions, ) @@ -1403,10 +1405,10 @@ def custom_forward(*inputs): cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None ), past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, padding_mask=padding_mask, encoder_padding_mask=encoder_padding_mask, + output_attentions=output_attentions, + use_cache=use_cache, ) hidden_states = layer_outputs[0] From a07044551c3c140ff58df4b73ece45981cfd506a Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 10 Oct 2023 20:19:56 +0000 Subject: [PATCH 08/25] fix bart --- src/transformers/models/bart/modeling_bart.py | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 7ea88251e611..e1189c8554ed 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -623,12 +623,20 @@ def __init__(self, config: BartConfig): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = BartAttention( - self.embed_dim, - config.decoder_attention_heads, - dropout=config.attention_dropout, - is_decoder=True, - ) + if getattr(config, "_flash_attn_2_enabled", False): + self.encoder_attn = BartFlashAttention2( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) + else: + self.encoder_attn = BartAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) From 341b7c930254b2f9185c6d878537553eb1cc66b7 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 10 Oct 2023 20:20:16 +0000 Subject: [PATCH 09/25] add FA to whisper --- .../models/whisper/modeling_whisper.py | 61 +++++++++++++------ 1 file changed, 44 insertions(+), 17 deletions(-) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 447d7275d557..11383800205a 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -464,16 +464,27 @@ def forward( return attn_output, attn_weights_reshaped, past_key_value +# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->Whisper +class WhisperFlashAttention2(WhisperAttention): + pass + # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Whisper class WhisperEncoderLayer(nn.Module): def __init__(self, config: WhisperConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = WhisperAttention( - embed_dim=self.embed_dim, - num_heads=config.encoder_attention_heads, - dropout=config.attention_dropout, - ) + if getattr(config, "_flash_attn_2_enabled", False): + self.self_attn = WhisperFlashAttention2( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + ) + else: + self.self_attn = WhisperAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + ) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -539,23 +550,39 @@ def __init__(self, config: WhisperConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = WhisperAttention( - embed_dim=self.embed_dim, - num_heads=config.decoder_attention_heads, - dropout=config.attention_dropout, - is_decoder=True, - ) + if getattr(config, "_flash_attn_2_enabled", False): + self.self_attn = WhisperFlashAttention2( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) + else: + self.self_attn = WhisperAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = WhisperAttention( - self.embed_dim, - config.decoder_attention_heads, - dropout=config.attention_dropout, - is_decoder=True, - ) + if getattr(config, "_flash_attn_2_enabled", False): + self.encoder_attn = WhisperFlashAttention2( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) + else: + self.encoder_attn = WhisperAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) From cdf21900fa5629995a1928edbe4bbff94bd00457 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 10 Oct 2023 20:21:36 +0000 Subject: [PATCH 10/25] make fix copies whisper --- .../models/whisper/modeling_whisper.py | 256 +++++++++++++++--- 1 file changed, 215 insertions(+), 41 deletions(-) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 11383800205a..c7dc08bd7aad 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -351,6 +351,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + padding_mask: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -466,25 +467,214 @@ def forward( # Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->Whisper class WhisperFlashAttention2(WhisperAttention): - pass + """ + Whisper flash attention module. This module inherits from `WhisperAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + padding_mask: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # WhisperFlashAttention2 attention does not support output_attentions + output_attentions = False + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, q_len, _ = hidden_states.size() + + # get query proj + query_states = self._shape(self.q_proj(hidden_states), -1, bsz) * self.scaling + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0].transpose(1, 2) + value_states = past_key_value[1].transpose(1, 2) + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) + value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + # TODO: Whisper does not have dropout in the config?? + # It is recommended to use dropout with FA according to the docs + # when training. + dropout_rate = 0.0 # if not self.training else self.attn_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + input_dtype = query_states.dtype + if input_dtype == torch.float32: + logger.warning_once( + "The input hidden states seems to be silently casted in float32, this might be related to" + " the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + " float16." + ) + + query_states = query_states.to(torch.float16) + key_states = key_states.to(torch.float16) + value_states = value_states.to(torch.float16) + + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, padding_mask, q_len, dropout=dropout_rate + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward + def _flash_attention_forward( + self, query_states, key_states, value_states, padding_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + padding_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + # Contains at least one padding token in the sequence + if padding_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, padding_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=True, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True + ) + + return attn_output + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input + def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + padding_mask = padding_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, padding_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Whisper class WhisperEncoderLayer(nn.Module): def __init__(self, config: WhisperConfig): super().__init__() self.embed_dim = config.d_model - if getattr(config, "_flash_attn_2_enabled", False): - self.self_attn = WhisperFlashAttention2( - embed_dim=self.embed_dim, - num_heads=config.encoder_attention_heads, - dropout=config.attention_dropout, - ) - else: - self.self_attn = WhisperAttention( - embed_dim=self.embed_dim, - num_heads=config.encoder_attention_heads, - dropout=config.attention_dropout, - ) + self.self_attn = WhisperAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + ) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -550,39 +740,23 @@ def __init__(self, config: WhisperConfig): super().__init__() self.embed_dim = config.d_model - if getattr(config, "_flash_attn_2_enabled", False): - self.self_attn = WhisperFlashAttention2( - embed_dim=self.embed_dim, - num_heads=config.encoder_attention_heads, - dropout=config.attention_dropout, - is_decoder=True, - ) - else: - self.self_attn = WhisperAttention( - embed_dim=self.embed_dim, - num_heads=config.encoder_attention_heads, - dropout=config.attention_dropout, - is_decoder=True, - ) + self.self_attn = WhisperAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - if getattr(config, "_flash_attn_2_enabled", False): - self.encoder_attn = WhisperFlashAttention2( - embed_dim=self.embed_dim, - num_heads=config.decoder_attention_heads, - dropout=config.attention_dropout, - is_decoder=True, - ) - else: - self.encoder_attn = WhisperAttention( - embed_dim=self.embed_dim, - num_heads=config.decoder_attention_heads, - dropout=config.attention_dropout, - is_decoder=True, - ) + self.encoder_attn = WhisperAttention( + self.embed_dim, + config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) From ddad165eabd491a758a5653bf9109f90067bc558 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 10 Oct 2023 21:02:35 +0000 Subject: [PATCH 11/25] correct more --- src/transformers/models/bart/modeling_bart.py | 2 +- .../models/whisper/modeling_whisper.py | 80 ++++++++++++++----- 2 files changed, 63 insertions(+), 19 deletions(-) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index e1189c8554ed..8c8458217b61 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -345,7 +345,7 @@ def forward( bsz, q_len, _ = hidden_states.size() # get query proj - query_states = self._shape(self.q_proj(hidden_states), -1, bsz) * self.scaling + query_states = self._shape(self.q_proj(hidden_states), -1, bsz) # get key, value proj # `past_key_value[0].shape[2] == key_value_states.shape[1]` # is checking that the `sequence_length` of the `past_key_value` is the same as diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index c7dc08bd7aad..b44a3a95401b 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -36,6 +36,7 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, + is_flash_attn_available, logging, replace_return_docstrings, ) @@ -43,6 +44,10 @@ from .tokenization_whisper import TASK_IDS, TO_LANGUAGE_CODE +if is_flash_attn_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "WhisperConfig" @@ -55,6 +60,19 @@ ] +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(padding_mask): + seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + # Copied from transformers.models.bart.modeling_bart.shift_tokens_right def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): """ @@ -496,7 +514,7 @@ def forward( bsz, q_len, _ = hidden_states.size() # get query proj - query_states = self._shape(self.q_proj(hidden_states), -1, bsz) * self.scaling + query_states = self._shape(self.q_proj(hidden_states), -1, bsz) # get key, value proj # `past_key_value[0].shape[2] == key_value_states.shape[1]` # is checking that the `sequence_length` of the `past_key_value` is the same as @@ -670,11 +688,20 @@ class WhisperEncoderLayer(nn.Module): def __init__(self, config: WhisperConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = WhisperAttention( - embed_dim=self.embed_dim, - num_heads=config.encoder_attention_heads, - dropout=config.attention_dropout, - ) + + if getattr(config, "_flash_attn_2_enabled", False) and False: + self.self_attn = WhisperFlashAttention2( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + ) + else: + self.self_attn = WhisperAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + ) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -740,23 +767,39 @@ def __init__(self, config: WhisperConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = WhisperAttention( - embed_dim=self.embed_dim, - num_heads=config.decoder_attention_heads, - dropout=config.attention_dropout, - is_decoder=True, - ) + if getattr(config, "_flash_attn_2_enabled", False): + self.self_attn = WhisperFlashAttention2( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) + else: + self.self_attn = WhisperAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = WhisperAttention( - self.embed_dim, - config.decoder_attention_heads, - dropout=config.attention_dropout, - is_decoder=True, - ) + if getattr(config, "_flash_attn_2_enabled", False) and False: + self.encoder_attn = WhisperFlashAttention2( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) + else: + self.encoder_attn = WhisperAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) @@ -858,6 +901,7 @@ class WhisperPreTrainedModel(PreTrainedModel): main_input_name = "input_features" supports_gradient_checkpointing = True _no_split_modules = ["WhisperEncoderLayer", "WhisperDecoderLayer"] + _supports_flash_attn_2 = True def _init_weights(self, module): std = self.config.init_std From 045f1838ca95a15a5fe457d6d40f04bd033dfcdf Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 10 Oct 2023 21:26:44 +0000 Subject: [PATCH 12/25] more --- src/transformers/models/bart/modeling_bart.py | 13 ++++++------- .../models/whisper/modeling_whisper.py | 16 ++++++++-------- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 8c8458217b61..e3649357832e 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -320,8 +320,6 @@ class BartFlashAttention2(BartAttention): untouched. The only required change would be on the forward pass where it needs to correctly call the public API of flash attention and deal with padding tokens in case the input contains any of them. """ - - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) @@ -409,11 +407,12 @@ def forward( key_states = key_states.to(torch.float16) value_states = value_states.to(torch.float16) + causal = self.is_decoder and not is_cross_attention attn_output = self._flash_attention_forward( - query_states, key_states, value_states, padding_mask, q_len, dropout=dropout_rate + query_states, key_states, value_states, padding_mask, q_len, causal=causal, dropout=dropout_rate ) - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = attn_output.reshape(bsz, q_len, -1) attn_output = self.out_proj(attn_output) if not output_attentions: @@ -423,7 +422,7 @@ def forward( # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward def _flash_attention_forward( - self, query_states, key_states, value_states, padding_mask, query_length, dropout=0.0, softmax_scale=None + self, query_states, key_states, value_states, padding_mask, query_length, causal=True, dropout=0.0, softmax_scale=None ): """ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token @@ -464,13 +463,13 @@ def _flash_attention_forward( max_seqlen_k=max_seqlen_in_batch_k, dropout_p=dropout, softmax_scale=softmax_scale, - causal=True, + causal=causal, ) attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) else: attn_output = flash_attn_func( - query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal ) return attn_output diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index b44a3a95401b..b15d05e45a2f 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -578,11 +578,12 @@ def forward( key_states = key_states.to(torch.float16) value_states = value_states.to(torch.float16) + causal = self.is_decoder and not is_cross_attention attn_output = self._flash_attention_forward( - query_states, key_states, value_states, padding_mask, q_len, dropout=dropout_rate + query_states, key_states, value_states, padding_mask, q_len, causal=causal, dropout=dropout_rate ) - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = attn_output.reshape(bsz, q_len, -1) attn_output = self.out_proj(attn_output) if not output_attentions: @@ -590,9 +591,8 @@ def forward( return attn_output, attn_weights, past_key_value - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward def _flash_attention_forward( - self, query_states, key_states, value_states, padding_mask, query_length, dropout=0.0, softmax_scale=None + self, query_states, key_states, value_states, padding_mask, query_length, causal=True, dropout=0.0, softmax_scale=None ): """ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token @@ -633,13 +633,13 @@ def _flash_attention_forward( max_seqlen_k=max_seqlen_in_batch_k, dropout_p=dropout, softmax_scale=softmax_scale, - causal=True, + causal=causal, ) attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) else: attn_output = flash_attn_func( - query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal, ) return attn_output @@ -689,7 +689,7 @@ def __init__(self, config: WhisperConfig): super().__init__() self.embed_dim = config.d_model - if getattr(config, "_flash_attn_2_enabled", False) and False: + if getattr(config, "_flash_attn_2_enabled", False): self.self_attn = WhisperFlashAttention2( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, @@ -786,7 +786,7 @@ def __init__(self, config: WhisperConfig): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - if getattr(config, "_flash_attn_2_enabled", False) and False: + if getattr(config, "_flash_attn_2_enabled", False): self.encoder_attn = WhisperFlashAttention2( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, From 5a82297a4c8b97e8801270ffa5086417847f33f8 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 10 Oct 2023 21:28:00 +0000 Subject: [PATCH 13/25] improve bart --- src/transformers/models/bart/modeling_bart.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index e3649357832e..dd4223a47f47 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -420,7 +420,6 @@ def forward( return attn_output, attn_weights, past_key_value - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward def _flash_attention_forward( self, query_states, key_states, value_states, padding_mask, query_length, causal=True, dropout=0.0, softmax_scale=None ): From 39f820d9822c02309faa7f0ab1cc3e77fe4f0446 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 24 Oct 2023 14:41:27 +0200 Subject: [PATCH 14/25] improve flash attention --- src/transformers/models/bart/modeling_bart.py | 267 ++++++++++++------ .../models/llama/modeling_llama.py | 8 +- 2 files changed, 183 insertions(+), 92 deletions(-) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 309717310e70..79ab86edb249 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -90,6 +90,143 @@ def _get_unpad_data(padding_mask): ] +# Copied from transformers.models.llama.modeling_llama.AttnMaskConverter +class AttnMaskConverter: + """ + A utility attention mask class that allows: + - Create a causal 4d mask + - Create a causal 4d mask with slided window + - Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length, + key_value_length) that can be multiplied with attention scores + + Parameters: + is_causal (`bool`): + Whether the attention mask should be a uni-directional (causal) or bi-directional mask. + + sliding_window (`int`, *optional*): + Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer. + """ + + def __init__(self, is_causal: bool, sliding_window: Optional[int] = None): + self.is_causal = is_causal + self.sliding_window = sliding_window + + def to_causal_4d( + self, + batch_size: int, + query_length: int, + key_value_length: int, + dtype: torch.dtype = torch.float32, + device: Union[torch.device, "str"] = "cpu", + ) -> torch.Tensor: + """ + Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative + bias to upper right hand triangular matrix (causal mask). + """ + if not self.is_causal: + raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.") + + # If shape is not cached, create a new causal mask and cache it + input_shape = (batch_size, query_length) + past_key_values_length = key_value_length - query_length + + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + causal_4d_mask = None + if input_shape[-1] > 1 or self.sliding_window is not None: + past_key_values_length = key_value_length - query_length + causal_4d_mask = self._make_causal_mask( + input_shape, + dtype, + device=device, + past_key_values_length=past_key_values_length, + sliding_window=self.sliding_window, + ) + + return causal_4d_mask + + def to_4d( + self, + attention_mask_2d: torch.Tensor, + query_length: int, + key_value_length: int, + dtype: torch.dtype = torch.float32, + ) -> torch.Tensor: + """ + Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length, + key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is + causal, a causal mask will be added. + """ + input_shape = (attention_mask_2d.shape[0], query_length) + past_key_values_length = key_value_length - query_length + + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + causal_4d_mask = None + if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal: + past_key_values_length = key_value_length - query_length + causal_4d_mask = self._make_causal_mask( + input_shape, + dtype, + device=attention_mask_2d.device, + past_key_values_length=past_key_values_length, + sliding_window=self.sliding_window, + ) + elif self.sliding_window is not None: + raise NotImplementedError("Sliding window is currently only implemented for causal masking") + + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to( + attention_mask_2d.device + ) + expanded_4d_mask = expanded_attn_mask if causal_4d_mask is None else expanded_attn_mask + causal_4d_mask + + return expanded_4d_mask + + def _make_causal_mask( + self, + input_ids_shape: torch.Size, + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0, + sliding_window: Optional[int] = None, + ): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + + # add lower triangular sliding window mask if necessary + if sliding_window is not None: + diagonal = past_key_values_length - sliding_window + 1 + + context_mask = 1 - torch.triu(torch.ones_like(mask, dtype=torch.int), diagonal=diagonal) + mask.masked_fill_(context_mask.bool(), torch.finfo(dtype).min) + + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + def _expand_mask(self, mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): """ Shift input ids one token to the right. @@ -106,37 +243,6 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start return shifted_input_ids -def _make_causal_mask( - input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 -): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - - -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - class BartLearnedPositionalEmbedding(nn.Embedding): """ This module learns positional embeddings up to a fixed maximum size. @@ -183,6 +289,7 @@ def __init__( ) self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder + self.is_causal = is_decoder self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) @@ -200,7 +307,6 @@ def forward( attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, - padding_mask: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -331,7 +437,6 @@ def forward( attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, - padding_mask: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: # BartFlashAttention2 attention does not support output_attentions output_attentions = False @@ -407,9 +512,8 @@ def forward( key_states = key_states.to(torch.float16) value_states = value_states.to(torch.float16) - causal = self.is_decoder and not is_cross_attention attn_output = self._flash_attention_forward( - query_states, key_states, value_states, padding_mask, q_len, causal=causal, dropout=dropout_rate + query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate ) attn_output = attn_output.reshape(bsz, q_len, -1) @@ -420,8 +524,9 @@ def forward( return attn_output, attn_weights, past_key_value + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward def _flash_attention_forward( - self, query_states, key_states, value_states, padding_mask, query_length, causal=True, dropout=0.0, softmax_scale=None + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None ): """ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token @@ -434,7 +539,7 @@ def _flash_attention_forward( Input key states to be passed to Flash Attention API value_states (`torch.Tensor`): Input value states to be passed to Flash Attention API - padding_mask (`torch.Tensor`): + attention_mask (`torch.Tensor`): The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the position of padding tokens and 1 for the position of non-padding tokens. dropout (`int`, *optional*): @@ -443,10 +548,10 @@ def _flash_attention_forward( The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) """ # Contains at least one padding token in the sequence - if padding_mask is not None: + if attention_mask is not None: batch_size = query_states.shape[0] query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( - query_states, key_states, value_states, padding_mask, query_length + query_states, key_states, value_states, attention_mask, query_length ) cu_seqlens_q, cu_seqlens_k = cu_seq_lens @@ -462,13 +567,13 @@ def _flash_attention_forward( max_seqlen_k=max_seqlen_in_batch_k, dropout_p=dropout, softmax_scale=softmax_scale, - causal=causal, + causal=self.is_causal, ) attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) else: attn_output = flash_attn_func( - query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True ) return attn_output @@ -986,6 +1091,8 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = No self.max_source_positions = config.max_position_embeddings self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + self.attn_mask_converter = AttnMaskConverter(is_causal=False) + self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) if embed_tokens is not None: @@ -1081,13 +1188,15 @@ def forward( hidden_states = self.layernorm_embedding(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - # expand attention_mask - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - padding_mask = attention_mask if 0 in attention_mask else None - attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) + if getattr(self.config, "_flash_attn_2_enabled", False): + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None else: - padding_mask = None + # 4d mask is passed through the layers + if attention_mask is not None: + attention_mask = self.causal_attn_mask_converter.to_4d( + attention_mask, input.shape[1], dtype=inputs_embeds.dtype + ) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -1126,14 +1235,12 @@ def custom_forward(*inputs): hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), - padding_mask ) else: layer_outputs = encoder_layer( hidden_states, attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), - padding_mask=padding_mask, output_attentions=output_attentions, ) @@ -1174,6 +1281,9 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = No if embed_tokens is not None: self.embed_tokens.weight = embed_tokens.weight + self.causal_attn_mask_converter = AttnMaskConverter(is_causal=True) + self.attn_mask_converter = AttnMaskConverter(is_causal=False) + self.embed_positions = BartLearnedPositionalEmbedding( config.max_position_embeddings, config.d_model, @@ -1191,29 +1301,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.embed_tokens = value - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( - inputs_embeds.device - ) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - def forward( self, input_ids: torch.LongTensor = None, @@ -1320,22 +1407,30 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input) * self.embed_scale - if attention_mask is None: - padding_mask = None + if getattr(self.config, "_flash_attn_2_enabled", False): + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None else: - padding_mask = attention_mask if 0 in attention_mask else None - - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) + key_value_length = input_shape[1] + past_key_values_length + # 4d mask is passed through the layers + if attention_mask is not None: + attention_mask = self.causal_attn_mask_converter.to_4d( + attention_mask, input_shape[1], key_value_length, dtype=inputs_embeds.dtype + ) + else: + attention_mask = self.causal_attn_mask_converter.to_causal_4d( + input_shape[0], input_shape[1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device + ) # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_padding_mask = encoder_attention_mask if 0 in encoder_attention_mask else None - encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) - else: - encoder_padding_mask = None + if getattr(self.config, "_flash_attn_2_enabled", False): + encoder_attention_mask = encoder_attention_mask if (encoder_attention_mask is not None and 0 in encoder_attention_mask) else None + else: + if encoder_attention_mask is not None: + encoder_attention_mask = self.attn_mask_converter.to_4d( + encoder_attention_mask, encoder_attention_mask.shape[1], dtype=inputs_embeds.dtype + ) # embed positions positions = self.embed_positions(input, past_key_values_length) @@ -1397,8 +1492,6 @@ def custom_forward(*inputs): head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, - padding_mask, - encoder_padding_mask, ) else: layer_outputs = decoder_layer( @@ -1411,8 +1504,6 @@ def custom_forward(*inputs): cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None ), past_key_value=past_key_value, - padding_mask=padding_mask, - encoder_padding_mask=encoder_padding_mask, output_attentions=output_attentions, use_cache=use_cache, ) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 3c1015b9e22c..45dbd7cef4bf 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -122,7 +122,7 @@ def to_4d( self, attention_mask_2d: torch.Tensor, query_length: int, - key_value_length: int, + key_value_length: Optional[int] = None, dtype: torch.dtype = torch.float32, ) -> torch.Tensor: """ @@ -131,12 +131,14 @@ def to_4d( causal, a causal mask will be added. """ input_shape = (attention_mask_2d.shape[0], query_length) - past_key_values_length = key_value_length - query_length # create causal mask # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] causal_4d_mask = None if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal: + if key_value_length is None: + raise ValueError("This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask.") + past_key_values_length = key_value_length - query_length causal_4d_mask = self._make_causal_mask( input_shape, @@ -913,8 +915,6 @@ def __init__(self, config: LlamaConfig): self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - # create attention mask cache that trickles down to each attention layer - # so that the attention_mask cache can be shared among layers self.attn_mask_converter = AttnMaskConverter(is_causal=True) self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) From c4517f34cb60c9957b9552f83c134e0d497d3c98 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 24 Oct 2023 14:49:52 +0200 Subject: [PATCH 15/25] fix more --- src/transformers/models/bart/modeling_bart.py | 37 ++++++++----------- 1 file changed, 15 insertions(+), 22 deletions(-) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 79ab86edb249..d3ac9bfec137 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -58,9 +58,9 @@ # Copied from transformers.models.llama.modeling_llama._get_unpad_data -def _get_unpad_data(padding_mask): - seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten() +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero( attention_mask.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = seqlens_in_batch.max().item() cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) return ( @@ -149,7 +149,7 @@ def to_4d( self, attention_mask_2d: torch.Tensor, query_length: int, - key_value_length: int, + key_value_length: Optional[int] = None, dtype: torch.dtype = torch.float32, ) -> torch.Tensor: """ @@ -158,12 +158,14 @@ def to_4d( causal, a causal mask will be added. """ input_shape = (attention_mask_2d.shape[0], query_length) - past_key_values_length = key_value_length - query_length # create causal mask # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] causal_4d_mask = None if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal: + if key_value_length is None: + raise ValueError("This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask.") + past_key_values_length = key_value_length - query_length causal_4d_mask = self._make_causal_mask( input_shape, @@ -274,6 +276,7 @@ def __init__( num_heads: int, dropout: float = 0.0, is_decoder: bool = False, + is_causal: bool = False, bias: bool = True, ): super().__init__() @@ -289,7 +292,7 @@ def __init__( ) self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder - self.is_causal = is_decoder + self.is_causal = is_causal self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) @@ -579,8 +582,8 @@ def _flash_attention_forward( return attn_output # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input - def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_length): - indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask) + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape key_layer = index_first_axis( @@ -605,8 +608,8 @@ def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_l query_layer = query_layer.squeeze(1) else: # The -q_len: slice assumes left padding. - padding_mask = padding_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, padding_mask) + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) return ( query_layer, @@ -648,7 +651,6 @@ def forward( hidden_states: torch.FloatTensor, attention_mask: torch.FloatTensor, layer_head_mask: torch.FloatTensor, - padding_mask: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = False, ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]: """ @@ -668,7 +670,6 @@ def forward( attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, - padding_mask=padding_mask, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states @@ -701,18 +702,13 @@ def __init__(self, config: BartConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = BartAttention( - embed_dim=self.embed_dim, - num_heads=config.decoder_attention_heads, - dropout=config.attention_dropout, - is_decoder=True, - ) if getattr(config, "_flash_attn_2_enabled", False): self.self_attn = BartFlashAttention2( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, + is_causal=True, ) else: self.self_attn = BartAttention( @@ -720,6 +716,7 @@ def __init__(self, config: BartConfig): num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, + is_causal=True, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -754,8 +751,6 @@ def forward( layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, - padding_mask: Optional[torch.Tensor] = None, - encoder_padding_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: @@ -789,7 +784,6 @@ def forward( attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, - padding_mask=padding_mask, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states @@ -810,7 +804,6 @@ def forward( layer_head_mask=cross_attn_layer_head_mask, past_key_value=cross_attn_past_key_value, output_attentions=output_attentions, - padding_mask=encoder_padding_mask, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states From f195c15a705efd43823e888141c6d9df2f20621a Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 24 Oct 2023 16:57:15 +0200 Subject: [PATCH 16/25] fix attn mask bug --- src/transformers/models/bart/modeling_bart.py | 35 +++++++++++-------- .../models/whisper/modeling_whisper.py | 5 +-- 2 files changed, 24 insertions(+), 16 deletions(-) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index d3ac9bfec137..7dad0c11ce43 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -40,14 +40,14 @@ add_end_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_available, + is_flash_attn_2_available, logging, replace_return_docstrings, ) from .configuration_bart import BartConfig -if is_flash_attn_available(): +if is_flash_attn_2_available(): from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa @@ -500,20 +500,27 @@ def forward( # In PEFT, usually we cast the layer norms in float32 for training stability reasons # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in float16 just to be sure everything works as expected. + # cast them back in the correct dtype just to be sure everything works as expected. # This might slowdown training & inference so it is recommended to not cast the LayerNorms # in fp32. (LlamaRMSNorm handles it correctly) + input_dtype = query_states.dtype if input_dtype == torch.float32: + # Handle the case where the model is quantized + if hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + logger.warning_once( - "The input hidden states seems to be silently casted in float32, this might be related to" - " the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - " float16." + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." ) - query_states = query_states.to(torch.float16) - key_states = key_states.to(torch.float16) - value_states = value_states.to(torch.float16) + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) attn_output = self._flash_attention_forward( query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate @@ -1187,7 +1194,7 @@ def forward( else: # 4d mask is passed through the layers if attention_mask is not None: - attention_mask = self.causal_attn_mask_converter.to_4d( + attention_mask = self.attn_mask_converter.to_4d( attention_mask, input.shape[1], dtype=inputs_embeds.dtype ) @@ -1404,15 +1411,15 @@ def forward( # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None else: - key_value_length = input_shape[1] + past_key_values_length + key_value_length = input_shape[-1] + past_key_values_length # 4d mask is passed through the layers if attention_mask is not None: attention_mask = self.causal_attn_mask_converter.to_4d( - attention_mask, input_shape[1], key_value_length, dtype=inputs_embeds.dtype + attention_mask, input_shape[-1], key_value_length, dtype=inputs_embeds.dtype ) else: attention_mask = self.causal_attn_mask_converter.to_causal_4d( - input_shape[0], input_shape[1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device + input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device ) # expand encoder attention mask @@ -1422,7 +1429,7 @@ def forward( else: if encoder_attention_mask is not None: encoder_attention_mask = self.attn_mask_converter.to_4d( - encoder_attention_mask, encoder_attention_mask.shape[1], dtype=inputs_embeds.dtype + encoder_attention_mask, input_shape[-1], encoder_attention_mask.shape[1], dtype=inputs_embeds.dtype ) # embed positions diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 825f4ce5a539..3d38903070d4 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -37,7 +37,7 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_available, + is_flash_attn_2_available, logging, replace_return_docstrings, ) @@ -45,10 +45,11 @@ from .tokenization_whisper import TASK_IDS, TO_LANGUAGE_CODE -if is_flash_attn_available(): +if is_flash_attn_2_available(): from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "WhisperConfig" From c4488b93e4e81a882fea7dca464154d2dbaf9a3a Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 24 Oct 2023 16:05:14 +0000 Subject: [PATCH 17/25] Fix all --- src/transformers/models/bart/modeling_bart.py | 11 +- .../models/whisper/modeling_whisper.py | 201 ++++++++++++++++-- 2 files changed, 193 insertions(+), 19 deletions(-) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 7dad0c11ce43..09eb131b423c 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -24,6 +24,8 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from build.lib.transformers.configuration_utils import PretrainedConfig + from ...activations import ACT2FN from ...modeling_outputs import ( BaseModelOutput, @@ -278,6 +280,7 @@ def __init__( is_decoder: bool = False, is_causal: bool = False, bias: bool = True, + config: Optional[PretrainedConfig] = None, ): super().__init__() self.embed_dim = embed_dim @@ -298,6 +301,7 @@ def __init__( self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.config = config def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() @@ -638,12 +642,14 @@ def __init__(self, config: BartConfig): embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, + config=config, ) else: self.self_attn = BartAttention( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, + config=config, ) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.dropout = config.dropout @@ -716,6 +722,7 @@ def __init__(self, config: BartConfig): dropout=config.attention_dropout, is_decoder=True, is_causal=True, + config=config, ) else: self.self_attn = BartAttention( @@ -724,6 +731,7 @@ def __init__(self, config: BartConfig): dropout=config.attention_dropout, is_decoder=True, is_causal=True, + config=config, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -736,6 +744,7 @@ def __init__(self, config: BartConfig): num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, + config=config, ) else: self.encoder_attn = BartAttention( @@ -743,6 +752,7 @@ def __init__(self, config: BartConfig): num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, + config=config, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) @@ -2256,7 +2266,6 @@ def forward( >>> list(logits.shape) == expected_shape True ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 3d38903070d4..7944ebb096f6 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -24,6 +24,8 @@ from torch import nn from torch.nn import CrossEntropyLoss +from build.lib.transformers.configuration_utils import PretrainedConfig + from ...activations import ACT2FN from ...generation.logits_process import WhisperTimeStampLogitsProcessor from ...modeling_outputs import ( @@ -62,10 +64,149 @@ ] +# Copied from transformers.models.llama.modeling_llama.AttnMaskConverter +class AttnMaskConverter: + """ + A utility attention mask class that allows: + - Create a causal 4d mask + - Create a causal 4d mask with slided window + - Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length, + key_value_length) that can be multiplied with attention scores + + Parameters: + is_causal (`bool`): + Whether the attention mask should be a uni-directional (causal) or bi-directional mask. + + sliding_window (`int`, *optional*): + Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer. + """ + + def __init__(self, is_causal: bool, sliding_window: Optional[int] = None): + self.is_causal = is_causal + self.sliding_window = sliding_window + + def to_causal_4d( + self, + batch_size: int, + query_length: int, + key_value_length: int, + dtype: torch.dtype = torch.float32, + device: Union[torch.device, "str"] = "cpu", + ) -> torch.Tensor: + """ + Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative + bias to upper right hand triangular matrix (causal mask). + """ + if not self.is_causal: + raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.") + + # If shape is not cached, create a new causal mask and cache it + input_shape = (batch_size, query_length) + past_key_values_length = key_value_length - query_length + + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + causal_4d_mask = None + if input_shape[-1] > 1 or self.sliding_window is not None: + past_key_values_length = key_value_length - query_length + causal_4d_mask = self._make_causal_mask( + input_shape, + dtype, + device=device, + past_key_values_length=past_key_values_length, + sliding_window=self.sliding_window, + ) + + return causal_4d_mask + + def to_4d( + self, + attention_mask_2d: torch.Tensor, + query_length: int, + key_value_length: Optional[int] = None, + dtype: torch.dtype = torch.float32, + ) -> torch.Tensor: + """ + Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length, + key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is + causal, a causal mask will be added. + """ + input_shape = (attention_mask_2d.shape[0], query_length) + + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + causal_4d_mask = None + if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal: + if key_value_length is None: + raise ValueError("This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask.") + + past_key_values_length = key_value_length - query_length + causal_4d_mask = self._make_causal_mask( + input_shape, + dtype, + device=attention_mask_2d.device, + past_key_values_length=past_key_values_length, + sliding_window=self.sliding_window, + ) + elif self.sliding_window is not None: + raise NotImplementedError("Sliding window is currently only implemented for causal masking") + + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to( + attention_mask_2d.device + ) + expanded_4d_mask = expanded_attn_mask if causal_4d_mask is None else expanded_attn_mask + causal_4d_mask + + return expanded_4d_mask + + def _make_causal_mask( + self, + input_ids_shape: torch.Size, + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0, + sliding_window: Optional[int] = None, + ): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + + # add lower triangular sliding window mask if necessary + if sliding_window is not None: + diagonal = past_key_values_length - sliding_window + 1 + + context_mask = 1 - torch.triu(torch.ones_like(mask, dtype=torch.int), diagonal=diagonal) + mask.masked_fill_(context_mask.bool(), torch.finfo(dtype).min) + + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + def _expand_mask(self, mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + # Copied from transformers.models.llama.modeling_llama._get_unpad_data -def _get_unpad_data(padding_mask): - seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten() +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = seqlens_in_batch.max().item() cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) return ( @@ -349,13 +490,16 @@ def __init__( num_heads: int, dropout: float = 0.0, is_decoder: bool = False, + is_causal: bool = False, bias: bool = True, + config: Optional[PretrainedConfig] = None, ): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads + self.config = config if (self.head_dim * num_heads) != self.embed_dim: raise ValueError( @@ -364,6 +508,7 @@ def __init__( ) self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder + self.is_causal = is_causal self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) @@ -383,7 +528,6 @@ def forward( attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, - padding_mask: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -516,7 +660,6 @@ def forward( attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, - padding_mask: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: # WhisperFlashAttention2 attention does not support output_attentions output_attentions = False @@ -594,7 +737,7 @@ def forward( causal = self.is_decoder and not is_cross_attention attn_output = self._flash_attention_forward( - query_states, key_states, value_states, padding_mask, q_len, causal=causal, dropout=dropout_rate + query_states, key_states, value_states, attention_mask, q_len, causal=causal, dropout=dropout_rate ) attn_output = attn_output.reshape(bsz, q_len, -1) @@ -606,7 +749,7 @@ def forward( return attn_output, attn_weights, past_key_value def _flash_attention_forward( - self, query_states, key_states, value_states, padding_mask, query_length, causal=True, dropout=0.0, softmax_scale=None + self, query_states, key_states, value_states, attention_mask, query_length, causal=True, dropout=0.0, softmax_scale=None ): """ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token @@ -619,7 +762,7 @@ def _flash_attention_forward( Input key states to be passed to Flash Attention API value_states (`torch.Tensor`): Input value states to be passed to Flash Attention API - padding_mask (`torch.Tensor`): + attention_mask (`torch.Tensor`): The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the position of padding tokens and 1 for the position of non-padding tokens. dropout (`int`, *optional*): @@ -628,10 +771,10 @@ def _flash_attention_forward( The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) """ # Contains at least one padding token in the sequence - if padding_mask is not None: + if attention_mask is not None: batch_size = query_states.shape[0] query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( - query_states, key_states, value_states, padding_mask, query_length + query_states, key_states, value_states, attention_mask, query_length ) cu_seqlens_q, cu_seqlens_k = cu_seq_lens @@ -647,7 +790,7 @@ def _flash_attention_forward( max_seqlen_k=max_seqlen_in_batch_k, dropout_p=dropout, softmax_scale=softmax_scale, - causal=causal, + causal=self.is_causal, ) attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) @@ -659,8 +802,8 @@ def _flash_attention_forward( return attn_output # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input - def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_length): - indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask) + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape key_layer = index_first_axis( @@ -685,8 +828,8 @@ def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_l query_layer = query_layer.squeeze(1) else: # The -q_len: slice assumes left padding. - padding_mask = padding_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, padding_mask) + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) return ( query_layer, @@ -708,12 +851,14 @@ def __init__(self, config: WhisperConfig): embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, + config=config ) else: self.self_attn = WhisperAttention( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, + config=config ) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) @@ -787,6 +932,8 @@ def __init__(self, config: WhisperConfig): num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, + is_causal=True, + config=config ) else: self.self_attn = WhisperAttention( @@ -794,6 +941,8 @@ def __init__(self, config: WhisperConfig): num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, + is_causal=True, + config=config ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -1094,6 +1243,8 @@ def __init__(self, config: WhisperConfig): self.max_source_positions = config.max_source_positions self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + self.attn_mask_converter = AttnMaskConverter(is_causal=False) + self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1) self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1) @@ -1244,6 +1395,9 @@ def __init__(self, config: WhisperConfig): self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) + + self.causal_attn_mask_converter = AttnMaskConverter(is_causal=True) + self.embed_positions = WhisperPositionalEmbedding(self.max_target_positions, config.d_model) self.layers = nn.ModuleList([WhisperDecoderLayer(config) for _ in range(config.decoder_layers)]) @@ -1377,9 +1531,20 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) + if getattr(self.config, "_flash_attn_2_enabled", False): + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + else: + key_value_length = input_shape[-1] + past_key_values_length + # 4d mask is passed through the layers + if attention_mask is not None: + attention_mask = self.causal_attn_mask_converter.to_4d( + attention_mask, input_shape[-1], key_value_length, dtype=inputs_embeds.dtype + ) + else: + attention_mask = self.causal_attn_mask_converter.to_causal_4d( + input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device + ) # embed positions if input_ids is not None: From 2f99cea333738401fbddd731592175aea8b5631c Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 24 Oct 2023 16:07:40 +0000 Subject: [PATCH 18/25] rename to converter --- src/transformers/models/bart/modeling_bart.py | 12 ++++----- .../models/whisper/modeling_whisper.py | 26 +------------------ 2 files changed, 6 insertions(+), 32 deletions(-) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 09eb131b423c..b78dca820f39 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -24,8 +24,6 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from build.lib.transformers.configuration_utils import PretrainedConfig - from ...activations import ACT2FN from ...modeling_outputs import ( BaseModelOutput, @@ -280,7 +278,7 @@ def __init__( is_decoder: bool = False, is_causal: bool = False, bias: bool = True, - config: Optional[PretrainedConfig] = None, + config: Optional[BartConfig] = None, ): super().__init__() self.embed_dim = embed_dim @@ -908,18 +906,18 @@ def dummy_inputs(self): return dummy_inputs -class PretrainedBartModel(BartPreTrainedModel): +class BartModel(BartPreTrainedModel): def __init_subclass__(self): warnings.warn( - "The class `PretrainedBartModel` has been depreciated, please use `BartPreTrainedModel` instead.", + "The class `BartModel` has been depreciated, please use `BartPreTrainedModel` instead.", FutureWarning, ) -class BartPretrainedModel(BartPreTrainedModel): +class BartModel(BartPreTrainedModel): def __init_subclass__(self): warnings.warn( - "The class `PretrainedBartModel` has been depreciated, please use `BartPreTrainedModel` instead.", + "The class `BartModel` has been depreciated, please use `BartPreTrainedModel` instead.", FutureWarning, ) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 7944ebb096f6..7bb042990d89 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -24,8 +24,6 @@ from torch import nn from torch.nn import CrossEntropyLoss -from build.lib.transformers.configuration_utils import PretrainedConfig - from ...activations import ACT2FN from ...generation.logits_process import WhisperTimeStampLogitsProcessor from ...modeling_outputs import ( @@ -492,7 +490,7 @@ def __init__( is_decoder: bool = False, is_causal: bool = False, bias: bool = True, - config: Optional[PretrainedConfig] = None, + config: Optional[WhisperConfig] = None, ): super().__init__() self.embed_dim = embed_dim @@ -1414,28 +1412,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.embed_tokens = value - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - def forward( self, input_ids=None, From d8ae461fa14c70c7994bca85a099ff31ae3abbcf Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 24 Oct 2023 16:26:34 +0000 Subject: [PATCH 19/25] fix whisper --- .../models/whisper/modeling_whisper.py | 36 +++++++++++-------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 7bb042990d89..ffcbb46bac1b 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -642,11 +642,10 @@ def forward( # Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->Whisper class WhisperFlashAttention2(WhisperAttention): """ - Whisper flash attention module. This module inherits from `WhisperAttention` as the weights of the module stays + Bart flash attention module. This module inherits from `BartAttention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of flash attention and deal with padding tokens in case the input contains any of them. """ - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) @@ -659,7 +658,7 @@ def forward( layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - # WhisperFlashAttention2 attention does not support output_attentions + # BartFlashAttention2 attention does not support output_attentions output_attentions = False # if key_value_states are provided this layer is used as a cross-attention layer @@ -711,31 +710,37 @@ def forward( if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] - # TODO: Whisper does not have dropout in the config?? + # TODO: Bart does not have dropout in the config?? # It is recommended to use dropout with FA according to the docs # when training. dropout_rate = 0.0 # if not self.training else self.attn_dropout # In PEFT, usually we cast the layer norms in float32 for training stability reasons # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in float16 just to be sure everything works as expected. + # cast them back in the correct dtype just to be sure everything works as expected. # This might slowdown training & inference so it is recommended to not cast the LayerNorms # in fp32. (LlamaRMSNorm handles it correctly) + input_dtype = query_states.dtype if input_dtype == torch.float32: + # Handle the case where the model is quantized + if hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + logger.warning_once( - "The input hidden states seems to be silently casted in float32, this might be related to" - " the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - " float16." + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." ) - query_states = query_states.to(torch.float16) - key_states = key_states.to(torch.float16) - value_states = value_states.to(torch.float16) + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) - causal = self.is_decoder and not is_cross_attention attn_output = self._flash_attention_forward( - query_states, key_states, value_states, attention_mask, q_len, causal=causal, dropout=dropout_rate + query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate ) attn_output = attn_output.reshape(bsz, q_len, -1) @@ -746,8 +751,9 @@ def forward( return attn_output, attn_weights, past_key_value + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward def _flash_attention_forward( - self, query_states, key_states, value_states, attention_mask, query_length, causal=True, dropout=0.0, softmax_scale=None + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None ): """ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token @@ -794,7 +800,7 @@ def _flash_attention_forward( attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) else: attn_output = flash_attn_func( - query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal, + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True ) return attn_output From c4fa0c9e0fdb95b997f1863ee0092c723f8fa2b4 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 24 Oct 2023 19:44:36 +0200 Subject: [PATCH 20/25] Apply suggestions from code review --- src/transformers/models/whisper/modeling_whisper.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index ffcbb46bac1b..69da083c0f06 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -933,7 +933,7 @@ def __init__(self, config: WhisperConfig): if getattr(config, "_flash_attn_2_enabled", False): self.self_attn = WhisperFlashAttention2( embed_dim=self.embed_dim, - num_heads=config.encoder_attention_heads, + num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, is_causal=True, @@ -942,7 +942,7 @@ def __init__(self, config: WhisperConfig): else: self.self_attn = WhisperAttention( embed_dim=self.embed_dim, - num_heads=config.encoder_attention_heads, + num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, is_causal=True, @@ -956,14 +956,14 @@ def __init__(self, config: WhisperConfig): if getattr(config, "_flash_attn_2_enabled", False): self.encoder_attn = WhisperFlashAttention2( embed_dim=self.embed_dim, - num_heads=config.encoder_attention_heads, + num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, ) else: self.encoder_attn = WhisperAttention( embed_dim=self.embed_dim, - num_heads=config.encoder_attention_heads, + num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, ) From 1ee2b6d9c0f1419d7102eb9630c68e39499da40c Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 24 Oct 2023 17:45:07 +0000 Subject: [PATCH 21/25] add spec decoding --- src/transformers/__init__.py | 2 + src/transformers/generation/utils.py | 3 +- src/transformers/models/whisper/__init__.py | 2 + .../models/whisper/modeling_whisper.py | 242 +++++++++++++++++- 4 files changed, 246 insertions(+), 3 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 8001bb33c814..6ac9858b9c55 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -3077,6 +3077,7 @@ "WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST", "WhisperForAudioClassification", "WhisperForConditionalGeneration", + "WhisperForCausalLM" "WhisperModel", "WhisperPreTrainedModel", ] @@ -6820,6 +6821,7 @@ WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST, WhisperForAudioClassification, WhisperForConditionalGeneration, + WhisperForCausalLM, WhisperModel, WhisperPreTrainedModel, ) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 1c412f8185dc..5bc7982495c2 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -4457,6 +4457,7 @@ def assisted_decoding( assistant_model_outputs = assistant_model( assist_inputs, past_key_values=model_kwargs["assistant_past_key_values"], + encoder_outputs=model_kwargs["assistant_encoder_outputs"], ) else: if assistant_model.config.is_encoder_decoder: @@ -4465,7 +4466,7 @@ def assisted_decoding( encoder_outputs=model_kwargs["assistant_encoder_outputs"], ) else: - assistant_model_outputs = assistant_model(candidate_input_ids) + assistant_model_outputs = assistant_model(candidate_input_ids, encoder_outputs=model_kwargs["assistant_encoder_outputs"]) # 1.2. greedily select the next candidate token model_kwargs["assistant_past_key_values"] = assistant_model_outputs.past_key_values diff --git a/src/transformers/models/whisper/__init__.py b/src/transformers/models/whisper/__init__.py index cd962478e34d..df3dadcaacd3 100644 --- a/src/transformers/models/whisper/__init__.py +++ b/src/transformers/models/whisper/__init__.py @@ -50,6 +50,7 @@ "WhisperModel", "WhisperPreTrainedModel", "WhisperForAudioClassification", + "WhisperForCausalLM", ] try: @@ -103,6 +104,7 @@ WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST, WhisperForAudioClassification, WhisperForConditionalGeneration, + WhisperForCausalLM, WhisperModel, WhisperPreTrainedModel, ) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index ffcbb46bac1b..96b98d12c4bd 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -32,6 +32,7 @@ Seq2SeqLMOutput, Seq2SeqModelOutput, SequenceClassifierOutput, + CausalLMOutputWithCrossAttentions, ) from ...modeling_utils import PreTrainedModel from ...utils import ( @@ -991,8 +992,6 @@ def forward( `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. encoder_hidden_states (`torch.FloatTensor`): cross attention input to the layer of shape `(batch, seq_len, embed_dim)` - encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size `(encoder_attention_heads,)`. cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of @@ -2401,3 +2400,242 @@ def forward( hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, ) + + +class WhisperDecoderWrapper(WhisperPreTrainedModel): + """ + This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is + used in combination with the [`EncoderDecoderModel`] framework. + """ + + def __init__(self, config): + super().__init__(config) + config.is_encoder_decoder = False + self.decoder = WhisperDecoder(config) + + def get_input_embeddings(self): + return self.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.decoder.embed_tokens = value + + def forward(self, *args, **kwargs): + return self.decoder(*args, **kwargs) + + +@add_start_docstrings( + """ + Whisper decoder with with a language modeling head on top (linear layer with weights tied to the input embeddings). + """, + WHISPER_START_DOCSTRING, +) +class WhisperForCausalLM(WhisperPreTrainedModel): + _tied_weights_keys = ["proj_out.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = WhisperDecoderWrapper(config) + + self.proj_out = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.proj_out + + def set_output_embeddings(self, new_embeddings): + self.proj_out = new_embeddings + + def get_input_embeddings(self) -> nn.Module: + return self.model.get_input_embeddings() + + def set_decoder(self, decoder): + self.model.decoder = decoder + + def get_decoder(self): + return self.model.decoder + + @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + if the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used + in the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional + tensors are only required when the model is used as a decoder in a Sequence to Sequence model. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, BartForCausalLM + + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base") + >>> model = BartForCausalLM.from_pretrained("facebook/bart-base", add_cross_attention=False) + >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder." + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> logits = outputs.logits + >>> expected_shape = [1, inputs.input_ids.shape[-1], model.config.vocab_size] + >>> list(logits.shape) == expected_shape + True + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + if isinstance(encoder_outputs, (BaseModelOutput, tuple, list)): + encoder_outputs = encoder_outputs[0] + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + encoder_hidden_states=encoder_outputs, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + logits = self.proj_out(outputs[0]) + + loss = None + if labels is not None: + labels = labels.to(logits.device) + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + use_cache=None, + encoder_outputs=None, + attention_mask=None, + **kwargs, + ): + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] + + return { + "encoder_outputs": encoder_outputs, + "past_key_values": past_key_values, + "input_ids": input_ids, + "use_cache": use_cache, + "attention_mask": None, + } + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past From 53f77008400c0ea6bded5af4b268b9870735e39e Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 24 Oct 2023 19:52:10 +0200 Subject: [PATCH 22/25] Update src/transformers/models/bart/modeling_bart.py --- src/transformers/models/bart/modeling_bart.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index b78dca820f39..784f58e2baf6 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -585,7 +585,7 @@ def _flash_attention_forward( attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) else: attn_output = flash_attn_func( - query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal ) return attn_output From d16fde8d52149e22f1dd68d624401d2452ae5c4c Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 24 Oct 2023 17:53:02 +0000 Subject: [PATCH 23/25] correct more --- src/transformers/models/whisper/modeling_whisper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index d69e067bc1e1..e90add14e63e 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -801,7 +801,7 @@ def _flash_attention_forward( attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) else: attn_output = flash_attn_func( - query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal ) return attn_output From b6d2b652ec39057395bca73c4fd6a91dd3cd0c55 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 27 Oct 2023 15:30:24 +0000 Subject: [PATCH 24/25] Add all --- src/transformers/generation/utils.py | 139 +++++++++++++----- .../models/whisper/modeling_whisper.py | 31 +++- 2 files changed, 131 insertions(+), 39 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 5bc7982495c2..dabf801e2d90 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1621,8 +1621,6 @@ def generate( "num_return_sequences has to be 1 when doing assisted generate, " f"but is {generation_config.num_return_sequences}." ) - if batch_size > 1: - raise ValueError("assisted generate is only supported for batch_size = 1") if not model_kwargs["use_cache"]: raise ValueError("assisted generate requires `use_cache=True`") @@ -4407,6 +4405,7 @@ def assisted_decoding( # keep track of which sequences are already finished unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + batch_size = input_ids.shape[0] # other auxiliary variables max_len = stopping_criteria[0].max_length @@ -4421,6 +4420,13 @@ def assisted_decoding( ) this_peer_finished = False # used by synced_gpus only + + # 2 * max_len to give us room to potentially left cut + position_ids = torch.arange(2 * max_len, device=input_ids.device, dtype=torch.long)[None, :].broadcast_to(batch_size, 2 * max_len) if batch_size > 1 else None + attention_mask = torch.ones_like(position_ids) if position_ids is not None else None + n_matches = None + eos_tokens_mask = None + while True: if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. @@ -4432,6 +4438,31 @@ def assisted_decoding( if this_peer_finished_flag.item() == 0.0: break + # Rotate everything for bsz > 1 + if n_matches is not None and position_ids is not None: + # compute by how much everything can be rotated + shift = unfinished_sequences * (n_matches.max() - n_matches) + (1 - unfinished_sequences) * (eos_tokens_mask.sum(-1) - 1) + + for i in range(batch_size): + if shift[i] > 0: + input_ids[i][shift[i]:] = input_ids[i][:-shift[i]].clone() + input_ids[i][:shift[i]] = self.config.pad_token_id + + position_ids = position_ids.add(-shift[:, None]).clamp(min=0) + attention_mask[:, :-1] = position_ids[:, 1:] > 0 + + left_cut = (1 - attention_mask).sum(-1).min() + + if left_cut > 0: + position_ids = position_ids[:, left_cut:] + attention_mask = attention_mask[:, left_cut:] + input_ids = input_ids[:, left_cut:] + + model_kwargs["past_key_values"] = _crop_past_key_values(self, model_kwargs["past_key_values"], left_cut=left_cut) + model_kwargs["assistant_past_key_values"] = _crop_past_key_values( + assistant_model, model_kwargs["assistant_past_key_values"], left_cut=left_cut + ) # the assistant does not have the token after the last match, hence the -1 + # Assistant: main logic start cur_len = input_ids.shape[-1] @@ -4439,23 +4470,30 @@ def assisted_decoding( # `.generate()` call if we decide to add `past_key_values` as a possible output of generate, as we # need access to the assistant cache to secure strong speedups. candidate_input_ids = input_ids - for _ in range(int(num_assistant_tokens)): + for assist_idx in range(int(num_assistant_tokens)): # 1.1. use the assistant model to obtain the next candidate logits if "assistant_past_key_values" in model_kwargs: prev_seq_len = model_kwargs["assistant_past_key_values"][0][assistant_kv_indexing].shape[-2] # `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model) new_token_len = candidate_input_ids.shape[1] - prev_seq_len assist_inputs = candidate_input_ids[:, -new_token_len:] + + assist_position_ids = position_ids[:, cur_len - new_token_len + assist_idx:cur_len + assist_idx] if position_ids is not None else None + assist_attention_mask = attention_mask[:, :cur_len + assist_idx] if attention_mask is not None else None # TODO (joao): make it compatible with models that use unconventional fwd pass logic, like blip2 if assistant_model.config.is_encoder_decoder: assistant_model_outputs = assistant_model( decoder_input_ids=assist_inputs, + decoder_position_ids=assist_position_ids, + decoder_attention_mask=assist_attention_mask, past_key_values=model_kwargs["assistant_past_key_values"], encoder_outputs=model_kwargs["assistant_encoder_outputs"], ) else: assistant_model_outputs = assistant_model( assist_inputs, + position_ids=assist_position_ids, + attention_mask=assist_attention_mask, past_key_values=model_kwargs["assistant_past_key_values"], encoder_outputs=model_kwargs["assistant_encoder_outputs"], ) @@ -4474,6 +4512,10 @@ def assisted_decoding( assistant_model_outputs.logits[:, -1, :] = logits_processor( candidate_input_ids, assistant_model_outputs.logits[:, -1, :] ) + + if stopping_criteria(candidate_input_ids, assistant_model_outputs.logits[:, -1, :]): + break + new_token = assistant_model_outputs.logits[:, -1, :].argmax(dim=-1) candidate_input_ids = torch.cat((candidate_input_ids, new_token[:, None]), dim=-1) @@ -4483,10 +4525,10 @@ def assisted_decoding( last_assistant_token_is_eos = ( ~last_assistant_token_is_eos.ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0).bool() ) - if last_assistant_token_is_eos: + if torch.logical_or(~unfinished_sequences.bool(), last_assistant_token_is_eos).all(): break else: - last_assistant_token_is_eos = False + last_assistant_token_is_eos = torch.zeros((1, 1), device=candidate_input_ids.device, dtype=torch.int8) candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1] @@ -4499,6 +4541,13 @@ def assisted_decoding( candidate_kwargs = self._extend_attention_mask(candidate_kwargs, candidate_input_ids.shape[1]) candidate_kwargs = self._extend_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1]) + if self.config.is_encoder_decoder: + candidate_kwargs["decoder_position_ids"] = position_ids[:, :cur_len + candidate_length] if position_ids is not None else None + candidate_kwargs["decoder_attention_mask"] = attention_mask[:, :cur_len + candidate_length] if attention_mask is not None else None + else: + candidate_kwargs["position_ids"] = position_ids[:, :cur_len + candidate_length] if position_ids is not None else None + candidate_kwargs["attention_mask"] = attention_mask[:, :cur_len + candidate_length] if attention_mask is not None else None + model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs) # 2.2. Run a forward pass on the candidate sequence @@ -4527,7 +4576,7 @@ def assisted_decoding( # 4. Compare the argmax from the original model logits with the assistant forecasted tokens. We can keep # the assistant forecasted tokens until the first mismatch, or until the max length is reached. candidate_new_tokens = candidate_input_ids[:, -candidate_length:] - n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum() + n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum(-1) # 5. Update variables according to the number of matching assistant tokens. Remember: the token generated # by the model after the last candidate match is also valid, as it is generated from a correct sequence. @@ -4535,12 +4584,29 @@ def assisted_decoding( # is no match. # 5.1. Ensure we don't generate beyond max_len or an EOS token - if last_assistant_token_is_eos and n_matches == candidate_length: - n_matches -= 1 - n_matches = min(n_matches, max_len - cur_len - 1) - + n_matches -= last_assistant_token_is_eos.int() * (n_matches == candidate_length).int() + # make sure than already finished sequences always match until longest "still active" sequence + n_matches = torch.clamp(n_matches, max=max_len - cur_len - 1) + # make sure that finished sentences cannot slow down valid tokens + n_matches = unfinished_sequences * n_matches + (1 - unfinished_sequences) * (unfinished_sequences * n_matches).max() + # 5.2. Get the valid continuation, after the matching tokens - valid_tokens = selected_tokens[:, : n_matches + 1] + valid_tokens = selected_tokens[:, : n_matches.max() + 1] + + # if eos_token was found in one sentence, set sentence to finished + if eos_token_id_tensor is not None: + eos_tokens = valid_tokens.eq(eos_token_id_tensor) + finished_seq_mask = ~(unfinished_sequences.bool()[:, None].broadcast_to(valid_tokens.shape)) + eos_tokens_mask = torch.logical_or(eos_tokens.cumsum(-1).bool(), finished_seq_mask) + valid_tokens = torch.where(eos_tokens_mask, eos_token_id_tensor, valid_tokens) + + # check which sentence has finished + unfinished_sequences = (1 - eos_tokens_mask.gather(1, n_matches[:, None]).squeeze(-1).int()) + + # stop when each sentence is finished + if unfinished_sequences.max() == 0: + this_peer_finished = True + input_ids = torch.cat((input_ids, valid_tokens), dim=-1) if streamer is not None: streamer.put(valid_tokens.cpu()) @@ -4548,19 +4614,19 @@ def assisted_decoding( # 5.3. Discard past key values relative to unused assistant tokens new_cache_size = new_cur_len - 1 - outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cache_size) + outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cache_size, n_matches) model_kwargs["assistant_past_key_values"] = _crop_past_key_values( - assistant_model, model_kwargs["assistant_past_key_values"], new_cache_size - 1 + assistant_model, model_kwargs["assistant_past_key_values"], new_cache_size - 1, n_matches ) # the assistant does not have the token after the last match, hence the -1 # 6. Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic, # probably can be improved -- we want to balance the benefits of getting assistant tokens correct with the # cost of forecasting incorrect assistant tokens. if assistant_model.generation_config.num_assistant_tokens_schedule == "heuristic": - if n_matches == int(num_assistant_tokens): - num_assistant_tokens += 2.0 + if n_matches.min() == int(num_assistant_tokens): + num_assistant_tokens += 2 else: - num_assistant_tokens = max(1.0, num_assistant_tokens - 1.0) + num_assistant_tokens = max(1, num_assistant_tokens - 1) # Assistant: main logic end if synced_gpus and this_peer_finished: @@ -4570,7 +4636,7 @@ def assisted_decoding( # Assistant: modified to append one tuple element per token, as in the other generation methods. if return_dict_in_generate: if output_scores: - scores += tuple(new_logits[:, i, :] for i in range(n_matches + 1)) + scores += tuple(new_logits[:, i, :] for i in range(n_matches.max() + 1)) if "past_key_values" not in model_kwargs: added_len = new_cur_len @@ -4611,19 +4677,6 @@ def assisted_decoding( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) - # if eos_token was found in one sentence, set sentence to finished - if eos_token_id_tensor is not None: - unfinished_sequences = unfinished_sequences.mul( - input_ids[:, -1] - .tile(eos_token_id_tensor.shape[0], 1) - .ne(eos_token_id_tensor.unsqueeze(1)) - .prod(dim=0) - ) - - # stop when each sentence is finished - if unfinished_sequences.max() == 0: - this_peer_finished = True - # stop if we exceed the maximum length if stopping_criteria(input_ids, scores): this_peer_finished = True @@ -4656,15 +4709,32 @@ def assisted_decoding( return input_ids -def _crop_past_key_values(model, past_key_values, maximum_length): +def _crop_past_key_values(model, past_key_values, maximum_length=None, n_matches=None, left_cut=None): """Crops the past key values up to a certain maximum length.""" new_past = [] if model.config.is_encoder_decoder: for idx in range(len(past_key_values)): + if left_cut is None: + k_cache = past_key_values[idx][0][:, :, :maximum_length, :] + v_cache = past_key_values[idx][1][:, :, :maximum_length, :] + else: + k_cache = past_key_values[idx][0][:, :, left_cut:, :] + v_cache = past_key_values[idx][1][:, :, left_cut:, :] + + if n_matches is not None: + for batch_idx in range(len(n_matches)): + num_roll_left = n_matches.max() - n_matches[batch_idx] + if num_roll_left > 0: + # TODO(PVP) - check mem usage + # k_cache[batch_idx].index_copy_(1, torch.arange(num_roll_left, maximum_length, device=k_cache.device), k_cache[batch_idx][:, :-num_roll_left].clone()) + # v_cache[batch_idx].index_copy_(1, torch.arange(num_roll_left, maximum_length, device=v_cache.device), v_cache[batch_idx][:, :-num_roll_left].clone()) + k_cache[batch_idx][:, num_roll_left:] = k_cache[batch_idx][:, :-num_roll_left].clone() + v_cache[batch_idx][:, num_roll_left:] = v_cache[batch_idx][:, :-num_roll_left].clone() + new_past.append( ( - past_key_values[idx][0][:, :, :maximum_length, :], - past_key_values[idx][1][:, :, :maximum_length, :], + k_cache, + v_cache, past_key_values[idx][2], past_key_values[idx][3], ) @@ -4721,6 +4791,9 @@ def _split_model_outputs(outputs, new_outputs, cur_len, added_len, is_decoder_at cur_len += 1 added_len -= cur_len + if torch.is_tensor(added_len): + added_len = added_len.max().item() + for i in range(added_len): new_tuple = () for layer in new_outputs: diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index e90add14e63e..f37a455a75b5 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -476,8 +476,11 @@ class WhisperPositionalEmbedding(nn.Embedding): def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None): super().__init__(num_positions, embedding_dim) - def forward(self, input_ids, past_key_values_length=0): - return self.weight[past_key_values_length : past_key_values_length + input_ids.shape[1]] + def forward(self, input_ids, past_key_values_length=0, position_ids=None): + if position_ids is None: + return self.weight[past_key_values_length : past_key_values_length + input_ids.shape[1]] + else: + return self.weight[position_ids] class WhisperAttention(nn.Module): @@ -1426,6 +1429,7 @@ def forward( cross_attn_head_mask=None, past_key_values=None, inputs_embeds=None, + position_ids=None, use_cache=None, output_attentions=None, output_hidden_states=None, @@ -1529,9 +1533,9 @@ def forward( # embed positions if input_ids is not None: - positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length) + positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length, position_ids=position_ids) else: - positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length) + positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length, position_ids=position_ids) hidden_states = inputs_embeds + positions hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) @@ -1719,6 +1723,7 @@ def forward( encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None, + decoder_position_ids: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, @@ -1777,6 +1782,7 @@ def forward( cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=decoder_inputs_embeds, + position_ids=decoder_position_ids, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, @@ -1850,6 +1856,7 @@ def forward( encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None, + decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -1904,6 +1911,7 @@ def forward( cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, decoder_inputs_embeds=decoder_inputs_embeds, + decoder_position_ids=decoder_position_ids, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, @@ -2201,6 +2209,8 @@ def prepare_inputs_for_generation( use_cache=None, encoder_outputs=None, attention_mask=None, + decoder_position_ids=None, + decoder_attention_mask=None, **kwargs, ): if past_key_values is not None: @@ -2215,12 +2225,16 @@ def prepare_inputs_for_generation( decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] + if decoder_position_ids is not None and decoder_position_ids.shape[1] > decoder_input_ids.shape[1]: + decoder_position_ids = decoder_position_ids[:, remove_prefix_length:] + return { "encoder_outputs": encoder_outputs, "past_key_values": past_key_values, "decoder_input_ids": decoder_input_ids, "use_cache": use_cache, - "decoder_attention_mask": None, + "decoder_attention_mask": decoder_attention_mask, + "decoder_position_ids": decoder_position_ids, } @staticmethod @@ -2608,6 +2622,7 @@ def prepare_inputs_for_generation( past_key_values=None, use_cache=None, encoder_outputs=None, + position_ids=None, attention_mask=None, **kwargs, ): @@ -2623,12 +2638,16 @@ def prepare_inputs_for_generation( input_ids = input_ids[:, remove_prefix_length:] + if position_ids is not None and position_ids.shape[1] > position_ids.shape[1]: + position_ids = position_ids[:, remove_prefix_length:] + return { "encoder_outputs": encoder_outputs, "past_key_values": past_key_values, "input_ids": input_ids, "use_cache": use_cache, - "attention_mask": None, + "attention_mask": attention_mask, + "position_ids": position_ids, } @staticmethod From b944f25c7fb1f5638d80cef8219588898281c035 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 27 Oct 2023 16:18:19 +0000 Subject: [PATCH 25/25] get spec dec batch working --- src/transformers/generation/utils.py | 42 ++++++++++++------- .../models/whisper/modeling_whisper.py | 9 +++- 2 files changed, 35 insertions(+), 16 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index dabf801e2d90..363c6cd52186 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -4721,15 +4721,15 @@ def _crop_past_key_values(model, past_key_values, maximum_length=None, n_matches k_cache = past_key_values[idx][0][:, :, left_cut:, :] v_cache = past_key_values[idx][1][:, :, left_cut:, :] - if n_matches is not None: - for batch_idx in range(len(n_matches)): - num_roll_left = n_matches.max() - n_matches[batch_idx] - if num_roll_left > 0: - # TODO(PVP) - check mem usage - # k_cache[batch_idx].index_copy_(1, torch.arange(num_roll_left, maximum_length, device=k_cache.device), k_cache[batch_idx][:, :-num_roll_left].clone()) - # v_cache[batch_idx].index_copy_(1, torch.arange(num_roll_left, maximum_length, device=v_cache.device), v_cache[batch_idx][:, :-num_roll_left].clone()) - k_cache[batch_idx][:, num_roll_left:] = k_cache[batch_idx][:, :-num_roll_left].clone() - v_cache[batch_idx][:, num_roll_left:] = v_cache[batch_idx][:, :-num_roll_left].clone() + if n_matches is not None: + for batch_idx in range(len(n_matches)): + num_roll_left = n_matches.max() - n_matches[batch_idx] + if num_roll_left > 0: + # TODO(PVP) - check mem usage + # k_cache[batch_idx].index_copy_(1, torch.arange(num_roll_left, maximum_length, device=k_cache.device), k_cache[batch_idx][:, :-num_roll_left].clone()) + # v_cache[batch_idx].index_copy_(1, torch.arange(num_roll_left, maximum_length, device=v_cache.device), v_cache[batch_idx][:, :-num_roll_left].clone()) + k_cache[batch_idx][:, num_roll_left:] = k_cache[batch_idx][:, :-num_roll_left].clone() + v_cache[batch_idx][:, num_roll_left:] = v_cache[batch_idx][:, :-num_roll_left].clone() new_past.append( ( @@ -4764,12 +4764,24 @@ def _crop_past_key_values(model, past_key_values, maximum_length=None, n_matches past_key_values[idx] = past_key_values[idx][:, :, :maximum_length, :] else: for idx in range(len(past_key_values)): - new_past.append( - ( - past_key_values[idx][0][:, :, :maximum_length, :], - past_key_values[idx][1][:, :, :maximum_length, :], - ) - ) + if left_cut is None: + k_cache = past_key_values[idx][0][:, :, :maximum_length, :] + v_cache = past_key_values[idx][1][:, :, :maximum_length, :] + else: + k_cache = past_key_values[idx][0][:, :, left_cut:, :] + v_cache = past_key_values[idx][1][:, :, left_cut:, :] + + if n_matches is not None: + for batch_idx in range(len(n_matches)): + num_roll_left = n_matches.max() - n_matches[batch_idx] + if num_roll_left > 0: + # TODO(PVP) - check mem usage + # k_cache[batch_idx].index_copy_(1, torch.arange(num_roll_left, maximum_length, device=k_cache.device), k_cache[batch_idx][:, :-num_roll_left].clone()) + # v_cache[batch_idx].index_copy_(1, torch.arange(num_roll_left, maximum_length, device=v_cache.device), v_cache[batch_idx][:, :-num_roll_left].clone()) + k_cache[batch_idx][:, num_roll_left:] = k_cache[batch_idx][:, :-num_roll_left].clone() + v_cache[batch_idx][:, num_roll_left:] = v_cache[batch_idx][:, :-num_roll_left].clone() + + new_past.append((k_cache, v_cache)) past_key_values = tuple(new_past) return past_key_values diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index f37a455a75b5..dbc7c6caa994 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -1537,7 +1537,10 @@ def forward( else: positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length, position_ids=position_ids) - hidden_states = inputs_embeds + positions + try: + hidden_states = inputs_embeds + positions + except: + import ipdb; ipdb.set_trace() hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) if self.gradient_checkpointing and self.training: @@ -2480,6 +2483,7 @@ def forward( cross_attn_head_mask: Optional[torch.Tensor] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.Tensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -2589,6 +2593,7 @@ def forward( cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, + position_ids=position_ids, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, @@ -2639,7 +2644,9 @@ def prepare_inputs_for_generation( input_ids = input_ids[:, remove_prefix_length:] if position_ids is not None and position_ids.shape[1] > position_ids.shape[1]: + print("pos len", position_ids.shape) position_ids = position_ids[:, remove_prefix_length:] + print("pos len", position_ids.shape) return { "encoder_outputs": encoder_outputs,