Skip to content

Commit d750b50

Browse files
authored
Config: unified logic to retrieve text config (#33219)
1 parent ebbe8d8 commit d750b50

File tree

10 files changed

+92
-89
lines changed

10 files changed

+92
-89
lines changed

.circleci/parse_test_outputs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,4 +67,4 @@ def main():
6767

6868

6969
if __name__ == "__main__":
70-
main()
70+
main()

src/transformers/configuration_utils.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1019,17 +1019,17 @@ def _get_non_default_generation_parameters(self) -> Dict[str, Any]:
10191019
"""
10201020
non_default_generation_parameters = {}
10211021
decoder_attribute_name = None
1022-
default_config = None
10231022

10241023
# Composite models don't have a default config, use their decoder config as a fallback for default values
10251024
# If no known pattern is matched, then `default_config = None` -> check against the global generation defaults
10261025
try:
10271026
default_config = self.__class__()
10281027
except ValueError:
1029-
for decoder_attribute_name in ("decoder", "generator", "text_config"):
1030-
if hasattr(self, decoder_attribute_name):
1031-
default_config = getattr(self, decoder_attribute_name).__class__()
1032-
break
1028+
decoder_config = self.get_text_config(decoder=True)
1029+
if decoder_config is not self:
1030+
default_config = decoder_config.__class__()
1031+
else:
1032+
decoder_config = None
10331033

10341034
# If it is a composite model, we want to check the subconfig that will be used for generation
10351035
self_decoder_config = self if decoder_attribute_name is None else getattr(self, decoder_attribute_name)
@@ -1057,6 +1057,36 @@ def _get_non_default_generation_parameters(self) -> Dict[str, Any]:
10571057

10581058
return non_default_generation_parameters
10591059

1060+
def get_text_config(self, decoder=False) -> "PretrainedConfig":
1061+
"""
1062+
Returns the config that is meant to be used with text IO. On most models, it is the original config instance
1063+
itself. On specific composite models, it is under a set of valid names.
1064+
1065+
If `decoder` is set to `True`, then only search for decoder config names.
1066+
"""
1067+
decoder_possible_text_config_names = ("decoder", "generator", "text_config")
1068+
encoder_possible_text_config_names = ("text_encoder",)
1069+
if decoder:
1070+
possible_text_config_names = decoder_possible_text_config_names
1071+
else:
1072+
possible_text_config_names = encoder_possible_text_config_names + decoder_possible_text_config_names
1073+
1074+
valid_text_config_names = []
1075+
for text_config_name in possible_text_config_names:
1076+
if hasattr(self, text_config_name):
1077+
text_config = getattr(self, text_config_name, None)
1078+
if text_config is not None:
1079+
valid_text_config_names += [text_config_name]
1080+
1081+
if len(valid_text_config_names) > 1:
1082+
raise ValueError(
1083+
f"Multiple valid text configs were found in the model config: {valid_text_config_names}. In this "
1084+
"case, using `get_text_config()` would be ambiguous. Please specify the desied text config directly."
1085+
)
1086+
elif len(valid_text_config_names) == 1:
1087+
return getattr(self, valid_text_config_names[0])
1088+
return self
1089+
10601090

10611091
def get_configuration_file(configuration_files: List[str]) -> str:
10621092
"""

src/transformers/generation/configuration_utils.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1192,25 +1192,30 @@ def from_model_config(cls, model_config: PretrainedConfig) -> "GenerationConfig"
11921192
"""
11931193
config_dict = model_config.to_dict()
11941194
config_dict.pop("_from_model_config", None)
1195-
config = cls.from_dict(config_dict, return_unused_kwargs=False, _from_model_config=True)
1195+
generation_config = cls.from_dict(config_dict, return_unused_kwargs=False, _from_model_config=True)
11961196

11971197
# Special case: some models have generation attributes set in the decoder. Use them if still unset in the
1198-
# generation config.
1199-
for decoder_name in ("decoder", "generator", "text_config"):
1200-
if decoder_name in config_dict:
1201-
default_generation_config = GenerationConfig()
1202-
decoder_config = config_dict[decoder_name]
1203-
for attr in config.to_dict().keys():
1204-
if attr in decoder_config and getattr(config, attr) == getattr(default_generation_config, attr):
1205-
setattr(config, attr, decoder_config[attr])
1198+
# generation config (which in turn is defined from the outer attributes of model config).
1199+
decoder_config = model_config.get_text_config(decoder=True)
1200+
if decoder_config is not model_config:
1201+
default_generation_config = GenerationConfig()
1202+
decoder_config_dict = decoder_config.to_dict()
1203+
for attr in generation_config.to_dict().keys():
1204+
is_unset = getattr(generation_config, attr) == getattr(default_generation_config, attr)
1205+
if attr in decoder_config_dict and is_unset:
1206+
setattr(generation_config, attr, decoder_config_dict[attr])
12061207

12071208
# If any `output_...` flag is set to `True`, we ensure `return_dict_in_generate` is set to `True`.
1208-
if config.return_dict_in_generate is False:
1209-
if any(getattr(config, extra_output_flag, False) for extra_output_flag in config.extra_output_flags):
1210-
config.return_dict_in_generate = True
1211-
1212-
config._original_object_hash = hash(config) # Hash to detect whether the instance was modified
1213-
return config
1209+
if generation_config.return_dict_in_generate is False:
1210+
if any(
1211+
getattr(generation_config, extra_output_flag, False)
1212+
for extra_output_flag in generation_config.extra_output_flags
1213+
):
1214+
generation_config.return_dict_in_generate = True
1215+
1216+
# Hash to detect whether the instance was modified
1217+
generation_config._original_object_hash = hash(generation_config)
1218+
return generation_config
12141219

12151220
def update(self, **kwargs):
12161221
"""

src/transformers/integrations/awq.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -209,10 +209,7 @@ def get_modules_to_fuse(model, quantization_config):
209209
current_fused_mapping = AWQ_FUSED_MAPPINGS[model.config.model_type]
210210

211211
# Properly deal with the case where we have a multi-modal model as well (e.g. Llava)
212-
if not hasattr(model.config, "text_config"):
213-
config = model.config
214-
else:
215-
config = model.config.text_config
212+
config = model.config.get_text_config(decoder=True)
216213

217214
# Handle hidden_size, num_attention_heads, num_key_value_heads on our own.
218215
hidden_size = config.hidden_size
@@ -345,11 +342,8 @@ def _fuse_awq_mlp(model, current_module_name, fuse_module_names, module, target_
345342
previous_device = gate_proj.qweight.device
346343

347344
# Deal also with the case model has `text_config` attribute
348-
hidden_act = (
349-
model.config.hidden_act
350-
if not hasattr(model.config, "text_config")
351-
else model.config.text_config.hidden_act
352-
)
345+
config = model.config.get_text_config(decoder=True)
346+
hidden_act = config.hidden_act
353347
activation_fn = ACT2FN[hidden_act]
354348
new_module = target_cls(gate_proj, down_proj, up_proj, activation_fn)
355349

src/transformers/modeling_utils.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2025,11 +2025,8 @@ def resize_token_embeddings(
20252025
else:
20262026
vocab_size = model_embeds.weight.shape[0]
20272027

2028-
# Update base model and current model config
2029-
if hasattr(self.config, "text_config"):
2030-
self.config.text_config.vocab_size = vocab_size
2031-
else:
2032-
self.config.vocab_size = vocab_size
2028+
# Update base model and current model config.
2029+
self.config.get_text_config().vocab_size = vocab_size
20332030
self.vocab_size = vocab_size
20342031

20352032
# Tie weights again if needed

src/transformers/models/clvp/modeling_clvp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -735,7 +735,7 @@ def _init_weights(self, module):
735735
nn.init.normal_(module.fc1.proj.weight if getattr(module.fc1, "proj") else module.fc1.weight, std=fc_std)
736736
nn.init.normal_(module.fc2.weight, std=in_proj_std)
737737
elif isinstance(module, ClvpEncoder):
738-
config = self.config.text_config if hasattr(self.config, "text_config") else self.config
738+
config = self.config.get_text_config()
739739
factor = config.initializer_factor
740740
module.projection.weight.data.normal_(mean=0.0, std=factor * (config.hidden_size**-0.5))
741741
elif isinstance(module, ClvpConditioningEncoder):

src/transformers/models/olmoe/modeling_olmoe.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1330,7 +1330,7 @@ def prepare_inputs_for_generation(
13301330
cache_position=None,
13311331
position_ids=None,
13321332
use_cache=True,
1333-
num_logits_to_keep=0,
1333+
num_logits_to_keep=None,
13341334
**kwargs,
13351335
):
13361336
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
@@ -1381,14 +1381,16 @@ def prepare_inputs_for_generation(
13811381
batch_size=batch_size,
13821382
)
13831383

1384+
if num_logits_to_keep is not None:
1385+
model_inputs["num_logits_to_keep"] = num_logits_to_keep
1386+
13841387
model_inputs.update(
13851388
{
13861389
"position_ids": position_ids,
13871390
"cache_position": cache_position,
13881391
"past_key_values": past_key_values,
13891392
"use_cache": use_cache,
13901393
"attention_mask": attention_mask,
1391-
"num_logits_to_keep": num_logits_to_keep,
13921394
}
13931395
)
13941396
return model_inputs

tests/generation/test_utils.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -831,7 +831,7 @@ def test_constrained_beam_search_generate(self):
831831

832832
# Sample constraints
833833
min_id = 3
834-
max_id = config.vocab_size
834+
max_id = config.get_text_config(decoder=True).vocab_size
835835

836836
force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0]
837837
constraints = [
@@ -889,7 +889,7 @@ def test_constrained_beam_search_generate_dict_output(self):
889889

890890
# Sample constraints
891891
min_id = 3
892-
max_id = model.config.vocab_size
892+
max_id = model.config.get_text_config(decoder=True).vocab_size
893893
force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0]
894894
constraints = [
895895
PhrasalConstraint(force_tokens),
@@ -2012,18 +2012,20 @@ def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_
20122012
self.assertTrue(output.past_key_values is None)
20132013

20142014
def _check_scores(self, batch_size, scores, length, config):
2015-
expected_shape = (batch_size, config.vocab_size)
2015+
vocab_size = config.get_text_config(decoder=True).vocab_size
2016+
expected_shape = (batch_size, vocab_size)
20162017
self.assertIsInstance(scores, tuple)
20172018
self.assertEqual(len(scores), length)
20182019
self.assertListEqual([iter_scores.shape for iter_scores in scores], [expected_shape] * len(scores))
20192020

20202021
def _check_logits(self, batch_size, scores, config):
2022+
vocab_size = config.get_text_config(decoder=True).vocab_size
20212023
self.assertIsInstance(scores, tuple)
20222024
self.assertListEqual([iter_scores.shape[0] for iter_scores in scores], [batch_size] * len(scores))
20232025
# vocabulary difference equal to one (imagegptmodel?) or zero (all other models)
2024-
vocab_diff = config.vocab_size - scores[0].shape[-1]
2026+
vocab_diff = vocab_size - scores[0].shape[-1]
20252027
self.assertTrue(vocab_diff in [0, 1])
2026-
self.assertListEqual([config.vocab_size - score.shape[-1] for score in scores], [vocab_diff] * len(scores))
2028+
self.assertListEqual([vocab_size - score.shape[-1] for score in scores], [vocab_diff] * len(scores))
20272029

20282030
def _check_attentions_for_generate(
20292031
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1

tests/test_modeling_common.py

Lines changed: 16 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1747,12 +1747,13 @@ def test_resize_position_vector_embeddings(self):
17471747
self.assertTrue(models_equal)
17481748

17491749
def test_resize_tokens_embeddings(self):
1750+
if not self.test_resize_embeddings:
1751+
self.skipTest(reason="test_resize_embeddings is set to `False`")
1752+
17501753
(
17511754
original_config,
17521755
inputs_dict,
17531756
) = self.model_tester.prepare_config_and_inputs_for_common()
1754-
if not self.test_resize_embeddings:
1755-
self.skipTest(reason="test_resize_embeddings is set to `False`")
17561757

17571758
for model_class in self.all_model_classes:
17581759
config = copy.deepcopy(original_config)
@@ -1764,18 +1765,15 @@ def test_resize_tokens_embeddings(self):
17641765
if self.model_tester.is_training is False:
17651766
model.eval()
17661767

1767-
model_vocab_size = config.text_config.vocab_size if hasattr(config, "text_config") else config.vocab_size
1768+
model_vocab_size = config.get_text_config().vocab_size
17681769
# Retrieve the embeddings and clone theme
17691770
model_embed = model.resize_token_embeddings(model_vocab_size)
17701771
cloned_embeddings = model_embed.weight.clone()
17711772

17721773
# Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
17731774
model_embed = model.resize_token_embeddings(model_vocab_size + 10)
1774-
new_model_vocab_size = (
1775-
model.config.text_config.vocab_size
1776-
if hasattr(model.config, "text_config")
1777-
else model.config.vocab_size
1778-
)
1775+
new_model_vocab_size = model.config.get_text_config().vocab_size
1776+
17791777
self.assertEqual(new_model_vocab_size, model_vocab_size + 10)
17801778
# Check that it actually resizes the embeddings matrix
17811779
self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] + 10)
@@ -1787,11 +1785,7 @@ def test_resize_tokens_embeddings(self):
17871785

17881786
# Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
17891787
model_embed = model.resize_token_embeddings(model_vocab_size - 15)
1790-
new_model_vocab_size = (
1791-
model.config.text_config.vocab_size
1792-
if hasattr(model.config, "text_config")
1793-
else model.config.vocab_size
1794-
)
1788+
new_model_vocab_size = model.config.get_text_config().vocab_size
17951789
self.assertEqual(new_model_vocab_size, model_vocab_size - 15)
17961790
# Check that it actually resizes the embeddings matrix
17971791
self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] - 15)
@@ -1817,21 +1811,13 @@ def test_resize_tokens_embeddings(self):
18171811
model = model_class(config)
18181812
model.to(torch_device)
18191813

1820-
model_vocab_size = config.text_config.vocab_size if hasattr(config, "text_config") else config.vocab_size
1814+
model_vocab_size = config.get_text_config().vocab_size
18211815
model.resize_token_embeddings(model_vocab_size + 10, pad_to_multiple_of=1)
1822-
new_model_vocab_size = (
1823-
model.config.text_config.vocab_size
1824-
if hasattr(model.config, "text_config")
1825-
else model.config.vocab_size
1826-
)
1816+
new_model_vocab_size = model.config.get_text_config().vocab_size
18271817
self.assertTrue(new_model_vocab_size + 10, model_vocab_size)
18281818

18291819
model_embed = model.resize_token_embeddings(model_vocab_size, pad_to_multiple_of=64)
1830-
new_model_vocab_size = (
1831-
model.config.text_config.vocab_size
1832-
if hasattr(model.config, "text_config")
1833-
else model.config.vocab_size
1834-
)
1820+
new_model_vocab_size = model.config.get_text_config().vocab_size
18351821
self.assertTrue(model_embed.weight.shape[0] // 64, 0)
18361822

18371823
self.assertTrue(model_embed.weight.shape[0], new_model_vocab_size)
@@ -1852,13 +1838,10 @@ def test_resize_tokens_embeddings(self):
18521838
model.resize_token_embeddings(model_vocab_size, pad_to_multiple_of=1.3)
18531839

18541840
def test_resize_embeddings_untied(self):
1855-
(
1856-
original_config,
1857-
inputs_dict,
1858-
) = self.model_tester.prepare_config_and_inputs_for_common()
18591841
if not self.test_resize_embeddings:
18601842
self.skipTest(reason="test_resize_embeddings is set to `False`")
18611843

1844+
original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
18621845
original_config.tie_word_embeddings = False
18631846

18641847
# if model cannot untied embeddings -> leave test
@@ -1874,13 +1857,9 @@ def test_resize_embeddings_untied(self):
18741857
continue
18751858

18761859
# Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
1877-
model_vocab_size = config.text_config.vocab_size if hasattr(config, "text_config") else config.vocab_size
1860+
model_vocab_size = config.get_text_config().vocab_size
18781861
model.resize_token_embeddings(model_vocab_size + 10)
1879-
new_model_vocab_size = (
1880-
model.config.text_config.vocab_size
1881-
if hasattr(model.config, "text_config")
1882-
else model.config.vocab_size
1883-
)
1862+
new_model_vocab_size = model.config.get_text_config().vocab_size
18841863
self.assertEqual(new_model_vocab_size, model_vocab_size + 10)
18851864
output_embeds = model.get_output_embeddings()
18861865
self.assertEqual(output_embeds.weight.shape[0], model_vocab_size + 10)
@@ -1892,11 +1871,7 @@ def test_resize_embeddings_untied(self):
18921871

18931872
# Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
18941873
model.resize_token_embeddings(model_vocab_size - 15)
1895-
new_model_vocab_size = (
1896-
model.config.text_config.vocab_size
1897-
if hasattr(model.config, "text_config")
1898-
else model.config.vocab_size
1899-
)
1874+
new_model_vocab_size = model.config.get_text_config().vocab_size
19001875
self.assertEqual(new_model_vocab_size, model_vocab_size - 15)
19011876
# Check that it actually resizes the embeddings matrix
19021877
output_embeds = model.get_output_embeddings()
@@ -1988,7 +1963,7 @@ def check_same_values(layer_1, layer_2):
19881963
# self.assertTrue(check_same_values(embeddings, decoding))
19891964

19901965
# Check that after resize they remain tied.
1991-
vocab_size = config.text_config.vocab_size if hasattr(config, "text_config") else config.vocab_size
1966+
vocab_size = config.get_text_config().vocab_size
19921967
model_tied.resize_token_embeddings(vocab_size + 10)
19931968
params_tied_2 = list(model_tied.parameters())
19941969
self.assertEqual(len(params_tied_2), len(params_tied))
@@ -4831,7 +4806,7 @@ def test_forward_with_num_logits_to_keep(self):
48314806

48324807
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
48334808
batch_size, sequence_length = inputs["input_ids"].shape
4834-
vocab_size = config.vocab_size
4809+
vocab_size = config.get_text_config().vocab_size
48354810
model = model_class(config).to(device=torch_device).eval()
48364811

48374812
# num_logits_to_keep=0 is a special case meaning "keep all logits"

tests/test_pipeline_mixin.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -675,14 +675,12 @@ def validate_test_components(test_case, task, model, tokenizer, processor):
675675
# Avoid `IndexError` in embedding layers
676676
CONFIG_WITHOUT_VOCAB_SIZE = ["CanineConfig"]
677677
if tokenizer is not None:
678-
config_vocab_size = getattr(model.config, "vocab_size", None)
678+
# Removing `decoder=True` in `get_text_config` can lead to conflicting values e.g. in MusicGen
679+
config_vocab_size = getattr(model.config.get_text_config(decoder=True), "vocab_size", None)
679680
# For CLIP-like models
680681
if config_vocab_size is None:
681-
if hasattr(model.config, "text_config"):
682+
if hasattr(model.config, "text_encoder"):
682683
config_vocab_size = getattr(model.config.text_config, "vocab_size", None)
683-
elif hasattr(model.config, "text_encoder"):
684-
config_vocab_size = getattr(model.config.text_encoder, "vocab_size", None)
685-
686684
if config_vocab_size is None and model.config.__class__.__name__ not in CONFIG_WITHOUT_VOCAB_SIZE:
687685
raise ValueError(
688686
"Could not determine `vocab_size` from model configuration while `tokenizer` is not `None`."

0 commit comments

Comments
 (0)