Skip to content

Commit a5bee89

Browse files
ylacombesanchit-gandhiamyeroberts
authored
Add Flash Attention 2 support to Bark (#27364)
* change handmade attention mask to _prepare_4d_attention_mask * add flashattention2 support in Bark * add flashattention2 tests on BarkSemanticModel * make style * fix flashattention and tests + make style * fix memory leak and allow Bark to pass flash attention to sub-models * make style * Apply suggestions from code review Co-authored-by: Sanchit Gandhi <[email protected]> * remove unecessary code from tests + justify overriding * Update tests/models/bark/test_modeling_bark.py Co-authored-by: amyeroberts <[email protected]> * make style --------- Co-authored-by: Sanchit Gandhi <[email protected]> Co-authored-by: amyeroberts <[email protected]>
1 parent ef71673 commit a5bee89

File tree

2 files changed

+355
-20
lines changed

2 files changed

+355
-20
lines changed

src/transformers/models/bark/modeling_bark.py

Lines changed: 236 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,14 @@
2626
BarkEosPrioritizerLogitsProcessor,
2727
SuppressTokensLogitsProcessor,
2828
)
29+
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
2930
from ...modeling_outputs import CausalLMOutputWithPast, MaskedLMOutput
3031
from ...modeling_utils import PreTrainedModel, get_parameter_device
3132
from ...utils import (
3233
add_start_docstrings,
3334
add_start_docstrings_to_model_forward,
3435
is_accelerate_available,
36+
is_flash_attn_2_available,
3537
logging,
3638
)
3739
from ..auto import AutoModel
@@ -49,6 +51,11 @@
4951
)
5052

5153

54+
if is_flash_attn_2_available():
55+
from flash_attn import flash_attn_func, flash_attn_varlen_func
56+
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
57+
58+
5259
logger = logging.get_logger(__name__)
5360

5461

@@ -62,6 +69,19 @@
6269
]
6370

6471

72+
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
73+
def _get_unpad_data(attention_mask):
74+
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
75+
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
76+
max_seqlen_in_batch = seqlens_in_batch.max().item()
77+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
78+
return (
79+
indices,
80+
cu_seqlens,
81+
max_seqlen_in_batch,
82+
)
83+
84+
6585
class BarkSelfAttention(nn.Module):
6686
# adapted from GPTNeoSelfAttention and Bark code
6787
# BarkSelfAttention can have two attention type, i.e full attention or causal attention
@@ -187,6 +207,177 @@ def forward(
187207
return outputs
188208

189209

210+
class BarkSelfFlashAttention2(BarkSelfAttention):
211+
"""
212+
Bark flash attention module. This module inherits from `BarkSelfAttention` as the weights of the module stays
213+
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
214+
flash attention and deal with padding tokens in case the input contains any of them.
215+
"""
216+
217+
def _split_heads(self, tensor, num_heads, attn_head_size):
218+
"""
219+
Splits hidden_size dim into attn_head_size and num_heads
220+
"""
221+
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
222+
tensor = tensor.view(new_shape)
223+
# Flash attention requires the input to have the shape
224+
# batch_size x seq_length x head_dim x hidden_dim - (batch, seq_length, head, head_features)
225+
return tensor
226+
227+
def _merge_heads(self, tensor, num_heads, attn_head_size):
228+
"""
229+
Merges attn_head_size dim and num_attn_heads dim into hidden_size
230+
"""
231+
# re-assemble all head outputs side by side
232+
# (batch, seq_len, num_heads, attn_head_size) -> (batch, seq_len, num_heads*attn_head_size)
233+
tensor = tensor.view(tensor.size()[:-2] + (num_heads * attn_head_size,))
234+
return tensor
235+
236+
def forward(
237+
self,
238+
hidden_states,
239+
attention_mask=None,
240+
past_key_values=None,
241+
head_mask=None,
242+
use_cache=False,
243+
output_attentions=False,
244+
):
245+
batch_size, query_len, _ = hidden_states.size()
246+
247+
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
248+
query, key, value = self.att_proj(hidden_states).split(self.embed_dim, dim=2)
249+
250+
query = self._split_heads(query, self.num_heads, self.head_dim)
251+
key = self._split_heads(key, self.num_heads, self.head_dim)
252+
value = self._split_heads(value, self.num_heads, self.head_dim)
253+
254+
if past_key_values is not None:
255+
# (batch, head, seq_length, head_features) -> (batch, seq_length, head, head_features)
256+
past_key = past_key_values[0].transpose(1, 2)
257+
past_value = past_key_values[1].transpose(1, 2)
258+
# and merge on seq_length
259+
key = torch.cat((past_key, key), dim=1)
260+
value = torch.cat((past_value, value), dim=1)
261+
262+
if use_cache is True:
263+
# (batch, head, seq_length, head_features)
264+
present = (key.transpose(1, 2), value.transpose(1, 2))
265+
else:
266+
present = None
267+
268+
attn_output = self._flash_attention_forward(query, key, value, attention_mask, query_len, dropout=self.dropout)
269+
270+
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
271+
attn_output = self.out_proj(attn_output)
272+
attn_output = self.resid_dropout(attn_output)
273+
274+
outputs = (attn_output, present)
275+
if output_attentions:
276+
attn_weights = None
277+
outputs += (attn_weights,)
278+
279+
return outputs
280+
281+
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
282+
def _flash_attention_forward(
283+
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
284+
):
285+
"""
286+
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
287+
first unpad the input, then computes the attention scores and pad the final attention scores.
288+
289+
Args:
290+
query_states (`torch.Tensor`):
291+
Input query states to be passed to Flash Attention API
292+
key_states (`torch.Tensor`):
293+
Input key states to be passed to Flash Attention API
294+
value_states (`torch.Tensor`):
295+
Input value states to be passed to Flash Attention API
296+
attention_mask (`torch.Tensor`):
297+
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
298+
position of padding tokens and 1 for the position of non-padding tokens.
299+
dropout (`int`, *optional*):
300+
Attention dropout
301+
softmax_scale (`float`, *optional*):
302+
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
303+
"""
304+
# Contains at least one padding token in the sequence
305+
if attention_mask is not None:
306+
batch_size = query_states.shape[0]
307+
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
308+
query_states, key_states, value_states, attention_mask, query_length
309+
)
310+
311+
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
312+
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
313+
314+
attn_output_unpad = flash_attn_varlen_func(
315+
query_states,
316+
key_states,
317+
value_states,
318+
cu_seqlens_q=cu_seqlens_q,
319+
cu_seqlens_k=cu_seqlens_k,
320+
max_seqlen_q=max_seqlen_in_batch_q,
321+
max_seqlen_k=max_seqlen_in_batch_k,
322+
dropout_p=dropout,
323+
softmax_scale=softmax_scale,
324+
causal=self.is_causal,
325+
)
326+
327+
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
328+
else:
329+
attn_output = flash_attn_func(
330+
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal
331+
)
332+
333+
return attn_output
334+
335+
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
336+
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
337+
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
338+
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
339+
340+
key_layer = index_first_axis(
341+
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
342+
)
343+
value_layer = index_first_axis(
344+
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
345+
)
346+
if query_length == kv_seq_len:
347+
query_layer = index_first_axis(
348+
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
349+
)
350+
cu_seqlens_q = cu_seqlens_k
351+
max_seqlen_in_batch_q = max_seqlen_in_batch_k
352+
indices_q = indices_k
353+
elif query_length == 1:
354+
max_seqlen_in_batch_q = 1
355+
cu_seqlens_q = torch.arange(
356+
batch_size + 1, dtype=torch.int32, device=query_layer.device
357+
) # There is a memcpy here, that is very bad.
358+
indices_q = cu_seqlens_q[:-1]
359+
query_layer = query_layer.squeeze(1)
360+
else:
361+
# The -q_len: slice assumes left padding.
362+
attention_mask = attention_mask[:, -query_length:]
363+
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
364+
365+
return (
366+
query_layer,
367+
key_layer,
368+
value_layer,
369+
indices_q,
370+
(cu_seqlens_q, cu_seqlens_k),
371+
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
372+
)
373+
374+
375+
BARK_ATTENTION_CLASSES = {
376+
"default": BarkSelfAttention,
377+
"flash_attention_2": BarkSelfFlashAttention2,
378+
}
379+
380+
190381
class BarkLayerNorm(nn.Module):
191382
"""LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False."""
192383

@@ -229,7 +420,8 @@ def __init__(self, config, is_causal=False):
229420
self.layernorm_1 = nn.LayerNorm(config.hidden_size)
230421
self.layernorm_2 = nn.LayerNorm(config.hidden_size)
231422

232-
self.attn = BarkSelfAttention(config, is_causal=is_causal)
423+
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
424+
self.attn = BARK_ATTENTION_CLASSES[attn_type](config, is_causal=is_causal)
233425

234426
self.mlp = BarkMLP(config)
235427

@@ -277,6 +469,7 @@ class BarkPreTrainedModel(PreTrainedModel):
277469

278470
config_class = BarkConfig
279471
supports_gradient_checkpointing = False
472+
_supports_flash_attn_2 = True
280473

281474
def _init_weights(self, module):
282475
"""Initialize the weights."""
@@ -596,21 +789,13 @@ def forward(
596789
if attention_mask is not None:
597790
if batch_size <= 0:
598791
raise ValueError("batch_size has to be defined and > 0")
599-
attention_mask = attention_mask.view(batch_size, -1)
600-
# We create a 3D attention mask from a 2D tensor mask.
601-
# Sizes are [batch_size, 1, 1, to_seq_length]
602-
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
603-
# this attention mask is more simple than the triangular masking of causal attention
604-
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
605-
attention_mask = attention_mask[:, None, None, :]
606-
607-
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
608-
# masked positions, this operation will create a tensor which is 0.0 for
609-
# positions we want to attend and the dtype's smallest value for masked positions.
610-
# Since we are adding it to the raw scores before the softmax, this is
611-
# effectively the same as removing these entirely.
612-
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
613-
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
792+
if getattr(self.config, "_flash_attn_2_enabled", False):
793+
attention_mask = attention_mask if 0 in attention_mask else None
794+
else:
795+
attention_mask = attention_mask.view(batch_size, -1)
796+
# [bsz, to_seq_length] -> [bsz, 1, 1, to_seq_length]
797+
# from_seq_length is 1 to easily broadcast
798+
attention_mask = _prepare_4d_attention_mask(attention_mask, input_embeds.dtype, tgt_len=1)
614799

615800
# Prepare head mask if needed
616801
# 1.0 in head_mask indicate we keep the head
@@ -1233,10 +1418,12 @@ def forward(
12331418
if attention_mask is not None:
12341419
if batch_size <= 0:
12351420
raise ValueError("batch_size has to be defined and > 0")
1236-
attention_mask = attention_mask.view(batch_size, -1)
1237-
attention_mask = attention_mask[:, None, None, :]
1238-
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
1239-
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
1421+
if getattr(self.config, "_flash_attn_2_enabled", False):
1422+
attention_mask = attention_mask if 0 in attention_mask else None
1423+
else:
1424+
# [bsz, to_seq_length] -> [bsz, 1, 1, to_seq_length]
1425+
# from_seq_length is 1 to easily broadcast
1426+
attention_mask = _prepare_4d_attention_mask(attention_mask, input_embeds.dtype, tgt_len=1)
12401427

12411428
head_mask = self.get_head_mask(head_mask, self.config.num_layers)
12421429

@@ -1669,3 +1856,32 @@ def generate(
16691856
return audio, output_lengths
16701857

16711858
return audio
1859+
1860+
@classmethod
1861+
def _check_and_enable_flash_attn_2(
1862+
cls, config, torch_dtype: Optional[torch.dtype] = None, device_map: Optional[Union[str, Dict[str, int]]] = None
1863+
):
1864+
"""
1865+
`_check_and_enable_flash_attn_2` originally don't expand flash attention enabling to the model
1866+
sub-configurations. We override the original method to make sure that Bark sub-models are using Flash Attention
1867+
if necessary.
1868+
1869+
If you don't know about Flash Attention, check out the official repository of flash attention:
1870+
https://github.com/Dao-AILab/flash-attention
1871+
1872+
For using Flash Attention 1.0 you can do it directly via the `BetterTransformer` API, have a look at this
1873+
specific section of the documentation to learn more about it:
1874+
https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#decoder-models
1875+
1876+
The method checks if the current setup is compatible with Flash Attention as it requires the model to be in
1877+
half precision and not ran on CPU.
1878+
1879+
If all checks pass, the method will create an attribute in the config `_flash_attn_2_enabled` so that the model
1880+
can initialize the correct attention module
1881+
"""
1882+
config = super()._check_and_enable_flash_attn_2(config, torch_dtype, device_map)
1883+
1884+
config.semantic_config._flash_attn_2_enabled = getattr(config, "_flash_attn_2_enabled", False)
1885+
config.coarse_acoustics_config._flash_attn_2_enabled = getattr(config, "_flash_attn_2_enabled", False)
1886+
config.fine_acoustics_config._flash_attn_2_enabled = getattr(config, "_flash_attn_2_enabled", False)
1887+
return config

0 commit comments

Comments
 (0)