Skip to content

Commit a970195

Browse files
sanchit-gandhiganteArthurZucker
authored
[whisper] static kv cache (#31166)
* make work with cache abstraction * correct for static cache * hacks for compile * make fast * fix * fix pos ids * generate * fix sdpa * fix sdpa cache pos * fix fa2 * clean fa2 * integrate cache into generate * make style * copies * more copies * update eager * update sdpa * update fa2 * simplify * use cache pos * always compute cross-cache for debug * avoid recompiles Co-authored-by: Arthur Zucker <[email protected]> * fix fix * fix fix fix * more fix * try encoder-decoder cache (too messy) * revert encoder-decoder cache * check cross-attn cache * use enc-dec dataclass * use richer enc-dec dataclass * clean-up * revert static cache changes * small fixes * revert to cpu flag * fix copies * add static slow test * past k/v docstring * more docstrings * cache_position docstrings * add to docs * add enc-dec cache to docs * make style * fix after rebase * fix beam * style * fix generation strategies * fix most decoder-only tests * style * skip test * more clean up * small docstrings * Apply suggestions from code review Co-authored-by: Joao Gante <[email protected]> * add todo * only crop self-attn * check cache in mixin * style * fix re-compile after rebase * move `is_updated` logic to enc-dec wrapper * revert back * revert cache back * finalise design * fix * fix fix * style * Update src/transformers/cache_utils.py Co-authored-by: Arthur <[email protected]> * deprecate * updates * final updates * style * style --------- Co-authored-by: Joao Gante <[email protected]> Co-authored-by: Arthur <[email protected]>
1 parent 57d7594 commit a970195

File tree

10 files changed

+705
-258
lines changed

10 files changed

+705
-258
lines changed

docs/source/en/internal/generation_utils.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,12 @@ A [`Constraint`] can be used to force the generation to include specific tokens
391391
- get_seq_length
392392
- reset
393393

394+
[[autodoc]] EncoderDecoderCache
395+
- get_seq_length
396+
- to_legacy_cache
397+
- from_legacy_cache
398+
- reset
399+
- reorder_cache
394400

395401
## Watermark Utils
396402

docs/source/en/model_doc/whisper.md

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,16 +52,14 @@ Here is a step-by-step guide to transcribing an audio sample using a pre-trained
5252
>>> # Select an audio file and read it:
5353
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
5454
>>> audio_sample = ds[0]["audio"]
55-
>>> waveform = audio_sample["array"]
56-
>>> sampling_rate = audio_sample["sampling_rate"]
5755

5856
>>> # Load the Whisper model in Hugging Face format:
5957
>>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
6058
>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
6159

6260
>>> # Use the model and processor to transcribe the audio:
6361
>>> input_features = processor(
64-
... waveform, sampling_rate=sampling_rate, return_tensors="pt"
62+
... audio_sample["array"], sampling_rate=audio_sample["sampling_rate"], return_tensors="pt"
6563
... ).input_features
6664

6765
>>> # Generate token ids
@@ -74,6 +72,49 @@ Here is a step-by-step guide to transcribing an audio sample using a pre-trained
7472
' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
7573
```
7674

75+
Whisper is compatible with the following optimisations:
76+
- [PyTorch Scaled Dot Product Attention (SDPA)](../perf_infer_gpu_one#pytorch-scaled-dot-product-attention): flash attention and memory-efficient attention kernels. Enabled by default for `torch>=2.1.1`.
77+
- [Flash Attention 2](../perf_infer_gpu_one#flashattention-2): improved implementation of flash attention through better parallelism and work partitioning.
78+
- [torch.compile](../llm_optims#static-kv-cache-and-torchcompile): JIT-compile the forward pass to dispatch to efficient fused kernels.
79+
80+
As an example, the following codesnippet enables SDPA and `torch.compile` for up to 5x faster inference:
81+
82+
```python
83+
>>> from datasets import load_dataset
84+
>>> from transformers import WhisperProcessor, WhisperForConditionalGeneration
85+
86+
>>> # Select an audio file and read it:
87+
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
88+
>>> audio_sample = ds[0]["audio"]
89+
90+
>>> # Load the Whisper model with SDPA attention
91+
>>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
92+
>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", attn_implementation="sdpa")
93+
94+
>>> # Enable static cache and compile the forward pass
95+
>>> model.generation_config.cache_implementation = "static"
96+
>>> model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
97+
98+
>>> # Use the model and processor to transcribe the audio:
99+
>>> input_features = processor(
100+
... audio_sample["array"], sampling_rate=audio_sample["sampling_rate"], return_tensors="pt"
101+
... ).input_features
102+
103+
>>> # Compile the forward pass
104+
>>> _ = model.generate(input_features)
105+
106+
>>> # Generate token ids using compiled graph (fast!)
107+
>>> predicted_ids = model.generate(input_features)
108+
109+
>>> # Decode token ids to text
110+
>>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
111+
112+
>>> transcription[0]
113+
' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
114+
```
115+
116+
For more details on each optimisation, refer to the documentation linked above.
117+
77118
## Resources
78119

79120
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with Whisper. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.

src/transformers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1212,6 +1212,7 @@
12121212
"Cache",
12131213
"CacheConfig",
12141214
"DynamicCache",
1215+
"EncoderDecoderCache",
12151216
"HQQQuantizedCache",
12161217
"QuantizedCache",
12171218
"QuantizedCacheConfig",
@@ -5895,6 +5896,7 @@
58955896
Cache,
58965897
CacheConfig,
58975898
DynamicCache,
5899+
EncoderDecoderCache,
58985900
HQQQuantizedCache,
58995901
QuantizedCache,
59005902
QuantizedCacheConfig,

src/transformers/cache_utils.py

Lines changed: 158 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -858,8 +858,12 @@ def update(
858858
k_out = self.key_cache[layer_idx]
859859
v_out = self.value_cache[layer_idx]
860860

861-
k_out[:, :, cache_position] = key_states
862-
v_out[:, :, cache_position] = value_states
861+
if cache_position is None:
862+
k_out.copy_(key_states)
863+
v_out.copy_(value_states)
864+
else:
865+
k_out[:, :, cache_position] = key_states
866+
v_out[:, :, cache_position] = value_states
863867

864868
return k_out, v_out
865869

@@ -971,6 +975,158 @@ def get_max_length(self) -> Optional[int]:
971975
# no matter how long the sentence is
972976
return None
973977

978+
def reset(self):
979+
self.key_cache.zero_()
980+
self.value_cache.zero_()
981+
982+
983+
class EncoderDecoderCache(Cache):
984+
"""
985+
Base, abstract class for all encoder-decoder caches. Can be used to hold combinations of self-attention and
986+
cross-attention caches.
987+
"""
988+
989+
def __init__(self, self_attention_cache: Cache, cross_attention_cache: Cache):
990+
self.self_attention_cache = self_attention_cache
991+
self.cross_attention_cache = cross_attention_cache
992+
993+
self.is_updated = {}
994+
for layer_idx in range(len(cross_attention_cache.key_cache)):
995+
self.is_updated[layer_idx] = bool(cross_attention_cache.get_seq_length(layer_idx) > 0)
996+
997+
def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
998+
"""
999+
Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
1000+
sequence length.
1001+
"""
1002+
if layer_idx < len(self):
1003+
return (
1004+
self.self_attention_cache.key_cache[layer_idx],
1005+
self.self_attention_cache.value_cache[layer_idx],
1006+
self.cross_attention_cache.key_cache[layer_idx],
1007+
self.cross_attention_cache.key_cache[layer_idx],
1008+
)
1009+
else:
1010+
raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
1011+
1012+
def __len__(self):
1013+
"""
1014+
Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
1015+
to the number of layers in the model.
1016+
"""
1017+
return len(self.self_attention_cache)
1018+
1019+
def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
1020+
"""Converts the `EncoderDecoderCache` instance into its equivalent in the legacy cache format."""
1021+
legacy_cache = ()
1022+
if len(self.cross_attention_cache) > 0:
1023+
for self_attn, cross_attn in zip(
1024+
self.self_attention_cache.to_legacy_cache(), self.cross_attention_cache.to_legacy_cache()
1025+
):
1026+
legacy_cache += (self_attn + cross_attn,)
1027+
else:
1028+
legacy_cache = self.self_attention_cache.to_legacy_cache()
1029+
return legacy_cache
1030+
1031+
@classmethod
1032+
def from_legacy_cache(
1033+
cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
1034+
) -> "EncoderDecoderCache":
1035+
"""Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`."""
1036+
cache = cls(self_attention_cache=DynamicCache(), cross_attention_cache=DynamicCache())
1037+
if past_key_values is not None:
1038+
for layer_idx in range(len(past_key_values)):
1039+
key_states, value_states = past_key_values[layer_idx][:2]
1040+
cache.self_attention_cache.update(key_states, value_states, layer_idx)
1041+
if len(past_key_values[layer_idx]) > 2:
1042+
key_states, value_states = past_key_values[layer_idx][2:]
1043+
cache.cross_attention_cache.update(key_states, value_states, layer_idx)
1044+
cache.is_updated[layer_idx] = True
1045+
return cache
1046+
1047+
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
1048+
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
1049+
if len(self.self_attention_cache.key_cache) <= layer_idx:
1050+
return 0
1051+
return (self.self_attention_cache.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
1052+
1053+
def reset(self):
1054+
if hasattr(self.self_attention_cache, "reset"):
1055+
self.self_attention_cache.reset()
1056+
if hasattr(self.cross_attention_cache, "reset"):
1057+
self.cross_attention_cache.reset()
1058+
elif not hasattr(self.self_attention_cache, "reset") and not hasattr(self.cross_attention_cache, "reset"):
1059+
raise ValueError(
1060+
"Neither self nor cross-attention cache have valid `.reset()` methods. `.reset()` should "
1061+
"only be called on compatible cache classes, such as `StaticCache` or `SlidingWindowCache`. "
1062+
f"Got {self.self_attention_cache.__str__()} for the self attention cache and "
1063+
f"{self.cross_attention_cache.__str__()} for the cross attention cache."
1064+
)
1065+
for layer_idx in self.is_updated:
1066+
self.is_updated[layer_idx] = False
1067+
1068+
def reorder_cache(self, beam_idx: torch.LongTensor):
1069+
"""Reorders the cache for beam search, given the selected beam indices."""
1070+
self.self_attention_cache.reorder_cache(beam_idx)
1071+
self.cross_attention_cache.reorder_cache(beam_idx)
1072+
1073+
def check_dynamic_cache(self, method: str):
1074+
if not (
1075+
isinstance(self.self_attention_cache, DynamicCache)
1076+
and isinstance(self.cross_attention_cache, DynamicCache)
1077+
):
1078+
raise ValueError(
1079+
f"`{method}` is only defined for dynamic cache, got {self.self_attention_cache.__str__()} for the self "
1080+
f"attention cache and {self.cross_attention_cache.__str__()} for the cross attention cache."
1081+
)
1082+
1083+
# TODO(gante, sanchit-gandhi): move following functionality into `.generate`
1084+
def crop(self, maximum_length: int):
1085+
"""Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be
1086+
negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search."""
1087+
self.check_dynamic_cache(self.crop.__name__)
1088+
self.self_attention_cache.crop(maximum_length)
1089+
1090+
def batch_split(self, full_batch_size: int, split_size: int) -> "List[EncoderDecoderCache]":
1091+
"""Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
1092+
`_split_model_inputs()` in `generation.utils`"""
1093+
self.check_dynamic_cache(self.batch_split.__name__)
1094+
self_attention_cache = self.self_attention_cache.batch_split(full_batch_size, split_size)
1095+
cross_attention_cache = self.cross_attention_cache.batch_split(full_batch_size, split_size)
1096+
1097+
out = []
1098+
for self_attn, cross_attn in zip(self_attention_cache, cross_attention_cache):
1099+
out.append(EncoderDecoderCache(self_attn, cross_attn))
1100+
return out
1101+
1102+
@classmethod
1103+
def from_batch_splits(cls, splits: List["EncoderDecoderCache"]) -> "EncoderDecoderCache":
1104+
"""This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
1105+
`generation.utils`"""
1106+
self_attention_cache = DynamicCache()
1107+
cross_attention_cache = DynamicCache()
1108+
for idx in range(len(splits[0])):
1109+
layer_keys = torch.cat([current.self_attention_cache.key_cache[idx] for current in splits], dim=0)
1110+
layer_values = torch.cat([current.self_attention_cache.value_cache[idx] for current in splits], dim=0)
1111+
self_attention_cache.update(layer_keys, layer_values, idx)
1112+
1113+
layer_keys = torch.cat([current.cross_attention_cache.key_cache[idx] for current in splits], dim=0)
1114+
layer_values = torch.cat([current.cross_attention_cache.value_cache[idx] for current in splits], dim=0)
1115+
cross_attention_cache.update(layer_keys, layer_values, idx)
1116+
return cls(self_attention_cache, cross_attention_cache)
1117+
1118+
def batch_repeat_interleave(self, repeats: int):
1119+
"""Repeat the cache `repeats` times in the batch dimension. Used in contrastive search."""
1120+
self.check_dynamic_cache(self.batch_repeat_interleave.__name__)
1121+
self.self_attention_cache.batch_repeat_interleave(repeats)
1122+
self.cross_attention_cache.batch_repeat_interleave(repeats)
1123+
1124+
def batch_select_indices(self, indices: torch.Tensor):
1125+
"""Only keep the `indices` in the batch dimension of the cache. Used in contrastive search."""
1126+
self.check_dynamic_cache(self.batch_select_indices.__name__)
1127+
self.self_attention_cache.batch_select_indices(indices)
1128+
self.cross_attention_cache.batch_select_indices(indices)
1129+
9741130

9751131
class HybridCache(Cache):
9761132
def __init__(self, config: PretrainedConfig, max_batch_size, max_cache_len, device="cpu", dtype=None) -> None:

0 commit comments

Comments
 (0)