Skip to content

Conversation

@ylacombe
Copy link
Contributor

@ylacombe ylacombe commented Nov 8, 2023

What does this PR do?

Following a recent series of PRs and issues to improve Bark, this PR aims to add FA2 support to Bark. Bark self-attention class supports both causal and non-causal attention but otherwise changes are minimal.

I've also taken the opportunity to switch to _prepare_4d_attention_mask instead of manually creating the 4d attention mask.

Benchmarks are currently running at the moment to measure speed/memory gains!

cc @sanchit-gandhi and @amyeroberts

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Nov 8, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

@sanchit-gandhi sanchit-gandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very clean - thanks @ylacombe for adding this! Keen to see what kind of performance gain we get from this

cls, config, torch_dtype: Optional[torch.dtype] = None, device_map: Optional[Union[str, Dict[str, int]]] = None
):
"""
If you don't know about Flash Attention, check out the official repository of flash attention:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be worth explaining quickly why we override this method in the docstring!


dummy_attention_mask = inputs_dict.get("attention_mask", None)

if dummy_attention_mask is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the motivation behind overriding the attention mask here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making sure that at least one of the input ids is masked !


logits = (
outputs.hidden_states[-1]
if not model.config.is_encoder_decoder
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We know bark is not encoder-decoder -> could we simplify the tests to reflect this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice catch!

else outputs_fa.decoder_hidden_states[-1]
)

assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pretty high tolerance! We've compared the audio outputs qualitatively with / without flash attention and they match?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the quick review, I've actually copied out and modified a test that is in the general suite, so I haven't change anything -> tolerance and attention mask overriding are the same than the original test!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had the same comment on tolerance for FA2 tests :D 0.04 was agreed as being acceptable

@ylacombe
Copy link
Contributor Author

ylacombe commented Nov 8, 2023

Thanks for the quick review, I've addressed your comments 🤗

Copy link
Contributor

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice - thanks for adding!

else outputs_fa.decoder_hidden_states[-1]
)

assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had the same comment on tolerance for FA2 tests :D 0.04 was agreed as being acceptable

@ylacombe
Copy link
Contributor Author

ylacombe commented Nov 8, 2023

Merging ! thanks for the quick reviews!

@ylacombe ylacombe merged commit a5bee89 into huggingface:main Nov 8, 2023
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this! I usually also request to add a section in the readme, and update the flash attention list of models that are supported here and the readme like this change.

else:
present = None

attn_output = self._flash_attention_forward(query, key, value, attention_mask, query_len, dropout=self.dropout)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here self.dropout is a module not a float. The doc of the _flash_attention_forward does not match and is not restrictive enough

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might work but I'd rather we standardize!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ylacombe ylacombe mentioned this pull request Nov 9, 2023
EduardoPach pushed a commit to EduardoPach/transformers that referenced this pull request Nov 19, 2023
* change handmade attention mask to _prepare_4d_attention_mask

* add flashattention2 support in Bark

* add flashattention2 tests on BarkSemanticModel

* make style

* fix flashattention and tests + make style

* fix memory leak and allow Bark to pass flash attention to sub-models

* make style

* Apply suggestions from code review

Co-authored-by: Sanchit Gandhi <[email protected]>

* remove unecessary code from tests + justify overriding

* Update tests/models/bark/test_modeling_bark.py

Co-authored-by: amyeroberts <[email protected]>

* make style

---------

Co-authored-by: Sanchit Gandhi <[email protected]>
Co-authored-by: amyeroberts <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants