Skip to content

Commit a30c865

Browse files
Cache: new Cache format in decoder-only models (#31421)
* draft bart with new cache * add cache for decoder-only models * revert utils * modify docstring * revert bart * minor fixes * fix copies (not related) * revert tests * remove enc-dec related code * remove bloom * remove opt (enc-dec) * update docstring * git, codegen, gpt_neo, gpt_neox, gpj * clean up * copied from statements * revert * tmp * update warning msg * forgot git * add more flags * run-slow git,codegen,gpt_neo,gpt_neox,gpj * add cache flag to VLMs * remove files * style * video LLMs also need a flag * style * llava will go in another PR * style * [run-slow] codegen, falcon, git, gpt_neo, gpt_neox, gptj, idefics * Update src/transformers/models/gpt_neo/modeling_gpt_neo.py Co-authored-by: Arthur <[email protected]> * copy from * deprecate until v4.45 and warn if not training * nit * fix test * test static cache * add more tests and fix models * fix copies * return sliding window mask * run slow tests & fix + codestyle * one more falcon fix for alibi --------- Co-authored-by: Arthur <[email protected]>
1 parent 6af0854 commit a30c865

File tree

11 files changed

+1890
-756
lines changed

11 files changed

+1890
-756
lines changed

src/transformers/cache_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1016,7 +1016,9 @@ def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len:
10161016

10171017
self.dtype = dtype if dtype is not None else torch.float32
10181018
self.num_key_value_heads = (
1019-
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
1019+
config.num_attention_heads
1020+
if getattr(config, "num_key_value_heads", None) is None
1021+
else config.num_key_value_heads
10201022
)
10211023

10221024
self.key_cache: List[torch.Tensor] = []

src/transformers/generation/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1473,7 +1473,7 @@ def _get_cache(
14731473
# NOTE: self.dtype is not compatible with torch.compile, as it calls `self.parameters()`.
14741474
# Workaround: trust the lm_head, whose attribute name is somewhat consistent across generative
14751475
# models. May cause trobles with non-text modalities.
1476-
cache_dtype = self.lm_head.weight.dtype
1476+
cache_dtype = self.get_output_embeddings().weight.dtype
14771477

14781478
cache_kwargs = {
14791479
"config": self.config,

src/transformers/models/codegen/modeling_codegen.py

Lines changed: 286 additions & 125 deletions
Large diffs are not rendered by default.

src/transformers/models/falcon/modeling_falcon.py

Lines changed: 312 additions & 141 deletions
Large diffs are not rendered by default.

src/transformers/models/git/modeling_git.py

Lines changed: 84 additions & 61 deletions
Large diffs are not rendered by default.

src/transformers/models/gpt_neo/modeling_gpt_neo.py

Lines changed: 286 additions & 108 deletions
Large diffs are not rendered by default.

src/transformers/models/gpt_neox/modeling_gpt_neox.py

Lines changed: 315 additions & 127 deletions
Large diffs are not rendered by default.

src/transformers/models/gptj/modeling_gptj.py

Lines changed: 303 additions & 144 deletions
Large diffs are not rendered by default.

src/transformers/models/idefics/modeling_idefics.py

Lines changed: 214 additions & 47 deletions
Large diffs are not rendered by default.

tests/generation/test_utils.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959
ImageGPTForCausalImageModeling,
6060
SpeechEncoderDecoderModel,
6161
)
62-
from transformers.cache_utils import DynamicCache, EncoderDecoderCache, QuantoQuantizedCache
62+
from transformers.cache_utils import DynamicCache, EncoderDecoderCache, QuantoQuantizedCache, StaticCache
6363
from transformers.generation import (
6464
BeamSampleDecoderOnlyOutput,
6565
BeamSampleEncoderDecoderOutput,
@@ -1769,6 +1769,53 @@ def test_new_cache_format(self, num_beams, do_sample):
17691769
)
17701770
)
17711771

1772+
def test_generate_with_static_cache(self):
1773+
"""
1774+
Tests if StaticCache works if we set attn_implementation=static when generation.
1775+
This doesn't test if generation quality is good, but tests that models with
1776+
self._supports_static_cache don't throw an error when generating and return
1777+
a StaticCache object at the end.
1778+
"""
1779+
for model_class in self.all_generative_model_classes:
1780+
if not model_class._supports_static_cache:
1781+
self.skipTest(reason="This model does not support the static cache format")
1782+
1783+
config, input_ids, attention_mask = self._get_input_ids_and_config()
1784+
if config.is_encoder_decoder:
1785+
self.skipTest(reason="This model is encoder-decoder and has Encoder-Decoder Cache")
1786+
1787+
config.use_cache = True
1788+
config.is_decoder = True
1789+
batch_size, seq_length = input_ids.shape
1790+
max_new_tokens = 20
1791+
1792+
model = model_class(config).to(torch_device).eval()
1793+
generation_kwargs = {
1794+
"max_length": None,
1795+
"max_new_tokens": max_new_tokens,
1796+
"cache_implementation": "static",
1797+
"return_dict_in_generate": True, # Required to return `past_key_values`
1798+
}
1799+
1800+
max_cache_len = seq_length + max_new_tokens
1801+
head_dim = (
1802+
model.config.head_dim
1803+
if hasattr(model.config, "head_dim")
1804+
else model.config.hidden_size // model.config.num_attention_heads
1805+
)
1806+
num_key_value_heads = (
1807+
model.config.num_attention_heads
1808+
if getattr(config, "num_key_value_heads", None) is None
1809+
else model.config.num_key_value_heads
1810+
)
1811+
num_hidden_layers = config.num_hidden_layers
1812+
results = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
1813+
1814+
cache_shape = (batch_size, num_key_value_heads, max_cache_len, head_dim)
1815+
self.assertTrue(isinstance(results.past_key_values, StaticCache))
1816+
self.assertTrue(len(results.past_key_values.key_cache) == num_hidden_layers)
1817+
self.assertTrue(results.past_key_values.key_cache[0].shape == cache_shape)
1818+
17721819
@require_quanto
17731820
def test_generate_with_quant_cache(self):
17741821
for model_class in self.all_generative_model_classes:

0 commit comments

Comments
 (0)