Skip to content

Conversation

@huseinzol05
Copy link
Contributor

@huseinzol05 huseinzol05 commented May 11, 2024

Static Cache for Whisper

This enable to use torch.compile for Whisper generation to enable faster generation, example https://gist.github.com/huseinzol05/9aff34ec1427ee8c92240cb4f3cc0c88

Compiled static cache able to achieve 186.26it/s while non-compiled got 150.20it/s .

Still work in progress

  1. Current forked only work to use static cache, need to follow caching steps as Llama.
  2. There are so many conditions need to fulfill first.
  3. Only worked on Pytorch 2.4.0.dev20240508+cu121 version, not yet released as stable for custom function reduce-overhead torch compile.

@mobicham
Copy link
Contributor

Thank you very much @huseinzol05 for the work.
Here's a version with HQQ 4-bit using the torchao backend. As expected there's a good speed-up with the static cache and fullgraph compilation: https://gist.github.com/mobicham/ecfe09a48efb11e4014386901a5c6cce

GPU: 4090
orig - no compile : 48 it/sec
orig + compiled   : 227 it/sec

hqq - no compile  : 42 it/sec
hqq + compile     : 308 it/sec

@kadirnar
Copy link
Contributor

Will it be merged? @younesbelkada

@amyeroberts
Copy link
Contributor

cc @sanchit-gandhi

@huseinzol05
Copy link
Contributor Author

@kadirnar , this PR is not ready to merge, or you can continue to work on it to fulfill no 1, 2 and 3. But if you want to use it, you have to split the audio into 30s chunks with overlap and feed into encoder-decoder process, feel free to add temperature and top_k like https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py#L52

@kadirnar
Copy link
Contributor

kadirnar commented May 16, 2024

Will it work if I run these codes? Also, should I make any changes to the gpt-fast library?

https://gist.github.com/huseinzol05/9aff34ec1427ee8c92240cb4f3cc0c88

@huseinzol05
Copy link
Contributor Author

yeah it should work, i use it in my prod, but dont forget to warmup the static cache multiple time first

@kadirnar
Copy link
Contributor

yeah it should work, i use it in my prod, but dont forget to warmup the static cache multiple time first

I ran the notebook file. It gives this error.

File /usr/local/lib/python3.10/dist-packages/transformers/cache_utils.py:484, in WhisperStaticCache.__init__(self, config, dtype, device, existing_cache, batch_size)
    482 torch._dynamo.mark_static_address(e_key_cache)
    483 torch._dynamo.mark_static_address(e_value_cache)
--> 484 e_key_cache[:, :, :, :] = existing_cache[k][2].clone()
    485 e_value_cache[:, :, :, :] = existing_cache[k][3].clone()
    486 self.key_cache.append(new_layer_key_cache)File /usr/local/lib/python3.10/dist-packages/transformers/cache_utils.py:484, in WhisperStaticCache.__init__(self, config, dtype, device, existing_cache, batch_size)
    482 torch._dynamo.mark_static_address(e_key_cache)
    483 torch._dynamo.mark_static_address(e_value_cache)
--> 484 e_key_cache[:, :, :, :] = existing_cache[k][2].clone()
    485 e_value_cache[:, :, :, :] = existing_cache[k][3].clone()
    486 self.key_cache.append(new_layer_key_cache)

IndexError: tuple index out of range

@huseinzol05
Copy link
Contributor Author

I just reran and no issue, super weird, which line is that you the error?

self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()

class WhisperStaticCache(Cache):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this great first start @huseinzol05! With @gante, we were discussing how the design of the static k/v cache should look for encoder-decoder models, and we distilled the design options down to two possibilities:

  1. Hold a tuple of StaticCache caches, e.g. as proposed here
  2. Add a new Cache classes specific to encoder-decoder models, e.g. those with the attributes:
    • key_cache (same as decoder-only self-attn)
    • value_cache (same as decoder-only self-attn)
    • cross_key_cache (new for enc-dec cross-attn)
    • cross_value_cache (new for enc-dec cross-attn)

Option 1 doesn't require any new Cache classes, so should be easier to maintain! Thus, we were thinking this would be the best design option for Whisper (and other encoder-decoder models in the library, such as BART). Would be curious to hear you opinions here, having had a go at option 2

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@huseinzol05 this is great work!

I'm heavily biased towards option 1, especially now that we are seeing more cache types. For instance, we could easily plug in the quantized cache as the decoder cache with 0 code overhead, if we design Whisper to support a tuple of Cache objects through past_key_values 🤗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

im good with anything

@Jiltseb
Copy link

Jiltseb commented Jun 3, 2024

We did a comparison of the performance of the torch compiled version with static cache and its HQQ variants (4,3,2 and 1.58 bits) on both short-form audio (open_asr_eval) and long-form audio (internal test benchmark).

Here is the link to the blog post: https://mobiusml.github.io/whisper-static-cache-blog/
Colab Notebook: https://colab.research.google.com/drive/18Zs-oG1Ztco3cfnNexcHDi-Zn9vk2RJ5?usp=sharing

I think the speech community can benefit a lot from this speed-up once integrated into transformers 🤗 !

@mobicham
Copy link
Contributor

mobicham commented Jun 6, 2024

Any progress on this folks? Is there a timeline for a general static support in transformers? We are very excited to see this officially supported in transformers!

@github-actions
Copy link
Contributor

github-actions bot commented Jul 1, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@kadirnar
Copy link
Contributor

kadirnar commented Jul 7, 2024

Will you merge this pull request? @sanchit-gandhi

@gante
Copy link
Contributor

gante commented Jul 15, 2024

Closing this PR: whisper + compilation had a few sensible design decisions, as shown in the discussion above, so we took charge of adding static caches to whisper (PR)

Thank you for kickstarting the process and for the discussion 🤗

@gante gante closed this Jul 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants