|
26 | 26 | BarkEosPrioritizerLogitsProcessor, |
27 | 27 | SuppressTokensLogitsProcessor, |
28 | 28 | ) |
| 29 | +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask |
29 | 30 | from ...modeling_outputs import CausalLMOutputWithPast, MaskedLMOutput |
30 | 31 | from ...modeling_utils import PreTrainedModel, get_parameter_device |
31 | 32 | from ...utils import ( |
32 | 33 | add_start_docstrings, |
33 | 34 | add_start_docstrings_to_model_forward, |
34 | 35 | is_accelerate_available, |
| 36 | + is_flash_attn_2_available, |
35 | 37 | logging, |
36 | 38 | ) |
37 | 39 | from ..auto import AutoModel |
|
49 | 51 | ) |
50 | 52 |
|
51 | 53 |
|
| 54 | +if is_flash_attn_2_available(): |
| 55 | + from flash_attn import flash_attn_func, flash_attn_varlen_func |
| 56 | + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa |
| 57 | + |
| 58 | + |
52 | 59 | logger = logging.get_logger(__name__) |
53 | 60 |
|
54 | 61 |
|
|
62 | 69 | ] |
63 | 70 |
|
64 | 71 |
|
| 72 | +# Copied from transformers.models.llama.modeling_llama._get_unpad_data |
| 73 | +def _get_unpad_data(attention_mask): |
| 74 | + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) |
| 75 | + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() |
| 76 | + max_seqlen_in_batch = seqlens_in_batch.max().item() |
| 77 | + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) |
| 78 | + return ( |
| 79 | + indices, |
| 80 | + cu_seqlens, |
| 81 | + max_seqlen_in_batch, |
| 82 | + ) |
| 83 | + |
| 84 | + |
65 | 85 | class BarkSelfAttention(nn.Module): |
66 | 86 | # adapted from GPTNeoSelfAttention and Bark code |
67 | 87 | # BarkSelfAttention can have two attention type, i.e full attention or causal attention |
@@ -187,6 +207,177 @@ def forward( |
187 | 207 | return outputs |
188 | 208 |
|
189 | 209 |
|
| 210 | +class BarkSelfFlashAttention2(BarkSelfAttention): |
| 211 | + """ |
| 212 | + Bark flash attention module. This module inherits from `BarkSelfAttention` as the weights of the module stays |
| 213 | + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of |
| 214 | + flash attention and deal with padding tokens in case the input contains any of them. |
| 215 | + """ |
| 216 | + |
| 217 | + def _split_heads(self, tensor, num_heads, attn_head_size): |
| 218 | + """ |
| 219 | + Splits hidden_size dim into attn_head_size and num_heads |
| 220 | + """ |
| 221 | + new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) |
| 222 | + tensor = tensor.view(new_shape) |
| 223 | + # Flash attention requires the input to have the shape |
| 224 | + # batch_size x seq_length x head_dim x hidden_dim - (batch, seq_length, head, head_features) |
| 225 | + return tensor |
| 226 | + |
| 227 | + def _merge_heads(self, tensor, num_heads, attn_head_size): |
| 228 | + """ |
| 229 | + Merges attn_head_size dim and num_attn_heads dim into hidden_size |
| 230 | + """ |
| 231 | + # re-assemble all head outputs side by side |
| 232 | + # (batch, seq_len, num_heads, attn_head_size) -> (batch, seq_len, num_heads*attn_head_size) |
| 233 | + tensor = tensor.view(tensor.size()[:-2] + (num_heads * attn_head_size,)) |
| 234 | + return tensor |
| 235 | + |
| 236 | + def forward( |
| 237 | + self, |
| 238 | + hidden_states, |
| 239 | + attention_mask=None, |
| 240 | + past_key_values=None, |
| 241 | + head_mask=None, |
| 242 | + use_cache=False, |
| 243 | + output_attentions=False, |
| 244 | + ): |
| 245 | + batch_size, query_len, _ = hidden_states.size() |
| 246 | + |
| 247 | + # calculate query, key, values for all heads in batch and move head forward to be the batch dim |
| 248 | + query, key, value = self.att_proj(hidden_states).split(self.embed_dim, dim=2) |
| 249 | + |
| 250 | + query = self._split_heads(query, self.num_heads, self.head_dim) |
| 251 | + key = self._split_heads(key, self.num_heads, self.head_dim) |
| 252 | + value = self._split_heads(value, self.num_heads, self.head_dim) |
| 253 | + |
| 254 | + if past_key_values is not None: |
| 255 | + # (batch, head, seq_length, head_features) -> (batch, seq_length, head, head_features) |
| 256 | + past_key = past_key_values[0].transpose(1, 2) |
| 257 | + past_value = past_key_values[1].transpose(1, 2) |
| 258 | + # and merge on seq_length |
| 259 | + key = torch.cat((past_key, key), dim=1) |
| 260 | + value = torch.cat((past_value, value), dim=1) |
| 261 | + |
| 262 | + if use_cache is True: |
| 263 | + # (batch, head, seq_length, head_features) |
| 264 | + present = (key.transpose(1, 2), value.transpose(1, 2)) |
| 265 | + else: |
| 266 | + present = None |
| 267 | + |
| 268 | + attn_output = self._flash_attention_forward(query, key, value, attention_mask, query_len, dropout=self.dropout) |
| 269 | + |
| 270 | + attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) |
| 271 | + attn_output = self.out_proj(attn_output) |
| 272 | + attn_output = self.resid_dropout(attn_output) |
| 273 | + |
| 274 | + outputs = (attn_output, present) |
| 275 | + if output_attentions: |
| 276 | + attn_weights = None |
| 277 | + outputs += (attn_weights,) |
| 278 | + |
| 279 | + return outputs |
| 280 | + |
| 281 | + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward |
| 282 | + def _flash_attention_forward( |
| 283 | + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None |
| 284 | + ): |
| 285 | + """ |
| 286 | + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token |
| 287 | + first unpad the input, then computes the attention scores and pad the final attention scores. |
| 288 | +
|
| 289 | + Args: |
| 290 | + query_states (`torch.Tensor`): |
| 291 | + Input query states to be passed to Flash Attention API |
| 292 | + key_states (`torch.Tensor`): |
| 293 | + Input key states to be passed to Flash Attention API |
| 294 | + value_states (`torch.Tensor`): |
| 295 | + Input value states to be passed to Flash Attention API |
| 296 | + attention_mask (`torch.Tensor`): |
| 297 | + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the |
| 298 | + position of padding tokens and 1 for the position of non-padding tokens. |
| 299 | + dropout (`int`, *optional*): |
| 300 | + Attention dropout |
| 301 | + softmax_scale (`float`, *optional*): |
| 302 | + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) |
| 303 | + """ |
| 304 | + # Contains at least one padding token in the sequence |
| 305 | + if attention_mask is not None: |
| 306 | + batch_size = query_states.shape[0] |
| 307 | + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( |
| 308 | + query_states, key_states, value_states, attention_mask, query_length |
| 309 | + ) |
| 310 | + |
| 311 | + cu_seqlens_q, cu_seqlens_k = cu_seq_lens |
| 312 | + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens |
| 313 | + |
| 314 | + attn_output_unpad = flash_attn_varlen_func( |
| 315 | + query_states, |
| 316 | + key_states, |
| 317 | + value_states, |
| 318 | + cu_seqlens_q=cu_seqlens_q, |
| 319 | + cu_seqlens_k=cu_seqlens_k, |
| 320 | + max_seqlen_q=max_seqlen_in_batch_q, |
| 321 | + max_seqlen_k=max_seqlen_in_batch_k, |
| 322 | + dropout_p=dropout, |
| 323 | + softmax_scale=softmax_scale, |
| 324 | + causal=self.is_causal, |
| 325 | + ) |
| 326 | + |
| 327 | + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) |
| 328 | + else: |
| 329 | + attn_output = flash_attn_func( |
| 330 | + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal |
| 331 | + ) |
| 332 | + |
| 333 | + return attn_output |
| 334 | + |
| 335 | + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input |
| 336 | + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): |
| 337 | + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) |
| 338 | + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape |
| 339 | + |
| 340 | + key_layer = index_first_axis( |
| 341 | + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k |
| 342 | + ) |
| 343 | + value_layer = index_first_axis( |
| 344 | + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k |
| 345 | + ) |
| 346 | + if query_length == kv_seq_len: |
| 347 | + query_layer = index_first_axis( |
| 348 | + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k |
| 349 | + ) |
| 350 | + cu_seqlens_q = cu_seqlens_k |
| 351 | + max_seqlen_in_batch_q = max_seqlen_in_batch_k |
| 352 | + indices_q = indices_k |
| 353 | + elif query_length == 1: |
| 354 | + max_seqlen_in_batch_q = 1 |
| 355 | + cu_seqlens_q = torch.arange( |
| 356 | + batch_size + 1, dtype=torch.int32, device=query_layer.device |
| 357 | + ) # There is a memcpy here, that is very bad. |
| 358 | + indices_q = cu_seqlens_q[:-1] |
| 359 | + query_layer = query_layer.squeeze(1) |
| 360 | + else: |
| 361 | + # The -q_len: slice assumes left padding. |
| 362 | + attention_mask = attention_mask[:, -query_length:] |
| 363 | + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) |
| 364 | + |
| 365 | + return ( |
| 366 | + query_layer, |
| 367 | + key_layer, |
| 368 | + value_layer, |
| 369 | + indices_q, |
| 370 | + (cu_seqlens_q, cu_seqlens_k), |
| 371 | + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), |
| 372 | + ) |
| 373 | + |
| 374 | + |
| 375 | +BARK_ATTENTION_CLASSES = { |
| 376 | + "default": BarkSelfAttention, |
| 377 | + "flash_attention_2": BarkSelfFlashAttention2, |
| 378 | +} |
| 379 | + |
| 380 | + |
190 | 381 | class BarkLayerNorm(nn.Module): |
191 | 382 | """LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False.""" |
192 | 383 |
|
@@ -229,7 +420,8 @@ def __init__(self, config, is_causal=False): |
229 | 420 | self.layernorm_1 = nn.LayerNorm(config.hidden_size) |
230 | 421 | self.layernorm_2 = nn.LayerNorm(config.hidden_size) |
231 | 422 |
|
232 | | - self.attn = BarkSelfAttention(config, is_causal=is_causal) |
| 423 | + attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default" |
| 424 | + self.attn = BARK_ATTENTION_CLASSES[attn_type](config, is_causal=is_causal) |
233 | 425 |
|
234 | 426 | self.mlp = BarkMLP(config) |
235 | 427 |
|
@@ -277,6 +469,7 @@ class BarkPreTrainedModel(PreTrainedModel): |
277 | 469 |
|
278 | 470 | config_class = BarkConfig |
279 | 471 | supports_gradient_checkpointing = False |
| 472 | + _supports_flash_attn_2 = True |
280 | 473 |
|
281 | 474 | def _init_weights(self, module): |
282 | 475 | """Initialize the weights.""" |
@@ -596,21 +789,13 @@ def forward( |
596 | 789 | if attention_mask is not None: |
597 | 790 | if batch_size <= 0: |
598 | 791 | raise ValueError("batch_size has to be defined and > 0") |
599 | | - attention_mask = attention_mask.view(batch_size, -1) |
600 | | - # We create a 3D attention mask from a 2D tensor mask. |
601 | | - # Sizes are [batch_size, 1, 1, to_seq_length] |
602 | | - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] |
603 | | - # this attention mask is more simple than the triangular masking of causal attention |
604 | | - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. |
605 | | - attention_mask = attention_mask[:, None, None, :] |
606 | | - |
607 | | - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for |
608 | | - # masked positions, this operation will create a tensor which is 0.0 for |
609 | | - # positions we want to attend and the dtype's smallest value for masked positions. |
610 | | - # Since we are adding it to the raw scores before the softmax, this is |
611 | | - # effectively the same as removing these entirely. |
612 | | - attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility |
613 | | - attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min |
| 792 | + if getattr(self.config, "_flash_attn_2_enabled", False): |
| 793 | + attention_mask = attention_mask if 0 in attention_mask else None |
| 794 | + else: |
| 795 | + attention_mask = attention_mask.view(batch_size, -1) |
| 796 | + # [bsz, to_seq_length] -> [bsz, 1, 1, to_seq_length] |
| 797 | + # from_seq_length is 1 to easily broadcast |
| 798 | + attention_mask = _prepare_4d_attention_mask(attention_mask, input_embeds.dtype, tgt_len=1) |
614 | 799 |
|
615 | 800 | # Prepare head mask if needed |
616 | 801 | # 1.0 in head_mask indicate we keep the head |
@@ -1233,10 +1418,12 @@ def forward( |
1233 | 1418 | if attention_mask is not None: |
1234 | 1419 | if batch_size <= 0: |
1235 | 1420 | raise ValueError("batch_size has to be defined and > 0") |
1236 | | - attention_mask = attention_mask.view(batch_size, -1) |
1237 | | - attention_mask = attention_mask[:, None, None, :] |
1238 | | - attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility |
1239 | | - attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min |
| 1421 | + if getattr(self.config, "_flash_attn_2_enabled", False): |
| 1422 | + attention_mask = attention_mask if 0 in attention_mask else None |
| 1423 | + else: |
| 1424 | + # [bsz, to_seq_length] -> [bsz, 1, 1, to_seq_length] |
| 1425 | + # from_seq_length is 1 to easily broadcast |
| 1426 | + attention_mask = _prepare_4d_attention_mask(attention_mask, input_embeds.dtype, tgt_len=1) |
1240 | 1427 |
|
1241 | 1428 | head_mask = self.get_head_mask(head_mask, self.config.num_layers) |
1242 | 1429 |
|
@@ -1669,3 +1856,32 @@ def generate( |
1669 | 1856 | return audio, output_lengths |
1670 | 1857 |
|
1671 | 1858 | return audio |
| 1859 | + |
| 1860 | + @classmethod |
| 1861 | + def _check_and_enable_flash_attn_2( |
| 1862 | + cls, config, torch_dtype: Optional[torch.dtype] = None, device_map: Optional[Union[str, Dict[str, int]]] = None |
| 1863 | + ): |
| 1864 | + """ |
| 1865 | + `_check_and_enable_flash_attn_2` originally don't expand flash attention enabling to the model |
| 1866 | + sub-configurations. We override the original method to make sure that Bark sub-models are using Flash Attention |
| 1867 | + if necessary. |
| 1868 | +
|
| 1869 | + If you don't know about Flash Attention, check out the official repository of flash attention: |
| 1870 | + https://github.com/Dao-AILab/flash-attention |
| 1871 | +
|
| 1872 | + For using Flash Attention 1.0 you can do it directly via the `BetterTransformer` API, have a look at this |
| 1873 | + specific section of the documentation to learn more about it: |
| 1874 | + https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#decoder-models |
| 1875 | +
|
| 1876 | + The method checks if the current setup is compatible with Flash Attention as it requires the model to be in |
| 1877 | + half precision and not ran on CPU. |
| 1878 | +
|
| 1879 | + If all checks pass, the method will create an attribute in the config `_flash_attn_2_enabled` so that the model |
| 1880 | + can initialize the correct attention module |
| 1881 | + """ |
| 1882 | + config = super()._check_and_enable_flash_attn_2(config, torch_dtype, device_map) |
| 1883 | + |
| 1884 | + config.semantic_config._flash_attn_2_enabled = getattr(config, "_flash_attn_2_enabled", False) |
| 1885 | + config.coarse_acoustics_config._flash_attn_2_enabled = getattr(config, "_flash_attn_2_enabled", False) |
| 1886 | + config.fine_acoustics_config._flash_attn_2_enabled = getattr(config, "_flash_attn_2_enabled", False) |
| 1887 | + return config |
0 commit comments